o
    	۷io                     @   s  d Z ddlZddlmZmZmZ ddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZmZ ddlmZmZmZ ddlmZm Z m!Z! ddl"m#Z# e!$e%Z&G dd dej'Z(dDddZ)G dd dej'Z*G dd dej'Z+G dd dej'Z,G dd dej'Z-G dd dej'Z.G dd dej'Z/G dd  d ej'Z0G d!d" d"ej'Z1	#	dEd$ej'd%ej2d&ej2d'ej2d(eej2 d)e3d*e3d+eej2 fd,d-Z4G d.d/ d/ej'Z5G d0d1 d1ej'Z6G d2d3 d3eZ7G d4d5 d5ej'Z8eG d6d7 d7eZ9eG d8d9 d9e9Z:eG d:d; d;e9Z;ed<d=G d>d? d?e9Z<ed@d=G dAdB dBe9Z=g dCZ>dS )FzPyTorch MarkupLM model.    N)CallableOptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringcan_return_tuplelogging   )MarkupLMConfigc                       s*   e Zd ZdZ fddZdddZ  ZS )XPathEmbeddingszConstruct the embeddings from xpath tags and subscripts.

    We drop tree-id in this version, as its info can be covered by xpath.
    c                    s   t     j| _t j| j  j| _t j	| _
t | _t j| j d j | _td j  j| _t fddt| jD | _t fddt| jD | _d S )N   c                       g | ]
}t  j jqS  )r   	Embeddingmax_xpath_tag_unit_embeddingsxpath_unit_hidden_size.0_configr   d/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py
<listcomp>>       z,XPathEmbeddings.__init__.<locals>.<listcomp>c                    r   r   )r   r    max_xpath_subs_unit_embeddingsr"   r#   r&   r   r(   r)   E   r*   )super__init__	max_depthr   Linearr"   hidden_sizexpath_unitseq2_embeddingsDropouthidden_dropout_probdropoutReLU
activationxpath_unitseq2_inner	inner2emb
ModuleListrangexpath_tag_sub_embeddingsxpath_subs_sub_embeddingsselfr'   	__class__r&   r(   r-   1   s"   




zXPathEmbeddings.__init__Nc              	   C   s   g }g }t | jD ](}|| j| |d d d d |f  || j| |d d d d |f  q	tj|dd}tj|dd}|| }| | | 	| 
|}|S )Ndim)r:   r.   appendr;   r<   torchcatr8   r4   r6   r7   )r>   xpath_tags_seqxpath_subs_seqxpath_tags_embeddingsxpath_subs_embeddingsixpath_embeddingsr   r   r(   forwardK   s   &(zXPathEmbeddings.forward)NN)__name__
__module____qualname____doc__r-   rM   __classcell__r   r   r?   r(   r   +   s    r   c                 C   s6   |  | }tj|dd|| | }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   rB   )neintrE   cumsumtype_aslong)	input_idspadding_idxpast_key_values_lengthmaskincremental_indicesr   r   r(   "create_position_ids_from_input_ids^   s   r]   c                       s@   e Zd ZdZ fddZdd Z							d
dd	Z  ZS )MarkupLMEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    || _tj|j|j|jd| _t|j	|j| _
|j| _t|| _t|j|j| _tj|j|jd| _t|j| _| jdt|j	ddd |j| _tj|j	|j| jd| _
d S )N)rY   epsposition_ids)r   rA   F)
persistent)r,   r-   r'   r   r    
vocab_sizer0   pad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsr.   r   rL   type_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsr2   r3   r4   register_bufferrE   arangeexpandrY   r=   r?   r   r(   r-   q   s    

zMarkupLMEmbeddings.__init__c                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        NrA   r   dtypedevicer   )sizerE   rm   rY   rW   rq   	unsqueezern   )r>   inputs_embedsinput_shapesequence_lengthra   r   r   r(   &create_position_ids_from_inputs_embeds   s   	z9MarkupLMEmbeddings.create_position_ids_from_inputs_embedsNr   c                 C   s<  |d ur	|  }n|  d d }|d ur|jn|j}	|d u r0|d ur+t|| j|}n| |}|d u r=tj|tj|	d}|d u rF| |}|d u r_| j	j
tjtt|| jg tj|	d }|d u rx| j	jtjtt|| jg tj|	d }|}
| |}| |}| ||}|
| | | }| |}| |}|S )NrA   ro   )rr   rq   r]   rY   rw   rE   zerosrW   re   r'   
tag_pad_idonestuplelistr.   subs_pad_idrg   ri   rL   rj   r4   )r>   rX   rG   rH   token_type_idsra   rt   rZ   ru   rq   words_embeddingsrg   ri   rL   
embeddingsr   r   r(   rM      s8   









zMarkupLMEmbeddings.forward)NNNNNNr   )rN   rO   rP   rQ   r-   rw   rM   rR   r   r   r?   r(   r^   n   s    r^   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )MarkupLMSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr_   )r,   r-   r   r/   r0   denserj   rk   r2   r3   r4   r=   r?   r   r(   r-         
zMarkupLMSelfOutput.__init__hidden_statesinput_tensorreturnc                 C   &   |  |}| |}| || }|S Nr   r4   rj   r>   r   r   r   r   r(   rM         

zMarkupLMSelfOutput.forwardrN   rO   rP   r-   rE   TensorrM   rR   r   r   r?   r(   r          $r   c                       2   e Zd Z fddZdejdejfddZ  ZS )MarkupLMIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )r,   r-   r   r/   r0   intermediate_sizer   
isinstance
hidden_actstrr
   intermediate_act_fnr=   r?   r   r(   r-      s
   
zMarkupLMIntermediate.__init__r   r   c                 C      |  |}| |}|S r   )r   r   r>   r   r   r   r(   rM         

zMarkupLMIntermediate.forwardr   r   r   r?   r(   r      s    r   c                       r   )MarkupLMOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r,   r-   r   r/   r   r0   r   rj   rk   r2   r3   r4   r=   r?   r   r(   r-      r   zMarkupLMOutput.__init__r   r   r   c                 C   r   r   r   r   r   r   r(   rM      r   zMarkupLMOutput.forwardr   r   r   r?   r(   r      r   r   c                       r   )MarkupLMPoolerc                    s*   t    t|j|j| _t | _d S r   )r,   r-   r   r/   r0   r   Tanhr6   r=   r?   r   r(   r-      s   
zMarkupLMPooler.__init__r   r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r6   )r>   r   first_token_tensorpooled_outputr   r   r(   rM     s   

zMarkupLMPooler.forwardr   r   r   r?   r(   r      s    r   c                       r   )MarkupLMPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S r   )r,   r-   r   r/   r0   r   r   r   r   r
   transform_act_fnrj   rk   r=   r?   r   r(   r-     s   
z(MarkupLMPredictionHeadTransform.__init__r   r   c                 C   s"   |  |}| |}| |}|S r   )r   r   rj   r   r   r   r(   rM     s   


z'MarkupLMPredictionHeadTransform.forwardr   r   r   r?   r(   r     s    	r   c                       s,   e Zd Z fddZdd Zdd Z  ZS )MarkupLMLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)bias)r,   r-   r   	transformr   r/   r0   rc   decoder	ParameterrE   rx   r   r=   r?   r   r(   r-   "  s
   

z!MarkupLMLMPredictionHead.__init__c                 C   s   | j | j_ d S r   )r   r   r>   r   r   r(   _tie_weights/  s   z%MarkupLMLMPredictionHead._tie_weightsc                 C   r   r   )r   r   r   r   r   r(   rM   2  r   z MarkupLMLMPredictionHead.forward)rN   rO   rP   r-   r   rM   rR   r   r   r?   r(   r   !  s    r   c                       r   )MarkupLMOnlyMLMHeadc                    s   t    t|| _d S r   )r,   r-   r   predictionsr=   r?   r   r(   r-   :  s   
zMarkupLMOnlyMLMHead.__init__sequence_outputr   c                 C   s   |  |}|S r   )r   )r>   r   prediction_scoresr   r   r(   rM   >  s   
zMarkupLMOnlyMLMHead.forwardr   r   r   r?   r(   r   9  s    r           modulequerykeyvalueattention_maskscalingr4   	head_maskc                 K   s   t ||dd| }	|d ur'|d d d d d d d |jd f }
|	|
 }	tjj|	dt jd|j	}	tjj
|	|| jd}	|d urM|	|dddd }	t |	|}|dd }||	fS )N   r	   rA   )rC   rp   )ptrainingr   )rE   matmul	transposeshaper   
functionalsoftmaxfloat32torp   r4   r   view
contiguous)r   r   r   r   r   r   r4   r   kwargsattn_weightscausal_maskattn_outputr   r   r(   eager_attention_forwardD  s   &r   c                       sZ   e Zd Z fddZ			ddejdeej deej dee d	e	ej f
d
dZ
  ZS )MarkupLMSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|| _|j| _t|j|j | _| j| j | _	t
|j| j	| _t
|j| j	| _t
|j| j	| _t
|j| _|j| _| jd | _d S )Nr   embedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()g      )r,   r-   r0   num_attention_headshasattr
ValueErrorr'   rT   attention_head_sizeall_head_sizer   r/   r   r   r   r2   attention_probs_dropout_probr4   attention_dropoutr   r=   r?   r   r(   r-   a  s"   

zMarkupLMSelfAttention.__init__NFr   r   r   output_attentionsr   c                 K   s   |j d d }g |d| jR }| ||dd}| ||dd}	| ||dd}
t}| jj	dkrCt
| jj	 }|| ||	|
|f| jsOdn| j| j|d|\}}|jg |dR   }|rp||f}|S |f}|S )NrA   r   r   eagerr   )r4   r   r   )r   r   r   r   r   r   r   r   r'   _attn_implementationr   r   r   r   reshaper   )r>   r   r   r   r   r   ru   hidden_shapequery_states
key_statesvalue_statesattention_interfacer   r   outputsr   r   r(   rM   v  s4   	
zMarkupLMSelfAttention.forwardNNF)rN   rO   rP   r-   rE   r   r   FloatTensorboolr{   rM   rR   r   r   r?   r(   r   `  s     r   c                       sb   e Zd Z fddZdd Z			ddejdeej d	eej d
ee	 de
ej f
ddZ  ZS )MarkupLMAttentionc                    s*   t    t|| _t|| _t | _d S r   )r,   r-   r   r>   r   outputsetpruned_headsr=   r?   r   r(   r-     s   


zMarkupLMAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rB   )lenr   r>   r   r   r   r   r   r   r   r   r   r   union)r>   headsindexr   r   r(   prune_heads  s   zMarkupLMAttention.prune_headsNFr   r   r   r   r   c           	      K   s@   | j |f|||d|}| |d |}|f|dd   }|S N)r   r   r   r   r   )r>   r   )	r>   r   r   r   r   r   self_outputsattention_outputr   r   r   r(   rM     s   zMarkupLMAttention.forwardr   )rN   rO   rP   r-   r   rE   r   r   r   r   r{   rM   rR   r   r   r?   r(   r     s"    r   c                       sb   e Zd Z fddZ			ddejdeej deej dee d	e	ej f
d
dZ
dd Z  ZS )MarkupLMLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S )Nr   )
r,   r-   chunk_size_feed_forwardseq_len_dimr   	attentionr   intermediater   r   r=   r?   r   r(   r-     s   


zMarkupLMLayer.__init__NFr   r   r   r   r   c           
      K   sP   | j |f|||d|}|d }|dd  }t| j| j| j|}	|	f| }|S r   )r   r   feed_forward_chunkr   r   )
r>   r   r   r   r   r   self_attention_outputsr   r   layer_outputr   r   r(   rM     s    
zMarkupLMLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )r>   r   intermediate_outputr   r   r   r(   r     s   
z MarkupLMLayer.feed_forward_chunkr   )rN   rO   rP   r-   rE   r   r   r   r   r{   rM   r   rR   r   r   r?   r(   r     s"    
r   c                       sz   e Zd Z fddZe					ddejdeej deej d	ee	 d
ee	 dee	 de
eej ef fddZ  ZS )MarkupLMEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   )r   )r$   rK   r&   r   r(   r)     s    z,MarkupLMEncoder.__init__.<locals>.<listcomp>F)	r,   r-   r'   r   r9   r:   num_hidden_layerslayergradient_checkpointingr=   r?   r&   r(   r-     s   
 
zMarkupLMEncoder.__init__NFTr   r   r   r   output_hidden_statesreturn_dictr   c                 K   s   |rdnd }|r
dnd }	t | jD ].\}
}|r||f }|d ur$||
 nd }|d||||d|}|d }|r?|	|d f }	q|rG||f }t|||	dS )Nr   )r   r   r   r   r   r   )last_hidden_stater   
attentions)	enumerater   r   )r>   r   r   r   r   r   r   r   all_hidden_statesall_self_attentionsrK   layer_modulelayer_head_masklayer_outputsr   r   r(   rM     s2   

zMarkupLMEncoder.forward)NNFFT)rN   rO   rP   r-   r   rE   r   r   r   r   r   r{   r   rM   rR   r   r   r?   r(   r     s.    	r   c                       sJ   e Zd ZU eed< dZdd Zedee	e
ejf  f fddZ  ZS )MarkupLMPreTrainedModelr'   markuplmc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS t |tre|jj	  dS dS )zInitialize the weightsr   )meanstdN      ?)r   r   r/   weightdatanormal_r'   initializer_ranger   zero_r    rY   rj   fill_r   )r>   r   r   r   r(   _init_weights)  s    


z%MarkupLMPreTrainedModel._init_weightspretrained_model_name_or_pathc                    s   t  j|g|R i |S r   )r,   from_pretrained)clsr  
model_argsr   r?   r   r(   r  ;  s   z'MarkupLMPreTrainedModel.from_pretrained)rN   rO   rP   r   __annotations__base_model_prefixr  classmethodr   r   r   osPathLiker  rR   r   r   r?   r(   r  #  s   
 *r  c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Zee	
	
	
	
	
	
	
	
	
	
	
dde	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e de	e de	e deeef fddZ  ZS )MarkupLMModelTc                    sD   t  | || _t|| _t|| _|rt|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
r,   r-   r'   r^   r   r   encoderr   pooler	post_init)r>   r'   add_pooling_layerr?   r   r(   r-   C  s   

zMarkupLMModel.__init__c                 C   s   | j jS r   r   re   r   r   r   r(   get_input_embeddingsS  s   z"MarkupLMModel.get_input_embeddingsc                 C   s   || j _d S r   r  )r>   r   r   r   r(   set_input_embeddingsV  s   z"MarkupLMModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )r>   heads_to_pruner   r   r   r   r(   _prune_headsY  s   zMarkupLMModel._prune_headsNrX   rG   rH   r   r~   ra   r   rt   r   r   r   r   c                 C   s  |	dur|	n| j j}	|
dur|
n| j j}
|dur|n| j j}|dur*|dur*td|dur9| || | }n|durF| dd }ntd|durQ|jn|j}|du r_tj	||d}|du rltj
|tj|d}|dd}|j| jd	}d
| d }|dur| dkr|dddd}|| j jdddd}n| dkr|ddd}|jt|  jd	}ndg| j j }| j||||||d}| j||||	|
dd}|d }| jdur| |nd}t|||j|jdS )a  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.

        Examples:

        ```python
        >>> from transformers import AutoProcessor, MarkupLMModel

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> model = MarkupLMModel.from_pretrained("microsoft/markuplm-base")

        >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"

        >>> encoding = processor(html_string, return_tensors="pt")

        >>> outputs = model(**encoding)
        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 4, 768]
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timerA   z5You have to specify either input_ids or inputs_embeds)rq   ro   r   r   )rp   r  g     r   )rX   rG   rH   ra   r~   rt   T)r   r   r   r   )r   pooler_outputr   r   )r'   r   r   use_return_dictr   %warn_if_padding_and_no_attention_maskrr   rq   rE   rz   rx   rW   rs   r   rp   rC   rn   r   next
parametersr   r  r  r   r   r   )r>   rX   rG   rH   r   r~   ra   r   rt   r   r   r   ru   rq   extended_attention_maskembedding_outputencoder_outputsr   r   r   r   r(   rM   a  sh   '
zMarkupLMModel.forward)T)NNNNNNNNNNN)rN   rO   rP   r-   r  r  r   r   r   r   rE   
LongTensorr   r   r   r{   r   rM   rR   r   r   r?   r(   r  @  sV    	

r  c                !       s   e Zd Z fddZee													ddeej deej deej deej deej d	eej d
eej deej deej deej dee	 dee	 dee	 de
eej ef fddZ  ZS )MarkupLMForQuestionAnsweringc                    s@   t  | |j| _t|dd| _t|j|j| _| 	  d S NF)r  )
r,   r-   
num_labelsr  r  r   r/   r0   
qa_outputsr  r=   r?   r   r(   r-     s
   z%MarkupLMForQuestionAnswering.__init__NrX   rG   rH   r   r~   ra   r   rt   start_positionsend_positionsr   r   r   r   c                 C   s  |dur|n| j j}| j||||||||||dd}|d }| |}|jddd\}}|d }|d }d}|	dur|
durt|	 dkrQ|	d}	t|
 dkr^|
d}
|d}|		d| |
	d| t
|d}|||	}|||
}|| d	 }t||||j|jd
S )ae  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.

        Examples:

        ```python
        >>> from transformers import AutoProcessor, MarkupLMForQuestionAnswering
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base-finetuned-websrc")
        >>> model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-base-finetuned-websrc")

        >>> html_string = "<html> <head> <title>My name is Niels</title> </head> </html>"
        >>> question = "What's his name?"

        >>> encoding = processor(html_string, questions=question, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> answer_start_index = outputs.start_logits.argmax()
        >>> answer_end_index = outputs.end_logits.argmax()

        >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
        >>> processor.decode(predict_answer_tokens).strip()
        'Niels'
        ```NT
rG   rH   r   r~   ra   r   rt   r   r   r   r   r   rA   rB   )ignore_indexr   )lossstart_logits
end_logitsr   r   )r'   r"  r  r-  splitsqueezer   r   rr   clamp_r   r   r   r   )r>   rX   rG   rH   r   r~   ra   r   rt   r.  r/  r   r   r   r   r   logitsr3  r4  
total_lossignored_indexloss_fct
start_lossend_lossr   r   r(   rM     sN   0






z$MarkupLMForQuestionAnswering.forward)NNNNNNNNNNNNN)rN   rO   rP   r-   r   r   r   rE   r   r   r   r{   r   rM   rR   r   r   r?   r(   r*    s\    
	
r*  zC
    MarkupLM Model with a `token_classification` head on top.
    )custom_introc                          e Zd Z fddZee												ddeej deej deej deej deej d	eej d
eej deej deej dee	 dee	 dee	 de
eej ef fddZ  ZS )MarkupLMForTokenClassificationc                    sb   t  | |j| _t|dd| _|jd ur|jn|j}t|| _	t
|j|j| _|   d S r+  )r,   r-   r,  r  r  classifier_dropoutr3   r   r2   r4   r/   r0   
classifierr  r>   r'   rA  r?   r   r(   r-   >  s   z'MarkupLMForTokenClassification.__init__NrX   rG   rH   r   r~   ra   r   rt   labelsr   r   r   r   c                 C   s   |dur|n| j j}| j|||||||||
|dd}|d }| |}d}|	dur:t }||d| j j|	d}t|||j|j	dS )a  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Examples:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForTokenClassification
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> processor.parse_html = False
        >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)

        >>> nodes = ["hello", "world"]
        >>> xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
        >>> node_labels = [1, 2]
        >>> encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```NTr0  r   rA   r2  r8  r   r   )
r'   r"  r  rB  r   r   r,  r   r   r   )r>   rX   rG   rH   r   r~   ra   r   rt   rD  r   r   r   r   r   r   r2  r;  r   r   r(   rM   L  s:   .
z&MarkupLMForTokenClassification.forwardNNNNNNNNNNNN)rN   rO   rP   r-   r   r   r   rE   r   r   r   r{   r   rM   rR   r   r   r?   r(   r@  7  sV    	
r@  z
    MarkupLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    c                       r?  )!MarkupLMForSequenceClassificationc                    sd   t  | |j| _|| _t|| _|jd ur|jn|j}t	|| _
t|j|j| _|   d S r   )r,   r-   r,  r'   r  r  rA  r3   r   r2   r4   r/   r0   rB  r  rC  r?   r   r(   r-     s   
z*MarkupLMForSequenceClassification.__init__NrX   rG   rH   r   r~   ra   r   rt   rD  r   r   r   r   c                 C   sJ  |dur|n| j j}| j|||||||||
|dd}|d }| |}| |}d}|	dur| j jdu rX| jdkr>d| j _n| jdkrT|	jtj	ksO|	jtj
krTd| j _nd| j _| j jdkrvt }| jdkrp|| |	 }n+|||	}n%| j jdkrt }||d| j|	d}n| j jdkrt }|||	}t|||j|jd	S )
a  
        xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Tag IDs for each token in the input sequence, padded up to config.max_depth.
        xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
            Subscript IDs for each token in the input sequence, padded up to config.max_depth.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
        >>> import torch

        >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
        >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)

        >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"
        >>> encoding = processor(html_string, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**encoding)

        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```NTr0  r   
regressionsingle_label_classificationmulti_label_classificationrA   rE  )r'   r"  r  r4   rB  problem_typer,  rp   rE   rW   rT   r   r6  r   r   r   r   r   r   )r>   rX   rG   rH   r   r~   ra   r   rt   rD  r   r   r   r   r   r8  r2  r;  r   r   r(   rM     sT   -



"


z)MarkupLMForSequenceClassification.forwardrF  )rN   rO   rP   r-   r   r   r   rE   r   r   r   r{   r   rM   rR   r   r   r?   r(   rG    sV    	
rG  )r*  rG  r@  r  r  )r   )r   N)?rQ   r  typingr   r   r   rE   r   torch.nnr   r   r   activationsr
   modeling_layersr   modeling_outputsr   r   r   r   r   r   modeling_utilsr   r   pytorch_utilsr   r   r   utilsr   r   r   configuration_markuplmr   
get_loggerrN   loggerModuler   r]   r^   r   r   r   r   r   r   r   r   floatr   r   r   r   r   r  r  r*  r@  rG  __all__r   r   r   r(   <module>   s~    

3c
;.)1 	mar