#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Base classes for Large Language Model services with function calling support."""

import asyncio
import inspect
import warnings
from dataclasses import dataclass
from typing import (
    Any,
    Awaitable,
    Callable,
    Dict,
    Mapping,
    Optional,
    Protocol,
    Sequence,
    Type,
)

from loguru import logger

from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
from pipecat.adapters.schemas.direct_function import DirectFunction, DirectFunctionWrapper
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
from pipecat.frames.frames import (
    CancelFrame,
    EndFrame,
    Frame,
    FunctionCallCancelFrame,
    FunctionCallFromLLM,
    FunctionCallInProgressFrame,
    FunctionCallResultFrame,
    FunctionCallResultProperties,
    FunctionCallsStartedFrame,
    InterruptionFrame,
    LLMConfigureOutputFrame,
    LLMContextSummaryRequestFrame,
    LLMContextSummaryResultFrame,
    LLMFullResponseEndFrame,
    LLMFullResponseStartFrame,
    LLMTextFrame,
    LLMUpdateSettingsFrame,
    StartFrame,
    UserImageRequestFrame,
)
from pipecat.processors.aggregators.llm_context import (
    LLMContext,
    LLMSpecificMessage,
)
from pipecat.processors.aggregators.llm_response import (
    LLMAssistantAggregatorParams,
    LLMUserAggregatorParams,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_service import AIService
from pipecat.services.settings import LLMSettings
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin
from pipecat.utils.context.llm_context_summarization import (
    DEFAULT_SUMMARIZATION_TIMEOUT,
    LLMContextSummarizationUtil,
)

# Type alias for a callable that handles LLM function calls.
FunctionCallHandler = Callable[["FunctionCallParams"], Awaitable[None]]


# Type alias for a callback function that handles the result of an LLM function call.
class FunctionCallResultCallback(Protocol):
    """Protocol for function call result callbacks.

    Handles the result of an LLM function call execution.
    """

    async def __call__(
        self, result: Any, *, properties: Optional[FunctionCallResultProperties] = None
    ) -> None:
        """Call the result callback.

        Args:
            result: The result of the function call.
            properties: Optional properties for the result.
        """
        ...


@dataclass
class FunctionCallParams:
    """Parameters for a function call.

    Parameters:
        function_name: The name of the function being called.
        tool_call_id: A unique identifier for the function call.
        arguments: The arguments for the function.
        llm: The LLMService instance being used.
        context: The LLM context.
        result_callback: Callback to handle the result of the function call.
    """

    function_name: str
    tool_call_id: str
    arguments: Mapping[str, Any]
    llm: "LLMService"
    context: OpenAILLMContext | LLMContext
    result_callback: FunctionCallResultCallback


@dataclass
class FunctionCallRegistryItem:
    """Represents an entry in the function call registry.

    This is what the user registers when calling register_function.

    Parameters:
        function_name: The name of the function (None for catch-all handler).
        handler: The handler for processing function call parameters.
        cancel_on_interruption: Whether to cancel the call on interruption.
        timeout_secs: Optional per-tool timeout in seconds. Overrides the global
            ``function_call_timeout_secs`` for this specific function.
    """

    function_name: Optional[str]
    handler: FunctionCallHandler | "DirectFunctionWrapper"
    cancel_on_interruption: bool
    handler_deprecated: bool
    timeout_secs: Optional[float] = None


@dataclass
class FunctionCallRunnerItem:
    """Internal function call entry for the function call runner.

    The runner executes function calls in order.

    Parameters:
        registry_item: The registry item containing handler information.
        function_name: The name of the function.
        tool_call_id: A unique identifier for the function call.
        arguments: The arguments for the function.
        context: The LLM context.
        run_llm: Optional flag to control LLM execution after function call.
    """

    registry_item: FunctionCallRegistryItem
    function_name: str
    tool_call_id: str
    arguments: Mapping[str, Any]
    context: OpenAILLMContext | LLMContext
    run_llm: Optional[bool] = None


class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
    """Base class for all LLM services.

    Handles function calling registration and execution with support for both
    parallel and sequential execution modes. Provides event handlers for
    completion timeouts and function call lifecycle events.

    The service supports the following event handlers:

    - on_completion_timeout: Called when an LLM completion timeout occurs
    - on_function_calls_started: Called when function calls are received and
      execution is about to start

    Example::

        @task.event_handler("on_completion_timeout")
        async def on_completion_timeout(service):
            logger.warning("LLM completion timed out")

        @task.event_handler("on_function_calls_started")
        async def on_function_calls_started(service, function_calls):
            logger.info(f"Starting {len(function_calls)} function calls")
    """

    _settings: LLMSettings

    # OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations.
    # However, subclasses should override this with a more specific adapter when necessary.
    adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter

    def __init__(
        self,
        run_in_parallel: bool = True,
        function_call_timeout_secs: float = 10.0,
        settings: Optional[LLMSettings] = None,
        **kwargs,
    ):
        """Initialize the LLM service.

        Args:
            run_in_parallel: Whether to run function calls in parallel or sequentially.
                Defaults to True.
            function_call_timeout_secs: Timeout in seconds for deferred function calls.
                Defaults to 10.0 seconds.
            settings: The runtime-updatable settings for the LLM service.
            **kwargs: Additional arguments passed to the parent AIService.

        """
        super().__init__(
            settings=settings
            # Here in case subclass doesn't implement more specific settings
            # (which hopefully should be rare)
            or LLMSettings(),
            **kwargs,
        )
        self._run_in_parallel = run_in_parallel
        self._function_call_timeout_secs = function_call_timeout_secs
        self._filter_incomplete_user_turns: bool = False
        self._base_system_instruction: Optional[str] = None
        self._start_callbacks = {}
        self._adapter = self.adapter_class()
        self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
        self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
        self._sequential_runner_task: Optional[asyncio.Task] = None
        self._skip_tts: Optional[bool] = None
        self._summary_task: Optional[asyncio.Task] = None

        self._register_event_handler("on_function_calls_started")
        self._register_event_handler("on_completion_timeout")

    def get_llm_adapter(self) -> BaseLLMAdapter:
        """Get the LLM adapter instance.

        Returns:
            The adapter instance used for LLM communication.
        """
        return self._adapter

    def create_llm_specific_message(self, message: Any) -> LLMSpecificMessage:
        """Create an LLM-specific message (as opposed to a standard message) for use in an LLMContext.

        Args:
            message: The message content.

        Returns:
            A LLMSpecificMessage instance.
        """
        return self.get_llm_adapter().create_llm_specific_message(message)

    async def run_inference(
        self,
        context: LLMContext | OpenAILLMContext,
        max_tokens: Optional[int] = None,
        system_instruction: Optional[str] = None,
    ) -> Optional[str]:
        """Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context.

        Must be implemented by subclasses.

        Args:
            context: The LLM context containing conversation history.
            max_tokens: Optional maximum number of tokens to generate. If provided,
                overrides the service's default max_tokens/max_completion_tokens setting.
            system_instruction: Optional system instruction to use for this inference.
                If provided, overrides any system instruction in the context.

        Returns:
            The LLM's response as a string, or None if no response is generated.
        """
        raise NotImplementedError(f"run_inference() not supported by {self.__class__.__name__}")

    def create_context_aggregator(
        self,
        context: OpenAILLMContext,
        *,
        user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
        assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
    ) -> Any:
        """Create a context aggregator for managing LLM conversation context.

        Must be implemented by subclasses.

        Args:
            context: The LLM context to create an aggregator for.
            user_params: Parameters for user message aggregation.
            assistant_params: Parameters for assistant message aggregation.

        Returns:
            A context aggregator instance.

        .. deprecated:: 0.0.99
            `create_context_aggregator()` is deprecated and will be removed in a future version.
            Use the universal `LLMContext` and `LLMContextAggregatorPair` instead.
            See `OpenAILLMContext` docstring for migration guide.
        """
        with warnings.catch_warnings():
            warnings.simplefilter("always")
            warnings.warn(
                "create_context_aggregator() is deprecated and will be removed in a future version. "
                "Use the universal LLMContext and LLMContextAggregatorPair directly instead. "
                "See OpenAILLMContext docstring for migration guide.",
                DeprecationWarning,
                stacklevel=2,
            )
        pass

    async def start(self, frame: StartFrame):
        """Start the LLM service.

        Args:
            frame: The start frame.
        """
        await super().start(frame)
        if not self._run_in_parallel:
            await self._create_sequential_runner_task()

    async def stop(self, frame: EndFrame):
        """Stop the LLM service.

        Args:
            frame: The end frame.
        """
        await super().stop(frame)
        if not self._run_in_parallel:
            await self._cancel_sequential_runner_task()
        await self._cancel_summary_task()

    async def cancel(self, frame: CancelFrame):
        """Cancel the LLM service.

        Args:
            frame: The cancel frame.
        """
        await super().cancel(frame)
        if not self._run_in_parallel:
            await self._cancel_sequential_runner_task()
        await self._cancel_summary_task()

    def _compose_system_instruction(self):
        """Compose system_instruction by appending turn completion instructions.

        Combines the base system instruction with turn completion instructions
        and writes the result to ``self._settings.system_instruction``.
        """
        base = self._base_system_instruction
        completion_instructions = self._user_turn_completion_config.completion_instructions
        if base:
            self._settings.system_instruction = f"{base}\n\n{completion_instructions}"
        else:
            self._settings.system_instruction = completion_instructions

    async def _update_settings(self, delta: LLMSettings) -> dict[str, Any]:
        """Apply a settings delta, handling turn-completion fields.

        Args:
            delta: An LLM settings delta.

        Returns:
            Dict mapping changed field names to their previous values.
        """
        changed = await super()._update_settings(delta)

        if "filter_incomplete_user_turns" in changed:
            self._filter_incomplete_user_turns = (
                self._settings.filter_incomplete_user_turns or False
            )
            logger.info(
                f"{self}: Incomplete turn filtering "
                f"{'enabled' if self._filter_incomplete_user_turns else 'disabled'}"
            )
            if self._filter_incomplete_user_turns:
                # Save the current system_instruction before composing
                self._base_system_instruction = self._settings.system_instruction
                self._compose_system_instruction()
            else:
                # Restore original system_instruction
                self._settings.system_instruction = self._base_system_instruction
                self._base_system_instruction = None

        if "user_turn_completion_config" in changed and self._filter_incomplete_user_turns:
            self.set_user_turn_completion_config(self._settings.user_turn_completion_config)
            self._compose_system_instruction()

        if (
            "system_instruction" in changed
            and self._filter_incomplete_user_turns
            and "filter_incomplete_user_turns" not in changed
        ):
            # system_instruction changed while turn completion is active.
            # Treat the new value as the new base and recompose.
            self._base_system_instruction = self._settings.system_instruction
            self._compose_system_instruction()

        return changed

    async def process_frame(self, frame: Frame, direction: FrameDirection):
        """Process a frame.

        Args:
            frame: The frame to process.
            direction: The direction of frame processing.
        """
        await super().process_frame(frame, direction)

        if isinstance(frame, InterruptionFrame):
            await self._handle_interruptions(frame)
        elif isinstance(frame, LLMConfigureOutputFrame):
            self._skip_tts = frame.skip_tts
        elif isinstance(frame, LLMUpdateSettingsFrame):
            if frame.delta is not None:
                await self._update_settings(frame.delta)
            elif frame.settings:
                # Backward-compatible path: convert legacy dict to settings object.
                with warnings.catch_warnings():
                    warnings.simplefilter("always")
                    warnings.warn(
                        "Passing a dict via LLMUpdateSettingsFrame(settings={...}) is deprecated "
                        "since 0.0.104, use LLMUpdateSettingsFrame(delta=LLMSettings(...)) instead.",
                        DeprecationWarning,
                        stacklevel=2,
                    )
                delta = type(self._settings).from_mapping(frame.settings)
                await self._update_settings(delta)
        elif isinstance(frame, LLMContextSummaryRequestFrame):
            await self._handle_summary_request(frame)

    async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
        """Pushes a frame.

        Args:
            frame: The frame to push.
            direction: The direction of frame pushing.
        """
        if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
            if self._skip_tts is not None:
                frame.skip_tts = self._skip_tts

        await super().push_frame(frame, direction)

    async def _push_llm_text(self, text: str):
        """Push LLM text, using turn completion detection if enabled.

        This helper method simplifies text pushing in LLM implementations by
        handling the conditional logic for turn completion internally.

        Args:
            text: The text content from the LLM to push.
        """
        if self._filter_incomplete_user_turns:
            await self._push_turn_text(text)
        else:
            await self.push_frame(LLMTextFrame(text))

    async def _handle_interruptions(self, _: InterruptionFrame):
        for function_name, entry in self._functions.items():
            if entry.cancel_on_interruption:
                await self._cancel_function_call(function_name)

    async def _handle_summary_request(self, frame: LLMContextSummaryRequestFrame):
        """Handle context summarization request from aggregator.

        Processes a summarization request by generating a compressed summary
        of conversation history. Uses the adapter to format the summary
        according to the provider's requirements. Broadcasts the result back
        to the aggregator for context reconstruction.

        Args:
            frame: The summary request frame containing context and parameters.
        """
        logger.debug(f"{self}: Processing summarization request {frame.request_id}")

        # Create a background task to generate the summary without blocking
        self._summary_task = self.create_task(self._generate_summary_task(frame))

    async def _generate_summary_task(self, frame: LLMContextSummaryRequestFrame):
        """Background task to generate summary without blocking the pipeline.

        Args:
            frame: The summary request frame containing context and parameters.
        """
        summary = ""
        last_index = -1
        error = None

        timeout = frame.summarization_timeout or DEFAULT_SUMMARIZATION_TIMEOUT

        try:
            summary, last_index = await asyncio.wait_for(
                self._generate_summary(frame),
                timeout=timeout,
            )
        except asyncio.TimeoutError:
            await self.push_error(error_msg=f"Context summarization timed out after {timeout}s")
        except Exception as e:
            error = f"Error generating context summary: {e}"
            await self.push_error(error, exception=e)

        await self.broadcast_frame(
            LLMContextSummaryResultFrame,
            request_id=frame.request_id,
            summary=summary,
            last_summarized_index=last_index,
            error=error,
        )

        self._summary_task = None

    async def _generate_summary(self, frame: LLMContextSummaryRequestFrame) -> tuple[str, int]:
        """Generate a compressed summary of conversation context.

        Uses the message selection logic to identify which messages
        to summarize, formats them as a transcript, and invokes the LLM to
        generate a concise summary. The summary is formatted according to the
        LLM provider's requirements using the adapter.

        Args:
            frame: The summary request frame containing context and configuration.

        Returns:
            Tuple of (formatted summary message, last_summarized_index).

        Raises:
            RuntimeError: If there are no messages to summarize, the service doesn't
                support run_inference(), or the LLM returns an empty summary.

        Note:
            Requires the service to implement run_inference() method for
            synchronous LLM calls.
        """
        # Get messages to summarize using utility method
        result = LLMContextSummarizationUtil.get_messages_to_summarize(
            frame.context, frame.min_messages_to_keep
        )

        if not result.messages:
            logger.debug(f"{self}: No messages to summarize")
            raise RuntimeError("No messages to summarize")

        logger.debug(
            f"{self}: Generating summary for {len(result.messages)} messages "
            f"(index 0 to {result.last_summarized_index}), "
            f"target_context_tokens={frame.target_context_tokens}"
        )

        # Create summary context
        transcript = LLMContextSummarizationUtil.format_messages_for_summary(result.messages)
        summary_context = LLMContext(
            messages=[{"role": "user", "content": f"Conversation history:\n{transcript}"}]
        )

        # Generate summary using run_inference
        # This will be overridden by each LLM service implementation
        try:
            summary_text = await self.run_inference(
                summary_context,
                max_tokens=frame.target_context_tokens,
                system_instruction=frame.summarization_prompt,
            )
        except NotImplementedError:
            raise RuntimeError(
                f"LLM service {self.__class__.__name__} does not implement run_inference"
            )

        if not summary_text:
            raise RuntimeError("LLM returned empty summary")

        summary_text = summary_text.strip()
        logger.info(
            f"{self}: Generated summary of {len(summary_text)} characters "
            f"for {len(result.messages)} messages"
        )

        return summary_text, result.last_summarized_index

    def register_function(
        self,
        function_name: Optional[str],
        handler: Any,
        start_callback=None,
        *,
        cancel_on_interruption: bool = True,
        timeout_secs: Optional[float] = None,
    ):
        """Register a function handler for LLM function calls.

        Args:
            function_name: The name of the function to handle. Use None to handle
                all function calls with a catch-all handler.
            handler: The function handler. Should accept a single FunctionCallParams
                parameter.
            start_callback: Legacy callback function (deprecated). Put initialization
                code at the top of your handler instead.

                .. deprecated:: 0.0.59
                    The `start_callback` parameter is deprecated and will be removed in a future version.

            cancel_on_interruption: Whether to cancel this function call when an
                interruption occurs. Defaults to True.
            timeout_secs: Optional per-tool timeout in seconds. Overrides the global
                ``function_call_timeout_secs`` for this specific function. Defaults to
                None, which uses the global timeout.
        """
        signature = inspect.signature(handler)
        handler_deprecated = len(signature.parameters) > 1
        if handler_deprecated:
            with warnings.catch_warnings():
                warnings.simplefilter("always")
                warnings.warn(
                    "Function calls with parameters `(function_name, tool_call_id, arguments, llm, context, result_callback)` are deprecated, use a single `FunctionCallParams` parameter instead.",
                    DeprecationWarning,
                )

        # Registering a function with the function_name set to None will run
        # that handler for all functions
        self._functions[function_name] = FunctionCallRegistryItem(
            function_name=function_name,
            handler=handler,
            cancel_on_interruption=cancel_on_interruption,
            handler_deprecated=handler_deprecated,
            timeout_secs=timeout_secs,
        )

        # Start callbacks are now deprecated.
        if start_callback:
            with warnings.catch_warnings():
                warnings.simplefilter("always")
                warnings.warn(
                    "Parameter 'start_callback' is deprecated, just put your code on top of the actual function call instead.",
                    DeprecationWarning,
                )

            self._start_callbacks[function_name] = start_callback

    def register_direct_function(
        self,
        handler: DirectFunction,
        *,
        cancel_on_interruption: bool = True,
        timeout_secs: Optional[float] = None,
    ):
        """Register a direct function handler for LLM function calls.

        Direct functions have their metadata automatically extracted from their
        signature and docstring, eliminating the need for accompanying
        configurations (as FunctionSchemas or in provider-specific formats).

        Args:
            handler: The direct function to register. Must follow DirectFunction protocol.
            cancel_on_interruption: Whether to cancel this function call when an
                interruption occurs. Defaults to True.
            timeout_secs: Optional per-tool timeout in seconds. Overrides the global
                ``function_call_timeout_secs`` for this specific function. Defaults to
                None, which uses the global timeout.
        """
        wrapper = DirectFunctionWrapper(handler)
        self._functions[wrapper.name] = FunctionCallRegistryItem(
            function_name=wrapper.name,
            handler=wrapper,
            cancel_on_interruption=cancel_on_interruption,
            handler_deprecated=False,
            timeout_secs=timeout_secs,
        )

    def unregister_function(self, function_name: Optional[str]):
        """Remove a registered function handler.

        Args:
            function_name: The name of the function handler to remove.
        """
        del self._functions[function_name]
        if function_name in self._start_callbacks:
            del self._start_callbacks[function_name]

    def unregister_direct_function(self, handler: Any):
        """Remove a registered direct function handler.

        Args:
            handler: The direct function handler to remove.
        """
        wrapper = DirectFunctionWrapper(handler)
        del self._functions[wrapper.name]
        # Note: no need to remove start callback here, as direct functions don't support start callbacks.

    def has_function(self, function_name: str):
        """Check if a function handler is registered.

        Args:
            function_name: The name of the function to check.

        Returns:
            True if the function is registered or if a catch-all handler (None)
            is registered.
        """
        if None in self._functions.keys():
            return True
        return function_name in self._functions.keys()

    async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
        """Execute a sequence of function calls from the LLM.

        Triggers the on_function_calls_started event and executes functions
        either in parallel or sequentially based on the run_in_parallel setting.

        Args:
            function_calls: The function calls to execute.
        """
        if len(function_calls) == 0:
            return

        await self._call_event_handler("on_function_calls_started", function_calls)

        await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)

        runner_items = []
        for function_call in function_calls:
            if function_call.function_name in self._functions.keys():
                item = self._functions[function_call.function_name]
            elif None in self._functions.keys():
                item = self._functions[None]
            else:
                logger.warning(
                    f"{self} is calling '{function_call.function_name}', but it's not registered."
                )
                continue

            runner_items.append(
                FunctionCallRunnerItem(
                    registry_item=item,
                    function_name=function_call.function_name,
                    tool_call_id=function_call.tool_call_id,
                    arguments=function_call.arguments,
                    context=function_call.context,
                )
            )

        if self._run_in_parallel:
            await self._run_parallel_function_calls(runner_items)
        else:
            await self._run_sequential_function_calls(runner_items)

    async def request_image_frame(
        self,
        user_id: str,
        *,
        function_name: Optional[str] = None,
        tool_call_id: Optional[str] = None,
        text_content: Optional[str] = None,
        video_source: Optional[str] = None,
        timeout: Optional[float] = 10.0,
    ):
        """Request an image from a user.

        Pushes a UserImageRequestFrame upstream to request an image from the
        specified user. The user image can then be processed by the LLM.

        Use this function from a function call if you want the LLM to process
        the image. If you expect the image to be processed by a vision service,
        you might want to push a UserImageRequestFrame upstream directly.

        .. deprecated:: 0.0.92
            This method is deprecated, push a `UserImageRequestFrame` instead.

        Args:
            user_id: The ID of the user to request an image from.
            function_name: Optional function name associated with the request.
            tool_call_id: Optional tool call ID associated with the request.
            text_content: Optional text content/context for the image request.
            video_source: Optional video source identifier.
            timeout: Optional timeout for the requested image to be added to the LLM context.

        """
        with warnings.catch_warnings():
            warnings.simplefilter("always")
            warnings.warn(
                "Method `request_image_frame()` is deprecated, push a `UserImageRequestFrame` instead.",
                DeprecationWarning,
            )
        await self.push_frame(
            UserImageRequestFrame(
                user_id=user_id,
                text=text_content,
                append_to_context=True,
                function_name=function_name,
                tool_call_id=tool_call_id,
                # Deprecated fields below.
                context=text_content,
            ),
            FrameDirection.UPSTREAM,
        )

    async def _create_sequential_runner_task(self):
        if not self._sequential_runner_task:
            self._sequential_runner_queue = asyncio.Queue()
            self._sequential_runner_task = self.create_task(self._sequential_runner_handler())

    async def _cancel_sequential_runner_task(self):
        if self._sequential_runner_task:
            await self.cancel_task(self._sequential_runner_task)
            self._sequential_runner_task = None

    async def _cancel_summary_task(self):
        if self._summary_task:
            await self.cancel_task(self._summary_task)
            self._summary_task = None

    async def _sequential_runner_handler(self):
        while True:
            runner_item = await self._sequential_runner_queue.get()
            task = self.create_task(self._run_function_call(runner_item))
            self._function_call_tasks[task] = runner_item
            # Since we run tasks sequentially we don't need to call
            # task.add_done_callback(self._function_call_task_finished).
            await task
            del self._function_call_tasks[task]

    async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
        tasks = []
        for runner_item in runner_items:
            task = self.create_task(self._run_function_call(runner_item))
            tasks.append(task)
            self._function_call_tasks[task] = runner_item
            task.add_done_callback(self._function_call_task_finished)

    async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
        # Enqueue all function calls for background execution.
        for runner_item in runner_items:
            await self._sequential_runner_queue.put(runner_item)

    async def _call_start_function(
        self, context: OpenAILLMContext | LLMContext, function_name: str
    ):
        if function_name in self._start_callbacks.keys():
            await self._start_callbacks[function_name](function_name, self, context)
        elif None in self._start_callbacks.keys():
            return await self._start_callbacks[None](function_name, self, context)

    async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
        if runner_item.function_name in self._functions.keys():
            item = self._functions[runner_item.function_name]
        elif None in self._functions.keys():
            item = self._functions[None]
        else:
            return

        logger.debug(
            f"{self} Calling function [{runner_item.function_name}:{runner_item.tool_call_id}] with arguments {runner_item.arguments}"
        )

        # NOTE(aleix): This needs to be removed after we remove the deprecation.
        await self._call_start_function(runner_item.context, runner_item.function_name)

        # Broadcast function call in-progress. This frame will let our assistant
        # context aggregator know that we are in the middle of a function
        # call. Some contexts/aggregators may not need this. But some definitely
        # do (Anthropic, for example).
        await self.broadcast_frame(
            FunctionCallInProgressFrame,
            function_name=runner_item.function_name,
            tool_call_id=runner_item.tool_call_id,
            arguments=runner_item.arguments,
            cancel_on_interruption=item.cancel_on_interruption,
        )

        timeout_task: Optional[asyncio.Task] = None

        # Define a callback function that pushes a FunctionCallResultFrame upstream & downstream.
        async def function_call_result_callback(
            result: Any, *, properties: Optional[FunctionCallResultProperties] = None
        ):
            nonlocal timeout_task

            # Cancel timeout task if it exists
            if timeout_task and not timeout_task.done():
                await self.cancel_task(timeout_task)

            await self.broadcast_frame(
                FunctionCallResultFrame,
                function_name=runner_item.function_name,
                tool_call_id=runner_item.tool_call_id,
                arguments=runner_item.arguments,
                result=result,
                run_llm=runner_item.run_llm,
                properties=properties,
            )

        # Start a timeout task for deferred function calls
        async def timeout_handler():
            try:
                effective_timeout = (
                    item.timeout_secs
                    if item.timeout_secs is not None
                    else self._function_call_timeout_secs
                )
                await asyncio.sleep(effective_timeout)
                logger.warning(
                    f"{self} Function call [{runner_item.function_name}:{runner_item.tool_call_id}] timed out after {effective_timeout} seconds."
                    f" You can increase this timeout by passing `timeout_secs` to `register_function()`,"
                    f" or set a global default via `function_call_timeout_secs` on the LLM constructor."
                )
                await function_call_result_callback(None)
            except asyncio.CancelledError:
                raise

        timeout_task = self.create_task(timeout_handler())

        try:
            # Yield to the event loop so the timeout task coroutine gets entered
            # before it could be cancelled. Without this, cancelling the task before
            # it starts would leave the coroutine in a "never awaited" state.
            await asyncio.sleep(0)
            if isinstance(item.handler, DirectFunctionWrapper):
                # Handler is a DirectFunctionWrapper
                await item.handler.invoke(
                    args=runner_item.arguments,
                    params=FunctionCallParams(
                        function_name=runner_item.function_name,
                        tool_call_id=runner_item.tool_call_id,
                        arguments=runner_item.arguments,
                        llm=self,
                        context=runner_item.context,
                        result_callback=function_call_result_callback,
                    ),
                )
            else:
                # Handler is a FunctionCallHandler
                if item.handler_deprecated:
                    await item.handler(
                        runner_item.function_name,
                        runner_item.tool_call_id,
                        runner_item.arguments,
                        self,
                        runner_item.context,
                        function_call_result_callback,
                    )
                else:
                    params = FunctionCallParams(
                        function_name=runner_item.function_name,
                        tool_call_id=runner_item.tool_call_id,
                        arguments=runner_item.arguments,
                        llm=self,
                        context=runner_item.context,
                        result_callback=function_call_result_callback,
                    )
                    await item.handler(params)
        except Exception as e:
            error_message = f"Error executing function call [{runner_item.function_name}]: {e}"
            logger.error(f"{self} {error_message}")
            await self.push_error(error_msg=error_message, exception=e, fatal=False)
        finally:
            if timeout_task and not timeout_task.done():
                await self.cancel_task(timeout_task)

    async def _cancel_function_call(self, function_name: Optional[str]):
        cancelled_tasks = set()
        for task, runner_item in self._function_call_tasks.items():
            if runner_item.registry_item.function_name == function_name:
                name = runner_item.function_name
                tool_call_id = runner_item.tool_call_id

                logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")

                if task:
                    # We remove the callback because we are going to cancel the
                    # task next, otherwise we will be removing it from the set
                    # while we are iterating.
                    task.remove_done_callback(self._function_call_task_finished)
                    await self.cancel_task(task)
                    cancelled_tasks.add(task)

                await self.broadcast_frame(
                    FunctionCallCancelFrame, function_name=name, tool_call_id=tool_call_id
                )

                logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")

        # Remove all cancelled tasks from our set.
        for task in cancelled_tasks:
            self._function_call_task_finished(task)

    def _function_call_task_finished(self, task: asyncio.Task):
        if task in self._function_call_tasks:
            del self._function_call_tasks[task]
