o
    }oiXf                     @   s   d dl Z d dlmZmZ d dlZd dlmZ d dlmZm	Z	m
Z
 d dlmZmZ d dlmZ G dd deZG d	d
 d
eZG dd deZG dd deZG dd deZdS )    N)ABCabstractmethod)	rearrange)InterleavedSampleSimilarityInterleavedSample	VQASample)ImageTextSampleMultiModalSampleConfig)loggingc                   @   s    e Zd Zdd Zedd ZdS )SampleEncoderc                 C   s   dS )an  
        Initialize the SampleEncoder class.

        This class serves as an abstract base class for encoding samples. It provides a common interface for
        different types of sample encoders. Subclasses should implement the encode method to perform the actual
        encoding process.

        Parameters:
        None

        Returns:
        None
        N selfr   r   k/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/multimodal/data/energon/sample_encoder.py__init__   s   zSampleEncoder.__init__c                 C      t d)a;  
        Abstract method to encode a sample. Must be implemented by subclasses.

        This method is responsible for encoding a given input sample into a format suitable for further processing.
        The encoded sample is then stored in the output_sample object.

        Parameters:
        input_sample (object): The input sample to be encoded. The type and structure of this object depend on the specific subclass.
        output_sample (object): The object where the encoded sample will be stored. The type and structure of this object depend on the specific subclass.

        Returns:
        None: The method does not return any value.

        Raises:
        NotImplementedError: If the method is called directly on the abstract class, it will raise this exception. Subclasses must implement this method.
        ,Subclasses must implement the encode method.NotImplementedErrorr   input_sampleoutput_sampler   r   r   encode*   s   zSampleEncoder.encodeN)__name__
__module____qualname__r   r   r   r   r   r   r   r      s    r   c                       sj   e Zd ZdZe df fdd	ZdejdejfddZd	ejdejfd
dZ	de
de
ddfddZ  ZS )BaseSampleEncodera#  
    Base class for encoding multimodal samples, specifically for handling text and image data.

    This class provides basic functionality for preprocessing images, computing loss masks,
    and managing sample configuration. It serves as a base class for more specialized encoders.

    Attributes:
    tokenizer (Tokenizer): The HF tokenizer used for tokenizing input text.
    image_processor (ImageProcessor): The HF image processor used for preprocessing input images.
    multimodal_sample_config (MultiModalSampleConfig): Configuration for multimodal samples, including tokens and placeholders.
    ignore_place_holder (int): Token ID used to ignore certain tokens during loss computation.
    image_token (Token): Token dataclass representing image placeholders in the tokenized sequence.
    Nc                    sJ   t    t|dr|j| _n|| _|| _|| _|j| _|j| _|| _dS )a  
        Initialize the BaseSampleEncoder.

        Parameters:
        tokenizer (Tokenizer): The tokenizer used for processing text.
        image_processor (ImageProcessor): The image processor used for preprocessing images.
        multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
            Defaults to MultiModalSampleConfig().
        	tokenizerN)	superr   hasattrr   image_processormultimodal_sample_configignore_place_holderimage_tokenimage_tag_typer   r   r    r!   r$   	__class__r   r   r   N   s   



zBaseSampleEncoder.__init__imagereturnc                 C   s4   | j j|dddd }t|tjsJ t|d}|S )a  
        Preprocess and reshape an input image for encoding.

        The function preprocesses an image using the specified image processor and reshapes it
        to the expected format for further processing.

        Parameters:
        image (torch.Tensor): A tensor representing the input image with dimensions (channels, height, width).

        Returns:
        torch.Tensor: A preprocessed and reshaped image tensor with dimensions (1, 1, channels, height, width).
        ptF)return_tensors
do_rescalepixel_valueszF c h w -> 1 F c h w)r    
preprocess
isinstancetorchTensorr   )r   r(   r   r   r   process_imagee   s   
zBaseSampleEncoder.process_imagelabelsc                 C   s&   t j| t jd}d||| jk< |S )a2  
        Compute a binary loss mask based on the provided labels.

        The function generates a mask where the loss is computed only for tokens that are not
        equal to the `ignore_place_holder` token.

        Parameters:
        labels (torch.Tensor): A tensor containing labels for which the loss mask needs to be generated.

        Returns:
        torch.Tensor: A binary mask tensor with the same shape as the input labels. The mask has ones
            for tokens where loss should be computed and zeros for `ignore_place_holder` tokens.
        dtypeg        )r0   onessizefloatr"   )r   r3   	loss_maskr   r   r   compute_loss_maskw   s   z#BaseSampleEncoder.compute_loss_maskr   r   c                 C   r   )a  
        Abstract method to encode an input sample.

        Subclasses must implement this method to encode input samples into the desired format.

        Parameters:
        input_sample (ImageTextSample): The sample to be encoded.
        output_sample (ImageTextSample): The object to store the encoded sample.

        Returns:
        None

        Raises:
        NotImplementedError: If the method is called directly on the abstract class.
        r   r   r   r   r   r   r      s   zBaseSampleEncoder.encode)r   r   r   __doc__r	   r   r0   r1   r2   r:   r   r   __classcell__r   r   r&   r   r   ?   s    r   c                       s   e Zd ZdZe df fdd	ZddefddZd	ed
e	j
fddZde	j
ded
e	j
fddZdedefddZdd Z  ZS )VQASampleEncodera  
    Encoder specifically designed for Visual Question Answering (VQA) samples.

    This class extends the BaseSampleEncoder to handle VQA tasks, applying a specific prompt
    template and computing labels and loss masks based on the VQA input.

    Attributes:
    conversation_template_config (ConversationTemplateConfig): Configuration for conversation templates used in VQA.
    Nc                       t  |||| |j| _dS )a  
        Initialize the VQASampleEncoder.

        Parameters:
        tokenizer (Tokenizer): The HF tokenizer used for processing text.
        image_processor (ImageProcessor): The HF image processor used for preprocessing images.
        multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
            Defaults to MultiModalSampleConfig().
        N)r   r   conversation_template_configr%   r&   r   r   r         zVQASampleEncoder.__init__F
input_textc                 C   s  t d|j d|j  g }| jjr|d| jjd t|jtrZt|jtrZt	t
|jt
|j}t|D ] }|| jjd |j| d || jjd |j| d q8n+t|jtrt|jtr|| jjd |jd || jjd |jd ntd| jjr| jj| j_n
| jjdu rtd	| jj|d
d
d}| jdkr|dd}n| jdu sJ d| j t d|  |S )a  
        Apply a conversation template to the input text for VQA.

        This method generates a templated prompt by combining system, user, and assistant messages.

        Parameters:
        input_text (VQASample): The VQA sample containing the context and answer.
        use_plain (bool, optional): Whether to use a plain format for the prompt. Defaults to False.

        Returns:
        str: The generated templated prompt as a string.
        z$apply_conversation_template context z answer system)rolecontentr      zYVQA Sample context/answers should either be a List[str] or str. Other types not supportedNzBoth tokenizer and conversation template does not have chat template defined. Refer to https://huggingface.co/docs/transformers/main/en/chat_templatingF)tokenizeadd_generation_promptinternvlz<image>z<img><image></img>zNot supported image_tag_type z'apply prompt template templated_prompt )r
   debugcontextanswersr?   rB   appendr/   listminlenrangerolesstr
ValueErrorchat_templater   apply_chat_templater$   replace)r   rA   	use_plainmessages
min_lengthitemplated_promptr   r   r   apply_prompt_template   s8    
z&VQASampleEncoder.apply_prompt_templatepromptr)   c                 C   s   dd dd | jjfD  d }t||}g }|D ]!}|| jjkr+|| jj qt|dkr<|| j	|ddj
 qtj|tjd	S )
a  
        Tokenize the input prompt, replacing special tokens with their IDs.

        This method splits the prompt into chunks based on the presence of special tokens (like <image>)
        and tokenizes each chunk. Special tokens are replaced with their corresponding token IDs.

        Parameters:
        prompt (str): The prompt string to tokenize.

        Returns:
        torch.Tensor: A tensor containing the tokenized prompt.
        (|c                 s   s    | ]}t |V  qd S )N)reescape).0tokenr   r   r   	<genexpr>   s    z,VQASampleEncoder.tokenize.<locals>.<genexpr>)r   Fadd_special_tokensr4   )joinr#   	token_strr`   splitrL   token_idrO   extendr   	input_idsr0   tensorlong)r   r]   regex_patternchunkstokenized_chunkschunkr   r   r   rF      s   "zVQASampleEncoder.tokenizetokenssamplec                 C   s   ddl m} t|| j }d}t| jdd}t|jt	r |jn|jg}|D ]H}| 
||}| jj|dddd }	| j|	d dkrI|	d	d }	|||	|\}
}|
dk rbtd
|j||	|  |S ||
| ||
|< |}q&|S )a)  
        Compute labels for the tokenized prompt based on the answers in the VQA sample.

        This method generates a label tensor where the tokens corresponding to the answers are marked
        with their token IDs, while other tokens are marked with the `ignore_place_holder` ID.

        Parameters:
        tokens (torch.Tensor): A tensor containing the tokenized prompt.
        sample (VQASample): The VQA sample containing the answers.

        Returns:
        torch.Tensor: A tensor containing the labels for the tokenized prompt.
        r   )_find_pattern_indicesstop_stringNFr*   )rg   r+    rE   zUnable to find a valid answer in the conversation. Details: 
- Messages: %s
- Tokens: %s
- Answer Tokens: %s
- Search Start Index: %d)nemo.collections.vlm.data.utilsrv   r0   	ones_liker"   getattrr?   r/   rK   rM   process_answer_strr   r   decoder
   warning)r   rt   ru   rv   r3   search_start_indexstop_strrK   answeranswer_tokensanswer_start
answer_endr   r   r   compute_labels  s0   zVQASampleEncoder.compute_labelsr   r   c                 C   s   |  |}td|  | |}| ||}|dd  }|dd  }td|  td|  | |}| |j}|j	|_	|j
d g|_|d|_||_||_||_|S )a  
        Encode a VQA sample into a format suitable for further processing.

        This method applies a prompt template, tokenizes the prompt, computes labels and a loss mask,
        and processes the image. The encoded sample is then stored in the output_sample object.

        Parameters:
        input_sample (VQASample): The VQA sample to be encoded.
        output_sample (ImageTextSample): The object to store the encoded sample.

        Returns:
        ImageTextSample: The encoded sample stored in output_sample.
        z/task encoder encode_sample conversation_prompt NrE   8task encoder encode_sample after tokenize prompt tokens z"task encoder encode_sample labels r   )r\   r
   rI   rF   r   
contiguousr:   r2   r(   __key__shapenum_image_tilessqueezeimagesrt   r3   r9   )r   r   r   conversation_promptrt   r3   r9   processed_imager   r   r   r   C  s"   


zVQASampleEncoder.encodec                 C   s   d| |d u r
d S | S )N rx   r   )r   r   r   r   r   r   r|   f  s   z#VQASampleEncoder.process_answer_str)F)r   r   r   r;   r	   r   r   r\   rR   r0   r1   rF   r   r   r   r|   r<   r   r   r&   r   r=      s    6=#r=   c                       sj   e Zd ZdZe df fdd	Zdeejejf fddZ	dejdejfd	d
Z
dedefddZ  ZS )InterleavedSampleEncodera  
    Encoder for handling interleaved sequences of text and images (InterleavedSample from energon).

    This class extends the BaseSampleEncoder to handle interleaved samples, where the input
    consists of a sequence of text strings and image tensors. The text and images are processed
    and encoded into a format suitable for further processing.

    Attributes:
    tokenizer (Tokenizer): The tokenizer used for processing text.
    image_processor (ImageProcessor): The image processor used for preprocessing images.
    multimodal_sample_config (MultiModalSampleConfig): Configuration for multimodal samples, including tokens and placeholders.
    Nc                    s   t  |||| dS )a  
        Initialize the InterleavedSampleEncoder.

        Parameters:
        tokenizer (Tokenizer): The tokenizer used for processing text.
        image_processor (ImageProcessor): The image processor used for preprocessing images.
        multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
            Defaults to MultiModalSampleConfig().
        N)r   r   r%   r&   r   r   r   x  s   z!InterleavedSampleEncoder.__init__r)   c                 C   s   g }g }|j D ]<}t|tjr!|| jj | |}|| qt|dkr;t	
d|  || j|ddj qtdt| tj|tjd}t	
d|  tj|dd	}||fS )
a  
        Tokenize the input sequence and process images in an interleaved sample.

        This method processes a sequence that consists of text strings and image tensors.
        The text is tokenized, and the images are processed. The method returns a tensor
        of tokenized text and a concatenated tensor of processed images.

        Parameters:
        sample (InterleavedSample): The interleaved sample containing a sequence of text strings and image tensors.

        Returns:
        tuple[torch.Tensor, torch.Tensor]: A tuple containing:
            - A tensor with tokenized text and image token IDs.
            - A concatenated tensor of processed images.
        r   z<Multimodal datalaoder encoder interleaved sample text chunk Frf   z*Unsupported type in interleaved sequence: r4   zAMultimodal dataloader encode interleaved sample tokenized chunks rE   dim)sequencer/   r0   r1   rL   r#   rk   r2   rO   r
   rI   rl   r   rm   rS   typern   ro   concatenate)r   ru   rr   r   rs   r   rt   image_tensorr   r   r   rF     s   

z!InterleavedSampleEncoder.tokenizert   c                 C   s.   |  }| j||| jjk< |dd  }|S )a  
        Compute labels for an interleaved sample, ignoring image token IDs.

        This method generates a label tensor where the tokens corresponding to images are marked
        with the `ignore_place_holder` ID, and other tokens retain their original IDs.

        Parameters:
        tokens (torch.Tensor): A tensor containing the tokenized sequence.

        Returns:
        torch.Tensor: A tensor containing the labels for the tokenized sequence.
        rE   N)cloner"   r#   rk   r   )r   rt   r3   r   r   r   r     s   z'InterleavedSampleEncoder.compute_labelsr   r   c                 C   sr   |  |\}}| |}|dd }td|  td|  | |}|j|_||_||_||_||_	|S )a  
        Encode an interleaved sample into a format suitable for further processing.

        This method tokenizes the input sequence, computes labels and a loss mask, and processes
        the images. The encoded sample is then stored in the output_sample object.

        Parameters:
        input_sample (InterleavedSample): The interleaved sample to be encoded.
        output_sample (ImageTextSample): The object to store the encoded sample.

        Returns:
        ImageTextSample: The encoded sample stored in output_sample.
        Nr   r   z"task encoder encode_sample lables )
rF   r   r
   rI   r:   r   r   rt   r3   r9   )r   r   r   rt   r   r3   r9   r   r   r   r     s   

zInterleavedSampleEncoder.encode)r   r   r   r;   r	   r   tupler0   r1   rF   r   r   r   r   r<   r   r   r&   r   r   j  s    "r   c                       sF   e Zd ZdZe df fdd	Zdedeej	ej	f fddZ
  ZS )	SimilarityInterleavedEncodera  
    Encoder for handling similarity-based interleaved sequences of text and images.

    This class extends the InterleavedSampleEncoder to handle samples where images and text
    are interleaved based on a similarity matrix. The images are inserted into the text sequence
    based on the similarity scores (matched_text_indices), allowing for flexible interleaving of media types.

    Attributes:
    image_following_text (bool): A flag indicating whether images should follow the text they are related to.
    Nc                    r>   )a  
        Initialize the SimilarityInterleavedEncoder.

        Parameters:
        tokenizer (Tokenizer): The tokenizer used for processing text.
        image_processor (ImageProcessor): The image processor used for preprocessing images.
        multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
            Defaults to MultiModalSampleConfig().
        N)r   r   image_following_textr%   r&   r   r   r     r@   z%SimilarityInterleavedEncoder.__init__ru   r)   c                    s  |j }|j}|j}g }d}dd tt||D } fdd|D }t|}tt|D ]5}	|t|k rZ|| |	krZ jsD| j	j
 |||	   jrU| j	j
 |d7 }q,|||	  q,g }
|D ] }|
rt|
d trt|tr|
d  d| 7  < qf|
| qfg }|
D ]}| j	j
kr|| q| j|dd	j qtj|tjd
}td|  tj|dd}||fS )a&  
        Tokenize the input sequence and process images based on similarity indices.

        This method processes a sequence of text strings and images, interleaving them based
        on similarity indices (matched_text_indices). The text is tokenized, and the images are processed. The method
        returns a tensor of tokenized text and a concatenated tensor of processed images.

        Parameters:
        sample (SimilarityInterleavedSample): The sample containing a sequence of text strings and images,
            along with similarity indices that determine the interleaving order.

        Returns:
        tuple[torch.Tensor, torch.Tensor]: A tuple containing:
            - A tensor with tokenized text and image token IDs.
            - A concatenated tensor of processed images.
        r   c                 S   s   g | ]\}}|qS r   r   )rb   _imgr   r   r   
<listcomp>
  s    z9SimilarityInterleavedEncoder.tokenize.<locals>.<listcomp>c                    s   g | ]}  |qS r   )r2   )rb   rs   r   r   r   r     s    rE   r   r   Frf   r4   zLMultimodal dataloader encode similarity interleaved sample tokenized chunks r   )r   textsmatched_text_indicessortedziprP   rO   r   rL   r#   rk   r/   rR   rl   r   rm   r0   rn   ro   r
   rI   r   )r   ru   r   r   r   interleaved_list	image_idxsorted_imagessorted_indicestext_idxfinal_sequenceitemrr   rs   rt   r   r   r   r   rF     sB   
z%SimilarityInterleavedEncoder.tokenize)r   r   r   r;   r	   r   r   r   r0   r1   rF   r<   r   r   r&   r   r     s
    &r   )r`   abcr   r   r0   einopsr   megatron.energonr   r   r   /nemo.collections.multimodal.data.energon.configr   r	   
nemo.utilsr
   r   r   r=   r   r   r   r   r   r   <module>   s   &] Om