#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Deepgram text-to-speech service implementation.

This module provides integration with Deepgram's text-to-speech API
for generating speech from text using various voice models.
"""

import json
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Optional

import aiohttp
from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    StartFrame,
    TTSAudioRawFrame,
    TTSStoppedFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService, WebsocketTTSService
from pipecat.utils.tracing.service_decorators import traced_tts

try:
    from websockets.asyncio.client import connect as websocket_connect
    from websockets.protocol import State
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error(
        "In order to use DeepgramWebsocketTTSService, you need to `pip install pipecat-ai[deepgram]`."
    )
    raise Exception(f"Missing module: {e}")


@dataclass
class DeepgramTTSSettings(TTSSettings):
    """Settings for DeepgramTTSService and DeepgramHttpTTSService."""

    pass


class DeepgramTTSService(WebsocketTTSService):
    """Deepgram WebSocket-based text-to-speech service.

    Provides real-time text-to-speech synthesis using Deepgram's WebSocket API.
    Supports streaming audio generation with interruption handling via the Clear
    message for conversational AI use cases.
    """

    Settings = DeepgramTTSSettings
    _settings: DeepgramTTSSettings

    SUPPORTED_ENCODINGS = ("linear16", "mulaw", "alaw")

    def __init__(
        self,
        *,
        api_key: str,
        voice: Optional[str] = None,
        base_url: str = "wss://api.deepgram.com",
        sample_rate: Optional[int] = None,
        encoding: str = "linear16",
        settings: Optional[DeepgramTTSSettings] = None,
        **kwargs,
    ):
        """Initialize the Deepgram WebSocket TTS service.

        Args:
            api_key: Deepgram API key for authentication.
            voice: Voice model to use for synthesis.

                .. deprecated:: 0.0.105
                    Use ``settings=DeepgramTTSSettings(voice=...)`` instead.

            base_url: WebSocket base URL for Deepgram API. Defaults to "wss://api.deepgram.com".
            sample_rate: Audio sample rate in Hz. If None, uses service default.
            encoding: Audio encoding format. Defaults to "linear16". Must be one of SUPPORTED_ENCODINGS.
            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional arguments passed to parent InterruptibleTTSService class.

        Raises:
            ValueError: If encoding is not in SUPPORTED_ENCODINGS.
        """
        if encoding.lower() not in self.SUPPORTED_ENCODINGS:
            raise ValueError(
                f"Unsupported encoding '{encoding}'. Must be one of {', '.join(self.SUPPORTED_ENCODINGS)} for WebSocket TTS."
            )

        # 1. Initialize default_settings with hardcoded defaults
        default_settings = DeepgramTTSSettings(
            model=None,
            voice="aura-2-helena-en",
            language=None,
        )

        # 2. Apply direct init arg overrides (deprecated)
        if voice is not None:
            _warn_deprecated_param("voice", DeepgramTTSSettings, "voice")
            default_settings.model = voice
            default_settings.voice = voice

        # 3. (No step 3, as there's no params object to apply)

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            pause_frame_processing=True,
            push_stop_frames=False,
            push_start_frame=True,
            append_trailing_space=True,
            settings=default_settings,
            **kwargs,
        )

        self._api_key = api_key
        self._base_url = base_url
        self._encoding = encoding

        self._receive_task = None

    def can_generate_metrics(self) -> bool:
        """Check if the service can generate metrics.

        Returns:
            True, as Deepgram WebSocket TTS service supports metrics generation.
        """
        return True

    async def start(self, frame: StartFrame):
        """Start the Deepgram WebSocket TTS service.

        Args:
            frame: The start frame containing initialization parameters.
        """
        await super().start(frame)
        await self._connect()

    async def stop(self, frame: EndFrame):
        """Stop the Deepgram WebSocket TTS service.

        Args:
            frame: The end frame.
        """
        await super().stop(frame)
        await self._disconnect()

    async def cancel(self, frame: CancelFrame):
        """Cancel the Deepgram WebSocket TTS service.

        Args:
            frame: The cancel frame.
        """
        await super().cancel(frame)
        await self._disconnect()

    async def _connect(self):
        """Connect to Deepgram WebSocket and start receive task."""
        await super()._connect()

        await self._connect_websocket()

        if self._websocket and not self._receive_task:
            self._receive_task = self.create_task(self._receive_task_handler(self._report_error))

    async def _disconnect(self):
        """Disconnect from Deepgram WebSocket and clean up tasks."""
        await super()._disconnect()

        if self._receive_task:
            await self.cancel_task(self._receive_task)
            self._receive_task = None

        await self._disconnect_websocket()

    async def _update_settings(self, delta: TTSSettings) -> dict[str, Any]:
        """Apply a settings delta.

        Args:
            delta: A :class:`TTSSettings` (or ``DeepgramTTSSettings``) delta.

        Returns:
            Dict mapping changed field names to their previous values.
        """
        changed = await super()._update_settings(delta)

        # Deepgram uses voice as the model, so keep them in sync for metrics
        if "voice" in changed:
            self._settings.model = self._settings.voice
            self._sync_model_name_to_metrics()

        if changed:
            await self._disconnect()
            await self._connect()

        return changed

    async def _connect_websocket(self):
        """Connect to Deepgram WebSocket API with configured settings."""
        try:
            if self._websocket and self._websocket.state is State.OPEN:
                return

            logger.debug("Connecting to Deepgram WebSocket")

            # Build WebSocket URL with query parameters
            params = []
            params.append(f"model={self._settings.voice}")
            params.append(f"encoding={self._encoding}")
            params.append(f"sample_rate={self.sample_rate}")

            url = f"{self._base_url}/v1/speak?{'&'.join(params)}"

            headers = {"Authorization": f"Token {self._api_key}"}

            self._websocket = await websocket_connect(url, additional_headers=headers)

            headers = {
                k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
            }
            logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')

            await self._call_event_handler("on_connected")
        except Exception as e:
            logger.error(f"{self} exception: {e}")
            await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
            self._websocket = None
            await self._call_event_handler("on_connection_error", f"{e}")

    async def _disconnect_websocket(self):
        """Close WebSocket connection and reset state."""
        try:
            await self.stop_all_metrics()

            if self._websocket:
                logger.debug("Disconnecting from Deepgram WebSocket")
                # Send Close message to gracefully close the connection
                await self._websocket.send(json.dumps({"type": "Close"}))
                await self._websocket.close()
        except Exception as e:
            logger.error(f"{self} exception: {e}")
            await self.push_error(ErrorFrame(error=f"{self} error: {e}"))
        finally:
            self._websocket = None
            await self._call_event_handler("on_disconnected")

    def _get_websocket(self):
        """Get active websocket connection or raise exception."""
        if self._websocket:
            return self._websocket
        raise Exception("Websocket not connected")

    async def on_audio_context_interrupted(self, context_id: str):
        """Send Clear message to Deepgram when an audio context is interrupted.

        The Clear message will clear Deepgram's internal text buffer and stop
        sending audio, allowing for a new response to be generated.

        Args:
            context_id: The ID of the audio context that was interrupted.
        """
        await self.stop_all_metrics()
        if self._websocket:
            try:
                await self._websocket.send(json.dumps({"type": "Clear"}))
            except Exception as e:
                logger.error(f"{self} error sending Clear message: {e}")

    async def _receive_messages(self):
        """Receive and process messages from Deepgram WebSocket."""
        async for message in self._get_websocket():
            if isinstance(message, bytes):
                # Binary message contains audio data
                ctx_id = self.get_active_audio_context_id()
                frame = TTSAudioRawFrame(message, self.sample_rate, 1, context_id=ctx_id)
                await self.append_to_audio_context(ctx_id, frame)
            elif isinstance(message, str):
                # Text message contains metadata or control messages
                try:
                    msg = json.loads(message)
                    msg_type = msg.get("type")

                    if msg_type == "Metadata":
                        logger.trace(f"Received metadata: {msg}")
                    elif msg_type == "Flushed":
                        logger.trace(f"Received Flushed: {msg}")
                        ctx_id = self.get_active_audio_context_id()
                        await self.append_to_audio_context(
                            ctx_id, TTSStoppedFrame(context_id=ctx_id)
                        )
                        await self.remove_audio_context(ctx_id)
                    elif msg_type == "Cleared":
                        logger.trace(f"Received Cleared: {msg}")
                        # Buffer has been cleared after interruption.
                        # The on_audio_context_interrupted handler already cleaned up.
                    elif msg_type == "Warning":
                        logger.warning(
                            f"{self} warning: {msg.get('description', 'Unknown warning')}"
                        )
                    else:
                        logger.debug(f"Received unknown message type: {msg}")
                except json.JSONDecodeError:
                    logger.error(f"Invalid JSON message: {message}")

    async def flush_audio(self, context_id: Optional[str] = None):
        """Flush any pending audio synthesis by sending Flush command.

        This should be called when the LLM finishes a complete response to force
        generation of audio from Deepgram's internal text buffer.
        """
        if self._websocket:
            try:
                flush_msg = {"type": "Flush"}
                await self._websocket.send(json.dumps(flush_msg))
            except Exception as e:
                logger.error(f"{self} error sending Flush message: {e}")

    @traced_tts
    async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
        """Generate speech from text using Deepgram's WebSocket TTS API.

        Args:
            text: The text to synthesize into speech.
            context_id: The context ID for tracking audio frames.

        Yields:
            Frame: Audio frames containing the synthesized speech, plus start/stop frames.
        """
        logger.debug(f"{self}: Generating TTS [{text}]")

        try:
            # Reconnect if the websocket is closed
            if not self._websocket or self._websocket.state is State.CLOSED:
                await self._connect()

            # Send text message to Deepgram
            # Note: We don't send Flush here - that should only be sent when the
            # LLM finishes a complete response via flush_audio()
            speak_msg = {"type": "Speak", "text": text}
            await self._get_websocket().send(json.dumps(speak_msg))

            # The audio frames will be handled in _receive_messages
            yield None

        except Exception as e:
            yield ErrorFrame(error=f"Unknown error occurred: {e}")


class DeepgramHttpTTSService(TTSService):
    """Deepgram HTTP text-to-speech service.

    Provides text-to-speech synthesis using Deepgram's HTTP TTS API.
    Supports various voice models and audio encoding formats with
    configurable sample rates and quality settings.
    """

    Settings = DeepgramTTSSettings
    _settings: DeepgramTTSSettings

    def __init__(
        self,
        *,
        api_key: str,
        voice: Optional[str] = None,
        aiohttp_session: aiohttp.ClientSession,
        base_url: str = "https://api.deepgram.com",
        sample_rate: Optional[int] = None,
        encoding: str = "linear16",
        settings: Optional[DeepgramTTSSettings] = None,
        **kwargs,
    ):
        """Initialize the Deepgram TTS service.

        Args:
            api_key: Deepgram API key for authentication.
            voice: Voice model to use for synthesis.

                .. deprecated:: 0.0.105
                    Use ``settings=DeepgramTTSSettings(voice=...)`` instead.

            aiohttp_session: Shared aiohttp session for HTTP requests with connection pooling.
            base_url: Custom base URL for Deepgram API. Defaults to "https://api.deepgram.com".
            sample_rate: Audio sample rate in Hz. If None, uses service default.
            encoding: Audio encoding format. Defaults to "linear16".
            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional arguments passed to parent TTSService class.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = DeepgramTTSSettings(
            model=None,
            voice="aura-2-helena-en",
            language=None,
        )

        # 2. Apply direct init arg overrides (deprecated)
        if voice is not None:
            _warn_deprecated_param("voice", DeepgramTTSSettings, "voice")
            default_settings.model = voice
            default_settings.voice = voice

        # 3. (No step 3, as there's no params object to apply)

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            push_start_frame=True,
            push_stop_frames=True,
            settings=default_settings,
            **kwargs,
        )

        self._api_key = api_key
        self._session = aiohttp_session
        self._base_url = base_url
        self._encoding = encoding

    def can_generate_metrics(self) -> bool:
        """Check if the service can generate metrics.

        Returns:
            True, as Deepgram TTS service supports metrics generation.
        """
        return True

    @traced_tts
    async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
        """Generate speech from text using Deepgram's TTS API.

        Args:
            text: The text to synthesize into speech.
            context_id: The context ID for tracking audio frames.

        Yields:
            Frame: Audio frames containing the synthesized speech, plus start/stop frames.
        """
        logger.debug(f"{self}: Generating TTS [{text}]")

        # Build URL with parameters
        url = f"{self._base_url}/v1/speak"

        headers = {"Authorization": f"Token {self._api_key}", "Content-Type": "application/json"}

        params = {
            "model": self._settings.voice,
            "encoding": self._encoding,
            "sample_rate": self.sample_rate,
            "container": "none",
        }

        payload = {
            "text": text,
        }

        try:
            await self.start_ttfb_metrics()

            async with self._session.post(
                url, headers=headers, json=payload, params=params
            ) as response:
                if response.status != 200:
                    error_text = await response.text()
                    raise Exception(f"HTTP {response.status}: {error_text}")

                await self.start_tts_usage_metrics(text)

                CHUNK_SIZE = self.chunk_size

                first_chunk = True
                async for chunk in response.content.iter_chunked(CHUNK_SIZE):
                    if first_chunk:
                        await self.stop_ttfb_metrics()
                        first_chunk = False

                    if chunk:
                        yield TTSAudioRawFrame(
                            audio=chunk,
                            sample_rate=self.sample_rate,
                            num_channels=1,
                            context_id=context_id,
                        )

        except Exception as e:
            yield ErrorFrame(f"Error getting audio: {str(e)}")
