#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""AWS Polly text-to-speech service implementation.

This module provides integration with Amazon Polly for text-to-speech synthesis,
supporting multiple languages, voices, and SSML features.
"""

import os
from dataclasses import dataclass, field
from typing import AsyncGenerator, List, Optional

from loguru import logger
from pydantic import BaseModel

from pipecat.audio.utils import create_stream_resampler
from pipecat.frames.frames import (
    ErrorFrame,
    Frame,
    TTSAudioRawFrame,
)
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts

try:
    import aioboto3
    from botocore.exceptions import BotoCoreError, ClientError
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
    raise Exception(f"Missing module: {e}")


def language_to_aws_language(language: Language) -> Optional[str]:
    """Convert a Language enum to AWS Polly language code.

    Args:
        language: The Language enum value to convert.

    Returns:
        The corresponding AWS Polly language code, or None if not supported.
    """
    LANGUAGE_MAP = {
        # Arabic
        Language.AR: "arb",
        Language.AR_AE: "ar-AE",
        # Catalan
        Language.CA: "ca-ES",
        # Chinese
        Language.ZH: "cmn-CN",  # Mandarin
        Language.YUE: "yue-CN",  # Cantonese
        Language.YUE_CN: "yue-CN",
        # Czech
        Language.CS: "cs-CZ",
        # Danish
        Language.DA: "da-DK",
        # Dutch
        Language.NL: "nl-NL",
        Language.NL_BE: "nl-BE",
        # English
        Language.EN: "en-US",  # Default to US English
        Language.EN_AU: "en-AU",
        Language.EN_GB: "en-GB",
        Language.EN_IN: "en-IN",
        Language.EN_NZ: "en-NZ",
        Language.EN_US: "en-US",
        Language.EN_ZA: "en-ZA",
        # Finnish
        Language.FI: "fi-FI",
        # French
        Language.FR: "fr-FR",
        Language.FR_BE: "fr-BE",
        Language.FR_CA: "fr-CA",
        # German
        Language.DE: "de-DE",
        Language.DE_AT: "de-AT",
        Language.DE_CH: "de-CH",
        # Hindi
        Language.HI: "hi-IN",
        # Icelandic
        Language.IS: "is-IS",
        # Italian
        Language.IT: "it-IT",
        # Japanese
        Language.JA: "ja-JP",
        # Korean
        Language.KO: "ko-KR",
        # Norwegian
        Language.NO: "nb-NO",
        Language.NB: "nb-NO",
        Language.NB_NO: "nb-NO",
        # Polish
        Language.PL: "pl-PL",
        # Portuguese
        Language.PT: "pt-PT",
        Language.PT_BR: "pt-BR",
        Language.PT_PT: "pt-PT",
        # Romanian
        Language.RO: "ro-RO",
        # Russian
        Language.RU: "ru-RU",
        # Spanish
        Language.ES: "es-ES",
        Language.ES_MX: "es-MX",
        Language.ES_US: "es-US",
        # Swedish
        Language.SV: "sv-SE",
        # Turkish
        Language.TR: "tr-TR",
        # Welsh
        Language.CY: "cy-GB",
        Language.CY_GB: "cy-GB",
    }

    return resolve_language(language, LANGUAGE_MAP, use_base_code=False)


@dataclass
class AWSPollyTTSSettings(TTSSettings):
    """Settings for AWSPollyTTSService.

    Parameters:
        engine: TTS engine to use ('standard', 'neural', etc.).
        pitch: Voice pitch adjustment (for standard engine only).
        rate: Speech rate adjustment.
        volume: Voice volume adjustment.
        lexicon_names: List of pronunciation lexicons to apply.
    """

    engine: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    pitch: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    rate: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    volume: str | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    lexicon_names: List[str] | _NotGiven = field(default_factory=lambda: NOT_GIVEN)


class AWSPollyTTSService(TTSService):
    """AWS Polly text-to-speech service.

    Provides text-to-speech synthesis using Amazon Polly with support for
    multiple languages, voices, SSML features, and voice customization
    options including prosody controls.
    """

    Settings = AWSPollyTTSSettings
    _settings: AWSPollyTTSSettings

    class InputParams(BaseModel):
        """Input parameters for AWS Polly TTS configuration.

        .. deprecated:: 0.0.105
            Use ``AWSPollyTTSSettings`` directly via the ``settings`` parameter instead.

        Parameters:
            engine: TTS engine to use ('standard', 'neural', etc.).
            language: Language for synthesis. Defaults to English.
            pitch: Voice pitch adjustment (for standard engine only).
            rate: Speech rate adjustment.
            volume: Voice volume adjustment.
            lexicon_names: List of pronunciation lexicons to apply.
        """

        engine: Optional[str] = None
        language: Optional[Language] = Language.EN
        pitch: Optional[str] = None
        rate: Optional[str] = None
        volume: Optional[str] = None
        lexicon_names: Optional[List[str]] = None

    def __init__(
        self,
        *,
        api_key: Optional[str] = None,
        aws_access_key_id: Optional[str] = None,
        aws_session_token: Optional[str] = None,
        region: Optional[str] = None,
        voice_id: Optional[str] = None,
        sample_rate: Optional[int] = None,
        params: Optional[InputParams] = None,
        settings: Optional[AWSPollyTTSSettings] = None,
        **kwargs,
    ):
        """Initializes the AWS Polly TTS service.

        Args:
            api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
            aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
            aws_session_token: AWS session token for temporary credentials.
            region: AWS region for Polly service. Defaults to 'us-east-1'.
            voice_id: Voice ID to use for synthesis. Defaults to 'Joanna'.

                .. deprecated:: 0.0.105
                    Use ``settings=AWSPollyTTSSettings(voice=...)`` instead.

            sample_rate: Audio sample rate. If None, uses service default.
            params: Additional input parameters for voice customization.

                .. deprecated:: 0.0.105
                    Use ``settings=AWSPollyTTSSettings(...)`` instead.

            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional arguments passed to parent TTSService class.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = AWSPollyTTSSettings(
            model=None,
            voice="Joanna",
            language="en-US",
            engine=None,
            pitch=None,
            rate=None,
            volume=None,
            lexicon_names=None,
        )

        # 2. Apply direct init arg overrides (deprecated)
        if voice_id is not None:
            _warn_deprecated_param("voice_id", AWSPollyTTSSettings, "voice")
            default_settings.voice = voice_id

        # 3. Apply params overrides — only if settings not provided
        if params is not None:
            _warn_deprecated_param("params", AWSPollyTTSSettings)
            if not settings:
                default_settings.engine = params.engine
                default_settings.language = (
                    self.language_to_service_language(params.language)
                    if params.language
                    else "en-US"
                )
                default_settings.pitch = params.pitch
                default_settings.rate = params.rate
                default_settings.volume = params.volume
                default_settings.lexicon_names = params.lexicon_names

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            push_start_frame=True,
            push_stop_frames=True,
            settings=default_settings,
            **kwargs,
        )

        # Get credentials from environment variables if not provided
        self._aws_params = {
            "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
            "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
            "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
            "region_name": region or os.getenv("AWS_REGION", "us-east-1"),
        }

        self._aws_session = aioboto3.Session()

        self._resampler = create_stream_resampler()

    def can_generate_metrics(self) -> bool:
        """Check if this service can generate processing metrics.

        Returns:
            True, as AWS Polly service supports metrics generation.
        """
        return True

    def language_to_service_language(self, language: Language) -> Optional[str]:
        """Convert a Language enum to AWS Polly language format.

        Args:
            language: The language to convert.

        Returns:
            The AWS Polly-specific language code, or None if not supported.
        """
        return language_to_aws_language(language)

    def _construct_ssml(self, text: str) -> str:
        ssml = "<speak>"

        language = self._settings.language
        ssml += f"<lang xml:lang='{language}'>"

        prosody_attrs = []
        # Prosody tags are only supported for standard and neural engines
        if self._settings.engine == "standard":
            if self._settings.pitch:
                prosody_attrs.append(f"pitch='{self._settings.pitch}'")

        if self._settings.rate:
            prosody_attrs.append(f"rate='{self._settings.rate}'")
        if self._settings.volume:
            prosody_attrs.append(f"volume='{self._settings.volume}'")

        if prosody_attrs:
            ssml += f"<prosody {' '.join(prosody_attrs)}>"

        ssml += text

        if prosody_attrs:
            ssml += "</prosody>"

        ssml += "</lang>"

        ssml += "</speak>"

        logger.trace(f"{self} SSML: {ssml}")

        return ssml

    @traced_tts
    async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
        """Generate speech from text using AWS Polly.

        Args:
            text: The text to synthesize into speech.
            context_id: The context ID for tracking audio frames.

        Yields:
            Frame: Audio frames containing the synthesized speech.
        """
        logger.debug(f"{self}: Generating TTS [{text}]")

        try:
            # Construct the parameters dictionary
            ssml = self._construct_ssml(text)

            params = {
                "Text": ssml,
                "TextType": "ssml",
                "OutputFormat": "pcm",
                "VoiceId": self._settings.voice,
                "Engine": self._settings.engine,
                # AWS only supports 8000 and 16000 for PCM. We select 16000.
                "SampleRate": "16000",
                "LexiconNames": self._settings.lexicon_names,
            }

            # Filter out None values
            filtered_params = {k: v for k, v in params.items() if v is not None}

            async with self._aws_session.client("polly", **self._aws_params) as polly:
                response = await polly.synthesize_speech(**filtered_params)
                if "AudioStream" in response:
                    # Get the streaming body and read it
                    stream = response["AudioStream"]
                    audio_data = await stream.read()
                else:
                    logger.error(f"{self} No audio stream in response")
                    audio_data = None

                audio_data = await self._resampler.resample(audio_data, 16000, self.sample_rate)

                await self.start_tts_usage_metrics(text)

                CHUNK_SIZE = self.chunk_size

                for i in range(0, len(audio_data), CHUNK_SIZE):
                    chunk = audio_data[i : i + CHUNK_SIZE]
                    if len(chunk) > 0:
                        await self.stop_ttfb_metrics()
                        frame = TTSAudioRawFrame(chunk, self.sample_rate, 1, context_id=context_id)
                        yield frame

        except (BotoCoreError, ClientError) as error:
            error_message = f"AWS Polly TTS error: {str(error)}"
            yield ErrorFrame(error=error_message)


class PollyTTSService(AWSPollyTTSService):
    """Deprecated alias for AWSPollyTTSService.

    .. deprecated:: 0.0.67
        `PollyTTSService` is deprecated, use `AWSPollyTTSService` instead.

    """

    Settings = AWSPollyTTSSettings

    def __init__(self, **kwargs):
        """Initialize the deprecated PollyTTSService.

        Args:
            **kwargs: All arguments passed to AWSPollyTTSService.
        """
        super().__init__(**kwargs)

        import warnings

        with warnings.catch_warnings():
            warnings.simplefilter("always")
            warnings.warn(
                "'PollyTTSService' is deprecated, use 'AWSPollyTTSService' instead.",
                DeprecationWarning,
            )
