import torch

from ...utils import TensorType
from ..superglue.image_processing_superglue_fast import SuperGlueImageProcessorFast
from .modeling_efficientloftr import EfficientLoFTRKeypointMatchingOutput


class EfficientLoFTRImageProcessorFast(SuperGlueImageProcessorFast):
    def post_process_keypoint_matching(
        self,
        outputs: "EfficientLoFTRKeypointMatchingOutput",
        target_sizes: TensorType | list[tuple],
        threshold: float = 0.0,
    ) -> list[dict[str, torch.Tensor]]:
        """
        Converts the raw output of [`EfficientLoFTRKeypointMatchingOutput`] into lists of keypoints, scores and descriptors
        with coordinates absolute to the original image sizes.
        Args:
            outputs ([`EfficientLoFTRKeypointMatchingOutput`]):
                Raw outputs of the model.
            target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
                Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
                target size `(height, width)` of each image in the batch. This must be the original image size (before
                any processing).
            threshold (`float`, *optional*, defaults to 0.0):
                Threshold to filter out the matches with low scores.
        Returns:
            `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
            of the pair, the matching scores and the matching indices.
        """
        if outputs.matches.shape[0] != len(target_sizes):
            raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
        if not all(len(target_size) == 2 for target_size in target_sizes):
            raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")

        if isinstance(target_sizes, list):
            image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
        else:
            if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
                raise ValueError(
                    "Each element of target_sizes must contain the size (h, w) of each image of the batch"
                )
            image_pair_sizes = target_sizes

        keypoints = outputs.keypoints.clone()
        keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
        keypoints = keypoints.to(torch.int32)

        results = []
        for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
            # Filter out matches with low scores
            valid_matches = torch.logical_and(scores > threshold, matches > -1)

            matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
            matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
            matching_scores = scores[0][valid_matches[0]]

            results.append(
                {
                    "keypoints0": matched_keypoints0,
                    "keypoints1": matched_keypoints1,
                    "matching_scores": matching_scores,
                }
            )

        return results


__all__ = ["EfficientLoFTRImageProcessorFast"]
