#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Moondream vision service implementation.

This module provides integration with the Moondream vision-language model
for image analysis and description generation.
"""

import asyncio
from dataclasses import dataclass
from typing import AsyncGenerator, Optional

from loguru import logger
from PIL import Image

from pipecat.frames.frames import (
    ErrorFrame,
    Frame,
    UserImageRawFrame,
    VisionFullResponseEndFrame,
    VisionFullResponseStartFrame,
    VisionTextFrame,
)
from pipecat.services.settings import VisionSettings, _warn_deprecated_param
from pipecat.services.vision_service import VisionService

try:
    import torch
    from transformers import AutoModelForCausalLM
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use Moondream, you need to `pip install pipecat-ai[moondream]`.")
    raise Exception(f"Missing module(s): {e}")


def detect_device():
    """Detect the appropriate device to run on.

    Detects available hardware acceleration and selects the best device
    and data type for optimal performance.

    Returns:
        tuple: A tuple containing (device, dtype) where device is a torch.device
               and dtype is the recommended torch data type for that device.
    """
    try:
        import intel_extension_for_pytorch  # noqa: F401

        if torch.xpu.is_available():
            return torch.device("xpu"), torch.float32
    except ImportError:
        pass
    if torch.cuda.is_available():
        return torch.device("cuda"), torch.float16
    elif torch.backends.mps.is_available():
        return torch.device("mps"), torch.float16
    else:
        return torch.device("cpu"), torch.float32


@dataclass
class MoondreamSettings(VisionSettings):
    """Settings for the Moondream vision service.

    Parameters:
        model: Moondream model identifier.
    """


class MoondreamService(VisionService):
    """Moondream vision-language model service.

    Provides image analysis and description generation using the Moondream
    vision-language model. Supports various hardware acceleration options
    including CUDA, MPS, and Intel XPU.
    """

    Settings = MoondreamSettings
    _settings: MoondreamSettings

    def __init__(
        self,
        *,
        model: Optional[str] = None,
        revision="2025-01-09",
        use_cpu=False,
        settings: Optional[MoondreamSettings] = None,
        **kwargs,
    ):
        """Initialize the Moondream service.

        Args:
            model: Hugging Face model identifier for the Moondream model.

                .. deprecated:: 0.0.105
                    Use ``settings=MoondreamSettings(model=...)`` instead.

            revision: Specific model revision to use.
            use_cpu: Whether to force CPU usage instead of hardware acceleration.
            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional arguments passed to the parent VisionService.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = MoondreamSettings(model="vikhyatk/moondream2")

        # 2. Apply direct init arg overrides (deprecated)
        if model is not None:
            _warn_deprecated_param("model", MoondreamSettings, "model")
            default_settings.model = model

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(settings=default_settings, **kwargs)

        if not use_cpu:
            device, dtype = detect_device()
        else:
            device = torch.device("cpu")
            dtype = torch.float32

        logger.debug("Loading Moondream model...")

        self._model = AutoModelForCausalLM.from_pretrained(
            self._settings.model,
            trust_remote_code=True,
            revision=revision,
            device_map={"": device},
            dtype=dtype,
        ).eval()

        logger.debug("Loaded Moondream model")

    async def run_vision(self, frame: UserImageRawFrame) -> AsyncGenerator[Frame, None]:
        """Analyze an image and generate a description.

        Args:
            frame: The image frame to process.
        """
        if not self._model:
            yield ErrorFrame("Moondream model not available")
            return

        logger.debug(f"Analyzing image (bytes length: {len(frame.image)})")

        def get_image_description(image_bytes: bytes, text: Optional[str]) -> str:
            image = Image.frombytes(frame.format, frame.size, image_bytes)
            image_embeds = self._model.encode_image(image)
            description = self._model.query(image_embeds, text)["answer"]
            return description

        description = await asyncio.to_thread(get_image_description, frame.image, frame.text)

        yield VisionFullResponseStartFrame()
        yield VisionTextFrame(text=description)
        yield VisionFullResponseEndFrame()
