#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_vibevoice_asr.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch import nn

from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_vibevoice_asr import VibeVoiceAsrConfig


@use_kernel_forward_from_hub("RMSNorm")
class VibeVoiceAsrRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        """
        VibeVoiceAsrRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class VibeVoiceAsrMultiModalProjector(nn.Module):
    def __init__(self, config: VibeVoiceAsrConfig):
        super().__init__()
        # Acoustic path
        self.acoustic_linear_1 = nn.Linear(
            config.acoustic_tokenizer_encoder_config.hidden_size, config.text_config.hidden_size
        )
        self.acoustic_norm = VibeVoiceAsrRMSNorm(config.text_config.hidden_size, eps=1e-6)
        self.acoustic_linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)

        # Semantic path
        self.semantic_linear_1 = nn.Linear(
            config.semantic_tokenizer_encoder_config.hidden_size, config.text_config.hidden_size
        )
        self.semantic_norm = VibeVoiceAsrRMSNorm(config.text_config.hidden_size, eps=1e-6)
        self.semantic_linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)

    def forward(self, acoustic_latents, semantic_latents):
        acoustic_features = self.acoustic_linear_1(acoustic_latents)
        acoustic_features = self.acoustic_norm(acoustic_features)
        acoustic_features = self.acoustic_linear_2(acoustic_features)

        semantic_features = self.semantic_linear_1(semantic_latents)
        semantic_features = self.semantic_norm(semantic_features)
        semantic_features = self.semantic_linear_2(semantic_features)

        return acoustic_features + semantic_features


class VibeVoiceAsrFeedForward(nn.Module):
    def __init__(self, config, hidden_size):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, config.ffn_expansion * hidden_size)
        self.activation = ACT2FN[config.hidden_act]
        self.linear2 = nn.Linear(config.ffn_expansion * hidden_size, hidden_size)

    def forward(self, hidden_states):
        return self.linear2(self.activation(self.linear1(hidden_states)))


class VibeVoiceAsrConv1dCacheLayer:
    def __init__(self):
        self.cache: torch.Tensor | None = None
        self.is_initialized: bool = False

    def lazy_initialization(self, hidden_states, conv_module):
        self.left_pad = conv_module.left_pad
        self.in_channels = conv_module.in_channels
        self.cache = torch.zeros(
            hidden_states.shape[0],
            self.in_channels,
            self.left_pad,
            device=hidden_states.device,
            dtype=hidden_states.dtype,
        )

        if not is_torchdynamo_compiling():
            torch._dynamo.mark_static_address(self.cache)

        self.is_initialized = True

    def update(self, hidden_states, conv_module=None):
        if not self.is_initialized and conv_module is not None:
            self.lazy_initialization(hidden_states, conv_module)
        elif not self.is_initialized:
            raise ValueError(
                "VibeVoiceAsrConv1dCacheLayer is not initialized. Make sure to provide conv_module to the update method."
            )

        # get the padding states
        if self.left_pad > 0:
            shortfall = max(0, self.left_pad - hidden_states.shape[-1])
            if shortfall > 0:
                padding_states = torch.cat([self.cache[:, :, -shortfall:], hidden_states], dim=-1)
            else:
                padding_states = hidden_states[:, :, -self.left_pad :]
        else:
            padding_states = torch.empty(
                hidden_states.shape[0], self.in_channels, 0, dtype=hidden_states.dtype, device=hidden_states.device
            )

        current_cache = self.cache.clone()
        self.cache.copy_(padding_states)

        return current_cache


class VibeVoiceAsrConv1dPaddingCache:
    def __init__(self):
        self.layers = {}

    def update(self, hidden_states, cache_key, conv_module):
        if cache_key not in self.layers:
            self.layers[cache_key] = VibeVoiceAsrConv1dCacheLayer()

        padding_states = self.layers[cache_key].update(hidden_states, conv_module)
        padded_hidden_states = torch.cat([padding_states, hidden_states], dim=-1)
        return padded_hidden_states


# TODO: @eustlb, @ebezzam this should be latter factorized with other causalconv1d (e.g. VoxtralRealtimeCausalConv1d)
class VibeVoiceAsrCausalConv1d(nn.Module):
    """Conv1d with built-in causal padding and optional streaming support through a cache."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        cache_key: str,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
    ):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups)
        self.causal_padding = (kernel_size - 1) * dilation - (stride - 1)
        if self.causal_padding < 0:
            raise ValueError(
                f"Invalid causal padding {self.causal_padding} for kernel_size={kernel_size}, "
                f"dilation={dilation}, stride={stride}."
            )
        self.cache_key = cache_key
        self.in_channels = in_channels
        self.left_pad = self.causal_padding

    def forward(
        self,
        hidden_states: torch.Tensor,
        padding_cache: VibeVoiceAsrConv1dPaddingCache | None = None,
    ) -> torch.Tensor:
        if padding_cache is not None:
            hidden_states = padding_cache.update(hidden_states, self.cache_key, self)
        else:
            hidden_states = nn.functional.pad(hidden_states, (self.left_pad, 0))

        return self.conv(hidden_states)


class VibeVoiceAsrConvNext1dLayer(nn.Module):
    """ConvNeXt-like block adapted for 1D convolutions."""

    def __init__(self, config, hidden_size, dilation=1, stride=1, layer_idx=None):
        super().__init__()

        self.norm = VibeVoiceAsrRMSNorm(hidden_size, eps=config.rms_norm_eps)
        self.ffn_norm = VibeVoiceAsrRMSNorm(hidden_size, eps=config.rms_norm_eps)
        self.ffn = VibeVoiceAsrFeedForward(config, hidden_size)
        self.gamma = nn.Parameter(config.layer_scale_init_value * torch.ones(hidden_size), requires_grad=True)
        self.ffn_gamma = nn.Parameter(config.layer_scale_init_value * torch.ones(hidden_size), requires_grad=True)
        self.mixer = VibeVoiceAsrCausalConv1d(
            in_channels=hidden_size,
            out_channels=hidden_size,
            kernel_size=config.kernel_size,
            cache_key=f"convnext_layer_{layer_idx}",
            groups=hidden_size,
            dilation=dilation,
            stride=stride,
        )

    def forward(self, hidden_states, padding_cache=None):
        # mixer
        residual = hidden_states
        hidden_states = self.norm(hidden_states.transpose(1, 2)).transpose(1, 2)
        hidden_states = self.mixer(hidden_states, padding_cache=padding_cache)
        hidden_states = hidden_states * self.gamma.unsqueeze(-1)
        hidden_states = residual + hidden_states

        # ffn
        residual = hidden_states
        hidden_states = self.ffn_norm(hidden_states.transpose(1, 2))
        hidden_states = self.ffn(hidden_states).transpose(1, 2)
        hidden_states = hidden_states * self.ffn_gamma.unsqueeze(-1)
        return residual + hidden_states


@auto_docstring
class VibeVoiceAsrPreTrainedModel(PreTrainedModel):
    config: VibeVoiceAsrConfig
    base_model_prefix = "model"
    main_input_name = "input_ids"
    _no_split_modules = None
    input_modalities = ("audio", "text")
    supports_gradient_checkpointing = True
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn = True
    _supports_sdpa = True

    def _init_weights(self, module):
        super()._init_weights(module)
        if isinstance(module, VibeVoiceAsrConvNext1dLayer):
            init.constant_(module.gamma, self.config.layer_scale_init_value)
            init.constant_(module.ffn_gamma, self.config.layer_scale_init_value)


@auto_docstring(
    custom_intro="""
    The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model.
    """
)
class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin):
    _keep_in_fp32_modules_strict = None
    _tp_plan = None
    _pp_plan = None

    def __init__(self, config: VibeVoiceAsrConfig):
        super().__init__(config)
        self.vocab_size = config.text_config.vocab_size
        self.language_model = AutoModelForCausalLM.from_config(config.text_config)
        self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config)
        self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config)
        self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        self.language_model.set_input_embeddings(value)

    def get_output_embeddings(self):
        return self.language_model.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        self.language_model.set_output_embeddings(new_embeddings)

    def set_decoder(self, decoder):
        self.language_model.set_decoder(decoder)

    def get_decoder(self):
        return self.language_model.get_decoder()

    @can_return_tuple
    @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.")
    def get_audio_features(
        self,
        input_values: torch.FloatTensor,
        padding_mask: torch.BoolTensor | None = None,
        acoustic_tokenizer_chunk_size: int | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple | BaseModelOutputWithPooling:
        r"""
        input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`):
            Input audio tensor. Audio should be sampled at 24kHz.
        padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing operations on padding feature indices.
        acoustic_tokenizer_chunk_size (`int`, *optional*):
            Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`,
            but can be modified to fit the available memory.
        """

        if acoustic_tokenizer_chunk_size is None:
            acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size
        else:
            if acoustic_tokenizer_chunk_size % self.config.acoustic_tokenizer_encoder_config.hop_length != 0:
                acoustic_tokenizer_chunk_size = int(
                    (acoustic_tokenizer_chunk_size // self.config.acoustic_tokenizer_encoder_config.hop_length)
                    * self.config.acoustic_tokenizer_encoder_config.hop_length
                )
                raise ValueError(
                    f"`acoustic_tokenizer_chunk_size` must be a multiple of hop length ({self.config.acoustic_tokenizer_encoder_config.hop_length}), {acoustic_tokenizer_chunk_size} is a valid option."
                )

        with torch.no_grad():
            acoustic_encoder_cache, semantic_encoder_cache = None, None
            acoustic_latents, semantic_latents = [], []

            for chunk in torch.split(input_values, acoustic_tokenizer_chunk_size, dim=-1):
                acoustic_encoder_output = self.acoustic_tokenizer_encoder(
                    chunk,
                    padding_cache=acoustic_encoder_cache,
                    use_cache=True,
                )
                acoustic_latents.append(acoustic_encoder_output.latents)
                acoustic_encoder_cache = acoustic_encoder_output.padding_cache

                semantic_encoder_output = self.semantic_tokenizer_encoder(
                    chunk,
                    padding_cache=semantic_encoder_cache,
                    use_cache=True,
                )
                semantic_latents.append(semantic_encoder_output.latents)
                semantic_encoder_cache = semantic_encoder_output.padding_cache

            acoustic_latents = torch.cat(acoustic_latents, dim=1)
            semantic_latents = torch.cat(semantic_latents, dim=1)

            # Sample acoustic tokens
            noise_std = self.config.acoustic_tokenizer_encoder_config.vae_std * torch.randn(
                acoustic_latents.shape[0], device=acoustic_latents.device, dtype=acoustic_latents.dtype
            )
            acoustic_latents = acoustic_latents + noise_std[:, None, None] * torch.randn_like(acoustic_latents)

        combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents)
        if padding_mask is not None:
            # Adjust padding mask according to tokenizer compression
            num_audio_tokens = torch.ceil(
                padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length
            ).to(torch.int64)
            padding_mask = torch.arange(num_audio_tokens.max(), device=combined_features.device) < num_audio_tokens[
                :, None
            ].to(combined_features.device)
            combined_features = combined_features[padding_mask]

        return BaseModelOutputWithPooling(last_hidden_state=acoustic_latents, pooler_output=combined_features)

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        input_values: torch.FloatTensor | None = None,
        padding_mask: torch.BoolTensor | None = None,
        acoustic_tokenizer_chunk_size: int | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> CausalLMOutputWithPast:
        r"""
        padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing operations on padding feature indices.
        acoustic_tokenizer_chunk_size (`int`, *optional*):
            Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to
            `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory.

        Example:

        ```python
        >>> from transformers import VibeVoiceAsrForConditionalGeneration, AutoProcessor

        >>> model_id = "microsoft/VibeVoice-ASR-HF"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto")

        >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
        >>> inputs = inputs.to(model.device, dtype=model.dtype)
        >>> outputs = model.generate(**inputs)

        >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
        >>> print(decoded_outputs)
        ```"""

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        if input_values is not None and input_ids is not None:
            audio_embeds = self.get_audio_features(
                input_values=input_values,
                padding_mask=padding_mask,
                acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size,
            ).pooler_output

            # Replace text-audio token placeholders with audio embeddings
            audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
            inputs_embeds = inputs_embeds.masked_scatter(
                audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
            )

        return self.language_model(
            inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs
        )

    def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs):
        input_values = kwargs.pop("input_values", None)
        padding_mask = kwargs.pop("padding_mask", None)
        acoustic_tokenizer_chunk_size = kwargs.pop("acoustic_tokenizer_chunk_size", None)

        model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)

        if is_first_iteration:
            if input_values is not None:
                model_inputs["input_values"] = input_values
            if padding_mask is not None:
                model_inputs["padding_mask"] = padding_mask
            if acoustic_tokenizer_chunk_size is not None:
                model_inputs["acoustic_tokenizer_chunk_size"] = acoustic_tokenizer_chunk_size

        return model_inputs


__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrPreTrainedModel"]
