
    qi3              	          d Z ddlZddlZddlmZ ddlZddlmZ ddlm	Z
 ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZ ddlmZ  ej6                  e      Ze ed       G d de                    Ze ed       G d de                    Ze ed       G d de                    Z e ed       G d de                    Z! G d dejD                        Z# G d dejD                        Z$d?d ejJ                  d!e&d"e'd#ejJ                  fd$Z( G d% d&ejD                        Z) G d' d(ejD                        Z* G d) d*ejD                        Z+ G d+ d,ejD                        Z, G d- d.e      Z- G d/ d0ejD                        Z.e G d1 d2e             Z/e G d3 d4e/             Z0 ed5       G d6 d7e/             Z1 ed8       G d9 d:e/             Z2 ed;       G d< d=ee/             Z3g d>Z4y)@zPyTorch FocalNet model.    N)	dataclass)nn   )initialization)ACT2FN)BackboneMixin)GradientCheckpointingLayer)BackboneOutput)PreTrainedModel)ModelOutputauto_docstringlogging   )FocalNetConfigzC
    FocalNet encoder's outputs, with potential hidden states.
    )custom_introc                       e Zd ZU dZdZej                  dz  ed<   dZe	ej                     dz  ed<   dZ
e	ej                     dz  ed<   y)FocalNetEncoderOutputa  
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlast_hidden_statehidden_statesreshaped_hidden_states)__name__
__module____qualname____doc__r   torchFloatTensor__annotations__r   tupler        `/opt/pipecat/venv/lib/python3.12/site-packages/transformers/models/focalnet/modeling_focalnet.pyr   r   $   sT     37u((4/659M5**+d29>BE%"3"34t;Br    r   zZ
    FocalNet model's outputs that also contains a pooling of the last hidden states.
    c                       e Zd ZU dZdZej                  dz  ed<   dZej                  dz  ed<   dZ	e
ej                     dz  ed<   dZe
ej                     dz  ed<   y)FocalNetModelOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
        Average pooling of the last layer hidden-state.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr   pooler_outputr   r   )r   r   r   r   r   r   r   r   r$   r   r   r   r   r    r!   r#   r#   9   si    	 37u((4/6.2M5$$t+259M5**+d29>BE%"3"34t;Br    r#   z.
    FocalNet masked image model outputs.
    c                       e Zd ZU dZdZej                  dz  ed<   dZej                  dz  ed<   dZ	e
ej                     dz  ed<   dZe
ej                     dz  ed<   y)!FocalNetMaskedImageModelingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
        Masked image modeling (MLM) loss.
    reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        Reconstructed pixel values.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlossreconstructionr   r   )r   r   r   r   r'   r   r   r   r(   r   r   r   r   r    r!   r&   r&   Q   sh     &*D%

d
")/3NE%%,359M5**+d29>BE%"3"34t;Br    r&   z4
    FocalNet outputs for image classification.
    c                       e Zd ZU dZdZej                  dz  ed<   dZej                  dz  ed<   dZ	e
ej                     dz  ed<   dZe
ej                     dz  ed<   y)FocalNetImageClassifierOutputa7  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Classification (or regression if config.num_labels==1) loss.
    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
        Classification (or regression if config.num_labels==1) scores (before SoftMax).
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr'   logitsr   r   )r   r   r   r   r'   r   r   r   r+   r   r   r   r   r    r!   r*   r*   k   sh     &*D%

d
")'+FE$+59M5**+d29>BE%"3"34t;Br    r*   c                        e Zd ZdZd fd	Z	 d	dej                  dz  dej                  dz  deej                     fdZ
 xZS )
FocalNetEmbeddingszX
    Construct the patch embeddings and layernorm. Optionally, also the mask token.
    c           	         t         |           t        ||j                  |j                  |j
                  |j                  |j                  d      | _        | j                  j                  | _
        |r4t        j                  t        j                  dd|j                              nd | _        t        j                   |j                  |j"                        | _        t        j&                  |j(                        | _        y )NT)config
image_size
patch_sizenum_channels	embed_dimuse_conv_embedis_stemr   eps)super__init__FocalNetPatchEmbeddingsr0   r1   r2   r3   r4   patch_embeddings	grid_size
patch_gridr   	Parameterr   zeros
mask_token	LayerNormlayer_norm_epsnormDropouthidden_dropout_probdropout)selfr/   use_mask_token	__class__s      r!   r9   zFocalNetEmbeddings.__init__   s     7((((,,&&!00!
 //99O]",,u{{1a9I9I'JKcgLL!1!1v7L7LM	zz&"<"<=r    Npixel_valuesbool_masked_posreturnc                 8   | j                  |      \  }}| j                  |      }|j                         \  }}}|K| j                  j	                  ||d      }|j                  d      j                  |      }	|d|	z
  z  ||	z  z   }| j                  |      }||fS )N      ?)r;   rC   sizer@   expand	unsqueezetype_asrF   )
rG   rJ   rK   
embeddingsoutput_dimensions
batch_sizeseq_len_mask_tokensmasks
             r!   forwardzFocalNetEmbeddings.forward   s     )-(=(=l(K%
%YYz*
!+!2
GQ&//00WbIK",,R088ED#sTz2[45GGJ\\*-
,,,r    )FN)r   r   r   r   r9   r   r   
BoolTensorr   Tensorr[   __classcell__rI   s   @r!   r-   r-      sQ    >& bf-!--4-GLGWGWZ^G^-	u||	-r    r-   c                   z     e Zd Z	 	 	 d fd	Zd Zdej                  dz  deej                  ee	   f   fdZ
 xZS )r:   c	                 d   t         |           t        |t        j                  j
                        r|n||f}t        |t        j                  j
                        r|n||f}|d   |d   z  |d   |d   z  z  }	|| _        || _        || _        |	| _	        |d   |d   z  |d   |d   z  f| _
        |r/|rd}
d}d}nd}
d}d}t        j                  |||
||      | _        nt        j                  ||||      | _        |r't        j                  ||j                  	      | _        y d | _        y )
Nr   r            r   )kernel_sizestridepadding)rf   rg   r6   )r8   r9   
isinstancecollectionsabcIterabler0   r1   r2   num_patchesr<   r   Conv2d
projectionrA   rB   rC   )rG   r/   r0   r1   r2   r3   add_normr4   r5   rm   rf   rh   rg   rI   s                r!   r9   z FocalNetPatchEmbeddings.__init__   s7    	#-j+//:R:R#SZZdfpYq
#-j+//:R:R#SZZdfpYq
!!}
15*Q-:VW=:XY$$(&$Q-:a=8*Q-:VW=:XY iii[Y`DO !iiiZ`jkDOYF4I4IJDIDIr    c                 n   || j                   d   z  dk7  rDd| j                   d   || j                   d   z  z
  f}t        j                  j                  ||      }|| j                   d   z  dk7  rFddd| j                   d   || j                   d   z  z
  f}t        j                  j                  ||      }|S )Nr   r   )r1   r   
functionalpad)rG   rJ   heightwidth
pad_valuess        r!   	maybe_padz!FocalNetPatchEmbeddings.maybe_pad   s    4??1%%*T__Q/%$//!:L2LLMJ==,,\:FLDOOA&&!+Q4??1#5QRAS8S#STJ==,,\:FLr    rJ   NrL   c                 N   |j                   \  }}}}|| j                  k7  rt        d      | j                  |||      }| j	                  |      }|j                   \  }}}}||f}|j                  d      j                  dd      }| j                  | j                  |      }||fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rd   r   )shaper2   
ValueErrorrw   ro   flatten	transposerC   )rG   rJ   rX   r2   rt   ru   rT   rU   s           r!   r[   zFocalNetPatchEmbeddings.forward   s    )5););&<4,,,w  ~~lFEB__\2
(..1fe#UO''*44Q:
99 :.J,,,r    )FFF)r   r   r   r9   rw   r   r   r   r^   intr[   r_   r`   s   @r!   r:   r:      sL     (T-E$5$5$< -u||UZ[^U_G_A` -r    r:   input	drop_probtrainingrL   c                    |dk(  s|s| S d|z
  }| j                   d   fd| j                  dz
  z  z   }|t        j                  || j                  | j
                        z   }|j                          | j                  |      |z  }|S )zc
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

            r   r   )r   )dtypedevice)ry   ndimr   randr   r   floor_div)r~   r   r   	keep_probry   random_tensoroutputs          r!   	drop_pathr      s    
 CxII[[^

Q 77E

5ELL YYMYYy!M1FMr    c                   x     e Zd ZdZd	dedz  ddf fdZdej                  dej                  fdZde	fdZ
 xZS )
FocalNetDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   rL   c                 0    t         |           || _        y r\   )r8   r9   r   )rG   r   rI   s     r!   r9   zFocalNetDropPath.__init__  s    "r    r   c                 D    t        || j                  | j                        S r\   )r   r   r   )rG   r   s     r!   r[   zFocalNetDropPath.forward  s    FFr    c                      d| j                    S )Nzp=)r   rG   s    r!   
extra_reprzFocalNetDropPath.extra_repr  s    DNN#$$r    r\   )r   r   r   r   floatr9   r   r^   r[   strr   r_   r`   s   @r!   r   r     sG    b#%$, #$ #GU\\ Gell G%C %r    r   c                   &     e Zd Zd fd	Zd Z xZS )FocalNetModulationc                    t         	|           || _        |j                  |   | _        |j
                  |   | _        || _        |j                  | _        |j                  | _	        t        j                  |d|z  | j                  dz   z   |      | _        t        j                  ||dd|      | _        t        j                         | _        t        j                  ||      | _        t        j$                  |      | _        t        j(                         | _        g | _        t/        | j                        D ]  }| j                  |z  | j                  z   }| j*                  j1                  t        j2                  t        j                  |||d||dz  d      t        j                                      | j,                  j1                  |        | j                  r't        j4                  ||j6                        | _        y y )Nrd   r   )bias)rf   rg   r   F)rf   rg   groupsrh   r   r6   )r8   r9   dimfocal_windowsfocal_windowfocal_levelsfocal_levelfocal_factor use_post_layernorm_in_modulationnormalize_modulatorr   Linearprojection_inrn   projection_contextGELU
activationprojection_outrD   projection_dropout
ModuleListfocal_layerskernel_sizesrangeappend
SequentialrA   rB   	layernorm)
rG   r/   indexr   r   r   r   krf   rI   s
            r!   r9   zFocalNetModulation.__init__  s   "007!..u5(060W0W-#)#=#= YYsAGt7G7G!7K,LSWX"$))C!ATX"Y'') iiS1"$**-?"@MMOt''( 
	2A++a/$2C2CCK$$IISk!CYdhiYipu GGI	 $$[1
	2 00\\#63H3HIDN 1r    c                 ,   |j                   d   }| j                  |      j                  dddd      j                         }t	        j
                  |||| j                  dz   fd      \  }}}d}t        | j                        D ]+  } | j                  |   |      }|||dd||dz   f   z  z   }- | j                  |j                  dd      j                  dd            }	||	|dd| j                  df   z  z   }| j                  r|| j                  dz   z  }| j                  |      }
||
z  }|j                  dddd      j                         }| j                  r| j                  |      }| j                  |      }| j!                  |      }|S )	z
        Args:
            hidden_state:
                Input features with shape of (batch_size, height, width, num_channels)
        rN   r   r   r   rd   NT)keepdim)ry   r   permute
contiguousr   splitr   r   r   r   meanr   r   r   r   r   r   )rG   hidden_stater2   xqctxgatesctx_alllevel
ctx_global	modulatorx_outs               r!   r[   zFocalNetModulation.forward5  s    $))"- |,44Q1a@KKMAlDDTDTWXDX'Y[\]3 4++, 	BE*$##E*3/CeAuuqy/@,@&A AAG	B __SXXaX%>%C%CAt%C%TU
Jq$2B2B2D/D)EEE ##!1!1A!56G ++G4	IaAq)44600NN5)E ##E*''.r    )rd   Tr   r   r   r   r9   r[   r_   r`   s   @r!   r   r     s    JB"r    r   c                   &     e Zd Zd fd	Zd Z xZS )FocalNetMlpc                 
   t         |           |xs |}|xs |}t        j                  ||      | _        t
        |j                     | _        t        j                  ||      | _        t        j                  |      | _
        y r\   )r8   r9   r   r   fc1r   
hidden_actr   fc2rD   drop)rG   r/   in_featureshidden_featuresout_featuresr   rI   s         r!   r9   zFocalNetMlp.__init__[  sh    #2{)8[99[/: !2!2399_l;JJt$	r    c                     | j                  |      }| j                  |      }| j                  |      }| j                  |      }| j                  |      }|S r\   )r   r   r   r   )rG   r   s     r!   r[   zFocalNetMlp.forwardd  sN    xx-|4yy.xx-yy.r    )NNr   r   r`   s   @r!   r   r   Z  s    %r    r   c                   *     e Zd ZdZd fd	Zd Z xZS )FocalNetLayera  Focal Modulation Network layer (block).

    Args:
        config (`FocalNetConfig`):
            Model config.
        index (`int`):
            Layer index.
        dim (`int`):
            Number of input channels.
        input_resolution (`tuple[int]`):
            Input resolution.
        drop_path (`float`, *optional*, defaults to 0.0):
            Stochastic depth rate.
    c                 H   t         |           || _        || _        || _        |j
                  | _        |j                  | _        t        j                  ||j                        | _        t        |||| j                        | _        |dkD  rt        |      nt        j                         | _        t        j                  ||j                        | _        t%        ||j&                  z        }t)        |||| j                        | _        d| _        d| _        |j0                  ryt        j2                  |j4                  t7        j8                  |      z  d      | _        t        j2                  |j4                  t7        j8                  |      z  d      | _        y y )Nr6   )r/   r   r   r   r   )r/   r   r   r   rO   T)requires_grad)r8   r9   r/   r   input_resolutionrE   r   use_post_layernormr   rA   rB   norm1r   
modulationr   Identityr   norm2r}   	mlp_ratior   mlpgamma_1gamma_2use_layerscaler>   layerscale_valuer   ones)rG   r/   r   r   r   r   mlp_hidden_dimrI   s          r!   r9   zFocalNetLayer.__init__}  sC     0 ..	"(";";\\#6+@+@A
,#yy	
 9BC))4R[[]\\#6+@+@A
S6#3#334f#~dhdmdmn  <<(?(?%**S/(QaefDL<<(?(?%**S/(QaefDL !r    c           	      :   |\  }}|j                   \  }}}|}| j                  r|n| j                  |      }|j                  ||||      }| j	                  |      j                  |||z  |      }| j                  s|n| j                  |      }|| j                  | j                  |z        z   }|| j                  | j                  | j                  r | j                  | j                  |            n| j                  | j                  |            z        z   }|S r\   )
ry   r   r   viewr   r   r   r   r   r   )	rG   r   input_dimensionsrt   ru   rV   rX   r2   shortcuts	            r!   r[   zFocalNetLayer.forward  s   (&2&8&8#
A| (,'>'>|DJJ|D\#((VULQ|499*funVbc+/+B+B|

S_H`  $..1L"MM#dnnLL595L5Ltzz$((<01RVRZRZ[_[e[efr[sRtv'
 

 r    )r   )r   r   r   r   r9   r[   r_   r`   s   @r!   r   r   m  s    g@r    r   c                   j     e Zd Z fdZdej
                  deeef   deej
                     fdZ xZ	S )FocalNetStagec                    t         |           || _        t        |j                        | _        t        | j
                        D cg c]  }|j                  d|z  z   }}||   }|| j
                  dz
  k  r||dz      nd }|| j
                  dz
  k  rt        nd }t        j                  d|j                  t        |j                        d      D 	cg c]  }	|	j                          }
}	|
t        |j                  d |       t        |j                  d |dz           }t        j                  t        |j                  |         D cg c]'  }t!        ||||t#        |t$              r||   n|      ) c}      | _        |' |||d||d|j(                  d	      | _        d| _        y d | _        d| _        y c c}w c c}	w c c}w )
Nrd   r   r   cpu)r   )r/   r   r   r   r   TF)r/   r0   r1   r2   r3   rp   r4   r5   )r8   r9   r/   lendepths
num_stagesr   r3   r:   r   linspacedrop_path_ratesumitemr   r   r   ri   listlayersr4   
downsamplepointing)rG   r/   r   r   ir3   r   out_dimr   r   dprr   rI   s               r!   r9   zFocalNetStage.__init__  s   fmm,8=doo8NO1V%%A.O	O+04??Q3F+F)EAI&T1619L1L,SW
 "'63H3H#fmmJ\ej!klAqvvxllFMM&512S{QR9S5TU	mm v}}U34	  !%5.8D.Iily	
 !(+ !%44	DO  #DOI P m	s   F<G,Gr   r   rL   c                    |\  }}| j                   D ]  } |||      } |}| j                  K|\  }}|j                  dd      j                  |j                  d   d||      }| j                  |      \  }}n||||f}|||f}|S )Nr   rd   r   rN   )r   r   r|   reshapery   )	rG   r   r   rt   ru   layer_module!hidden_states_before_downsamplingrU   stage_outputss	            r!   r[   zFocalNetStage.forward  s    ( KK 	JL(8HIM	J -:)??&,MFE)33Aq9AA177:BM 04}/M,M, "( >&(IK\]r    )
r   r   r   r9   r   r^   r   r}   r[   r_   r`   s   @r!   r   r     s=    *XU\\ U3PS8_ Y^_d_k_kYl r    r   c                   |     e Zd Z fdZ	 	 	 d
dej
                  deeef   dedz  dedz  dedz  dee	z  fd	Z
 xZS )FocalNetEncoderc                 2   t         |           t        |j                        | _        || _        t        j                  t        | j                        D cg c]$  }t        |||d   d|z  z  |d   d|z  z  f      & c}      | _
        d| _        y c c}w )Nr   rd   r   )r/   r   r   F)r8   r9   r   r   r   r/   r   r   r   r   stagesgradient_checkpointing)rG   r/   r<   i_layerrI   s       r!   r9   zFocalNetEncoder.__init__  s    fmm,mm  %T__5  !!&/lq'z&BIaLUVX_U_D`%a	
 ',#s   )Br   r   output_hidden_statesN(output_hidden_states_before_downsamplingreturn_dictrL   c                    |rdnd }|rdnd }|rE|j                   \  }}	}
 |j                  |g||
 }|j                  dddd      }||fz  }||fz  }t        | j                        D ]  \  }} |||      }|d   }|d   }|d   }|d   |d   f}|rP|rN|j                   \  }}	}
 |j                  |g|d   |d   f|
 }|j                  dddd      }||fz  }||fz  }z|s}|r|j                   \  }}	}
 |j                  |g||
 }|j                  dddd      }||fz  }||fz  } |st        d ||fD              S t        |||	      S )
Nr   r   r   r   rd   rN   c              3   &   K   | ]	  }||  y wr\   r   ).0vs     r!   	<genexpr>z*FocalNetEncoder.forward.<locals>.<genexpr>6  s     Xq!-Xs   )r   r   r   )ry   r   r   	enumerater   r   r   )rG   r   r   r  r  r  all_hidden_statesall_reshaped_hidden_statesrV   rX   hidden_sizereshaped_hidden_stater   stage_moduler   r   rU   s                    r!   r[   zFocalNetEncoder.forward  s    #7BD+?RT")6)<)<&J;$6M$6$6z$bDT$bVa$b!$9$A$A!Q1$M!-!11&+@*BB&(5 	GOA|(8HIM)!,M0=a0@- -a 0 1" 57H7LM#(P-N-T-T*
A{ )O(I(N(N)"3A"68I!8L!M)OZ)% )>(E(EaAq(Q%!&G%II!*/D.FF*%.V-:-@-@*
A{(:(:(::(fHX(fZe(f%(=(E(EaAq(Q%!m%55!*/D.FF*3	G6 X]4E$FXXX$++#=
 	
r    )FFT)r   r   r   r9   r   r^   r   r}   boolr   r[   r_   r`   s   @r!   r   r     sp    ,, -2@E#'5
||5
  S/5
 #Tk	5

 37+5
 D[5
 
&	&5
r    r   c                   d     e Zd ZU eed<   dZdZdZdgZ e	j                          fd       Z xZS )FocalNetPreTrainedModelr/   focalnetrJ   Tr   c                    t         |   |       t        |t              r-|j                   t        j                  |j                         yyt        |t              r| j                  j                  rit        j                  |j                  | j                  j                         t        j                  |j                  | j                  j                         yyy)zInitialize the weightsN)r8   _init_weightsri   r-   r@   initzeros_r   r/   r   	constant_r   r   r   )rG   modulerI   s     r!   r  z%FocalNetPreTrainedModel._init_weightsG  s     	f%f01  ,F--. -.{{))v~~t{{/K/KLv~~t{{/K/KL * /r    )r   r   r   r   r   base_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modulesr   no_gradr  r_   r`   s   @r!   r  r  ?  s?    "$O&*#()U]]_	M 	Mr    r  c                        e Zd Zd
 fd	Zd Ze	 	 	 	 ddej                  dz  dej                  dz  de	dz  de	dz  de
ez  f
d	       Z xZS )FocalNetModelc                    t         |   |       || _        t        |j                        | _        t        |j                  d| j
                  dz
  z  z        | _        t        ||      | _
        t        || j                  j                        | _        t        j                  | j                  |j                         | _        |rt        j$                  d      nd| _        | j)                          y)z
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        use_mask_token (`bool`, *optional*, defaults to `False`):
            Whether to use a mask token for masked image modeling.
        rd   r   )rH   r6   N)r8   r9   r/   r   r   r   r}   r3   num_featuresr-   rT   r   r=   encoderr   rA   rB   r   AdaptiveAvgPool1dpooler	post_init)rG   r/   add_pooling_layerrH   rI   s       r!   r9   zFocalNetModel.__init__V  s     	 fmm, 0 0119L3M MN,VNS&vt/I/IJd&7&7V=R=RS1Bb**1- 	r    c                 .    | j                   j                  S r\   )rT   r;   r   s    r!   get_input_embeddingsz"FocalNetModel.get_input_embeddingsk  s    ///r    NrJ   rK   r  r  rL   c                    ||n| j                   j                  }||n| j                   j                  }|t        d      | j	                  ||      \  }}| j                  ||||      }|d   }	| j                  |	      }	d}
| j                  7| j                  |	j                  dd            }
t        j                  |
d      }
|s|	|
f|dd z   }|S t        |	|
|j                  |j                        S )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rK   r  r  r   r   rd   )r   r$   r   r   )r/   r  use_return_dictrz   rT   r$  r   r&  r|   r   r{   r#   r   r   )rG   rJ   rK   r  r  kwargsembedding_outputr   encoder_outputssequence_outputpooled_outputr   s               r!   r[   zFocalNetModel.forwardn  s    %9$D $++JjJj 	 &1%<k$++B]B]?@@-1__\[j_-k**,,!5#	 ' 
 *!,..9;;" KK(A(A!Q(GHM!MM-;M%}58KKFM"-')77#2#I#I	
 	
r    )TFNNNN)r   r   r   r9   r*  r   r   r   r]   r  r   r#   r[   r_   r`   s   @r!   r!  r!  T  s    *0  2637,0#'/
''$./
 ))D0/
 #Tk	/

 D[/
 
$	$/
 /
r    r!  a  
    FocalNet Model with a decoder on top for masked image modeling.

    This follows the same implementation as in [SimMIM](https://huggingface.co/papers/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    c                        e Zd Z fdZe	 	 	 	 d	dej                  dz  dej                  dz  dedz  dedz  de	e
z  f
d       Z xZS )
FocalNetForMaskedImageModelingc                    t         |   |       t        |dd      | _        t	        |j
                        | _        t        |j                  d| j                  dz
  z  z        }t        j                  t        j                  ||j                  dz  |j                  z  d      t        j                  |j                              | _        | j!                          y )NFT)r(  rH   rd   r   )in_channelsout_channelsrf   )r8   r9   r!  r  r   r   r   r}   r3   r   r   rn   encoder_strider2   PixelShuffledecoderr'  )rG   r/   r#  rI   s      r!   r9   z'FocalNetForMaskedImageModeling.__init__  s     %fVZ[fmm,6++aDOOa4G.HHI}}II(v7L7La7ORXReRe7est OOF112	
 	r    NrJ   rK   r  r  rL   c                    ||n| j                   j                  }| j                  ||||      }|d   }|j                  dd      }|j                  \  }}	}
t        j                  |
dz        x}}|j                  ||	||      }| j                  |      }d}|| j                   j                  | j                   j                  z  }|j                  d||      }|j                  | j                   j                  d      j                  | j                   j                  d      j                  d      j                         }t        j                  j!                  ||d	      }||z  j#                         |j#                         d
z   z  | j                   j$                  z  }|s|f|dd z   }||f|z   S |S t'        |||j(                  |j*                        S )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import httpx
        >>> from io import BytesIO

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> with httpx.stream("GET", url) as response:
        ...     image = Image.open(BytesIO(response.read()))

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
        >>> config = FocalNetConfig()
        >>> model = FocalNetForMaskedImageModeling(config)

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 192, 192]
        ```N)rK   r  r  r   r   rd   g      ?rN   none)	reductiongh㈵>)r'   r(   r   r   )r/   r-  r  r|   ry   mathfloorr   r;  r0   r1   repeat_interleaverR   r   r   rr   l1_lossr   r2   r&   r   r   )rG   rJ   rK   r  r  r.  outputsr1  rV   r2   sequence_lengthrt   ru   reconstructed_pixel_valuesmasked_im_lossrP   rZ   reconstruction_lossr   s                      r!   r[   z&FocalNetForMaskedImageModeling.forward  s   N &1%<k$++B]B]--+!5#	   
 "!*)33Aq94C4I4I1
L/OS$899)11*lFTYZ &*\\/%B"&;;))T[[-C-CCD-55b$EO11$++2H2H!L""4;;#9#91=1	  #%--"7"7F`lr"7"s1D8==?488:PTCTUX\XcXcXpXppN02WQR[@F3A3M^%.YSYY05!//#*#A#A	
 	
r    r3  )r   r   r   r9   r   r   r   r]   r  r   r&   r[   r_   r`   s   @r!   r5  r5    s    "  2637,0#'O
''$.O
 ))D0O
 #Tk	O

 D[O
 
2	2O
 O
r    r5  z
    FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
    ImageNet.
    c                        e Zd Z fdZe	 	 	 	 d	dej                  dz  dej                  dz  dedz  dedz  de	e
z  f
d       Z xZS )
FocalNetForImageClassificationc                 >   t         |   |       |j                  | _        t        |      | _        |j                  dkD  r4t        j                  | j                  j                  |j                        nt        j                         | _	        | j                          y )Nr   )r8   r9   
num_labelsr!  r  r   r   r#  r   
classifierr'  rG   r/   rI   s     r!   r9   z'FocalNetForImageClassification.__init__  sx      ++%f- IOHYHY\]H]BIIdmm00&2C2CDcecncncp 	
 	r    NrJ   labelsr  r  rL   c                 <   ||n| j                   j                  }| j                  |||      }|d   }| j                  |      }d}	|| j	                  ||| j                         }	|s|f|dd z   }
|	|	f|
z   S |
S t        |	||j                  |j                        S )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr,  r   rd   )r'   r+   r   r   )r/   r-  r  rL  loss_functionr*   r   r   )rG   rJ   rN  r  r  r.  rC  r2  r+   r'   r   s              r!   r[   z&FocalNetForImageClassification.forward*  s     &1%<k$++B]B]--!5#   
  
/%%ffdkkBDY,F)-)9TGf$EvE,!//#*#A#A	
 	
r    r3  )r   r   r   r9   r   r   r   
LongTensorr  r   r*   r[   r_   r`   s   @r!   rI  rI    s~      26*.,0#''
''$.'
   4''
 #Tk	'

 D['
 
.	.'
 '
r    rI  zG
    FocalNet backbone, to be used with frameworks like X-Decoder.
    c            
       p     e Zd ZdZdef fdZe	 	 d
dej                  de	dz  de	dz  de
fd	       Z xZS )FocalNetBackboneFr/   c                     t         |   |       |j                  g|j                  z   | _        t        |      | _        | j                          y r\   )r8   r9   r3   hidden_sizesr#  r!  r  r'  rM  s     r!   r9   zFocalNetBackbone.__init__]  sD     #--.1D1DD%f- 	r    NrJ   r  r  rL   c                    ||n| j                   j                  }||n| j                   j                  }| j                  |dd      }|j                  }d}t        | j                        D ]  \  }}	|	| j                  v s|||   fz  } |s|f}
|r|
|j                  fz  }
|
S t        ||r|j                  d      S dd      S )a  
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import httpx
        >>> from io import BytesIO

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> with httpx.stream("GET", url) as response:
        ...     image = Image.open(BytesIO(response.read()))

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
        >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        ```NTr,  r   )feature_mapsr   
attentions)
r/   r-  r  r  r   r  stage_namesr   r   r
   )rG   rJ   r  r  r.  rC  r   rW  idxstager   s              r!   r[   zFocalNetBackbone.forwardf  s    8 &1%<k$++B]B]$8$D $++JjJj 	 --4UY-Z66#D$4$45 	6JC)))s!3 55	6 "_F#70022M%3G'//
 	
MQ
 	
r    )NN)r   r   r   has_attentionsr   r9   r   r   r^   r  r
   r[   r_   r`   s   @r!   rS  rS  U  sd     N~   -1#'	3
ll3
 #Tk3
 D[	3
 
3
 3
r    rS  )rI  r5  rS  r!  r  )r   F)5r   collections.abcrj   r?  dataclassesr   r   r    r   r  activationsr   backbone_utilsr   modeling_layersr	   modeling_outputsr
   modeling_utilsr   utilsr   r   r   configuration_focalnetr   
get_loggerr   loggerr   r#   r&   r*   Moduler-   r:   r^   r   r  r   r   r   r   r   r   r   r  r!  r5  rI  rS  __all__r   r    r!   <module>rk     s      !   & ! + 9 . - 9 9 2 
		H	% 
CK C C 
C+ C C$ 
C C C( 
CK C C(%- %-PD-bii D-PU\\ e T V[VbVb  %ryy %D DN")) &BBII BJ?. ?DH
bii H
V Mo M M( I
+ I
 I
X b
%< b
b
J 8
%< 8
8
v 
@
}&= @

@
Fr    