
    i                         d Z ddlZddlmZ ddlZddlmZ ddlmZmZm	Z	 ddl
mZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZmZ ddlmZ ddlmZ ddlmZmZmZm Z m!Z!m"Z"m#Z# ddl$m%Z%m&Z& ddl'm(Z( ddl)m*Z*m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0 ddl1m2Z2 ddl3m4Z4m5Z5 ddl6m7Z7  e-       r	  e/jp                  e9      Z:dejv                  de<fdZ= G d dej|                        Z? G d dej|                        Z@	 	 dDdej                  dejv                  dejv                  d ejv                  d!ejv                  dz  d"eBdz  d#eBd$e(e*   fd%ZC G d& d'ej                        ZD G d( d)e      ZE G d* d+e      ZF G d, d-ej                        ZGe+ G d. d/e&             ZH G d0 d1eH      ZI G d2 d3eH      ZJe+ G d4 d5eH             ZK e+d67       G d8 d9eHe             ZL e+d:7       G d; d<eH             ZMe+ G d= d>eH             ZN G d? d@eH      ZO G dA dBeHe      ZPg dCZQy)EzPyTorch MBART model.    N)Callable)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )initialization)ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)create_bidirectional_maskcreate_causal_mask)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentionsSeq2SeqLMOutputSeq2SeqModelOutput#Seq2SeqQuestionAnsweringModelOutputSeq2SeqSequenceClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tupleis_torch_flex_attn_availableis_torchdynamo_compilingloggingtorch_compilable_check)merge_with_config_defaults)OutputRecordercapture_outputs   )MBartConfig	input_idspad_token_idc                 f   | j                         }|t        d      |j                  |dk(  |       |j                  |      j	                  d      dz
  j                  d      }|j                  d|      j                         }|ddddf   j                         |ddddf<   ||dddf<   |S )z
    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
    have a single `decoder_start_token_id` in contrast to other Bart-like models.
    Nz1self.model.config.pad_token_id has to be defined.ir'   dimr   )clone
ValueErrormasked_fill_nesum	unsqueezegathersqueeze)r)   r*   prev_output_tokensindex_of_eosdecoder_start_tokenss        y/var/www/vps2.regionflexible.com/Desarrollo/venv/lib/python3.12/site-packages/transformers/models/mbart/modeling_mbart.pyshift_tokens_rightr;   @   s    
 #*LMM##$6$$>M&)),7;;;BQFQQRTUL-44QEMMO 21crc6 : @ @ Bq!"u3q!t    c                   v     e Zd ZdZdedef fdZ	 d
dej                  dedej                  dz  f fd	Z xZ	S )MBartLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    num_embeddingsembedding_dimc                 N    d| _         t        | 	  || j                   z   |       y N   )offsetsuper__init__)selfr?   r@   	__class__s      r:   rF   z(MBartLearnedPositionalEmbedding.__init__Z   s$     $++5}Er<   Nr)   past_key_values_lengthposition_idsc                 $   |a|j                   dd \  }}t        j                  |||z   t        j                  | j                  j
                        j                  |d      }n|j                  d      }t        | %  || j                  z         S )z3`input_ids' shape is expected to be [bsz x seqlen].NrC   )dtypedevicer.   r   )shapetorcharangelongweightrM   expandr4   rE   forwardrD   )rG   r)   rI   rJ   bszseq_lenrH   s         r:   rT   z'MBartLearnedPositionalEmbedding.forward`   s    
 $??2A.LC <<&(>(HPUPZPZcgcncncucufS"o  (11!4Lw|dkk9::r<   )r   N)
__name__
__module____qualname____doc__intrF   rO   TensorrT   __classcell__rH   s   @r:   r>   r>   U   sW    Fs F3 F mq;;?B;V[VbVbeiVi; ;r<   r>   c            
       `     e Zd ZdZd
dededededz  f fdZdej                  f fd	Z	 xZ
S )MBartScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
    r?   r@   padding_idxembed_scaleNc                 6    t         |   |||       || _        y N)rE   rF   rb   )rG   r?   r@   ra   rb   rH   s        r:   rF   z!MBartScaledWordEmbedding.__init__v   s    D&r<   r)   c                 <    t         |   |      | j                  z  S rd   )rE   rT   rb   )rG   r)   rH   s     r:   rT   z MBartScaledWordEmbedding.forwardz   s    wy)D,<,<<<r<   )      ?rW   rX   rY   rZ   r[   floatrF   rO   r\   rT   r]   r^   s   @r:   r`   r`   q   sE    's '3 'S '_dgk_k '= = =r<   r`   modulequerykeyvalueattention_maskscalingdropoutkwargsc                    ||j                  d      dz  }t        j                  ||j                  dd            |z  }|||z   }t        j
                  j                  |d      }t        j
                  j                  ||| j                        }t        j                  ||      }	|	j                  dd      j                         }	|	|fS )Nr.         rC   r   r,   ptrainingr'   )
sizerO   matmul	transposer   
functionalsoftmaxro   ru   
contiguous)
ri   rj   rk   rl   rm   rn   ro   rp   attn_weightsattn_outputs
             r:   eager_attention_forwardr~      s     **R.D( <<s}}Q':;gEL!#n4==((2(>L==((6??([L,,|U3K''1-88:K$$r<   c                       e Zd ZdZ	 	 	 	 	 	 ddedededededed	edz  d
edz  f fdZ	 	 	 dde	j                  de	j                  dz  dedz  de	j                  dz  dee   dee	j                  e	j                  dz  f   fdZ xZS )MBartAttentionz=Multi-headed attention from 'Attention Is All You Need' paperN	embed_dim	num_headsro   
is_decoderbias	is_causalconfig	layer_idxc	                    t         	|           || _        || _        || _        ||z  | _        || _        | j
                  |z  | j                  k7  rt        d| j                   d| d      | j
                  dz  | _        || _	        || _
        || _        |9| j                  r-t        j                  d| j                  j                   d       t!        j"                  |||      | _        t!        j"                  |||      | _        t!        j"                  |||      | _        t!        j"                  |||      | _        y )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).rr   zInstantiating a decoder z without passing `layer_idx` is not recommended and will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.r   )rE   rF   r   r   ro   head_dimr   r0   rn   r   r   r   loggerwarning_oncerH   rW   r   Lineark_projv_projq_projout_proj)
rG   r   r   ro   r   r   r   r   r   rH   s
            r:   rF   zMBartAttention.__init__   s$    	""!Y.MMI%$..8MdnnM]$YKr3  }}d*$""*4>>+B+B*C D, , ii	94@ii	94@ii	94@		)YTBr<   hidden_stateskey_value_statespast_key_valuesrm   rp   returnc                    |du}|j                   dd }g |d| j                  }| j                  |      j                  |      j	                  dd      }	d}
|St        |t              rA|j                  j                  | j                        }
|r|j                  }n|j                  }n|}|r|n|}|rK|I|
rGj                  | j                     j                  }|j                  | j                     j                  }n| j                  |      }| j!                  |      }g |j                   dd d| j                  }|j                  |      j	                  dd      }|j                  |      j	                  dd      }|Kj#                  ||| j                        \  }}|r)t        |t              rd|j                  | j                  <   t%        j&                  | j(                  j*                  t,              } || |	|||f| j.                  sdn| j0                  | j2                  d|\  }} |j4                  g |d j7                         }| j9                  |      }||fS )	z#Input shape: Batch x Time x ChannelNr.   r'   rC   FT        )ro   rn   )rN   r   r   viewrx   
isinstancer   
is_updatedgetr   cross_attention_cacheself_attention_cachelayerskeysvaluesr   r   updater   get_interfacer   _attn_implementationr~   ru   ro   rn   reshaper{   r   )rG   r   r   r   rm   rp   is_cross_attentioninput_shapehidden_shapequery_statesr   curr_past_key_valuescurrent_states
key_statesvalue_stateskv_shapeattention_interfacer}   r|   s                      r:   rT   zMBartAttention.forward   sd    .T9 $))#2.88b8$--8 {{=166|DNNqRST
&/+>?,77;;DNNK
%+:+P+P(+:+O+O('6$-?)]/"=*-44T^^DIIJ/66t~~FMML^4J;;~6LF--cr2FBFFH#2<<QBJ',,X6@@AFL*+?+F+FzS_aeaoao+p(
L%*_FY*ZAEO..t~~>(?(M(MKK,,.E)
 %8	%
  $}}C$,,LL	%
 	%
!\ *k));;;;FFHmmK0L((r<   )r   FTFNNNNN)rW   rX   rY   rZ   r[   rh   boolr(   rF   rO   r\   r   r   r   tuplerT   r]   r^   s   @r:   r   r      s   G  %) $%C%C %C 	%C
 %C %C %C d"%C :%CT 15(,.2H)||H)  ,,-H) 	H)
 t+H) -.H) 
u||U\\D00	1H)r<   r   c                   ~     e Zd Zdef fdZdej                  dej                  dee   dej                  fdZ	 xZ
S )MBartEncoderLayerr   c                 h   t         |           |j                  | _        t	        | j                  |j
                  |j                  |      | _        t        j                  | j                        | _
        |j                  | _        t        |j                     | _        |j                  | _        t        j                   | j                  |j"                        | _        t        j                   |j"                  | j                        | _        t        j                  | j                        | _        y )N)r   r   ro   r   )rE   rF   d_modelr   r   encoder_attention_headsattention_dropout	self_attnr   	LayerNormself_attn_layer_normro   r
   activation_functionactivation_fnactivation_dropoutr   encoder_ffn_dimfc1fc2final_layer_normrG   r   rH   s     r:   rF   zMBartEncoderLayer.__init__  s    'nn44,,	
 %'LL$@!~~#F$>$>?"(";";99T^^V-C-CD99V33T^^D "T^^ <r<   r   rm   rp   r   c                     |}| j                  |      } | j                  d||d|\  }}t        j                  j	                  || j                  | j
                        }||z   }|}| j                  |      }| j                  | j                  |            }t        j                  j	                  || j                  | j
                        }| j                  |      }t        j                  j	                  || j                  | j
                        }||z   }|j                  t        j                  k(  rEt        j                  |j                        j                  dz
  }t        j                   || |      }|S )a>  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        )r   rm   rs   i  )minmax )r   r   r   ry   ro   ru   r   r   r   r   r   rL   rO   float16finfor   clamp)rG   r   rm   rp   residual_clamp_values          r:   rT   zMBartEncoderLayer.forward$  sT    !11-@)4>> 
')
 
q
 --mt||VZVcVc-d =0 --m<**488M+BC--mt?V?Vaeanan-o/--mt||VZVcVc-d =0%--/++m&9&9:>>EK!KKK<[YMr<   )rW   rX   rY   r(   rF   rO   r\   r   r   rT   r]   r^   s   @r:   r   r     sL    ={ =$"||" " +,	"
 
"r<   r   c                        e Zd Zddededz  f fdZ	 	 	 	 	 ddej                  dej                  dz  dej                  dz  dej                  dz  d	edz  d
e	dz  de
e   dej                  fdZ xZS )MBartDecoderLayerNr   r   c           	         t         |           |j                  | _        t	        | j                  |j
                  |j                  dd||      | _        |j                  | _        t        |j                     | _        |j                  | _        t        j                  | j                        | _        t	        | j                  |j
                  |j                  d||      | _        t        j                  | j                        | _        t        j$                  | j                  |j&                        | _        t        j$                  |j&                  | j                        | _        t        j                  | j                        | _        y )NT)r   r   ro   r   r   r   r   )ro   r   r   r   )rE   rF   r   r   r   decoder_attention_headsr   r   ro   r
   r   r   r   r   r   r   encoder_attnencoder_attn_layer_normr   decoder_ffn_dimr   r   r   )rG   r   r   rH   s      r:   rF   zMBartDecoderLayer.__init__J  s    'nn44,,
 ~~#F$>$>?"(";";$&LL$@!*NN**,,
 (*||DNN'C$99T^^V-C-CD99V33T^^D "T^^ <r<   r   rm   encoder_hidden_statesencoder_attention_maskr   	use_cacherp   r   c                    |}| j                  |      } | j                  d|||d|\  }}	t        j                  j	                  || j                  | j
                        }||z   }|h|}| j                  |      } | j                  d||||d|\  }}	t        j                  j	                  || j                  | j
                        }||z   }|}| j                  |      }| j                  | j                  |            }t        j                  j	                  || j                  | j
                        }| j                  |      }t        j                  j	                  || j                  | j
                        }||z   }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            past_key_values (`Cache`): cached past key and value projection states
        )r   r   rm   rs   )r   r   rm   r   r   )r   r   r   ry   ro   ru   r   r   r   r   r   r   r   )
rG   r   rm   r   r   r   r   rp   r   r   s
             r:   rT   zMBartDecoderLayer.forwardi  s   * !11-@ *4>> 
'+)
 	
q --mt||VZVcVc-d =0 !,$H 88GM0t00  +!65 /	 
  M1 MM11-4<<Z^ZgZg1hM$}4M !--m<**488M+BC--mt?V?Vaeanan-o/--mt||VZVcVc-d =0r<   rd   )NNNNT)rW   rX   rY   r(   r[   rF   rO   r\   r   r   r   r   rT   r]   r^   s   @r:   r   r   I  s    ={ =sTz =D /3596:(,!%:||: t+:  %||d2	:
 !&t 3: : $;: +,: 
:r<   r   c                   l     e Zd ZdZdedededef fdZdej                  dej                  fd	Z	 xZ
S )
MBartClassificationHeadz-Head for sentence-level classification tasks.	input_dim	inner_dimnum_classespooler_dropoutc                     t         |           t        j                  ||      | _        t        j
                  |      | _        t        j                  ||      | _        y )N)rt   )rE   rF   r   r   denseDropoutro   r   )rG   r   r   r   r   rH   s        r:   rF   z MBartClassificationHead.__init__  sD     	YYy)4
zzN3		)[9r<   r   r   c                     | j                  |      }| j                  |      }t        j                  |      }| j                  |      }| j	                  |      }|S rd   )ro   r   rO   tanhr   )rG   r   s     r:   rT   zMBartClassificationHead.forward  sN    ]3

=1

=1]3m4r<   rg   r^   s   @r:   r   r     sL    7
:
: 
: 	
:
 
:U\\ ell r<   r   c                   Z     e Zd ZU eed<   dZdZg dZdZdZ	dZ
dZ fdZed        Z xZS )MBartPreTrainedModelr   modelT)r   r   r   c                     t         |   |       t        |t              r t	        j
                  |j                         y y rd   )rE   _init_weightsr   MBartForConditionalGenerationinitzeros_final_logits_bias)rG   ri   rH   s     r:   r   z"MBartPreTrainedModel._init_weights  s2    f%f;<KK001 =r<   c                     | j                   j                  }t        j                  g ddddd|gg| j                        }|j                  |      |d}|S )N)r      
      rC   r         rC   rM   )rm   r)   )r   r*   rO   tensorrM   r2   )rG   	pad_tokenr)   dummy_inputss       r:   r   z!MBartPreTrainedModel.dummy_inputs  sW    KK,,	LL"2Q2q)4L!MVZVaVab	'll95"
 r<   )rW   rX   rY   r(   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_supports_flash_attn_supports_sdpa_supports_flex_attn_can_compile_fullgraphr   propertyr   r]   r^   s   @r:   r   r     sK    &*#TN!2
  r<   r   c                        e Zd ZdZe eedd      dZdef fdZ	d Z
ee	 	 	 dd
ej                  d	z  dej                  d	z  dej                   d	z  dee   deez  f
d              Z xZS )MBartEncoderz
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`MBartEncoderLayer`].

    Args:
        config: MBartConfig
        embed_tokens (nn.Embedding): output embedding
    r'   r   index
layer_name)r   
attentionsr   c                    t         |   |       |j                  | _        |j                  | _        |j
                  }|j                  | _        |j                  | _	        |j                  rt        j                  |      nd}t        |j                  || j                  |      | _        t!        |j                  |      | _        t%        j&                  t)        |j*                        D cg c]  }t-        |       c}      | _        || _        t%        j2                  |      | _        t%        j2                  |j
                        | _        d| _        | j;                          y c c}w )Nrf   rb   F)rE   rF   ro   encoder_layerdrop	layerdropr   r*   ra   max_position_embeddingsmax_source_positionsscale_embeddingmathsqrtr`   
vocab_sizeembed_tokensr>   embed_positionsr   
ModuleListrangeencoder_layersr   r   r   r   layernorm_embedding
layer_normgradient_checkpointing	post_init)rG   r   r   rb   r   rH   s        r:   rF   zMBartEncoder.__init__  s    ~~11NN	!..$*$B$B!.4.D.Ddii	*#4y$*:*:
  ?** 
 mmfNcNcHd$e1%6v%>$ef#%<<	#: ,,v~~6&+# %fs   -E(c                 n    | j                   r)t        | j                  dd      r| j                          y y y )Nr  F)r   getattrr   gradient_checkpointing_enablerG   s    r:   ._backward_compatibility_gradient_checkpointingz;MBartEncoder._backward_compatibility_gradient_checkpointing  s1    //GDKKIach4i..0 5j/r<   Nr)   rm   inputs_embedsrp   r   c                 h   |du |duz  rt        d      || j                  |      }| j                  |d         }||j                  |j                        z   }| j                  |      }t        j                  j                  || j                  | j                        }t        | j                  ||      }t        | j                        D ]F  \  }}d}	| j                  r&t        j                  g       }
|
| j                   k  rd}	|	r= |||fi |}H | j#                  |      }t%        |      S )	a  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
        Nz:You must specify exactly one of input_ids or inputs_embeds).r.   rs   )r   r  rm   FT)last_hidden_state)r0   r  r  torM   r  r   ry   ro   ru   r   r   	enumerater   rO   randr  r  r   )rG   r)   rm   r  rp   	embed_posr   idxencoder_layerto_dropdropout_probabilitys              r:   rT   zMBartEncoder.forward
  s5   > -t";<YZZ  --i8M((w)?@	%	]5I5I(JJ00?--mt||VZVcVc-d2;;')
 #,DKK"8 	CG}}&+jjn#&7"G -!"! !	 6??r<   r   )rW   rX   rY   rZ   r   r%   r   _can_record_outputsr(   rF   r  r$   r&   rO   
LongTensorr\   FloatTensorr   r   r   r   rT   r]   r^   s   @r:   r   r     s     +$^1U
{ 81
   .2.226	@@##d*@@ t+@@ ((4/	@@
 +,@@ 
	 @@   @@r<   r   c                   B    e Zd ZdZe eedd       eedd      dZdef fdZ	e
e	 	 	 	 	 	 	 dd
ej                  d	z  dej                  d	z  dej                  d	z  dej                  d	z  ded	z  dej                  d	z  ded	z  dee   deez  fd              Z xZS )MBartDecoderz
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]

    Args:
        config: MBartConfig
        embed_tokens (nn.Embedding): output embedding
    r'   r   r  r   )r   r  cross_attentionsr   c           	         t         |   |       |j                  | _        |j                  | _        |j
                  | _        |j                  | _        |j                  rt        j                  |j                        nd}t        |j                  |j                  | j                  |      | _        t!        |j                  |j                        | _        t%        j&                  t)        |j*                        D cg c]  }t-        ||       c}      | _        || _        t%        j2                  |j                        | _        t%        j2                  |j                        | _        d| _        | j;                          y c c}w )Nrf   r  )r   F)rE   rF   ro   decoder_layerdropr  r*   ra   r	  max_target_positionsr  r  r  r   r`   r  r  r>   r  r   r  r  decoder_layersr   r   r   r   r  r  r  r  )rG   r   rb   irH   s       r:   rF   zMBartDecoder.__init__^  s     ~~11!..$*$B$B!393I3Idii/s4v~~t/?/?[
  ?**NN 
 mmUZ[a[p[pUq$rPQ%6v%K$rs#%<<#? ,,v~~6&+# %ss   ?FNr)   rm   r   r   r   r  r   rp   r   c                    |du |duz  rt        d      || j                  |      }|rd|b|| j                  j                  r4t	        t        | j                        t        | j                              nt        | j                        }|j                         dd \  }	}
||j                         nd}t        j                  |
|j                        |z   }|1t               s'||
z   }t        j                  |	||j                        }t        |t              r|j                  n|}t        | j                  |||      }t!        | j                  |||      }| j#                  t$        ||	      }||j'                  |j                        z   }| j)                  |      }t*        j,                  j/                  || j.                  | j0                  
      }t3        | j4                        D ]E  \  }}| j0                  r%t        j6                  g       }|| j8                  k  r7 ||||f|||d|}G | j;                  |      }t=        ||      S )a(  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time)r   r.   r   r   )r   r  rm   r   )r   r  rm   r   )rJ   rs   )r   r   r   )r  r   )r0   r  r   is_encoder_decoderr   r   rv   get_seq_lengthrO   rP   rM   r!   onesr   r   r   r   r  inputr   r  r   ry   ro   ru   r!  r   r"  r  r  r   )rG   r)   rm   r   r   r   r  r   rp   
batch_size
seq_lengthrI   rJ   mask_seq_lengthself_attn_cachecausal_maskr   r$  decoder_layerr'  s                       r:   rT   zMBartDecoder.forwardx  sZ   n -t";<stt  --i8M 0 )48V8V $L$DlZ^ZeZeFfg!5  "/!3!3!5cr!:
JETE`!?!?!Afg||J}7K7KLOee!*B*D4zAO"ZZ
OML`L`aN /+>? 00  	 );;')+	
 ";;;'1"7	"
 ++E3IXd+e%8L8L(MM00?--mt||VZVcVc-d"+DKK"8 	C}}&+jjn#&7)% (> /# M	" 68++
 	
r<   )NNNNNNN)rW   rX   rY   rZ   r   r%   r   r(  r(   rF   r$   r&   rO   r)  r\   r*  r   r   r   r   r   r   rT   r]   r^   s   @r:   r,  r,  O  s    +$^1U*>~^{ 4   .2.2:>:>(,26!%}
##d*}
 t+}
  %0047	}

 !& 0 04 7}
 }
 ((4/}
 $;}
 +,}
 
:	:}
   }
r<   r,  c                       e Zd ZdddZdef fdZd Zd Zee		 	 	 	 	 	 	 	 	 	 dde
j                  dz  d	e
j                  dz  d
e
j                  dz  de
j                  dz  deee
j                        dz  dedz  de
j                  dz  de
j                  dz  dedz  dedz  dee   deee
j                     z  fd              Z xZS )
MBartModelzshared.weight)zdecoder.embed_tokens.weightzencoder.embed_tokens.weightr   c                 J   t         |   |       |j                  |j                  }}|j                  rt        j                  |j                        nd}t        ||j                  ||      | _	        t        |      | _        t        |      | _        | j                          y )Nrf   r  )rE   rF   r*   r  r  r  r  r   r`   sharedr   encoderr,  decoderr  )rG   r   ra   r  rb   rH   s        r:   rF   zMBartModel.__init__  s|     "("5"5v7H7HZ393I3Idii/s.z6>>;dop#F+#F+ 	r<   c                     | j                   S rd   )rA  r  s    r:   get_input_embeddingszMBartModel.get_input_embeddings  s    {{r<   c                 ~    || _         | j                   | j                  _        | j                   | j                  _        y rd   )rA  rB  r  rC  rG   rl   s     r:   set_input_embeddingszMBartModel.set_input_embeddings  s)    $(KK!$(KK!r<   Nr)   rm   decoder_input_idsdecoder_attention_maskencoder_outputsr   r  decoder_inputs_embedsr   return_dictrp   r   c                 T   |
|
n| j                   j                  }
|"| t        || j                   j                        }| | j                  d	||||
d|}nI|
rGt        |t              s7t        |d   t        |      dkD  r|d   ndt        |      dkD  r|d   nd      } | j                  d	|||d   ||||	|
d|}|
s||z   S t        |j                  |j                  |j                  |j                  |j                  |j                  |j                  |j                        S )
a5  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
            varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        N)r)   rm   r  rM  r   r'   rC   )r  r   r  )r)   rm   r   r   r   r  r   rM  )r  r   decoder_hidden_statesdecoder_attentionsr-  encoder_last_hidden_stater   encoder_attentionsr   )r   use_return_dictr;   r*   rB  r   r   lenrC  r   r  r   r   r  r-  )rG   r)   rm   rI  rJ  rK  r   r  rL  r   rM  rp   decoder_outputss                r:   rT   zMBartModel.forward  sa   J &1%<k$++B]B] $)>)F 29dkk>V>V W"*dll #-+'	
 O O_!M-"1!"4474H14Loa0RV14_1E1I?1-tO '$,, 

'1"1!"4#1+/#

 

 "_44!-??+;;"1"?"?.99,==&5&G&G"1"?"?.99	
 		
r<   
NNNNNNNNNN)rW   rX   rY   _tied_weights_keysr(   rF   rE  rH  r   r   rO   r)  r\   r   r*  r   r   r   r   r   rT   r]   r^   s   @r:   r?  r?    s`    (7'6
{ 0
  .2.259:>BF(,26:>!%#'S
##d*S
 t+S
 !++d2	S

 !& 0 04 7S
 uU%6%6784?S
 S
 ((4/S
  %0047S
 $;S
 D[S
 +,S
 
eE$5$56	6S
  S
r<   r?  z
    The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.
    )custom_introc                       e Zd ZdZdgZddiZdef fdZ	 dded	edz  d
e	de
j                  f fdZdeddfdZe	 	 	 	 	 	 	 	 	 	 	 ddej                   dz  dej"                  dz  dej                   dz  dej                   dz  deeej&                        dz  dedz  dej&                  dz  dej&                  dz  dej                   dz  de	dz  de	dz  dee   deeej&                     z  fd       Zdej"                  fdZ xZS )r   r   r   lm_head.weightzmodel.shared.weightr   c                 x   t         |   |       t        |      | _        | j	                  dt        j                  d| j                  j                  j                  f             t        j                  |j                  | j                  j                  j                  d      | _        | j                          y )Nr   r'   Fr   )rE   rF   r?  r   register_bufferrO   zerosrA  r?   r   r   r   lm_headr  r   s     r:   rF   z&MBartForConditionalGeneration.__init__x  s     '
0%++q$**BSBSBbBb>c2deyy1B1B1Q1QX]^ 	r<   Nnew_num_tokenspad_to_multiple_ofmean_resizingr   c                 z    t         |   |||      }| j                  |j                  j                  d          |S )Nr   )rE   resize_token_embeddings_resize_final_logits_biasrR   rN   )rG   r_  r`  ra  new_embeddingsrH   s        r:   rc  z5MBartForConditionalGeneration.resize_token_embeddings  s?     8I[]jk&&~'<'<'B'B1'EFr<   c                 6   | j                   j                  d   }||k  r| j                   d d d |f   }nSt        j                  d||z
  f| j                   j                        }t        j
                  | j                   |gd      }| j                  d|       y )Nr.   r'   r   r,   r   )r   rN   rO   r]  rM   catr\  )rG   r_  old_num_tokensnew_bias
extra_biass        r:   rd  z7MBartForConditionalGeneration._resize_final_logits_bias  s    //55b9^+--a..@AHa.)H%IRVRhRhRoRopJyy$"8"8*!E1MH0(;r<   r)   rm   rI  rJ  rK  r   r  rL  labelsr   rM  rp   c                    ||n| j                   j                  }|	=|
rt        j                  d       d}
|"| t	        |	| j                   j
                        } | j                  |f||||||||
|d	|}| j                  |d         | j                  z   }d}|	Ft               } ||j                  d| j                   j                        |	j                  d            }|s|f|dd z   }||f|z   S |S t        |||j                  |j                  |j                  |j                   |j"                  |j$                  |j&                  	      S )	u6  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
            varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example Translation:

        ```python
        >>> from transformers import AutoTokenizer, MBartForConditionalGeneration

        >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")

        >>> example_english_phrase = "42 is the answer"
        >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")

        >>> # Translate
        >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
        >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        '42 este răspuns'
        ```

        Mask filling example:

        ```python
        >>> from transformers import AutoTokenizer, MBartForConditionalGeneration

        >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

        >>> # de_DE is the language symbol id <LID> for German
        >>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"

        >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
        >>> logits = model(input_ids).logits

        >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
        >>> probs = logits[0, masked_index].softmax(dim=0)
        >>> values, predictions = probs.topk(5)

        >>> tokenizer.decode(predictions).split()
        ['nett', 'sehr', 'ganz', 'nicht', 'so']
        ```
        NzJThe `use_cache` argument is changed to `False` since `labels` is provided.F)	rm   rI  rK  rJ  r   r  rL  r   rM  r   r.   r'   	losslogitsr   rO  rP  r-  rQ  r   rR  )r   rM  r   warningr;   r*   r   r^  r   r   r   r  r   r   rO  rP  r-  rQ  r   rR  )rG   r)   rm   rI  rJ  rK  r   r  rL  rk  r   rM  rp   outputs	lm_logitsmasked_lm_lossloss_fctoutputs                     r:   rT   z%MBartForConditionalGeneration.forward  sv   ` &1%<k$++BYBYklI (-B-J$6vt{{?W?W$X!$**
)/+#9+'"7#
 
 LL,t/E/EE	')H%innR9O9O&PRXR]R]^`RabN\GABK/F3A3M^%.YSYY#33")"?"?&99$55&-&G&G")"?"?&99

 
	
r<   c                 B    t        || j                  j                        S rd   )r;   r   r*   )rG   rk  s     r:   %prepare_decoder_input_ids_from_labelszCMBartForConditionalGeneration.prepare_decoder_input_ids_from_labels  s    !&$++*B*BCCr<   )NT)NNNNNNNNNNN)rW   rX   rY   r   _keys_to_ignore_on_load_missingrW  r(   rF   r[   r   r   	Embeddingrc  rd  r   rO   r)  r\   r   r*  r   r   r   r   rT   rw  r]   r^   s   @r:   r   r   n  s     ':&;#*,AB{  ae!7:TzY]	< < <  .2.259:>BF(,26:>*.!%#'z
##d*z
 t+z
 !++d2	z

 !& 0 04 7z
 uU%6%6784?z
 z
 ((4/z
  %0047z
   4'z
 $;z
 D[z
 +,z
 
5!2!23	3z
 z
xDELL Dr<   r   z
    MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    c                   h    e Zd Zdef fdZee	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  dej                  dz  dej                  dz  de
ej                     dz  d	ej                  dz  d
ej                  dz  dej                  dz  dedz  dee   deez  fd              Z xZS )MBartForSequenceClassificationr   c                     t        |   |fi | t        |      | _        t	        |j
                  |j
                  |j                  |j                        | _        | j                          y rd   )
rE   rF   r?  r   r   r   
num_labelsclassifier_dropoutclassification_headr  )rG   r   rp   rH   s      r:   rF   z'MBartForSequenceClassification.__init__  sZ    *6*'
#:NNNN%%	$
  	r<   Nr)   rm   rI  rJ  rK  r  rL  rk  r   rp   r   c
                    |d}	|$|"t        d| j                  j                          | j                  |f|||||||	d|
}|d   }|j	                  | j
                  j                        j                  |j                        }t        t        j                  |j                  d            j                         dk(  d       ||ddf   j                  |j                  d      d|j                  d            dddddf   }| j!                  |      }d}||j                  |j                        }| j
                  j"                  | j
                  j$                  dk(  rd	| j
                  _        nv| j
                  j$                  dkD  rL|j&                  t        j(                  k(  s|j&                  t        j*                  k(  rd
| j
                  _        nd| j
                  _        | j
                  j"                  d	k(  rSt-               }| j
                  j$                  dk(  r& ||j/                         |j/                               }n |||      }n| j
                  j"                  d
k(  rGt1               } ||j                  d| j
                  j$                        |j                  d            }n,| j
                  j"                  dk(  rt3               } |||      }t5        |||j6                  |j8                  |j:                  |j<                  |j>                  |j@                  |jB                  	      S )a  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NFz8Passing input embeddings is currently not supported for rm   rI  rJ  rK  r  rL  r   r   r'   z7All examples must have the same number of <eos> tokens.r.   
regressionsingle_label_classificationmulti_label_classificationrm  )"NotImplementedErrorrH   rW   r   eqr   eos_token_idr   rM   r#   rO   unique_consecutiver3   numelr   rv   r  problem_typer}  rL   rQ   r[   r   r6   r   r   r   r   rO  rP  r-  rQ  r   rR  )rG   r)   rm   rI  rJ  rK  r  rL  rk  r   rp   rq  r   eos_masksentence_representationro  rn  rt  s                     r:   rT   z&MBartForSequenceClassification.forward&  s   T I!:%J4>>KbKbJcd  '1djj
'
)/#9+'"7
'
 
'
  
<< 8 89<<]=Q=QR$$X\\!_5;;=BE	
 #0!"<"A"A-BTBTUVBWY[]j]o]opr]s"tr1H#
 ))*ABYYv}}-F{{''/;;))Q./;DKK,[[++a/V\\UZZ5OSYS_S_chclclSl/LDKK,/KDKK,{{''<7"9;;))Q.#FNN$4fnn6FGD#FF3D))-JJ+-B0F0F GUWY))-II,./.#33")"?"?&99$55&-&G&G")"?"?&99

 
	
r<   )	NNNNNNNNN)rW   rX   rY   r(   rF   r   r   rO   r)  r\   listr*  r   r   r   r   r   rT   r]   r^   s   @r:   r{  r{    s,   {   .2.259:>:>26:>*.!%i
##d*i
 t+i
 !++d2	i

 !& 0 04 7i
 e//047i
 ((4/i
  %0047i
   4'i
 $;i
 +,i
 
0	0i
  i
r<   r{  c                       e Zd Z fdZee	 	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  dej                  dz  dej                  dz  de	ej                     dz  dej                  dz  d	ej                  dz  d
ej                  dz  dej                  dz  dedz  dee   deez  fd              Z xZS )MBartForQuestionAnsweringc                     t         |   |       d|_        |j                  | _        t        |      | _        t        j                  |j                  |j                        | _        | j                          y rB   )
rE   rF   r}  r?  r   r   r   hidden_size
qa_outputsr  r   s     r:   rF   z"MBartForQuestionAnswering.__init__  s[      ++'
))F$6$68I8IJ 	r<   Nr)   rm   rI  rJ  rK  start_positionsend_positionsr  rL  r   rp   r   c                 D   ||d}
 | j                   |f||||||	|
d|}|d   }| j                  |      }|j                  dd      \  }}|j                  d      j	                         }|j                  d      j	                         }d}||t        |j                               dkD  r|j                  d      }t        |j                               dkD  r|j                  d      }|j                  d      }|j                  d|      }|j                  d|      }t        |      } |||      } |||      }||z   d	z  }t        ||||j                  |j                  |j                  |j                  |j                  |j                  |j                   

      S )a  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        NFr  r   r'   r.   r,   )ignore_indexrC   )
rn  start_logits
end_logitsr   rO  rP  r-  rQ  r   rR  )r   r  splitr6   r{   rT  rv   r   r   r   r   rO  rP  r-  rQ  r   rR  )rG   r)   rm   rI  rJ  rK  r  r  r  rL  r   rp   rq  sequence_outputro  r  r  
total_lossignored_indexrt  
start_lossend_losss                         r:   rT   z!MBartForQuestionAnswering.forward  s   P &=+DI&0djj
'
)/#9+'"7
'
 
'
 "!*1#)<<r<#: j#++B/::<''+668

&=+D?'')*Q."1"9"9""==%%'(1, - 5 5b 9(--a0M-33A}EO)//=AM']CH!,@J
M:H$x/14J2%!#33")"?"?&99$55&-&G&G")"?"?&99
 	
r<   rV  )rW   rX   rY   rF   r   r   rO   r\   r)  r  r*  r   r   r   r   r   rT   r]   r^   s   @r:   r  r    s<   
  *..259:>:>371526:>!%W
<<$&W
 t+W
 !++d2	W

 !& 0 04 7W
 e//047W
 ))D0W
 ''$.W
 ((4/W
  %0047W
 $;W
 +,W
 
4	4W
  W
r<   r  c                   (     e Zd ZdZ fdZd Z xZS )MBartDecoderWrapperz
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    c                 d    t         |   |       t        |      | _        | j	                          y rd   )rE   rF   r,  rC  r  r   s     r:   rF   zMBartDecoderWrapper.__init__  s&     #F+r<   c                 &     | j                   |i |S rd   )rC  )rG   argsrp   s      r:   rT   zMBartDecoderWrapper.forward  s    t||T,V,,r<   )rW   rX   rY   rZ   rF   rT   r]   r^   s   @r:   r  r    s    

-r<   r  c                   \    e Zd ZddiZ fdZd Zd Zee	 	 	 	 	 	 	 	 	 dde	j                  dz  de	j                  dz  d	e	j                  dz  d
e	j                  dz  dedz  de	j                  dz  de	j                  dz  dedz  dee	j                  z  dee   deez  fd              Z xZS )MBartForCausalLMrZ  z!model.decoder.embed_tokens.weightc                     d|_         d|_        t        |   |       t	        |      | _        t        j                  |j                  |j                  d      | _
        | j                          y )NTFr   )r   r4  rE   rF   r  r   r   r   r  r  r^  r  r   s     r:   rF   zMBartForCausalLM.__init__  sX     $)! (0
yy!3!3V5F5FUS 	r<   c                 B    | j                   j                  j                  S rd   r   rC  r  r  s    r:   rE  z%MBartForCausalLM.get_input_embeddings!  s    zz!!...r<   c                 :    || j                   j                  _        y rd   r  rG  s     r:   rH  z%MBartForCausalLM.set_input_embeddings$  s    */

'r<   Nr)   rm   r   r   r   r  rk  r   logits_to_keeprp   r   c
                     | j                   j                  d|||||||d|
}|d   }t        |	t              rt	        |	 d      n|	}| j                  |dd|ddf         }d}|a|j                  |j                        }t               } ||j                  d| j                  j                        |j                  d            }t        |||j                  |j                  |j                  |j                         S )aP  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, MBartForCausalLM

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
        >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25")
        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> logits = outputs.logits
        >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
        >>> list(logits.shape) == expected_shape
        True
        ```)r)   rm   r   r   r   r  r   r   Nr.   )rn  ro  r   r   r  r-  r   )r   rC  r   r[   slicer^  r   rM   r   r   r   r  r   r   r   r  r-  )rG   r)   rm   r   r   r   r  rk  r   r  rp   rq  r   slice_indicesro  rn  rt  s                    r:   rT   zMBartForCausalLM.forward'  s   L >PTZZ=O=O 	>
)"7#9+'	>
 	>
  
8B>SV8W~ot4]kmA}a,?@AYYv}}-F')HFKKDKK,B,BCV[[QS_UD0#33!//))$55
 	
r<   )	NNNNNNNNr   )rW   rX   rY   rW  rF   rE  rH  r   r   rO   r)  r\   r*  r   r   r[   r   r   r   r   rT   r]   r^   s   @r:   r  r    s/   =	/0  .2.2:>;?(,26*.!%-.A
##d*A
 t+A
  %0047	A

 !& 1 1D 8A
 A
 ((4/A
   4'A
 $;A
 ell*A
 +,A
 
2	2A
  A
r<   r  )r  r   r  r{  r?  r   )Nr   )RrZ   r  collections.abcr   rO   r   torch.nnr   r   r    r	   r   activationsr
   cache_utilsr   r   r   
generationr   masking_utilsr   r   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   r   r   r   r   r   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r    r!   r"   r#   utils.genericr$   utils.output_capturingr%   r&   configuration_mbartr(   
get_loggerrW   r   r\   r[   r;   ry  r>   r`   Modulerh   r~   r   r   r   r   r   r   r,  r?  r   r{  r  r  r  __all__r   r<   r:   <module>r     sy     $   A A & ! C C ) J :   G &   8 E ,  ! 
		H	%%,, c *;bll ;8
=r|| 
=( !%II%<<% 
% <<	%
 LL4'% T\% % '(%:r)RYY r)j52 5pZ2 Z|bii 0 ?  4r@' r@jh
' h
V p
% p
 p
f 
\D$8/ \D
\D~ z
%9 z
z
z g
 4 g
 g
V-. - Y
+_ Y
xr<   