o
    i.                    @   s  d dl Z d dlmZ d dlmZ d dlmZmZmZ d dl	Z
d dlZd dlm  mZ d dlZd dlmZ d dlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZ ddlmZ ddl m!Z!m"Z" ddl#m$Z$m%Z%m&Z& ddl'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0m1Z1 ddl2m3Z3 ddl4m5Z5m6Z6 ddl7m8Z8 ddl9m:Z:m;Z;m<Z<m=Z=m>Z>m?Z?m@Z@ ddlAmBZBmCZCmDZD ddlEmFZF ddlGmHZH ddlImJZJmKZKmLZLmMZMmNZN ddlOmPZPmQZQ ddlRmSZS ddlTmUZU ddlVmWZWmXZXmYZY e? rd dlZZZe@[e\Z]G dd deUZ^G dd  d eHZ_G d!d" d"eZ`e<G d#d$ d$e6Zaee<d%d&G d'd( d(e3ZbG d)d* d*ePZcG d+d, d,eQZdG d-d. d.eYZeG d/d0 d0ejfZgG d1d2 d2ejfZhG d3d4 d4eXZiG d5d6 d6eWZjG d7d8 d8eFZkG d9d: d:ejfZlG d;d< d<eNZmG d=d> d>eMZnG d?d@ d@eKZoG dAdB dBeLZpG dCdD dDejfZqG dEdF dFejfZrG dGdH dHejfZsG dIdJ dJejfZtG dKdL dLeJZuG dMdN dNejfZvG dOdP dPejfZwe<dQd&G dRdS dSeaZxG dTdU dUeaeZyG dVdW dWeZzg dXZ{dS )Y    N)Iterable)	dataclass)CallableOptionalUnion)nn)BlipImageProcessor   )ACT2FN)Cache)PretrainedConfig)%ClassifierFreeGuidanceLogitsProcessorGenerationMixinGenerationModeLogitsProcessorList)GenerateDecoderOnlyOutput)BatchFeatureget_size_dict)convert_to_rgbresizeto_channel_dimension_format)
ChannelDimension
ImageInputPILImageResamplingget_image_sizeinfer_channel_dimension_formatis_scaled_imagemake_flat_list_of_imagesto_numpy_arrayvalid_imagesvalidate_preprocess_arguments)ModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)
TensorTypeTransformersKwargsauto_docstringcan_return_tuplefilter_out_non_signature_kwargsis_vision_availablelogging   )CONFIG_MAPPING
AutoConfig	AutoModel)Blip2VisionModel)ChameleonVQVAEConfig)ChameleonVQVAEChameleonVQVAEEncoderAttnBlock#ChameleonVQVAEEncoderConvDownsample ChameleonVQVAEEncoderResnetBlockChameleonVQVAEVectorQuantizer)IdeficsBaseModelOutputWithPastIdeficsCausalLMOutputWithPast)eager_attention_forward)SiglipVisionConfig)SiglipEncoderSiglipEncoderLayerSiglipVisionEmbeddingsc                       sN   e Zd ZdZdZdZ									
												d fdd	Z  ZS )JanusVisionConfiga
  
    This is the configuration class to store the configuration of a [`JanusVisionModel`]. It is used to instantiate a
    `JanusVisionModel` according to the specified arguments, defining the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        hidden_size (`int`, *optional*, defaults to 1024):
            Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 24):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        patch_size (`int`, *optional*, defaults to 16):
            The size (resolution) of each patch.
        image_size (`int`, *optional*, defaults to 384):
            The size (resolution) of each image.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            Dropout probability for attention weights.
        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the layer normalization layers.
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"`, and `"gelu_new"` are supported.
        mlp_ratio (`float`, *optional*, defaults to 4.0):
            Ratio of MLP hidden dimensionality to embedding dimensionality.
        attention_bias (`bool`, *optional*, defaults to `True`):
            Whether to add a bias to the queries, keys, and values in the attention layers.
        hidden_dropout_rate (`float`, *optional*, defaults to 0.0):
            The dropout probability for fully connected layers in the encoder.
        projection_dim (`int`, *optional*, defaults to 2048):
            Dimensionality of the MLP projection head.
        projection_dropout (`float`, *optional*, defaults to 0.0):
            Dropout probability for the projection layer.
        use_qk_norm (`bool`, *optional*, defaults to `False`):
            Whether to normalize the query and key matrices.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated normal initializer for initializing all weight matrices.
        depth (`int`, *optional*, defaults to 2):
            Number of hidden layers in the aligner module.
        num_image_tokens (`int`, *optional*, defaults to 576):
            Number of image tokens.
    janus_vision_modelvision_config         r	             ư>gelu      @T   F{Gz?r,   @  c                    sd   t  jd|||||||||	d	| | `|
| _|| _|| _|| _|| _|| _|| _	|| _
|| _d S )N)	hidden_sizenum_hidden_layersnum_attention_headsnum_channels
patch_size
image_sizeattention_dropoutlayer_norm_eps
hidden_act )super__init__intermediate_size	mlp_ratioattention_biashidden_dropout_rateprojection_dimprojection_dropoutuse_qk_norminitializer_rangedepthnum_image_tokens)selfrL   rM   rN   rO   rP   rQ   rR   rS   rT   rY   rZ   r[   r\   r]   r^   r_   r`   ra   kwargs	__class__rU   [/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/janus/modular_janus.pyrW      s.   

zJanusVisionConfig.__init__)rA   rB   rC   r	   rC   rD   rE   rF   rG   rH   TrE   rI   rE   FrJ   r,   rK   )__name__
__module____qualname____doc__
model_typebase_config_keyrW   __classcell__rU   rU   rd   rf   r>   T   s.    .r>   c                       sx   e Zd ZdZddddddddg d	d
dddd
ddfdededededededededee dedef fddZ  Z	S )JanusVQVAEConfiga:
  
    This is the configuration class to store the configuration of a [`JanusVQVAEModel`]. It is used to instantiate a
    `JanusVQVAEModel` according to the specified arguments, defining the model architecture.
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information. Instantiating a
    configuration with the defaults will yield a similar configuration to the VQModel of the
    [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B).

    Args:
        embed_dim (`int`, *optional*, defaults to 8):
            Dimensionality of each embedding vector.
        num_embeddings (`int`, *optional*, defaults to 16384):
            Number of codebook embeddings.
        double_latent (`bool`, *optional*, defaults to `False`):
            Whether to use double z channels.
        latent_channels (`int`, *optional*, defaults to 256):
            Number of channels for the latent space.
        num_patches (`int`, *optional*, defaults to 32):
            Num of patches the input images can be divided into.
        in_channels (`int`, *optional*, defaults to 3):
            Number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            Number of out channels.
        base_channels (`int`, *optional*, defaults to 128):
            Base channel count.
        channel_multiplier (`list[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
            Channel multipliers for each resolution.
        num_res_blocks (`int`, *optional*, defaults to 2):
            Number of residual blocks.
        dropout (`float`, *optional*, defaults to 0.0):
            Dropout rate.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        projection_dim (`int`, *optional*, defaults to 2048):
            Dimensionality of the MLP projection head.
        num_hidden_layers (`int`, *optional*, defaults to 2):
            Number of hidden layers in VAVAE MLP Connecter module.
        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
        image_token_embed_dim (`int`, *optional*, defaults to 2048):
            Dimension of image embeddings. It should be same as the dimensionality of text embeddings.
       i @  F       r	      )   rs   r,   r,      r,   rE   rJ   rI   rG   	embed_dimnum_embeddingsdouble_latentlatent_channelsnum_patchesin_channelsout_channelsbase_channelschannel_multipliernum_res_blocksdropoutc                    s\   t  jd|||||||	|
||d
| || _|| _|| _|| _|| _|| _| `| `	| `
d S )N)
ru   rv   rw   rx   rz   r|   r}   r~   r   r_   rU   )rV   rW   ry   r{   r\   rM   rT   image_token_embed_dim
resolutionattn_resolutions	attn_type)rb   ru   rv   rw   rx   ry   rz   r{   r|   r}   r~   r   r_   r\   rM   rT   r   rc   rd   rU   rf   rW      s.   zJanusVQVAEConfig.__init__)
rg   rh   ri   rj   intboollistfloatrW   rm   rU   rU   rd   rf   rn      sR    .	
rn   c                       s:   e Zd ZdZdZeeedZ				d fdd	Z	  Z
S )	JanusConfiga;  
    This is the configuration class to store the configuration of a [`JanusModel`]. It is used to instantiate an
    Janus model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the Janus-1B or Janus-7B models.

    e.g. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B) or
    [deepseek-community/Janus-Pro-7B](https://huggingface.co/deepseek-community/Janus-Pro-7B)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
            The config object or dictionary of the text backbone.
        vision_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `JanusVisionConfig`):
            The config object or dictionary of the vision backbone.
        vq_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `JanusVQVAEConfig`):
            The config object or dictionary of the VQVAE backbone.
        image_token_id (`int`, *optional*, defaults to 100581):
            Token index of a placeholder image token.

    Example:

    ```python
    >>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig

    >>> # Initializing a Janus vision config
    >>> vision_config = JanusVisionConfig()

    >>> # Initializing a Llama config
    >>> text_config = LlamaConfig()

    >>> # Initializing a VQ config
    >>> vq_config = JanusVQVAEConfig()

    >>> # Initializing a Janus Pro 1B style configuration
    >>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config)

    >>> # Initializing a model from the Janus Pro 1B style configuration
    >>> model = JanusForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```janus)text_configr@   	vq_configN c                    sj  t |tr|dd|d< t|d  d	i || _n"|d u r*td td  | _nt |tr3|| _n	tdt	| |d u rJtd t
 | _n t |trXt
d	i || _nt |t
ra|| _n	tdt	| |d u rxtd t | _n t |trtd	i || _nt |tr|| _n	tdt	| | jj| _| jj| jj | j_|| _t jd	i | d S )
Nrk   llamaz7`text_config` is None. Initializing with default valueszTInvalid type for `text_config`. Must be either `dict` or `LlamaConfig`. Type found: zK`vision_config` is None. Initializing with default JanusVisionConfig valuesz\Invalid type for `vision_config`. Must be either `dict` or `JanusVisionConfig`. Type found: zF`vq_config` is None. Initializing with default JanusVQVAEConfig valueszWInvalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`. Type found: rU   )
isinstancedictgetr-   r   loggerinfor   
ValueErrortyper>   r@   rn   r   r_   rQ   rP   ry   image_token_idrV   rW   )rb   r   r@   r   r   rc   rd   rU   rf   rW   D  sR   











zJanusConfig.__init__)NNNr   )rg   rh   ri   rj   rk   r.   r>   rn   sub_configsrW   rm   rU   rU   rd   rf   r     s    -r   c                   @   s>   e Zd ZU eed< dZdZddgZddgZdZ	dZ
dZdZd	S )
JanusPreTrainedModelconfigmodelTLlamaDecoderLayerJanusVisionEncoderLayerpast_key_valuescausal_maskFN)rg   rh   ri   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_skip_keys_device_placement_supports_flash_attn_supports_sdpa_can_compile_fullgraph!_supports_param_buffer_assignmentrU   rU   rU   rf   r   }  s   
 r   z9
    Base class for Janus VQ-VAE mode model outputs.
    )custom_introc                   @   s6   e Zd ZU dZdZeej ed< dZ	eej ed< dS )JanusVQVAEOutputz
    decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
        Reconstructed pixel values after encoding and decoding the input.
    embedding_loss (`torch.FloatTensor`):
        Embedding loss.
    Ndecoded_pixel_valuesembedding_loss)
rg   rh   ri   rj   r   r   torchFloatTensorr   r   rU   rU   rU   rf   r     s   
 r   c                   @      e Zd ZdS )JanusBaseModelOutputWithPastNrg   rh   ri   rU   rU   rU   rf   r         r   c                   @   r   )JanusCausalLMOutputWithPastNr   rU   rU   rU   rf   r     r   r   c                   @   s(   e Zd ZddejdedejfddZdS )	JanusVisionEmbeddingsFpixel_valuesinterpolate_pos_encodingreturnc           
      C   sh   |j \}}}}| jjj}| |j|d}|ddd}|r(| |||}	n| | j	}	||	 }|S )Ndtyper,   rs   )
shapepatch_embeddingweightr   toflatten	transposer   position_embeddingposition_ids)
rb   r   r   _heightwidthtarget_dtypepatch_embeds
embeddings
pos_embedsrU   rU   rf   forward  s   
zJanusVisionEmbeddings.forwardN)F)rg   rh   ri   r   Tensorr   r   rU   rU   rU   rf   r     s     r   c                       sL   e Zd ZdZdef fddZ	ddejdeej de	e
 fd	d
Z  ZS )JanusVisionAttentionz(Attention Class for Janus Vision Encoderr   c                    sL  t    || _|j| _|j| _| j| j | _| j| j | jkr-td| j d| j d| jd | _	|j
| _
|j}|j}d| _d| _tj| j| j| j |jd| _tj| j| j| j |jd| _tj| j| j| j |jd| _t| j| j| _|dkrt|nt | _|rt| jnt | _|rt| j| _d S t | _d S )	Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      Frs   biasr   )rV   rW   r   rL   ru   rN   	num_headshead_dimr   scalerR   r]   r^   	is_causalnum_key_value_groupsr   LinearrZ   q_projk_projv_projprojection_layerDropoutIdentity	LayerNormq_normk_norm)rb   r   proj_dropoutqk_normrd   rU   rf   rW     s0   

$zJanusVisionAttention.__init__Nhidden_statesattention_maskrc   c                 K   s4  |  \}}}| |}| |}| |}	|d| j| j}| |}|d| j| j}| |}|||| j| j	dd}|||| j| j	dd}|	
||| j| j	dd}	t}
| jjdkrjt| jj }
|
| |||	|f| jsvdn| j| j| jd|\}}|||| j}| |}| |}||fS )Nrs   r,   eagerrE   )r   scalingr   )sizer   r   r   reshaper   r   r   r   r   viewr9   r   _attn_implementationr"   trainingrR   r   r   ru   r   r]   )rb   r   r   rc   
batch_sizeseq_lenr   query_states
key_statesvalue_statesattention_interfaceattn_outputattn_weightsoutputrU   rU   rf   r     s>   




	


zJanusVisionAttention.forwardN)rg   rh   ri   rj   r>   rW   r   r   r   r$   r&   r   rm   rU   rU   rd   rf   r     s     r   c                       s8   e Zd Zdef fddZdejdejfddZ  ZS )JanusVisionMLPr   c                    sr   t    || _t|j|j | _t|j | _	t
|j| j| _t
| j|j| _t
|j| _t
|j| _d S r   )rV   rW   r   r   rL   rY   rX   r
   rT   activation_fnr   r   fc1fc2r   r[   dropout1dropout2rb   r   rd   rU   rf   rW     s   
zJanusVisionMLP.__init__r   r   c                 C   s6   |  |}| |}| |}| |}| |}|S r   )r   r   r   r   r   rb   r   rU   rU   rf   r     s   




zJanusVisionMLP.forward)	rg   rh   ri   r>   rW   r   r   r   rm   rU   rU   rd   rf   r     s    
r   c                       "   e Zd Zdef fddZ  ZS )r   r   c                    sZ   t  | || _|j| _t|| _tj| j|j	d| _
tj| j|j	d| _t|| _d S )N)eps)rV   rW   r   rL   ru   r   	self_attnr   r   rS   layer_norm1layer_norm2r   mlpr   rd   rU   rf   rW     s   
z JanusVisionEncoderLayer.__init__rg   rh   ri   r>   rW   rm   rU   rU   rd   rf   r         r   c                       r   )JanusVisionEncoderr   c                    s0   t    t fddt jD | _d S )Nc                    s   g | ]}t  qS rU   )r   .0r   r   rU   rf   
<listcomp>$      z/JanusVisionEncoder.__init__.<locals>.<listcomp>)rV   rW   r   
ModuleListrangerM   layersr   rd   r  rf   rW   "  s   $zJanusVisionEncoder.__init__r   rU   rU   rd   rf   r  !  r   r  c                       r   )JanusVisionModelr   c                    s   t  | t|| _d S r   )rV   rW   r  encoderr   rd   rU   rf   rW   (  s   zJanusVisionModel.__init__r   rU   rU   rd   rf   r
  '  r   r
  c                       *   e Zd Zdef fddZdd Z  ZS )JanusVisionAlignerMLPr   c                    N   t    t j j| _t fddtd j	D | _
t j | _d S )Nc                       g | ]
}t  j jqS rU   r   r   r\   r  r  rU   rf   r  3      z2JanusVisionAlignerMLP.__init__.<locals>.<listcomp>rs   )rV   rW   r   r   rL   r\   r   r  r  r`   hidden_layersr
   rT   r   r   rd   r  rf   rW   .     
zJanusVisionAlignerMLP.__init__c                 C   ,   |  |}| jD ]}| |}||}q|S r   r   r  r   rb   r   layerrU   rU   rf   r   7  
   



zJanusVisionAlignerMLP.forward)rg   rh   ri   r>   rW   r   rm   rU   rU   rd   rf   r  -      	r  c                       s8   e Zd Zdef fddZdejdejfddZ  Z	S )JanusVQVAEVectorQuantizerr   c                    s   t  | |jgd | _d S )Nr,   )rV   rW   ry   quant_state_dimsr   rd   rU   rf   rW   @  s   z"JanusVQVAEVectorQuantizer.__init__image_tokensr   c                 C   sb   |j d }| jjj d }| |}tj|ddd}||g| j|R }|dddd }|S )Nr   r   r,   )pdimr	   rs   )	r   	embeddingr   F	normalizer   r  permute
contiguous)rb   r  r   emb_dimhidden_state_quantrU   rU   rf   get_codebook_entryD  s   

z,JanusVQVAEVectorQuantizer.get_codebook_entry)
rg   rh   ri   rn   rW   r   
LongTensorr   r&  rm   rU   rU   rd   rf   r  ?  s    r  c                   @   r   )JanusVQVAEResnetBlockNr   rU   rU   rU   rf   r(  T  r   r(  c                   @   r   )JanusVQVAEAttnBlockNr   rU   rU   rU   rf   r)  X  r   r)  c                   @   r   )JanusVQVAEConvDownsampleNr   rU   rU   rU   rf   r*  \  r   r*  c                       s$   e Zd Z fddZdd Z  ZS )JanusVQVAEConvUpsamplec                    s&   t    tjj||dddd| _d S )Nr	   rs   kernel_sizestridepadding)rV   rW   r   r   Conv2dconv)rb   rz   rd   rU   rf   rW   a  s   
zJanusVQVAEConvUpsample.__init__c                 C   s   t j|ddd}| |}|S )Ng       @nearest)scale_factormode)r   interpolater1  r   rU   rU   rf   r   e  s   
zJanusVQVAEConvUpsample.forward)rg   rh   ri   rW   r   rm   rU   rU   rd   rf   r+  `  s    r+  c                       s<   e Zd Zdedef fddZdejdejfddZ  Z	S )	JanusVQVAEMidBlockr   channelsc                    s8   t    t|||d| _t|| _t|||d| _d S )Nr   rz   r{   )rV   rW   r(  block_1r)  attn_1block_2)rb   r   r7  rd   rU   rf   rW   l  s   

zJanusVQVAEMidBlock.__init__r   r   c                 C   "   |  |}| |}| |}|S r   )r9  r:  r;  r   rU   rU   rf   r   z     


zJanusVQVAEMidBlock.forward)
rg   rh   ri   rn   r   rW   r   r   r   rm   rU   rU   rd   rf   r6  k  s    r6  c                       s,   e Zd Z fddZdejfddZ  ZS )JanusVQVAEEncoderc              	      sn  t    t|j| _|j| _|j}|j}|j}|j	}|j}t
jj||dddd| _dt| }|| _t | _t| jD ]T}t }	t }
|||  }|||  }t| jD ]}|	t|||d |}|| jd krt|
t| qXt }|	|_|
|_|| jd krt||_| j| q=t||| _t
jjd|ddd	| _t
jj||rd
| n|dddd| _d S )Nr	   rs   r,  )rs   r8  rq   rF   T
num_groupsrO   r   affiner,   ) rV   rW   lenr}   num_resolutionsr~   r|   rz   rw   rx   r   r   r0  conv_intuplein_channel_multiplierr  downr  appendr(  r)  Moduleblockattnr*  
downsampler6  mid	GroupNormnorm_outconv_out)rb   r   r|   rz   rw   rx   r}   rF  i_levelrJ  rK  block_in	block_outi_blockrG  rd   rU   rf   rW     sX   


zJanusVQVAEEncoder.__init__r   c                 C   s   |  |g}t| jD ]C}t| jD ]'}| j| j| |d }t| j| jdkr4| j| j| |}|| q|| jd krN|| j| 	|d  q|d }| 
|}| |}|t|9 }| |}|S )Nr   r   rs   )rD  r  rC  r~   rG  rJ  rB  rK  rH  rL  rM  rO  r   sigmoidrP  )rb   r   r   rQ  rT  hidden_statelast_hidden_staterU   rU   rf   r     s$   


zJanusVQVAEEncoder.forward)rg   rh   ri   rW   r   r'  r   rm   rU   rU   rd   rf   r>    s    3r>  c                       s2   e Zd Z fddZdejdejfddZ  ZS )JanusVQVAEDecoderc              	      sP  t    t|j| _|j| _|j}|j}|j}||j| jd   }t	j
j||dddd| _t||| _t
 | _tt| jD ]N}t
 }t
 }||j|  }	t| jd D ]}
|t|||	d |	}|| jd krt|t| qXt
 }||_||_|dkrt||_| j| q@t	j
jd|ddd	| _t	j
j||dddd| _d S )
Nrs   r	   r,  r8  r   rq   rF   Tr?  )rV   rW   rB  r}   rC  r~   r|   rx   r{   r   r   r0  rD  r6  rM  r  upreversedr  rH  r(  r)  rI  rJ  rK  r+  upsamplerN  rO  rP  )rb   r   r|   rx   r{   rR  rQ  rJ  rK  rS  rT  rY  rd   rU   rf   rW     sD   


zJanusVQVAEDecoder.__init__rV  r   c                 C   s   |  |}| |}t| jD ]9}t| jd D ] }| j| j| |}t| j| jdkr8| j| j| |}q|| jd krH| j| 	|}q| 
|}|t|9 }| |}|S )Nrs   r   )rD  rM  r  rC  r~   rY  rJ  rB  rK  r[  rO  r   rU  rP  )rb   rV  rQ  rT  rU   rU   rf   r     s   



zJanusVQVAEDecoder.forward)rg   rh   ri   rW   r   r   r   rm   rU   rU   rd   rf   rX    s    .rX  c                       sl   e Zd Zg dZdZdef fddZdejdej	fdd	Z
eedej	deej	ej	f fd
dZ  ZS )
JanusVQVAE)r)  r(  r  r   r   c                    s(   t  | t|| _d| _|   d S )NF)rV   rW   rX  decodergradient_checkpointing	post_initr   rd   rU   rf   rW     s   
zJanusVQVAE.__init__r  r   c                 C   sr   |j d | jjd | jjd  kr'td| jjd | jjd   d|j  d| j|}| |}| |}|S )aG  
        Decodes quantized token IDs into pixel values.
        Args:
            image_tokens (torch.LongTensor): Batch of token IDs.
        Returns:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
                Pixel values decoded from the token IDs.
        rs   r   z4Expected `image_tokens` to have shape `(batch_size, z)`, but got shape `z`.)r   quantizer  r   r&  post_quant_convr]  )rb   r  codebook_entryr   r   rU   rU   rf   decode"  s   "	

zJanusVQVAE.decodec                 C   s6   |j d }| |\}}}| ||d}t||S )Nr   r   )r   encoderc  r   r   )rb   r   r   quantr   indicesr   rU   rU   rf   r   5  s   

zJanusVQVAE.forward)rg   rh   ri   r   main_input_namern   rW   r   r'  r   rc  r(   r'   rE  r   rm   rU   rU   rd   rf   r\    s    r\  c                       r  )JanusVQVAEAlignerMLPr   c                    r  )Nc                    r  rU   r  r  r  rU   rf   r  H  r  z1JanusVQVAEAlignerMLP.__init__.<locals>.<listcomp>rs   )rV   rW   r   r   ru   r\   r   r  r  rM   r  r
   rT   r   r   rd   r  rf   rW   C  r  zJanusVQVAEAlignerMLP.__init__c                 C   r  r   r  r  rU   rU   rf   r   L  r  zJanusVQVAEAlignerMLP.forward)rg   rh   ri   rn   rW   r   rm   rU   rU   rd   rf   rh  B  r  rh  c                       s<   e Zd ZdZdef fddZdejdejfddZ	  Z
S )	JanusVQVAEHeadzOHead used for sampling tokens in image generation, replacing the usual lm head.r   c                    s>   t    t|j|j| _t|j | _	t|j|j
| _d S r   )rV   rW   r   r   r   r\   proj_outr
   rT   r   rv   vision_headr   rd   rU   rf   rW   W  s   
zJanusVQVAEHead.__init__r   r   c                 C   r<  r   )rj  r   rk  r   rU   rU   rf   r   ]  r=  zJanusVQVAEHead.forward)rg   rh   ri   rj   rn   rW   r   r   tensorr   rm   rU   rU   rd   rf   ri  T  s    ri  zl
    The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
    c                       s   e Zd Zdef fddZdd Zdd Zdd	 Zd
ej	dej
dej
fddZee									dd
eej	 deej
 deej deej	 dee deej	 deej
 dee deeejf fddZ  ZS )
JanusModelr   c                    s   t  | || _t|j| _t| jj| _t	|j
| _t| jjj| jjj| _t| jj| _t| jj| _tj|jd| _d| _|   d S )Nr  F)rV   rW   r   r
  _from_configr@   vision_modelr  alignerr\  r   vqmodelr   	Embeddingrv   ru   generation_embeddingsrh  generation_alignerri  generation_headr/   from_configr   language_modelr^  r_  r   rd   rU   rf   rW   j  s   zJanusModel.__init__c                 C   s
   | j  S r   )rw  get_input_embeddingsrb   rU   rU   rf   rx    s   
zJanusModel.get_input_embeddingsc                 C   s   | j | d S r   )rw  set_input_embeddingsrb   valuerU   rU   rf   rz    s   zJanusModel.set_input_embeddingsc                 C   s   |  |}| |j}|S r   )ro  rp  rW  )rb   r   image_embedsrU   rU   rf   get_image_features  s   
zJanusModel.get_image_features	input_idsinputs_embedsimage_featuresc                 C   s   |du r||   tj| jjtj|jdk}|d}n|| jjk}| }|	d
||j}||  | krP|jd |jd  }td| d| |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        Nr   devicer   r   rs   z6Image features and image tokens do not match: tokens: z, features )rx  r   rl  r   r   longr  allsum	unsqueeze	expand_asr   numelr   r   )rb   r  r  r  special_image_maskn_image_tokensn_image_featuresrU   rU   rf   get_placeholder_mask  s   zJanusModel.get_placeholder_maskNr   r   r   r   r   cache_position	use_cachelogits_to_keepc
              
   K   s   |d u |d uA rt d|d u r|  |}|d ur>| |}|d|jd }||j|j}| j|||d}|	||}| j
d|||||||	d|
}t|j|j|j|j|d ur_|dS d dS )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either oner   )r  r  )r  r   r   r   r  r  r  )rW  r   r   
attentionsimage_hidden_statesrU   )r   rx  r~  r   r   r   r  r   r  masked_scatterrw  r   rW  r   r   r  )rb   r  r   r   r   r   r  r  r  r  rc   r}  r  image_attention_mask	lm_outputrU   rU   rf   r     sD   

zJanusModel.forward)	NNNNNNNNr   )rg   rh   ri   r   rW   rx  rz  r~  r   r'  r   r  r(   r'   r   r   r   r   r   r   r   rm   rU   rU   rd   rf   rm  d  sT    
	
rm  c                       sJ  e Zd ZddgZdZdef fddZdd Zd	d
 Zde	j
de	j
fddZee										d&dee	j dee	j dee	j
 dee	j dee dee	j dee	j dee	j dee deee	j
f dee fddZ						d' fdd	Zd e	j
fd!d"Ze	j			d(dee	j
 dee	j d#ee f fd$d%Z  ZS ))JanusForConditionalGenerationz(model.language_model.embed_tokens.weightzlm_head.weightTr   c                    sB   t  | || _t|| _tj|jj|jj	dd| _
|   d S )NFr   )rV   rW   r   rm  r   r   r   r   rL   
vocab_sizelm_headr_  r   rd   rU   rf   rW     s
   
z&JanusForConditionalGeneration.__init__c                 C   s   | j j S r   )r   rw  rx  ry  rU   rU   rf   rx    s   z2JanusForConditionalGeneration.get_input_embeddingsc                 C   s   | j j| d S r   )r   rw  rz  r{  rU   rU   rf   rz    s   z2JanusForConditionalGeneration.set_input_embeddingsinputsr   c                 C   s   | j |}| j |}|S r   )r   rs  rt  )rb   r  rV  rU   rU   rf   'prepare_embeddings_for_image_generation  s   zEJanusForConditionalGeneration.prepare_embeddings_for_image_generationNr   r  r   r   r   r   r  r  labelsr  r  rc   c                 K   s   | j d|||||||	|d|}|j}t|
trt|
 dn|
}| |dd|ddf }d}|durD| jd||| jjj	d|}t
|||j|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        )r  r   r   r   r   r  r  r  N)logitsr  r  )lossr  r   r   r  r  rU   )r   rW  r   r   slicer  loss_functionr   r   r  r   r   r   r  r  )rb   r  r   r   r   r   r  r  r  r  r  rc   outputsr   slice_indicesr  r  rU   rU   rf   r     s<   	z%JanusForConditionalGeneration.forwardc           
         s8   t  j|f|||||d|}	|d dkr||	d< |	S )N)r   r  r   r  r  r   r   )rV   prepare_inputs_for_generation)
rb   r  r   r   r   r  r  r  rc   model_inputsrd   rU   rf   r  "  s   z;JanusForConditionalGeneration.prepare_inputs_for_generationr  c                 C   s"   | j j|}|dddd}|S )a,  
        Decodes generated image tokens from language model to continuous pixel values
        with VQGAN module via upsampling.
        Args:
            image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
                The tensors corresponding to the input images.
        r   r,   r	   rs   )r   rq  rc  r"  )rb   r  decoded_imagerU   rU   rf   decode_image_tokens@  s   z1JanusForConditionalGeneration.decode_image_tokenslogits_processorc           %         sx  | d| j}t|}| dd}|dkr$t jd|||d d|S |jdi |}| tj	tj
fvr:td|  | |  |d urK|nt }d|d< |jd u r_td d	|_|j|d
< | ||j|\}}	}|j|j}
}t|jdkrtd|j d|d u}| j|||jd |jr|jdkr|t|j d |_| j||jd |d ||d}| jd|||jd|\}}| jjj j!}|j\}}|"dd}| dd }|"dd}||d< ||d d d f |jk||d d d f |j#d k@ }||d d d f $||j% | & |}| '|||}|(dd d u r<| j)|j*p,d|d t+|j,|| |d|d< t-j.||f|
|d}|j/}|j0}|j1}|j2}|j3}|r]|r]dnd }|rg|rgdnd }|rq|rqdnd }|r{|r{dnd }t4|D ]}| j5d||d|}|d 6|j|d< |d 6|j|d< | jj7di |||d}| 8||}|j9d d dd d f : } | j;| }!|||!}"|j<rt-j=|"dd}#t-j>|#dd?d}$nt-j@|"dd}$|$|d d |f< t-A|$|$g}$|$Bd}$| C|$}q|r,|r||!f7 }|r|| D f7 }|r$||jE7 }|r,||jF7 }|r:tG||!||||jHdS |S ) Ngeneration_configgeneration_modetext)r  r   r  guidance_scalezGot incompatible mode for Image Generation, should be one of greedy or sampling. Ensure that beam search is de-activated by setting `num_beams=1`.Tr  zU`guidance_scale` is required for CFG but not provided. Setting to default value of 5.   r  r,   z;Expected input ids of shape (batch_size, seq_len), but got z3Passing `inputs embeds` is not supported currently.)r  rs   )r  input_ids_seq_lengthencoder_input_idsprefix_allowed_tokens_fnr  r  )r  r   expand_sizer   boi_token_idr   static)cache_implementationr   max_cache_lenmodel_kwargsr  rU   )r  r  r  )output_attentionsoutput_hidden_statesr   )r  )num_samples)	sequencesscoresr  r  r   r   )Ipopr  copydeepcopyrV   generateupdateget_generation_moder   SAMPLEGREEDY_SEARCHr   validate_validate_model_kwargsr   r  r   warning_prepare_model_inputsbos_token_idr   r  rB  r   _prepare_special_tokensrH  r   _get_logits_processor_expand_inputs_for_generationnum_return_sequencesr   ro  r   ra   repeatgeneration_kwargsmasked_fill_pad_token_idrx  _get_initial_cache_positionr   
_get_cacher  max
max_lengthr   zerosr  r  output_scoresoutput_logitsreturn_dict_in_generater  r  r   rw  #_update_model_kwargs_for_generationrW  cloneru  	do_samplesoftmaxmultinomialsqueezeargmaxcatr  r  r   r  r   r   r   )%rb   r  r   r  rc   r  r  r  r  model_input_namer   r  kwargs_has_attention_maskra   r   r   input_tokensmaskr  generated_tokensr  r  r  r  r  
raw_scores
raw_logitsdecoder_hidden_statesdecoder_attentionsir  r  rV  r  next_token_scoresprobs
next_tokenrd   rU   rf   r  L  s   	

















	z&JanusForConditionalGeneration.generate)
NNNNNNNNNr   )NNNNNN)NNN)rg   rh   ri   _tied_weights_keysr   r   rW   rx  rz  r   r   r  r(   r'   r   r'  r   r   r   r   r   r$   r&   r   r  r  no_gradr   r  rm   rU   rU   rd   rf   r    sz    		
6r  c                #       s  e Zd ZdZdddejdddddddfdedeee	e
f  de
d	ed
edee
ef dedeeeee f  deeeee f  dee dee f fddZ			d%dejdee
ee
e
e
f f deee	ef  deee	ef  dejf
ddZejddfdejdeee	e
f e
f d	edeee	ef  deee	ef  dejfddZe ddddddddddddejdfdedee deee	e
f  d	ee d
ee dee dee deeeee f  deeeee f  deee	ef  dee deee
ee
e
e
f f  dee dedeee	ef  dejjf dd Z							d&ded
ee dee dee deee  deee  dee	 dee	 fd!d"Z	d'dejdeeee f deeee f deee	ef  dejf
d#d$Z  ZS )(JanusImageProcessora  
    Constructs a JANUS image processor.

    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
            `do_resize` parameter in the `preprocess` method.
        size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
            Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
            method.
        min_size (`int`, *optional*, defaults to 14):
            The minimum allowed size for the resized image. Ensures that neither the height nor width
            falls below this value after resizing.
        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
            Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
            overridden by the `resample` parameter in the `preprocess` method.
        do_rescale (`bool`, *optional*, defaults to `True`):
            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
            `do_rescale` parameter in the `preprocess` method.
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
            overridden by the `rescale_factor` parameter in the `preprocess` method.
        do_normalize (`bool`, *optional*, defaults to `True`):
            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
            method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
        image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
            overridden by the `image_mean` parameter in the `preprocess` method.
        image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
            Can be overridden by the `image_std` parameter in the `preprocess` method.
        do_convert_rgb (`bool`, *optional*, defaults to `True`):
            Whether to convert the image to RGB.
        do_pad (`bool`, *optional*, defaults to `True`):
            Whether to pad the image to square or not.
    TN   gp?	do_resizer   min_sizeresample
do_rescalerescale_factordo_normalize
image_mean	image_stddo_convert_rgbdo_padc                    sH   t  jdi | || _|| _|d u rd| _d S tdd |D | _d S )N)   r  r  c                 s   s    | ]	}t |d  V  qdS )   N)r   )r  xrU   rU   rf   	<genexpr>J  s    z/JanusImageProcessor.__init__.<locals>.<genexpr>rU   )rV   rW   r  r  background_colorrE  )rb   r  r   r  r  r  r  r  r  r  r  r  rc   rd   rU   rf   rW   4  s   
zJanusImageProcessor.__init__r   imager  data_formatinput_data_formatr   c                 C   s  t ||\}}|tjkr|jd n|jd }||kr*|dur&t|||}|S |}|S t||}t|tr8|g}nt||krFt	d| d|tjkrt
j|||f|jd}	t|D ]\}
}||	|
ddddf< qZ||kr|| d }||	dd||| ddf< |	S || d }||	dddd||| f< |	S t
j|||f|jd}	t|D ]\}
}||	dddd|
f< q||kr|| d }||	||| ddddf< |	S || d }||	dd||| ddf< |	S )a}  
        Pads an image to a square based on the longest edge.

        Args:
            image (`np.ndarray`):
                The image to pad.
            background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
                The color to use for the padding. Can be an integer for single channel or a
                tuple of integers representing for multi-channel images. If passed as integer
                in multi-channel mode, it will default to `0` in subsequent channels.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output image. Can be one of:
                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                If unset, will use same as the input image.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the input image. Can be one of:
                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

        Returns:
            `np.ndarray`: The padded image.
        r   r   Nz(background_color must have no more than z) elements to match the number of channelsr   r,   )r   r   FIRSTr   r   r  r   r   rB  r   npr  r   	enumerate)rb   r  r  r   r  r   r   rO   max_dimresultr  colorstartrU   rU   rf   pad_to_squareL  sL   




z!JanusImageProcessor.pad_to_squarec                 K   s   |du rt |}t||\}}t||}	t|dd}|d |d kr0td|d  d|d  |d }||	 }
tt||
 | jtt||
 | jg}t|f||||d|}|S )	an  
        Resize an image to dynamically calculated size.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`dict[str, int]` or `int`):
                The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`.
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
            data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `None`: will be inferred from input
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.

        Returns:
            `np.ndarray`: The resized image.
        NTdefault_to_squarer   r   z5Output height and width must be the same. Got height=z and width=)r   r  r   r  )r   r   r  r   r   r   r  r   )rb   r  r   r  r   r  rc   r   r   max_sizedeltaoutput_size_nonpaddedrU   rU   rf   r     s2   #
zJanusImageProcessor.resizeimagesreturn_tensorsc              
      s  |dur|nj }durnj|dur|nj}dur!nj|dur*|nj}dur3njdur<nj|durE|nj}|durN|nj} durW nj	 dur`nj
tdd|}t|}t|sztdt|||d |rdd |D }dd |D }|rt|d	 rtd
 du rt|d	 |rfdd|D }|rȇ fdd|D }|rՇfdd|D }|rfdd|D }fdd|D }td|i|
d}|S )a`  
        Preprocess an image or batch of images.

        Args:
            images (`ImageInput`):
                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
                passing in images with pixel values between 0 and 1, set `do_rescale=False`.
            do_resize (`bool`, *optional*, defaults to `self.do_resize`):
                Whether to resize the image.
            size (`dict[str, int]`, *optional*, defaults to `self.size`):
                Controls the size of the image after `resize`. The shortest edge of the image is resized to
                `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
                is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
                edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
            resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
                Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
                Whether to rescale the image values between [0 - 1].
            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
                Rescale factor to rescale the image by if `do_rescale` is set to `True`.
            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
                Whether to normalize the image.
            image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
                Image mean to normalize the image by if `do_normalize` is set to `True`.
            image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
                Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
            do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
                Whether to convert the image to RGB.
            background_color (`tuple[int, int, int]`):
                The background color to use for the padding.
            do_pad (`bool`, *optional*, defaults to `self.do_pad`):
                Whether to pad the image to square or not.
            return_tensors (`str` or `TensorType`, *optional*):
                The type of tensors to return. Can be one of:
                    - Unset: Return a list of `np.ndarray`.
                    - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
                    - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
                    - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
                The channel dimension format for the output image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - Unset: Use the channel dimension format of the input image.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
        NFr
  zkInvalid image type. Must be of type PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray.)r  r  r  r  r  r  r   r  c                 S      g | ]}t |qS rU   )r   r  r  rU   rU   rf   r  ?  r  z2JanusImageProcessor.preprocess.<locals>.<listcomp>c                 S   r  rU   )r   r  rU   rU   rf   r  B  r  r   zIt looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.c                    s   g | ]}j | d qS ))r  r   r  r  )r   r  )r  r  rb   r   rU   rf   r  O      c                    s   g | ]
}j | d qS ))r  r  r  )r	  r  )r  r  rb   rU   rf   r  V  s    c                    s   g | ]
}j | d qS ))r  r   r  )rescaler  )r  r  rb   rU   rf   r  `  s    c                    s   g | ]}j | d qS )r  meanstdr  )r!  r  )r  r  r  rb   rU   rf   r  f  r  c                    s   g | ]	}t | d qS )input_channel_dim)r   r  )r   r  rU   rf   r  k  s    r   datatensor_type)r  r  r  r  r  r  r  r  r  r  r   r   fetch_imagesr   r   r   r    r   r   warning_oncer   r   )rb   r  r  r   r  r  r  r  r  r  r  r  r  r  r   r  encoded_outputsrU   )	r  r   r  r  r  r  r  rb   r   rf   
preprocess  st   F
	zJanusImageProcessor.preprocessc	                 C   sR  |dur|n| j }|du rd| j n|}|dur|n| j}|dur#|n| j}|dur,|n| j}t|}t|d tjjrHt	|dkrD|S |d S |du rRt
|d }g }	|D ]@}
t|
}
|rg| j|
|||d}
|r{| j|
||d}
|
ddtj}
|r|r|dkrt|
tj|d	}
tj|
}
|	|
 qVd
|	i}|dkr|nd}t||dS )znApplies post-processing to the decoded image tokens by reversing transformations applied during preprocessing.Ng      ?r   rs   )r  r  r  r  )r   r  r  zPIL.Image.Imager  r   r  )r  r  r  r  r  r   r   PILImagerB  r   r   unnormalizer  clipastyper  uint8r   r   LAST	fromarrayrH  r   )rb   r  r  r  r  r  r  r  r  r   r  r  rU   rU   rf   postprocesss  s6   zJanusImageProcessor.postprocessc                 C   s   d}t |trt||krtd| dt| n|g| }t |tr7t||kr6td| dt| n|g| }tdd t||D }tdd |D }| j||||d}|S )	a~  
        Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
        image = (image * image_std) + image_mean
        Args:
            image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
                Batch of pixel values to postprocess.
            image_mean (`float` or `Iterable[float]`):
                The mean to use for unnormalization.
            image_std (`float` or `Iterable[float]`):
                The standard deviation to use for unnormalization.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
        r	   zmean must have z$ elements if it is an iterable, got zstd must have c                 s   s    | ]
\}}| | V  qd S r   rU   )r  r  r  rU   rU   rf   r    s    z2JanusImageProcessor.unnormalize.<locals>.<genexpr>c                 s   s    | ]}d | V  qdS )rs   NrU   )r  r  rU   rU   rf   r    s    r  )r   r   rB  r   rE  zipr!  )rb   r  r  r  r  rO   rev_image_meanrev_image_stdrU   rU   rf   r#    s"   



zJanusImageProcessor.unnormalize)r   NN)NNNNNNNr   ) rg   rh   ri   rj   r   BICUBICr   r   r   strr   r   r   r   rW   r  ndarrayrE  r   r	  r   r)   r  r   r%   r!  r"  r   r)  r   r#  rm   rU   rU   rd   rf   r    s.   )
	

N
A	
 

	
8r  )	r  r   r  rm  r\  r
  rn   r>   r   )|r  collections.abcr   dataclassesr   typingr   r   r   numpyr  r   torch.nn.functionalr   
functionalr   torch.utils.checkpoint.transformers.models.blip.image_processing_blipr   activationsr
   cache_utilsr   configuration_utilsr   
generationr   r   r   r   generation.utilsr   image_processing_utilsr   r   image_transformsr   r   r   image_utilsr   r   r   r   r   r   r   r   r   r    modeling_outputsr!   modeling_utilsr"   r#   processing_utilsr$   utilsr%   r&   r'   r(   r)   r*   r+   autor-   r.   r/   blip_2.modeling_blip_2r0   !chameleon.configuration_chameleonr1   chameleon.modeling_chameleonr2   r3   r4   r5   r6   idefics.modeling_ideficsr7   r8   llama.modeling_llamar9   siglip.configuration_siglipr:   siglip.modeling_siglipr;   r<   r=   r!  
get_loggerrg   r   r>   rn   r   r   r   r   r   r   rI  r   r   r   r  r
  r  r  r(  r)  r*  r+  r6  r>  rX  r\  rh  ri  rm  r  r  __all__rU   rU   rU   rf   <module>   s   0$	
aZnLMD0l  9   K