#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Soniox speech-to-text service implementation."""

import json
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, List, Optional

from loguru import logger
from pydantic import BaseModel

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
    VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.stt_latency import SONIOX_TTFS_P99
from pipecat.services.stt_service import WebsocketSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

try:
    from websockets.asyncio.client import connect as websocket_connect
    from websockets.protocol import State
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use Soniox, you need to `pip install pipecat-ai[soniox]`.")
    raise Exception(f"Missing module: {e}")


KEEPALIVE_MESSAGE = '{"type": "keepalive"}'

FINALIZE_MESSAGE = '{"type": "finalize"}'

END_TOKEN = "<end>"

FINALIZED_TOKEN = "<fin>"


class SonioxContextGeneralItem(BaseModel):
    """Represents a key-value pair for structured general context information."""

    key: str
    value: str


class SonioxContextTranslationTerm(BaseModel):
    """Represents a custom translation mapping for ambiguous or domain-specific terms."""

    source: str
    target: str


class SonioxContextObject(BaseModel):
    """Context object for models with context_version 2, for Soniox stt-rt-v3-preview and higher.

    Learn more about context in the documentation:
    https://soniox.com/docs/stt/concepts/context
    """

    general: Optional[List[SonioxContextGeneralItem]] = None
    text: Optional[str] = None
    terms: Optional[List[str]] = None
    translation_terms: Optional[List[SonioxContextTranslationTerm]] = None


class SonioxInputParams(BaseModel):
    """Real-time transcription settings.

    .. deprecated:: 0.0.105
        Use ``settings=SonioxSTTSettings(...)`` instead.

    See Soniox WebSocket API documentation for more details:
    https://soniox.com/docs/speech-to-text/api-reference/websocket-api#configuration-parameters

    Parameters:
        model: Model to use for transcription.
        audio_format: Audio format to use for transcription.
        num_channels: Number of channels to use for transcription.
        language_hints: List of language hints to use for transcription.
        language_hints_strict: If true, strictly enforce language hints (only transcribe in provided languages).
        context: Customization for transcription. String for models with context_version 1 and ContextObject for models with context_version 2.
        enable_speaker_diarization: Whether to enable speaker diarization. Tokens are annotated with speaker IDs.
        enable_language_identification: Whether to enable language identification. Tokens are annotated with language IDs.
        client_reference_id: Client reference ID to use for transcription.
    """

    model: str = "stt-rt-v4"

    audio_format: Optional[str] = "pcm_s16le"
    num_channels: Optional[int] = 1

    language_hints: Optional[List[Language]] = None
    language_hints_strict: Optional[bool] = None
    context: Optional[SonioxContextObject | str] = None

    enable_speaker_diarization: Optional[bool] = False
    enable_language_identification: Optional[bool] = False

    client_reference_id: Optional[str] = None


def is_end_token(token: dict) -> bool:
    """Determine if a token is an end token."""
    return token["text"] == END_TOKEN or token["text"] == FINALIZED_TOKEN


def language_to_soniox_language(language: Language) -> str:
    """Pipecat Language enum uses same ISO 2-letter codes as Soniox, except with added regional variants.

    For a list of all supported languages, see: https://soniox.com/docs/speech-to-text/core-concepts/supported-languages
    """
    lang_str = str(language.value).lower()
    if "-" in lang_str:
        return lang_str.split("-")[0]
    return lang_str


def _prepare_language_hints(
    language_hints: Optional[List[Language]],
) -> Optional[List[str]]:
    if language_hints is None:
        return None

    prepared_languages = [language_to_soniox_language(lang) for lang in language_hints]
    # Remove duplicates (in case of language_hints with multiple regions).
    return list(set(prepared_languages))


@dataclass
class SonioxSTTSettings(STTSettings):
    """Settings for SonioxSTTService.

    Parameters:
        language_hints: List of language hints to use for transcription.
        language_hints_strict: If true, strictly enforce language hints.
        context: Customization for transcription. String for models with
            context_version 1 and SonioxContextObject for models with
            context_version 2.
        enable_speaker_diarization: Whether to enable speaker diarization.
        enable_language_identification: Whether to enable language identification.
        client_reference_id: Client reference ID to use for transcription.
    """

    language_hints: List[Language] | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    language_hints_strict: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    context: SonioxContextObject | str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    enable_speaker_diarization: bool | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    enable_language_identification: bool | None | _NotGiven = field(
        default_factory=lambda: NOT_GIVEN
    )
    client_reference_id: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)


class SonioxSTTService(WebsocketSTTService):
    """Speech-to-Text service using Soniox's WebSocket API.

    This service connects to Soniox's WebSocket API for real-time transcription
    with support for multiple languages, custom context, speaker diarization,
    and more.

    For complete API documentation, see: https://soniox.com/docs/speech-to-text/api-reference/websocket-api
    """

    Settings = SonioxSTTSettings
    _settings: SonioxSTTSettings

    def __init__(
        self,
        *,
        api_key: str,
        url: str = "wss://stt-rt.soniox.com/transcribe-websocket",
        sample_rate: Optional[int] = None,
        model: Optional[str] = None,
        audio_format: str = "pcm_s16le",
        num_channels: int = 1,
        params: Optional[SonioxInputParams] = None,
        vad_force_turn_endpoint: bool = True,
        settings: Optional[SonioxSTTSettings] = None,
        ttfs_p99_latency: Optional[float] = SONIOX_TTFS_P99,
        **kwargs,
    ):
        """Initialize the Soniox STT service.

        Args:
            api_key: Soniox API key.
            url: Soniox WebSocket API URL.
            sample_rate: Audio sample rate.
            model: Soniox model to use for transcription.

                .. deprecated:: 0.0.105
                    Use ``settings=SonioxSTTSettings(model=...)`` instead.

            audio_format: Audio format for transcription. Defaults to ``"pcm_s16le"``.
            num_channels: Number of audio channels. Defaults to 1.
            params: Additional configuration parameters, such as language hints, context and
                speaker diarization.

                .. deprecated:: 0.0.105
                    Use ``settings=SonioxSTTSettings(...)`` instead.

            vad_force_turn_endpoint: Listen to `VADUserStoppedSpeakingFrame` to send finalize message to Soniox.
                If disabled, Soniox will detect the end of the speech. Defaults to True.
            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
                Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
            **kwargs: Additional arguments passed to the STTService.
        """
        # --- 1. Hardcoded defaults ---
        default_settings = SonioxSTTSettings(
            model="stt-rt-v4",
            language=None,
            language_hints=None,
            language_hints_strict=None,
            context=None,
            enable_speaker_diarization=False,
            enable_language_identification=False,
            client_reference_id=None,
        )

        # --- 2. Deprecated direct-arg overrides ---
        if model is not None:
            _warn_deprecated_param("model", SonioxSTTSettings, "model")
            default_settings.model = model

        # --- 3. Deprecated params overrides ---
        if params is not None:
            _warn_deprecated_param("params", SonioxSTTSettings)
            if not settings:
                default_settings.model = params.model
                if params.audio_format is not None:
                    audio_format = params.audio_format
                if params.num_channels is not None:
                    num_channels = params.num_channels
                default_settings.language_hints = params.language_hints
                default_settings.language_hints_strict = params.language_hints_strict
                default_settings.context = params.context
                default_settings.enable_speaker_diarization = params.enable_speaker_diarization
                default_settings.enable_language_identification = (
                    params.enable_language_identification
                )
                default_settings.client_reference_id = params.client_reference_id

        # --- 4. Settings delta (canonical API, always wins) ---
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            ttfs_p99_latency=ttfs_p99_latency,
            keepalive_timeout=1,
            keepalive_interval=5,
            settings=default_settings,
            **kwargs,
        )

        self._api_key = api_key
        self._url = url
        self._vad_force_turn_endpoint = vad_force_turn_endpoint

        # Init-only audio config
        self._audio_format = audio_format
        self._num_channels = num_channels

        self._final_transcription_buffer = []
        self._last_tokens_received: Optional[float] = None

        self._receive_task = None

    def can_generate_metrics(self) -> bool:
        """Check if this service can generate processing metrics.

        Returns:
            True, as Soniox STT supports metrics generation.
        """
        return True

    async def start(self, frame: StartFrame):
        """Start the Soniox STT websocket connection.

        Args:
            frame: The start frame containing initialization parameters.
        """
        await super().start(frame)
        await self._connect()

    async def _update_settings(self, delta: SonioxSTTSettings) -> dict[str, Any]:
        """Apply settings delta and reconnect if anything changed.

        Args:
            delta: A settings delta.

        Returns:
            Dict mapping changed field names to their previous values.
        """
        changed = await super()._update_settings(delta)

        if changed:
            await self._disconnect()
            await self._connect()

        return changed

    async def stop(self, frame: EndFrame):
        """Stop the Soniox STT websocket connection.

        Stopping waits for the server to close the connection as we might receive
        additional final tokens after sending the stop recording message.

        Args:
            frame: The end frame.
        """
        await super().stop(frame)
        await self._send_stop_recording()
        await self._disconnect()

    async def cancel(self, frame: CancelFrame):
        """Cancel the Soniox STT websocket connection.

        Compared to stop, this method closes the connection immediately without waiting
        for the server to close it. This is useful when we want to stop the connection
        immediately without waiting for the server to send any final tokens.

        Args:
            frame: The cancel frame.
        """
        await super().cancel(frame)
        await self._disconnect()

    async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
        """Send audio data to Soniox STT Service.

        Args:
            audio: Raw audio bytes to transcribe.

        Yields:
            Frame: None (transcription results come via WebSocket callbacks).
        """
        if self._websocket and self._websocket.state is State.OPEN:
            await self._websocket.send(audio)

        yield None

    @traced_stt
    async def _handle_transcription(
        self, transcript: str, is_final: bool, language: Optional[Language] = None
    ):
        """Handle a transcription result with tracing."""
        pass

    async def process_frame(self, frame: Frame, direction: FrameDirection):
        """Processes a frame of audio data, either buffering or transcribing it.

        Args:
            frame: The frame to process.
            direction: The direction of frame processing.
        """
        await super().process_frame(frame, direction)

        if isinstance(frame, VADUserStoppedSpeakingFrame) and self._vad_force_turn_endpoint:
            # Send finalize message to Soniox so we get the final tokens asap.
            if self._websocket and self._websocket.state is State.OPEN:
                await self._websocket.send(FINALIZE_MESSAGE)
                logger.debug(f"Triggered finalize event on: {frame.name=}, {direction=}")

    async def _send_stop_recording(self):
        """Send stop recording message to Soniox."""
        if self._websocket and self._websocket.state is State.OPEN:
            # Send stop recording message
            await self._websocket.send("")

    async def _connect(self):
        """Connect to the Soniox service.

        Establishes websocket connection and starts receive and keepalive tasks.
        """
        await self._connect_websocket()

        await super()._connect()

        if self._websocket and not self._receive_task:
            self._receive_task = self.create_task(self._receive_task_handler(self._report_error))

    async def _disconnect(self):
        """Disconnect from the Soniox service.

        Cleans up tasks and closes websocket connection.
        """
        await super()._disconnect()

        if self._receive_task:
            await self.cancel_task(self._receive_task)
            self._receive_task = None

        await self._disconnect_websocket()

    async def _connect_websocket(self):
        """Establish the websocket connection to Soniox."""
        try:
            if self._websocket and self._websocket.state is State.OPEN:
                return

            logger.debug("Connecting to Soniox STT")

            self._websocket = await websocket_connect(self._url)

            if not self._websocket:
                await self.push_error(error_msg=f"Unable to connect to Soniox API at {self._url}")
                raise Exception(f"Unable to connect to Soniox API at {self._url}")

            # If vad_force_turn_endpoint is not enabled, we need to enable endpoint detection.
            # Either one or the other is required.
            enable_endpoint_detection = not self._vad_force_turn_endpoint

            s = self._settings

            context = s.context
            if isinstance(context, SonioxContextObject):
                context = context.model_dump()

            # Send the initial configuration message.
            config = {
                "api_key": self._api_key,
                "model": s.model,
                "audio_format": self._audio_format,
                "num_channels": self._num_channels,
                "enable_endpoint_detection": enable_endpoint_detection,
                "sample_rate": self.sample_rate,
                "language_hints": _prepare_language_hints(s.language_hints),
                "language_hints_strict": s.language_hints_strict,
                "context": context,
                "enable_speaker_diarization": s.enable_speaker_diarization,
                "enable_language_identification": s.enable_language_identification,
                "client_reference_id": s.client_reference_id,
            }

            # Send the configuration message.
            await self._websocket.send(json.dumps(config))

            await self._call_event_handler("on_connected")
            logger.debug("Connected to Soniox STT")
        except Exception as e:
            await self.push_error(error_msg=f"Unable to connect to Soniox: {e}", exception=e)
            raise

    async def _disconnect_websocket(self):
        """Close the websocket connection to Soniox."""
        try:
            if self._websocket:
                logger.debug("Disconnecting from Soniox STT")
                await self._websocket.close()
        except Exception as e:
            await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
        finally:
            self._websocket = None
            await self._call_event_handler("on_disconnected")

    def _get_websocket(self):
        """Get the current WebSocket connection.

        Returns:
            The WebSocket connection.

        Raises:
            Exception: If WebSocket is not connected.
        """
        if self._websocket:
            return self._websocket
        raise Exception("Websocket not connected")

    async def _receive_messages(self):
        """Receive and process websocket messages.

        Continuously processes messages from the websocket connection.
        """
        # Transcription frame will be only sent after we get the "endpoint" event.
        self._final_transcription_buffer = []

        async def send_endpoint_transcript():
            if self._final_transcription_buffer:
                text = "".join(map(lambda token: token["text"], self._final_transcription_buffer))
                # Soniox only pushes TranscriptionFrame when an end token is received,
                # so every TranscriptionFrame is inherently finalized
                await self.push_frame(
                    TranscriptionFrame(
                        text=text,
                        user_id=self._user_id,
                        timestamp=time_now_iso8601(),
                        result=self._final_transcription_buffer,
                        finalized=True,
                    )
                )
                await self._handle_transcription(text, is_final=True)
                await self.stop_processing_metrics()
                self._final_transcription_buffer = []

        async for message in self._get_websocket():
            try:
                content = json.loads(message)

                tokens = content["tokens"]

                if tokens:
                    if len(tokens) == 1 and tokens[0]["text"] == FINALIZED_TOKEN:
                        # Ignore finalized token, prevent auto-finalize cycling.
                        pass
                    else:
                        # Got at least one token, so we can reset the auto finalize delay.
                        self._last_tokens_received = time.time()

                # We will only send the final tokens after we get the "endpoint" event.
                non_final_transcription = []

                for token in tokens:
                    if token["is_final"]:
                        if is_end_token(token):
                            # Found an endpoint, tokens until here will be sent as transcript,
                            # the rest will be sent as interim tokens (even final tokens).
                            await send_endpoint_transcript()
                        else:
                            if not self._final_transcription_buffer:
                                await self.start_processing_metrics()
                            self._final_transcription_buffer.append(token)
                    else:
                        non_final_transcription.append(token)

                if self._final_transcription_buffer or non_final_transcription:
                    final_text = "".join(
                        map(lambda token: token["text"], self._final_transcription_buffer)
                    )
                    non_final_text = "".join(
                        map(lambda token: token["text"], non_final_transcription)
                    )

                    await self.push_frame(
                        InterimTranscriptionFrame(
                            # Even final tokens are sent as interim tokens as we want to send
                            # nicely formatted messages - therefore waiting for the endpoint.
                            text=final_text + non_final_text,
                            user_id=self._user_id,
                            timestamp=time_now_iso8601(),
                            result=self._final_transcription_buffer + non_final_transcription,
                        )
                    )

                error_code = content.get("error_code")
                error_message = content.get("error_message")
                if error_code or error_message:
                    # In case of error, still send the final transcript (if any remaining in the buffer).
                    await send_endpoint_transcript()
                    await self.push_error(
                        error_msg=f"Error: {error_code} (_receive_messages) - {error_message}"
                    )

                finished = content.get("finished")
                if finished:
                    # When finished, still send the final transcript (if any remaining in the buffer).
                    await send_endpoint_transcript()
                    logger.debug("Transcription finished.")
                    return

            except json.JSONDecodeError:
                logger.warning(f"Received non-JSON message: {message}")
            except Exception as e:
                logger.warning(f"Error processing message: {e}")

    async def _send_keepalive(self, silence: bytes):
        """Send a Soniox protocol-level keepalive message.

        Args:
            silence: Silent PCM audio bytes (unused, Soniox uses a protocol message).
        """
        await self._websocket.send(KEEPALIVE_MESSAGE)
