
    qi~                        d Z ddlZddlmZmZ ddlmZ ddlZddlm	Z	 ddl
m	c mZ ddlmZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZmZ ddlmZ ddlm Z  ddl!m"Z"  ejF                  e$      Z%ee G d de                    Z&ee G d de                    Z' G d de	jP                        Z) G d de	jP                        Z* G d de      Z+ G d de	jP                        Z, G d de	jP                        Z- G d d e	jP                        Z.e G d! d"e             Z/e G d# d$e/             Z0 G d% d&e/      Z1g d'Z2y)(zPyTorch TimesFM model.    N)CallableSequence)	dataclass   )initialization)FlashAttentionKwargs)BaseModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringcan_return_tuplelogging   )LlamaRMSNorm)simple_eager_attention_forward   )TimesFmConfigc                   b    e Zd ZU dZdZej                  dz  ed<   dZej                  dz  ed<   y)TimesFmOutputz
    loc (`torch.Tensor` of shape `(batch_size, )`):
        The mean of the time series inputs.
    scale (`torch.Tensor` of shape `(batch_size,)`):
        The scale of the time series inputs.
    Nlocscale)	__name__
__module____qualname____doc__r   torchTensor__annotations__r        ]/opt/pipecat/venv/lib/python3.12/site-packages/transformers/models/timesfm/modular_timesfm.pyr   r   &   s/      $C	#!%E5<<$%r!   r   c                       e Zd ZU dZdZej                  dz  ed<   dZej                  dz  ed<   dZ	ej                  e
z  dz  ed<   y)TimesFmOutputForPredictiona  
    mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
        The mean predictions of the time series.
    full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
        The full predictions of the time series including the mean and the quantiles.
    loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
        The loss of the TimesFM model.
    Nmean_predictionsfull_predictionsloss)r   r   r   r   r%   r   r   r   r&   r'   floatr    r!   r"   r$   r$   4   sI     -1ellT)0,0ellT)0(,D%,,

%,r!   r$   c                   0     e Zd ZdZdef fdZddZ xZS )
TimesFmMLPzPax MLP in pytorch.configc                     t         |           |j                  }|j                  }t	        j
                  ||      | _        t	        j
                  ||      | _        t	        j                  |d      | _	        y )Ngư>)normalized_shapeeps)
super__init__hidden_sizeintermediate_sizennLinear	gate_proj	down_proj	LayerNorm
layer_norm)selfr+   r1   r2   	__class__s       r"   r0   zTimesFmMLP.__init__H   s]    (("44;0AB#4kB,,Nr!   c                     | j                  |      }| j                  |      }t        j                  |      }| j	                  |      }||d|d d d d d f   z
  z  }||z   S )N      ?)r8   r5   Frelur6   )r9   xpaddingsgate_inpgateoutputss         r"   forwardzTimesFmMLP.forwardQ   sc    ??1%~~h'vvd|..&x1d
';!;<G{r!   Nr   r   r   r   r   r0   rD   __classcell__r:   s   @r"   r*   r*   E   s    O} Or!   r*   c                   (     e Zd ZdZ fdZd Z xZS )TimesFmResidualBlockzTimesFM residual block.c                     t         |           || _        || _        || _        t        j                  ||      | _        t        j                         | _	        t        j                  ||      | _
        t        j                  ||      | _        y rE   )r/   r0   
input_dimshidden_dimsoutput_dimsr3   r4   input_layerSiLU
activationoutput_layerresidual_layer)r9   rL   rM   rN   r:   s       r"   r0   zTimesFmResidualBlock.__init__^   sk    $&&99Z='')IIk;? ii
K@r!   c                     | j                  |      }| j                  |      }| j                  |      }| j                  |      }||z   S rE   )rO   rQ   rR   rS   )r9   r?   hiddenoutputresiduals        r"   rD   zTimesFmResidualBlock.forwardi   sK    !!!$(""6*&&q)  r!   )r   r   r   r   r0   rD   rG   rH   s   @r"   rJ   rJ   [   s    !	A!r!   rJ   c                       e Zd Zy)TimesFmRMSNormN)r   r   r   r    r!   r"   rY   rY   q   s    r!   rY   c                   0     e Zd ZdZdef fdZddZ xZS )TimesFmPositionalEmbeddingz6Generates position embedding for a given 1-d sequence.r+   c           
         t         |           |j                  }|j                  }||c| _        | _        |j                  | _        | j
                  dz  }t        j                  t        |      t        |      z        t        |dz
  d      z  }| j                  d|t        j                  t        j                  |t        j                        | z        z         y )Nr   r   inv_timescalesdtype)r/   r0   min_timescalemax_timescaler1   embedding_dimsmathlogr(   maxregister_bufferr   exparangefloat32)r9   r+   r`   ra   num_timescaleslog_timescale_incrementr:   s         r"   r0   z#TimesFmPositionalEmbedding.__init__x   s    ,,,,1>.D.$00,,1"&((5+?%BV+V"WZ]^lop^prsZt"tEIIell>&W[rZr&rss	
r!   c                 N   ||t        d      |Jt        j                  |t        j                  | j                  j
                        j                  d      }n'|j                  dk7  rt        d|j                          |j                  g |j                  d | j                  j                  ddd      z  }t        j                  t        j                  |      t        j                  |      gd      }t        j                  |ddd| j                  dz  f      }|S )	a  Generates a Tensor of sinusoids with different frequencies.

        Args:
            seq_length: an optional Python int defining the output sequence length.
              if the `position` argument is specified.
            position: [B, seq_length], optional position for each token in the
              sequence, only required when the sequence is packed.

        Returns:
            [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
        z.Either position or seq_length must be providedr_   devicer   r   z*position must be 2-dimensional, got shape r   dim)
ValueErrorr   rh   ri   r]   rn   	unsqueezendimshapeviewcatsincosr=   padrb   )r9   
seq_lengthpositionscaled_timesignals        r"   rD   z"TimesFmPositionalEmbedding.forward   s     
 2MNN||JemmDL_L_LfLfgqqrstH]]aI(..IYZ[[#hmm7X^^7Q7$:M:M:R:RSTVWY[:\\EIIk2EIIk4JKQRS v1a)<)<q)@ABr!   NNrF   rH   s   @r"   r[   r[   u   s    @
} 
r!   r[   c                        e Zd ZdZdedef fdZdej                  dej                  fdZ		 dd	ej                  d
ej                  dz  de
e   deej                  ej                  dz  f   fdZ xZS )TimesFmAttentionzlImplements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.r+   	layer_idxc                    t         |           || _        d| _        |j                  | _        || _        |j                  | _        |j                  | _        |j                  | _	        | j                  | j                  z  | _
        | j                  | j                  z  | _        t        j                  t        j                  | j                  f            | _        t        j"                  | j                  | j                  | j                  z        | _        t        j"                  | j                  | j                  | j                  z        | _        t        j"                  | j                  | j                  | j                  z        | _        t        j"                  | j                  | j                  z  | j                        | _        y )NT)r/   r0   r+   	is_causalattention_dropoutr   num_attention_heads	num_headsr1   head_dimq_sizekv_sizer3   	Parameterr   emptyscalingr4   q_projk_projv_projo_projr9   r+   r   r:   s      r"   r0   zTimesFmAttention.__init__   s3   !'!9!9"33!--nnt}}4~~5||EKK0@$ABii 0 0$..4==2PQii 0 0$..4==2PQii 0 0$..4==2PQii >@P@PQr!   queryreturnc                     t        j                  | j                        j                  dt	        j
                  | j                        z        }||d d d d d f   z  S )Ng^$3eG?)r=   softplusr   mulrc   sqrtr   )r9   r   r   s      r"   _scale_queryzTimesFmAttention._scale_query   sJ    

4<<(,,[499T]];S-STuT4q0111r!   Nhidden_statesattention_maskkwargsc                    |j                   d d }g |d| j                  }| j                  |      j                  |      j	                  dd      }| j                  |      }| j                  |      j                  |      j	                  dd      }| j                  |      j                  |      j	                  dd      }t        j                  | j                  j                  t              }	 |	| ||||f| j                  sdn| j                  dd|\  }
} |
j                  g |d j!                         }
| j#                  |
      }
|
|fS )Nro   r   r           r<   )dropoutr   )ru   r   r   rv   	transposer   r   r   r
   get_interfacer+   _attn_implementationr   trainingr   reshape
contiguousr   )r9   r   r   r   input_shapehidden_shapequery_states
key_statesvalue_statesattention_interfaceattn_outputattn_weightss               r"   rD   zTimesFmAttention.forward   sW    $))#2.88b8$--8{{=166|DNNqRST((6[[/44\BLLQPQR
{{=166|DNNqRST(?(M(MKK,,.L)
 %8	%
  $}}C$2H2H	%
 	%
!\ *k));;;;FFHkk+.L((r!   rE   )r   r   r   r   r   intr0   r   r   r   r   r   tuplerD   rG   rH   s   @r"   r   r      s    vR} R R(2%,, 25<< 2 /3)||) t+) -.	)
 
u||U\\D00	1)r!   r   c                        e Zd ZdZdedef fdZ	 ddej                  dej                  dej                  de	d	e
ej                  d
z  ej                  f   f
dZ xZS )TimesFmDecoderLayerzTransformer layer.r+   r   c                     t         |           t        ||      | _        t	        |      | _        t        |j                  |j                        | _	        y )N)r   )r.   )
r/   r0   r   	self_attnr*   mlprY   r1   rms_norm_epsinput_layernormr   s      r"   r0   zTimesFmDecoderLayer.__init__   sC    )&IFf%-f.@.@fFYFYZr!   r   r   r@   output_attentionsr   Nc                     |}| j                  |      }| j                  |||      \  }}||z   }| j                  ||      }||fS )N)r   r   r   )r@   )r   r   r   )r9   r   r   r@   r   rW   scoress          r"   rD   zTimesFmDecoderLayer.forward   se     !,,]; $')/ !/ !
v
 !=0 B}$$r!   )F)r   r   r   r   r   r   r0   r   r   boolr   rD   rG   rH   s   @r"   r   r      sy    [} [ [ #(%||% % ,,	%
  % 
u||d"ELL0	1%r!   r   c                   h     e Zd ZU eed<   dZdgZdZdZdZ	 e
j                          fd       Z xZS )TimesFmPreTrainedModelr+   timesfmr   past_values)timeTc           
      "   t         |   |       t        |t              r t	        j
                  |j                         y t        |t              r|j                  dz  }|j                  |j                  }}t        j                  t        |      t        |      z        t        |dz
  d      z  }t	        j                  |j                   |t#        j$                  t#        j&                  |t"        j(                        | z        z         y y )Nr   r   r^   )r/   _init_weights
isinstancer   initones_r   r[   rb   ra   r`   rc   rd   r(   re   copy_r]   r   rg   rh   ri   )r9   modulerj   ra   r`   rk   r:   s         r"   r   z$TimesFmPreTrainedModel._init_weights	  s    f%f./JJv~~& :;#22a7N+1+?+?AUAU=M&*hhu]/CeMFZ/Z&[^a"A_ '# JJ%%))ELLu}}MQhPhhij <r!   )r   r   r   r   r   base_model_prefix_no_split_modulesmain_input_nameinput_modalities_supports_sdpar   no_gradr   rG   rH   s   @r"   r   r      sB    !./#O NU]]_ r!   r   c                       e Zd Zdef fdZdej                  dej                  deej                  eej                  ej                  f   f   fdZe	e
	 	 ddej                  dej                  d	ej                  d
ededefd              Ze	 ddej                  dz  dedej"                  dej$                  dedej                  dz  fd       Zedej                  dej                  deej                  ej                  f   fd       Zedej                  dej                  dej                  fd       Z xZS )TimesFmModelr+   c           	         t         |   |       || _        t        d|j                  z  |j
                  |j                        | _        t        j                  |j                  |j
                        | _        t        j                  t        |j                        D cg c]  }t        ||       c}      | _        | j                  j"                  rt%        |      | _        | j)                          y c c}w )Nr   rL   rN   rM   )num_embeddingsembedding_dim)r+   )r/   r0   r+   rJ   patch_lengthr1   r2   input_ff_layerr3   	Embedding	freq_sizefreq_emb
ModuleListrangenum_hidden_layersr   layersuse_positional_embeddingr[   position_emb	post_initr   s      r"   r0   zTimesFmModel.__init__  s     26...**00

 F4D4DTZTfTfgmmEJ6KcKcEde	 3e
 ;;// :& ID 	 fs   "C9inputspatched_padsr   c                    | j                  ||      \  }}t        j                  || j                  j                        }||ddddf   z
  |ddddf   z  }t        j
                  t        j                  || j                  j                  z
        | j                  j                  k  t        j                  | j                  j                  |j                  |j                        |      }|||ffS )zInput is of shape [B, N, P].minNrm   )_timesfm_masked_mean_stdr   clampr+   	tolerancewhereabspad_valtensorr_   rn   )r9   r   r   musigmarC   s         r"   _forward_transformzTimesFmModel._forward_transform1  s     11&,G	EEt{{'<'<= Bq$}--q$}1EE++IIft{{2223dkk6K6KKLL,,GMM'..Y

 U##r!   r   past_values_paddingfreqr   output_hidden_statesc                    |j                   d   }|j                  |d| j                  j                        }|j                  |d| j                  j                        }	t	        j
                  t	        j                  |	dz
        | j                  j                  k  t	        j                  d|j                  |j                        |      }t	        j
                  t	        j                  || j                  j                  z
        | j                  j                  k  t	        j                  d|	j                  |	j                        |	      }	| j                  ||	      \  }}
|d|	z
  z  }t	        j                  ||	gd      }| j                  |      }t	        j                  |	d      d   }| j                  j                   r]| j#                  |j                   d         }t	        j$                  |g|j                   d   z  d      }| j'                  ||      }||z  }| j)                  |      }||z  }|}| j+                  ||j                   d   |j                  |j                  d	      }g }g }| j,                  d
| j                  j.                   D ]8  } |||||      \  }}|r|j1                  |       |s(|j1                  |       : |r|g|z   }nd
}t3        |||r|nd
|
d   |
d         S )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Past values of the time series that serves as input to the model.
        past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The padding indicator of the time series.
        freq (`torch.LongTensor` of shape `(batch_size,)`):
            Frequency indices for the time series data.
        r   ro   r<   r   rm   rp   r   T)r   sequence_lengthr_   rn   r   N)r   r   r@   r   )last_hidden_stater   
attentionsr   r   )ru   rv   r+   r   r   r   r   r   r   r_   rn   r   r   rw   r   r   r   r   concat_timesfm_shift_padded_seqr   _prepare_4d_attention_maskr   r   appendr   )r9   r   r   r   r   r   r   bsizepatched_inputsr   statsconcat_inputsmodel_inputpatched_paddingpos_embf_embr   r   all_attentionsall_hidden_stateslayerr   s                         r"   rD   zTimesFmModel.forwardA  s   ( !!!$$))%T[[5M5MN*//r4;;;S;STIIlS()DKK,A,AALLN$8$8AVAVW

 {{IInt{{':'::;dkk>S>SSLLL$6$6|?R?RS

 !% 7 7 U (3+=>		><"@bI))-8  ))Lb9!<;;//''(9(9!(<=GllG9{/@/@/C#CKG44_gNG7"Kd#u $88*)//2%% '' 9 
 [[!@4;;#@#@A 
	8E$)+-("3	%!FM !%%f-#!((7
	8  !,0A A $++):~a(
 	
r!   r   Nr   r_   rn   r   c                    |j                   rt        j                  |      j                  nt        j                  |      j                  }| &| j                  | j                  d   ddd      } | |z  } |rbt        j                  t        j                  ||f||      |z  d      }|j                  dd||      }| t        j                  | |      } | S |} | S )a  
        Creates 4D attention mask and combines causal and padding masks if needed.

        Args:
            attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
            sequence_length: Length of the sequence
            dtype: Data type of the mask
            device: Device of the mask
            is_causal: Whether to apply causal masking

        Returns:
            4D attention mask of shape (batch_size, 1, seq_length, seq_length)
        r   r   ro   rm   )diagonal)
is_floating_pointr   finfor   iinforv   ru   triuonesminimum)r   r   r_   rn   r   	min_valuecausal_masks          r"   r   z'TimesFmModel._prepare_4d_attention_mask  s    , /4.E.EEKK&**5;;W\K]KaKa	 %+001E1Ea1H!QPRSN+i7N **

O_=USYZ]ffK &**1a/RK )!&~{!K  "-r!   paddingc                 F   dt         j                  fd}t        j                  d|z
  d      } ||      }t        j                  | j                  d         }| ||ddf   }|||ddf   }d|z
  }t        j                  |d      }	t        j
                  |	d	      }	t        j                  ||z  d      }
|
|	z  }||j                  d
      z
  |z  }t        j                  |dz  d      |	z  }t        j
                  |d	      }t        j                  |      }||fS )a  Calculates mean and standard deviation of `inputs` across axis 1.

        It excludes values where `padding` is 1.

        Args:
            inputs: A PyTorch tensor of shape [b, n, p].
            padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.

        Returns:
            A tuple containing the mean and standard deviation.
            We return the statistics of the first patch with more than three non-padded values.
        arrc                 (   t        j                  | dk\  j                  t         j                        d      }| dk\  j                  t         j                        j	                  d      }t        j
                  |dk(  | j                  d   dz
  |      S )Nr   r   rp   r   )r   argmaxtoint32sumr   ru   )r  indicesrow_sums      r"   _get_patch_indexz?TimesFmModel._timesfm_masked_mean_std.<locals>._get_patch_index  sk    llC1H==#=1EGaxmmEKK0444;G;;w!|SYYq\A-=wGGr!   r   r   rp   r   Nr<   r   ro   r   )r   r   r  rh   ru   r   rs   r   )r   r  r  pad_sumpatch_indicesbidxsr  rz   masknum_valid_elements
masked_summasked_meanmasked_centered_arr
masked_var
masked_stds                  r"   r   z%TimesFmModel._timesfm_masked_mean_std  s    	H%,, 	H
 ))AKQ/(1V\\!_-UM1,-e]A-. 3w #YYt3"[[);E YYsTzq1
 #55  #[%:%:2%>>$FYY2A51=@RR
[[5
ZZ
+
J&&r!   r  seqc                    |j                   \  }}}| dk(  }|j                  t        j                        j	                  d      }d||j                  d       <   t        j                  ||j                        j                  ddd      j                  |d|      }||ddddf   z
  |z  }|j                  d|      }	|	S )zShifts rows of seq based on the first 0 in each row of the mask.

        Args:
            mask: mask tensor of shape [B, N]
            seq: seq tensor of shape [B, N, P]

        Returns:
            The shifted sequence.
        r   r   rp   ro   )rn   N)ru   r  r   r  r  anyrh   rn   rv   expandgather)
r  r#  
batch_sizenum_seqfeature_dimnew_maskr  	idx_rangeshifted_idxshifted_seqs
             r"   r   z&TimesFmModel._timesfm_shift_padded_seq  s     ,/99(
G[%)QY ++ekk*11a18 )+!$$% LL<AA!RKRRS]_acno	 !71dD=#99WD jjK0r!   )FF)T)r   r   r   r   r0   r   r   r   r   r   r   
LongTensorr   r   rD   staticmethodr   r_   rn   r   r   r   rG   rH   s   @r"   r   r     s   } &$ll$27,,$	u||U5<<#=>>	?$   #(%*V
\\V
 #--V
 ll	V

  V
 #V
 
V
  V
p  +t+++ {{+ 	+
 + 
	+ +Z ,' ,' ,'QVW\WcWcejeqeqWqQr ,' ,'\  5<< ELL  r!   r   c                   `    e Zd ZdZdef fdZ	 ddeej                     dee	   dz  de	dz  de
ej                  d	f   fd
Zdej                  de
ej                  ej                  f   dej                  fdZdej                  dej                  dej                  fdZee	 	 	 	 	 	 	 	 ddeej                     deej                  e	z     dz  de	dz  dej                  dz  de	dz  dedededz  dedz  defd              Zedej                  de	deej                     fd       Z xZS )TimesFmModelForPredictionz/TimesFM model for quantile and mean prediction.r+   c                 J   t         |   |       || _        |j                  | _        |j
                  | _        t        |      | _        t        |j                  |j
                  dt        |j                        z   z  |j                        | _        | j                          y )Nr   r   )r/   r0   r+   context_lengthcontext_lenhorizon_lengthhorizon_lenr   decoderrJ   r1   len	quantilesr2   horizon_ff_layerr   )r9   r+   r:   s     r"   r0   z"TimesFmModelForPrediction.__init__  s     !00!00#F+ !5))--S9I9I5J1JK00!
 	r!   Nr   r   r5  r   .c                 X   || j                   }g g }}|D ]  }|j                  d   }t        j                  || j                  z   |j
                  |j                        }||k  r||z
  }	t        j                  t        j                  |	|j
                  |j                        |gd      }t        j                  t        j                  |	|j
                  |j                        |gd      }n||kD  r|| d }||| j                  z    d }|j                  |       |j                  |        t        j                  |d      t        j                  |d      f}
|E|
t        j                  |dt        |       t        j                        j                  dd      fz   }
|
S )a  Pad/truncate input time series to `context_len` and build a padding mask.

        Args:
            inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task.
            freq: Optional list of frequencies (returned as a tensor when provided).
            context_len: Optional context length override (defaults to `self.context_len`).

        Returns:
            Tuple of (padded_inputs, padding_mask) and optionally a freq tensor.
        Nr   rm   rp   r^   ro   r   )r5  ru   r   zerosr7  r_   rn   rw   r
  r   stackr   r9  r  r   )r9   r   r   r5  input_tsinput_paddingts	input_lenr  num_front_padresults              r"   _preprocessz%TimesFmModelForPrediction._preprocess/  s    **K"$b- 	*BIkk)d.>.>">bhhWYW`W`aG;& +i 7YYMRTR[R[ \^`aghi))UZZRXXV]VdVd%egn$ouvw[(&!K$2B2B$B"C"EFOOB  )	* ++hA.Mq0QRu||D3v;,?u{{S[[\^`abddFr!   model_outputr   c                    | j                  |      }|j                  \  }}}|j                  ||| j                  j                  t        | j                  j                        dz         }|\  }}||dddddf   z  |dddddf   z   S )z*Postprocess output of stacked transformer.r   N)r;  ru   rv   r+   r6  r9  r:  )	r9   rF  r   	output_tsbn_r   r   s	            r"   _postprocess_outputz-TimesFmModelForPrediction._postprocess_outputT  s     )),7	 //1aNN1a)C)CSI^I^E_bcEcd		E5D$!4551dD$;N8OOOr!   predictionstargetsc                 *   g }t        | j                  j                        D ]M  \  }}||d|f   z
  }t        j                  |dz
  |z  ||z        }|j                  |j                                O t        j                  |      j                         S )N.r   )	enumerater+   r:  r   re   r   meanr>  )r9   rM  rN  lossesiqerrorsr'   s           r"   _quantile_lossz(TimesFmModelForPrediction._quantile_lossc  s    dkk334 	'DAq{3622F99a!ev-q6z:DMM$))+&	' {{6"''))r!   r   window_sizefuture_valuesforecast_context_lenreturn_forecast_on_contexttruncate_negativer   r   c
           
      n	   || j                   }n|}|d   j                  }|D cg c]  }|| d 
 }}t        j                  t        j                  |D cg c]  }t        j                  |       c}            }|Yg }g }t        |      D ]A  \  }}|j                  | j                  ||             |*|j                  ||   gdz         C |}||}|$t        j                  d       dgt        |      z  }|| j                  j                  }|	| j                  j                  }	| j                  ||      \  }}}|j                  |      }|j                  |      }|j                  |      }|}|j                   d   }g }|j                   d   |j                   d   | j"                  z   k7  r8t%        d|j                   d    d|j                   d    d| j"                         | j                  j&                  }| j"                  |z   dz
  |z  }t)        |      D ]/  }|ddd|j                   d   f   }|dd| df   }|dd| df   }| j+                  |||||		      }| j-                  |j.                  |j0                  |j2                  f      }|rl|dk(  rg|dddd
d| j                  j4                  ddf   }|j7                  |j9                  d      d
|j9                  d            }|j;                  |       |ddd
d|df   } |ddd
d|ddf   }|j;                  |       t        j<                  || gd
      }2 |rHt        j<                  |d      ddd|| j                  j4                  z
  | j"                  z   ddf   }n-t        j<                  |d      ddd| j"                  ddf   }|dddddf   }!|*|!ddddf   |!ddddf   z   }!|ddddf   |ddddf   z   }|dk\  r.|r,t        j>                  |!d      }!t        j>                  |d      }d}"|9tA        jB                  |!|      }#| jE                  |ddddddf   |      }$|#|$z   }"tG        j.                  |r|jH                  nd|	r|jJ                  nd|!||"      S c c}w c c}w )aa  
        past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Past values of the time series that serves as input to the model.
        freq (`torch.LongTensor` of shape `(batch_size,)`):
            Frequency indices for the time series data.
        window_size (`int`, *optional*):
            Window size of trend + residual decomposition. If None then we do not do decomposition.
        future_values (`torch.Tensor`, *optional*):
            Optional future time series values to be used for loss computation.
        forecast_context_len (`int`, *optional*):
            Optional max context length.
        return_forecast_on_context (`bool`, *optional*):
            True to return the forecast on the context when available, i.e. after the first input patch.
        truncate_negative (`bool`, *optional*):
            Truncate to only non-negative values if any of the contexts have non-negative values,
            otherwise do nothing.
        output_attentions (`bool`, *optional*):
            Whether to output the attentions.
        output_hidden_states (`bool`, *optional*):
            Whether to output the hidden states.

        Example:

        ```python
        >>> from transformers import TimesFmModelForPrediction

        >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")

        >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
        >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)

        >>> # Generate
        >>> with torch.no_grad():
        >>>     outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
        >>>     point_forecast_conv = outputs.mean_predictions
        >>>     quantile_forecast_conv = outputs.full_predictions
        ```
        Nr   r   z6No frequency provided via `freq`. Default to high (0).r   z=Length of paddings must match length of input + horizon_len: z != z + )r   r   r   r   r   ro   r   )axis.r   )r   r   r   r%   r&   r'   )&r5  rn   r   r   r>  rP  extend_timesfm_moving_averageloggerinfor9  r+   r   r   rE  r  ru   r7  rr   r6  r   r8  rL  r   r   r   r   r   sizer   concatenatemaximumr=   mse_lossrV  r$   r   r   )%r9   r   r   rW  rX  rY  rZ  r[  r   r   r   fcontext_lenrn   rA  r   inp_min
new_inputs	new_freqsrS  r?  r@  inp_freq	final_outr5  full_outputsoutput_patch_lennum_decode_patches
step_indexcurrent_paddingdecoder_outputfprop_outputsnew_full_tsnew_tsmean_outputsr'   re  quantile_losss%                                        r"   rD   z!TimesFmModelForPrediction.forwardk  s-   j  '++L/L Q&& 0;;"l]^$;;))EKK(H22(HIJ"JI"6* 42!!$">">r;"OP#$$d1gY]34  F <KKPQ3V$D$ $ = ='#';;#C#C ,0,<,<VT,J)-;;v&%((0;;v&	ooa(q!Y__Q%7$:J:J%JJ!''*+4	0B/C3tGWGWFXZ   ;;55"..1AAAEJZZ 23 	HJ+Aq9??13E/E,EFO \MN!23H+A}~,=>M!\\$$1"3%9 * N !4400##^%9%9:M
 *jAo ,Ass4Ndkk6N6N4NPQ,QR)11+2B2B12Er;K[K[\]K^_##K0 #1b*;+;*;Q#>?F'2/@0@/@!(CDK,))9f*=BGI=	H@ & ,,\BPkDKK$<$<<t?O?OOPRSSL
 !,,\B1a$JZJZFZ\]C]^L#Aq!G,"'1c	2\!$Q$)5LLL'1c	2\!$Q$)5LLLa<- ==s;L ==s;L$zz,>H //Q12X0FVMm+D),>>4E~004:N.66TX))
 	
A <(Hs   R-R2r  c                 4   t        j                  | |dz
  dfdd      }t        j                  || j                  | j
                        |z  }t        j                  |j                  ddd      |j                  ddd            j                         }|| |z
  gS )zCCalculates the moving average using PyTorch's convolution function.r   r   constantrm   ro   )	r=   rz   r   r
  r_   rn   conv1drv   squeeze)r  rW  
arr_paddedkernelsmoothed_arrs        r"   r_  z1TimesFmModelForPrediction._timesfm_moving_average  s     UU3q! 4j!D
KsyyL{Zxx
1b 96;;q!R;PQYY[cL011r!   r   )NNNNFFNN)r   r   r   r   r   r0   r   r   r   r   r   rE  rL  rV  r   r   r   r$   rD   r0  listr_  rG   rH   s   @r"   r2  r2    s   9} ( lp#u||,#4<SMD4H#^adh^h#	u||S 	!#JP!LLP16u||U\\7Q1RP	P*%,, * *RWR^R^ *  59"&-1+/+0"')-,0c
ell+c
 u||c)*T1c
 4Z	c

 ||d*c
 "Djc
 %)c
  c
  $;c
 #Tkc
 
$c
  c
J 2U\\ 2 2U\\HZ 2 2r!   r2  )r2  r   r   )3r   rc   collections.abcr   r   dataclassesr   r   torch.nnr3   torch.nn.functional
functionalr=    r   r   modeling_flash_attention_utilsr   modeling_outputsr	   modeling_utilsr
   r   processing_utilsr   utilsr   r   r   llama.modeling_llamar   (phi4_multimodal.modeling_phi4_multimodalr   configuration_timesfmr   
get_loggerr   r`  r   r$   Moduler*   rJ   rY   r[   r   r   r   r   r2  __all__r    r!   r"   <module>r     sI     . !     & B / F & > > / U 0 
		H	% 	&O 	&  	& - -  - ,!299 !,	\ 	+ +\9)ryy 9)x%")) %B _  6 y) y yxB2 6 B2J Rr!   