# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch import Tensor, nn

from ... import initialization as init
from ...backbone_utils import BackboneMixin
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...utils import is_timm_available, requires_backends
from .configuration_timm_backbone import TimmBackboneConfig


if is_timm_available():
    import timm


class TimmBackbone(BackboneMixin, PreTrainedModel):
    """
    Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
    other models in the library keeping the same API.
    """

    main_input_name = "pixel_values"
    input_modalities = ("image",)
    supports_gradient_checkpointing = False
    config: TimmBackboneConfig

    def __init__(self, config, **kwargs):
        requires_backends(self, "timm")

        if config.backbone is None:
            raise ValueError("backbone is not set in the config. Please set it to a timm model name.")

        # We just take the final layer by default. This matches the default for the transformers models.
        out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
        pretrained = kwargs.pop("pretrained", False)
        in_chans = kwargs.pop("in_chans", config.num_channels)

        backbone = timm.create_model(
            config.backbone,
            pretrained=pretrained,
            # This is currently not possible for transformer architectures.
            features_only=config.features_only,
            in_chans=in_chans,
            out_indices=out_indices,
            output_stride=config.output_stride,
            **kwargs,
        )

        # Needs to be called after creating timm model, because `super()` will try to infer
        # `stage_names` from model architecture
        super().__init__(config, timm_backbone=backbone)
        self._backbone = backbone

        # Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of
        # provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively
        if getattr(config, "freeze_batch_norm_2d", False):
            self.freeze_batch_norm_2d()

        # These are used to control the output of the model when called. If output_hidden_states is True, then
        # return_layers is modified to include all layers.
        self._return_layers = {
            layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts()
        }
        self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}

        self.post_init()

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        requires_backends(cls, ["vision", "timm"])

        config = kwargs.pop("config", TimmBackboneConfig())
        num_channels = kwargs.pop("num_channels", config.num_channels)
        features_only = kwargs.pop("features_only", config.features_only)
        out_indices = kwargs.pop("out_indices", config.out_indices)
        config = TimmBackboneConfig(
            backbone=pretrained_model_name_or_path,
            num_channels=num_channels,
            features_only=features_only,
            out_indices=out_indices,
        )
        return super()._from_config(config, pretrained=True, **kwargs)

    def freeze_batch_norm_2d(self):
        timm.utils.model.freeze_batch_norm_2d(self._backbone)

    def unfreeze_batch_norm_2d(self):
        timm.utils.model.unfreeze_batch_norm_2d(self._backbone)

    @torch.no_grad()
    def _init_weights(self, module):
        """We need to at least re-init the non-persistent buffers if the model was initialized on meta device (we
        assume weights and persistent buffers will be part of checkpoint as we have no way to control timm inits)"""
        if hasattr(module, "init_non_persistent_buffers"):
            module.init_non_persistent_buffers()
        elif isinstance(module, nn.BatchNorm2d):
            # For non-pretrained models, always initialize buffers (handles both meta device and to_empty() cases)
            running_mean = getattr(module, "running_mean", None)
            if running_mean is not None:
                init.zeros_(module.running_mean)
                init.ones_(module.running_var)
                init.zeros_(module.num_batches_tracked)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        **kwargs,
    ) -> BackboneOutput | tuple[Tensor, ...]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

        if output_attentions:
            raise ValueError("Cannot output attentions for timm backbones at the moment")

        if output_hidden_states:
            # We modify the return layers to include all the stages of the backbone
            self._backbone.return_layers = self._all_layers
            hidden_states = self._backbone(pixel_values, **kwargs)
            self._backbone.return_layers = self._return_layers
            feature_maps = tuple(hidden_states[i] for i in self.out_indices)
        else:
            feature_maps = self._backbone(pixel_values, **kwargs)
            hidden_states = None

        feature_maps = tuple(feature_maps)
        hidden_states = tuple(hidden_states) if hidden_states is not None else None

        if not return_dict:
            output = (feature_maps,)
            if output_hidden_states:
                output = output + (hidden_states,)
            return output

        return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None)


__all__ = ["TimmBackbone"]
