o
    ߥi9                     @   s  d dl Z d dlZd dlZd dlmZ d dlmZmZmZm	Z	m
Z
 d dlZd dlZd dlm  mZ d dlmZmZ d dlmZ d dlmZmZmZmZmZmZ d dlmZ d dlmZ d d	l m!Z" d
dl#m$Z$ dZ%dZ&G dd dej'Z(G dd dej'Z)G dd dej'Z*G dd dej'Z+G dd dZ,G dd dej'Z-G dd dej'Z.G dd deZ/G dd de/Z0G d d! d!ej'Z1G d"d# d#ej'Z2G d$d% d%e/Z3dS )&    N)	dataclass)AnyDictListOptionalUnion)Tensornn)xavier_uniform_)
BertConfig	BertModelBertTokenizerRobertaConfigRobertaModelRobertaTokenizer)ACT2FN)PreTrainedModel)logger   )
PlugConfigzconfig.jsonzpytorch_model.binc                       s:   e Zd ZdZ		d
 fdd	Z					ddd	Z  ZS )MultiHeadedAttentiona  
    Multi-Head Attention module from
    "Attention is All You Need"
    :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.

    Similar to standard `dot` attention but uses
    multiple attention distributions simulataneously
    to select relevant items.

    .. mermaid::

       graph BT
          A[key]
          B[value]
          C[query]
          O[output]
          subgraph Attn
            D[Attn 1]
            E[Attn 2]
            F[Attn N]
          end
          A --> D
          C --> D
          A --> E
          C --> E
          A --> F
          C --> F
          D --> O
          E --> O
          F --> O
          B --> O

    Also includes several additional tricks.

    Args:
       head_count (int): number of parallel heads
       model_dim (int): the dimension of keys/values/queries,
           must be divisible by head_count
       dropout (float): dropout parameter
    皙?Tc                    s   || dksJ || | _ || _t   || _t||| j  | _t||| j  | _t||| j  | _	tj
dd| _t|| _|| _| jrRt||| _d S d S )Nr   dim)dim_per_head	model_dimsuper__init__
head_countr	   Linearlinear_keyslinear_valueslinear_querySoftmaxsoftmaxDropoutdropoutuse_final_linearfinal_linear)selfr   r   r'   r(   	__class__ [/home/ubuntu/.local/lib/python3.10/site-packages/modelscope/models/nlp/fid_plug/backbone.pyr   R   s$   

zMultiHeadedAttention.__init__NFc	                    s  | d | j| j fdd}	 fdd}
|dur|dkrp| || || |}}}|	|}|	|}|j}|d durStj|d 	||fd	d
}|d durgtj|d 	||fd	d
}||d< ||d< nL|dkr| |}|d du r| || |}}|	|}|	|}n	|d |d }}||d< ||d< n| |}| |}| |}|	|}|	|}|	|}|t
 }t||d	d}|dur|d|}||td}| |}|dur|dddf | }|t|d	d	d  }t|ddddf |dgd}| |}| jr8|
t||}| |}|r6||fS |S t||}|rE||fS |S )a  
        Compute the context vector and the attention vectors.

        Args:
           key (`FloatTensor`): set of `key_len`
                key vectors `[batch, key_len, dim]`
           value (`FloatTensor`): set of `key_len`
                value vectors `[batch, key_len, dim]`
           query (`FloatTensor`): set of `query_len`
                 query vectors  `[batch, query_len, dim]`
           mask: binary mask indicating which keys have
                 non-zero attention `[batch, query_len, key_len]`
        Returns:
           (`FloatTensor`, `FloatTensor`) :

           * output context vectors `[batch, query_len, dim]`
           * one of the attention vectors `[batch, query_len, key_len]`
        r   c                    s   |   dddS )z  projection r   r      )view	transposex
batch_sizer   r   r-   r.   shape   s   z+MultiHeadedAttention.forward.<locals>.shapec                    s   |  dd  d S )z  compute context r   r/   r   )r1   
contiguousr0   r2   r4   r-   r.   unshape   s   z-MultiHeadedAttention.forward.<locals>.unshapeNr*   	self_keysr/   r   self_valuescontextmemory_keysmemory_values   r   -infr   g&.>)sizer   r   r#   r!   r"   devicetorchcattomathsqrtmatmulr1   	unsqueeze	expand_asmasked_fillfloatr%   sumr'   r(   r)   )r*   keyvaluequerymasklayer_cachetypepredefined_graph_1return_attnr6   r8   rA   scoresattnattn_masked	drop_attnr;   outputr-   r4   r.   forwardi   s   









&

zMultiHeadedAttention.forward)r   T)NNNNF__name__
__module____qualname____doc__r   rZ   __classcell__r-   r-   r+   r.   r   (   s    ,r   c                       *   e Zd ZdZd fdd	Zdd Z  ZS )PositionwiseFeedForwarda*   A two-layer Feed-Forward-Network with residual layer norm.

    Args:
        d_model (int): the size of input for the first-layer of the FFN.
        d_ff (int): the hidden layer size of the second-layer
            of the FNN.
        dropout (float): dropout probability in :math:`[0, 1)`.
    r   c                    s\   t    tj|dd| _t||| _td | _t	|| _
t||| _t	|| _d S )Nư>epsgelu_new)r   r   r	   	LayerNorm
layer_normr    w_1r   actvr&   	dropout_1w_2	dropout_2)r*   d_modeld_ffr'   r+   r-   r.   r      s   

z PositionwiseFeedForward.__init__c              	   C   s4   |  | | | |}| | |}|| S N)rk   rj   ri   rh   rm   rl   )r*   r3   interrY   r-   r-   r.   rZ      s   zPositionwiseFeedForward.forward)r   r[   r-   r-   r+   r.   rb      s    		rb   c                       s<   e Zd ZdZdZ fddZ			d
ddZdd	 Z  ZS )TransformerDecoderLayera  
    Args:
      d_model (int): the dimension of keys/values/queries in
                       MultiHeadedAttention, also the input size of
                       the first-layer of the PositionwiseFeedForward.
      heads (int): the number of heads for MultiHeadedAttention.
      d_ff (int): the second-layer of the PositionwiseFeedForward.
      dropout (float): dropout probability(0-1.0).
      self_attn_type (string): type of self-attention scaled-dot, average
      c                    s   t    t|||d| _t|||d| _t|||| _tj|dd| _	tj|dd| _
t|| _| | j}| d| d S )N)r'   rc   rd   rP   )r   r   r   	self_attncontext_attnrb   feed_forwardr	   rg   layer_norm_1layer_norm_2r&   drop_get_attn_subsequent_maskMAX_SIZEregister_buffer)r*   rn   headsro   r'   rP   r+   r-   r.   r     s   
z TransformerDecoderLayer.__init__Nc              
   C   s   t |t j| jddd|dd|df t j d}| |}	|	}
|dur8t j||	fdd}
d}| j|
|
|	||dd}| 	|| }| 
|}| j|||||ddd	\}}| | 	|| }|||
fS )
a#  
        Args:
            inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
            memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
            src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
            tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`

        Returns:
            (`FloatTensor`, `FloatTensor`, `FloatTensor`):

            * output `[batch_size x 1 x model_dim]`
            * attn `[batch_size x 1 x src_len]`
            * all_input `[batch_size x current_step x model_dim]`

        Nr   r   r   r*   )rP   rQ   rR   r;   T)rP   rQ   rR   rT   )rB   gtrR   uint8rP   r@   rw   rC   rt   ry   rx   ru   rv   )r*   inputsmemory_banksrc_pad_masktgt_pad_maskprevious_inputrQ   stepdec_mask
input_norm	all_inputrO   
query_normmidrV   rY   r-   r-   r.   rZ     sD   
(



zTransformerDecoderLayer.forwardc                 C   s2   d||f}t jt |ddd}t|}|S )z
        Get an attention mask to avoid using the subsequent info.

        Args:
            size: int

        Returns:
            (`LongTensor`):

            * subsequent_mask `[1 x size x size]`
        r   )kr   )nptriuonesastyperB   
from_numpy)r*   r@   
attn_shapesubsequent_maskr-   r-   r.   rz   K  s   

z1TransformerDecoderLayer._get_attn_subsequent_mask)NNN)	r\   r]   r^   r_   r{   r   rZ   rz   r`   r-   r-   r+   r.   rr      s    

8rr   c                       s0   e Zd Zd	 fdd	Zd
ddZdd Z  ZS )PositionalEncodingrs   c                    s   t    t||}td|d}ttjd|dtjdt	d|   }t
| | |d d dd df< t| | |d d dd df< |d}| d| t|| _|| _d S )Nr   r   r/   )dtypeg     @pe)r   r   rB   zerosarangerH   exprK   rE   logsincosr|   r	   r&   r'   r   )r*   r'   r   max_lenr   positiondiv_termr+   r-   r.   r   _  s   
$$

zPositionalEncoding.__init__Nc                 C   sl   |t | j }|r|| jd d |f d d d d d f  }n|| jd d d |df  }| |}|S Nr   )rE   rF   r   r   r@   r'   )r*   embr   r-   r-   r.   rZ   l  s   * 
zPositionalEncoding.forwardc                 C   s   | j d d d |df S r   )r   r@   )r*   r   r-   r-   r.   get_embv  s   zPositionalEncoding.get_emb)rs   rp   )r\   r]   r^   r   rZ   r   r`   r-   r-   r+   r.   r   ]  s    

r   c                   @   s8   e Zd ZddedefddZdd Zdd	 Zd
d ZdS )TransformerDecoderStater   srccache_num_layersc                 C   s2   || _ d | _d | _d | _|dkr| | d S d S Nr   )r   r   previous_layer_inputscache_init_cache)r*   r   r   r-   r-   r.   r   |  s   z TransformerDecoderState.__init__c                 C   s   || _ || _d | _d S rp   )r   r   r   )r*   	new_inputr   r-   r-   r.   update_state  s   
z$TransformerDecoderState.update_statec                 C   sB   i | _ t|D ]}d d d}d |d< d |d< || j d|< qd S )N)r<   r=   r9   r:   layer_{})r   rangeformat)r*   
num_layerslayerrQ   r-   r-   r.   r     s   
z#TransformerDecoderState._init_cachec                    s:   d fdd	 | j d| _ | jd ur | j d S d S )Nr   c                    s@   |   D ]\}}|d urt|tr | q||| |< qd S rp   )items
isinstancedict)struct	batch_dimr   v_recursive_mapfnr-   r.   r     s   

z<TransformerDecoderState.map_batch_fn.<locals>._recursive_mapr   )r   r   )r*   r   r-   r   r.   map_batch_fn  s
   
z$TransformerDecoderState.map_batch_fnN)r   )	r\   r]   r^   r   intr   r   r   r   r-   r-   r-   r.   r   z  s
    r   c                       sH   e Zd ZdZdZ fddZ		ddededed	ed
ef
ddZ	  Z
S )TransformerDecodera  
    The Transformer decoder from "Attention is All You Need".


    .. mermaid::

       graph BT
          A[input]
          B[multi-head self-attn]
          BB[multi-head src-attn]
          C[feed forward]
          O[output]
          A --> B
          B --> BB
          BB --> C
          C --> O


    Args:
       num_layers (int): number of encoder layers.
       d_model (int): size of the model
       heads (int): number of heads
       d_ff (int): size of the inner FF layer
       dropout (float): dropout parameters
       embeddings (:obj:`onmt.modules.Embeddings`):
          embeddings to use, should have positional encodings
       attn_type (str): if using a seperate copy attention
    transformerc                    sd   t    || _|| _t| jj| _t fddt	|D | _
tjdd| _d | _d S )Nc                    s   g | ]	}t  qS r-   )rr   .0_ro   rn   r'   r}   r-   r.   
<listcomp>  s    z/TransformerDecoder.__init__.<locals>.<listcomp>rc   rd   )r   r   r   
embeddingsr   embedding_dimpos_embr	   
ModuleListr   transformer_layersrg   rh   state)r*   r   rn   r}   ro   r'   r   r+   r   r.   r     s   


zTransformerDecoder.__init__Nr   tgtr   r   memory_masksc                 C   s  |j }|}| \}}	| \}
}| |}| dksJ | ||}|}| jj}|j|d	|
||}|d urI|d}	|	|||	}n|j|d	|||	}|j
d u r^g }g }t| jD ]@}d }|j
d u rx|jd urx|j| }| j| ||||||j
d ur|j
d| nd |d\}}}|j
d u r|| || qe|j
d u rt|}| |}|j
d u r||| |||fS )Nr>   r   r   r   )r   rQ   r   )r   r@   r   r   r   padding_idxdataeqrH   expandr   r   r   r   r   r   r   appendrB   stackrh   r   )r*   r   r   r   r   r   	src_words	tgt_words	src_batchsrc_len	tgt_batchtgt_lenr   rY   src_memory_bankr   r   r   saved_inputsattnsiprev_layer_inputrV   r   r-   r-   r.   rZ     sZ   















zTransformerDecoder.forwardNN)r\   r]   r^   r_   decoder_typer   r   r   r   rZ   r`   r-   r-   r+   r.   r     s"    r   c                       s$   e Zd Z fddZdd Z  ZS )PlugPointerGeneratorc                    s(   t    t||| _td| _d S r   )r   r   r	   r    dense
LogSoftmaxgen_func)r*   hidden_size
vocab_sizer+   r-   r.   r     s   
zPlugPointerGenerator.__init__c                 C   s   |  |}| |}|S rp   )r   r   )r*   r3   r-   r-   r.   rZ     s   

zPlugPointerGenerator.forwardr\   r]   r^   r   rZ   r`   r-   r-   r+   r.   r     s    r   c                   @   s8   e Zd ZdZeZdZedee	e
ejf  fddZdS )PlugPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    plugpretrained_model_name_or_pathc                 C   sn   t j|t}t j|rt|nt }t j||j|_t j|t}t j|r0t	
|nd }| ||S rp   )ospathjoinCONFIG_NAMEisfiler   from_json_fileencoder_pthWEIGHTS_NAMErB   load)clsr   config_fileconfigcheckpoint_file
checkpointr-   r-   r.   from_pretrained#  s$   
z#PlugPreTrainedModel.from_pretrainedN)r\   r]   r^   r_   r   config_classbase_model_prefixclassmethodr   r   strr   PathLiker   r-   r-   r-   r.   r     s    r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )	PlugModelNc           	         s<  t  | || _|jdks|jdkrtt|j| _n|jdkr+t	t
|j| _|jdkrjt|j| jjjj}| jjjjj|jjd d< | jjjjjd d d d f |jd d|jjdd < || jjj_| jjj| _tj| j| jjj|jdkrdndd}|jrt| jjjjj|_t|j|j|j|j|j|d	| _ t!|j| j| _"| j jj| j"j#_|d urt$|d
 % D ]6}|&dr|d
 | |d
 |'dd< |d
 (| |&dr|d
 | |d
 |'dd< |d
 (| q| j)|d
 dd}t*| d S | j + D ]>}t,|tj-tjfr|jjj.ddd nt,|tj/r3|j0j1  |jj2d t,|tj-rF|j0d urF|j0j1  q	| j"3 D ]}|4 dkr\t5| qM|j1  qM|j6r|jdkrztj| j| jjjdd}ntj| j| jjjdd}t| jjjj|_|| j _| j jj| j"j#_d S )Nbertzh_bertrobertai   r   r   r   )r   )r}   ro   r'   r   modelzmodule. zplug.F)strict        g{Gz?)meanstd      ?)7r   r   r   encoderr   r   r   r   r   r   r   max_posr	   	Embeddingr   r   r   position_embeddingsweightr   repeatr   	share_embcopydeepcopyword_embeddingsr   
dec_layersdec_hidden_size	dec_headsdec_ff_sizedec_dropoutdecoderr   	generatorr   listkeys
startswithreplacepopload_state_dictprintmodulesr   r    normal_rg   biaszero_fill_
parametersr   r
   use_bert_emb)	r*   r   r   my_pos_embeddingstgt_embeddingsrM   msgmodulepr+   r-   r.   r   5  s   












zPlugModel.__init__c           
      C   sP   | j |||dd\}}t|}| ||d d d df |\}}	}||	d |fS )NFtoken_type_idsreturn_dictr   )r   r   r  )
r*   r   r   mask_srcr,  top_vecr   r   decoder_outputsr   r-   r-   r.   rZ     s   
$zPlugModel.forwardrp   r   r-   r-   r+   r.   r   3  s    Qr   c                       ra   )LabelSmoothingLossz
    With label smoothing,
    KL-divergence between q_{smoothed ground truth prob.}(w)
    and p_{prob. computed by model}(w) is minimized.
    c                    st   d|  k rdksJ  J || _ tt|   ||d  }t|f|}d|| j < | d|d d| | _d S )Nr  r  r/   r   one_hot)	r   r   r1  r   rB   fullr|   rH   
confidence)r*   label_smoothingtgt_vocab_sizeignore_indexsmoothing_valuer3  r+   r-   r.   r     s   
zLabelSmoothingLoss.__init__c                 C   sR   | j |dd}|d|d| j ||| jkdd tj	||ddS )zf
        output (FloatTensor): batch_size x n_classes
        target (LongTensor): batch_size
        r   r   rL   )	reduction)
r3  r  r@   scatter_rH   r5  masked_fill_r   Fkl_div)r*   rY   target
model_probr-   r-   r.   rZ     s   zLabelSmoothingLoss.forward)r2  r[   r-   r-   r+   r.   r1    s    r1  c                       s:   e Zd ZdZd fdd	Zdd Zdd Zd	d
 Z  ZS )NMTLossComputez(
    Standard NMT Loss Computation.
    r  c                    sN   t    || _|d | _|dkrt||| jd| _d S tj| jdd| _d S )NPADr   )r8  rL   )r8  r:  )r   r   r  r   r1  	criterionr	   NLLLoss)r*   r  symbolsr   r6  r+   r-   r.   r     s   

zNMTLossCompute.__init__c                 C   s   | d|dS )Nr   r/   r0   r@   )r*   _vr-   r-   r.   _bottle  s   zNMTLossCompute._bottlec                 C   s   | d||dS )Nr   r   rF  )r*   rG  r5   r-   r-   r.   	_unbottle  s   zNMTLossCompute._unbottlec                 C   s   |d d dd f }| d| d}}|| j }| |}| |}| d}	| ||	}
|
	t
|}
|
|||dfS )Nr   r   r   )r@   ner   rL   rH  r  r7   r0   rC  divrK   )r*   r   rY   r?  r5   decoder_lengthnormalizationbottled_outputrU   gtruthlossr-   r-   r.   rZ     s   

zNMTLossCompute.forward)r  )	r\   r]   r^   r_   r   rH  rI  rZ   r`   r-   r-   r+   r.   rA    s    rA  c                	       s   e Zd ZeG dd dZd(def fddZd)dd	Z	
d*dddefddZ	d+ddZ
dded dfddZ									
			d,dddedefdd Zd!d" Z		d)d#ejd$ejd%eeejf fd&d'Z  ZS )-PlugForConditionalGenerationc                   @   sr   e Zd ZU eed< ejed< ejed< ejed< ejed< dZed ed< dZ	eee
  ed< dZee
 ed	< dS )
z"PlugForConditionalGeneration.Batchr5   r   r   r.  r,  Nquery_idsrc_strtgt_str)r\   r]   r^   r   __annotations__rB   r   rR  r   rS  r   rT  r-   r-   r-   r.   Batch  s   
 



rV  Ndefaultdatasetc                    s   t  | t | _|| _|jdkr'tj|j	dd}|j
|j|j|jd}n%|jdks1|jdkrLtj|j	dd}|jd |jd	 |jd
 |jd d}|| _|| _t||| _t| jj|| jj|j| _|| j_| jd | _| jd | _d S )Nr   F)do_lower_case)BOSEOSrB  EOQr   r   Tz[CLS]z[SEP]z[PAD]z	[unused2]rZ  r[  )r   r   logging
get_loggerr   r   r  r   r   r   cls_token_idsep_token_idpad_token_idunk_token_idr   vocab	tokenizerrE  r   r   rA  r  r   r6  rP  rX  start_token	end_token)r*   r   r   rX  rd  rE  r+   r-   r.   r     s>   


z%PlugForConditionalGeneration.__init__c                 C   s@   |d u r| | jd  }| ||||d }| ||}|S )NrB  r   )rJ  rE  longr   rP  )r*   r   r   r.  r,  rY   rP  r-   r-   r.   rZ     s
   z$PlugForConditionalGeneration.forwardFbatchfastc                 O   sN   | j   t  | j|g|R i |W  d   S 1 s w   Y  dS )aq  
        Translate a batch of sentences.

        Mostly a wrapper around :obj:`Beam`.

        Args:
           batch (:obj:`Batch`): a batch from a dataset object
           data (:obj:`Dataset`): the dataset object
           fast (bool): enables fast beam search (may not support all features)

        Todo:
           Shouldn't need the original dataset.
        N)r   evalrB   no_grad_fast_translate_batch)r*   rh  ri  argskwargsr-   r-   r.   translate_batch  s   

$z,PlugForConditionalGeneration.translate_batchr   c                 C   s   t tt| }|dkr"|| |d |d< ||< || }t | }|d  |9  < |d}||ddd|ddd j| }|dkrW|| }|S )Nr   r   r   )	r  r   lenr@   permuter7   r0   r1   r  )r*   r3   countr   permout_sizerh  r-   r-   r.   _tile  s"   

z"PlugForConditionalGeneration._tile
   r  Infr   c                 C   s   |dkrt t|||d}|t||d d k }|||< |dk rgtj|dd\}}tjtj|dddd}	|	|k}
|dkrHd|
d	d |f< |
d	d df 	 |
d	dd f< d|
d
< |

d||
}|||< |S )Nr   r   ).r   Nr  T)
descendingr   r   .).r   )minmaxr@   rB   topksortcumsumr=  r%   clonescatter)r*   logitstop_ktop_pfilter_valuemin_tokens_to_keepindices_to_removesorted_logitssorted_indicescumulative_probssorted_indices_to_remover-   r-   r.   _top_k_top_p_filtering+  s2   

z3PlugForConditionalGeneration._top_k_top_p_filteringP   Tr>   333333?   
max_length
min_lengthc           H         s    |j }|j}|j}|j}jj|||dd\}}t|jjj}|j	}|
 fdd j| dd}tj|tj|d}tjd|   tj|d}tj|  d	gjtj|d}i }tg }|d ur|D ] }t|d d
 }|d
 }||g |g ||< |t| qhtjdgtdg d	   |d|} dd t|D }!i }"dd t|D |"d< dd t|D |"d< dg| |"d< ||"d< t|D ] }#|d d d
f d	d
}$|$dd	}$jj||$||#d\}%}&}jj|%dd	d}'|'d
}(|#|k rd|'d d j f< t|dkrs|d})g }*t|)D ]0}+g },|D ]!}-t||+|#d	 |- |#d	 f ! " # }.|,||.g 7 },q(|*$t|, q"|'d|)ks]J t|)D ]}+|*|+ D ]	}/d|'|+|/f< qgqa|dkr	 %||d|	|#d	 }0t|'dD ],}+t|0|+ D ]"}1|'|+|1f dk r|'|+|1f  |9  < q|'|+|1f  |  < qq|#d	 | }2|
r|'| }3j&|3||d	d}3tj't(j)|3d
dd	d}4t(j*|3d	d}3|3| d
+d	7 }3|3|2 }3t,|3d
|4}5|4d
 }4|5d
 }5n|'| d
+d	7 }'|'|2 }6|6-d
 |( }6|6j. d
d\}5}4j/j0r|d	}7|7dkrt|dD ]e}+d}8dd ||+ D j/j1dkrYj234 5 nfddD d67d d!5 tdkrvq:fd"dtd	td	 D }9t|9d
 }:|:|9d d
 v rd#}8|8rd$|6|+< q:|5|2 } |4|( };|48|(}4|;|d |;d +d	 }<|<d
t9|:d|4d
d	gd
}|4;j }=|#d	 |kr|=<j  |=d d df ;d	}>|== r|d
 |d
}?t|=dD ]}+||+ }@|>|+ r|=|+ <j  |=|+ > d
}A|AD ]%}B|!|@ $|5|+|Bf |?|+|Bd	d f f |rFt|! krFd#|>|+< q"|>|+ rt?|!|@ d%d d#d&}Cj/j@d'kskj/j@d(krj/jAs|Cd   D ]}D|D\}E}F|"d |@ $|E |"d |@ $|F qqq|Cd \}E}F|"d |@ $|E |"d |@ $|F q|>;d> d
}Gt|Gdkr |"S | :d|G} |<:d|G}<|:d|G}|?:d|Gd
|d
}|<d
|:d}|
fd)d q|"S )*NFr+  c                    s   j |  |dS )Nr   )ru  r   r   )	num_beamsr*   r-   r.   <lambda>n      zDPlugForConditionalGeneration._fast_translate_batch.<locals>.<lambda>r   r   )r   rA   )r   r   rA   r   r   r  r?   )rA   c                 S      g | ]}g qS r-   r-   r   r-   r-   r.   r     r  zFPlugForConditionalGeneration._fast_translate_batch.<locals>.<listcomp>c                 S   r  r-   r-   r   r-   r-   r.   r     r  predictionsc                 S   r  r-   r-   r   r-   r-   r.   r     r  rU   
gold_scorerh  )r   g@xr  )r  r  r  )num_samplesr>   c                 S   s   g | ]}t |qS r-   )r   r   wr-   r-   r.   r     s    r   c                    s   g | ]} j j| qS r-   )rd  ids_to_tokensr  )r*   r-   r.   r     s     z ##r  c                    s*   g | ]} |d    |  |d   fqS )r   r-   r   r   )wordsr-   r.   r     s    "TgPKc                 S   s   | d S )Nr   r-   r2   r-   r-   r.   r  0  s    )rM   reverseqg_ranking_test
paraphrasec                    s   |  | S rp   )index_selectr  )select_indicesr-   r.   r  K  s    )Br5   r   r.  r,  r   r   r   r  r   rA   r   ru  rB   r   rg  r4  re  settuplegetaddrp  tensorrK   r  r   r0   r1   r  rZ   squeezer@   rf  cpunumpytolistr   calc_banned_tokensr  multinomialr=  r%   log_softmaxrH   gatherreshaper{  r   block_trigramr  rd  decodestripsplitr   r  fmodrC   r  r   r#  anynonzerosortedrX  sample_topk)Hr*   rh  r  r  bad_words_idsearly_stoppingr  length_penaltyrepetition_penaltyno_repeat_ngram_size	do_sampletemperaturer  r  rm  rn  r5   r   r.  r,  src_featuresr   r   rA   batch_offsetbeam_offset	alive_seqbad_words_prefix_dictbad_words_prefix_lenbw_idrM   rN   topk_log_probs
hypothesesresultsr   decoder_inputdec_outr   	log_probsr   	num_hyposbad_word_banned_tokenr   curr_banned_tokenpre_lenpre_keybanned_tokenprev_output_tokensprevious_tokencurr_length_penalty_scorestopk_idstopk_scorescurr_scorescur_lenfailtrigramstrigramtopk_beam_indexbatch_indexis_finishedend_conditionr  bfinished_hypjbest_hypeachscoreprednon_finishedr-   )r  r  r*   r  r.   rl  N  s  





















z2PlugForConditionalGeneration._fast_translate_batchc           
         s   d k rdd t |D S dd t |D t |D ]5}|    | }tfddt D  D ]}t|d d }||g |d g ||< q9qfdd  fd	dt |D }	|	S )
Nr   c                 S   r  r-   r-   r   r-   r-   r.   r   T  r  zCPlugForConditionalGeneration.calc_banned_tokens.<locals>.<listcomp>c                 S   s   g | ]}i qS r-   r-   r   r-   r-   r.   r   U  r  c                    s   g | ]} |d  qS rp   r-   r  )
gen_tokensr-   r.   r   Z  s    r   c                    s<    d  }t | | f    }|  |g S r   )r  r  r  r  r  )hypo_idx	start_idx	ngram_idx)r  generated_ngramsr  prev_input_idsr-   r.   _get_generated_ngrams_  s   zNPlugForConditionalGeneration.calc_banned_tokens.<locals>._get_generated_ngramsc                    s   g | ]} |qS r-   r-   )r   r  )r  r-   r.   r   g  s    )r   r  r  r  zipr  r  )
r*   r  r  r  r  idxgenerated_ngramngramprev_ngram_tuplebanned_tokensr-   )r  r  r  r  r  r  r.   r  O  s*   

z/PlugForConditionalGeneration.calc_banned_tokens	input_idsattention_maskreturnc           	      O   s`   |d u r| | jd  }| j| d |d ||d}| j|g|R i |}|d }d|iS )NrB  r   )r5   r   r   r,  r.  r  )rJ  rE  rg  rV  r@   ro  )	r*   r  r  r,  rm  rn  rh  translation_batchpredsr-   r-   r.   	translatel  s   
z&PlugForConditionalGeneration.translate)NrW  r   )Fr   )r  rv  NTr>   r  r  r  Fr  r   r  )r\   r]   r^   r   rV  r   r   rZ   boolro  ru  rK   r  r   rl  r  rB   r   r   r  r`   r-   r-   r+   r.   rQ    s`    

!	


%
  rQ  )4r  rE   r   dataclassesr   typingr   r   r   r   r   r  r   rB   torch.nn.functionalr	   
functionalr=  r   torch.nn.initr
   transformersr   r   r   r   r   r   transformers.activationsr   transformers.modeling_utilsr   modelscope.utilsr   r]  configurationr   r   r   Moduler   rb   rr   r   r   r   r   r   r   r1  rA  rQ  r-   r-   r-   r.   <module>   s<     6g&m["