#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Mistral LLM service implementation using OpenAI-compatible interface."""

from dataclasses import dataclass
from typing import List, Optional, Sequence

from loguru import logger
from openai.types.chat import ChatCompletionMessageParam

from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.frames.frames import FunctionCallFromLLM
from pipecat.services.openai.base_llm import OpenAILLMSettings
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.settings import _warn_deprecated_param


@dataclass
class MistralLLMSettings(OpenAILLMSettings):
    """Settings for MistralLLMService."""

    pass


class MistralLLMService(OpenAILLMService):
    """A service for interacting with Mistral's API using the OpenAI-compatible interface.

    This service extends OpenAILLMService to connect to Mistral's API endpoint while
    maintaining full compatibility with OpenAI's interface and functionality.
    """

    Settings = MistralLLMSettings
    _settings: MistralLLMSettings

    def __init__(
        self,
        *,
        api_key: str,
        base_url: str = "https://api.mistral.ai/v1",
        model: Optional[str] = None,
        settings: Optional[MistralLLMSettings] = None,
        **kwargs,
    ):
        """Initialize the Mistral LLM service.

        Args:
            api_key: The API key for accessing Mistral's API.
            base_url: The base URL for Mistral API. Defaults to "https://api.mistral.ai/v1".
            model: The model identifier to use. Defaults to "mistral-small-latest".

                .. deprecated:: 0.0.105
                    Use ``settings=OpenAILLMSettings(model=...)`` instead.

            settings: Runtime-updatable settings. When provided alongside deprecated
                parameters, ``settings`` values take precedence.
            **kwargs: Additional keyword arguments passed to OpenAILLMService.
        """
        # 1. Initialize default_settings with hardcoded defaults
        default_settings = MistralLLMSettings(model="mistral-small-latest")

        # 2. Apply direct init arg overrides (deprecated)
        if model is not None:
            _warn_deprecated_param("model", MistralLLMSettings, "model")
            default_settings.model = model

        # 3. (No step 3, as there's no params object to apply)

        # 4. Apply settings delta (canonical API, always wins)
        if settings is not None:
            default_settings.apply_update(settings)

        super().__init__(api_key=api_key, base_url=base_url, settings=default_settings, **kwargs)

    def create_client(self, api_key=None, base_url=None, **kwargs):
        """Create OpenAI-compatible client for Mistral API endpoint.

        Args:
            api_key: The API key for authentication. If None, uses instance key.
            base_url: The base URL for the API. If None, uses instance URL.
            **kwargs: Additional arguments passed to the client constructor.

        Returns:
            An OpenAI-compatible client configured for Mistral API.
        """
        logger.debug(f"Creating Mistral client with api {base_url}")
        return super().create_client(api_key, base_url, **kwargs)

    def _apply_mistral_fixups(
        self, messages: List[ChatCompletionMessageParam]
    ) -> List[ChatCompletionMessageParam]:
        """Apply fixups to messages to meet Mistral-specific requirements.

        1. A "tool"-role message must be followed by an assistant message.

        2. "system"-role messages must only appear at the start of a
           conversation.

        3. Assistant messages must have prefix=True when they are the final
           message in a conversation (but at no other point).

        Args:
            messages: The original list of messages.

        Returns:
            Messages with Mistral prefix requirement applied to final assistant message.
        """
        if not messages:
            return messages

        # Create a copy to avoid modifying the original
        fixed_messages = [dict(msg) for msg in messages]

        # Ensure all tool responses are followed by an assistant message
        assistant_insert_indices = []
        for i, msg in enumerate(fixed_messages):
            if msg.get("role") == "tool":
                # If this is the last message or the next message is not assistant
                if i == len(fixed_messages) - 1 or fixed_messages[i + 1].get("role") != "assistant":
                    assistant_insert_indices.append(i + 1)
        for idx in reversed(assistant_insert_indices):
            fixed_messages.insert(idx, {"role": "assistant", "content": " "})

        # Convert any "system" messages that aren't at the start (i.e., after the initial contiguous block) to "user"
        first_non_system_idx = next(
            (i for i, msg in enumerate(fixed_messages) if msg.get("role") != "system"),
            len(fixed_messages),
        )
        for i, msg in enumerate(fixed_messages):
            if msg.get("role") == "system" and i >= first_non_system_idx:
                msg["role"] = "user"

        # Get the last message
        last_message = fixed_messages[-1]

        # Only add prefix=True to the last message if it's an assistant message
        # and Mistral would otherwise reject it
        if last_message.get("role") == "assistant" and "prefix" not in last_message:
            last_message["prefix"] = True

        return fixed_messages

    async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
        """Execute function calls, filtering out already-completed ones.

        Mistral and OpenAI have different function call detection patterns:

        OpenAI (Stream-based detection):

        - Detects function calls only from streaming chunks as the LLM generates them
        - Second LLM completion doesn't re-detect existing tool_calls in message history
        - Function calls execute exactly once

        Mistral (Message-based detection):

        - Detects function calls from the complete message history on each completion
        - Second LLM completion with the response re-detects the same tool_calls from
          previous messages
        - Without filtering, function calls would execute twice

        This method prevents duplicate execution by:

        1. Checking message history for existing tool result messages
        2. Filtering out function calls that already have corresponding results
        3. Only executing function calls that haven't been completed yet

        Note: This filtering prevents duplicate function execution, but the
        on_function_calls_started event may still fire twice due to the detection
        pattern difference. This is expected behavior.

        Args:
            function_calls: The function calls to potentially execute.
        """
        if not function_calls:
            return

        # Filter out function calls that already have results
        calls_to_execute = []

        # Get messages from the first function call's context (they should all have the same context)
        messages = function_calls[0].context.get_messages() if function_calls else []

        # Get all tool_call_ids that already have results
        executed_call_ids = set()
        for msg in messages:
            if msg.get("role") == "tool" and msg.get("tool_call_id"):
                executed_call_ids.add(msg.get("tool_call_id"))

        # Only include function calls that haven't been executed yet
        for call in function_calls:
            if call.tool_call_id not in executed_call_ids:
                calls_to_execute.append(call)
            else:
                logger.trace(
                    f"Skipping already-executed function call: {call.function_name}:{call.tool_call_id}"
                )

        # Call parent method with filtered list
        if calls_to_execute:
            await super().run_function_calls(calls_to_execute)

    def build_chat_completion_params(self, params_from_context: OpenAILLMInvocationParams) -> dict:
        """Build parameters for Mistral chat completion request.

        Handles Mistral-specific requirements including:
        - Assistant message prefix requirement for API compatibility
        - Parameter mapping (random_seed instead of seed)
        - Core completion settings
        """
        # Apply Mistral's assistant prefix requirement for API compatibility
        fixed_messages = self._apply_mistral_fixups(params_from_context["messages"])

        params = {
            "model": self._settings.model,
            "stream": True,
            "messages": fixed_messages,
            "tools": params_from_context["tools"],
            "tool_choice": params_from_context["tool_choice"],
            "frequency_penalty": self._settings.frequency_penalty,
            "presence_penalty": self._settings.presence_penalty,
            "temperature": self._settings.temperature,
            "top_p": self._settings.top_p,
            "max_tokens": self._settings.max_tokens,
        }

        # Handle Mistral-specific parameter mapping
        # Mistral uses "random_seed" instead of "seed"
        if self._settings.seed:
            params["random_seed"] = self._settings.seed

        # Add any extra parameters
        params.update(self._settings.extra)

        # Prepend system instruction if set
        if self._settings.system_instruction:
            messages = params.get("messages", [])
            params["messages"] = [
                {"role": "system", "content": self._settings.system_instruction}
            ] + messages

        return params
