#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Cartesia Speech-to-Text service implementation.

This module provides a WebSocket-based STT service that integrates with
the Cartesia Live transcription API for real-time speech recognition.
"""

import json
import urllib.parse
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Optional

from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
    VADUserStartedSpeakingFrame,
    VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import STTSettings, _warn_deprecated_param
from pipecat.services.stt_latency import CARTESIA_TTFS_P99
from pipecat.services.stt_service import WebsocketSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

try:
    from websockets.asyncio.client import connect as websocket_connect
    from websockets.protocol import State
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`.")
    raise Exception(f"Missing module: {e}")


@dataclass
class CartesiaSTTSettings(STTSettings):
    """Settings for CartesiaSTTService."""

    pass


class CartesiaLiveOptions:
    """Configuration options for Cartesia Live STT service.

    .. deprecated:: 0.0.105
        Use ``settings=CartesiaSTTSettings(...)`` for model/language and
        direct ``__init__`` parameters for encoding/sample_rate instead.
    """

    def __init__(
        self,
        *,
        model: str = "ink-whisper",
        language: str = Language.EN.value,
        encoding: str = "pcm_s16le",
        sample_rate: int = 16000,
        **kwargs,
    ):
        """Initialize CartesiaLiveOptions with default or provided parameters.

        Args:
            model: The transcription model to use. Defaults to "ink-whisper".
            language: Target language for transcription. Defaults to English.
            encoding: Audio encoding format. Defaults to "pcm_s16le".
            sample_rate: Audio sample rate in Hz. Defaults to 16000.
            **kwargs: Additional parameters for the transcription service.
        """
        self.model = model
        self.language = language
        self.encoding = encoding
        self.sample_rate = sample_rate
        self.additional_params = kwargs

    def to_dict(self):
        """Convert options to dictionary format.

        Returns:
            Dictionary containing all configuration parameters.
        """
        params = {
            "model": self.model,
            "language": self.language if isinstance(self.language, str) else self.language.value,
            "encoding": self.encoding,
            "sample_rate": str(self.sample_rate),
        }

        return params

    def items(self):
        """Get configuration items as key-value pairs.

        Returns:
            Iterator of (key, value) tuples for all configuration parameters.
        """
        return self.to_dict().items()

    def get(self, key, default=None):
        """Get a configuration value by key.

        Args:
            key: The configuration parameter name to retrieve.
            default: Default value if key is not found.

        Returns:
            The configuration value or default if not found.
        """
        if hasattr(self, key):
            return getattr(self, key)
        return self.additional_params.get(key, default)

    @classmethod
    def from_json(cls, json_str: str) -> "CartesiaLiveOptions":
        """Create options from JSON string.

        Args:
            json_str: JSON string containing configuration parameters.

        Returns:
            New CartesiaLiveOptions instance with parsed parameters.
        """
        return cls(**json.loads(json_str))


class CartesiaSTTService(WebsocketSTTService):
    """Speech-to-text service using Cartesia Live API.

    Provides real-time speech transcription through WebSocket connection
    to Cartesia's Live transcription service. Supports both interim and
    final transcriptions with configurable models and languages.

    Cartesia disconnects WebSocket connections after 3 minutes of inactivity.
    The timeout resets with each message (audio data or text command) sent to
    the server. Silence-based keepalive is enabled by default to prevent this.
    See: https://docs.cartesia.ai/api-reference/stt/stt
    """

    Settings = CartesiaSTTSettings
    _settings: CartesiaSTTSettings

    def __init__(
        self,
        *,
        api_key: str,
        base_url: str = "",
        encoding: str = "pcm_s16le",
        sample_rate: Optional[int] = None,
        live_options: Optional[CartesiaLiveOptions] = None,
        settings: Optional[CartesiaSTTSettings] = None,
        ttfs_p99_latency: Optional[float] = CARTESIA_TTFS_P99,
        **kwargs,
    ):
        """Initialize CartesiaSTTService with API key and options.

        Args:
            api_key: Authentication key for Cartesia API.
            base_url: Custom API endpoint URL. If empty, uses default.
            encoding: Audio encoding format. Defaults to "pcm_s16le".
            sample_rate: Audio sample rate in Hz. If None, uses the pipeline
                sample rate.
            live_options: Configuration options for transcription service.

                .. deprecated:: 0.0.105
                    Use ``settings=CartesiaSTTSettings(...)`` for model/language
                    and direct init parameters for encoding/sample_rate instead.

            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
                Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
            **kwargs: Additional arguments passed to parent STTService.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = CartesiaSTTSettings(
            model="ink-whisper",
            language=Language.EN.value,
        )

        # 2. Apply live_options overrides — only if settings not provided
        if live_options is not None:
            _warn_deprecated_param("live_options", CartesiaSTTSettings)
            if not settings:
                if live_options.sample_rate and sample_rate is None:
                    sample_rate = live_options.sample_rate
                if live_options.encoding:
                    encoding = live_options.encoding
                if live_options.model:
                    default_settings.model = live_options.model
                if live_options.language:
                    lang = live_options.language
                    default_settings.language = lang.value if isinstance(lang, Language) else lang

        # 3. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            ttfs_p99_latency=ttfs_p99_latency,
            keepalive_timeout=120,
            keepalive_interval=30,
            settings=default_settings,
            **kwargs,
        )

        self._api_key = api_key
        self._base_url = base_url or "api.cartesia.ai"
        self._receive_task = None

        # Init-only audio config (not runtime-updatable).
        self._encoding = encoding

    def can_generate_metrics(self) -> bool:
        """Check if the service can generate processing metrics.

        Returns:
            True, indicating metrics are supported.
        """
        return True

    async def start(self, frame: StartFrame):
        """Start the STT service and establish connection.

        Args:
            frame: Frame indicating service should start.
        """
        await super().start(frame)
        await self._connect()

    async def stop(self, frame: EndFrame):
        """Stop the STT service and close connection.

        Args:
            frame: Frame indicating service should stop.
        """
        await super().stop(frame)
        await self._disconnect()

    async def cancel(self, frame: CancelFrame):
        """Cancel the STT service and close connection.

        Args:
            frame: Frame indicating service should be cancelled.
        """
        await super().cancel(frame)
        await self._disconnect()

    async def _start_metrics(self):
        """Start performance metrics collection for transcription processing."""
        await self.start_processing_metrics()

    async def process_frame(self, frame: Frame, direction: FrameDirection):
        """Process incoming frames and handle speech events.

        Args:
            frame: The frame to process.
            direction: Direction of frame flow in the pipeline.
        """
        await super().process_frame(frame, direction)

        if isinstance(frame, VADUserStartedSpeakingFrame):
            await self._start_metrics()
        elif isinstance(frame, VADUserStoppedSpeakingFrame):
            # Send finalize command to flush the transcription session
            if self._websocket and self._websocket.state is State.OPEN:
                await self._websocket.send("finalize")

    async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
        """Process audio data for speech-to-text transcription.

        Args:
            audio: Raw audio bytes to transcribe.

        Yields:
            None - transcription results are handled via WebSocket responses.
        """
        # If the connection is closed, due to timeout, we need to reconnect when the user starts speaking again
        if not self._websocket or self._websocket.state is State.CLOSED:
            await self._connect()

        await self._websocket.send(audio)
        yield None

    async def _connect(self):
        await self._connect_websocket()

        await super()._connect()

        if self._websocket and not self._receive_task:
            self._receive_task = self.create_task(self._receive_task_handler(self._report_error))

    async def _disconnect(self):
        await super()._disconnect()

        if self._receive_task:
            await self.cancel_task(self._receive_task)
            self._receive_task = None

        await self._disconnect_websocket()

    async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
        """Apply a settings delta.

        Args:
            delta: A :class:`STTSettings` (or ``CartesiaSTTSettings``) delta.

        Returns:
            Dict mapping changed field names to their previous values.
        """
        changed = await super()._update_settings(delta)

        # TODO: someday we could reconnect here to apply updated settings.
        # Code might look something like the below:
        # if changed:
        #     await self._disconnect()
        #     await self._connect()

        self._warn_unhandled_updated_settings(changed)

        return changed

    async def _connect_websocket(self):
        try:
            if self._websocket and self._websocket.state is State.OPEN:
                return
            logger.debug("Connecting to Cartesia STT")

            params = {
                "model": self._settings.model,
                "language": self._settings.language,
                "encoding": self._encoding,
                "sample_rate": str(self.sample_rate),
            }
            ws_url = f"wss://{self._base_url}/stt/websocket?{urllib.parse.urlencode(params)}"
            headers = {"Cartesia-Version": "2025-04-16", "X-API-Key": self._api_key}

            self._websocket = await websocket_connect(ws_url, additional_headers=headers)
            await self._call_event_handler("on_connected")
        except Exception as e:
            await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)

    async def _disconnect_websocket(self):
        try:
            if self._websocket and self._websocket.state is State.OPEN:
                logger.debug("Disconnecting from Cartesia STT")
                await self._websocket.close()
        except Exception as e:
            await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
        finally:
            self._websocket = None
            await self._call_event_handler("on_disconnected")

    def _get_websocket(self):
        if self._websocket:
            return self._websocket
        raise Exception("Websocket not connected")

    async def _receive_messages(self):
        """Process incoming WebSocket messages."""
        async for message in self._get_websocket():
            try:
                data = json.loads(message)
                await self._process_response(data)
            except json.JSONDecodeError:
                logger.warning(f"Received non-JSON message: {message}")
            except Exception as e:
                logger.error(f"Error processing message: {e}")

    async def _process_response(self, data):
        if "type" in data:
            if data["type"] == "transcript":
                await self._on_transcript(data)

            elif data["type"] == "error":
                error_msg = data.get("message", "Unknown error")
                await self.push_error(error_msg=error_msg)

    @traced_stt
    async def _handle_transcription(
        self, transcript: str, is_final: bool, language: Optional[Language] = None
    ):
        """Handle a transcription result with tracing."""
        pass

    async def _on_transcript(self, data):
        if "text" not in data:
            return

        transcript = data.get("text", "")
        is_final = data.get("is_final", False)
        language = None

        if "language" in data:
            try:
                language = Language(data["language"])
            except (ValueError, KeyError):
                pass

        if len(transcript) > 0:
            if is_final:
                await self.push_frame(
                    TranscriptionFrame(
                        transcript,
                        self._user_id,
                        time_now_iso8601(),
                        language,
                        result=data,
                    )
                )
                await self._handle_transcription(transcript, is_final, language)
                await self.stop_processing_metrics()
            else:
                # For interim transcriptions, just push the frame without tracing
                await self.push_frame(
                    InterimTranscriptionFrame(
                        transcript,
                        self._user_id,
                        time_now_iso8601(),
                        language,
                        result=data,
                    )
                )
