#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/timesfm2_5/modular_timesfm2_5.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_timesfm2_5.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ... import initialization as init
from ...activations import ACT2FN
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from .configuration_timesfm2_5 import TimesFm2_5Config


@dataclass
@auto_docstring
class TimesFm2_5Output(BaseModelOutput):
    r"""
    context_mu (`torch.Tensor` of shape `(batch_size, num_patches)`):
        Running means computed per input patch during normalization.
    context_sigma (`torch.Tensor` of shape `(batch_size, num_patches)`):
        Running standard deviations computed per input patch during normalization.
    """

    loc: torch.Tensor | None = None
    scale: torch.Tensor | None = None

    context_mu: torch.Tensor | None = None
    context_sigma: torch.Tensor | None = None


@dataclass
@auto_docstring
class TimesFm2_5OutputForPrediction(BaseModelOutput):
    r"""
    mean_predictions (`torch.Tensor` of shape `(batch_size, horizon_length)`):
        Deterministic forecasts after denormalization.
    full_predictions (`torch.Tensor` of shape `(batch_size, horizon_length, quantiles)`):
        Quantile forecasts including the median after denormalization.
    loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
        Training loss combining MSE and quantile losses when targets are supplied.
    """

    mean_predictions: torch.Tensor | None = None
    full_predictions: torch.Tensor | None = None
    loss: torch.Tensor | float | None = None


class TimesFm2_5MLP(nn.Module):
    def __init__(self, config: TimesFm2_5Config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.activation]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.use_bias)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class TimesFm2_5ResidualBlock(nn.Module):
    """[`TimesFmResidualBlock`] variant with configurable `use_bias` and `activation`."""

    def __init__(self, config, input_dims: int, hidden_dims: int, output_dims: int, use_bias: bool | None = None):
        super().__init__()
        self.input_dims = input_dims
        self.hidden_dims = hidden_dims
        self.output_dims = output_dims
        self.input_layer = nn.Linear(input_dims, hidden_dims, bias=use_bias)
        self.activation = ACT2FN[config.activation]
        self.output_layer = nn.Linear(hidden_dims, output_dims, bias=use_bias)
        self.residual_layer = nn.Linear(input_dims, output_dims, bias=use_bias)
        use_bias = use_bias if use_bias is not None else config.use_bias

    def forward(self, x):
        # Align activations to block parameter dtype for mixed precision stability
        x = x.to(self.input_layer.weight.dtype)
        hidden = self.input_layer(x)
        hidden = self.activation(hidden)
        output = self.output_layer(hidden)
        residual = self.residual_layer(x)
        return output + residual


@use_kernel_forward_from_hub("RMSNorm")
class TimesFm2_5RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        """
        TimesFm2_5RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class TimesFm2_5RotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: TimesFm2_5Config, device=None):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config

        self.rope_type = self.config.rope_parameters["rope_type"]
        rope_init_fn: Callable = self.compute_default_rope_parameters
        if self.rope_type != "default":
            rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)

    @staticmethod
    def compute_default_rope_parameters(
        config: TimesFm2_5Config | None = None,
        device: Optional["torch.device"] = None,
        seq_len: int | None = None,
    ) -> tuple["torch.Tensor", float]:
        """
        Computes the inverse frequencies according to the original RoPE implementation
        Args:
            config ([`~transformers.PreTrainedConfig`]):
                The model configuration.
            device (`torch.device`):
                The device to use for initialization of the inverse frequencies.
            seq_len (`int`, *optional*):
                The current sequence length. Unused for this type of RoPE.
        Returns:
            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
        """
        base = config.rope_parameters["rope_theta"]
        dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

        attention_factor = 1.0  # Unused in this type of RoPE

        # Compute the inverse frequencies
        inv_freq = 1.0 / (
            base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
        )
        return inv_freq, attention_factor

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


@use_kernel_func_from_hub("rotary_pos_emb")
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float,
    dropout: float = 0.0,
    **kwargs: Unpack[TransformersKwargs],
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class TimesFm2_5Attention(nn.Module):
    """TimesFM 2.5 attention with learnable per-dimension query scaling."""

    def __init__(self, config: TimesFm2_5Config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )
        self.q_norm = TimesFm2_5RMSNorm(self.head_dim, config.rms_norm_eps)
        self.k_norm = TimesFm2_5RMSNorm(self.head_dim, config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: torch.Tensor | None,
        past_key_values=None,
        cache_position: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        query_states = self.q_norm(query_states)
        key_states = self.k_norm(key_states)

        scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
        query_states = query_states * scale[None, None, None, :]

        if past_key_values is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=self.attention_dropout if self.training else 0.0,
            # scaling=1.0 because per-dimension learnable scaling is already applied to query_states above
            scaling=1.0,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class TimesFm2_5DecoderLayer(GradientCheckpointingLayer):
    """TimesFM 2.5 Transformer decoder layer with pre/post RMS normalization and no KV cache."""

    def __init__(self, config: TimesFm2_5Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = TimesFm2_5Attention(config=config, layer_idx=layer_idx)

        self.mlp = TimesFm2_5MLP(config)
        self.input_layernorm = TimesFm2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = TimesFm2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm = TimesFm2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = TimesFm2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.Tensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            **kwargs,
        )
        hidden_states = self.post_attention_layernorm(hidden_states) + residual

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states) + residual

        return hidden_states


class TimesFm2_5PositionalEmbedding(nn.Module):
    """Generates position embedding for a given 1-d sequence."""

    def __init__(self, config: TimesFm2_5Config):
        super().__init__()
        min_timescale = config.min_timescale
        max_timescale = config.max_timescale
        self.min_timescale, self.max_timescale = min_timescale, max_timescale
        self.embedding_dims = config.hidden_size

        num_timescales = self.embedding_dims // 2
        log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
        self.register_buffer(
            "inv_timescales",
            min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
        )

    def forward(self, seq_length=None, position=None):
        """Generates a Tensor of sinusoids with different frequencies.

        Args:
            seq_length: an optional Python int defining the output sequence length.
              if the `position` argument is specified.
            position: [B, seq_length], optional position for each token in the
              sequence, only required when the sequence is packed.

        Returns:
            [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
        """
        if position is None and seq_length is None:
            raise ValueError("Either position or seq_length must be provided")

        if position is None:
            # [1, seqlen]
            position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
        elif position.ndim != 2:
            raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")

        scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
        signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)

        # Padding to ensure correct embedding dimension
        signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
        return signal


@auto_docstring
class TimesFm2_5PreTrainedModel(PreTrainedModel):
    config: TimesFm2_5Config
    base_model_prefix = "model"
    _no_split_modules = ["TimesFm2_5DecoderLayer"]
    main_input_name = "past_values"
    input_modalities = ("time",)
    _supports_sdpa = True
    config_class = TimesFm2_5Config
    _supports_flash_attn = True
    _supports_flex_attn = True
    _can_record_outputs = {
        "hidden_states": TimesFm2_5DecoderLayer,
        "attentions": TimesFm2_5Attention,
    }

    @torch.no_grad()
    def _init_weights(self, module):
        super()._init_weights(module)
        if isinstance(module, TimesFm2_5Attention):
            # Initialize scaling parameter
            init.ones_(module.scaling)
        elif isinstance(module, TimesFm2_5PositionalEmbedding):
            num_timescales = module.embedding_dims // 2
            max_timescale, min_timescale = module.max_timescale, module.min_timescale
            log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
                num_timescales - 1, 1
            )
            init.copy_(
                module.inv_timescales,
                min_timescale
                * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
            )


class TimesFm2_5Model(TimesFm2_5PreTrainedModel):
    def __init__(self, config: TimesFm2_5Config):
        super().__init__(config)
        self.config = config
        self.tolerance = 1e-6

        self.input_ff_layer = TimesFm2_5ResidualBlock(
            config,
            input_dims=2 * config.patch_length,
            hidden_dims=config.hidden_size,
            output_dims=config.hidden_size,
            use_bias=True,
        )

        self.layers = nn.ModuleList(
            [TimesFm2_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.rotary_emb = TimesFm2_5RotaryEmbedding(config)
        self.gradient_checkpointing = False

        self.post_init()

    def _revin(
        self,
        hidden_states: torch.Tensor,
        loc: torch.Tensor,
        scale: torch.Tensor,
        reverse: bool = False,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Reversible instance normalization (RevIN).

        Normalizes or denormalizes `hidden_states` using the provided location and scale statistics.
        When `mask` is provided during normalization (reverse=False), masked positions are zeroed out.
        """
        if len(loc.shape) == len(hidden_states.shape) - 1:
            loc = loc[..., None]
            scale = scale[..., None]
        elif len(loc.shape) == len(hidden_states.shape) - 2:
            loc = loc[..., None, None]
            scale = scale[..., None, None]

        loc = loc.to(hidden_states.device)
        scale = scale.to(hidden_states.device)
        safe_scale = torch.where(scale < self.tolerance, torch.ones_like(scale), scale)

        if reverse:
            return hidden_states * scale + loc

        normed = (hidden_states - loc) / safe_scale
        if mask is not None:
            normed = torch.where(mask, torch.zeros_like(normed), normed)
        return normed

    @staticmethod
    def _update_running_stats(
        count: torch.Tensor,
        mean: torch.Tensor,
        std: torch.Tensor,
        new_values: torch.Tensor,
        mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Update running mean/std using Welford's online algorithm.

        Combines existing statistics (`count`, `mean`, `std`) with a new batch of values,
        respecting the boolean `mask` (True = masked/invalid).
        """
        is_valid = (~mask).to(new_values.dtype)
        inc_count = is_valid.sum(dim=-1)

        inc_count_safe = torch.where(inc_count == 0, torch.ones_like(inc_count), inc_count)
        inc_mean = (new_values * is_valid).sum(dim=-1) / inc_count_safe
        inc_mean = torch.where(inc_count == 0, torch.zeros_like(inc_mean), inc_mean)

        centered = new_values - inc_mean.unsqueeze(-1)
        inc_var = ((centered * is_valid) ** 2).sum(dim=-1) / inc_count_safe
        inc_var = torch.where(inc_count == 0, torch.zeros_like(inc_var), inc_var)
        inc_std = torch.sqrt(torch.clamp(inc_var, min=0.0))

        new_count = count + inc_count
        new_count_safe = torch.where(new_count == 0, torch.ones_like(new_count), new_count)

        new_mean = (count * mean + inc_mean * inc_count) / new_count_safe
        new_mean = torch.where(new_count == 0, torch.zeros_like(new_mean), new_mean)

        term1 = count * std.pow(2)
        term2 = inc_count * inc_std.pow(2)
        term3 = count * (mean - new_mean).pow(2)
        term4 = inc_count * (inc_mean - new_mean).pow(2)

        new_var = (term1 + term2 + term3 + term4) / new_count_safe
        new_var = torch.where(new_count == 0, torch.zeros_like(new_var), new_var)
        new_std = torch.sqrt(torch.clamp(new_var, min=0.0))

        return new_count, new_mean, new_std

    @merge_with_config_defaults
    @capture_outputs
    @auto_docstring
    def forward(
        self,
        past_values: torch.Tensor,
        past_values_padding: torch.LongTensor | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> TimesFm2_5Output:
        r"""
        past_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
            Past values of the time series used as input to the model.
        past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Padding mask for the input. `1` indicates padded (masked) time steps, `0` indicates valid values.
        """
        batch_size, seq_len = past_values.shape
        patch_len = self.config.patch_length

        if past_values_padding is None:
            past_values_padding = torch.zeros_like(past_values, dtype=torch.long)

        patched_inputs = past_values.view(batch_size, -1, patch_len)
        patched_masks = past_values_padding[:, :seq_len].view(batch_size, -1, patch_len)
        patched_masks_bool = patched_masks >= 0.5

        count = past_values.new_zeros(batch_size)
        mean = past_values.new_zeros(batch_size)
        std = past_values.new_zeros(batch_size)
        mean_history: list[torch.Tensor] = []
        std_history: list[torch.Tensor] = []

        for i in range(patched_inputs.shape[1]):
            count, mean, std = self._update_running_stats(
                count, mean, std, patched_inputs[:, i, :], patched_masks_bool[:, i, :]
            )
            mean_history.append(mean)
            std_history.append(std)

        if mean_history:
            context_mu = torch.stack(mean_history, dim=1)
            context_sigma = torch.stack(std_history, dim=1)
        else:
            context_mu = mean.unsqueeze(1)
            context_sigma = std.unsqueeze(1)

        normed_inputs = self._revin(patched_inputs, context_mu, context_sigma, reverse=False, mask=patched_masks_bool)

        tokenizer_inputs = torch.cat(
            [normed_inputs, patched_masks_bool.to(dtype=normed_inputs.dtype)],
            dim=-1,
        )
        input_embeddings = self.input_ff_layer(tokenizer_inputs)

        patch_padding = patched_masks_bool[..., -1]

        sequence_length = input_embeddings.shape[1]
        num_masked = patch_padding.to(torch.int32).sum(dim=-1, keepdim=True)
        position_ids = torch.arange(sequence_length, device=input_embeddings.device).unsqueeze(0) - num_masked

        padding_mask = (~patch_padding).to(torch.int64)
        cache_position = torch.arange(sequence_length, device=input_embeddings.device)
        attention_mask = create_causal_mask(
            self.config, input_embeddings, padding_mask, cache_position, past_key_values=None
        )
        position_embeddings = self.rotary_emb(input_embeddings, position_ids)

        hidden_states = input_embeddings

        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                position_embeddings=position_embeddings,
                attention_mask=attention_mask,
                position_ids=position_ids,
                **kwargs,
            )

        loc = context_mu[:, -1]
        scale = torch.clamp(context_sigma[:, -1], min=self.tolerance)

        return TimesFm2_5Output(
            last_hidden_state=hidden_states,
            loc=loc,
            scale=scale,
            context_mu=context_mu,
            context_sigma=context_sigma,
        )


class TimesFm2_5ModelForPrediction(TimesFm2_5PreTrainedModel):
    """TimesFm2_5 model for quantile and mean prediction."""

    def __init__(self, config: TimesFm2_5Config):
        super().__init__(config)
        self.config = config
        self.context_len = config.context_length
        self.horizon_len = config.horizon_length

        self.model = TimesFm2_5Model(config)

        num_quantiles = len(config.quantiles) + 1
        self.output_projection_point = TimesFm2_5ResidualBlock(
            config,
            input_dims=config.hidden_size,
            hidden_dims=config.hidden_size,
            output_dims=config.horizon_length * num_quantiles,
        )
        self.output_projection_quantiles = TimesFm2_5ResidualBlock(
            config,
            input_dims=config.hidden_size,
            hidden_dims=config.hidden_size,
            output_dims=config.output_quantile_len * num_quantiles,
        )

        # Initialize weights and apply final processing
        self.post_init()

    def _preprocess(
        self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None
    ) -> tuple[torch.Tensor, ...]:
        """Pad/truncate input time series to `context_len` and build a padding mask.

        Args:
            inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task.
            freq: Optional list of frequencies (returned as a tensor when provided).
            context_len: Optional context length override (defaults to `self.context_len`).

        Returns:
            Tuple of (padded_inputs, padding_mask) and optionally a freq tensor.
        """
        if context_len is None:
            context_len = self.context_len

        input_ts, input_padding = [], []

        for ts in inputs:
            input_len = ts.shape[0]
            padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
            if input_len < context_len:
                num_front_pad = context_len - input_len
                ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
                padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
            elif input_len > context_len:
                ts = ts[-context_len:]
                padding = padding[-(context_len + self.horizon_len) :]

            input_ts.append(ts)
            input_padding.append(padding)

        result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))
        if freq is not None:
            result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),)
        return result

    def _postprocess_output(
        self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        """Postprocess output of stacked transformer."""

        # B x N x (H.Q)
        output_ts = self.horizon_ff_layer(model_output)

        # Reshape using view
        b, n, _ = output_ts.shape
        output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)

        mu, sigma = stats
        return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]

    def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        losses = []
        for i, q in enumerate(self.config.quantiles):
            errors = targets - predictions[..., i]
            loss = torch.max((q - 1) * errors, q * errors)
            losses.append(loss.mean())
        return torch.stack(losses).mean()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        past_values: Sequence[torch.Tensor],
        window_size: int | None = None,
        future_values: torch.Tensor | None = None,
        forecast_context_len: int | None = None,
        truncate_negative: bool | None = None,
        force_flip_invariance: bool | None = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> TimesFm2_5OutputForPrediction:
        r"""
        past_values (`Sequence[torch.Tensor]`):
            Past values of the time series that serves as input to the model. Each tensor is a 1D time series.
        window_size (`int`, *optional*):
            Window size of trend + residual decomposition. If `None`, decomposition is not applied.
        future_values (`torch.Tensor`, *optional*):
            Optional future values used to compute the loss.
        forecast_context_len (`int`, *optional*):
            Optional context length override used during forecasting.
        truncate_negative (`bool`, *optional*):
            Whether to clamp outputs to non-negative values. If `None`, defaults to `config.infer_is_positive`.
        force_flip_invariance (`bool`, *optional*):
            Whether to apply the flip-invariance combination. If `None`, defaults to
            `config.force_flip_invariance`.
        """
        forecast_context_len = forecast_context_len or self.context_len
        device = past_values[0].device

        inputs = [ts[-forecast_context_len:] for ts in past_values]
        input_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))

        if window_size is not None:
            new_inputs: list[torch.Tensor] = []
            for ts in inputs:
                new_inputs.extend(self._timesfm_moving_average(ts, window_size))
            inputs = new_inputs

        if truncate_negative is None:
            truncate_negative = self.config.infer_is_positive
        if force_flip_invariance is None:
            force_flip_invariance = self.config.force_flip_invariance

        input_ts, input_padding = self._preprocess(inputs, context_len=forecast_context_len)
        input_ts = input_ts.to(device)
        input_padding = input_padding.to(device)

        mu_global = input_ts.mean(dim=1, keepdim=True)
        sigma_global = input_ts.std(dim=1, keepdim=True)

        normalized_ts = self.model._revin(input_ts, mu_global, sigma_global, reverse=False)

        pf_outputs, quantile_spreads, model_outputs = self._decode_and_project(normalized_ts, input_padding, **kwargs)

        if force_flip_invariance:
            flipped_pf, flipped_qs, _ = self._decode_and_project(-normalized_ts, input_padding, **kwargs)

            def _flip_quantiles(x: torch.Tensor) -> torch.Tensor:
                return torch.cat([x[..., :1], torch.flip(x[..., 1:], dims=(-1,))], dim=-1)

            pf_outputs = (pf_outputs - _flip_quantiles(flipped_pf)) / 2
            quantile_spreads = (quantile_spreads - _flip_quantiles(flipped_qs)) / 2

        horizon = min(self.horizon_len, pf_outputs.shape[1])
        full_forecast = pf_outputs[:, :horizon, :].clone()

        median_index = min(self.config.decode_index, full_forecast.shape[-1] - 1)
        if self.config.use_continuous_quantile_head:
            max_quantile_horizon = min(horizon, quantile_spreads.shape[1])
            for idx, _ in enumerate(self.config.quantiles, start=1):
                if idx == median_index or idx >= full_forecast.shape[-1]:
                    continue
                full_forecast[:, :max_quantile_horizon, idx] = (
                    quantile_spreads[:, :max_quantile_horizon, idx]
                    - quantile_spreads[:, :max_quantile_horizon, median_index]
                    + full_forecast[:, :max_quantile_horizon, median_index]
                )

        full_predictions = self.model._revin(full_forecast, mu_global, sigma_global, reverse=True)
        decode_index = min(self.config.decode_index, full_predictions.shape[-1] - 1)
        mean_predictions = full_predictions[:, :, decode_index]

        if window_size is not None:
            mean_predictions = mean_predictions[0::2, ...] + mean_predictions[1::2, ...]
            full_predictions = full_predictions[0::2, ...] + full_predictions[1::2, ...]

        if truncate_negative:
            zero = torch.zeros(1, device=mean_predictions.device, dtype=mean_predictions.dtype)
            clamped_mean = torch.maximum(mean_predictions, zero)
            clamped_full = torch.maximum(full_predictions, zero)
            should_clamp = (input_min >= 0).to(mean_predictions.device)
            mean_predictions = torch.where(should_clamp, clamped_mean, mean_predictions)
            full_predictions = torch.where(should_clamp, clamped_full, full_predictions)

        loss = None
        if future_values is not None:
            target_len = future_values.shape[1]
            valid_mean_predictions = mean_predictions[:, :target_len]
            valid_full_predictions = full_predictions[:, :target_len]
            mse_loss = F.mse_loss(valid_mean_predictions, future_values)
            quantile_indices = [i for i in range(valid_full_predictions.shape[-1]) if i != decode_index]
            if quantile_indices:
                index_tensor = torch.tensor(quantile_indices, device=valid_full_predictions.device, dtype=torch.long)
                quantile_tensor = torch.index_select(valid_full_predictions, dim=-1, index=index_tensor)
                quantile_loss = self._quantile_loss(quantile_tensor, future_values)
                loss = mse_loss + quantile_loss
            else:
                loss = mse_loss

        return TimesFm2_5OutputForPrediction(
            last_hidden_state=model_outputs.last_hidden_state,
            hidden_states=model_outputs.hidden_states,
            attentions=model_outputs.attentions,
            mean_predictions=mean_predictions,
            full_predictions=full_predictions,
            loss=loss,
        )

    @staticmethod
    def _timesfm2_5_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
        """Calculates the moving average using PyTorch's convolution function."""
        # Pad with zeros to handle initial window positions
        arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
        # Create a convolution kernel
        kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
        # Apply convolution to calculate the moving average
        smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
        return [smoothed_arr, arr - smoothed_arr]

    def _decode_and_project(
        self,
        normalized_ts: torch.Tensor,
        input_padding: torch.Tensor,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Run the decoder and project to point/quantile outputs.

        Returns:
            Tuple of (point_forecast, quantile_spreads), each of shape `(batch, length, num_quantiles)`.
        """
        model_outputs = self.model(
            past_values=normalized_ts,
            past_values_padding=input_padding,
            **kwargs,
        )

        hidden_states = model_outputs.last_hidden_state
        context_mu = model_outputs.context_mu
        context_sigma = model_outputs.context_sigma

        point_output = self.model._revin(
            self.output_projection_point(hidden_states), context_mu, context_sigma, reverse=True
        )
        quantile_output = self.model._revin(
            self.output_projection_quantiles(hidden_states), context_mu, context_sigma, reverse=True
        )

        batch_size, num_patches = point_output.shape[:2]
        num_quantiles = len(self.config.quantiles) + 1

        point_forecast = point_output.view(batch_size, num_patches, self.config.horizon_length, num_quantiles)[
            :, -1, :, :
        ]
        quantile_spreads = quantile_output.view(
            batch_size, num_patches, self.config.output_quantile_len, num_quantiles
        )[:, -1, :, :]

        # Ensure both outputs are on the same device for model parallelism
        quantile_spreads = quantile_spreads.to(point_forecast.device)

        return point_forecast, quantile_spreads, model_outputs


__all__ = ["TimesFm2_5ModelForPrediction", "TimesFm2_5PreTrainedModel", "TimesFm2_5Model"]
