# Copyright 2025 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections.abc import Callable
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from ... import initialization as init
from ...activations import ACT2FN
from ...configuration_utils import PreTrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import (
    auto_docstring,
)
from ...utils.output_capturing import OutputRecorder
from ..auto import CONFIG_MAPPING, AutoConfig
from ..sam2.modeling_sam2 import eager_attention_forward, window_partition
from ..sam2_video.configuration_sam2_video import (
    Sam2VideoConfig,
    Sam2VideoMaskDecoderConfig,
    Sam2VideoPromptEncoderConfig,
)
from ..sam2_video.modeling_sam2_video import (
    Sam2VideoAttention,
    Sam2VideoFeedForward,
    Sam2VideoImageSegmentationOutput,
    Sam2VideoInferenceSession,
    Sam2VideoLayerNorm,
    Sam2VideoMemoryAttention,
    Sam2VideoMemoryEncoder,
    Sam2VideoMemoryFuserCXBlock,
    Sam2VideoModel,
    Sam2VideoPositionEmbeddingSine,
    Sam2VideoPreTrainedModel,
    Sam2VideoSegmentationOutput,
    Sam2VideoTwoWayAttentionBlock,
    Sam2VideoVisionEncoderOutput,
    Sam2VideoVisionRotaryEmbedding,
    rotate_pairwise,
)


class EdgeTamVideoPromptEncoderConfig(Sam2VideoPromptEncoderConfig):
    pass


class EdgeTamVideoMaskDecoderConfig(Sam2VideoMaskDecoderConfig):
    pass


class EdgeTamVideoConfig(Sam2VideoConfig):
    r"""
    [`EdgeTamVideoConfig`] is the configuration class to store the configuration of a [`EdgeTamVideoModel`]. It is used to instantiate a
    EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder
    configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny
    [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture.

    Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PreTrainedConfig`] for more information.

    Args:
        vision_config (Union[`dict`, `EdgeTamVideoVisionConfig`], *optional*):
            Dictionary of configuration options used to initialize [`EdgeTamVideoVisionConfig`].
        prompt_encoder_config (Union[`dict`, `EdgeTamVideoPromptEncoderConfig`], *optional*):
            Dictionary of configuration options used to initialize [`EdgeTamVideoPromptEncoderConfig`].
        mask_decoder_config (Union[`dict`, `EdgeTamVideoMaskDecoderConfig`], *optional*):
            Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`].
        initializer_range (`float`, *optional*, defaults to 0.02):
            Standard deviation for parameter initialization.
        num_maskmem (`int`, *optional*, defaults to 7):
            The number of memory slots for the mask memory.
        image_size (`int`, *optional*, defaults to 1024):
            The size of the input images.
        sigmoid_scale_for_mem_enc (`float`, *optional*, defaults to 20.0):
            Scale factor for the sigmoid function in the memory encoder.
        sigmoid_bias_for_mem_enc (`float`, *optional*, defaults to -10.0):
            Bias for the sigmoid function in the memory encoder.
        enable_occlusion_spatial_embedding (`bool`, *optional*, defaults to `True`):
            Whether to enable spatial embedding for occlusions.
        multimask_output_in_sam (`bool`, *optional*, defaults to `True`):
            Whether to output multiple masks from the SAM head.
        multimask_min_pt_num (`int`, *optional*, defaults to 0):
            The minimum number of points to trigger multimask output.
        multimask_max_pt_num (`int`, *optional*, defaults to 1):
            The maximum number of points to trigger multimask output.
        multimask_output_for_tracking (`bool`, *optional*, defaults to `True`):
            Whether to use multimask output for tracking.
        max_object_pointers_in_encoder (`int`, *optional*, defaults to 16):
            The maximum number of object pointers in the encoder.
        max_cond_frame_num (`int`, *optional*, defaults to -1):
            Maximum number of conditioning frames to use in memory attention. Set to -1 to use all conditioning frames.
        enable_temporal_pos_encoding_for_object_pointers (`bool`, *optional*, defaults to `True`):
            Whether to enable temporal positional encoding for object pointers.
        memory_attention_hidden_size (`int`, *optional*, defaults to 256):
            Dimensionality of the memory attention hidden states.
        memory_attention_num_layers (`int`, *optional*, defaults to 2):
            The number of layers in the memory attention module.
        memory_attention_num_attention_heads (`int`, *optional*, defaults to 1):
            Number of attention heads for each attention layer in the memory attention.
        memory_attention_downsample_rate (`int`, *optional*, defaults to 1):
            The downsample rate for the attention layers.
        memory_attention_mlp_hidden_size (`int`, *optional*, defaults to 2048):
            The dimension of the feedforward network in the memory attention module.
        memory_attention_mlp_hidden_act (`str`, *optional*, defaults to `"relu"`):
            The non-linear activation function in the feedforward network in the memory attention module.
        memory_attention_dropout (`float`, *optional*, defaults to 0.1):
            The dropout rate for the memory attention module.
        memory_attention_rope_theta (`float`, *optional*, defaults to 10000):
            The Rope theta parameter.
        memory_attention_rope_feat_sizes (`Tuple[int, int]`, *optional*, defaults to `[64, 64]`):
            The feature sizes for the Rope positional encoding.
        memory_attention_rope_k_sizes (`List[int]`, *optional*, defaults to `[16, 16]`):
            The key feature sizes for the RoPE positional encoding in memory attention.
        memory_attention_rope_dropout (`float`, *optional*, defaults to 0.1):
            The dropout rate for the Rope positional encoding.
        perceiver_resampler_num_latents (`int`, *optional*, defaults to 256):
            The number of 1D latent tokens in the perceiver resampler.
        perceiver_resampler_num_latents_2d (`int`, *optional*, defaults to 256):
            The number of 2D latent tokens in the perceiver resampler.
        perceiver_resampler_hidden_size (`int`, *optional*, defaults to 64):
            The hidden size of the perceiver resampler.
        perceiver_resampler_mlp_intermediate_size (`int`, *optional*, defaults to 256):
            The intermediate size of the feedforward network in the perceiver resampler.
        perceiver_resampler_num_attention_heads (`int`, *optional*, defaults to 1):
            The number of attention heads in the perceiver resampler.
        perceiver_resampler_attention_head_dim (`int`, *optional*, defaults to 64):
            The dimension of each attention head in the perceiver resampler.
        perceiver_resampler_num_layers (`int`, *optional*, defaults to 2):
            The number of layers in the perceiver resampler.
        perceiver_resampler_hidden_dropout (`float`, *optional*, defaults to 0.0):
            The dropout rate for the hidden layers in the perceiver resampler.
        perceiver_resampler_attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout rate for the attention layers in the perceiver resampler.
        memory_encoder_hidden_size (`int`, *optional*, defaults to 256):
            Dimensionality of the memory encoder hidden states.
        memory_encoder_output_channels (`int`, *optional*, defaults to 64):
            The number of output channels for the memory encoder.
        mask_downsampler_embed_dim (`int`, *optional*, defaults to 256):
            The dimension of the mask downsampler embedding.
        memory_fuser_intermediate_dim (`int`, *optional*, defaults to 1024):
            The intermediate dimension of the memory fuser feedforward network.
        mask_downsampler_kernel_size (`int`, *optional*, defaults to 3):
            The kernel size for the mask downsampler.
        mask_downsampler_stride (`int`, *optional*, defaults to 2):
            The stride for the mask downsampler.
        mask_downsampler_padding (`int`, *optional*, defaults to 1):
            The padding for the mask downsampler.
        mask_downsampler_total_stride (`int`, *optional*, defaults to 16):
            The total stride for the mask downsampler.
        mask_downsampler_hidden_act (`str`, *optional*, defaults to `"gelu"`):
            The non-linear activation function in the mask downsampler.
        memory_fuser_num_layers (`int`, *optional*, defaults to 2):
            The number of layers in the memory fuser.
        memory_fuser_embed_dim (`int`, *optional*, defaults to 256):
            The dimension of the memory fuser embedding.
        memory_fuser_kernel_size (`int`, *optional*, defaults to 7):
            The kernel size for the memory fuser.
        memory_fuser_padding (`int`, *optional*, defaults to 3):
            The padding for the memory fuser.
        memory_fuser_layer_scale_init_value (`float`, *optional*, defaults to 1e-06):
            The initial value for the layer scale in the memory fuser.
        memory_fuser_hidden_act (`str`, *optional*, defaults to `"gelu"`):
            The non-linear activation function in the memory fuser.

    Example:

    ```python
    >>> from transformers import (
    ...     EdgeTamVisionConfig,
    ...     EdgeTamVideoPromptEncoderConfig,
    ...     EdgeTamVideoMaskDecoderConfig,
    ...     EdgeTamVideoModel,
    ...     EdgeTamVideoConfig,
    ... )

    >>> # Initializing a EdgeTamVideoConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration
    >>> configuration = EdgeTamVideoConfig()

    >>> # Initializing a EdgeTamVideoModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration
    >>> model = EdgeTamVideoModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config

    >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig

    >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations
    >>> vision_config = EdgeTamVisionConfig()
    >>> prompt_encoder_config = EdgeTamVideoPromptEncoderConfig()
    >>> mask_decoder_config = EdgeTamVideoMaskDecoderConfig()

    >>> config = EdgeTamVideoConfig(vision_config, prompt_encoder_config, mask_decoder_config)
    ```"""

    model_type = "edgetam_video"
    sub_configs = {
        "vision_config": AutoConfig,
        "prompt_encoder_config": EdgeTamVideoPromptEncoderConfig,
        "mask_decoder_config": EdgeTamVideoMaskDecoderConfig,
    }

    def __init__(
        self,
        vision_config=None,
        prompt_encoder_config=None,
        mask_decoder_config=None,
        initializer_range=0.02,
        num_maskmem=7,
        image_size=1024,
        sigmoid_scale_for_mem_enc=20.0,
        sigmoid_bias_for_mem_enc=-10.0,
        enable_occlusion_spatial_embedding=True,
        multimask_output_in_sam=True,
        multimask_min_pt_num=0,
        multimask_max_pt_num=1,
        multimask_output_for_tracking=True,
        max_object_pointers_in_encoder=16,
        max_cond_frame_num=-1,
        enable_temporal_pos_encoding_for_object_pointers=True,
        # memory attention
        memory_attention_hidden_size=256,
        memory_attention_num_layers=2,
        memory_attention_num_attention_heads=1,
        memory_attention_downsample_rate=1,
        memory_attention_mlp_hidden_size=2048,
        memory_attention_mlp_hidden_act="relu",
        memory_attention_dropout=0.1,
        memory_attention_rope_theta=10000,
        memory_attention_rope_feat_sizes=None,
        memory_attention_rope_k_sizes=None,
        memory_attention_rope_dropout=0.1,
        # spatial perceiver resampler
        perceiver_resampler_num_latents=256,
        perceiver_resampler_num_latents_2d=256,
        perceiver_resampler_hidden_size=64,
        perceiver_resampler_mlp_intermediate_size=256,
        perceiver_resampler_num_attention_heads=1,
        perceiver_resampler_attention_head_dim=64,
        perceiver_resampler_num_layers=2,
        perceiver_resampler_hidden_dropout=0.0,
        perceiver_resampler_attention_dropout=0.0,
        # memory encoder
        memory_encoder_hidden_size=256,
        memory_encoder_output_channels=64,
        mask_downsampler_embed_dim=256,
        memory_fuser_intermediate_dim=1024,
        mask_downsampler_kernel_size=3,
        mask_downsampler_stride=2,
        mask_downsampler_padding=1,
        mask_downsampler_total_stride=16,
        mask_downsampler_hidden_act="gelu",
        memory_fuser_num_layers=2,
        memory_fuser_embed_dim=256,
        memory_fuser_kernel_size=7,
        memory_fuser_padding=3,
        memory_fuser_layer_scale_init_value=1e-6,
        memory_fuser_hidden_act="gelu",
        **kwargs,
    ):
        PreTrainedConfig.__init__(**kwargs)
        vision_config = vision_config if vision_config is not None else {}
        prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
        mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
        memory_attention_rope_feat_sizes = (
            [64, 64] if memory_attention_rope_feat_sizes is None else memory_attention_rope_feat_sizes
        )
        memory_attention_rope_k_sizes = (
            [16, 16] if memory_attention_rope_k_sizes is None else memory_attention_rope_k_sizes
        )

        if isinstance(vision_config, dict):
            vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model")
            vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
        if isinstance(prompt_encoder_config, EdgeTamVideoPromptEncoderConfig):
            prompt_encoder_config = prompt_encoder_config.to_dict()
        if isinstance(mask_decoder_config, EdgeTamVideoMaskDecoderConfig):
            mask_decoder_config = mask_decoder_config.to_dict()

        self.vision_config = vision_config
        self.prompt_encoder_config = EdgeTamVideoPromptEncoderConfig(**prompt_encoder_config)
        self.mask_decoder_config = EdgeTamVideoMaskDecoderConfig(**mask_decoder_config)

        self.initializer_range = initializer_range
        self.num_maskmem = num_maskmem  # default 1 input frame + 6 previous frames
        self.image_size = image_size
        self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc  # scale factor for mask sigmoid prob
        self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc  # bias factor for mask sigmoid prob
        self.enable_occlusion_spatial_embedding = enable_occlusion_spatial_embedding
        self.multimask_output_in_sam = multimask_output_in_sam
        self.multimask_min_pt_num = multimask_min_pt_num
        self.multimask_max_pt_num = multimask_max_pt_num
        self.multimask_output_for_tracking = multimask_output_for_tracking
        self.max_object_pointers_in_encoder = max_object_pointers_in_encoder
        self.max_cond_frame_num = max_cond_frame_num
        self.enable_temporal_pos_encoding_for_object_pointers = enable_temporal_pos_encoding_for_object_pointers

        # memory attention
        self.memory_attention_hidden_size = memory_attention_hidden_size
        self.memory_attention_num_layers = memory_attention_num_layers
        self.memory_attention_num_attention_heads = memory_attention_num_attention_heads
        self.memory_attention_downsample_rate = memory_attention_downsample_rate
        self.memory_attention_mlp_hidden_size = memory_attention_mlp_hidden_size
        self.memory_attention_mlp_hidden_act = memory_attention_mlp_hidden_act
        self.memory_attention_dropout = memory_attention_dropout
        self.memory_attention_rope_theta = memory_attention_rope_theta
        self.memory_attention_rope_feat_sizes = memory_attention_rope_feat_sizes
        self.memory_attention_rope_k_sizes = memory_attention_rope_k_sizes
        self.memory_attention_rope_dropout = memory_attention_rope_dropout

        # spatial perceiver resampler
        self.perceiver_resampler_num_latents = perceiver_resampler_num_latents
        self.perceiver_resampler_num_latents_2d = perceiver_resampler_num_latents_2d
        self.perceiver_resampler_hidden_size = perceiver_resampler_hidden_size
        self.perceiver_resampler_mlp_intermediate_size = perceiver_resampler_mlp_intermediate_size
        self.perceiver_resampler_attention_head_dim = perceiver_resampler_attention_head_dim
        self.perceiver_resampler_num_attention_heads = perceiver_resampler_num_attention_heads
        self.perceiver_resampler_num_layers = perceiver_resampler_num_layers
        self.perceiver_resampler_hidden_dropout = perceiver_resampler_hidden_dropout
        self.perceiver_resampler_attention_dropout = perceiver_resampler_attention_dropout

        # memory encoder
        self.memory_encoder_hidden_size = memory_encoder_hidden_size
        self.memory_encoder_output_channels = memory_encoder_output_channels
        self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
        self.mask_downsampler_kernel_size = mask_downsampler_kernel_size
        self.mask_downsampler_stride = mask_downsampler_stride
        self.mask_downsampler_padding = mask_downsampler_padding
        self.mask_downsampler_total_stride = mask_downsampler_total_stride
        self.mask_downsampler_hidden_act = mask_downsampler_hidden_act
        self.memory_fuser_num_layers = memory_fuser_num_layers
        self.memory_fuser_embed_dim = memory_fuser_embed_dim
        self.memory_fuser_intermediate_dim = memory_fuser_intermediate_dim
        self.memory_fuser_kernel_size = memory_fuser_kernel_size
        self.memory_fuser_padding = memory_fuser_padding
        self.memory_fuser_layer_scale_init_value = memory_fuser_layer_scale_init_value
        self.memory_fuser_hidden_act = memory_fuser_hidden_act


class EdgeTamVideoLayerNorm(Sam2VideoLayerNorm):
    pass


class EdgeTamVideoMemoryFuserCXBlock(Sam2VideoMemoryFuserCXBlock):
    pass


class EdgeTamVideoVisionEncoderOutput(Sam2VideoVisionEncoderOutput):
    pass


class EdgeTamVideoVisionRotaryEmbedding(Sam2VideoVisionRotaryEmbedding):
    def __init__(self, config: EdgeTamVideoConfig, end_x: int | None = None, end_y: int | None = None):
        nn.Module.__init__()
        self.dim = config.memory_attention_hidden_size // (
            config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
        )
        # Ensure even dimension for proper axial splitting
        if self.dim % 4 != 0:
            raise ValueError("Dimension must be divisible by 4 for axial RoPE")
        self.end_x, self.end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
        self.memory_attention_rope_theta = config.memory_attention_rope_theta

        # directly register the cos and sin embeddings as we have a fixed feature shape
        inv_freq = self.create_inv_freq()
        self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
        self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)


class EdgeTamVideoAttention(Sam2VideoAttention):
    pass


def apply_rotary_pos_emb_2d_self_attn(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary position embedding to query and key tensors for self-attention.

    Args:
        q: Query tensor of shape (..., seq_len, head_dim)
        k: Key tensor of shape (..., seq_len, head_dim)
        cos: Cosine position embedding of shape (seq_len, head_dim)
        sin: Sine position embedding of shape (seq_len, head_dim)

    Returns:
        Rotated (q, k) tensors
    """
    # Apply RoPE to queries
    q_embed = q.float()  # force upscale to float32 as in the original implementation
    q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)

    # Apply RoPE to keys (same embeddings as queries for self-attention)
    k_embed = k.float()  # force upscale to float32 as in the original implementation
    k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin)

    return q_embed.type_as(q), k_embed.type_as(k)


def apply_rotary_pos_emb_2d_cross_attn(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    cos_k: torch.Tensor,
    sin_k: torch.Tensor,
    num_k_exclude_rope: int = 0,
    repeat_freqs_k: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary position embedding to query and key tensors for cross-attention.

    Args:
        q: Query tensor of shape (..., seq_len, head_dim)
        k: Key tensor of shape (..., seq_len, head_dim)
        cos: Cosine position embedding of shape (seq_len, head_dim)
        sin: Sine position embedding of shape (seq_len, head_dim)
        cos_k: Cosine position embedding for keys of shape (seq_len, head_dim)
        sin_k: Sine position embedding for keys of shape (seq_len, head_dim)
        num_k_exclude_rope: Number of tokens at end of k to exclude from RoPE (e.g., object pointer tokens)
        repeat_freqs_k: Frequency repetition for keys in cross-attention (e.g., for spatial memory tokens)

    Returns:
        Rotated (q, k) tensors
    """
    # Apply RoPE to queries (always straightforward)
    q_embed = q.float()
    q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)

    # Split keys: RoPE tokens and excluded tokens (e.g., object pointers)
    num_total_k_tokens = k.shape[-2]
    k_for_rope = k[..., : num_total_k_tokens - num_k_exclude_rope, :]
    k_excluded = k[..., num_total_k_tokens - num_k_exclude_rope :, :]

    # Early return if no keys need RoPE
    if k_for_rope.shape[-2] == 0:
        return q_embed.type_as(q), k_excluded

    batch_size, num_heads, k_seq_len, channels_per_head = k_for_rope.shape

    # Handle temporal/spatial token structure for memory
    # Keys have temporal + spatial structure, only spatial tokens get RoPE
    tokens_per_group = k_seq_len // repeat_freqs_k
    spatial_tokens = cos_k.shape[-2]
    temporal_tokens = tokens_per_group - spatial_tokens

    # Reshape and separate temporal/spatial tokens
    k_grouped = k_for_rope.view(batch_size, num_heads, repeat_freqs_k, tokens_per_group, channels_per_head)
    k_temporal = k_grouped[..., :temporal_tokens, :].reshape(batch_size, num_heads, -1, channels_per_head)
    k_spatial = k_grouped[..., temporal_tokens:, :].reshape(batch_size, num_heads, -1, channels_per_head)

    # Only apply RoPE to spatial tokens
    k_rope_input = k_spatial

    # Prepare position embeddings for repeated groups
    if repeat_freqs_k > 1:
        cos_k = cos_k.repeat(1, 1, repeat_freqs_k, 1)
        sin_k = sin_k.repeat(1, 1, repeat_freqs_k, 1)

    # Apply RoPE to spatial tokens
    k_spatial_embed = k_rope_input.float()
    k_spatial_embed = (k_spatial_embed * cos_k) + (rotate_pairwise(k_spatial_embed) * sin_k)

    # Reconstruct: temporal + spatial tokens back to original structure
    k_spatial_reshaped = k_spatial_embed.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head)
    k_temporal_reshaped = k_temporal.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head)
    k_final = torch.cat([k_temporal_reshaped, k_spatial_reshaped], dim=3)
    k_final = k_final.view(batch_size, num_heads, k_seq_len, channels_per_head)

    # Combine RoPE-processed keys with excluded tokens
    k_embed = torch.cat([k_final.type_as(k), k_excluded], dim=-2)
    return q_embed.type_as(q), k_embed


class EdgeTamVideoRoPESelfAttention(nn.Module):
    """Self-attention with rotary position encoding."""

    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.memory_attention_hidden_size
        self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
        self.num_attention_heads = config.memory_attention_num_attention_heads
        self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
        self.scaling = self.head_dim**-0.5
        self.is_causal = False

        self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
        self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
        self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
        self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
        self.dropout_p = config.memory_attention_rope_dropout

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tensor:
        # Input projections
        batch_size, point_batch_size = query.shape[:2]
        new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)

        query = self.q_proj(query).view(*new_shape).transpose(1, 2)
        key = self.k_proj(key).view(*new_shape).transpose(1, 2)
        value = self.v_proj(value).view(*new_shape).transpose(1, 2)

        cos, sin = position_embeddings
        # Apply rotary position encoding for self-attention
        query, key = apply_rotary_pos_emb_2d_self_attn(query, key, cos=cos, sin=sin)

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query,
            key,
            value,
            attention_mask=None,
            dropout=0.0 if not self.training else self.dropout_p,
            scaling=self.scaling,
            is_causal=self.is_causal,
            **kwargs,
        )
        attn_output = attn_output.reshape(
            batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
        ).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class EdgeTamVideoRoPECrossAttention(nn.Module):
    """Cross-attention with rotary position encoding."""

    def __init__(self, config: EdgeTamVideoConfig, kv_in_dim: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.memory_attention_hidden_size
        self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
        self.num_attention_heads = config.memory_attention_num_attention_heads
        self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
        self.scaling = self.head_dim**-0.5
        self.is_causal = False

        self.kv_in_dim = kv_in_dim

        self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
        self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
        self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
        self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
        self.dropout_p = config.memory_attention_rope_dropout

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        position_embeddings_k: tuple[torch.Tensor, torch.Tensor],
        num_k_exclude_rope: int = 0,
        rope_k_repeat: int = 0,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tensor:
        # Input projections
        batch_size, point_batch_size = query.shape[:2]
        new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)

        query = self.q_proj(query).view(*new_shape).transpose(1, 2)
        key = self.k_proj(key).view(*new_shape).transpose(1, 2)
        value = self.v_proj(value).view(*new_shape).transpose(1, 2)

        cos, sin = position_embeddings
        cos_k, sin_k = position_embeddings_k
        # Apply rotary position encoding for cross-attention
        query, key = apply_rotary_pos_emb_2d_cross_attn(
            query,
            key,
            cos=cos,
            sin=sin,
            cos_k=cos_k,
            sin_k=sin_k,
            repeat_freqs_k=rope_k_repeat,
            num_k_exclude_rope=num_k_exclude_rope,
        )

        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, attn_weights = attention_interface(
            self,
            query,
            key,
            value,
            attention_mask=None,
            dropout=0.0 if not self.training else self.dropout_p,
            scaling=self.scaling,
            is_causal=self.is_causal,
            **kwargs,
        )
        attn_output = attn_output.reshape(
            batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
        ).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class EdgeTamVideoTwoWayAttentionBlock(Sam2VideoTwoWayAttentionBlock):
    pass


class EdgeTamVideoPositionEmbeddingSine(Sam2VideoPositionEmbeddingSine):
    # maxsize=2 because we need to cache the forward method for both memory encoder and perceiver resampler
    @compile_compatible_method_lru_cache(maxsize=2)
    def forward(self, **super_kwargs):
        return super().forward(**super_kwargs)


class EdgeTamVideoMemoryEncoder(Sam2VideoMemoryEncoder):
    pass


class EdgeTamVideoFeedForward(Sam2VideoFeedForward):
    pass


class EdgeTamVideoPreTrainedModel(Sam2VideoPreTrainedModel):
    def _init_weights(self, module):
        super()._init_weights()
        if isinstance(module, EdgeTamVideoVisionRotaryEmbedding):
            inv_freq = module.create_inv_freq()
            init.copy_(module.rope_embeddings_cos, inv_freq.cos())
            init.copy_(module.rope_embeddings_sin, inv_freq.sin())


class EdgeTamVideoInferenceSession(Sam2VideoInferenceSession):
    pass


class EdgeTamVideoMemoryAttentionMLP(nn.Module):
    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.memory_attention_hidden_size
        self.intermediate_size = config.memory_attention_mlp_hidden_size
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
        self.dropout = nn.Dropout(config.memory_attention_dropout)
        self.act_fn = ACT2FN[config.memory_attention_mlp_hidden_act]

    def forward(self, x):
        return self.down_proj(self.dropout(self.act_fn(self.up_proj(x))))


class EdgeTamVideoMemoryAttentionLayer(nn.Module):
    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()
        hidden_size = config.memory_attention_hidden_size
        self.self_attn = EdgeTamVideoRoPESelfAttention(config)
        self.cross_attn_image = EdgeTamVideoRoPECrossAttention(config, kv_in_dim=64)

        # MLP module
        self.mlp = EdgeTamVideoMemoryAttentionMLP(config)

        self.layer_norm1 = nn.LayerNorm(hidden_size)
        self.layer_norm2 = nn.LayerNorm(hidden_size)
        self.layer_norm3 = nn.LayerNorm(hidden_size)
        self.dropout1 = nn.Dropout(config.memory_attention_dropout)
        self.dropout2 = nn.Dropout(config.memory_attention_dropout)
        self.dropout3 = nn.Dropout(config.memory_attention_dropout)

    def forward(
        self,
        queries: Tensor,
        keys: Tensor,
        key_point_embedding: Tensor,
        rope_position_embeddings: tuple[Tensor, Tensor],
        rope_position_embeddings_k: tuple[Tensor, Tensor] | None = None,
        num_k_exclude_rope: int = 0,
        rope_k_repeat: int = 0,
    ) -> torch.Tensor:
        # Self-Attention
        query = self.layer_norm1(queries)
        query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings)
        queries = queries + self.dropout1(query)

        # Cross-Attention
        query = self.layer_norm2(queries)
        query, _ = self.cross_attn_image(
            query=query,
            key=keys + key_point_embedding,
            value=keys,
            position_embeddings=rope_position_embeddings,
            position_embeddings_k=rope_position_embeddings_k,
            num_k_exclude_rope=num_k_exclude_rope,
            rope_k_repeat=rope_k_repeat,
        )
        queries = queries + self.dropout2(query)
        # MLP
        query = self.layer_norm3(queries)
        query = self.mlp(query)
        queries = queries + self.dropout3(query)
        return queries


class EdgeTamVideoMemoryAttention(Sam2VideoMemoryAttention):
    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()
        self.rotary_emb_k = EdgeTamVideoVisionRotaryEmbedding(
            config, end_x=config.memory_attention_rope_k_sizes[0], end_y=config.memory_attention_rope_k_sizes[1]
        )

    def forward(
        self,
        current_vision_features: torch.Tensor,
        memory: torch.Tensor,
        current_vision_position_embeddings: Tensor | None = None,
        memory_posision_embeddings: Tensor | None = None,
        num_object_pointer_tokens: int = 0,
        num_spatial_memory_tokens: int = -1,
    ):
        """
        Args:
            current_vision_features (`torch.FloatTensor`):
                The current vision features used for self-attention.
            memory (`torch.FloatTensor`):
                The memory features used for cross-attention.
            current_vision_position_embeddings (`torch.FloatTensor`, *optional*):
                The position embeddings for the current vision features.
            memory_posision_embeddings (`torch.FloatTensor`, *optional*):
                The position embeddings for the memory features.
            num_object_pointer_tokens (`int`, *optional*, defaults to 0):
                The number of object pointer tokens.
        """
        output = current_vision_features
        if current_vision_position_embeddings is not None:
            output = output + 0.1 * current_vision_position_embeddings

        # Convert to batch first
        output = output.transpose(0, 1)
        memory = memory.transpose(0, 1).unsqueeze(1)
        memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1)
        rope_position_embeddings = self.rotary_emb()
        rope_position_embeddings_k = self.rotary_emb_k()
        for layer in self.layers:
            output = layer(
                queries=output.unsqueeze(1) if output.ndim == 3 else output,
                keys=memory,
                key_point_embedding=memory_posision_embeddings,
                rope_position_embeddings=rope_position_embeddings,
                rope_position_embeddings_k=rope_position_embeddings_k,
                num_k_exclude_rope=num_object_pointer_tokens,
                rope_k_repeat=num_spatial_memory_tokens,
            )

        normed_output = self.layer_norm(output)

        # Convert back to seq first
        normed_output = normed_output.transpose(0, 1)

        return normed_output


class EdgeTamVideoPerceiverMLP(nn.Module):
    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()
        self.hidden_size = config.perceiver_resampler_hidden_size
        self.intermediate_size = config.perceiver_resampler_mlp_intermediate_size

        self.layer_norm = nn.LayerNorm(self.hidden_size)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = nn.GELU()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.down_proj(self.act_fn(self.up_proj(hidden_states)))
        return hidden_states


class EdgeTamVideoPerceiverAttention(nn.Module):
    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.perceiver_resampler_hidden_size
        self.num_attention_heads = config.perceiver_resampler_num_attention_heads
        self.head_dim = config.perceiver_resampler_attention_head_dim
        self.attention_dropout = config.perceiver_resampler_attention_dropout

        self.inner_dim = self.head_dim * self.num_attention_heads
        self.scaling = self.head_dim**-0.5
        self.is_causal = False

        self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
        self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        positional_encoding: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        # Project queries, keys, and values
        query = self.q_proj(query)
        key = self.k_proj(key)
        value = self.v_proj(value)

        # Reshape for multi-head attention
        batch_size, seq_len_q = query.shape[:2]
        query = query.view(batch_size, seq_len_q, self.num_attention_heads, self.head_dim).transpose(1, 2)
        seq_len_kv = key.shape[1]
        key = key.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2)

        # Add positional encoding if provided
        if positional_encoding is not None:
            pos_encoding = positional_encoding.view(
                batch_size, seq_len_kv, self.num_attention_heads, self.head_dim
            ).transpose(1, 2)
            key = key + pos_encoding
            value = value + pos_encoding

        # Apply attention
        attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
            self.config._attn_implementation, eager_attention_forward
        )

        attn_output, _ = attention_interface(
            self,
            query,
            key,
            value,
            attention_mask=None,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            is_causal=self.is_causal,
            **kwargs,
        )

        # Reshape output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.inner_dim)
        return self.o_proj(attn_output)


class EdgeTamVideoPerceiverEncoderLayer(nn.Module):
    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()

        self.cross_attention = EdgeTamVideoPerceiverAttention(config)
        self.mlp = EdgeTamVideoPerceiverMLP(config)
        self.dropout = nn.Dropout(config.perceiver_resampler_hidden_dropout)

        self.self_attention = EdgeTamVideoPerceiverAttention(config)
        self.self_mlp = EdgeTamVideoPerceiverMLP(config)

        # Layer norms moved from attention classes to here
        self.layer_norm_input = nn.LayerNorm(config.perceiver_resampler_hidden_size)
        self.layer_norm_latents = nn.LayerNorm(config.perceiver_resampler_hidden_size)
        self.layer_norm_self = nn.LayerNorm(config.perceiver_resampler_hidden_size)

    def forward(
        self,
        latents: torch.Tensor,
        input_features: torch.Tensor,
        positional_encoding: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Cross attention with layer norms
        normalized_latents = self.layer_norm_latents(latents)
        normalized_input = self.layer_norm_input(input_features)
        cross_attention_output = self.cross_attention(
            query=normalized_latents,
            key=normalized_input,
            value=normalized_input,
            positional_encoding=positional_encoding,
        )
        latents = latents + self.dropout(cross_attention_output)

        mlp_output = self.mlp(latents)
        latents = latents + mlp_output

        # Self attention with layer norm
        normalized_latents_self = self.layer_norm_self(latents)
        self_attention_output = self.self_attention(
            query=normalized_latents_self, key=normalized_latents_self, value=normalized_latents_self
        )
        latents = latents + self_attention_output

        self_mlp_output = self.self_mlp(latents)
        latents = latents + self_mlp_output

        return latents


class EdgeTamVideoPerceiverResampler(nn.Module):
    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.perceiver_resampler_hidden_size
        self.num_latents_1d = config.perceiver_resampler_num_latents
        self.num_latents_2d = config.perceiver_resampler_num_latents_2d
        self.num_layers = config.perceiver_resampler_num_layers

        if self.num_latents_1d > 0:
            self.latents_1d = nn.Parameter(torch.randn(self.num_latents_1d, self.hidden_size))
        if self.num_latents_2d > 0:
            self.latents_2d = nn.Parameter(torch.randn(self.num_latents_2d, self.hidden_size))

        self.positional_encoding = EdgeTamVideoPositionEmbeddingSine(
            num_pos_feats=self.hidden_size // 2, normalize=True
        )

        self.layers = nn.ModuleList([EdgeTamVideoPerceiverEncoderLayer(config) for _ in range(self.num_layers)])

        self.layer_norm = nn.LayerNorm(self.hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        positional_encoding: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        output_latents = []
        output_positional_encodings = []

        if self.num_latents_1d > 0:
            latents_1d, pos_1d = self._forward_1d(hidden_states, positional_encoding)
            output_latents.append(latents_1d)
            output_positional_encodings.append(pos_1d)

        if self.num_latents_2d > 0:
            latents_2d, pos_2d = self._forward_2d(hidden_states)
            output_latents.append(latents_2d)
            output_positional_encodings.append(pos_2d)

        combined_latents = torch.cat(output_latents, dim=1)

        combined_positional_encoding = None
        if positional_encoding is not None and output_positional_encodings:
            combined_positional_encoding = torch.cat(output_positional_encodings, dim=1)

        return combined_latents, combined_positional_encoding

    def _forward_1d(
        self,
        hidden_states: torch.Tensor,
        positional_encoding: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        batch_size = hidden_states.shape[0]

        latents = self.latents_1d.unsqueeze(0).expand(batch_size, -1, -1)
        flattened_features = hidden_states.permute(0, 2, 3, 1).flatten(1, 2)

        positional_features = None
        if positional_encoding is not None:
            positional_features = positional_encoding.permute(0, 2, 3, 1).flatten(1, 2)

        for layer in self.layers:
            latents = layer(latents, flattened_features, positional_features)

        latents = self.layer_norm(latents)

        output_positional_encoding = None
        if positional_encoding is not None:
            output_positional_encoding = torch.zeros_like(latents)

        return latents, output_positional_encoding

    def _forward_2d(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, channels, height, width = hidden_states.shape

        latents_2d = self.latents_2d.unsqueeze(0).expand(batch_size, -1, -1).view(-1, 1, channels)

        num_windows_per_dim = int(math.sqrt(self.num_latents_2d))
        window_size = height // num_windows_per_dim

        windowed_input = hidden_states.permute(0, 2, 3, 1)
        windowed_features, _ = window_partition(windowed_input, window_size)
        windowed_features = windowed_features.flatten(1, 2)

        for layer in self.layers:
            latents_2d = layer(latents_2d, windowed_features, positional_encoding=None)

        latents_2d = latents_2d.view(batch_size, num_windows_per_dim, num_windows_per_dim, channels).permute(
            0, 3, 1, 2
        )

        positional_encoding_2d = self.positional_encoding(latents_2d.shape, latents_2d.device, latents_2d.dtype).to(
            dtype=hidden_states.dtype
        )
        positional_encoding_2d = positional_encoding_2d.permute(0, 2, 3, 1).flatten(1, 2)

        latents_2d = latents_2d.permute(0, 2, 3, 1).flatten(1, 2)
        latents_2d = self.layer_norm(latents_2d)

        return latents_2d, positional_encoding_2d


class EdgeTamVideoImageSegmentationOutput(Sam2VideoImageSegmentationOutput):
    pass


class EdgeTamVideoSegmentationOutput(Sam2VideoSegmentationOutput):
    pass


@auto_docstring
class EdgeTamVideoModel(Sam2VideoModel):
    _keys_to_ignore_on_load_unexpected = []
    _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}

    def __init__(self, config: EdgeTamVideoConfig):
        super().__init__(config)
        self.spatial_perceiver = EdgeTamVideoPerceiverResampler(config)

        self.post_init()

    def _build_memory_attention_inputs(
        self,
        temporal_positions_and_previous_outputs: list[tuple[int, dict]],
        device: torch.device,
    ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
        """
        Concatenate memory features and positional embeddings from previous frames.

        Returns:
            Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate).
        """
        memories_to_concatenate = []
        memory_positional_embeddings_to_concatenate = []

        for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
            if prev_output_data is None:
                continue  # Skip if no output data for this temporal position (e.g., padding frames)

            # Load memory features (potentially from CPU to GPU)
            # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
            memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
            memories_to_concatenate.append(memory_features.permute(1, 0, 2))

            # Spatial positional encoding (potentially from CPU to GPU)
            spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
            spatial_memory_pos_embed = spatial_memory_pos_embed.squeeze(1).permute(1, 0, 2)

            # Add temporal positional encoding
            # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
            combined_memory_pos_embed = (
                spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
            )
            memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)

        return memories_to_concatenate, memory_positional_embeddings_to_concatenate

    def _prepare_memory_conditioned_features(
        self,
        inference_session: EdgeTamVideoInferenceSession,
        frame_idx: int,
        obj_idx: int,
        is_initial_conditioning_frame: bool,
        current_vision_features: list[torch.Tensor],
        current_vision_positional_embeddings: list[torch.Tensor],
        num_total_frames: int,
        track_in_reverse_time: bool = False,
        streaming: bool = False,
    ) -> torch.Tensor:
        """
        Fuse current frame's visual features with memory from previous frames for enhanced object tracking.

        This method conditions the current frame's visual features on temporal memory from previous frames,
        enabling consistent object tracking across video sequences. For initial conditioning frames, it uses
        no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both
        conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention.

        Args:
            inference_session (`EdgeTamVideoInferenceSession`):
                The video inference session object.
            frame_idx (`int`):
                Index of the current frame being processed.
            obj_idx (`int`):
                Index of the object being processed.
            is_initial_conditioning_frame (`bool`):
                Whether this is an initial conditioning frame with user inputs (True) or a subsequent
                tracking frame (False).
            current_vision_features (`torch.Tensor`):
                Highest-level vision features of shape `(seq_len, batch_size, channels)`.
            current_vision_positional_embeddings (`torch.Tensor`):
                Positional embedding tensors corresponding to the highest-level vision features.
            num_total_frames (`int`):
                Total number of frames in the video sequence.
            track_in_reverse_time (`bool`, *optional*, defaults to `False`):
                Whether tracking is performed in reverse temporal order.
            streaming (`bool`, *optional*, defaults to `False`):
                Whether this is streaming inference mode.

        Returns:
            `torch.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)`
                suitable for input to the SAM decoder.
        """
        # Get dimensions from the highest-level (lowest-resolution) feature map
        batch_size = current_vision_features.size(1)
        num_channels = self.hidden_dim
        height, width = self.backbone_feature_sizes[-1]
        device = current_vision_features.device

        # If memory is disabled (e.g., for single image SAM), return current features directly.
        if self.num_maskmem == 0:
            # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width)
            # Assuming SeqLen = Height * Width for the last feature map
            current_feature_map = current_vision_features.permute(1, 2, 0).view(
                batch_size, num_channels, height, width
            )
            return current_feature_map

        # Step 1: Handle initial conditioning frames
        if is_initial_conditioning_frame:
            # For initial conditioning frames, no prior memory is used directly in this block.
            # If configured, directly add a learnable "no memory" embedding.
            # current_vision_features has shape (SeqLen, Batch, Channels)
            conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding
            # Reshape to (Batch, Channels, Height, Width)
            conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view(
                batch_size, num_channels, height, width
            )
            return conditioned_feature_map

        # Step 2: Get memory frames and concatenate their features
        temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs(
            inference_session, obj_idx, frame_idx, track_in_reverse_time
        )

        memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs(
            temporal_positions_and_previous_outputs, device
        )
        num_spatial_memory_tokens = len(memories_to_concatenate)

        # Step 3: Get and process object pointers
        temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers(
            inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming
        )

        num_object_pointer_tokens = 0
        if pointer_tokens:
            object_pointers, object_pointers_pos_embed = self._process_object_pointers(
                temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device
            )

            if object_pointers is not None:
                memories_to_concatenate.append(object_pointers)
                memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
                num_object_pointer_tokens = object_pointers.shape[0]

        # Step 4: Concatenate all retrieved memories and their positional embeddings
        combined_memory = torch.cat(memories_to_concatenate, dim=0)
        combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0)

        # Step 5: Forward through the memory attention mechanism
        conditioned_feature_map_flat = self.memory_attention(
            current_vision_features=current_vision_features,
            current_vision_position_embeddings=current_vision_positional_embeddings,
            memory=combined_memory,
            memory_posision_embeddings=combined_memory_positional_embeddings,  # Corrected typo from API
            num_object_pointer_tokens=num_object_pointer_tokens,
            num_spatial_memory_tokens=num_spatial_memory_tokens,
        )

        # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width)
        conditioned_feature_map = (
            conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width)
        )
        return conditioned_feature_map

    def _encode_new_memory(
        self,
        current_vision_feats: torch.Tensor,
        pred_masks_high_res: torch.Tensor,
        object_score_logits: torch.Tensor,
        is_mask_from_pts: bool,
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        """Encode the current image and its prediction into a memory feature."""
        batch_size = current_vision_feats.size(1)  # batch size on this frame
        channels = self.hidden_dim
        height, width = self.backbone_feature_sizes[-1]  # top-level (lowest-resolution) feature size
        # top-level feature, (HW)BC => BCHW
        pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width)
        if is_mask_from_pts and not self.training:
            # binarize the mask logits
            mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype)
        else:
            # apply sigmoid on the raw mask logits to turn them into range (0, 1)
            mask_for_mem = torch.sigmoid(pred_masks_high_res)
        # apply scale and bias terms to the sigmoid probabilities
        mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc
        mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc

        maskmem_features, maskmem_pos_enc = self.memory_encoder(
            pix_feat,
            mask_for_mem,
        )
        # add a no-object embedding to the spatial memory to indicate that the frame
        # is predicted to be occluded (i.e. no object is appearing in the frame)
        if self.occlusion_spatial_embedding_parameter is not None:
            is_obj_appearing = (object_score_logits > 0).float()
            maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[
                ..., None, None
            ].expand(*maskmem_features.shape)

        maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype)
        maskmem_features, maskmem_pos_enc = self.spatial_perceiver(maskmem_features, maskmem_pos_enc)
        maskmem_features = maskmem_features.to(pred_masks_high_res.dtype)
        maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype)

        return maskmem_features, maskmem_pos_enc

    def forward(
        self,
        inference_session: EdgeTamVideoInferenceSession,
        frame_idx: int | None = None,
        frame: torch.Tensor | None = None,
        reverse: bool = False,
        **kwargs,
    ) -> EdgeTamVideoSegmentationOutput:
        r"""
        inference_session (`EdgeTamVideoInferenceSession`):
            The video inference session object.
        frame_idx (`int`, *optional*):
            The index of the frame on which to run inference. No need to provide when inferring
            on a new streamed frame.
        frame (`torch.Tensor`, *optional*):
            The frame to process. Provide when streaming.
        reverse (`bool`, *optional*, defaults to `False`):
            Whether to propagate in reverse.
        """
        if frame is not None:
            frame_idx = inference_session.add_new_frame(frame, frame_idx)

        if frame is not None and inference_session.get_obj_num() == 0:
            raise ValueError("No objects are provided for tracking; please add inputs first.")

        num_objects = inference_session.get_obj_num()
        pred_masks_per_obj = [None] * num_objects
        object_score_logits_per_obj = [None] * num_objects
        # Note: We avoid batched inference here because per-object inputs (clicks/masks)
        # can differ across objects.
        for obj_idx in range(num_objects):
            obj_id = inference_session.obj_idx_to_id(obj_idx)
            has_new_inputs = obj_id in inference_session.obj_with_new_inputs
            has_cond_output = frame_idx in inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
            # If this object has no new inputs and this frame already has a
            # conditioning output, reuse the cached masks instead of recomputing.
            if (not has_new_inputs) and has_cond_output:
                pred_masks = inference_session.get_output(obj_idx, frame_idx, "pred_masks", is_conditioning_frame=True)
                object_score_logits = inference_session.get_output(
                    obj_idx, frame_idx, "object_score_logits", is_conditioning_frame=True
                )
                is_init_cond_frame = True
            else:
                # Defaults when there are no new inputs
                is_init_cond_frame = False
                point_inputs = None
                mask_inputs = None

                if has_new_inputs:
                    is_init_cond_frame = frame_idx not in inference_session.frames_tracked_per_obj[obj_idx]
                    if is_init_cond_frame:
                        reverse = False
                    point_inputs = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None)
                    mask_inputs = inference_session.mask_inputs_per_obj[obj_idx].get(frame_idx, None)
                    if point_inputs is not None or mask_inputs is not None:
                        inference_session.obj_with_new_inputs.remove(obj_id)

                current_out = self._run_single_frame_inference(
                    inference_session=inference_session,
                    obj_idx=obj_idx,
                    frame_idx=frame_idx,
                    batch_size=1,  # run on the slice of a single object
                    is_init_cond_frame=is_init_cond_frame,
                    point_inputs=point_inputs,
                    mask_inputs=mask_inputs,
                    reverse=reverse,
                    run_mem_encoder=True,
                    streaming=frame is not None,
                )
                inference_session.store_output(
                    obj_idx, frame_idx, output_value=current_out, is_conditioning_frame=is_init_cond_frame
                )
                pred_masks = current_out["pred_masks"]
                object_score_logits = current_out["object_score_logits"]

            pred_masks_per_obj[obj_idx] = pred_masks
            object_score_logits_per_obj[obj_idx] = object_score_logits.squeeze(-1)
            if not is_init_cond_frame:
                # only for tracked frames, not for initial conditioning frames
                inference_session.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse}

        # Resize the output mask to the original video resolution (we directly use
        # the mask scores on GPU for output to avoid any CPU conversion in between)
        if len(pred_masks_per_obj) > 1:
            all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
            all_object_score_logits = torch.cat(object_score_logits_per_obj, dim=0)
        else:
            all_pred_masks = pred_masks_per_obj[0]
            all_object_score_logits = object_score_logits_per_obj[0]

        return EdgeTamVideoSegmentationOutput(
            object_ids=inference_session.obj_ids.copy(),
            pred_masks=all_pred_masks,
            object_score_logits=all_object_score_logits,
            frame_idx=frame_idx,
        )

    def _use_mask_as_output(
        self,
        backbone_features: torch.Tensor,
        high_res_features: list[torch.Tensor],
        mask_inputs: torch.Tensor,
    ) -> EdgeTamVideoImageSegmentationOutput:
        """
        Directly turn binary `mask_inputs` into a output mask logits without using SAM.
        (same input and output shapes as in forward above).
        """
        # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
        out_scale, out_bias = 20.0, -10.0  # sigmoid(-10.0)=4.5398e-05
        mask_inputs_float = mask_inputs.to(backbone_features[0].dtype)
        high_res_masks = mask_inputs_float * out_scale + out_bias
        low_res_masks = F.interpolate(
            high_res_masks.float(),
            size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
            align_corners=False,
            mode="bilinear",
            antialias=True,  # use antialias for downsampling
        ).to(backbone_features[0].dtype)
        # a dummy IoU prediction of all 1's under mask input
        iou_scores = mask_inputs.new_ones(mask_inputs.size(0), 1).to(backbone_features[0].dtype)
        # produce an object pointer using the SAM decoder from the mask input
        object_pointer = self._single_frame_forward(
            input_masks=self.mask_downsample(mask_inputs_float.to(backbone_features[0].dtype)),
            image_embeddings=high_res_features + [backbone_features],
        ).object_pointer
        # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
        # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
        # on the object_scores from the SAM decoder.
        is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
        is_obj_appearing = is_obj_appearing[..., None]
        lambda_is_obj_appearing = is_obj_appearing.to(backbone_features[0].dtype)
        object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
        object_pointer = lambda_is_obj_appearing * object_pointer
        object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer
        return EdgeTamVideoImageSegmentationOutput(
            iou_scores=iou_scores,
            pred_masks=low_res_masks,
            high_res_masks=high_res_masks,
            object_pointer=object_pointer,
            object_score_logits=object_score_logits,
            image_embeddings=high_res_features + [backbone_features],
        )

    def _run_single_frame_inference(
        self,
        inference_session: EdgeTamVideoInferenceSession,
        frame_idx: int,
        obj_idx: int,
        batch_size: int,
        is_init_cond_frame: bool,
        point_inputs: torch.Tensor | None,
        mask_inputs: torch.Tensor | None,
        reverse: bool,
        run_mem_encoder: bool,
        prev_sam_mask_logits: torch.Tensor | None = None,
        streaming: bool = False,
    ) -> dict[str, Any]:
        """
        Perform a single tracking step for video object segmentation.

        Args:
            inference_session (`EdgeTamVideoInferenceSession`):
                The video inference session object.
            frame_idx (`int`):
                Index of the current frame.
            obj_idx (`int`):
                Index of the current object.
            batch_size (`int`):
                Batch size of the current frame.
            is_init_cond_frame (`bool`):
                Whether this is an initial conditioning frame with user inputs.
            point_inputs (`dict`, *optional*):
                Point prompt inputs for the current frame.
            mask_inputs (`torch.Tensor`, *optional*):
                Mask prompt inputs for the current frame.
            reverse (`bool`, *optional*, defaults to `False`):
                Whether to track in reverse time order.
            run_mem_encoder (`bool`, *optional*, defaults to `True`):
                Whether to run the memory encoder on predicted masks.
            prev_sam_mask_logits (`torch.Tensor`, *optional*):
                Previously predicted SAM mask logits that can be fed with new clicks.
            streaming (`bool`, *optional*, defaults to `False`):
                Whether this is streaming inference.

        Returns:
            `dict`: Dictionary containing the tracking results for the current frame, including:
                - pred_masks: Predicted low-resolution masks.
                - object_pointer: Object pointer for memory.
                - object_score_logits: Object score logits (inference only).
                - maskmem_features: Memory features for future frames.
                - maskmem_pos_enc: Memory positional encodings.
        """
        # Retrieve correct image features
        current_vision_feats, current_vision_pos_embeds = self._prepare_vision_features(
            inference_session, frame_idx, batch_size
        )
        # point and mask should not appear as input simultaneously on the same frame
        if point_inputs is not None and mask_inputs is not None:
            raise ValueError(
                "point_inputs and mask_inputs should not appear as input simultaneously on the same frame"
            )
        # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
        if len(current_vision_feats) > 1:
            high_res_features = [
                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
                for x, s in zip(current_vision_feats[:-1], self.backbone_feature_sizes[:-1])
            ]
        else:
            high_res_features = None
        if mask_inputs is not None:
            # We directly output the mask input (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
            pix_feat = current_vision_feats[-1].permute(1, 2, 0)
            pix_feat = pix_feat.view(-1, self.hidden_dim, *self.backbone_feature_sizes[-1])
            sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
        else:
            # fused the visual feature with previous memory features in the memory bank
            pix_feat = self._prepare_memory_conditioned_features(
                inference_session=inference_session,
                frame_idx=frame_idx,
                obj_idx=obj_idx,
                is_initial_conditioning_frame=is_init_cond_frame,
                current_vision_features=current_vision_feats[-1],
                current_vision_positional_embeddings=current_vision_pos_embeds[-1],
                num_total_frames=inference_session.num_frames,
                track_in_reverse_time=reverse,
                streaming=streaming,
            )
            # apply SAM-style segmentation head
            # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
            # e.g. in demo where such logits come from earlier interaction instead of correction sampling
            # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
            if prev_sam_mask_logits is not None:
                mask_inputs = prev_sam_mask_logits
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
            sam_outputs = self._single_frame_forward(
                pixel_values=None,  # Vision features already computed
                input_points=point_inputs["point_coords"] if point_inputs is not None else None,
                input_labels=point_inputs["point_labels"] if point_inputs is not None else None,
                input_masks=mask_inputs,
                image_embeddings=high_res_features + [pix_feat],
                multimask_output=multimask_output,
            )

        # Finally run the memory encoder on the predicted mask to encode
        # it into a new memory feature (which will be used to condition vision features in future frames)
        maskmem_features = None
        maskmem_pos_enc = None
        if run_mem_encoder and self.num_maskmem > 0:
            maskmem_features, maskmem_pos_enc = self._encode_new_memory(
                current_vision_feats=current_vision_feats[-1],
                pred_masks_high_res=sam_outputs.high_res_masks,
                object_score_logits=sam_outputs.object_score_logits,
                is_mask_from_pts=(point_inputs is not None or mask_inputs is not None),
            )

        current_out = {
            "pred_masks": sam_outputs.pred_masks,
            "object_pointer": sam_outputs.object_pointer,
            "maskmem_features": maskmem_features if maskmem_features is not None else None,
            "maskmem_pos_enc": maskmem_pos_enc,
        }
        if not self.training:
            current_out["object_score_logits"] = sam_outputs.object_score_logits

        return current_out

    def _batch_encode_memories(self):
        raise NotImplementedError("Batch memory encoding is not implemented for EdgeTamVideo yet.")
        # Todo, implement batch memory encoding for edgetam video


__all__ = [
    "EdgeTamVideoMaskDecoderConfig",
    "EdgeTamVideoPromptEncoderConfig",
    "EdgeTamVideoConfig",
    "EdgeTamVideoModel",
    "EdgeTamVideoInferenceSession",
    "EdgeTamVideoPreTrainedModel",
]
