#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Google AI image generation service implementation.

This module provides integration with Google's Imagen model for generating
images from text prompts using the Google AI API.
"""

import io
import os

# Suppress gRPC fork warnings
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"

from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Optional

from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field

from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
from pipecat.services.google.utils import update_google_client_http_options
from pipecat.services.image_service import ImageGenService
from pipecat.services.settings import NOT_GIVEN, ImageGenSettings, _NotGiven, _warn_deprecated_param

try:
    from google import genai
    from google.genai import types
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
    raise Exception(f"Missing module: {e}")


@dataclass
class GoogleImageGenSettings(ImageGenSettings):
    """Settings for the Google image generation service.

    Parameters:
        model: Google Imagen model identifier.
        number_of_images: Number of images to generate per request.
        negative_prompt: Text describing what not to include in generated images.
    """

    number_of_images: int | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    negative_prompt: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)


class GoogleImageGenService(ImageGenService):
    """Google AI image generation service using Imagen models.

    Provides text-to-image generation capabilities using Google's Imagen models
    through the Google AI API. Supports multiple image generation and negative
    prompting for enhanced control over generated content.
    """

    Settings = GoogleImageGenSettings
    _settings: GoogleImageGenSettings

    class InputParams(BaseModel):
        """Configuration parameters for Google image generation.

        .. deprecated:: 0.0.105
            Use ``settings=GoogleImageGenSettings(...)`` instead.

        Parameters:
            number_of_images: Number of images to generate (1-8). Defaults to 1.
            model: Google Imagen model to use. Defaults to "imagen-3.0-generate-002".
            negative_prompt: Optional negative prompt to guide what not to include.
        """

        number_of_images: int = Field(default=1, ge=1, le=8)
        model: str = Field(default="imagen-3.0-generate-002")
        negative_prompt: Optional[str] = Field(default=None)

    def __init__(
        self,
        *,
        api_key: str,
        params: Optional[InputParams] = None,
        http_options: Optional[Any] = None,
        settings: Optional[GoogleImageGenSettings] = None,
        **kwargs,
    ):
        """Initialize the GoogleImageGenService with API key and parameters.

        Args:
            api_key: Google AI API key for authentication.
            params: Configuration parameters for image generation.

                .. deprecated:: 0.0.105
                    Use ``settings=GoogleImageGenSettings(...)`` instead.

            http_options: HTTP options for the client.
            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional arguments passed to the parent ImageGenService.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = GoogleImageGenSettings(
            model="imagen-3.0-generate-002",
            number_of_images=1,
            negative_prompt=None,
        )

        # 2. Apply params overrides (deprecated)
        if params is not None:
            _warn_deprecated_param("params", GoogleImageGenSettings)
            if not settings:
                default_settings.model = params.model
                default_settings.number_of_images = params.number_of_images
                default_settings.negative_prompt = params.negative_prompt

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(settings=default_settings, **kwargs)

        # Add client header
        http_options = update_google_client_http_options(http_options)

        self._client = genai.Client(api_key=api_key, http_options=http_options)

    def can_generate_metrics(self) -> bool:
        """Check if this service can generate processing metrics.

        Returns:
            True, as Google image generation service supports metrics.
        """
        return True

    async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
        """Generate images from a text prompt using Google's Imagen model.

        Args:
            prompt: The text description to generate images from.

        Yields:
            Frame: Generated URLImageRawFrame objects containing the generated
                images, or ErrorFrame objects if generation fails.

        Raises:
            Exception: If there are issues with the Google AI API or image processing.
        """
        logger.debug(f"Generating image from prompt: {prompt}")
        await self.start_ttfb_metrics()

        try:
            response = await self._client.aio.models.generate_images(
                model=self._settings.model,
                prompt=prompt,
                config=types.GenerateImagesConfig(
                    number_of_images=self._settings.number_of_images,
                    negative_prompt=self._settings.negative_prompt,
                ),
            )
            await self.stop_ttfb_metrics()

            if not response or not response.generated_images:
                yield ErrorFrame("Image generation failed")
                return

            for img_response in response.generated_images:
                # Google returns the image data directly
                image_bytes = img_response.image.image_bytes
                image = Image.open(io.BytesIO(image_bytes))

                frame = URLImageRawFrame(
                    url=None,  # Google doesn't provide URLs, only image data
                    image=image.tobytes(),
                    size=image.size,
                    format=image.format,
                )
                yield frame

        except Exception as e:
            yield ErrorFrame(f"Image generation error: {str(e)}")
