#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Deepgram speech-to-text service for AWS SageMaker.

This module provides a Pipecat STT service that connects to Deepgram models
deployed on AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for
low-latency real-time transcription with support for interim results, multiple
languages, and various Deepgram features.
"""

import asyncio
import json
from dataclasses import dataclass, fields
from typing import Any, AsyncGenerator, Optional

from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
    VADUserStartedSpeakingFrame,
    VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.deepgram.stt import DeepgramSTTSettings, LiveOptions
from pipecat.services.settings import STTSettings, _warn_deprecated_param, is_given
from pipecat.services.stt_latency import DEEPGRAM_SAGEMAKER_TTFS_P99
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt


@dataclass
class DeepgramSageMakerSTTSettings(DeepgramSTTSettings):
    """Settings for the Deepgram SageMaker STT service.

    Inherits all fields from :class:`DeepgramSTTSettings`.
    """

    pass


class DeepgramSageMakerSTTService(STTService):
    """Deepgram speech-to-text service for AWS SageMaker.

    Provides real-time speech recognition using Deepgram models deployed on
    AWS SageMaker endpoints. Uses HTTP/2 bidirectional streaming for low-latency
    transcription with support for interim results, speaker diarization, and
    multiple languages.

    Requirements:

    - AWS credentials configured (via environment variables, AWS CLI, or instance metadata)
    - A deployed SageMaker endpoint with Deepgram model: https://developers.deepgram.com/docs/deploy-amazon-sagemaker

    Example::

        stt = DeepgramSageMakerSTTService(
            endpoint_name="my-deepgram-endpoint",
            region="us-east-2",
            settings=DeepgramSageMakerSTTSettings(
                model="nova-3",
                language="en",
                interim_results=True,
                punctuate=True,
            ),
        )
    """

    Settings = DeepgramSageMakerSTTSettings
    _settings: DeepgramSageMakerSTTSettings

    def __init__(
        self,
        *,
        endpoint_name: str,
        region: str,
        encoding: str = "linear16",
        channels: int = 1,
        multichannel: bool = False,
        sample_rate: Optional[int] = None,
        mip_opt_out: Optional[bool] = None,
        live_options: Optional[LiveOptions] = None,
        settings: Optional[DeepgramSageMakerSTTSettings] = None,
        ttfs_p99_latency: Optional[float] = DEEPGRAM_SAGEMAKER_TTFS_P99,
        **kwargs,
    ):
        """Initialize the Deepgram SageMaker STT service.

        Args:
            endpoint_name: Name of the SageMaker endpoint with Deepgram model
                deployed (e.g., "my-deepgram-nova-3-endpoint").
            region: AWS region where the endpoint is deployed (e.g., "us-east-2").
            encoding: Audio encoding format. Defaults to "linear16".
            channels: Number of audio channels. Defaults to 1.
            multichannel: Transcribe each audio channel independently.
                Defaults to False.
            sample_rate: Audio sample rate in Hz. If None, uses the pipeline
                sample rate.
            mip_opt_out: Opt out of Deepgram model improvement program.
            live_options: Legacy configuration options.

                .. deprecated:: 0.0.105
                    Use ``settings=DeepgramSageMakerSTTSettings(...)`` for
                    runtime-updatable fields and direct init parameters for
                    connection-level config.

            settings: Runtime-updatable settings. When provided alongside
                ``live_options``, ``settings`` values take precedence (applied
                after the ``live_options`` merge).
            ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
                Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
            **kwargs: Additional arguments passed to the parent STTService.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = DeepgramSageMakerSTTSettings(
            model="nova-3",
            language=Language.EN,
            detect_entities=False,
            diarize=False,
            dictation=False,
            endpointing=None,
            interim_results=True,
            keyterm=None,
            keywords=None,
            numerals=False,
            profanity_filter=True,
            punctuate=True,
            redact=None,
            replace=None,
            search=None,
            smart_format=False,
            utterance_end_ms=None,
            vad_events=False,
        )

        # 2. Apply live_options overrides — only if settings not provided
        if live_options is not None:
            _warn_deprecated_param("live_options", DeepgramSageMakerSTTSettings)
            if not settings:
                # Extract init-only fields from live_options
                if live_options.sample_rate is not None and sample_rate is None:
                    sample_rate = live_options.sample_rate
                if live_options.encoding is not None:
                    encoding = live_options.encoding
                if live_options.channels is not None:
                    channels = live_options.channels
                if live_options.multichannel is not None:
                    multichannel = live_options.multichannel
                if live_options.mip_opt_out is not None:
                    mip_opt_out = live_options.mip_opt_out

                # Build settings delta from remaining fields
                init_only = {
                    "sample_rate",
                    "encoding",
                    "channels",
                    "multichannel",
                    "mip_opt_out",
                }
                lo_dict = {k: v for k, v in live_options.to_dict().items() if k not in init_only}
                delta = DeepgramSageMakerSTTSettings.from_mapping(lo_dict)
                default_settings.apply_update(delta)

        # 3. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        # Sync extra to top-level fields so self._settings is unambiguous
        default_settings._sync_extra_to_fields()

        super().__init__(
            sample_rate=sample_rate,
            ttfs_p99_latency=ttfs_p99_latency,
            settings=default_settings,
            **kwargs,
        )

        self._endpoint_name = endpoint_name
        self._region = region

        # Init-only connection config (not runtime-updatable).
        self._encoding = encoding
        self._channels = channels
        self._multichannel = multichannel
        self._mip_opt_out = mip_opt_out

        self._client: Optional[SageMakerBidiClient] = None
        self._response_task: Optional[asyncio.Task] = None
        self._keepalive_task: Optional[asyncio.Task] = None

    def can_generate_metrics(self) -> bool:
        """Check if this service can generate processing metrics.

        Returns:
            True, as Deepgram SageMaker service supports metrics generation.
        """
        return True

    async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
        """Apply a settings delta and warn about unhandled changes."""
        changed = await super()._update_settings(delta)

        if not changed:
            return changed

        # Sync extra to fields after the update so self._settings stays unambiguous
        if isinstance(self._settings, DeepgramSTTSettings):
            self._settings._sync_extra_to_fields()

        # TODO: someday we could reconnect here to apply updated settings.
        # Code might look something like the below:
        # await self._disconnect()
        # await self._connect()

        self._warn_unhandled_updated_settings(changed)

        return changed

    async def start(self, frame: StartFrame):
        """Start the Deepgram SageMaker STT service.

        Args:
            frame: The start frame containing initialization parameters.
        """
        await super().start(frame)
        await self._connect()

    async def stop(self, frame: EndFrame):
        """Stop the Deepgram SageMaker STT service.

        Args:
            frame: The end frame.
        """
        await super().stop(frame)
        await self._disconnect()

    async def cancel(self, frame: CancelFrame):
        """Cancel the Deepgram SageMaker STT service.

        Args:
            frame: The cancel frame.
        """
        await super().cancel(frame)
        await self._disconnect()

    async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
        """Send audio data to Deepgram for transcription.

        Args:
            audio: Raw audio bytes to transcribe.

        Yields:
            Frame: None (transcription results come via BiDi stream callbacks).
        """
        if self._client and self._client.is_active:
            try:
                await self._client.send_audio_chunk(audio)
            except Exception as e:
                yield ErrorFrame(error=f"Unknown error occurred: {e}")
        yield None

    def _build_query_string(self) -> str:
        """Build query string from current settings and init-only connection config."""
        params = {}
        s = self._settings

        # Declared Deepgram-specific fields from settings
        for f in fields(s):
            if f.name in ("model", "language", "extra") or f.name.startswith("_"):
                continue
            value = getattr(s, f.name)
            if not is_given(value) or value is None:
                continue
            params[f.name] = str(value).lower() if isinstance(value, bool) else str(value)

        # model and language
        if is_given(s.model) and s.model is not None:
            params["model"] = str(s.model)
        if is_given(s.language) and s.language is not None:
            params["language"] = str(s.language)

        # Init-only connection config
        params["encoding"] = self._encoding
        params["channels"] = str(self._channels)
        params["multichannel"] = str(self._multichannel).lower()
        params["sample_rate"] = str(self.sample_rate)

        if self._mip_opt_out is not None:
            params["mip_opt_out"] = str(self._mip_opt_out).lower()

        # Any remaining values in extra
        if s.extra:
            for key, value in s.extra.items():
                if value is not None:
                    params[key] = str(value).lower() if isinstance(value, bool) else str(value)

        return "&".join(f"{k}={v}" for k, v in params.items())

    async def _connect(self):
        """Connect to the SageMaker endpoint and start the BiDi session.

        Builds the Deepgram query string from settings, creates the BiDi client,
        starts the streaming session, and launches background tasks for processing
        responses and sending KeepAlive messages.
        """
        logger.debug("Connecting to Deepgram on SageMaker...")

        query_string = self._build_query_string()

        # Create BiDi client
        self._client = SageMakerBidiClient(
            endpoint_name=self._endpoint_name,
            region=self._region,
            model_invocation_path="v1/listen",
            model_query_string=query_string,
        )

        try:
            # Start the session
            await self._client.start_session()

            # Start processing responses in the background
            self._response_task = self.create_task(self._process_responses())

            # Start keepalive task to maintain connection
            self._keepalive_task = self.create_task(self._send_keepalive())

            logger.debug("Connected to Deepgram on SageMaker")
            await self._call_event_handler("on_connected")

        except Exception as e:
            await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
            await self._call_event_handler("on_connection_error", str(e))

    async def _disconnect(self):
        """Disconnect from the SageMaker endpoint.

        Sends a CloseStream message to Deepgram, cancels background tasks
        (KeepAlive and response processing), and closes the BiDi session.
        Safe to call multiple times.
        """
        if self._client and self._client.is_active:
            logger.debug("Disconnecting from Deepgram on SageMaker...")

            # Send CloseStream message to Deepgram
            try:
                await self._client.send_json({"type": "CloseStream"})
            except Exception as e:
                logger.warning(f"Failed to send CloseStream message: {e}")

            # Cancel keepalive task
            if self._keepalive_task and not self._keepalive_task.done():
                await self.cancel_task(self._keepalive_task)

            # Cancel response processing task
            if self._response_task and not self._response_task.done():
                await self.cancel_task(self._response_task)

            # Close the BiDi session
            await self._client.close_session()

            logger.debug("Disconnected from Deepgram on SageMaker")
            await self._call_event_handler("on_disconnected")

    async def _send_keepalive(self):
        """Send periodic KeepAlive messages to maintain the connection.

        Sends a KeepAlive JSON message to Deepgram every 5 seconds while the
        connection is active. This prevents the connection from timing out during
        periods of silence.
        """
        while self._client and self._client.is_active:
            await asyncio.sleep(5)
            if self._client and self._client.is_active:
                try:
                    await self._client.send_json({"type": "KeepAlive"})
                except Exception as e:
                    logger.warning(f"Failed to send KeepAlive: {e}")

    async def _process_responses(self):
        """Process streaming responses from Deepgram on SageMaker.

        Continuously receives responses from the BiDi stream, decodes the payload,
        parses JSON responses from Deepgram, and processes transcription results.
        Runs as a background task until the connection is closed or cancelled.
        """
        try:
            while self._client and self._client.is_active:
                result = await self._client.receive_response()

                if result is None:
                    break

                # Check if this is a PayloadPart with bytes
                if hasattr(result, "value") and hasattr(result.value, "bytes_"):
                    if result.value.bytes_:
                        response_data = result.value.bytes_.decode("utf-8")

                        try:
                            # Parse JSON response from Deepgram
                            parsed = json.loads(response_data)

                            # Extract and process transcript if available
                            if "channel" in parsed:
                                await self._handle_transcript_response(parsed)

                        except json.JSONDecodeError:
                            logger.warning(f"Non-JSON response: {response_data}")

        except asyncio.CancelledError:
            logger.debug("Response processor cancelled")
        except Exception as e:
            await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
        finally:
            logger.debug("Response processor stopped")

    async def _handle_transcript_response(self, parsed: dict):
        """Handle a transcript response from Deepgram.

        Extracts the transcript text, determines if it's final or interim, extracts
        language information, and pushes the appropriate frame (TranscriptionFrame
        or InterimTranscriptionFrame) downstream.

        Args:
            parsed: The parsed JSON response from Deepgram containing channel,
                alternatives, transcript, and metadata.
        """
        alternatives = parsed.get("channel", {}).get("alternatives", [])
        if not alternatives or not alternatives[0].get("transcript"):
            return

        transcript = alternatives[0]["transcript"]
        if not transcript.strip():
            return

        is_final = parsed.get("is_final", False)

        # Extract language if available
        language = None
        if alternatives[0].get("languages"):
            language = alternatives[0]["languages"][0]
            language = Language(language)

        if is_final:
            # Check if this response is from a finalize() call.
            # Only mark as finalized when both we requested it AND Deepgram confirms it.
            from_finalize = parsed.get("from_finalize", False)
            if from_finalize:
                self.confirm_finalize()
            await self.push_frame(
                TranscriptionFrame(
                    transcript,
                    self._user_id,
                    time_now_iso8601(),
                    language,
                    result=parsed,
                )
            )
            await self._handle_transcription(transcript, is_final, language)
            await self.stop_processing_metrics()
        else:
            # Interim transcription
            await self.push_frame(
                InterimTranscriptionFrame(
                    transcript,
                    self._user_id,
                    time_now_iso8601(),
                    language,
                    result=parsed,
                )
            )

    @traced_stt
    async def _handle_transcription(
        self, transcript: str, is_final: bool, language: Optional[Language] = None
    ):
        """Handle a transcription result with tracing.

        This method is decorated with @traced_stt for observability and tracing
        integration. The actual transcription processing is handled by the parent
        class and observers.

        Args:
            transcript: The transcribed text.
            is_final: Whether this is a final transcription result.
            language: The detected language of the transcription, if available.
        """
        pass

    async def _start_metrics(self):
        """Start processing metrics collection."""
        await self.start_processing_metrics()

    async def process_frame(self, frame: Frame, direction: FrameDirection):
        """Process frames with Deepgram SageMaker-specific handling.

        Args:
            frame: The frame to process.
            direction: The direction of frame processing.
        """
        await super().process_frame(frame, direction)

        # Start metrics when user starts speaking (if VAD is not provided by Deepgram)
        if isinstance(frame, VADUserStartedSpeakingFrame):
            await self._start_metrics()
        elif isinstance(frame, VADUserStoppedSpeakingFrame):
            # https://developers.deepgram.com/docs/finalize
            # Mark that we're awaiting a from_finalize response
            self.request_finalize()
            if self._client and self._client.is_active:
                try:
                    await self._client.send_json({"type": "Finalize"})
                except Exception as e:
                    logger.warning(f"Error sending Finalize message: {e}")
            logger.trace(f"Triggered finalize event on: {frame.name=}, {direction=}")
