#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""XTTS text-to-speech service implementation.

This module provides integration with Coqui XTTS streaming server for
text-to-speech synthesis using local Docker deployment.
"""

from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional

import aiohttp
from loguru import logger

from pipecat.audio.utils import create_stream_resampler
from pipecat.frames.frames import (
    ErrorFrame,
    Frame,
    StartFrame,
    TTSAudioRawFrame,
)
from pipecat.services.settings import TTSSettings, _warn_deprecated_param
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts

# The server below can connect to XTTS through a local running docker
#
# Docker command: $ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest-cuda121
#
# You can find more information on the official repo:
# https://github.com/coqui-ai/xtts-streaming-server


def language_to_xtts_language(language: Language) -> Optional[str]:
    """Convert a Language enum to XTTS language code.

    Args:
        language: The Language enum value to convert.

    Returns:
        The corresponding XTTS language code, or None if not supported.
    """
    LANGUAGE_MAP = {
        Language.CS: "cs",
        Language.DE: "de",
        Language.EN: "en",
        Language.ES: "es",
        Language.FR: "fr",
        Language.HI: "hi",
        Language.HU: "hu",
        Language.IT: "it",
        Language.JA: "ja",
        Language.KO: "ko",
        Language.NL: "nl",
        Language.PL: "pl",
        Language.PT: "pt",
        Language.RU: "ru",
        Language.TR: "tr",
        # Special case for Chinese base language
        Language.ZH: "zh-cn",
    }

    return resolve_language(language, LANGUAGE_MAP, use_base_code=True)


@dataclass
class XTTSTTSSettings(TTSSettings):
    """Settings for XTTSService."""

    pass


class XTTSService(TTSService):
    """Coqui XTTS text-to-speech service.

    Provides text-to-speech synthesis using a locally running Coqui XTTS
    streaming server. Supports multiple languages and voice cloning through
    studio speakers configuration.
    """

    Settings = XTTSTTSSettings
    _settings: XTTSTTSSettings

    def __init__(
        self,
        *,
        voice_id: Optional[str] = None,
        base_url: str,
        aiohttp_session: aiohttp.ClientSession,
        language: Language = Language.EN,
        sample_rate: Optional[int] = None,
        settings: Optional[XTTSTTSSettings] = None,
        **kwargs,
    ):
        """Initialize the XTTS service.

        Args:
            voice_id: ID of the voice/speaker to use for synthesis.

                .. deprecated:: 0.0.105
                    Use ``settings=XTTSTTSSettings(voice=...)`` instead.

            base_url: Base URL of the XTTS streaming server.
            aiohttp_session: HTTP session for making requests to the server.
            language: Language for synthesis. Defaults to English.
            sample_rate: Audio sample rate. If None, uses default.
            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional arguments passed to parent TTSService.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = XTTSTTSSettings(
            model=None,
            voice=None,
            language=self.language_to_service_language(language),
        )

        # 2. Apply direct init arg overrides (deprecated)
        if voice_id is not None:
            _warn_deprecated_param("voice_id", XTTSTTSSettings, "voice")
            default_settings.voice = voice_id

        # 3. (No step 3, as there's no params object to apply)

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(
            sample_rate=sample_rate,
            push_start_frame=True,
            push_stop_frames=True,
            settings=default_settings,
            **kwargs,
        )

        # Init-only fields (not runtime-updatable)
        self._base_url = base_url

        self._studio_speakers: Optional[Dict[str, Any]] = None
        self._aiohttp_session = aiohttp_session

        self._resampler = create_stream_resampler()

    def can_generate_metrics(self) -> bool:
        """Check if this service can generate processing metrics.

        Returns:
            True, as XTTS service supports metrics generation.
        """
        return True

    def language_to_service_language(self, language: Language) -> Optional[str]:
        """Convert a Language enum to XTTS service language format.

        Args:
            language: The language to convert.

        Returns:
            The XTTS-specific language code, or None if not supported.
        """
        return language_to_xtts_language(language)

    async def start(self, frame: StartFrame):
        """Start the XTTS service and load studio speakers.

        Args:
            frame: The start frame containing initialization parameters.
        """
        await super().start(frame)

        if self._studio_speakers:
            return

        async with self._aiohttp_session.get(self._base_url + "/studio_speakers") as r:
            if r.status != 200:
                text = await r.text()
                await self.push_error(
                    error_msg=f"Error getting studio speakers (status: {r.status}, error: {text})"
                )
                return
            self._studio_speakers = await r.json()

    @traced_tts
    async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
        """Generate speech from text using XTTS streaming server.

        Args:
            text: The text to synthesize into speech.
            context_id: The context ID for tracking audio frames.

        Yields:
            Frame: Audio frames containing the synthesized speech.
        """
        logger.debug(f"{self}: Generating TTS [{text}]")

        if not self._studio_speakers:
            logger.error(f"{self} no studio speakers available")
            return

        embeddings = self._studio_speakers[self._settings.voice]

        url = self._base_url + "/tts_stream"

        payload = {
            "text": text.replace(".", "").replace("*", ""),
            "language": self._settings.language,
            "speaker_embedding": embeddings["speaker_embedding"],
            "gpt_cond_latent": embeddings["gpt_cond_latent"],
            "add_wav_header": False,
            "stream_chunk_size": 20,
        }

        async with self._aiohttp_session.post(url, json=payload) as r:
            if r.status != 200:
                text = await r.text()
                yield ErrorFrame(error=f"Error getting audio (status: {r.status}, error: {text})")
                return

            await self.start_tts_usage_metrics(text)

            CHUNK_SIZE = self.chunk_size

            buffer = bytearray()
            async for chunk in r.content.iter_chunked(CHUNK_SIZE):
                if len(chunk) > 0:
                    await self.stop_ttfb_metrics()
                    # Append new chunk to the buffer.
                    buffer.extend(chunk)

                    # Check if buffer has enough data for processing.
                    while (
                        len(buffer) >= 48000
                    ):  # Assuming at least 0.5 seconds of audio data at 24000 Hz
                        # Process the buffer up to a safe size for resampling.
                        process_data = buffer[:48000]
                        # Remove processed data from buffer.
                        buffer = buffer[48000:]

                        # XTTS uses 24000 so we need to resample to our desired rate.
                        resampled_audio = await self._resampler.resample(
                            bytes(process_data), 24000, self.sample_rate
                        )
                        # Create the frame with the resampled audio
                        frame = TTSAudioRawFrame(
                            resampled_audio, self.sample_rate, 1, context_id=context_id
                        )
                        yield frame

            # Process any remaining data in the buffer.
            if len(buffer) > 0:
                resampled_audio = await self._resampler.resample(
                    bytes(buffer), 24000, self.sample_rate
                )
                frame = TTSAudioRawFrame(
                    resampled_audio, self.sample_rate, 1, context_id=context_id
                )
                yield frame
