# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Nougat.
"""

from typing import Optional, Union

from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy

from ...processing_utils import ProcessorMixin
from ...utils import PaddingStrategy, TensorType, auto_docstring


@auto_docstring
class NougatProcessor(ProcessorMixin):
    def __init__(self, image_processor, tokenizer):
        super().__init__(image_processor, tokenizer)

    @auto_docstring
    def __call__(
        self,
        images=None,
        text=None,
        do_crop_margin: bool | None = None,
        do_resize: bool | None = None,
        size: dict[str, int] | None = None,
        resample: "PILImageResampling" = None,  # noqa: F821
        do_thumbnail: bool | None = None,
        do_align_long_axis: bool | None = None,
        do_pad: bool | None = None,
        do_rescale: bool | None = None,
        rescale_factor: int | float | None = None,
        do_normalize: bool | None = None,
        image_mean: float | list[float] | None = None,
        image_std: float | list[float] | None = None,
        data_format: Optional["ChannelDimension"] = "channels_first",  # noqa: F821
        input_data_format: Union[str, "ChannelDimension"] | None = None,  # noqa: F821
        text_pair: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
        text_target: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
        text_pair_target: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
        add_special_tokens: bool = True,
        padding: bool | str | PaddingStrategy = False,
        truncation: bool | str | TruncationStrategy | None = None,
        max_length: int | None = None,
        stride: int = 0,
        is_split_into_words: bool = False,
        pad_to_multiple_of: int | None = None,
        return_tensors: str | TensorType | None = None,
        return_token_type_ids: bool | None = None,
        return_attention_mask: bool | None = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_length: bool = False,
        verbose: bool = True,
    ):
        r"""
        do_crop_margin (`bool`, *optional*):
            Whether to automatically crop white margins from document images. When enabled, the processor detects
            and removes white space around the edges of document pages, which is useful for processing scanned
            documents or PDFs with large margins.
        do_thumbnail (`bool`, *optional*):
            Whether to create a thumbnail version of the image. When enabled, a smaller version of the image is
            generated alongside the main processed image, which can be useful for preview or faster processing.
        do_align_long_axis (`bool`, *optional*):
            Whether to automatically align images so that the longer axis is horizontal. When enabled, portrait
            images are rotated to landscape orientation, which is typically better for document processing tasks.
        """
        if images is None and text is None:
            raise ValueError("You need to specify either an `images` or `text` input to process.")

        if images is not None:
            inputs = self.image_processor(
                images,
                do_crop_margin=do_crop_margin,
                do_resize=do_resize,
                size=size,
                resample=resample,
                do_thumbnail=do_thumbnail,
                do_align_long_axis=do_align_long_axis,
                do_pad=do_pad,
                do_rescale=do_rescale,
                rescale_factor=rescale_factor,
                do_normalize=do_normalize,
                image_mean=image_mean,
                image_std=image_std,
                return_tensors=return_tensors,
                data_format=data_format,
                input_data_format=input_data_format,
            )
        if text is not None:
            encodings = self.tokenizer(
                text,
                text_pair=text_pair,
                text_target=text_target,
                text_pair_target=text_pair_target,
                add_special_tokens=add_special_tokens,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                stride=stride,
                is_split_into_words=is_split_into_words,
                pad_to_multiple_of=pad_to_multiple_of,
                return_tensors=return_tensors,
                return_token_type_ids=return_token_type_ids,
                return_attention_mask=return_attention_mask,
                return_overflowing_tokens=return_overflowing_tokens,
                return_special_tokens_mask=return_special_tokens_mask,
                return_offsets_mapping=return_offsets_mapping,
                return_length=return_length,
                verbose=verbose,
            )

        if text is None:
            return inputs
        elif images is None:
            return encodings
        else:
            inputs["labels"] = encodings["input_ids"]
            return inputs

    def post_process_generation(self, *args, **kwargs):
        """
        This method forwards all its arguments to NougatTokenizer's [`~PreTrainedTokenizer.post_process_generation`].
        Please refer to the docstring of this method for more information.
        """
        return self.tokenizer.post_process_generation(*args, **kwargs)


__all__ = ["NougatProcessor"]
