o
    ߥi                     @   s8  d Z ddlmZmZmZ ddlZddlZddlZddlZddl	Z	ddl
Z
ddlZddlZddlmZ ddlmZ ddlmZ ddlmZ e ZejZejZdd	 Zd
d ZeejjjedZG dd dejZ G dd dejZ!G dd dejZ"G dd dejZ#G dd dejZ$G dd dejZ%G dd dejZ&G dd dejZ'G dd dejZ(G dd  d ejZ)G d!d" d"ejZ*G d#d$ d$ejZ+G d%d& d&ejZ,G d'd( d(ejZ-G d)d* d*ejZ.G d+d, d,ejZ/G d-d. d.ejZ0G d/d0 d0ejZ1G d1d2 d2ejZ2G d3d4 d4e2Z3G d5d6 d6ejZ4dS )7zPyTorch BERT model.    )absolute_importdivisionprint_functionN)nn)SpaceTCnConfig)	ModelFile)
get_loggerc                 C   s    | d dt | td   S )zImplementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    g      ?      ?g       @)torcherfmathsqrtx r   ]/home/ubuntu/.local/lib/python3.10/site-packages/modelscope/models/nlp/space_T_cn/backbone.pygelu(   s    r   c                 C   s   | t |  S N)r
   sigmoidr   r   r   r   swish0   s   r   )r   relur   c                       s&   e Zd Zd fdd	Zdd Z  ZS )BertLayerNorm-q=c                    s<   t t|   tt|| _tt|| _	|| _
dS )zWConstruct a layernorm module in the TF style (epsilon inside the square root).
        N)superr   __init__r   	Parameterr
   onesweightzerosbiasvariance_epsilon)selfhidden_sizeeps	__class__r   r   r   9   s   
zBertLayerNorm.__init__c                 C   sN   |j ddd}|| dj ddd}|| t|| j  }| j| | j S )NT)keepdim   )meanpowr
   r   r    r   r   )r!   r   usr   r   r   forwardA   s   zBertLayerNorm.forward)r   __name__
__module____qualname__r   r-   __classcell__r   r   r$   r   r   7   s    r   c                       sF   e Zd ZdZ fddZ														dddZ  ZS )BertEmbeddingszLConstruct the embeddings from word, position and token_type embeddings.
    c                    s   t t|   t|j|j| _t|j|j| _	t|j
|j| _td|j| _td|j| _t|jdd| _t|j| _d S )N      r   r#   )r   r3   r   r   	Embedding
vocab_sizer"   word_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddingsmatch_type_embeddingstype_embeddingsr   	LayerNormDropouthidden_dropout_probdropoutr!   configr$   r   r   r   L   s   zBertEmbeddings.__init__Nc              	   C   sr  | d}tj|tj|jd}|d|}|d u r t|}| |}| |}|d ur|d urt	
||	    }t	j
|td|	    }t|D ]0\}}| D ]'\}}|| | }|dkriqZtj|||d |d d f dd|||d d f< qZqR| |}| |}|| | }|d ur| |}||7 }|d ur| |}||7 }| |}| |}|S )N   )dtypedevicer   rG   dim)sizer
   arangelongrH   	unsqueeze	expand_as
zeros_liker9   nparraycpunumpytolistobject	enumerateitemsr)   r;   r=   r>   r?   r@   rC   )r!   	input_ids
header_idstoken_type_idsmatch_type_idsl_hs
header_lentype_idxcol_dict_listidsheader_flatten_tokensheader_flatten_indexheader_flatten_outputtoken_column_idtoken_column_maskcolumn_start_indexheaders_length
seq_lengthposition_idswords_embeddingsheader_embeddingsbicol_dictkivilengthr;   r=   
embeddingsr>   r?   r   r   r   r-   \   sN   











zBertEmbeddings.forward)NNNNNNNNNNNNNNr/   r0   r1   __doc__r   r-   r2   r   r   r$   r   r3   H   s$    r3   c                       s.   e Zd Z fddZdd ZdddZ  ZS )	BertSelfAttentionc                    s   t t|   |j|j dkrtd|j|jf |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _d S )Nr   LThe hidden size (%d) is not a multiple of the number of attention heads (%d))r   rv   r   r"   num_attention_heads
ValueErrorintattention_head_sizeall_head_sizer   LinearquerykeyvaluerA   attention_probs_dropout_probrC   rD   r$   r   r   r      s    
zBertSelfAttention.__init__c                 C   6   |  d d | j| jf }|j| }|ddddS Nr&   r   r(   rF      rL   rx   r{   viewpermuter!   r   new_x_shaper   r   r   transpose_for_scores   
   
z&BertSelfAttention.transpose_for_scoresNc                 C   s   |  |}| |}| |}| |}| |}| |}	t||dd}
|
t| j	 }
|
| }
t
jdd|
}| |}t||	}|dddd }| d d | jf }|j| }|S )Nr&   rJ   r   r(   rF   r   )r~   r   r   r   r
   matmul	transposer   r   r{   r   SoftmaxrC   r   
contiguousrL   r|   r   )r!   hidden_statesattention_maskschema_link_matrixmixed_query_layermixed_key_layermixed_value_layerquery_layer	key_layervalue_layerattention_scoresattention_probscontext_layernew_context_layer_shaper   r   r   r-      s,   








zBertSelfAttention.forwardr   r/   r0   r1   r   r   r-   r2   r   r   r$   r   rv      s    rv   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )!BertSelfAttentionWithRelationsRATzd
    Adapted from https://github.com/microsoft/rat-sql/blob/master/ratsql/models/transformer.py
    c                    s   t t|   |j|j dkrtd|j|jf |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	d|j|j | _t	d|j|j | _d S Nr   rw      )r   r   r   r"   rx   ry   rz   r{   r|   r   r}   r~   r   r   rA   r   rC   r7   relation_k_embrelation_v_embrD   r$   r   r   r      s,   

z*BertSelfAttentionWithRelationsRAT.__init__c                 C   r   r   r   r   r   r   r   r      r   z6BertSelfAttentionWithRelationsRAT.transpose_for_scoresc                 C   sL  |  |}| |}| |}| |}| |}| |}	| |}
| |}t|	|
dd}|dd}|		dddd}t||}|	dddd}|| t
| j }|| }tjdd|}| |}t||}|	dddd}t||}|	dddd}|| }|	dddd }| dd | jf }|j| }|S )	7
        relation is [batch, seq len, seq len]
        r&   r   r   r(   rF   r   rJ   N)r~   r   r   r   r   r   r
   r   r   r   r   r   r{   r   r   rC   r   rL   r|   r   )r!   r   r   relationr   r   r   
relation_k
relation_vr   r   r   r   relation_k_tquery_layer_trelation_attention_scoresrelation_attention_scores_tmerged_attention_scoresr   r   attention_probs_tcontext_relationcontext_relation_tmerged_context_layerr   r   r   r   r-      sp   




z)BertSelfAttentionWithRelationsRAT.forward)r/   r0   r1   ru   r   r   r-   r2   r   r   r$   r   r      s
    r   c                       ,   e Zd Z fddZdd Zdd Z  ZS ))BertSelfAttentionWithRelationsTableformerc                    s   t t|   |j|j dkrtd|j|jf |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	d| j| _d S r   )r   r   r   r"   rx   ry   rz   r{   r|   r   r}   r~   r   r   rA   r   rC   r7   schema_link_embeddingsrD   r$   r   r   r   ;  s"   
z2BertSelfAttentionWithRelationsTableformer.__init__c                 C   r   r   r   r   r   r   r   r   N  r   z>BertSelfAttentionWithRelationsTableformer.transpose_for_scoresc                 C   s   |  |}| |}| |}| |}|dddd}| |}| |}	| |}
t||	dd}|t	
| j }|| }|| }tjdd|}| |}t||
}|dddd }| dd | jf }|j| }|S )	r   r   r   rF   r(   r&   r   rJ   N)r~   r   r   r   r   r   r
   r   r   r   r   r{   r   r   rC   r   rL   r|   r   )r!   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r-   T  sF   




z1BertSelfAttentionWithRelationsTableformer.forwardr   r   r   r$   r   r   9  s    r   c                       $   e Zd Z fddZdd Z  ZS )BertSelfOutputc                    sB   t t|   t|j|j| _t|jdd| _t	|j
| _d S Nr   r6   )r   r   r   r   r}   r"   denser   r@   rA   rB   rC   rD   r$   r   r   r        zBertSelfOutput.__init__c                 C   &   |  |}| |}| || }|S r   r   rC   r@   r!   r   input_tensorr   r   r   r-        

zBertSelfOutput.forwardr.   r   r   r$   r   r         r   c                       (   e Zd Zd fdd	ZdddZ  ZS )	BertAttentionnonec                    sR   t t|   |dkrt|| _|dkrt|| _|dkr"t|| _t|| _d S )Nr   ratadd)	r   r   r   rv   r!   r   r   r   outputr!   rE   schema_link_moduler$   r   r   r     s   


zBertAttention.__init__Nc                 C   s   |  |||}| ||}|S r   )r!   r   )r!   r   r   r   self_outputattention_outputr   r   r   r-     s
   zBertAttention.forwardr   r   r.   r   r   r$   r   r     s    
r   c                       r   )BertIntermediatec                    sH   t t|   t|j|j| _t|j	t
rt|j	 | _d S |j	| _d S r   )r   r   r   r   r}   r"   intermediate_sizer   
isinstance
hidden_actstrACT2FNintermediate_act_fnrD   r$   r   r   r     s   
zBertIntermediate.__init__c                 C   s   |  |}| |}|S r   )r   r   r!   r   r   r   r   r-     s   

zBertIntermediate.forwardr.   r   r   r$   r   r     r   r   c                       r   )
BertOutputc                    sB   t t|   t|j|j| _t|jdd| _	t
|j| _d S r   )r   r   r   r   r}   r   r"   r   r   r@   rA   rB   rC   rD   r$   r   r   r     r   zBertOutput.__init__c                 C   r   r   r   r   r   r   r   r-     r   zBertOutput.forwardr.   r   r   r$   r   r     r   r   c                       r   )		BertLayerr   c                    s4   t t|   t||d| _t|| _t|| _d S Nr   )	r   r   r   r   	attentionr   intermediater   r   r   r$   r   r   r     s   
zBertLayer.__init__Nc                 C   s(   |  |||}| |}| ||}|S r   )r   r   r   )r!   r   r   r   r   intermediate_outputlayer_outputr   r   r   r-     s   
zBertLayer.forwardr   r   r.   r   r   r$   r   r     s    r   c                       s(   e Zd Z fddZ	dddZ  ZS )SqlBertEncoderc                    s8   t t|   t| t fddt|D | _d S )Nc                       g | ]}t  qS r   copydeepcopy.0_layerr   r   
<listcomp>      z+SqlBertEncoder.__init__.<locals>.<listcomp>)r   r   r   r   r   
ModuleListranger   )r!   layersrE   r$   r   r   r     s
   
zSqlBertEncoder.__init__Tc                 C   s:   g }| j D ]}|||}|r|| q|s|| |S r   r   append)r!   r   r   output_all_encoded_layersall_encoder_layerslayer_moduler   r   r   r-     s   



zSqlBertEncoder.forward)Tr.   r   r   r$   r   r     s    	r   c                       s.   e Zd Zd fdd	Z			d	ddZ  ZS )
BertEncoderr   c                    s>   t t|   t||d t fddt|jD | _d S )Nr   c                    r   r   r   r   r   r   r   r     r   z(BertEncoder.__init__.<locals>.<listcomp>)	r   r   r   r   r   r   r   num_hidden_layersr   r   r$   r   r   r     s
   
zBertEncoder.__init__NTc                 C   s<   g }| j D ]}||||}|r|| q|s|| |S r   r   )r!   r   r   all_schema_link_matrixall_schema_link_maskr   r   r   r   r   r   r-     s   


zBertEncoder.forwardr   )NNTr.   r   r   r$   r   r     s    	r   c                       r   )
BertPoolerc                    s.   t t|   t|j|j| _t | _d S r   )	r   r   r   r   r}   r"   r   Tanh
activationrD   r$   r   r   r     s   zBertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r!   r   first_token_tensorpooled_outputr   r   r   r-   	  s   

zBertPooler.forwardr.   r   r   r$   r   r         r   c                       r   )BertPredictionHeadTransformc                    sR   t t|   t|j|j| _t|jt	rt
|j n|j| _t|jdd| _d S r   )r   r   r   r   r}   r"   r   r   r   r   r   transform_act_fnr   r@   rD   r$   r   r   r     s   
z$BertPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   r@   r   r   r   r   r-     s   


z#BertPredictionHeadTransform.forwardr.   r   r   r$   r   r     s    r   c                       r   )BertLMPredictionHeadc                    sZ   t t|   t|| _tj|d|ddd| _|| j_	t
t|d| _d S )NrF   r   F)r   )r   r   r   r   	transformr   r}   rL   decoderr   r   r
   r   r   r!   rE   bert_model_embedding_weightsr$   r   r   r   $  s   

zBertLMPredictionHead.__init__c                 C   s   |  |}| || j }|S r   )r   r   r   r   r   r   r   r-   2  s   
zBertLMPredictionHead.forwardr.   r   r   r$   r   r   "  s    r   c                       r   )BertOnlyMLMHeadc                    s   t t|   t||| _d S r   )r   r   r   r   predictionsr   r$   r   r   r   :  s   
zBertOnlyMLMHead.__init__c                 C      |  |}|S r   )r   )r!   sequence_outputprediction_scoresr   r   r   r-   ?     
zBertOnlyMLMHead.forwardr.   r   r   r$   r   r   8  r   r   c                       r   )BertOnlyNSPHeadc                    s"   t t|   t|jd| _d S Nr(   )r   r  r   r   r}   r"   seq_relationshiprD   r$   r   r   r   F  s   zBertOnlyNSPHead.__init__c                 C   r  r   )r  )r!   r   seq_relationship_scorer   r   r   r-   J  r  zBertOnlyNSPHead.forwardr.   r   r   r$   r   r  D  s    r  c                       r   )BertPreTrainingHeadsc                    s.   t t|   t||| _t|jd| _d S r  )	r   r	  r   r   r   r   r}   r"   r  r   r$   r   r   r   Q  s
   zBertPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r   )r   r  )r!   r  r   r  r  r   r   r   r-   W  s   

zBertPreTrainingHeads.forwardr.   r   r   r$   r   r	  O  r   r	  c                       s:   e Zd ZdZ fddZdd Ze		d	ddZ  ZS )
PreTrainedBertModelz An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    c                    s:   t t|   t|tstd| jj| jj|| _	d S )NzParameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. To create a model from a Google pretrained model use `model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`)
r   r
  r   r   r   ry   formatr%   r/   rE   )r!   rE   inputskwargsr$   r   r   r   b  s   

zPreTrainedBertModel.__init__c                 C   s|   t |tjtjfr|jjjd| jjd nt |t	r'|j
j  |jjd t |tjr:|j
dur<|j
j  dS dS dS )z! Initialize the weights.
        g        )r)   stdr	   N)r   r   r}   r7   r   datanormal_rE   initializer_ranger   r   zero_fill_)r!   moduler   r   r   init_bert_weightsl  s   
z%PreTrainedBertModel.init_bert_weightsNc                    sZ  |}d}t j|r|}n*t }td|| t	|d}	|	
| W d   n1 s0w   Y  |}t j|t}
t|
}td| | |g|R i |}du rft j|t}t|g }g } D ]$}d}d|v r||dd}d|v r|dd}|r|| || qnt||D ]\}}||< qg g g  td	d dur_d fdd	|t|drd
ndd tdkrtd|jj t  tddd td|jj t  tdkr#td|jj t  tddd td|jj t  |r+t | |S )a  
        Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
            pretrained_model_name: either:
                - a str with the name of a pre-trained model to load selected in the list of:
                    . `bert-base-uncased`
                    . `bert-large-uncased`
                    . `bert-base-cased`
                    . `bert-large-cased`
                    . `bert-base-multilingual-uncased`
                    . `bert-base-multilingual-cased`
                    . `bert-base-chinese`
                - a path or url to a pretrained model archive containing:
                    . `bert_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
            state_dict: an optional state dictionary (collections.OrderedDict object)
                to use instead of Google pre-trained models
            *inputs, **kwargs: additional input for the specific Bert class
                (ex: num_labels for BertForSequenceClassification)
        Nz)extracting archive file {} to temp dir {}zr:gzzModel config {}gammar   betar   	_metadata c              	      sh   d u ri n	 |d d i }| ||d  | j D ]\}}|d ur1||| d  q d S )Nr&   T.)get_load_from_state_dict_modulesrY   )r  prefixlocal_metadatanamechild
error_msgsloadmetadatamissing_keys
state_dictunexpected_keysr   r   r$    s   
z1PreTrainedBertModel.from_pretrained.<locals>.loadbertzbert.)r  r   z7Weights of {} not initialized from pretrained model: {}z
**********zWARNING missing weightsz0Weights from pretrained model not used in {}: {}zWARNING unexpected weights)r  )!ospathisdirtempfilemkdtemploggerinfor  tarfileopen
extractalljoinCONFIG_NAMEr   from_json_fileWEIGHTS_NAMEr
   r$  keysreplacer   zippopgetattrr   r  hasattrlenr%   r/   printshutilrmtree)clspretrained_model_namer'  	cache_dirr  r  resolved_archive_filetempdirserialization_dirarchiveconfig_filerE   modelweights_pathold_keysnew_keysr   new_keyold_keyr   r"  r   from_pretrainedz  s   





z#PreTrainedBertModel.from_pretrained)NN)	r/   r0   r1   ru   r   r  classmethodrP  r2   r   r   r$   r   r
  ]  s    
r
  c                       sR   e Zd ZdZd	 fdd	Z																			d
ddZ  ZS )SpaceTCnModelaJ  SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN").

    Params:
        config: a SpaceTCnConfig class instance with the configuration to build a new model

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output
            as described below. Default: `True`.

    Outputs: Tuple of (encoded_layers, pooled_output)
        `encoded_layers`: controled by `output_all_encoded_layers` argument:
            - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
                of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
                encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
            - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
                to the last attention block of shape [batch_size, sequence_length, hidden_size],
        `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
            classifier pretrained on top of the hidden state associated to the first character of the
            input (`CLF`) to train on the Next-Sentence task (see BERT's paper).

    Example:
        >>> # Already been converted into WordPiece token ids
        >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
        >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
        >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

        >>> config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        >>>     num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

        >>> model = modeling.SpaceTCnModel(config=config)
        >>> all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
    r   c                    sB   t t| | t|| _t||d| _t|| _| 	| j
 d S r   )r   rR  r   r3   rs   r   encoderr   poolerapplyr  r   r$   r   r   r     s   

zSpaceTCnModel.__init__NTc                 C   s   |d u r	t |}|d u rt |}|dd}d| d }| |||||||	|
||||||||}| j|||||d}|d }| |}|sM|d }||fS )NrF   r(   r	   g     )r   r   r   r&   )r
   	ones_likerQ   rO   rs   rS  rT  )r!   rZ   r[   token_order_idsr\   r   r]   r^   r_   type_idsra   rb   rc   rd   re   rf   rg   rh   ri   r   r   r   extended_attention_maskembedding_outputencoded_layersr  r   r   r   r   r-     s0   




zSpaceTCnModel.forwardr   )NNNNNNNNNNNNNNNNNNTrt   r   r   r$   r   rR    s.    +rR  c                       r   )Seq2SQLc                    s   t t|   || _|| _|| _|| _|
| _|| _|| _	|| _
|| _|	| _t||	| _t||	| _t||| _t||	| _t|| || | _t||	 ||	 | _t|d| _t||d | _t||d | _t||	d | _d S )Nr   rF   )r   r\  r   iShSlsdrrH   	n_agg_ops
n_cond_opsn_action_opsmax_select_nummax_where_numr   r}   w_sss_modelw_sse_model
s_ht_modelwc_ht_modelselect_agg_model
w_op_model
conn_modelaction_model
slen_model
wlen_model)r!   r]  r^  lSr`  rb  ra  rc  rd  re  rH   r$   r   r   r   ^  s2   

zSeq2SQL.__init__c                 C   s
   || _ d S r   )rH   )r!   rH   r   r   r   
set_device|  s   
zSeq2SQL.set_devicec           *   	      s"  t |}t |}	t||    }t||    }t||    }t |     g }
g }g }g }g }g }g }g }t|D ]\|d  |
d  |d  fddtj	D }|d j	  fddtj
D }|| || dd t| d D | d g||    }||  fd	dtd
| dD fddt|	t D 7 | qPtj|
tjdj}
tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}t|}t|jgj}t|jgj}t|jgj}t|jgj}t|j	jgj}t|j
jgj}t||d jgj}t||	jgj}t|D ]}|| d
|
| ||d d f< || d
|| ||d d f< || d
|| ||d d f< || d
|| ||d d f< || d
||d d f ||d d d d f< || d
||d d f ||d d d d f< || d
||d d f ||d d d d f< || d
||d d f ||d d d d f< q|dj|d} |dj|j
d }!|dj|j	d }"|dj|jd }#|djj	 |dj}$|dj|dj	dd}%|dj|dj	dd}&|dj|dj	dd}' |dj|dj
dd}(!|djj
 |dj"})|#|(|)| |%|$|&|'f|!|"ffS )NrF   r(   r   c                    s   g | ]} d  | qS )   r   r   i)elemr   r   r         z#Seq2SQL.forward.<locals>.<listcomp>rr  c                    s   g | ]} d  j  | qS )   )re  rs  )ru  r!   r   r   r     s    c                 S   s   g | ]}|qS r   r   rs  r   r   r   r     s    c                    s   g | ]}  | qS r   r   rs  )column_indexibr   r   r     rv  r   c                    s   g | ]} d  qS )r   r   r   )indexr   r   r     s    rI   r&   )#maxrR   rS   rT   rU   rV   rX   r   r   re  rd  r>  r
   tensorrN   torH   r   r]  index_selectrl  reshapern  ro  rm  rc  rk  rb  ri  r   rf  rg  rh  rj  ra  )*r!   
wemb_layerl_nr^   start_indexrx  tokensrb   max_l_nmax_l_hs
conn_index
slen_index
wlen_indexaction_indexwhere_op_indexselect_agg_indexheader_pos_indexquery_indexwoisaiqilistheader_indexbSconn_embslen_embwlen_emb
action_embwo_embsa_embqv_embht_embrt  s_ccos_slens_wlens_action	wo_output	wc_outputwv_sswv_se	sc_output	sa_outputr   )rx  ru  ry  rz  r!   r   r-     s  

$

 "
    .2


zSeq2SQL.forward)r/   r0   r1   r   rq  r-   r2   r   r   r$   r   r\  \  s    r\  )5ru   
__future__r   r   r   r   r   r*  r@  r1  r-  rU   rR   r
   r   .modelscope.models.nlp.space_T_cn.configurationr   modelscope.utils.constantr   modelscope.utils.loggerr   r/  CONFIGURATIONr5  TORCH_MODEL_BIN_FILEr7  r   r   
functionalr   r   Moduler   r3   rv   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r	  r
  rR  r\  r   r   r   r   <module>   sT   L:kJ s