o
    eiXU                     @   s  d Z ddlmZ ddlZddlmZmZ ddlmZ ddlm	Z	 ddl
mZmZmZ d	d
lmZ ddlmZ eeZeeddG dd deZeeddG dd deZeeddG dd deZeG dd de	ZG dd deZG dd deZG dd deZG dd deZG dd  d eZed!dG d"d# d#eZed$dG d%d& d&eZed'dG d(d) d)eZ g d*Z!dS )+z5PyTorch DPR model for Open Domain Question Answering.    )	dataclassN)Tensornn   )BaseModelOutputWithPooling)PreTrainedModel)ModelOutputauto_docstringlogging   )	BertModel   )	DPRConfigz6
    Class for outputs of [`DPRQuestionEncoder`].
    )custom_introc                   @   P   e Zd ZU dZejed< dZeejdf dB ed< dZ	eejdf dB ed< dS )DPRContextEncoderOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
        The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer
        hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
        This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
    pooler_outputN.hidden_states
attentions
__name__
__module____qualname____doc__torchFloatTensor__annotations__r   tupler    r   r   b/home/ubuntu/transcripts/venv/lib/python3.10/site-packages/transformers/models/dpr/modeling_dpr.pyr   (   
   
 
r   c                   @   r   )DPRQuestionEncoderOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
        The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer
        hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
        This output is to be used to embed questions for nearest neighbors queries with context embeddings.
    r   N.r   r   r   r   r   r   r   r!   ;   r    r!   c                   @   st   e Zd ZU dZejed< dZejdB ed< dZejdB ed< dZ	e
ejdf dB ed< dZe
ejdf dB ed< dS )	DPRReaderOutputa  
    start_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
        Logits of the start index of the span for each passage.
    end_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
        Logits of the end index of the span for each passage.
    relevance_logits (`torch.FloatTensor` of shape `(n_passages, )`):
        Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
        question, compared to all the other passages.
    start_logitsN
end_logitsrelevance_logits.r   r   )r   r   r   r   r   r   r   r$   r%   r   r   r   r   r   r   r   r"   N   s   
 

r"   c                   @   s   e Zd ZdZdS )DPRPreTrainedModelTN)r   r   r   _supports_sdpar   r   r   r   r&   f   s    r&   c                       s   e Zd ZdZdef fddZ						ddededB d	edB d
edB dedededee	edf B fddZ
edefddZ  ZS )
DPREncoder
bert_modelconfigc                    sd   t  | t|dd| _| jjjdkrtd|j| _| jdkr,t	| jjj|j| _
|   d S )NF)add_pooling_layerr   z!Encoder hidden_size can't be zero)super__init__r   r)   r*   hidden_size
ValueErrorprojection_dimr   Linearencode_proj	post_initselfr*   	__class__r   r   r-   n   s   
zDPREncoder.__init__NF	input_idsattention_masktoken_type_idsinputs_embedsoutput_attentionsoutput_hidden_statesreturn_dictreturn.c              	   K   sv   | j |||||||d}	|	d }
|
d d dd d f }| jdkr%| |}|s1|
|f|	dd   S t|
||	j|	jdS )Nr8   r9   r:   r;   r<   r=   r>   r   r   )last_hidden_stater   r   r   )r)   r0   r2   r   r   r   )r5   r8   r9   r:   r;   r<   r=   r>   kwargsoutputssequence_outputpooled_outputr   r   r   forwardy   s*   	

zDPREncoder.forwardc                 C   s   | j dkr	| jjS | jjjS )Nr   )r0   r2   out_featuresr)   r*   r.   )r5   r   r   r   embeddings_size   s   

zDPREncoder.embeddings_size)NNNFFF)r   r   r   base_model_prefixr   r-   r   boolr   r   rF   propertyintrH   __classcell__r   r   r6   r   r(   k   s8    

$r(   c                       sf   e Zd ZdZdef fddZ				ddeded	edB d
edededee	edf B fddZ
  ZS )DPRSpanPredictorencoderr*   c                    sF   t  | t|| _t| jjd| _t| jjd| _| 	  d S )Nr   r   )
r,   r-   r(   rO   r   r1   rH   
qa_outputsqa_classifierr3   r4   r6   r   r   r-      s
   
zDPRSpanPredictor.__init__NFr8   r9   r;   r<   r=   r>   r?   .c                 K   s   |d ur|  n|  d d \}}	| j||||||d}
|
d }| |}|jddd\}}|d }|d }| |d d dd d f }|||	}|||	}||}|si|||f|
dd   S t||||
j	|
j
dS )Nr   )r9   r;   r<   r=   r>   r   r   )dim)r#   r$   r%   r   r   )sizerO   rP   splitsqueeze
contiguousrQ   viewr"   r   r   )r5   r8   r9   r;   r<   r=   r>   rB   
n_passagessequence_lengthrC   rD   logitsr#   r$   r%   r   r   r   rF      s6   $

zDPRSpanPredictor.forward)NFFF)r   r   r   rI   r   r-   r   rJ   r"   r   rF   rM   r   r   r6   r   rN      s,    	rN   c                   @      e Zd ZU dZeed< dZdS )DPRPretrainedContextEncoder
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    r*   ctx_encoderNr   r   r   r   r   r   rI   r   r   r   r   r]         
 r]   c                   @   r\   )DPRPretrainedQuestionEncoderr^   r*   question_encoderNr`   r   r   r   r   rb      ra   rb   c                   @   r\   )DPRPretrainedReaderr^   r*   span_predictorNr`   r   r   r   r   rd      ra   rd   zf
    The bare DPRContextEncoder transformer outputting pooler outputs as context representations.
    c                          e Zd Zdef fddZe							ddedB dedB dedB dedB d	edB d
edB dedB dee	edf B fddZ
  ZS )DPRContextEncoderr*   c                    (   t  | || _t|| _|   d S N)r,   r-   r*   r(   r_   r3   r4   r6   r   r   r-        
zDPRContextEncoder.__init__Nr8   r9   r:   r;   r<   r=   r>   r?   .c              	   K   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}|dur*|dur*td|dur3| }	n|dur@| dd }	ntd|durK|jn|j}
|du rc|du r]tj|	|
dn|| j j	k}|du rptj
|	tj|
d}| j|||||||d}|s|dd S t|j|j|jd	S )
aS  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
            formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs (for a pair title+text for example):

            ```
            tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
            token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
            ```

            (b) For single sequences (for a question for example):

            ```
            tokens:         [CLS] the dog is hairy . [SEP]
            token_type_ids:   0   0   0   0  0     0   0
            ```

            DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
            rather than the left.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)

        Examples:

        ```python
        >>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer

        >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
        >>> model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
        >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
        >>> embeddings = model(input_ids).pooler_output
        ```NDYou cannot specify both input_ids and inputs_embeds at the same timerR   5You have to specify either input_ids or inputs_embedsdevicedtypern   r@   r   r   r   r   )r*   r<   r=   use_return_dictr/   rT   rn   r   onespad_token_idzeroslongr_   r   r   r   r   r5   r8   r9   r:   r;   r<   r=   r>   rB   input_shapern   rC   r   r   r   rF     sB   2


zDPRContextEncoder.forwardNNNNNNN)r   r   r   r   r-   r	   r   rJ   r   r   rF   rM   r   r   r6   r   rg     6    
rg   zh
    The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.
    c                       rf   )DPRQuestionEncoderr*   c                    rh   ri   )r,   r-   r*   r(   rc   r3   r4   r6   r   r   r-   v  rj   zDPRQuestionEncoder.__init__Nr8   r9   r:   r;   r<   r=   r>   r?   .c              	   K   s(  |dur|n| j j}|dur|n| j j}|dur|n| j j}|dur*|dur*td|dur9| || | }	n|durF| dd }	ntd|durQ|jn|j}
|du ri|du rctj	|	|
dn|| j j
k}|du rvtj|	tj|
d}| j|||||||d}|s|dd S t|j|j|jd	S )
aj  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
            formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs (for a pair title+text for example):

            ```
            tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
            token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
            ```

            (b) For single sequences (for a question for example):

            ```
            tokens:         [CLS] the dog is hairy . [SEP]
            token_type_ids:   0   0   0   0  0     0   0
            ```

            DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
            rather than the left.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)

        Examples:

        ```python
        >>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer

        >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        >>> model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
        >>> embeddings = model(input_ids).pooler_output
        ```
        Nrk   rR   rl   rm   ro   r@   r   rq   )r*   r<   r=   rr   r/   %warn_if_padding_and_no_attention_maskrT   rn   r   rs   rt   ru   rv   rc   r!   r   r   r   rw   r   r   r   rF   }  sD   2


zDPRQuestionEncoder.forwardry   )r   r   r   r   r-   r	   r   rJ   r!   r   rF   rM   r   r   r6   r   r{   p  rz   r{   zE
    The bare DPRReader transformer outputting span predictions.
    c                       s~   e Zd Zdef fddZe						ddedB dedB dedB dedB d	edB d
edB dee	edf B fddZ
  ZS )	DPRReaderr*   c                    rh   ri   )r,   r-   r*   rN   re   r3   r4   r6   r   r   r-     rj   zDPRReader.__init__Nr8   r9   r;   r<   r=   r>   r?   .c           
      K   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|dur*|dur*td|dur9| || | }n|durF| dd }ntd|durQ|jn|j}	|du r_tj	||	d}| j
||||||dS )a  
        input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
            and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should
            be formatted with [CLS] and [SEP] with the format:

            `[CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>`

            DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
            rather than the left.

            Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details.

            [What are input IDs?](../glossary#input-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.

        Examples:

        ```python
        >>> from transformers import DPRReader, DPRReaderTokenizer

        >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
        >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
        >>> encoded_inputs = tokenizer(
        ...     questions=["What is love ?"],
        ...     titles=["Haddaway"],
        ...     texts=["'What Is Love' is a song recorded by the artist Haddaway"],
        ...     return_tensors="pt",
        ... )
        >>> outputs = model(**encoded_inputs)
        >>> start_logits = outputs.start_logits
        >>> end_logits = outputs.end_logits
        >>> relevance_logits = outputs.relevance_logits
        ```
        Nrk   rR   rl   rm   )r;   r<   r=   r>   )r*   r<   r=   rr   r/   r|   rT   rn   r   rs   re   )
r5   r8   r9   r;   r<   r=   r>   rB   rx   rn   r   r   r   rF     s.   1
zDPRReader.forward)NNNNNN)r   r   r   r   r-   r	   r   rJ   r"   r   rF   rM   r   r   r6   r   r}     s0    	r}   )rg   r]   r&   rb   rd   r{   r}   )"r   dataclassesr   r   r   r   modeling_outputsr   modeling_utilsr   utilsr   r	   r
   bert.modeling_bertr   configuration_dprr   
get_loggerr   loggerr   r!   r"   r&   r(   rN   r]   rb   rd   rg   r{   r}   __all__r   r   r   r   <module>   sZ   
9?

efX