# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from pathlib import Path
from typing import Any

import numpy as np

from ...utils import auto_docstring, is_soundfile_available, is_torch_available


if is_torch_available():
    import torch

if is_soundfile_available():
    import soundfile as sf

from ...audio_utils import AudioInput, make_list_of_audio
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput


class CsmAudioKwargs(AudioKwargs, total=False):
    """
    encoded_length_kwargs (`dict[str, Any]`, *optional*):
        Dictionary of keyword arguments used to compute the encoded audio sequence length. This includes parameters
        such as `kernel_sizes`, `strides`, `dilations`, and `use_causal_conv` that define the convolutional layers
        used in audio encoding. The encoded length is used to determine how many audio tokens to generate for each
        audio input in the text sequence.
    """

    encoded_length_kwargs: dict[str, Any] | None


class CsmProcessorKwargs(ProcessingKwargs, total=False):
    audio_kwargs: CsmAudioKwargs
    _defaults = {
        "text_kwargs": {
            "padding": True,
            "padding_side": "left",
            "add_special_tokens": False,
        },
        "audio_kwargs": {
            "encoded_length_kwargs": {
                "kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
                "strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
                "dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                "use_causal_conv": True,
            },
            "sampling_rate": 24000,
        },
        "common_kwargs": {"return_tensors": "pt"},
    }


@auto_docstring
class CsmProcessor(ProcessorMixin):
    def __init__(
        self,
        feature_extractor,
        tokenizer,
        chat_template=None,
    ):
        if not hasattr(tokenizer, "audio_token"):
            self.audio_token = "<|AUDIO|>"
            self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
        else:
            self.audio_token = tokenizer.audio_token
            self.audio_token_id = tokenizer.audio_token_id

        if not hasattr(tokenizer, "audio_eos_token"):
            self.audio_eos_token = "<|audio_eos|>"
            self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
        else:
            self.audio_eos_token = tokenizer.audio_eos_token
            self.audio_eos_token_id = tokenizer.audio_eos_token_id

        super().__init__(feature_extractor, tokenizer, chat_template=chat_template)

    @staticmethod
    def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
        """
        Compute the length of the encoded audio sequence.

        Args:
            audio_length (int): The length of the audio sequence.
            kernel_sizes (list[int]): The kernel sizes for the convolutional layers.
            strides (list[int]): The strides for the convolutional layers.
            use_causal_conv (bool): Whether to use causal convolutions.
        """
        cur_length = audio_length

        if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
            return cur_length

        for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
            effective_kernel_size = (kernel_size - 1) * dilation + 1
            padding_total = kernel_size - stride
            padding_right = padding_total // 2
            padding_left = padding_total - padding_right

            n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
            n_frames = math.ceil(n_frames) - 1
            ideal_length = n_frames * stride + kernel_size - padding_total
            extra_padding = ideal_length - cur_length

            if use_causal_conv:
                padding_left = padding_total
                padding_right = extra_padding
            else:
                padding_right = padding_right + extra_padding

            cur_length = cur_length + padding_left + padding_right
            cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1

        return cur_length

    def save_audio(
        self,
        audio: AudioInput,
        saving_path: str | Path | list[str | Path],
        **kwargs: Unpack[CsmProcessorKwargs],
    ):
        # TODO: @eustlb, this should be in AudioProcessor
        if not is_soundfile_available():
            raise ImportError("Please install `soundfile` to save audio files.")

        # ensure correct audio input
        audio = make_list_of_audio(audio)

        # ensure correct saving path
        if isinstance(saving_path, (str, Path)):
            saving_path = [saving_path]
        elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
            raise ValueError("Invalid input path. Please provide a string, or a list of strings")

        if len(audio) != len(saving_path):
            raise ValueError("The number of audio and saving paths must be the same")

        output_kwargs = self._merge_kwargs(
            CsmProcessorKwargs,
            **kwargs,
        )
        audio_kwargs = output_kwargs["audio_kwargs"]
        sampling_rate = audio_kwargs["sampling_rate"]

        for audio_value, p in zip(audio, saving_path):
            if isinstance(audio_value, torch.Tensor):
                audio_value = audio_value.cpu().float().numpy()
            sf.write(p, audio_value, sampling_rate)

    @auto_docstring
    def __call__(
        self,
        text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None,
        audio: AudioInput | None = None,
        output_labels: bool | None = False,
        depth_decoder_labels_ratio: float | None = 1.0,
        **kwargs: Unpack[CsmProcessorKwargs],
    ):
        r"""
        output_labels (bool, *optional*, default=False):
            Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
            - `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
            - `-100` will be ignored in the loss computation
            - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
        depth_decoder_labels_ratio (float, *optional*, default=1.0):
            The ratio of audio frames to keep for the depth decoder labels.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
            - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
        """

        output_kwargs = self._merge_kwargs(
            CsmProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )

        text_kwargs = output_kwargs["text_kwargs"]
        audio_kwargs = output_kwargs["audio_kwargs"]
        return_tensors = text_kwargs.get("return_tensors", None)
        if return_tensors != "pt":
            raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")

        if isinstance(text, str):
            text = [text]
        elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
            raise ValueError("Invalid input text. Please provide a string, or a list of strings")
        n_audio_in_text = [t.count(self.audio_token) for t in text]

        n_audio = 0
        if audio is not None:
            audio = make_list_of_audio(audio)
            n_audio = len(audio)

        if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
            if audio is None:
                raise ValueError("No audio were provided, but there are audio tokens in the prompt")
            else:
                raise ValueError(
                    f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
                    f"number of provided audios ({n_audio})."
                )

        if audio is not None:
            encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
            num_audio_tokens_list = [
                self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
            ]
            num_audio_tokens_list_copy = num_audio_tokens_list.copy()

            # expand the text to repeat the audio token for the corresponding number of frames
            expanded_text = []
            for sample in text:
                replace_str = []
                while self.audio_token in sample:
                    num_audio_tokens = num_audio_tokens_list_copy.pop(0)
                    expanded_audio_token = self.audio_token * num_audio_tokens

                    replace_str.append(expanded_audio_token)
                    sample = sample.replace(self.audio_token, "<placeholder>", 1)

                while "<placeholder>" in sample:
                    sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
                expanded_text.append(sample)

            text = expanded_text

        encoding = self.tokenizer(text, **text_kwargs)
        data = {}
        data.update(encoding)

        if audio is not None:
            audio_kwargs.pop("return_attention_mask", None)  # not supported by the feature extractor

            concatenated_audio, input_values_cutoffs = [], []
            offset = 0
            for n_audio in n_audio_in_text:
                if n_audio == 0:
                    concatenated_audio.append(np.zeros(0))
                    input_values_cutoffs.append(torch.tensor([-1]))
                else:
                    concatenated_audio.append(
                        np.concatenate(
                            [
                                el.cpu().numpy() if isinstance(el, torch.Tensor) else el
                                for el in audio[offset : offset + n_audio]
                            ],
                            axis=-1,
                        )
                    )
                    input_values_cutoffs.append(
                        torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
                    )
                    offset += n_audio

            audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
            audio_inputs.pop("padding_mask", None)  # not applicable here
            data.update(audio_inputs)

            # pad and stack the audio cut idxs
            max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
            input_values_cutoffs = [
                torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
                for cut_idxs in input_values_cutoffs
            ]
            data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)

        if output_labels:
            audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
            n_audio_frames = audio_frame_idxs.shape[0]

            if depth_decoder_labels_ratio <= 1.0:
                rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
                skip_frames_idxs = audio_frame_idxs[rand_idxs]
            else:
                skip_frames_idxs = audio_frame_idxs

            labels = torch.where(
                (data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
                data["input_ids"],
                -100,
            )
            labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101

            data["labels"] = labels

        return BatchFeature(data=data, tensor_type=return_tensors)

    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        feature_extractor_input_names = self.feature_extractor.model_input_names

        # Remove `padding_mask`, it is popped and not used when processing. Make a copy of list when removing
        # otherwise `self.feature_extractor.model_input_names` is also modified
        feature_extractor_input_names = [name for name in feature_extractor_input_names if name != "padding_mask"]
        return list(tokenizer_input_names + feature_extractor_input_names + ["input_values_cutoffs"])


__all__ = ["CsmProcessor"]
