# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License

"""Hume Text-to-Speech service implementation."""

import base64
import os
import warnings
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Optional

import httpx
from loguru import logger
from pydantic import BaseModel

from pipecat import version as pipecat_version
from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    Frame,
    InterruptionFrame,
    StartFrame,
    TTSAudioRawFrame,
    TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import NOT_GIVEN, TTSSettings, _NotGiven, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
from pipecat.utils.tracing.service_decorators import traced_tts

try:
    from hume import AsyncHumeClient
    from hume.tts import FormatPcm, PostedUtterance, PostedUtteranceVoiceWithId
    from hume.tts.types import TimestampMessage
except ModuleNotFoundError as e:  # pragma: no cover - import-time guidance
    logger.error(f"Exception: {e}")
    logger.error("In order to use Hume, you need to `pip install pipecat-ai[hume]`.")
    raise Exception(f"Missing module: {e}")


HUME_SAMPLE_RATE = 48_000  # Hume TTS streams at 48 kHz

# Tracking headers for Hume API requests
DEFAULT_HEADERS = {
    "X-Hume-Client-Name": "pipecat",
    "X-Hume-Client-Version": pipecat_version(),
}


@dataclass
class HumeTTSSettings(TTSSettings):
    """Settings for HumeTTSService.

    Parameters:
        description: Natural-language acting directions (up to 100 characters).
        speed: Speaking-rate multiplier (0.5-2.0).
        trailing_silence: Seconds of silence to append at the end (0-5).
    """

    description: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    speed: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
    trailing_silence: float | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)


class HumeTTSService(TTSService):
    """Hume Octave Text-to-Speech service.

    Streams PCM audio via Hume's HTTP output streaming (JSON chunks) endpoint
    using the Python SDK and emits ``TTSAudioRawFrame`` frames suitable for Pipecat transports.

    Supported features:

    - Generates speech from text using Hume TTS.
    - Streams PCM audio.
    - Supports word-level timestamps for precise audio-text synchronization.
    - Supports dynamic updates of voice and synthesis parameters at runtime.
    - Provides metrics for Time To First Byte (TTFB) and TTS usage.
    """

    Settings = HumeTTSSettings
    _settings: HumeTTSSettings

    class InputParams(BaseModel):
        """Optional synthesis parameters for Hume TTS.

        .. deprecated:: 0.0.105
            Use ``settings=HumeTTSSettings(...)`` instead.

        Parameters:
            description: Natural-language acting directions (up to 100 characters).
            speed: Speaking-rate multiplier (0.5-2.0).
            trailing_silence: Seconds of silence to append at the end (0-5).
        """

        description: Optional[str] = None
        speed: Optional[float] = None
        trailing_silence: Optional[float] = None

    def __init__(
        self,
        *,
        api_key: Optional[str] = None,
        voice_id: Optional[str] = None,
        params: Optional[InputParams] = None,
        sample_rate: Optional[int] = HUME_SAMPLE_RATE,
        settings: Optional[HumeTTSSettings] = None,
        **kwargs,
    ) -> None:
        """Initialize the HumeTTSService.

        Args:
            api_key: Hume API key. If omitted, reads the ``HUME_API_KEY`` environment variable.
            voice_id: ID of the voice to use. Only voice IDs are supported; voice names are not.

                .. deprecated:: 0.0.105
                    Use ``settings=HumeTTSSettings(voice=...)`` instead.

            params: Optional synthesis controls (acting instructions, speed, trailing silence).

                .. deprecated:: 0.0.105
                    Use ``settings=HumeTTSSettings(...)`` instead.

            sample_rate: Output sample rate for emitted PCM frames. Defaults to 48_000 (Hume).
            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional arguments passed to the parent class.
        """
        api_key = api_key or os.getenv("HUME_API_KEY")
        if not api_key:
            raise ValueError("HumeTTSService requires an API key (env HUME_API_KEY or api_key=)")

        if sample_rate != HUME_SAMPLE_RATE:
            logger.warning(
                f"Hume TTS streams at {HUME_SAMPLE_RATE} Hz; configured sample_rate={sample_rate}"
            )

        # 1. Initialize default_settings with hardcoded defaults
        default_settings = HumeTTSSettings(
            model=None,
            voice=None,
            language=None,  # Not applicable here
            description=None,
            speed=None,
            trailing_silence=None,
        )

        # 2. Apply direct init arg overrides (deprecated)
        if voice_id is not None:
            _warn_deprecated_param("voice_id", HumeTTSSettings, "voice")
            default_settings.voice = voice_id

        # 3. Apply params overrides — only if settings not provided
        if params is not None:
            _warn_deprecated_param("params", HumeTTSSettings)
            if not settings:
                default_settings.description = params.description
                default_settings.speed = params.speed
                default_settings.trailing_silence = params.trailing_silence

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            push_text_frames=False,
            push_stop_frames=True,
            push_start_frame=True,
            settings=default_settings,
            **kwargs,
        )

        # Create a custom httpx.AsyncClient with tracking headers
        # Headers are included in all requests made by the Hume SDK
        self._http_client = httpx.AsyncClient(headers=DEFAULT_HEADERS)

        self._client = AsyncHumeClient(api_key=api_key, httpx_client=self._http_client)

        self._audio_bytes = b""

        # Track cumulative time for word timestamps across utterances
        self._cumulative_time = 0.0

    def can_generate_metrics(self) -> bool:
        """Can generate metrics.

        Returns:
            True if metrics can be generated, False otherwise.
        """
        return True

    async def start(self, frame: StartFrame) -> None:
        """Start the service.

        Args:
            frame: The start frame.
        """
        await super().start(frame)
        self._reset_state()

    def _reset_state(self):
        """Reset internal state variables."""
        self._cumulative_time = 0.0

    async def stop(self, frame: EndFrame) -> None:
        """Stop the service and cleanup resources.

        Args:
            frame: The end frame.
        """
        await super().stop(frame)
        if hasattr(self, "_http_client") and self._http_client:
            await self._http_client.aclose()

    async def cancel(self, frame: CancelFrame) -> None:
        """Cancel the service and cleanup resources.

        Args:
            frame: The cancel frame.
        """
        await super().cancel(frame)
        if hasattr(self, "_http_client") and self._http_client:
            await self._http_client.aclose()

    async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
        """Push a frame and handle state changes.

        Args:
            frame: The frame to push.
            direction: The direction to push the frame.
        """
        await super().push_frame(frame, direction)
        if isinstance(frame, (InterruptionFrame, TTSStoppedFrame)):
            # Reset timing on interruption or stop
            self._reset_state()

            if isinstance(frame, TTSStoppedFrame):
                await self.add_word_timestamps([("Reset", 0)])

    async def update_setting(self, key: str, value: Any) -> None:
        """Runtime updates via key/value pair.

        .. deprecated:: 0.0.104
            Use ``TTSUpdateSettingsFrame(delta=HumeTTSSettings(...))`` instead.

        Args:
            key: The name of the setting to update. Recognized keys are:
                - "voice_id"
                - "description"
                - "speed"
                - "trailing_silence"
            value: The new value for the setting.
        """
        with warnings.catch_warnings():
            warnings.simplefilter("always")
            warnings.warn(
                "'update_setting' is deprecated, use "
                "'TTSUpdateSettingsFrame(delta=HumeTTSSettings(...))' instead.",
                DeprecationWarning,
                stacklevel=2,
            )

        key_l = (key or "").lower()
        known_keys = {"voice_id", "voice", "description", "speed", "trailing_silence"}

        if key_l in known_keys:
            kwargs: dict[str, Any] = {}
            if key_l in ("voice_id", "voice"):
                kwargs["voice"] = str(value)
            elif key_l == "description":
                kwargs["description"] = None if value is None else str(value)
            elif key_l == "speed":
                kwargs["speed"] = None if value is None else float(value)
            elif key_l == "trailing_silence":
                kwargs["trailing_silence"] = None if value is None else float(value)
            await self._update_settings(HumeTTSSettings(**kwargs))

    @traced_tts
    async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
        """Generate speech from text using Hume TTS with word timestamps.

        Args:
            text: The text to be synthesized.
            context_id: Unique identifier for this TTS context.

        Returns:
            An async generator that yields `Frame` objects, including
            `TTSStartedFrame`, `TTSAudioRawFrame`, `ErrorFrame`, and
            `TTSStoppedFrame`.
        """
        logger.debug(f"{self}: Generating Hume TTS: [{text}]")

        # Build the request payload
        utterance_kwargs: dict[str, Any] = {
            "text": text,
            "voice": PostedUtteranceVoiceWithId(id=self._settings.voice),
        }
        if self._settings.description is not None:
            utterance_kwargs["description"] = self._settings.description
        if self._settings.speed is not None:
            utterance_kwargs["speed"] = self._settings.speed
        if self._settings.trailing_silence is not None:
            utterance_kwargs["trailing_silence"] = self._settings.trailing_silence

        utterance = PostedUtterance(**utterance_kwargs)

        # Request raw PCM chunks in the streaming JSON
        pcm_fmt = FormatPcm(type="pcm")

        await self.start_tts_usage_metrics(text)

        try:
            # Instant mode is always enabled here (not user-configurable)
            # Hume emits mono PCM at 48 kHz; downstream can resample if needed.
            # We buffer audio bytes before sending to prevent glitches.
            self._audio_bytes = b""

            # Use version "2" by default if no description is provided
            # Version "1" is needed when description is used
            version = "1" if self._settings.description is not None else "2"

            # Track the duration of this utterance based on the last timestamp
            utterance_duration = 0.0

            async for chunk in self._client.tts.synthesize_json_streaming(
                utterances=[utterance],
                format=pcm_fmt,
                instant_mode=True,
                version=version,
                include_timestamp_types=["word"],  # Request word-level timestamps
            ):
                # Process audio chunks
                audio_b64 = getattr(chunk, "audio", None)
                if audio_b64:
                    await self.stop_ttfb_metrics()
                    pcm_bytes = base64.b64decode(audio_b64)
                    self._audio_bytes += pcm_bytes

                    # Buffer audio until we have enough to avoid glitches
                    if len(self._audio_bytes) >= self.chunk_size:
                        frame = TTSAudioRawFrame(
                            audio=self._audio_bytes,
                            sample_rate=self.sample_rate,
                            num_channels=1,
                            context_id=context_id,
                        )
                        yield frame
                        self._audio_bytes = b""

                # Process timestamp messages
                if isinstance(chunk, TimestampMessage):
                    timestamp = chunk.timestamp
                    if timestamp.type == "word":
                        # Convert milliseconds to seconds and add cumulative offset
                        word_start_time = self._cumulative_time + (timestamp.time.begin / 1000.0)
                        word_end_time = self._cumulative_time + (timestamp.time.end / 1000.0)

                        # Track the maximum end time for this utterance
                        utterance_duration = max(utterance_duration, word_end_time)

                        # Add word timestamp
                        await self.add_word_timestamps(
                            [(timestamp.text, word_start_time)], context_id
                        )

            # Flush any remaining audio bytes
            if self._audio_bytes:
                frame = TTSAudioRawFrame(
                    audio=self._audio_bytes,
                    sample_rate=self.sample_rate,
                    num_channels=1,
                    context_id=context_id,
                )

                yield frame

                self._audio_bytes = b""

            # Update cumulative time for next utterance
            if utterance_duration > 0:
                self._cumulative_time = utterance_duration

        except Exception as e:
            await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
        finally:
            # Ensure TTFB timer is stopped even on early failures
            await self.stop_ttfb_metrics()
