
    i                        d Z ddlZddlmZ ddlZddlmZ ddlmZ ddl	m
Z
 ddlmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZmZmZmZ ddlmZ ddlmZ  ej>                  e       Z!dejD                  de#fdZ$d Z%d Z&d Z' G d dej
                  jP                        Z) G d dejP                        Z* G d dejP                        Z+ G d de      Z,e G d d e             Z-e ed!"       G d# d$e                    Z.e ed%"       G d& d'e                    Z/e G d( d)e-             Z0 ed*"       G d+ d,e-e             Z1g d-Z2y).zPyTorch MAMBA2 model.    N)	dataclass)nn   )initialization)ACT2FN)CacheDynamicCache)GenerationMixin)lazy_load_kernel)GradientCheckpointingLayer)PreTrainedModel)ModelOutputauto_docstringis_torchdynamo_compilinglogging)resolve_internal_import   )Mamba2Configinput_tensorpad_sizec                     t        | j                        dk(  r
ddddd|ddfnddd|ddf}t        j                  j                  j                  | |dd      S )z
    Padding x tensor with `pad_size` on the seq_len dim (dim=1)

    Assumes that we only have tensors of either size 4 or 3
       r   constant)modevalue)lenshapetorchr   
functionalpad)r   r   	pad_shapes      {/var/www/vps2.regionflexible.com/Desarrollo/venv/lib/python3.12/site-packages/transformers/models/mamba2/modeling_mamba2.pypad_tensor_by_sizer#   (   sf     47|7I7I3Ja3OAq!Q!Q/VWYZ\]_gijlmUnI88""<ST"UU    c                    t        | |      } t        | j                        dk(  r.| j                  | j                  d   d|| j                  d         S | j                  | j                  d   d|| j                  d   | j                  d         S )z
    Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
    simultaneously splitting it into chunk sequences.

    Assumes that we only have tensors of either size 4 or 3
    r   r      )r#   r   r   reshape)r   r   
chunk_sizes      r"   reshape_into_chunksr*   3   s     &lH=L
<!###L$6$6q$92z<K]K]^_K`aa ##q!2z<3E3Ea3H,J\J\]^J_
 	
r$   c                 "   | j                  d      } | d   j                  g | j                         | } t        j                  t        j                  ||| j
                  t        j                        d      }| j                  | d      } t        j                  | d      }t        j                  t        j                  ||| j
                  t        j                        d      }|j                  | t        j                         }|S )zo
    More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
    r&   .Ndevicedtype)diagonalr   dim)
sizeexpandr   trilonesr.   boolmasked_fillcumsuminf)r   r)   masktensor_segsums       r"   segment_sumr>   G   s     ""2&J 2<	*11S<3D3D3FS
SL::ejjZ@S@S[`[e[efqstD++TE15LLL26M ::ejjZ@S@S[`[e[efqrsD!--teeiiZ@Mr$   c                     |N|j                   d   dkD  r<|j                   d   dkD  r*| j                  }| |dddddf   z  j                  |      } | S )zm
    Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
    Nr   r   )r   r/   to)hidden_statesattention_maskr/   s      r"   apply_mask_to_padding_statesrC   [   sa    
 !n&:&:1&=&AnFZFZ[\F]`aFa##&1d
)CCGGNr$   c                   (     e Zd Zd fd	ZddZ xZS )MambaRMSNormGatedc                     t         |           t        j                  t	        j
                  |            | _        || _        y Nsuper__init__r   	Parameterr   r7   weightvariance_epsilonselfhidden_sizeeps	__class__s      r"   rJ   zMambaRMSNormGated.__init__h   s/    ll5::k#:; #r$   c                    |j                   }|j                  t        j                        }|?|t        j
                  j                  |j                  t        j                              z  }|j                  d      j                  dd      }|t        j                  || j                  z         z  }| j                  |j                  |      z  S Nr'   r&   T)keepdim)r/   r@   r   float32r   r   silupowmeanrsqrtrM   rL   )rO   rA   gateinput_dtypevariances        r"   forwardzMambaRMSNormGated.forwardm   s    #))%((7)BMM,>,>twwu}}?U,VVM $$Q',,R,>%Ht?T?T4T(UU{{]--k:::r$   gư>rG   __name__
__module____qualname__rJ   r^   __classcell__rR   s   @r"   rE   rE   g   s    $
	;r$   rE   c                   2    e Zd ZdZddededef fdZ ej                         d        Z
	 	 ddej                  d	edz  d
ej                  dz  fdZ	 	 ddej                  d	edz  d
ej                  dz  fdZ	 	 dd	edz  d
ej                  dz  fdZ xZS )Mamba2Mixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    config	layer_idxinitialize_mixer_weightsc           	      n   t         |           |j                  | _        |j                  | _        |j                  | _        |j                  | _        t        |j                  | j                  z        | _
        t        |j                        | _        || _        |j                  | _        |j                  | _        t         |j                     | _        |j$                  | _        |j&                  | _        |j(                  | _        |j*                  | _        |j,                  | _        |j.                  | _        |j0                  | _        |j2                  | _        |j4                  | _        | j                  d| j(                  z  | j
                  z  z   | _        t9        j:                  | j6                  | j6                  |j                  |j                  | j6                  |j                  dz
        | _        | j                  | j6                  z   | j                  z   }t9        j>                  | j                  ||j@                        | _!        t9        jD                  tG        jH                  | j                              | _%        t9        jD                  tG        jH                  | j                              | _&        tO        | j                  | j$                        | _(        t9        jD                  tG        jH                  | j                              | _)        |r3| jJ                  jT                  jV                  dk7  r| jY                          t9        j>                  | j                  | j                  |j@                        | _-        |j@                  | _         t]        d      }t_        |dd       a0t_        |d	d       a1t]        d
      }te        |d      a3te        |d      a4te        |d      a5tm        tf        th        tj        tb        t`        f      a7tn        stp        js                  d       y y )Nr'   r   )in_channelsout_channelsbiaskernel_sizegroupspaddingrn   rQ   metazcausal-conv1dcausal_conv1d_updatecausal_conv1d_fnz	mamba-ssmz8ops.triton.selective_state_update.selective_state_update)chained_pathz1ops.triton.ssd_combined.mamba_chunk_scan_combinedz8ops.triton.ssd_combined.mamba_split_conv1d_scan_combineda  The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d):rI   rJ   	num_headsrP   
state_sizessm_state_sizeconv_kernelconv_kernel_sizeintr5   intermediate_sizetime_step_rankri   use_conv_bias
hidden_act
activationr   actlayer_norm_epsilonrms_normn_groupshead_dimr)   time_step_limittime_step_mintime_step_maxtime_step_floorconv_dimr   Conv1dconv1dLinearuse_biasin_projrK   r   emptydt_biasA_logrE   normDr.   typeinit_mamba2_weightsout_projr   getattrru   rv   r   selective_state_updatemamba_chunk_scan_combined mamba_split_conv1d_scan_combinedallis_fast_path_availableloggerwarning_once)rO   rh   ri   rj   projection_sizecausal_conv1d	mamba_ssmrR   s          r"   rJ   zMamba2Mixer.__init__   sD   ))!--$// & 2 2!$V]]T5E5E%E!F!&"7"78"#11 ++&++,"(";"; ++%55#11#11%55..T]]1BTEXEX1XXii%%**==&&*
 004==@4>>Qyy
 ||EKK$?@ \\%++dnn"=>
%d&<&<$BYBYZ	ekk$..9:#(;(;(@(@F(J$$&		$"8"8$:J:JQWQ`Q`a )9&}6LdS"=2DdK %[1	!8$^"
 %<$W%
! ,C$^,
(
 "%&)0 $"
 &> &r$   c                 t   t        j                  d| j                  dz   | j                  j                  t         j
                        }t        j                  | j                  t        j                  |             t        j                  | j                         t        j                  t        j                  | j                  | j                  j                  t         j
                        t        j                  | j                        t        j                  | j                         z
  z  t        j                  | j                         z         j#                  | j$                        }|t        j                  t        j&                  |              z   }t        j                  | j                  |       y )Nr   r-   )min)r   arangerx   r   r.   rV   initcopy_logones_r   exprandr   mathr   r   clampr   expm1)rO   Adtinv_dts       r"   r   zMamba2Mixer.init_mamba2_weights   s   LLDNNQ.tzz7H7HPUP]P]^

4::uyy|,

466YYJJt~~dll.A.AWxx**+dhht7I7I.JJLhht))*+
 %D((%
)	 	 eiibS!1 122

4<<(r$   NrA   cache_paramsrB   c                 $   t        ||      }| j                  |      }|j                  \  }}}| j                  | j                  z  }|j                  d   d| j
                  z  z
  d| j                  z  | j                  z  z
  | j                  z
  dz  }	|*|j                  | j                        r|j                  d      j                  |	|	| j
                  | j                  | j                  gd      \  }}}
}}t        ||j                  | j                     j                  | j                  j                   j                  d      | j                  j"                  | j$                        }t'        j                  || j
                  ||gd      \  }}}t'        j(                  | j*                  j-                                }|d d d df   d d d d d f   j/                  d| j0                  | j                        j3                  t&        j4                        }|d d d d d f   j/                  dd| j0                        }| j6                  d d d df   j/                  d| j0                        }| j8                  d d d df   j/                  d| j0                        }|j;                  || j                  |j                  d   | j                  z        }|j;                  || j                  |j                  d   | j                  z        }|j;                  || j                  | j0                        }t=        |j                  | j                     j>                  ||||||d |d
      }|j;                  || j                  | j0                  z        }| jA                  ||
      }| jC                  |      d d d df   }|S t'        j(                  | j*                  j-                                }| jD                  d	t-        d
      fk(  ri nd| jD                  i}| jF                  r|tI        || j                  j                   j                  d      | j                  j"                  | j6                  |f| j8                  | jJ                  d | j$                  | j@                  j                   | j@                  jL                  | jB                  j                   | jB                  j"                  | j0                  | j                  ddd|}|S |j                  |	|	| j
                  | j                  | j                  gd      \  }}}
}}|k|jO                  dd      }tP        jR                  jU                  || jV                  |j                  d   z
  df      }|jY                  || j                        }| j$                  dvrH| j[                  | j                  |jO                  dd            dd |f   jO                  dd            }npt]        |jO                  dd      | j                  j                   j                  d      | j                  j"                  | j$                        jO                  dd      }t        ||      }t'        j                  || j
                  ||gd      \  }}}t_        |j;                  ||d| j0                        |||j;                  ||| j                  d      |j;                  ||| j                  d      f| jJ                  | j8                  d d d| j6                  dd|\  }}|||ja                  || j                         |j;                  ||d      }| jA                  ||
      }| jC                  |      }|S )Nr&   r'   r   r2   .r/   T)zr   dt_softplusg        r;   dt_limitF)r   r)   seq_idxr   rmsnorm_weightrmsnorm_epsoutproj_weightoutproj_biasheaddimngroupsnorm_before_gatereturn_final_statesr   ri   )rW   swish)xrL   rn   r   )r)   r   r   r   r   r   r   )1rC   r   r   r   rz   r~   rx   has_previous_stateri   squeezesplitr   ru   layersconv_statesr   rL   rn   r   r   r   r   floatr5   r   r@   rV   r   r   viewr   recurrent_statesr   r   r   trainingr   r)   rM   	transposer   r   r    r|   update_conv_stater   rv   r   update_recurrent_state)rO   rA   r   rB   projected_states
batch_sizeseq_len_groups_time_state_sized_mlpr[   hidden_states_B_Cr   BCr   r   r   hidden_states_reshapedoutdt_limit_kwargshidden_states_B_C_transposedr   scan_output	ssm_states                            r"   cuda_kernels_forwardz Mamba2Mixer.cuda_kernels_forward   s    5]NS<<6 "/!4!4
GQ!%1D1D!D""2&$((()$--$"5"556 nn  #(G(G(W0@0H0H0K0Q0Qt55t}}dnnU[] 1R 1-Aq$)2
 !5!##DNN3??""**1-  ! #(++!'')?AWX#M1a 4::++-..A!T3,1d
+222t}}dFYFYZ]]didqdq]rAAq$J&&r2t}}=Bll1dC<077DMMJGq$|$++B>Az4==!''!*2MNAz4==!''!*2MNA%2%7%7
DNNTXTaTa%b"2##DNN3DD& M *..z4>>DMM;YZM IImT:M --.q$|<Cv 
o 4::++-..A$($8$8S%,<O$ObV`bfbvbvUwO }}!56$KK&&..q1KK$$LL ff# ##'99#3#3 $		 : :#'==#7#7!%!3!3 MM MM%*(-#$ &%d 
y 5E4J4JE4#9#94==$..Y_a 5K 511d-r  +3D3N3NqRS3T0"$--"3"34..1M1S1STV1WWYZ[#K #/"@"@X\XfXf"@"gK??*;;(,$5$?$?1$EFsHWH}U__`acde)% )9+55a;#{{1199!<![[--#'??	)
  i1o & %AARTb$c!&+kk%++-CE[\'#q! *C!&&z7BNFF:wrBFF:wrB*  $ff (, LL $* &*&Y" (\-E 77	T^^7\)..z7BG"iiT: mmK0
r$   c                 r   |j                   \  }}}|j                  }t        ||      }| j                  |      }|j                   d   d| j                  z  z
  d| j
                  z  | j                  z  z
  | j                  z
  dz  }	|j                  |	|	| j                  | j                  | j                  gd      \  }}}
}}|j                  dd      }|d uxr |j                  | j                        }|r|j                  || j                        }t        j                  || j                   j"                  j%                  d      z  d      }| j&                  r|| j                   j(                  z   }| j+                  |      }n|Yt,        j.                  j1                  || j2                  |j                   d   z
  df      }|j                  || j                         | j+                  | j!                  |      dd |f   j                  dd            }t        ||      }t        j                  || j                  | j
                  | j                  z  | j
                  | j                  z  gd      \  }}}t        j4                  | j6                  j9                                }|r|j:                  | j                     j<                  }|d d dd d f   d d d df   }|j                  dd      j?                  ||j                   d   | j@                        }| jB                  d   j?                  | jB                  j                   d   | j@                        }t        j,                  j.                  jE                  ||jG                  |j                        z         }t        jH                  || jJ                  d   | jJ                  d         }|d	   j?                  | j                  | j@                  | j                        jG                  t        jL                  
      }t        j4                  |d   |z        jG                  |      }|jO                  || j
                  d      dd d d f   }|j?                  || j
                  | j                  | j
                  z  |j                   d         jQ                         }|jO                  |d|j                   d         }|d   |dd d d f   z  }|jO                  |d| j@                        }||d   z  jG                  |      }|j:                  | j                     jR                  |z  |z   }|jU                  || j                        }|jO                  || j
                  d      dd d d f   }|j?                  || j
                  | j                  | j
                  z  |j                   d         jQ                         }|jO                  |d|j                   d         }|jG                  |j<                  |j                        }|jW                  || j                  z  | j@                  | j                        }|jW                  || j                  z  | j                  d      }t        jX                  ||      }|jW                  || j                  | j@                        }| jZ                  d   j?                  | jZ                  j                   d   | j@                        }|||z  z   jG                  |j                        }|jO                  |d      d d d df   }nt,        j.                  jE                  || jB                  z         }t        jH                  || jJ                  d   | jJ                  d         }|jO                  ||d| j@                        j9                         }|jO                  ||d| j                        j9                         }|jO                  ||d| j                        j9                         }|j]                  | j                  | j
                  z  d| j                        }|j]                  | j                  | j
                  z  d| j                        }| j^                  || j^                  z  z
  | j^                  z  }| jZ                  d   ta        ||      z  }||d   z  }|jG                  |j                        |z  }||||fD cg c]  }tc        ||| j^                         c}\  }}}}|je                  dddd      }t        jf                  |d      }t        j4                  ti        |            } |d d d d d d d d d d d f   |d d d d d d d d d d d f   z  }!|!j                  d      }"|"d   | je                  ddddd      d   z  }#|#j                  d      }$|$d   |d d d d d f   z  j                  d      }%t        j4                  |d d d d d d dd f   |z
        }&||&je                  dddd      d   z  }'|'dd d d f   |d   z  j                  d      }(t        jj                  |(d d d df         })t        jl                  |)|(gd      }(t        j4                  ti        t,        j.                  j1                  |d d d d d d df   d                  }*|*j                  dd      }*|*d	   |(d d d d d df   z  j                  d      }+|+d d d df   |+d d df   },}(t        j4                  |      }-|dd d d f   |(d d d d d df   z  }.|-je                  dddd      }/|.j                  d      |/d   z  }0|%|0z   }|jO                  |d| j                  | j@                        }||z   }|dkD  r|d d d |d d d d f   }|jO                  ||d      }|,||jU                  |,| j                         | jo                  ||
      }1| jq                  |1jG                  |            }2|2S c c}w )Nr&   r'   r2   r   r   r   .r,   ).NNr   )r.   r-   )r3   output_sizer   r   r1   )r   r   )9r   r/   rC   r   r~   r   rz   rx   r   r   r   r   ri   r   r   sumr   rL   r   r   rn   r   r   r   r    r|   r   r   r   r   r.   r5   r   r   softplusr@   r   r   rV   r(   
contiguousr   r   r   bmmr   repeat_interleaver)   r#   r*   permuter:   r>   
zeros_likecatr   r   )3rO   rA   r   rB   r   r   r   r/   r   r   r[   r   r   is_decodingr   r   r   r   cache_devicer   dAdBdBx
ssm_statesssm_states_reshaped
C_reshapedyr   r   
D_residualtA_cumsumLG_intermediateGM_intermediateMY_diagdecay_statesB_decaystatesprevious_statesdecay_chunk
new_statesr   state_decay_outC_times_statesstate_decay_out_permutedY_offr   contextualized_statess3                                                      r"   torch_forwardzMamba2Mixer.torch_forward  s)    "/!4!4
GQ## 5]NS<<6!''+a$2H2H.HH1t}}K\_c_r_rKrrsw  tB  tB  B  GH  H,<,B,Bt55t~~V\^ -C -
)1d%r .77!<"$.b<3R3RSWSaSa3b &889JVZVdVd8eK %		dkk0088;;! !!$58H8H$H! $): ; ' mm//%(=(=@Q@W@WXZ@[([]^'_ ..{dnn.U $5F)GXgX)V)`)`abde)f g89JN[#kk##T]]T5H5H%H$--Z^ZmZmJmn
q! YYtzz'')**'..t~~>EEL Aq!GQc\*Ba#**:rxx|T]]SBll9-44T\\5G5G5JDMMZG$$--b7::bhh3G.GHBR!5!5a!8$:N:Nq:QRB/"))$..$--I\I\]``glgtgt`uA))ByMA-.22,2GB
 		*dmmR8dAFAT]]DNNdmm4SUVU\U\]_U`allnA		*b!''"+6AI3a<0B *11*b$--PMi0044L4IC &,,T^^<MMPRRUXXJ%<<ZSWSaSa<bJ 		*dmmR8dAFAT]]DNNdmm4SUVU\U\]_U`allnA		*b!''"+6A $ahhaggFJ",//*t~~2Mt}}^b^q^q"r
T^^ ;T=P=PRSTJ		-z:Az4>>4==AA y!((a$--HA]Q&&**1773A 		*b)!T3,7A ''T\\(9:BR!5!5a!8$:N:Nq:QRB)11*gr4==Y__aM		*gr43F3FGMMOA		*gr43F3FGMMOA##DNNdmm$CX\XfXf#gA##DNNdmm$CX\XfXf#gA'DOO*CCtVH	*-?x-XXJ *ByM9M](()B.A cpqrtuwxay%z\]&9!Xt&W%z"M1a 		!Q1%A||A2.H 		+a.)A q!Qa23a1dAq!8K6LLN""r"*A y\AIIaAq!,DY,OON""r"*A 	l]1a:%>>CCCJF !99XaArsl%;h%FGL,..q"b!<YGGGc4l+mI.FFKKPQKRF $..va!e}=OYY8a@F))K0A0A(1aQRTV;BWY_0`$abK%//15K%o61dC9PPUUZ[U\J *1crc6 2Jq"u4EIF $ii1OT1oq!T30GGN'6'>'>q!Q'J$#''+.Fy.QQE A		*b$..$--HAJA!|a'1a'(		*gr2A $)A33I3Xii4(
 !%knnU.C D$$A &{s   p4c                     t         rId| j                  j                  j                  j                  v rt               s| j                  |||      S | j                  |||      S )Ncuda)r   r   rL   r.   r   r   r   r	  )rO   rA   r   rB   kwargss        r"   r^   zMamba2Mixer.forwardL  sT     "f0C0C0J0J0O0O&OXpXr,,]L.YY!!-~NNr$   )TNN)ra   rb   rc   __doc__r   r}   r8   rJ   r   no_gradr   Tensorr   r   r	  r^   rd   re   s   @r"   rg   rg   y   s    [| [ [W[ [z U]]_) )$ &*.2	]||] dl] t+	]F &*.2	{%||{% dl{% t+	{%B &*.2		O dl	O t+		Or$   rg   c                   &     e Zd Zd fd	Zd Z xZS )Mamba2RMSNormc                     t         |           t        j                  t	        j
                  |            | _        || _        y)zM
        Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
        NrH   rN   s      r"   rJ   zMamba2RMSNorm.__init__Y  s1     	ll5::k#:; #r$   c                 "   |j                   }|j                  t        j                        }|j	                  d      j                  dd      }|t        j                  || j                  z         z  }| j                  |j                  |      z  S rT   )	r/   r@   r   rV   rX   rY   rZ   rM   rL   )rO   rA   r\   r]   s       r"   r^   zMamba2RMSNorm.forwarda  sy    #))%((7 $$Q',,R,>%Ht?T?T4T(UU{{]--k:::r$   r_   r`   re   s   @r"   r  r  X  s    $;r$   r  c                   T     e Zd Z fdZ	 	 ddedz  dej                  dz  fdZ xZS )Mamba2Blockc                     t         |           || _        || _        |j                  | _        t        |j                  |j                        | _        t        ||d      | _
        y )Nrs   F)ri   rj   )rI   rJ   rh   ri   residual_in_fp32r  rP   r   r   rg   mixer)rO   rh   ri   rR   s      r"   rJ   zMamba2Block.__init__j  sU    " & 7 7!&"4"4&:S:ST	 9W\]
r$   Nr   rB   c                    |}| j                  |j                  | j                   j                  j                              }| j                  r|j                  t
        j                        }| j                  |||      }||z   }|S )Nr   r   rB   )r   r@   rL   r/   r  r   rV   r  )rO   rA   r   rB   r  residuals         r"   r^   zMamba2Block.forwardr  su     !		-"2"29I9I9O9O"2"PQ  {{5==1H

=|\j
k =0r$   r  )	ra   rb   rc   rJ   r   r   r  r^   rd   re   s   @r"   r  r  i  s7    ^ &*.2	 dl t+	r$   r  c                   X    e Zd ZU eed<   dZdgZdZdZ e	j                         d        Zy)Mamba2PreTrainedModelrh   backboner  Tc                 L   | j                   j                  }t        |t              r#|j	                          t        j                  |j                  j                  t        j                  d             |j                  j                  )t        j                  |j                  j                         t        j                  |j                  j                  t        j                  d             | j                   j                  rB|j                  j                  }|t        j                  | j                   j                        z  }t        |t         j"                        rNt        j$                  |j                  |       |j                   t        j                  |j                         yyt        |t&        t(        f      r t        j*                  |j                         yt        |t         j,                        r"t        j$                  |j                  |       yy)zInitialize the weights.   )aN)std)rh   initializer_range
isinstancerg   r   r   kaiming_uniform_r   rL   r   sqrtrn   zeros_r   rescale_prenorm_residualnum_hidden_layersr   r   normal_r  rE   r   	Embedding)rO   moduler#  ps       r"   _init_weightsz#Mamba2PreTrainedModel._init_weights  sT    kk++fk* &&(!!&--"6"6$))A,G}}!!-FMM../!!&//"8"8DIIaLI{{33 OO**TYYt{{<<==fbii(LLC0{{&FKK( '0A BCJJv}}%-LLC0 .r$   N)ra   rb   rc   r   __annotations__base_model_prefix_no_split_modulessupports_gradient_checkpointing_is_statefulr   r  r/   r$   r"   r  r    s;    "&&*#LU]]_"1 "1r$   r  z-
    Class for the MAMBA2 model outputs.
    )custom_introc                   |    e Zd ZU dZdZej                  dz  ed<   dZe	dz  ed<   dZ
eej                     dz  ed<   y)Mamba2Outputa4  
    cache_params (`Cache`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.

        Includes both the State space model state matrices after the selective scan, and the Convolutional states
    Nlast_hidden_stater   rA   )ra   rb   rc   r  r9  r   FloatTensorr0  r   r   rA   tupler5  r$   r"   r8  r8    sG     37u((4/6!%L%$,%59M5**+d29r$   r8  zK
    Base class for causal language model (or autoregressive) outputs.
    c                       e Zd ZU dZdZej                  dz  ed<   dZej                  dz  ed<   dZ	e
dz  ed<   dZeej                     dz  ed<   y)Mamba2CausalLMOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    cache_params (`Cache`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.

        Includes both the State space model state matrices after the selective scan, and the Convolutional states
    Nlosslogitsr   rA   )ra   rb   rc   r  r>  r   r:  r0  r?  r   r   rA   r;  r5  r$   r"   r=  r=    s[    
 &*D%

d
")'+FE$+!%L%$,%59M5**+d29r$   r=  c                        e Zd Z fdZd Zd Zd Ze	 	 	 	 	 	 	 ddej                  dz  dej                  dz  de
dz  d	edz  d
edz  dedz  dej                  dz  deez  fd       Z xZS )Mamba2Modelc           	         t         |   |       t        j                  |j                  |j
                        | _        t        j                  t        |j                        D cg c]  }t        ||       c}      | _        d| _        t        |j
                  |j                        | _        | j!                  | j"                         | j%                          y c c}w )Nr   Frs   )rI   rJ   r   r,  
vocab_sizerP   
embeddings
ModuleListranger*  r  r   gradient_checkpointingr  r   norm_f"_register_load_state_dict_pre_hook	load_hook	post_init)rO   rh   idxrR   s      r"   rJ   zMamba2Model.__init__  s     ,,v'8'8&:L:LMmmSXY_YqYqSr$sC[3%G$st&+##F$6$6F<U<UV//? %ts   &Cc                 f    |D ],  }d|v s|j                  |      ||j                  dd      <    y  y )Nz
embedding.zembeddings.)popreplace)rO   
state_dictprefixargsks        r"   rJ  zMamba2Model.load_hook  s;     	Aq EO^^TUEV
199\=AB	r$   c                     | j                   S rG   rD  rO   s    r"   get_input_embeddingsz Mamba2Model.get_input_embeddings  s    r$   c                     || _         y rG   rU  rO   new_embeddingss     r"   set_input_embeddingsz Mamba2Model.set_input_embeddings  s	    (r$   N	input_idsinputs_embedsr   	use_cacheoutput_hidden_statesreturn_dictrB   returnc                 ^   ||n| j                   j                  }||n#| j                  s| j                   j                  nd}||n| j                   j                  }|du |duz  rt        d      || j                  |      }| j                  r| j                  r|rd}|r|t        | j                         }|}	|rdnd}
| j                  D ]  } ||	||      }	|s|
|	fz   }
 | j                  |	      }	|r|
|	fz   }
|st        d |	||
fD              S t        |	|r||
      S d|
      S )	a  
        cache_params (`Cache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        NFz:You must specify exactly one of input_ids or inputs_embeds)rh   r5  r  c              3   &   K   | ]	  }||  y wrG   r5  ).0vs     r"   	<genexpr>z&Mamba2Model.forward.<locals>.<genexpr>1  s     fqXYXefs   )r9  r   rA   )rh   r_  r   r^  r`  
ValueErrorrD  rG  r	   r   rH  r;  r8  )rO   r\  r]  r   r^  r_  r`  rB   r  rA   all_hidden_statesmixer_blocks               r"   r^   zMamba2Model.forward  se   ( %9$D $++JjJj 	 "+!6IZ^ZgZgT[[=R=Rmr	%0%<k$++BYBY-t";<YZZ  OOI6M&&4==YI-'t{{;L%"6BD;; 	IK')-M $$58H$H!	I M2 1]4D Df]LBS$Tfff+)2+
 	
8<+
 	
r$   )NNNNNNN)ra   rb   rc   rJ   rJ  rW  r[  r   r   
LongTensorr   r8   r  r;  r8  r^   rd   re   s   @r"   rA  rA    s    
)  .215%)!%,0#'.2<
##d*<
 ''$.<
 dl	<

 $;<
 #Tk<
 D[<
 t+<
 
	<
 <
r$   rA  z
    The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
    embeddings).
    c                   j    e Zd ZddiZ fdZd Zd Z	 	 	 	 	 ddedz  dej                  dz  d	e
dz  f fd
Ze	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  dedz  dej                  dz  de
dz  de
dz  de
dz  dej                  dz  deej                  z  deez  fd       Z xZS )Mamba2ForCausalLMzlm_head.weightzbackbone.embeddings.weightc                     t         |   |       t        |      | _        t	        j
                  |j                  |j                  d      | _        | j                          y )NFrr   )
rI   rJ   rA  r  r   r   rP   rC  lm_headrK  )rO   rh   rR   s     r"   rJ   zMamba2ForCausalLM.__init__C  sF     #F+yy!3!3V5F5FUSr$   c                 6    | j                   j                         S rG   )r  rW  rV  s    r"   rW  z&Mamba2ForCausalLM.get_input_embeddingsJ  s    }}1133r$   c                 8    | j                   j                  |      S rG   )r  r[  rY  s     r"   r[  z&Mamba2ForCausalLM.set_input_embeddingsM  s    }}11.AAr$   Nr   rB   is_first_iterationc           	      F    t        	|   |f|||||d|}|r|sd |d<   |S )N)r]  r^  r   rB   rq  rB   )rI   prepare_inputs_for_generation)
rO   r\  r]  r^  r   rB   rq  r  model_inputsrR   s
            r"   rs  z/Mamba2ForCausalLM.prepare_inputs_for_generationP  sN     w<
'%)1
 
 /-1L)*r$   r\  r]  labelsr_  r`  r^  logits_to_keepra  c
           	      &   ||n| j                   j                  }| j                  |||||||      }|d   }t        |	t              rt        |	 d      n|	}| j                  |dd|ddf   j                  | j                  j                  j                              j                         }d}|* | j                  d||| j                   j                  d|
}|s|f|dd z   }||f|z   S |S t        |||j                  |j                        S )aN  
        cache_params (`Cache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        N)r   r]  r_  r`  r^  rB   r   )r?  ru  rC  r   )r>  r?  r   rA   r5  )rh   r`  r  r%  r}   slicern  r@   rL   r/   r   loss_functionrC  r=  r   rA   )rO   r\  r]  r   ru  r_  r`  r^  rB   rv  r  mamba2_outputsrA   slice_indicesr?  r>  outputs                    r"   r^   zMamba2ForCausalLM.forwardi  s-   2 &1%<k$++BYBY%'!5#) ' 
 'q)8B>SV8W~ot4]kmA}a,?@CCDLLDWDWD]D]^_eeg%4%%pVFt{{OeOepiopDY!33F)-)9TGf$EvE#'44(66	
 	
r$   )NNNNF)	NNNNNNNNr   )ra   rb   rc   _tied_weights_keysrJ   rW  r[  r   r   r  r8   rs  r   rj  r:  r}   r;  r=  r^   rd   re   s   @r"   rl  rl  :  sB    +,HI4B %).2*/
 dl t+ !4K2  .226%)*.,0#'!%.2-.6
##d*6
 ((4/6
 dl	6

   4'6
 #Tk6
 D[6
 $;6
 t+6
 ell*6
 
%	%6
 6
r$   rl  )rl  rA  r  )3r  r   dataclassesr   r   r    r   r   activationsr   cache_utilsr   r	   
generationr
   integrationsr   modeling_layersr   modeling_utilsr   utilsr   r   r   r   utils.import_utilsr   configuration_mamba2r   
get_loggerra   r   r  r}   r#   r*   r>   rC   ModulerE   rg   r  r  r  r8  r=  rA  rl  __all__r5  r$   r"   <module>r     s     !   & ! . ) , 9 - S S 9 . 
		H	%VU\\ VS V
((	; ;$\O")) \O~;BII ;", 4 *1O *1 *1Z :; : : :; : :& V
' V
 V
r `
- `
`
F Hr$   