
    qi                        d dl Z d dlmZ d dlZd dlmZ d dlmZmZmZ ddl	m
Z
 ddlmZmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZ ddlmZmZ ddlm Z  ddl!m"Z"m#Z#m$Z$ ddl%m&Z&  e$jN                  e(      Z) G d dejT                        Z+ G d dejT                        Z,	 	 d0dejZ                  dej\                  dej\                  dej\                  dej\                  dz  de/dz  de/de e"   fdZ0 G d dejZ                        Z1 G d  d!e      Z2e# G d" d#e             Z3e# G d$ d%e3             Z4 e#d&'       G d( d)e3e             Z5e# G d* d+e3             Z6 e#d,'       G d- d.e3             Z7g d/Z8y)1    N)Callable)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)create_causal_mask)FlashAttentionKwargs)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions SequenceClassifierOutputWithPastTokenClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringlogging   )BioGptConfigc                   x     e Zd ZdZdedef fdZ	 	 d
dej                  dedej                  dz  f fd	Z xZ	S ) BioGptLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    num_embeddingsembedding_dimc                 N    d| _         t        | 	  || j                   z   |       y )N   )offsetsuper__init__)selfr   r   	__class__s      \/opt/pipecat/venv/lib/python3.12/site-packages/transformers/models/biogpt/modeling_biogpt.pyr$   z)BioGptLearnedPositionalEmbedding.__init__6   s$     $++5}E    Nattention_maskpast_key_values_lengthposition_idsc                     |8t        j                  |d      }||z  dz
  j                         }|dd|df   }t        |   || j
                  z         S )z3`input_ids_shape` is expected to be [bsz x seqlen].Nr   dim)torchcumsumlongr#   forwardr"   )r%   r)   r*   r+   r&   s       r'   r2   z(BioGptLearnedPositionalEmbedding.forward<   s^      <<A>L(>9A=CCEL'+A+B(BCLw|dkk9::r(   )r   N)
__name__
__module____qualname____doc__intr$   r/   
LongTensorr2   __classcell__r&   s   @r'   r   r   1   s]    Fs F3 F '(04	;((; !$; &&-	; ;r(   r   c            
       `     e Zd ZdZd
dededededz  f fdZdej                  f fd	Z	 xZ
S )BioGptScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
    r   r   padding_idxembed_scaleNc                 6    t         |   |||       || _        y N)r#   r$   r>   )r%   r   r   r=   r>   r&   s        r'   r$   z"BioGptScaledWordEmbedding.__init__R   s    D&r(   	input_idsc                 <    t         |   |      | j                  z  S r@   )r#   r2   r>   )r%   rA   r&   s     r'   r2   z!BioGptScaledWordEmbedding.forwardV   s    wy)D,<,<<<r(   )      ?)r3   r4   r5   r6   r7   floatr$   r/   Tensorr2   r9   r:   s   @r'   r<   r<   M   sE    's '3 'S '_dgk_k '= = =r(   r<   modulequerykeyvaluer)   scalingdropoutkwargsc                    ||j                  d      dz  }t        j                  ||j                  dd            |z  }|||z   }t        j
                  j                  |d      }t        j
                  j                  ||| j                        }t        j                  ||      }	|	j                  dd      j                         }	|	|fS )N      r!   r   r-   ptrainingr   )
sizer/   matmul	transposenn
functionalsoftmaxrK   rR   
contiguous)
rF   rG   rH   rI   r)   rJ   rK   rL   attn_weightsattn_outputs
             r'   eager_attention_forwardr\   Z   s     **R.D( <<s}}Q':;gEL!#n4==((2(>L==((6??([L,,|U3K''1-88:K$$r(   c                   Z    e Zd ZdZ	 	 	 	 	 	 ddedededededed	edz  d
edz  f fdZ	 	 	 	 	 dde	j                  de	j                  dz  dedz  de	j                  dz  dede	j                  dz  dee   dee	j                  e	j                  dz  ee	j                     dz  f   fdZ xZS )BioGptAttentionz=Multi-headed attention from 'Attention Is All You Need' paperN	embed_dim	num_headsrK   
is_decoderbias	is_causalconfig	layer_idxc	                    t         	|           || _        || _        || _        ||z  | _        || _        | j
                  |z  | j                  k7  rt        d| j                   d| d      | j
                  dz  | _        || _	        || _
        || _        |9| j                  r-t        j                  d| j                  j                   d       t!        j"                  |||      | _        t!        j"                  |||      | _        t!        j"                  |||      | _        t!        j"                  |||      | _        y )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).rO   zInstantiating a decoder z without passing `layer_idx` is not recommended and will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.rb   )r#   r$   r_   r`   rK   head_dimrd   
ValueErrorrJ   ra   rc   re   loggerwarning_oncer&   r3   rV   Lineark_projv_projq_projout_proj)
r%   r_   r`   rK   ra   rb   rc   rd   re   r&   s
            r'   r$   zBioGptAttention.__init__y   s$    	""!Y.MMI%$..8MdnnM]$YKr3  }}d*$""*4>>+B+B*C D, , ii	94@ii	94@ii	94@		)YTBr(   hidden_stateskey_value_statespast_key_valuesr)   output_attentionscache_positionrL   returnc                    |du}|j                   dd \  }	}
|r|j                   d   n|
}|	|
d| j                  f}|	|d| j                  f} | j                  |      j                  | j	                  dd      }d}|St        |t              rA|j                  j                  | j                        }|r|j                  }n|j                  }n|}|r|n|}|rK|I|rGj                  | j                     j                  }|j                  | j                     j                  }n| j                  |      }| j!                  |      } |j                  | j	                  dd      } |j                  | j	                  dd      }|T|s|nd}j#                  ||| j                  d|i      \  }}|r)t        |t              rd|j                  | j                  <   t%        j&                  | j(                  j*                  t,              } || ||||f| j.                  sdn| j0                  | j2                  |d	|\  }}|j5                  |	|
d      j7                         }| j9                  |      }||fS )
z#Input shape: Batch x Time x ChannelNrN   r   r!   Fru   T        )rK   rJ   rt   )shaperh   ro   viewrU   
isinstancer   
is_updatedgetre   cross_attention_cacheself_attention_cachelayerskeysvaluesrm   rn   updater   get_interfacerd   _attn_implementationr\   rR   rK   rJ   reshaperY   rp   )r%   rq   rr   rs   r)   rt   ru   rL   is_cross_attentionbsztgt_lensrc_lenq_input_shapekv_input_shapequery_statesr|   curr_past_key_valuescurrent_states
key_statesvalue_statesattention_interfacer[   rZ   s                          r'   r2   zBioGptAttention.forward   s{     .T9 %**3B/W/A"((+wgr4==9wDMM: 7t{{=166FPPQRTUV
&/+>?,77;;DNNK
%+:+P+P(+:+O+O('6$-?)]/"=*-44T^^DIIJ/66t~~FMML^4J;;~6L(.9CCAqIJ,<,,n=GG1ML*7It+?+F+Fdnn?OQ_>`,(
L &*_FY*ZAEO..t~~>(?(M(MKK,,.E)
 %8
%
  $}}C$,,LL/
%
 
%
!\ "))#w;FFHmmK0L((r(   )rx   FTFNN)NNNFN)r3   r4   r5   r6   r7   rD   boolr   r$   r/   rE   r	   r   r   tupler2   r9   r:   s   @r'   r^   r^   v   s=   G  &* $%C%C %C 	%C
 %C %C %C t#%C :%CT 15(,.2"'.2P)||P)  ,,-P) 	P)
 t+P)  P) t+P) -.P) 
u||U\\D0%2E2LL	MP)r(   r^   c                   :    e Zd Zddededz  f fdZ	 	 	 	 	 	 ddej                  dej                  dz  dedz  de	dz  d	e	dz  d
ej                  dz  dej                  dz  dee   deej                  eej                  ej                  f   dz  f   fdZ xZS )BioGptDecoderLayerNrd   re   c           	      n   t         |           |j                  | _        t	        | j                  |j
                  |j                  dd||      | _        |j                  | _	        t        |j                     | _        |j                  | _        t        j                  | j                        | _        t        j"                  | j                  |j$                        | _        t        j"                  |j$                  | j                        | _        t        j                  | j                        | _        y )NT)r_   r`   rK   ra   rc   rd   re   )r#   r$   hidden_sizer_   r^   num_attention_headsattention_probs_dropout_prob	self_attnhidden_dropout_probrK   r   
hidden_actactivation_fnactivation_dropoutrV   	LayerNormself_attn_layer_normrl   intermediate_sizefc1fc2final_layer_norm)r%   rd   re   r&   s      r'   r$   zBioGptDecoderLayer.__init__   s    ++(nn0077
 11#F$5$56"(";";$&LL$@!99T^^V-E-EF99V55t~~F "T^^ <r(   rq   r)   rs   rt   	use_cacher+   ru   rL   rv   c           
      ^   |}	| j                  |      } | j                  d||||||d|\  }}
t        j                  j	                  || j                  | j
                        }|	|z   }|}	| j                  |      }| j                  |      }| j                  |      }t        j                  j	                  || j                  | j
                        }| j                  |      }t        j                  j	                  || j                  | j
                        }|	|z   }|f}|r||
fz  }|S )a\  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            past_key_values (`Cache`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
                cache in the correct position and to infer the complete sequence length.
        )rq   rs   r)   rt   r+   ru   rP    )r   r   rV   rW   rK   rR   r   r   r   r   r   )r%   rq   r)   rs   rt   r   r+   ru   rL   residualself_attn_weightsoutputss               r'   r2   zBioGptDecoderLayer.forward  s=   6 !11-@ ,:4>> ,
'+)/%),
 ,
(( --mt||VZVcVc-d =0 !--m</**=9--mt?V?Vaeanan-o/--mt||VZVcVc-d =0 ")++Gr(   r@   )NNFTNN)r3   r4   r5   r   r7   r$   r/   rE   r	   r   r8   r   r   r   FloatTensorr2   r9   r:   s   @r'   r   r      s    =| =d
 =4 /3(,).!%04.2;||; t+; 	;
  $;; $;; &&-; t+; +,; 
u  %(9(95;L;L(L"MPT"TT	U;r(   r   c                   0    e Zd ZU eed<   dZdZdZdZdZ	dZ
y)BioGptPreTrainedModelrd   biogptTN)r3   r4   r5   r   __annotations__base_model_prefixsupports_gradient_checkpointing_supports_flash_attn_supports_sdpa_supports_flex_attn_can_compile_fullgraphr   r(   r'   r   r   I  s+     &*#N!r(   r   c                   (    e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  dej                  dz  de	dz  de
dz  d	ej                  dz  d
e
dz  de
dz  de
dz  dej                  dz  dee   deez  fd       Z xZS )BioGptModelrd   c           	         t         |   |       || _        |j                  | _        |j                  | _        |j                  | _        |j                  | _	        |j                  rt        j                  |j                        nd}t        |j                  | j                  | j                  |      | _        t!        |j"                  | j                        | _        t'        j(                  t+        |j,                        D cg c]  }t/        ||       c}      | _        t'        j2                  | j                        | _        d| _        | j9                          y c c}w )NrC   )r>   )re   F)r#   r$   rd   	layerdropr   rK   r   r_   pad_token_idr=   scale_embeddingmathsqrtr<   
vocab_sizeembed_tokensr   max_position_embeddingsembed_positionsrV   
ModuleListrangenum_hidden_layersr   r   r   
layer_normgradient_checkpointing	post_init)r%   rd   r>   ir&   s       r'   r$   zBioGptModel.__init__W  s    ))11++!..7=7M7Mdii 2 23SV5t~~t/?/?[
  @@^@^`d`n`nommV[\b\t\tVu$vQR%7!%L$vw,,t~~6&+# %ws   E"NrA   r)   inputs_embedsrs   r   r+   rt   output_hidden_statesreturn_dictru   rL   rv   c                 <   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|	|	n| j                   j                  }	|d u |d uz  rt        d      |$|}|j                  }|j                  d|d         }n-| |j                         d d }|d d d d df   }nt        d      || j                  |      }| j                  r%| j                  r|rt        j                  d       d}|r|t        | j                         }|j                         d d \  }}||j                         nd}|
%t!        j"                  |||z   |j$                        }
|'||z   }t!        j&                  |||j$                        }|}t)        | j                   |||
|	      }||
j+                  d      }| j-                  |||
      }||z   }t.        j0                  j3                  || j2                  | j                        }| j                  r%| j                  r|rt        j                  d       d}|rdnd }|rdnd }d }t5        | j6                        D ]_  \  }}|r||fz  }| j                  r%t!        j8                  g       }|| j:                  k  r? ||f||||||
d|}|d   }|sW||d   fz  }a |r||fz  }| j=                  |      }|	st?        d |||||fD              S tA        |||||      S )NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timerN   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsz[`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`...F)rd   r   device)rd   r   r)   ru   rs   )r+   rP   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...r   )r)   rs   rt   r   r+   ru   r   c              3   $   K   | ]  }|| 
 y wr@   r   ).0vs     r'   	<genexpr>z&BioGptModel.forward.<locals>.<genexpr>  s      = s   )last_hidden_staters   rq   
attentionscross_attentions)!rd   rt   r   r   use_return_dictri   ry   rz   rS   r   r   rR   rj   rk   r
   get_seq_lengthr/   aranger   onesr   	unsqueezer   rV   rW   rK   	enumerater   randr   r   r   r   )r%   rA   r)   r   rs   r   r+   rt   r   r   ru   rL   inputinput_shape
batch_size
seq_lengthr*   mask_seq_lengthself_attn_cachecausal_mask	positionsrq   all_hidden_statesall_self_attnsall_cross_attentionsidxdecoder_layerdropout_probabilitylayer_outputss                                r'   r2   zBioGptModel.forwardl  s    2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	%0%<k$++B]B] -t";<stt"E++K!r;r?;I&',,.s3K!!Q(+Edee  --e4M&&4==##q "	 0*$++>O!.!3!3!5cr!:
JETE`!?!?!Afg!"\\&(>(KTaThThN !4zAO"ZZ
OML`L`aN)(;;'))+
 )33A6L((9O^j(k	%	1--mt||VZVcVc-d&&4==##p "	"6BD0d#"+DKK"8 	6C#!m%55!}}&+jjn#&7)	* /"3#)-	 	M *!,M =#3"55/	64  -!116 ':K^]qr  
 9+++%1
 	
r(   )
NNNNNNNNNN)r3   r4   r5   r   r$   r   r/   r8   r   r	   r   rE   r   r   r   r   r2   r9   r:   s   @r'   r   r   U  s   | *  .23726(,!%04)-,0#'.2D
##d*D
 ))D0D
 ((4/	D

 D
 $;D
 &&-D
  $;D
 #TkD
 D[D
 t+D
 +,D
 
:	:D
 D
r(   r   zR
    BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
    )custom_introc                   v    e Zd ZddiZ fdZd Zd Ze	 	 	 	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  d	ej                  dz  d
edz  dej                  dz  dedz  dej                  dz  dedz  dedz  dedz  dej                  dz  deej                  z  dee   deez  fd       Z xZS )BioGptForCausalLMzoutput_projection.weightzbiogpt.embed_tokens.weightc                     t         |   |       t        |      | _        t	        j
                  |j                  |j                  d      | _        | j                          y NFrg   )
r#   r$   r   r   rV   rl   r   r   output_projectionr   r%   rd   r&   s     r'   r$   zBioGptForCausalLM.__init__  sJ     !&)!#6+=+=v?P?PW\!] 	r(   c                     | j                   S r@   r   r%   s    r'   get_output_embeddingsz'BioGptForCausalLM.get_output_embeddings  s    %%%r(   c                     || _         y r@   r   )r%   new_embeddingss     r'   set_output_embeddingsz'BioGptForCausalLM.set_output_embeddings  s
    !/r(   NrA   r)   r   rs   labelsr   r+   rt   r   r   ru   logits_to_keeprL   rv   c                    |
|
n| j                   j                  }
 | j                  |f|||||||	|
|d	|}|d   }t        |t              rt        | d      n|}| j                  |dd|ddf         }d}|* | j                  d||| j                   j                  d|}|
s|f|dd z   }||f|z   S |S t        |||j                  |j                  |j                  |j                        S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        N)	r)   r   rs   r   r+   rt   r   r   ru   r   )logitsr   r   r   )lossr   rs   rq   r   r   r   )rd   r   r   r{   r7   slicer   loss_functionr   r   rs   rq   r   r   )r%   rA   r)   r   rs   r   r   r+   rt   r   r   ru   r   rL   r   rq   slice_indicesr   r   outputs                       r'   r2   zBioGptForCausalLM.forward  s0   . &1%<k$++B]B]$++
)'+%/!5#)
 
  
8B>SV8W~ot4]k''a6I(JK%4%%pVFt{{OeOepiopDY,F)-)9TGf$EvE0#33!//))$55
 	
r(   NNNNNNNNNNNr   )r3   r4   r5   _tied_weights_keysr$   r   r   r   r/   r8   r   r	   r   rE   r7   r   r   r   r   r2   r9   r:   s   @r'   r   r     sM    56RS&0  .23726(,*.!%04)-,0#'.2-.:
##d*:
 ))D0:
 ((4/	:

 :
   4':
 $;:
 &&-:
  $;:
 #Tk:
 D[:
 t+:
 ell*:
 +,:
 
2	2:
 :
r(   r   c                   X    e Zd Z fdZe	 	 	 	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  dej                  dz  dedz  dej                  dz  dej                  dz  d	e	dz  d
ej                  dz  de	dz  de	dz  de	dz  dej                  dz  deez  fd       Z xZS )BioGptForTokenClassificationc                 z   t         |   |       |j                  | _        t        |      | _        t        |d      r|j                  |j                  }n|j                  }t        j                  |      | _
        t        j                  |j                  |j                        | _        | j                          y )Nclassifier_dropout)r#   r$   
num_labelsr   r   hasattrr	  r   rV   DropoutrK   rl   r   
classifierr   )r%   rd   r	  r&   s      r'   r$   z%BioGptForTokenClassification.__init__K  s      ++!&)6/0V5N5N5Z!'!:!:!'!;!;zz"45))F$6$68I8IJr(   NrA   token_type_idsr)   rs   r   r   r   r+   rt   r   r   ru   rv   c                    ||n| j                   j                  }| j                  |||||||	|
||
      }|d   }| j                  |      }| j	                  |      }d}|t               }||j                  d      dk(  }|j                  d| j                        }t        j                  ||j                  d      t        j                  |j                        j                  |            } |||      }n2 ||j                  d| j                        |j                  d            }|s|f|dd z   }||f|z   S |S t        |||j                  |j                        S )  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N	rs   r)   r   r   r+   rt   r   r   ru   r   rN   r   r!   )r   r   rq   r   )rd   r   r   rK   r  r   rz   r
  r/   wheretensorignore_indextype_asr   rq   r   )r%   rA   r  r)   rs   r   r   r   r+   rt   r   r   ru   rL   transformer_outputsrq   r   r   loss_fctactive_lossactive_logitsactive_labelsr  s                          r'   r2   z$BioGptForTokenClassification.forwardY  sr   . &1%<k$++B]B]"kk+)'%/!5#) * 
 ,A.]3/')H),11"5: &B @ %R%,,x?T?T2U2]2]^d2e!  }=B @&++b/RY!4QR!88F)-)9TGf$EvE$-;;*55	
 	
r(   )NNNNNNNNNNNN)r3   r4   r5   r$   r   r/   r8   r   r	   r   rE   r   r   r2   r9   r:   s   @r'   r  r  I  s6     .22637(,26*.!%04)-,0#'.2@
##d*@
 ((4/@
 ))D0	@

 @
 ((4/@
   4'@
 $;@
 &&-@
  $;@
 #Tk@
 D[@
 t+@
 
&	&@
 @
r(   r  a  
    The BioGpt Model transformer with a sequence classification head on top (linear layer).

    [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it is required to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    c                   j    e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  de	dz  dej                  dz  dej                  dz  d	e
dz  d
ej                  dz  de
dz  de
dz  de
dz  dej                  dz  deej                  z  deez  fd       Zd Zd Z xZS )BioGptForSequenceClassificationrd   c                     t         |   |       |j                  | _        t        |      | _        t        j                  |j                  | j                  d      | _        | j                          y r   )
r#   r$   r
  r   r   rV   rl   r   scorer   r   s     r'   r$   z(BioGptForSequenceClassification.__init__  sS      ++!&)YYv114??O
 	r(   NrA   r)   rs   r   r   r   r+   rt   r   r   ru   r   rv   c                 `   |
|
n| j                   j                  }
| j                  ||||||||	|
|
      }|d   }t        |t              rt        | d      n|}| j                  |dd|ddf         }||j                  dd \  }}n|j                  dd \  }}| j                   j                  d}n|Vt        j                  || j                   j                        j                  d      dz
  j                  |j                        }n.d}t        j                  | j                   j"                   d       |t        j$                  ||j                        |f   }d}|| j                   j&                  | j(                  dk(  rd	| j                   _        nl| j(                  dkD  rL|j*                  t        j,                  k(  s|j*                  t        j                  k(  rd
| j                   _        nd| j                   _        | j                   j&                  d	k(  rIt/               }| j(                  dk(  r& ||j1                         |j1                               }n |||      }n| j                   j&                  d
k(  r=t3               } ||j5                  d| j(                        |j5                  d            }n,| j                   j&                  dk(  rt7               } |||      }|
s|f|dd z   }||f|z   S |S t9        |||j:                  |j<                  |j>                        S )r  Nr  r   r!   rN   r   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r   
regressionsingle_label_classificationmulti_label_classification)r   r   rs   rq   r   ) rd   r   r   r{   r7   r   r  ry   r   r/   nesumtor   rj   rk   r&   r3   r   problem_typer
  dtyper1   r   squeezer   rz   r   r   rs   rq   r   )r%   rA   r)   rs   r   r   r   r+   rt   r   r   ru   r   rL   r  rq   r  r   r   sequence_lengthpooled_logitsr   r  r  s                           r'   r2   z'BioGptForSequenceClassification.forward  s   . &1%<k$++B]B]"kk+)'%/!5#) * 
 ,A.8B>SV8W~ot4]kM!]A*=>? *3//"1*='J*7*=*=bq*A'J;;##+ O$#(88It{{7O7O#P#T#TUW#X[\#\"`"`aganan"o"$##~~../ 0^ ^
 u||Jv}}M^_{{''/??a'/;DKK,__q(fllejj.HFLL\a\e\eLe/LDKK,/KDKK,{{''<7"9??a'#M$9$9$;V^^=MND#M6:D))-JJ+- 2 22t GUWY))-II,.v6#%(;AB(??F)-)9TGf$EvE/ /??-;;*55
 	
r(   c                 .    | j                   j                  S r@   r   r   r   s    r'   get_input_embeddingsz4BioGptForSequenceClassification.get_input_embeddings  s    {{'''r(   c                 &    || j                   _        y r@   r,  )r%   rI   s     r'   set_input_embeddingsz4BioGptForSequenceClassification.set_input_embeddings  s    #( r(   r  )r3   r4   r5   r   r$   r   r/   r8   r   r	   r   rE   r7   r   r   r2   r-  r/  r9   r:   s   @r'   r  r    sF   |   .237(,26*.!%04)-,0#'.2-.[
##d*[
 ))D0[
 	[

 ((4/[
   4'[
 $;[
 &&-[
  $;[
 #Tk[
 D[[
 t+[
 ell*[
 
1	1[
 [
z()r(   r  )r   r  r  r   r   )Nrx   )9r   collections.abcr   r/   torch.nnrV   r   r   r   activationsr   cache_utilsr	   r
   r   
generationr   masking_utilsr   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   r   r   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   configuration_biogptr   
get_loggerr3   rj   	Embeddingr   r<   ModulerE   rD   r\   r^   r   r   r   r   r  r  __all__r   r(   r'   <module>rA     s  *  $   A A ! C C ) / B 9  G & @ @ . 
		H	%;r|| ;8
= 
=& !%II%<<% 
% <<	%
 LL4'% T\% % '(%8z)bii z)zS3 Sl "O " " [
' [
 [
| 
M
- M

M
` P
#8 P
 P
f l)&; l)l)^r(   