o
    i ]                     @   s  d dl Z d dlmZ d dlmZ d dlZd dlmZ d dlm  m	Z
 d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d	d
lmZmZ G dd dejZG dd dejZG dd dejZG dd dejZG dd dejZdejdede eje eef f fddZ!dejdede eef de eef dejf
ddZ"ded ed!ejdejfd"d#Z#d$ejd%ejd&ejde eef d e eef dejfd'd(Z$G d)d* d*ejZ%d+d, Z&d-d. Z'G d/d0 d0eZ(G d1d2 d2ejZ)dS )3    N)Iterable)partial)CLIPVisionConfig)MMEncoderAttention)Conv2dLayer)QuantizationConfig)default_weight_loader   )CLIPEncoderCLIPVisionEmbeddingsc                	       sP   e Zd Zejfdededeej ddf fddZde	j
de	j
fd	d
Z  ZS )MLPBlockembedding_dimmlp_dimactreturnNc                    s2   t    t||| _t||| _| | _d S N)super__init__nnLinearlin1lin2r   )selfr   r   r   	__class__ \/home/ubuntu/vllm_env/lib/python3.10/site-packages/vllm/model_executor/models/deepencoder.pyr      s   
zMLPBlock.__init__xc                 C   s   |  | | |S r   )r   r   r   r   r   r   r   r   forward)   s   zMLPBlock.forward)__name__
__module____qualname__r   GELUinttypeModuler   torchTensorr   __classcell__r   r   r   r   r      s    r   c                       sB   e Zd Zddededdf fddZdejdejfd	d
Z  Z	S )LayerNorm2dư>num_channelsepsr   Nc                    s8   t    tt|| _tt|| _|| _	d S r   )
r   r   r   	Parameterr'   onesweightzerosbiasr-   )r   r,   r-   r   r   r   r   0   s   

zLayerNorm2d.__init__r   c                 C   sn   |j ddd}|| dj ddd}|| t|| j  }| jd d d d f | | jd d d d f  }|S )Nr	   T)keepdim   )meanpowr'   sqrtr-   r0   r2   )r   r   usr   r   r   r   6   s
   ,zLayerNorm2d.forward)r+   )
r    r!   r"   r$   floatr   r'   r(   r   r)   r   r   r   r   r*   /   s    r*   c                %       s   e Zd Zdddddddddejejdd	dd
ddfdededededededededede	ej
 de	ej
 dededededeedf deddf$ fd d!Zd"ejd#efd$d%Zd&ejdejfd'd(Z  ZS ))ImageEncoderViT                     @   TFr   r   img_size
patch_sizein_chans	embed_dimdepth	num_heads	mlp_ratio	out_chansqkv_bias
norm_layer	act_layeruse_abs_posuse_rel_posrel_pos_zero_initwindow_sizeglobal_attn_indexes.last_conv_outputr   Nc                    s  t    || _t||f||f||d| _d| _|r*tt	d|| || || _t
 | _t|D ]"}t||||	|
|||||vrD|nd|| || fd
}| j| q3tt||dddt|t||dddd	t|| _td
dddddd| _td|ddddd| _dS )a  
        Args:
            img_size (int): Input image size.
            patch_size (int): Patch size.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
            depth (int): Depth of ViT.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_abs_pos (bool): If True, use absolute positional embeddings.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks.
            global_attn_indexes (list): Indexes for blocks using global attention.
        )kernel_sizestriderE   rF   Nr	   r   )
dimrH   rI   rK   rL   rM   rO   rP   rQ   
input_sizeF)rT   r2   r>   )rT   paddingr2   rB   i   r4   )rT   rU   rX   r2   )r   r   rC   
PatchEmbedpatch_embed	pos_embedr   r.   r'   r1   
ModuleListblocksrangeBlockappend
Sequentialr   r*   necknet_2net_3)r   rC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rS   iblockr   r   r   r   @   sj   
&
zImageEncoderViT.__init__abs_postgt_sizec                 C   sj   |j }|d}||kr3|dddd}|tj}tj|||fdddd|}|dddd}|S |S )	Nr	   r   r>   r4   bicubicTFsizemode	antialiasalign_corners)dtyperk   permutetor'   float32Finterpolate)r   rg   rh   ro   src_sizeold_pos_embednew_pos_embedr   r   r   get_abs_pos   s"   
zImageEncoderViT.get_abs_posr   c                 C   sn   |  |}| jd ur|| | j|d }| jD ]}||}q| |dddd}| |}| |}|S )Nr	   r   r>   r4   )	rZ   r[   rx   rk   r]   rb   rp   rc   rd   )r   r   blkneck_outputconv2_outputconv3_outputr   r   r   r      s   





zImageEncoderViT.forward)r    r!   r"   r   	LayerNormr#   r$   r:   boolr%   r&   tupler   r'   r(   rx   r   r)   r   r   r   r   r;   ?   sr    	

br;   c                       s   e Zd ZdZddejejddddfdeded	ed
e	de
ej de
ej de	de	dedeeef dB ddf fddZdejdejfddZ  ZS )r_   zWTransformer blocks with support of window attention and residual propagation
    blocksrA   TFr   NrV   rH   rI   rK   rL   rM   rO   rP   rQ   rW   r   c                    sf   t    ||| _t||||||	dkr|
n|	|	fd| _||| _t|t|| |d| _|	| _	dS )ai  
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks. If it equals 0, then
                use global attention.
            input_size (tuple(int, int) or None): Input resolution for calculating the relative
                positional parameter size.
        r   )rH   rK   rO   rP   rW   )r   r   r   N)
r   r   norm1RelPosAttentionattnnorm2r   r$   mlprQ   )r   rV   rH   rI   rK   rL   rM   rO   rP   rQ   rW   r   r   r   r      s   


	
zBlock.__init__r   c                 C   s   |}|  |}| jdkr|jd |jd }}t|| j\}}| |}| jdkr3t|| j|||f}|| }|| | | }|S )Nr   r	   r4   )r   rQ   shapewindow_partitionr   window_unpartitionr   r   )r   r   shortcutHWpad_hwr   r   r   r      s   



zBlock.forward)r    r!   r"   __doc__r   r}   r#   r$   r:   r~   r%   r&   r   r   r'   r(   r   r)   r   r   r   r   r_      sD    	
.r_   c                       sl   e Zd ZdZ					ddededed	ed
edeeef dB ddf fddZdej	dej	fddZ
  ZS )r   z=Multi-head Attention block with relative position embeddings.   TFNrV   rH   rK   rO   rP   rW   r   c                    s   t    || _|| }|d | _tj||d |d| _t||| _|| _| jrS|dus1J dt	t
d|d  d || _t	t
d|d  d || _dS dS )	a  
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            input_size (tuple(int, int) or None): Input resolution for calculating the relative
                positional parameter size.
        g      r>   )r2   NzBInput size must be provided if using relative positional encoding.r4   r   r	   )r   r   rH   scaler   r   qkvprojrO   r.   r'   r1   	rel_pos_h	rel_pos_w)r   rV   rH   rK   rO   rP   rW   head_dimr   r   r   r     s   


 $zRelPosAttention.__init__r   c              	   C   s  |j \}}}}| |||| d| jdddddd}|d|| j || dd\}}}	d\}
}| jrGt|| j| j	||f||f\}
}|
|| j|| d}|
|| j|| d}|	
|| j|| d}	| jr|

|| j|
d|
d|
d}
|
|| j|d|d|d}|
| 
|| j|
d|
d|d }tjjj|||	|d}n	tjj|||	}|
|| j||dddddd|||d}| |}|S )	Nr>   r4   r   r	      )NN)	attn_mask)r   r   reshaperH   rp   unbindrO   add_decomposed_rel_posr   r   viewrk   r'   r   
functionalscaled_dot_product_attentionr   )r   r   Br   r   _r   qkvrel_hrel_w	attn_biasr   r   r   r   0  s@   *& 
zRelPosAttention.forward)r   TFTN)r    r!   r"   r   r$   r~   r   r   r'   r(   r   r)   r   r   r   r   r   
  s.    #r   r   rQ   r   c              	   C   s   | j \}}}}|||  | }|||  | }|dks|dkr+t| ddd|d|f} || || }}	| ||| ||	| ||} | dddddd d|||}
|
||	ffS )aU  
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.

    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    r   r	   r>   r4   r      r   )r   rs   padr   rp   
contiguous)r   rQ   r   r   r   Cpad_hpad_wHpWpwindowsr   r   r   r   ^  s   "r   r   r   hwc           
      C   s   |\}}|\}}| j d || | |  }| ||| || ||d}	|	dddddd |||d}	||ks=||krO|	ddd|d|ddf  }	|	S )	a  
    Window unpartition into original sequences and removing padding.
    Args:
        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
        window_size (int): window size.
        pad_hw (Tuple): padded height and width (Hp, Wp).
        hw (Tuple): original height and width (H, W) before padding.

    Returns:
        x: unpartitioned sequences with [B, H, W, C].
    r   r   r	   r>   r4   r   r   N)r   r   rp   r   )
r   rQ   r   r   r   r   r   r   r   r   r   r   r   r   z  s   $$r   q_sizek_sizerel_posc           	      C   s   t dt| | d }|jd |kr>|j}|tj}tj|	d|jd d
ddd|dd|}|	d|
dd}n|}tj| |jddddf t||  d	 }tj||jddddf t| | d	 }|| |d t| | d	  }||  S )
a\  
    Get relative positional embeddings according to the relative positions of
        query and key sizes.
    Args:
        q_size (int): size of query q.
        k_size (int): size of key k.
        rel_pos (Tensor): relative position embeddings (L, C).

    Returns:
        Extracted positional embeddings according to relative positions.
    r4   r	   r   r   linear)rk   rl   )deviceNg      ?)r$   maxr   ro   rq   r'   rr   rs   rt   r   rp   aranger   long)	r   r   r   max_rel_distro   rel_pos_resizedq_coordsk_coordsrelative_coordsr   r   r   get_rel_pos  s*   r   r   r   r   c                 C   s   |\}}|\}}t |||}	t |||}
| j\}}}| ||||}td||	}td||
}|d}|d}|||| |d}|||| d|}||fS )a  
    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
    Args:
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

    Returns:
        attn (Tensor): attention map with added relative positional embeddings.
    zbhwc,hkc->bhwkzbhwc,wkc->bhwkr   r	   )r   r   r   r'   einsum	unsqueeze)r   r   r   r   r   q_hq_wk_hk_wRhRwr   r   rV   r_qr   r   r   r   r   r     s   

r   c                       st   e Zd ZdZ					ddeeef deeef deeef d	ed
eddf fddZdejdejfddZ	  Z
S )rY   z#
    Image to Patch Embedding.
    r=   r=   r   r   r>   r?   rT   rU   rX   rE   rF   r   Nc                    s"   t    t|||||d| _dS )aP  
        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
        )rT   rU   rX   N)r   r   r   r   )r   rT   rU   rX   rE   rF   r   r   r   r     s   

zPatchEmbed.__init__r   c                 C   s   |  |}|dddd}|S )Nr   r4   r>   r	   )r   rp   r   r   r   r   r     s   
zPatchEmbed.forward)r   r   r   r>   r?   )r    r!   r"   r   r   r$   r   r'   r(   r   r)   r   r   r   r   rY     s*    


rY   c                   C   s   t dddg ddS )Nr?   r@   )r4   r   r      )encoder_embed_dimencoder_depthencoder_num_headsencoder_global_attn_indexes)
_build_samr   r   r   r   build_sam_vit_b  s   r   c                 C   s<   d}d}d}t || |dttjjdd||dd|d|d	}|S )
NrB   r<   r=   r   r+   r-   T   )rG   rF   rC   rI   rL   rH   rD   rK   rO   rR   rQ   rJ   )r;   r   r'   r   r}   )r   r   r   r   prompt_embed_dim
image_sizevit_patch_sizeimage_encoderr   r   r   r     s$   r   c                   @   sD   e Zd ZdejdefddZ	ddejdejdB dejfd	d
ZdS )DeepCLIPVisionEmbeddingsrg   rh   c                 C   s   | d}|d}|d d |dd  }}tt|jd d }tt|}|j}||kr||d|||dddd	 }|
tj}tj|||fdddd	
|}	|	dddd}	|	|| |}	tj||	gdd
}
|
d|| d |}
|
S |S )Nr   r   r	   r>   r4   ri   TFrj   rV   )rk   squeezer$   mathr7   r   ro   r   rp   r   rq   r'   rr   rs   rt   cat)r   rg   rh   rV   abs_pos_new	cls_tokenrv   ru   ro   rw   vision_pos_embedr   r   r   rx   +  s6   

z$DeepCLIPVisionEmbeddings.get_abs_posNpixel_valuespatch_embedsr   c                 C   sx   |j d }|d ur|}n| |}|ddd}| j|dd}tj||gdd}|| | 	| j
|d }|S )Nr   r4   r	   r   r   )r   patch_embeddingflatten	transposeclass_embeddingexpandr'   r   rx   position_embeddingposition_idsrk   )r   r   r   
batch_sizeclass_embeds
embeddingsr   r   r   r   M  s   

z DeepCLIPVisionEmbeddings.forwardr   )r    r!   r"   r'   r(   r$   rx   r   r   r   r   r   r   *  s    #r   c                       s   e Zd Z	dddddededB dedB deddf
 fd	d
Zedd Z	edd Z
	ddddejdejdB dee dB dejfddZdeeeejf  dee fddZ  ZS )DeepCLIPVisionTransformerN )num_hidden_layers_overrideprefixconfigquant_configr   r   r   c                   s   t    || _|j}t|| _tj||jd| _	t
|||| dtd| _|j}t| jj|jkrAtd| dt| jj dd S )Nr   z.encoder)r   r   r   r   attn_clszThe original encoder only has z layers, but you requested z layers.)r   r   r   hidden_sizer   r   r   r}   layer_norm_epspre_layrnormr
   r   transformernum_hidden_layerslenlayers
ValueError)r   r   r   r   r   rF   r   r   r   r   r   `  s(   


z"DeepCLIPVisionTransformer.__init__c                 C      t |  jS r   )next
parametersro   r   r   r   r   ro        zDeepCLIPVisionTransformer.dtypec                 C   r   r   )r   r   r   r   r   r   r   r     r   z DeepCLIPVisionTransformer.device)select_layersr   r   r   c                C   s,   |  ||}| |}| j||d ud}|S )N)inputs_embedsreturn_all_hidden_states)r   r   r   )r   r   r   r   hidden_statesencoder_outputsr   r   r   r     s   
z!DeepCLIPVisionTransformer.forwardweightsc                 C   sL   t |  }t }|D ]\}}|| }t|dt}||| || q|S )Nweight_loader)dictnamed_parameterssetgetattrr   add)r   r  params_dictloaded_paramsnameloaded_weightparamr  r   r   r   load_weights  s   
z&DeepCLIPVisionTransformer.load_weightsr   )r    r!   r"   r   r   r$   strr   propertyro   r   r'   r(   listr   r   r   r  r  r)   r   r   r   r   r   _  sB    "



,r   )*r   collections.abcr   	functoolsr   r'   torch.nnr   torch.nn.functionalr   rs   transformersr   $vllm.model_executor.layers.attentionr   vllm.model_executor.layers.convr   'vllm.model_executor.layers.quantizationr   -vllm.model_executor.model_loader.weight_utilsr   clipr
   r   r&   r   r*   r;   r_   r   r(   r$   r   r   r   r   r   rY   r   r   r   r   r   r   r   r   <module>   sn    ET



'


%#	5