
    qi                         d dl mZ d dlmZ d dlmZmZmZmZ  e       rddl	Z	ddl
mZ ddlZddlmZ ddl
mZ  ej$                  e      Zdad Zdee   dz  d	e	j2                  d
efdZe	j8                  Z e	j<                  e      j>                  Z  e	j<                  e      jB                  Z"ejF                  dejH                  fd       Z%d,de	jL                  ded
e'e	jL                  e	jL                  f   fdZ(ejF                  dejH                  dejH                  dejH                  dejH                  fd       Z)ejF                  dejH                  dejH                  dejH                  dejH                  fd       Z*e	jV                  fde	jL                  de	jL                  de	jL                  de	jL                  dee   d	e	j2                  d
e	jL                  fdZ,e	jV                  fde	jL                  de	jL                  de	jL                  de	jL                  dee   d	e	j2                  d
e	jL                  fdZ-e	j\                  de	jV                  fde	jL                  de	jL                  de	jL                  de	jL                  de'eef   dz  d	e	j2                  d
e	jL                  fd        Z/ G d! d"ej`                        Z1d# Z2 G d$ d%ejf                        Z4	 d-d&ee5   dz  fd'Z6 G d( d)e      Z7 G d* d+e      Z8y).   )ConversionOps)should_convert_module)is_kernels_availableis_torch_accelerator_availableis_torch_availablelogging    N)
functionalc                      t         	 ddlm}   | d      a t         rt         S dS # t        $ r%}t        j                  d| d       da Y d}~7d}~ww xY w)zALazily load the CUTLASS quantization kernel from HuggingFace Hub.N   )
get_kernelzRedHatAI/quantizationz,Failed to load CUTLASS quantization kernel: . Falling back to Triton.F)_quantization_kernelhub_kernelsr   	Exceptionloggerwarning_once)r   es     [/opt/pipecat/venv/lib/python3.12/site-packages/transformers/integrations/finegrained_fp8.py_get_quantization_kernelr   "   s_     #	)/#-.E#F  $8ATA  	)"NqcQj kl#( 	)s   $ 	AAA
block_sizeoutput_dtypereturnc                    t               r(t        j                  j                         r
t	               sy|t        j
                  t        j                  fvry| yt        |       dk7  s| d   dk7  s| d   dk7  ryt        j                  j                         }|d   dz  |d   z   }t               }|y	 |j                  |      S # t        $ r Y yw xY w)a;  
    Check if CUTLASS blockwise FP8 matmul is supported for the given block size and output dtype.

    CUTLASS blockwise kernels require:
    - SM90+ (Hopper or newer)
    - Block size [128, 128] for weights
    - Block size [1, 128] for activations (handled implicitly)
    - Output dtype bfloat16 or float16
    Fr   r	      r   
   )r   torchcudais_availabler   bfloat16float16lenget_device_capabilityr   $cutlass_scaled_mm_supports_block_fp8r   )r   r   
capabilitycuda_capabilitykernels        r   _supports_cutlassr(   0   s     uzz'>'>'@H\H^ ENNEMM:: 
:!z!}3z!}7K 113J mb(:a=8O &'F~::?KK s   3C 	CC
BLOCK_SIZEc                    t        j                  d      }||z  t        j                  d|      z   }t        j                  | |z         j	                  t         j
                        }t        j                  t        j                  |            dz  }||z  }|j	                  |j                  j                        }t        j                  ||z   |       t        j                  ||z   |       y )Nr	   axisg      |@)tl
program_idarangeloadtofloat32maxabsdtype
element_tystore)	x_ptry_ptrs_ptrr)   pidoffsxsys	            r   act_quant_kernelr@   ]   s    
--Q
Cbii:66D
  ,A
rvvayE!A	AA	U[[##$AHHUT\1HHUS[!    r=   c                 f     j                         sJ  j                  d   |z  dk(  sJ t        j                   t        j                        }  j
                  g  j                         d d  j                  d      |z  dt        j                  i} fd}t        |    |||       ||fS )Nr	   r5   r5   c                 T    t        j                  j                         | d         fS )Nr)   )tritoncdivnumel)metar=   s    r   gridzact_quant.<locals>.grido   s"    AGGItL'9:<<rA   )r)   )	is_contiguousshaper   
empty_likefloat8_e4m3fn	new_emptysizer2   r@   )r=   r   r?   r>   rJ   s   `    r   	act_quantrQ   i   s    ??772;#q(((%"5"56ARQVVXcr]RAFF2J*$<REMMRA= T1az:a4KrA   BLOCK_SIZE_MBLOCK_SIZE_NBLOCK_SIZE_KGROUP_SIZE_Mc                    t        j                  d      }t        j                  ||      }t        j                  ||      }||z  }||z  }||z  }t        ||z
  |      }|||z  z   }||z  |z  } ||z  t        j                  d|      z   |z  }!| |z  t        j                  d|      z   |z  }"t        j                  d|      }#| |!dddf   |
z  |#dddf   |z  z   z   }$||#dddf   |z  |"dddf   |z  z   z   }%||!|z  z   }&|"|z  }'||'|z  z   }(t        j
                  ||ft         j                        })t        dt        j                  ||            D ]  }*t        j                  |$|#dddf   ||*|z  z
  k  d      }+t        j                  |%|#dddf   ||*|z  z
  k  d      },|*|z  }-|-|	z  }.t        j                  |&|.|z  z         }/t        j                  |(|.|z  z         }0|)t        j                  |+|,      |/dddf   z  |0dddf   z  z  })|$||z  z  }$|%||z  z  }% |j                  j                  t         j                  k(  r |)j                  t         j                        }1nf|j                  j                  t         j                  k(  r |)j                  t         j                        }1n|)j                  t         j                        }1||z  t        j                  d|      z   }2| |z  t        j                  d|      z   }3|||2dddf   z  z   ||3dddf   z  z   }4|2dddf   |k  |3dddf   |k  z  }5t        j                  |4|1|5       y)zTriton-accelerated function used to perform linear operations (dot
    product) on input tensors `A` and `B` with block-wise quantization, and
    store the result in output tensor `C`.
    r	   r+   NrD           maskotherrY   )r-   r.   rG   minr/   zerosr2   ranger0   dotr5   r6   r    r1   r!   r7   )6ABCAsBsMNKgroup_ngroup_k	stride_am	stride_ak	stride_bk	stride_bn	stride_cm	stride_cnstride_As_mstride_As_kstride_Bs_kstride_Bs_nrR   rS   rT   rU   r;   	num_pid_m	num_pid_nnum_pid_in_groupgroup_idfirst_pid_mgroup_size_mpid_mpid_noffs_amoffs_bnoffs_ka_ptrsb_ptrsAs_ptrsoffs_bsnBs_ptrsaccumulatorkabk_startoffs_ksa_sb_scoffs_cmoffs_cnc_ptrsc_masks6                                                         r   _w8a8_block_fp8_matmulr   w   sb   J --Q
C<(I<(I#i/&&H\)Ky;.=L3-.E##4E|#bii<&@@AEG|#bii<&@@AEGYYq,'F'!T'"Y.a91LLMF&D/I-a0@90LLMF7[((G'!H8k))G((L,7rzzJK1bgga./ +GGFa1q<7G3G!GsSGGF41q<7G3G!GsSl"W$ggg+ 556ggg+ 556rvva|c!T'l2Sq\AA,**,**+ 	wwR[[(NN2;;'	
		rzz	)NN2::&NN2::&l"RYYq,%??Gl"RYYq,%??GWQW---	GD!G<L0LLFag"wtQw'7!';<FHHVQV$rA   c                 J   t        j                  d      }t        j                  ||      }t        j                  ||      }||z  }||z  }||z  }t        ||z
  |      }|||z  z   }||z  |z  }||z  t        j                  d|      z   |z  }||z  t        j                  d|      z   |z  }t        j                  d|      }| |dddf   |
z  |dddf   |z  z   z   } ||dddf   |z  |dddf   |z  z   z   }!t        j
                  |      }"t        j
                  |      }#t        j                  ||ft         j                        }$t        dt        j                  ||            D ]  }%t        j
                  | |dddf   ||%|z  z
  k  d      }&t        j
                  |!|dddf   ||%|z  z
  k  d      }'|$t        j                  |&|'      |"z  |#z  z  }$| ||z  z  } |!||z  z  }! |j                  j                  t         j                  k(  r |$j                  t         j                        }(nf|j                  j                  t         j                  k(  r |$j                  t         j                        }(n|$j                  t         j                        }(||z  t        j                  d|      z   })||z  t        j                  d|      z   }*|||)dddf   z  z   ||*dddf   z  z   }+|)dddf   |k  |*dddf   |k  z  },t        j                  |+|(|,       y)zTriton-accelerated function used to perform linear operations (dot
    product) on input tensors `A` and `B` with per-tensor quantization, and
    store the result in output tensor `C`.
    r	   r+   NrD   rW   rX   r[   )r-   r.   rG   r\   r/   r0   r]   r2   r^   r_   r5   r6   r    r1   r!   r7   )-r`   ra   rb   rc   rd   re   rf   rg   rh   ri   rj   rk   rl   rm   rn   ro   rR   rS   rT   rU   r;   rt   ru   rv   rw   rx   ry   rz   r{   r|   r}   r~   r   r   scale_ascale_br   r   r   r   r   r   r   r   r   s-                                                r   !_w8a8_block_fp8_matmul_per_tensorr      s    B --Q
C<(I<(I#i/&&H\)Ky;.=L3-.E##4E|#bii<&@@AEG|#bii<&@@AEGYYq,'F'!T'"Y.a91LLMF&D/I-a0@90LLMFggbkGggbkG((L,7rzzJK1bgga./ +GGFa1q<7G3G!GsSGGF41q<7G3G!GsSrvva|g-77,**,**+ 	wwR[[(NN2;;'	
		rzz	)NN2::&NN2::&l"RYYq,%??Gl"RYYq,%??GWQW---	GD!G<L0LLFag"wtQw'7!';<FHHVQV$rA   r`   ra   rc   rd   c                 :   |d\  }}nt        |      dk(  sJ |d   |d   }}||j                  d   k(  r||j                  d   k(  rd}d}| j                  d   |j                  d   k(  sJ |j                         dk7  rf| j                  dd |j                  dd k(  r| j                         sJ t	        j
                  | j                  d   |      |j                  d   k(  sJ | j                         | j                  d   z  |j                  \  }|j                  dk(  r|j                         sJ |j                         dk7  r|j                  dk(  sJ t	        j
                  |      |j                  d   k(  sJ  d	| d	|j                          t	        j
                  ||      |j                  d   k(  sJ | d	| d	|j                          | j                  dd fz   }	| j                  |	|
      }
d}|k  r!t	        j                        }t        |d      }|}||z  dk(  sJ |}fd}|j                         dk(  r|j                         dk(  r~t        |   | ||
|||||| j                  d      | j                  d      |j                  d      |j                  d      |
j                  d      |
j                  d      |||d       |
S t        |   | ||
|||||| j                  d      | j                  d      |j                  d      |j                  d      |
j                  d      |
j                  d      |j                  d      |j                  d      |j                  d      |j                  d      |||d       |
S )a  This function performs matrix multiplication with block-wise
    quantization.
    It takes two input tensors `A` and `B` with scales `As` and `Bs`.
    The output is returned in the specified `output_dtype`.
    Args:
        A: The input tensor, e.g., activation.
        B: The input tensor, e.g., weight.
        As: The per-token-group quantization scale for `A`.
        Bs: The per-block quantization scale for `B`.
        block_size: The block size for per-block quantization. It should
        be 2-dim, e.g., [128, 128].
        output_dytpe: The dtype of the returned tensor.
    Returns:
        torch.Tensor: The result of matmul.
    N)r   r   r   r	   r   rC   r   , rD      c                 l    t        j                  | d         t        j                  | d         z  fS )NrR   rS   )rF   rG   )METAre   rf   s    r   rJ   z*w8a8_block_fp8_matmul_triton.<locals>.gridQ  s1    AtN34v{{1d>FZ7[[]]rA      )rR   rS   rT   rU   )r"   rL   rH   rK   rF   rG   ndimrO   next_power_of_2r3   r   strider   )r`   ra   rc   rd   r   r   block_nblock_krg   C_shaperb   rR   rT   rS   rJ   re   rf   s                  @@r   w8a8_block_fp8_matmul_tritonr     sr   . #:!###%a=*Q- !''"+'QWWR["8772;!''"+%%%	xxzQwws|rxx},1BBB{{1772;0BHHRL@@@		QWWR[ A77DAq66Q;1??,,,	xxzQww!||{{1g&"((1+5T!Bwir"((7TT5{{1g&"((1+5T!Bwir"((7TT5ggcrlaT!G	G<0AL<--a0<,L\!Q&&&L^ 
xxzQ288:?)$/HHRLHHRLHHQKHHQKHHRLHHRL%%%)	
d H7 	t$HHRLHHRLHHQKHHQKHHRLHHRLIIbMIIbMIIaLIIaL%%%1	
6 HrA   c                    t        ||      r~t               }|q	 | j                  }| j                         | j                  d   z  }| j                  d   }	|j                  d   }
|	dz  dk7  s|
dz  dk7  rt	        d|	 d|
 d      | j                  ||	      j                         }|j                         j                         }|j                  |d      j                         }|j                         j                         j                         }|j                         j                         }|j                         j                         j                         }|j                  |||||d      }|dd |
fz   }|j                  |      S t        | |||||      S # t        $ r#}t        j                  d| d	       Y d}~7d}~ww xY w)
a  
    Dispatch to CUTLASS or Triton for block-wise FP8 matmul.

    Uses CUTLASS when:
    - Block size is [128, 128] (the only size CUTLASS supports)
    - Running on SM90+ (Hopper or newer)
    - The CUTLASS kernel is available
    - Output dtype is bfloat16 or float16 (CUTLASS requirement)
    - Tensor dimensions are compatible (divisible by 16)

    Otherwise falls back to Triton.
    NrC   r	   r   zCUTLASS requires K (z	) and N (z) divisible by 16zCUTLASS kernel failed: r   )r(   r   rL   rH   
ValueErrorview
contiguoustcutlass_scaled_mmr   r   r   r   )r`   ra   rc   rd   r   r   r'   original_shapere   rg   rf   A_2dB_col_majorAs_2dBs_kmrb   r   r   s                     r   w8a8_block_fp8_matmulr     s   * \2)+-\ "#GGI,GGBKGGAJ r6Q;!b&A+$';A3isJ[%\]]vva|..0  lln..0 2113	,,.002 ))+	,,.002 ,,T;ul\`a("-4vvg&
 (1b"j,OO	  \##&=aS@Y$Z[[\s   E/F 	G%GGinput_qweight_qinput_scaleweight_scalec                 ~   | j                   dk(  r| j                  nd| j                  d   | j                  d   f\  }}}|j                  d   }	| j                  d|      }
|j                  |j                  d   d      }|	|d   z  }||d   z  }t        j                  ||z  |	ft        j
                  | j                        }t        |      D ]  }||d   z  }||d   z   }t        |      D ]  }||d   z  }||d   z   }|
dd||f   }|||||f   }|dd||dz   f   }|||f   }t        j                  ||j                         t        j                  dt        j
                  | j                        ||      |z  }|dd||fxx   |z  cc<     |j                  |||	      }|j                  |      S )a  
    Performs blocked matrix multiplication with FP8 quantized matrices.

    Args:
        input_q: Quantized input tensor with 1x128 block quantization
        weight_q: Quantized weight tensor with 128x128 block quantization
        input_scale: Scaling factors for input blocks
        weight_scale: Scaling factors for weight blocks
        block_size: Tuple of (M, N) for weight block dimensions
        output_dtype: Desired output dtype
       r   r	   rC   )r5   deviceN)r   r   	out_dtype)r   rL   r   r   r]   r2   r   r^   
_scaled_mmr   tensorr1   )r   r   r   r   r   r   
batch_sizeseq_len
hidden_dimout_featuresinput_reshapedinput_scale_reshapednum_weight_blocks_mnum_weight_blocks_noutputim_startm_endjn_startn_endinput_blockweight_blockcurr_input_scalecurr_weight_scaleblock_results                             r   w8a8_block_fp8_matmul_compiler     s   ( 8?||q7HgmmqRYR_R_`aRbdkdqdqrsdtNu#J>>!$L \\"j1N&++K,=,=a,@"E&*Q-7$
15[[*w.=U]][b[i[ijF&' 5jm#*Q-'*+ 	5A*Q-'Gjm+E )GEM)9:K#GEM75=$@AL  4Aq1q5yLA ,QT 2    NN$!LL%--W-* ##  1gem#$4$/	5	5: [[Wl;F99\""rA   c                        e Zd Zdej                  ddfdedededeeef   dz  f fdZd	ej                  d
ej                  fdZ
 xZS )	FP8LinearFNdynamicin_featuresr   biasr   c                    t         	|   ||       || _        || _        t        j
                  j                  t	        j                  |||            | _        | j                  >t        j                  t	        j                  dt        j                              | _        n|| j                  d   z   dz
  | j                  d   z  }|| j                  d   z   dz
  | j                  d   z  }t        j                  t	        j                  ||t        j                              | _        | j                  dk(  r=t        j                  t	        j                  dt        j                              | _        |r8t        j                  t	        j                  | j                              | _        y | j                  dd        y )NrD         ?r	   r   staticr   )super__init__r   activation_schemer   nn	Parameteremptyweightr   r2   weight_scale_invactivation_scaler   r   register_parameter)
selfr   r   r   r5   r   r   scale_out_featuresscale_in_features	__class__s
            r   r   zFP8Linear.__init__  sL    	l3 %!2hh((\;V[)\]??"$&LLc1W$XD!".1C"Ca"GDOO\]L^!^!,tq/A!AA!E$//Z[J\ \$&LL.0AW%D! !!X-$&LLc1W$XD!U[[1B1B%CDDI##FD1rA   inputr   c           	      0   | j                   j                         dkD  r+t        j                  || j                   | j                        S t        | j                   t        j                  j                  j                        rI| j                   j                  j                         }| j                  j                  j                         }n4| j                   j                         }| j                  j                         }t               r(t        j                  j                         j                   nd}t#        t        |t        j$                        }|j'                  |j&                        5  | j(                  dk(  rt+        || j,                  d         \  }}n| j(                  dk(  re| j.                  j1                  t        j2                        }||z  j5                  t6        t8              j1                  t        j:                        }nt=        d      t?        ||||| j,                  |j@                        }d d d        |jC                          | j                  | j                  z   }j1                  |j@                        S # 1 sw Y   PxY w)	Nr   r   r   r   r\   r3   zNot supportedr   rD   )"r   element_sizeFlinearr   
isinstancer   distributedr   DTensor_local_tensorr   r   r   acceleratorcurrent_acceleratortypegetattrr   r   r   rQ   r   r   r1   r2   clamp_FP8_MIN_FP8_MAXrN   NotImplementedErrorr   r5   synchronize)	r   r   r   	scale_invdevice_typetorch_accelerator_moduleqinputscaler   s	            r   forwardzFP8Linear.forward;  s   ;;##%)88E4;;		::$++u'8'8'?'?'G'GH22==? 11??JJL	//1 11<<>	JhJj%++??AFFpvK'.uk5::'N$)00> ))Y6$-eT__Q5G$HMFE++x7 1144U]]CE#em22xX2NQQRWReRefF .o>>.OO!&* %002yy$$))+995;;9//3 s   ,CJJ)__name__
__module____qualname__r   rN   intbooltupler   Tensorr   __classcell__r   s   @r   r   r     sk    
 !!-1# 2 2  2 	 2 #s(Od* 2D&0U\\ &0ell &0rA   r   c                     | |z   dz
  |z  S )Nr    )r   r   s     r   	_ceil_divr
  d  s    EAI!rA   c                       e Zd Zej                  f fd	Zdej                  dej                  dej                  dej                  fdZdej                  dej                  d	ej                  dej                  fd
Z xZ	S )	FP8Expertc                 f   t         |           ddlm} || _        t        |d      r|j                  n|j                  | _        |j                  | _	        t        |d      r|j                  n|j                  | _        d| j                  z  | j                  }}| j                  | j                  }}t        j                  t        j                   | j                  |||            | _        t        j                  t        j                   | j                  |||            | _        | j                  \  }	}
t'        ||	      }t'        ||
      }t        j                  t        j                   | j                  ||t        j(                              | _        t'        ||	      }t'        ||
      }t        j                  t        j                   | j                  ||t        j(                              | _        | j/                  dd        | j/                  dd        ||j0                     | _        y )Nr   )ACT2FNnum_local_expertsmoe_intermediate_sizerD   gate_up_bias	down_bias)r   r   activationsr  r   hasattrr  num_expertshidden_sizer   r  intermediate_sizeintermediate_dimr   r   r   r]   gate_up_proj	down_projr
  r2   gate_up_proj_scale_invdown_proj_scale_invr   
hidden_actact_fn)r   configr   r5   r  Wg_outWg_inWd_outWd_inbobi
gu_scale_o
gu_scale_i
dp_scale_o
dp_scale_ir   s                  r   r   zFP8Expert.__init__i  s   ($7>vGZ7[633agasas ,,,3F<S,TF((Z`ZrZr 	 D1114??)>)>LLT5E5Evu\a)bcekk$2B2BFEY^&_`B vr*
ub)
&(llKK((*jV'
#
 vr*
ub)
#%<<KK((*jV$
 
 	5T2 V../rA   hidden_statestop_k_indextop_k_weightsr   c                    t        j                  |      }t        j                         5  t         j                  j                  j                  || j                        }|j                  ddd      }t        j                  |j                  d      d      j                         }d d d        D ]  }|d   }|t        | j                        k(  r"t        j                  |         \  }}	||	   }
| j                  |
| j                  |   | j                  |         j!                  dd      \  }}| j#                  |      |z  }| j                  || j$                  |   | j&                  |         }||	|d f   }||j)                  |j*                        z  }|j-                  d|	|j)                  |j*                                |S # 1 sw Y   'xY w)N)num_classesr   r   r	   )rC   r   dimrC   )r   
zeros_likeno_gradr   r
   one_hotr  permutegreatersumnonzeror"   r  wherer   r  chunkr  r  r  r1   r5   
index_add_)r   r*  r+  r,  final_hidden_statesexpert_mask
expert_hit
expert_idx	top_k_pos	token_idxcurrent_stategateupcurrent_hidden_statesrouting_weightss                  r   r   zFP8Expert.forward  s    $..}=]]_ 	S((--55ktO_O_5`K%--aA6K{8'DaHPPRJ	S
 % 	nJ#AJS!2!233#(;;{:/F#G Iy))4M{{t00<d>Y>YZd>eeA2e D" %)KK$5$:!$(KK%t~~j'A4C[C[\fCg%! ,Iy$,FGO$9O<N<NOdOjOj<k$k!**1i9N9Q9QReRkRk9lm!	n$ #"/	S 	Ss   A=GGr   r   r   c           	      6   |j                         dkD  rt        j                  ||d       S t               r(t        j
                  j                         j                  nd}t        t        |t        j                        }|j                  |j                        5  t        || j                  d         \  }}t        ||||| j                  |j                        }d d d        |j                          j!                  |j                        S # 1 sw Y   5xY w)Nr   r   r   rD   )r   r   r   r   r   r   r   r   r   r   r   rQ   r   r   r5   r   r1   )	r   r   r   r   r   r   r   r   r   s	            r   r   zFP8Expert.linear  s     1$88E6400 KiJj%++??AFFpvK'.uk5::'N$)00> 	 )%1C D.$OO!&	 %002995;;9//	 	s   ADD)
r   r   r  r   rN   r   r  r   r   r  r  s   @r   r  r  h  s    161D1D )0\#||# \\# ||	#
 
#@0ELL 0%,, 0RWR^R^ 0chcoco 0rA   r  modules_to_not_convertc                 v   |j                   r| S d}| j                         D ]  \  }}t        ||      s|ri nddi}d}t        j                  d      5  |j                  d      r1t        d
| j                  j                         |j                  d|}n_t        |t        j                        rEt        d
|j                  |j                  |j                   du|j"                  |j                  d|}|| j%                  ||       d}ddd        |st&        j)                  d	       | S # 1 sw Y   xY w)a  
    A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules.

    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
            Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
        quantization_config (`FbgemmFp8Config`):
            The quantization config object that contains the quantization parameters.
        pre_quantized (`book`, defaults to `False`):
            Whether the model is pre-quantized or not
    Fr5   NrI   z.experts)r  r   )r   r   r   r   r   TzYou are loading your model using fp8 but no linear modules were found in your model. Please double check your model architecture.r	  )
dequantizenamed_modulesr   r   r   endswithr  r  get_text_configweight_block_sizer   r   Linearr   r   r   r   r   set_submoduler   warning)	modelrG  quantization_configpre_quantizedhas_been_replacedmodule_namemodulemodule_kwargs
new_modules	            r   replace_with_fp8_linearrY    sO   " %%$224 )V$[2HI+'4
\\&! 	)##J/&  <<7792DD $

 FBII.&  & 2 2!'!4!4D0&9&K&K2DD $
 %##K<$(!%	) 	))4 <	
 L3	) 	)s   B8D..D8	c                   X    e Zd ZdZd Zdej                  deeej                  f   fdZ	y)Fp8Quantizez^
    A quantization operation that creates two tensors, weight and scale out of a weight.
    c                     || _         y Nhf_quantizerr   r_  s     r   r   zFp8Quantize.__init__
  
    (rA   
input_dictr   c                 l   t        |j                               d   \  }}|d   }d }| j                  j                  kt	        | j                  j                  t
              r&| j                  j                  j                  d      }n!t        | j                  j                  dd       }||j                  d   |j                  d   f}|\  }}|j                  d   |j                  d   }	}||z  dk7  s|	|z  dk7  rt        d| d|	 d| d| d| 
      |j                  d d }
||z  }|	|z  }|j                  }|j                  t        j                        } |j                  g |
|||| }|j                         j                  d	
      }t        j                   |dkD  |t        j"                  |            }t$        |z  }t        j                   |dkD  |t        j"                  |            }|j'                  d      j'                  d      }||z  }t        j(                  |t*        t$              j                  t,              }|j                  |      }d|z  j                  t        j                        }|j/                  d      r|j1                  dd      d   dz   }n|dz   }||||iS )Nr	   rM  r   rC   Matrix dimensions (r   $) must be divisible by block sizes (z). for )rC   r/  rf  r   r   r   .r   z.weight_scale_inv
_scale_inv)r  itemsr_  rR  r   dictgetr   rL   r   r1   r   r2   reshaper4   amaxr8  	ones_liker   	unsqueezer   r   
_FP8_DTYPErK  rsplit)r   rb  kwargstarget_keysvaluer   block_mr   rowscolsleading_shape
rows_tiles
cols_tilesr   
value_fp32reshapedmax_abssafe_max_absscalesscales_broadcastscaled	quantized
inv_scales	scale_keys                           r   convertzFp8Quantize.convert  s   ":#3#3#56q9Ua 
00<$++??F!..BBFFGZ[
$T%6%6%J%JL_aef
++b/5;;r?;J%[[_ekk"od '>Q$.A"5%dV2dV3WX_W``bcjbkkrs~r  A 
 CR(W_
W_
XXemm,
 &:%%_}_j_'_:_W^_ ,,.%%(%3{{7Q;9QR L(Wq[&%//&2IJ "++B/99"=,,KKH(CFFzR	%%n5	Fl&&u}}5
)#**32158KKI#l2I z
 	
rA   N)
r   r   r  __doc__r   r   r  rj  strr  r	  rA   r   r[  r[    s1    )?
%,, ?
T#u||BS=T ?
rA   r[  c            	       p    e Zd ZdZd Z	 ddeeej                  f   dedz  deeej                  f   fdZ	y)	Fp8DequantizeziInverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor.c                     || _         y r]  r^  r`  s     r   r   zFp8Dequantize.__init__R  ra  rA   Nrb  full_layer_namer   c                 l   t        |      dk  r||d   iS |d   d   }|d   d   }|j                  dd  \  }}| j                  j                  j                  }||j                  d   |j                  d   f}|\  }	}
||	z  dk7  s||
z  dk7  rt        d| d| d	|	 d|
 d
	      |j                  |j                        }|j                  d||	z  |	||
z  |
      }|j                  d||	z  ||
z        }|j                  d      j                  d      }||z  }||j                  |j                        iS )Nr   zweight$r	   r   r   rC   rd  r   re  z).)
r"   rL   r_  rR  rM  r   r1   r5   rl  ro  )r   rb  r  rr  r  r  rv  rw  r   ru  r   r|  expanded_scalesdequantizeds                 r   r  zFp8Dequantize.convertU  ss    z?Q#Z	%:;;y)!,	./2__RS)
d&&::LL
#//"-yr/BCJ%'>Q$.A"5%dV2dV3WX_W``bcjbkkmn  LL.	$$R'47?T[\ ..TW_dgoN)33B7AA!D0 [00A
 	
rA   r]  )
r   r   r  r  r   rj  r  r   r  r  r	  rA   r   r  r  O  sP    s) '+ 
ell*+ 
 t 

 
c5<<	  
rA   r  )r   )NNF)9core_model_loadingr   quantizers.quantizers_utilsr   utilsr   r   r   r   r   torch.nnr   rF   triton.languagelanguager-   r
   r   
get_loggerr   r   r   r   listr  r5   r  r(   rN   rp  finfor\   r   r3   r   jit	constexprr@   r  r  rQ   r   r   r2   r   r   compiler   rN  r   r
  Moduler  r  rY  r[  r  r	  rA   r   <module>r     s   / ? e e  ( 
		H	%  B$$s)d"2 $%++ $RV $N   
5;;z"&&5;;z"&& bll  
 
3 
u||U\\?Y9Z 
 Q%4 ,,5Q%6 ,,7Q%8 ,,9Q%: ,,;Q% Q%h E%, ,,-E%. ,,/E%0 ,,1E%2 ,,3E% E%\ !&r||r||r 	r 		r
 S	r ++r \\rv !&HP||HP||HP 	HP 		HP
 S	HP ++HP \\HPX  *. %>#\\>#ll># ># ,,	>#
 c3h$&># ++># \\># >#BI0		 I0Xc0		 c0N ej4#'9t#34nG
- G
T&
M &
rA   