# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable

import torch
from torch import nn

from ... import initialization as init
from ...cache_utils import Cache
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..bamba.configuration_bamba import BambaConfig
from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache
from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding
from ..granitemoeshared.modeling_granitemoeshared import (
    GraniteFlashAttentionKwargs,
    GraniteMoeSharedAttention,
    GraniteMoeSharedDecoderLayer,
    GraniteMoeSharedForCausalLM,
    GraniteMoeSharedMLP,
    GraniteMoeSharedModel,
    GraniteMoeSharedMoE,
    GraniteMoeSharedPreTrainedModel,
    apply_rotary_pos_emb,
    eager_attention_forward,
)
from .configuration_granitemoehybrid import GraniteMoeHybridConfig


logger = logging.get_logger(__name__)


class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
    def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
        super().__init__(config, layer_idx)

    def forward(  # FIME: @ARTHUR this forward is also classic: attention nope
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None,
        past_key_values: Cache | None = None,
        cache_position: torch.LongTensor | None = None,
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,  # None or rope embeddings
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        if position_embeddings is not None:
            cos, sin = position_embeddings
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            cache_kwargs = {"cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class GraniteMoeHybridMambaLayer(BambaMixer):
    def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
        super().__init__(BambaConfig(config), layer_idx)


class GraniteMoeHybridRMSNormGated(BambaRMSNormGated):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__(hidden_size, eps)


class GraniteMoeHybridMLP(GraniteMoeSharedMLP):
    def __init__(self, config: GraniteMoeHybridConfig):
        super().__init__(config)


class GraniteMoeHybridRotaryEmbedding(Gemma2RotaryEmbedding):
    pass


class GraniteMoeHybridMoE(GraniteMoeSharedMoE):
    pass


class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
    def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        self.shared_mlp = GraniteMoeHybridMLP(config)
        # Either attention or mamba will be initialized, depending on the layer type.
        self.self_attn = None
        self.mamba = None

        if config.layers_block_type[layer_idx] == "mamba":
            self.mamba = GraniteMoeHybridMambaLayer(config, layer_idx)
        else:
            self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
        self.layer_type = config.layers_block_type[layer_idx]

        # Allow non-MoE (dense)
        self.block_sparse_moe = GraniteMoeHybridMoE(config) if config.num_local_experts > 0 else None

        # Accept 0 experts: skip MoE if num_local_experts == 0
        self.has_experts = getattr(config, "num_local_experts", 0) > 0

    @auto_docstring
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        past_key_values: Cache | None = None,
        use_cache: bool | None = False,
        cache_position: torch.LongTensor | None = None,
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
        **kwargs: Unpack[GraniteFlashAttentionKwargs],
    ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        if self.mamba is not None:
            hidden_states = self.mamba(
                hidden_states=hidden_states,
                cache_position=cache_position,
                cache_params=past_key_values,
                attention_mask=attention_mask,
                **kwargs,
            )
        else:
            hidden_states, _ = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)

        if self.has_experts:
            moe_hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
        else:
            hidden_states = self.shared_mlp(hidden_states)

        hidden_states = residual + hidden_states * self.residual_multiplier
        return hidden_states


class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel):
    config: GraniteMoeHybridConfig
    _no_split_modules = ["GraniteMoeHybridDecoderLayer"]
    _is_stateful = True

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        if isinstance(module, GraniteMoeHybridMambaLayer):
            init.ones_(module.dt_bias)
            init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
            init.ones_(module.D)
        elif isinstance(module, GraniteMoeHybridRMSNormGated):
            init.ones_(module.weight)


class GraniteMoeHybridModel(GraniteMoeSharedModel):
    def __init__(self, config: GraniteMoeHybridConfig):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.embedding_multiplier = config.embedding_multiplier
        self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if config.position_embedding_type == "rope" else None

    @auto_docstring
    @merge_with_config_defaults
    @capture_outputs
    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        use_cache: bool | None = None,
        cache_position: torch.LongTensor | None = None,
        **kwargs: Unpack[GraniteFlashAttentionKwargs],
    ) -> tuple | BaseModelOutputWithPast:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        inputs_embeds = inputs_embeds * self.embedding_multiplier

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = create_causal_mask(
            self.config,
            inputs_embeds,
            attention_mask,
            cache_position,
            past_key_values,
        )
        mamba_mask = self._update_mamba_mask(attention_mask, cache_position)

        # embed positions
        hidden_states = inputs_embeds
        position_embeddings = None
        if self.rotary_emb is not None:
            position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for decoder_layer in self.layers:
            # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
            layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask

            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=layer_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )
        hidden_states = self.norm(hidden_states)

        if past_key_values and not past_key_values.has_previous_state:
            past_key_values.has_previous_state = True

        return MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
        )

    def _update_mamba_mask(self, attention_mask, cache_position):
        """
        No need for zeroing states when
            1. Cached forward
            2. Attending to all inputs
        """
        mamba_mask = attention_mask
        if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
            mamba_mask = None
        return mamba_mask


class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
    _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}

    def __init__(self, config: GraniteMoeHybridConfig):
        super().__init__(config)
        self.model = GraniteMoeHybridModel(config)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, **super_kwargs):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, GraniteMoeHybridForCausalLM

        >>> model = GraniteMoeHybridForCausalLM.from_pretrained("ibm-granite/granite-4.0-h-tiny")
        >>> tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-4.0-h-tiny")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        return super().forward(**super_kwargs)

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        is_first_iteration=False,
        **kwargs,
    ):
        # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`

        if past_key_values is None and use_cache:
            past_key_values = HybridMambaAttentionDynamicCache(
                self.config, input_ids.shape[0], self.dtype, device=self.device
            )

        model_inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            position_ids=position_ids,
            use_cache=use_cache,
            is_first_iteration=is_first_iteration,
            **kwargs,
        )

        return model_inputs


__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"]
