#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""AWS Transcribe Speech-to-Text service implementation.

This module provides a WebSocket-based connection to AWS Transcribe for real-time
speech-to-text transcription with support for multiple languages and audio formats.
"""

import json
import os
import random
import string
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Optional

from loguru import logger

from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    ErrorFrame,
    Frame,
    InterimTranscriptionFrame,
    StartFrame,
    TranscriptionFrame,
)
from pipecat.services.aws.utils import build_event_message, decode_event, get_presigned_url
from pipecat.services.settings import STTSettings, _warn_deprecated_param
from pipecat.services.stt_latency import AWS_TRANSCRIBE_TTFS_P99
from pipecat.services.stt_service import WebsocketSTTService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt

try:
    from websockets.asyncio.client import connect as websocket_connect
    from websockets.protocol import State
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use AWS services, you need to `pip install pipecat-ai[aws]`.")
    raise Exception(f"Missing module: {e}")


@dataclass
class AWSTranscribeSTTSettings(STTSettings):
    """Settings for AWSTranscribeSTTService."""

    pass


class AWSTranscribeSTTService(WebsocketSTTService):
    """AWS Transcribe Speech-to-Text service using WebSocket streaming.

    Provides real-time speech transcription using AWS Transcribe's streaming API.
    Supports multiple languages, configurable sample rates, and both interim and
    final transcription results.
    """

    Settings = AWSTranscribeSTTSettings
    _settings: AWSTranscribeSTTSettings

    def __init__(
        self,
        *,
        api_key: Optional[str] = None,
        aws_access_key_id: Optional[str] = None,
        aws_session_token: Optional[str] = None,
        region: Optional[str] = None,
        sample_rate: Optional[int] = None,
        language: Optional[Language] = None,
        settings: Optional[AWSTranscribeSTTSettings] = None,
        ttfs_p99_latency: Optional[float] = AWS_TRANSCRIBE_TTFS_P99,
        **kwargs,
    ):
        """Initialize the AWS Transcribe STT service.

        Args:
            api_key: AWS secret access key. If None, uses AWS_SECRET_ACCESS_KEY environment variable.
            aws_access_key_id: AWS access key ID. If None, uses AWS_ACCESS_KEY_ID environment variable.
            aws_session_token: AWS session token for temporary credentials. If None, uses AWS_SESSION_TOKEN environment variable.
            region: AWS region for the service.
            sample_rate: Audio sample rate in Hz. If None, uses the pipeline sample rate.
                AWS Transcribe only supports 8000 or 16000 Hz; other values are
                clamped to 16000 Hz at connect time.
            language: Language for transcription.

                .. deprecated:: 0.0.105
                    Use ``settings=AWSTranscribeSTTSettings(language=...)`` instead.

            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            ttfs_p99_latency: P99 latency from speech end to final transcript in seconds.
                Override for your deployment. See https://github.com/pipecat-ai/stt-benchmark
            **kwargs: Additional arguments passed to parent STTService class.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = AWSTranscribeSTTSettings(
            model=None,
            language=self.language_to_service_language(Language.EN),
        )

        # 2. Apply direct init arg overrides (deprecated)
        if language is not None:
            _warn_deprecated_param("language", AWSTranscribeSTTSettings, "language")
            default_settings.language = self.language_to_service_language(language)

        # 3. (No step 3, as there's no params object to apply)

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            ttfs_p99_latency=ttfs_p99_latency,
            settings=default_settings,
            **kwargs,
        )

        # Init-only connection config (not runtime-updatable).
        self._media_encoding = "linear16"
        self._number_of_channels = 1
        self._show_speaker_label = False
        self._enable_channel_identification = False

        self._credentials = {
            "aws_access_key_id": aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
            "aws_secret_access_key": api_key or os.getenv("AWS_SECRET_ACCESS_KEY"),
            "aws_session_token": aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
            "region": region or os.getenv("AWS_REGION", "us-east-1"),
        }

        self._receive_task = None

    def can_generate_metrics(self) -> bool:
        """Check if this service can generate processing metrics.

        Returns:
            True, as AWS Transcribe STT supports metrics generation.
        """
        return True

    def get_service_encoding(self, encoding: str) -> str:
        """Convert internal encoding format to AWS Transcribe format.

        Args:
            encoding: Internal encoding format string.

        Returns:
            AWS Transcribe compatible encoding format.
        """
        encoding_map = {
            "linear16": "pcm",  # AWS expects "pcm" for 16-bit linear PCM
        }
        return encoding_map.get(encoding, encoding)

    async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
        """Apply a settings delta and reconnect if anything changed."""
        changed = await super()._update_settings(delta)

        if changed and self._websocket:
            await self._disconnect()
            await self._connect()

        return changed

    async def start(self, frame: StartFrame):
        """Initialize the connection when the service starts.

        Args:
            frame: Start frame signaling service initialization.
        """
        await super().start(frame)
        await self._connect()

    async def stop(self, frame: EndFrame):
        """Stop the service and disconnect from AWS Transcribe.

        Args:
            frame: End frame signaling service shutdown.
        """
        await super().stop(frame)
        await self._disconnect()

    async def cancel(self, frame: CancelFrame):
        """Cancel the service and disconnect from AWS Transcribe.

        Args:
            frame: Cancel frame signaling service cancellation.
        """
        await super().cancel(frame)
        await self._disconnect()

    async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
        """Process audio data and send to AWS Transcribe.

        Args:
            audio: Raw audio bytes to transcribe.

        Yields:
            ErrorFrame: If processing fails or connection issues occur.
        """
        if self._websocket and self._websocket.state is State.OPEN:
            try:
                # Format the audio data according to AWS event stream format
                event_message = build_event_message(audio)

                # Send the formatted event message
                await self._websocket.send(event_message)
                # Start metrics after first chunk sent
                await self.start_processing_metrics()
            except Exception as e:
                yield ErrorFrame(error=f"Error sending audio: {e}")

        yield None

    async def _connect(self):
        """Connect to the AWS Transcribe service.

        Establishes websocket connection and starts receive task.
        """
        await super()._connect()

        await self._connect_websocket()

        if self._websocket and not self._receive_task:
            self._receive_task = self.create_task(self._receive_task_handler(self._report_error))

    async def _disconnect(self):
        """Disconnect from the AWS Transcribe service.

        Sends end-stream message and cleans up.
        """
        await super()._disconnect()

        if self._receive_task:
            await self.cancel_task(self._receive_task)
            self._receive_task = None

        # Send end-stream message before closing
        if self._websocket and self._websocket.state is State.OPEN:
            try:
                end_stream = {"message-type": "event", "event": "end"}
                await self._websocket.send(json.dumps(end_stream))
            except Exception as e:
                await self.push_error(error_msg=f"Error sending end-stream: {e}", exception=e)

        await self._disconnect_websocket()

    async def _connect_websocket(self):
        """Establish the websocket connection to AWS Transcribe."""
        try:
            if self._websocket and self._websocket.state is State.OPEN:
                return

            logger.debug("Connecting to AWS Transcribe WebSocket")

            language_code = self._settings.language
            if not language_code:
                raise ValueError(f"Unsupported language: {language_code}")

            # Validate sample rate — AWS Transcribe only supports 8000 or 16000 Hz
            connect_sample_rate = self.sample_rate
            if connect_sample_rate not in (8000, 16000):
                logger.warning(
                    f"AWS Transcribe only supports 8000 Hz or 16000 Hz sample rates. "
                    f"Converting from {connect_sample_rate} Hz to 16000 Hz."
                )
                connect_sample_rate = 16000

            # Generate random websocket key
            websocket_key = "".join(
                random.choices(
                    string.ascii_uppercase + string.ascii_lowercase + string.digits, k=20
                )
            )

            # Add required headers
            additional_headers = {
                "Origin": "https://localhost",
                "Sec-WebSocket-Key": websocket_key,
                "Sec-WebSocket-Version": "13",
                "Connection": "keep-alive",
            }

            # Get presigned URL
            presigned_url = get_presigned_url(
                region=self._credentials["region"],
                credentials={
                    "access_key": self._credentials["aws_access_key_id"],
                    "secret_key": self._credentials["aws_secret_access_key"],
                    "session_token": self._credentials["aws_session_token"],
                },
                language_code=language_code,
                media_encoding=self.get_service_encoding(
                    self._media_encoding
                ),  # Convert to AWS format
                sample_rate=connect_sample_rate,
                number_of_channels=self._number_of_channels,
                enable_partial_results_stabilization=True,
                partial_results_stability="high",
                show_speaker_label=self._show_speaker_label,
                enable_channel_identification=self._enable_channel_identification,
            )

            logger.debug(f"{self} Connecting to WebSocket with URL: {presigned_url[:100]}...")

            # Connect with the required headers and settings
            self._websocket = await websocket_connect(
                presigned_url,
                additional_headers=additional_headers,
                subprotocols=["mqtt"],
                ping_interval=None,
                ping_timeout=None,
                compression=None,
            )

            await self._call_event_handler("on_connected")
            logger.info(f"{self} Successfully connected to AWS Transcribe")
        except Exception as e:
            await self.push_error(
                error_msg=f"Unable to connect to AWS Transcribe: {e}", exception=e
            )
            raise

    async def _disconnect_websocket(self):
        """Close the websocket connection to AWS Transcribe."""
        try:
            if self._websocket:
                logger.debug("Disconnecting from AWS Transcribe WebSocket")
                await self._websocket.close()
        except Exception as e:
            await self.push_error(error_msg=f"Error closing websocket: {e}", exception=e)
        finally:
            self._websocket = None
            await self._call_event_handler("on_disconnected")

    def language_to_service_language(self, language: Language) -> str | None:
        """Convert internal language enum to AWS Transcribe language code.

        Source:
        https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html
        All language codes that support streaming are included.

        Args:
            language: Internal language enumeration value.

        Returns:
            AWS Transcribe compatible language code, or None if unsupported.
        """
        LANGUAGE_MAP = {
            # Afrikaans
            Language.AF: "af-ZA",
            Language.AF_ZA: "af-ZA",
            # Arabic
            Language.AR: "ar-SA",  # Default to Modern Standard Arabic
            Language.AR_AE: "ar-AE",  # Gulf Arabic
            Language.AR_SA: "ar-SA",  # Modern Standard Arabic
            # Basque
            Language.EU: "eu-ES",
            Language.EU_ES: "eu-ES",
            # Catalan
            Language.CA: "ca-ES",
            Language.CA_ES: "ca-ES",
            # Chinese
            Language.ZH: "zh-CN",  # Default to Simplified
            Language.ZH_CN: "zh-CN",  # Simplified
            Language.ZH_TW: "zh-TW",  # Traditional
            Language.ZH_HK: "zh-HK",  # Cantonese (also yue-HK)
            Language.YUE: "zh-HK",  # Cantonese fallback
            # Croatian
            Language.HR: "hr-HR",
            Language.HR_HR: "hr-HR",
            # Czech
            Language.CS: "cs-CZ",
            Language.CS_CZ: "cs-CZ",
            # Danish
            Language.DA: "da-DK",
            Language.DA_DK: "da-DK",
            # Dutch
            Language.NL: "nl-NL",
            Language.NL_NL: "nl-NL",
            # English
            Language.EN: "en-US",  # Default to US
            Language.EN_AU: "en-AU",  # Australian
            Language.EN_GB: "en-GB",  # British
            Language.EN_IN: "en-IN",  # Indian
            Language.EN_IE: "en-IE",  # Irish
            Language.EN_NZ: "en-NZ",  # New Zealand
            # Note: Scottish (en-AB) and Welsh (en-WL) don't have direct Language enum matches
            Language.EN_ZA: "en-ZA",  # South African
            Language.EN_US: "en-US",  # US
            # Persian/Farsi
            Language.FA: "fa-IR",
            Language.FA_IR: "fa-IR",
            # Finnish
            Language.FI: "fi-FI",
            Language.FI_FI: "fi-FI",
            # French
            Language.FR: "fr-FR",  # Default to France
            Language.FR_FR: "fr-FR",
            Language.FR_CA: "fr-CA",  # Canadian
            # Galician
            Language.GL: "gl-ES",
            Language.GL_ES: "gl-ES",
            # Georgian
            Language.KA: "ka-GE",
            Language.KA_GE: "ka-GE",
            # German
            Language.DE: "de-DE",  # Default to Germany
            Language.DE_DE: "de-DE",
            Language.DE_CH: "de-CH",  # Swiss
            # Greek
            Language.EL: "el-GR",
            Language.EL_GR: "el-GR",
            # Hebrew
            Language.HE: "he-IL",
            Language.HE_IL: "he-IL",
            # Hindi
            Language.HI: "hi-IN",
            Language.HI_IN: "hi-IN",
            # Indonesian
            Language.ID: "id-ID",
            Language.ID_ID: "id-ID",
            # Italian
            Language.IT: "it-IT",
            Language.IT_IT: "it-IT",
            # Japanese
            Language.JA: "ja-JP",
            Language.JA_JP: "ja-JP",
            # Korean
            Language.KO: "ko-KR",
            Language.KO_KR: "ko-KR",
            # Latvian
            Language.LV: "lv-LV",
            Language.LV_LV: "lv-LV",
            # Malay
            Language.MS: "ms-MY",
            Language.MS_MY: "ms-MY",
            # Norwegian
            Language.NB: "no-NO",  # Norwegian Bokmål
            Language.NB_NO: "no-NO",
            Language.NO: "no-NO",
            # Polish
            Language.PL: "pl-PL",
            Language.PL_PL: "pl-PL",
            # Portuguese
            Language.PT: "pt-PT",  # Default to Portugal
            Language.PT_PT: "pt-PT",
            Language.PT_BR: "pt-BR",  # Brazilian
            # Romanian
            Language.RO: "ro-RO",
            Language.RO_RO: "ro-RO",
            # Russian
            Language.RU: "ru-RU",
            Language.RU_RU: "ru-RU",
            # Serbian
            Language.SR: "sr-RS",
            Language.SR_RS: "sr-RS",
            # Slovak
            Language.SK: "sk-SK",
            Language.SK_SK: "sk-SK",
            # Somali
            Language.SO: "so-SO",
            Language.SO_SO: "so-SO",
            # Spanish
            Language.ES: "es-ES",  # Default to Spain
            Language.ES_ES: "es-ES",
            Language.ES_US: "es-US",  # US Spanish
            # Swedish
            Language.SV: "sv-SE",
            Language.SV_SE: "sv-SE",
            # Tagalog/Filipino
            Language.TL: "tl-PH",
            Language.FIL: "tl-PH",  # Filipino maps to Tagalog
            Language.FIL_PH: "tl-PH",
            # Thai
            Language.TH: "th-TH",
            Language.TH_TH: "th-TH",
            # Ukrainian
            Language.UK: "uk-UA",
            Language.UK_UA: "uk-UA",
            # Vietnamese
            Language.VI: "vi-VN",
            Language.VI_VN: "vi-VN",
            # Zulu
            Language.ZU: "zu-ZA",
            Language.ZU_ZA: "zu-ZA",
        }

        return resolve_language(language, LANGUAGE_MAP, use_base_code=False)

    @traced_stt
    async def _handle_transcription(
        self, transcript: str, is_final: bool, language: Optional[str] = None
    ):
        pass

    def _get_websocket(self):
        """Get the current WebSocket connection.

        Returns:
            The WebSocket connection.

        Raises:
            Exception: If WebSocket is not connected.
        """
        if self._websocket:
            return self._websocket
        raise Exception("Websocket not connected")

    async def _receive_messages(self):
        """Receive and process websocket messages.

        Continuously processes messages from the websocket connection.
        """
        async for response in self._get_websocket():
            try:
                headers, payload = decode_event(response)

                if headers.get(":message-type") == "event":
                    # Process transcription results
                    results = payload.get("Transcript", {}).get("Results", [])
                    if results:
                        result = results[0]
                        alternatives = result.get("Alternatives", [])
                        if alternatives:
                            transcript = alternatives[0].get("Transcript", "")
                            is_final = not result.get("IsPartial", True)

                            if transcript:
                                if is_final:
                                    await self.push_frame(
                                        TranscriptionFrame(
                                            transcript,
                                            self._user_id,
                                            time_now_iso8601(),
                                            self._settings.language,
                                            result=result,
                                        )
                                    )
                                    await self._handle_transcription(
                                        transcript,
                                        is_final,
                                        self._settings.language,
                                    )
                                    await self.stop_processing_metrics()
                                else:
                                    await self.push_frame(
                                        InterimTranscriptionFrame(
                                            transcript,
                                            self._user_id,
                                            time_now_iso8601(),
                                            self._settings.language,
                                            result=result,
                                        )
                                    )
                elif headers.get(":message-type") == "exception":
                    error_msg = payload.get("Message", "Unknown error")
                    await self.push_error(error_msg=f"AWS Transcribe error: {error_msg}")
                else:
                    logger.debug(f"{self} Other message type received: {headers}")
                    logger.debug(f"{self} Payload: {payload}")
            except Exception as e:
                logger.warning(f"Error processing message: {e}")
