#!/usr/bin/env python3
"""
Fantastic Sound Factory — Agente de voz en tiempo real
Pipecat + LiveKit + Whisper STT + DashScope Qwen (LLM + TTS)
"""

import asyncio
import base64
import json
import os
import uuid
from typing import AsyncGenerator

import aiohttp
from dotenv import load_dotenv
from loguru import logger

load_dotenv("/opt/pipecat/.env")

# ── Pipecat ───────────────────────────────────────────────────────────────────
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
    Frame,
    TTSAudioRawFrame,
    TTSStartedFrame,
    TTSStoppedFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.qwen.llm import QwenLLMService
from pipecat.services.tts_service import TTSService
from pipecat.services.whisper.stt import WhisperSTTService, Model
from pipecat.transports.livekit.transport import LiveKitParams, LiveKitTransport

# ── LiveKit token ─────────────────────────────────────────────────────────────
from livekit import api as livekit_api

# ── Config ────────────────────────────────────────────────────────────────────
DASHSCOPE_API_KEY  = os.environ["DASHSCOPE_API_KEY"]
LIVEKIT_URL        = os.environ["LIVEKIT_URL"]
LIVEKIT_API_KEY    = os.environ["LIVEKIT_API_KEY"]
LIVEKIT_API_SECRET = os.environ["LIVEKIT_API_SECRET"]
N8N_BASE_URL       = os.getenv("N8N_BASE_URL", "https://n8n.fantasticsoundfactory.com")
ROOM_NAME          = os.getenv("LIVEKIT_ROOM_NAME", "fsf-test")

DASHSCOPE_WS_URL  = "wss://dashscope-intl.aliyuncs.com/api-ws/v1/inference/"
DASHSCOPE_LLM_URL = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"

SYSTEM_PROMPT = """Eres el narrador de Fantastic Sound Factory, una novela interactiva de ficción con voz.
Tu rol es narrar la historia de forma inmersiva, responder a las acciones del jugador y mantener la narrativa.
Habla de forma descriptiva y teatral en español. Usa las tools disponibles para obtener el estado del juego
y resolver las acciones del jugador antes de narrar el resultado."""


# ── DashScope TTS (WebSocket real-time) ───────────────────────────────────────
class DashScopeTTSService(TTSService):
    """TTS en tiempo real — DashScope qwen3-tts-vc-realtime."""

    def __init__(self, *, api_key: str, **kwargs):
        super().__init__(
            push_stop_frames=True,
            **kwargs,
        )
        self._api_key = api_key
        self._model   = "qwen3-tts-vc-realtime-2026-01-15"
        self._voice   = "Chelsie"
        self._sample_rate = 24000

    async def run_tts(self, text: str, context_id: str = None) -> AsyncGenerator[Frame, None]:
        from websockets.asyncio.client import connect
        task_id = uuid.uuid4().hex
        logger.debug(f"DashScope TTS: síntesis de '{text[:60]}...' task_id={task_id}")
        try:
            async with connect(
                DASHSCOPE_WS_URL,
                additional_headers={"Authorization": f"Bearer {self._api_key}"},
            ) as ws:
                await ws.send(json.dumps({
                    "header": {
                        "action": "run-task",
                        "task_id": task_id,
                        "streaming": "duplex",
                    },
                    "payload": {
                        "task_group": "audio",
                        "task": "tts",
                        "function": "SpeechSynthesizer",
                        "model": self._model,
                        "parameters": {
                            "text_type": "PlainText",
                            "voice": self._voice,
                            "format": "pcm",
                            "sample_rate": self._sample_rate,
                            "volume": 50,
                            "rate": 1.0,
                            "pitch": 1.0,
                        },
                        "input": {},
                    },
                }))
                await ws.send(json.dumps({
                    "header": {"action": "continue-task", "task_id": task_id, "streaming": "duplex"},
                    "payload": {"input": {"text": text}},
                }))
                await ws.send(json.dumps({
                    "header": {"action": "finish-task", "task_id": task_id, "streaming": "duplex"},
                    "payload": {"input": {}},
                }))

                yield TTSStartedFrame()
                audio_chunks = 0
                async for msg in ws:
                    if isinstance(msg, bytes) and msg:
                        audio_chunks += 1
                        yield TTSAudioRawFrame(
                            audio=msg,
                            sample_rate=self._sample_rate,
                            num_channels=1,
                        )
                    elif isinstance(msg, str):
                        data = json.loads(msg)
                        event = data.get("header", {}).get("event", "")
                        # Algunos modelos devuelven audio en base64 dentro del JSON
                        audio_b64 = (
                            data.get("payload", {})
                                .get("output", {})
                                .get("audio")
                        )
                        if audio_b64:
                            audio_bytes = base64.b64decode(audio_b64)
                            if audio_bytes:
                                audio_chunks += 1
                                yield TTSAudioRawFrame(
                                    audio=audio_bytes,
                                    sample_rate=self._sample_rate,
                                    num_channels=1,
                                )
                        if event in ("task-finished", "task-failed"):
                            if event == "task-failed":
                                logger.error(f"DashScope TTS task-failed: {data}")
                            break
                logger.debug(f"DashScope TTS: {audio_chunks} chunks de audio recibidos")
                yield TTSStoppedFrame()
        except Exception as e:
            logger.error(f"DashScope TTS error: {e}")
            yield TTSStartedFrame()
            yield TTSStoppedFrame()


# ── N8N tool handlers (nuevo API: FunctionCallParams) ─────────────────────────
async def _call_n8n(path: str, session_id: str, extra: dict = None) -> str:
    url = f"{N8N_BASE_URL}/webhook/{path}"
    payload = {"session_id": session_id, **(extra or {})}
    try:
        async with aiohttp.ClientSession() as http:
            async with http.post(
                url, json=payload, timeout=aiohttp.ClientTimeout(total=10)
            ) as r:
                return json.dumps(await r.json(), ensure_ascii=False)
    except Exception as e:
        logger.error(f"N8N {path} error: {e}")
        return json.dumps({"error": str(e)})


async def handle_resolve_action(params: FunctionCallParams) -> None:
    args = params.arguments
    result = await _call_n8n("resolve-action", args.get("session_id", ""), {"action": args.get("action", "")})
    await params.result_callback(result)

async def handle_search_story(params: FunctionCallParams) -> None:
    args = params.arguments
    result = await _call_n8n("agent-rag", args.get("session_id", ""), {"query": args.get("query", "")})
    await params.result_callback(result)

async def handle_get_config(params: FunctionCallParams) -> None:
    result = await _call_n8n("get-config", params.arguments.get("session_id", ""))
    await params.result_callback(result)

async def handle_get_game_state(params: FunctionCallParams) -> None:
    result = await _call_n8n("get-game-state", params.arguments.get("session_id", ""))
    await params.result_callback(result)


# ── Tools schema ───────────────────────────────────────────────────────────────
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "resolve_action",
            "description": "Resuelve una acción del jugador y devuelve el resultado narrativo.",
            "parameters": {
                "type": "object",
                "properties": {
                    "session_id": {"type": "string"},
                    "action":     {"type": "string"},
                },
                "required": ["session_id", "action"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "search_story",
            "description": "Busca información en la base de conocimiento de la historia.",
            "parameters": {
                "type": "object",
                "properties": {
                    "session_id": {"type": "string"},
                    "query":      {"type": "string"},
                },
                "required": ["session_id", "query"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_config",
            "description": "Obtiene la configuración del juego para la sesión.",
            "parameters": {
                "type": "object",
                "properties": {"session_id": {"type": "string"}},
                "required": ["session_id"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_game_state",
            "description": "Obtiene el estado actual del juego.",
            "parameters": {
                "type": "object",
                "properties": {"session_id": {"type": "string"}},
                "required": ["session_id"],
            },
        },
    },
]


# ── Main ────────────────────────────────────────────────────────────────────────
async def main():
    token = livekit_api.AccessToken(LIVEKIT_API_KEY, LIVEKIT_API_SECRET)
    token.with_identity("fsf-narrator").with_name("FSF Narrator").with_grants(
        livekit_api.VideoGrants(room_join=True, room=ROOM_NAME)
    )
    jwt = token.to_jwt()

    transport = LiveKitTransport(
        url=LIVEKIT_URL,
        token=jwt,
        room_name=ROOM_NAME,
        params=LiveKitParams(
            audio_in_enabled=True,
            audio_out_enabled=True,
            vad_analyzer=SileroVADAnalyzer(),
        ),
    )

    stt = WhisperSTTService(model=Model.SMALL, language="es", device="cpu")

    llm = QwenLLMService(
        api_key=DASHSCOPE_API_KEY,
        base_url=DASHSCOPE_LLM_URL,
        model="qwen-plus",
    )
    llm.register_function("resolve_action",  handle_resolve_action)
    llm.register_function("search_story",    handle_search_story)
    llm.register_function("get_config",      handle_get_config)
    llm.register_function("get_game_state",  handle_get_game_state)

    tts = DashScopeTTSService(api_key=DASHSCOPE_API_KEY)

    context = OpenAILLMContext(
        messages=[{"role": "system", "content": SYSTEM_PROMPT}],
        tools=TOOLS,
    )
    context_aggregator = llm.create_context_aggregator(context)

    pipeline = Pipeline([
        transport.input(),
        stt,
        context_aggregator.user(),
        llm,
        tts,
        transport.output(),
        context_aggregator.assistant(),
    ])

    task = PipelineTask(
        pipeline,
        params=PipelineParams(allow_interruptions=True),
    )

    @transport.event_handler("on_first_participant_joined")
    async def on_first_participant_joined(transport, participant_id):
        logger.info(f"Jugador conectado: {participant_id}")
        await task.queue_frames([context_aggregator.user().get_context_frame()])

    @transport.event_handler("on_participant_disconnected")
    async def on_participant_disconnected(transport, participant_id):
        logger.info(f"Jugador desconectado: {participant_id}")

    runner = PipelineRunner()
    await runner.run(task)


if __name__ == "__main__":
    asyncio.run(main())
