o
    i                    @   s  d Z ddlmZ ddlmZ ddlmZmZmZ ddl	Z	ddl	m
Z
mZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZmZmZmZmZ ddlm Z m!Z!m"Z" e riddl#m$Z$ e%e&Z'de	j
de	j
fddZ(de	j
de	j
fddZ)eeG dd deZ*de
de
fddZ+de
de
fddZ,d d! Z-d"d# Z.eed$d%G d&d' d'eZ/eed(d%G d)d* d*eZ0G d+d, d,ej1Z2G d-d. d.ej1Z3G d/d0 d0ej1Z4G d1d2 d2ej1Z5G d3d4 d4eZ6eG d5d6 d6eZ7G d7d8 d8ej1Z8G d9d: d:ej1Z9G d;d< d<e7Z:G d=d> d>ej1Z;G d?d@ d@e7Z<eG dAdB dBe7Z=G dCdD dDej1Z>G dEdF dFej1Z?G dGdH dHe7Z@g dIZAdS )JzPyTorch OWL-ViT model.    )	dataclass)	lru_cache)AnyOptionalUnionN)Tensornn   )ACT2FN) _create_4d_causal_attention_mask_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel)ModelOutputauto_docstringfilter_out_non_signature_kwargsis_vision_availablelogging	torch_int   )OwlViTConfigOwlViTTextConfigOwlViTVisionConfig)center_to_corners_formatlogitsreturnc                 C   s   t j| tjt| | jdS )Ndevice)r   
functionalcross_entropytorcharangelenr   )r    r%   g/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/transformers/models/owlvit/modeling_owlvit.pycontrastive_loss3   s   r'   
similarityc                 C   s    t | }t |  }|| d S )Ng       @)r'   t)r(   caption_loss
image_lossr%   r%   r&   owlvit_loss8   s   r,   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< d
ee fddZdS )OwlViTOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of
        [`OwlViTVisionModel`].
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`OwlViTTextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`OwlViTVisionModel`].
    Nlosslogits_per_imagelogits_per_texttext_embedsimage_embedstext_model_outputvision_model_outputr   c                       t  fdd  D S )Nc                 3   .    | ]}|d vr | nt  | V  qdS )r3   r4   Ngetattrto_tuple.0kselfr%   r&   	<genexpr>^   
    
z(OwlViTOutput.to_tuple.<locals>.<genexpr>tuplekeysr>   r%   r>   r&   r:   ]      zOwlViTOutput.to_tuple)__name__
__module____qualname____doc__r.   r   r"   FloatTensor__annotations__r/   r0   r1   r2   r3   r   r4   rC   r   r:   r%   r%   r%   r&   r-   >   s   
 r-   r)   c                 C   sD   |   r| jtjtjfv r| S |  S | jtjtjfv r| S |  S N)	is_floating_pointdtyper"   float32float64floatint32int64int)r)   r%   r%   r&   _upcaste   s   rU   boxesc                 C   sH   t | } | dddf | dddf  | dddf | dddf   S )a  
    Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

    Args:
        boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
            Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
            < x2` and `0 <= y1 < y2`.

    Returns:
        `torch.FloatTensor`: a tensor containing the area for each box.
    N   r   r	   r   )rU   )rV   r%   r%   r&   box_arean   s   @rX   c           
      C   s   t | }t |}t| d d d d df |d d d df }t| d d d dd f |d d dd f }|| jdd}|d d d d df |d d d d df  }|d d d f | | }|| }	|	|fS )NrW   r   minr   )rX   r"   maxrZ   clamp)
boxes1boxes2area1area2left_topright_bottomwidth_heightinterunioniour%   r%   r&   box_iou   s   ..,rg   c                 C   s*  | ddddf | ddddf k  std|  |ddddf |ddddf k  s:td| t| |\}}t| dddddf |ddddf }t| dddddf |ddddf }|| jdd}|dddddf |dddddf  }||| |  S )z
    Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

    Returns:
        `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
    NrW   z<boxes1 must be in [x0, y0, x1, y1] (corner) format, but got z<boxes2 must be in [x0, y0, x1, y1] (corner) format, but got r   rY   r   )all
ValueErrorrg   r"   rZ   r[   r\   )r]   r^   rf   re   top_leftbottom_rightrc   arear%   r%   r&   generalized_box_iou   s   ,	,..,rm   z6
    Output type of [`OwlViTForObjectDetection`].
    )custom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dZeed
< dZeed< dee fddZdS )OwlViTObjectDetectionOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
        Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
        bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
        scale-invariant IoU loss.
    loss_dict (`Dict`, *optional*):
        A dictionary containing the individual losses. Useful for logging.
    logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
        Classification logits (including no-object) for all queries.
    pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
        possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to retrieve the
        unnormalized bounding boxes.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
        image embeddings for each patch.
    class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
        Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
        number of patches is (image_size / patch_size)**2.
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`OwlViTTextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`OwlViTVisionModel`].
    Nr.   	loss_dictr   
pred_boxesr1   r2   class_embedsr3   r4   r   c                    r5   )Nc                 3   r6   r7   r8   r;   r>   r%   r&   r@      rA   z7OwlViTObjectDetectionOutput.to_tuple.<locals>.<genexpr>rB   r>   r%   r>   r&   r:      rE   z$OwlViTObjectDetectionOutput.to_tuple)rF   rG   rH   rI   r.   r   r"   rJ   rK   rp   dictr   rq   r1   r2   rr   r3   r   r4   rC   r   r:   r%   r%   r%   r&   ro      s   
 ro   zM
    Output type of [`OwlViTForObjectDetection.image_guided_detection`].
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< dZeed	< dZeed
< dee fddZdS )&OwlViTImageGuidedObjectDetectionOutputa  
    logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
        Classification logits (including no-object) for all queries.
    image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
        image embeddings for each patch.
    query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
        image embeddings for each patch.
    target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual target image in the batch
        (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to
        retrieve the unnormalized bounding boxes.
    query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual query image in the batch
        (disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to
        retrieve the unnormalized bounding boxes.
    class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
        Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
        number of patches is (image_size / patch_size)**2.
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`OwlViTTextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`OwlViTVisionModel`].
    Nr   r2   query_image_embedstarget_pred_boxesquery_pred_boxesrr   r3   r4   r   c                    r5   )Nc                 3   r6   r7   r8   r;   r>   r%   r&   r@   	  rA   zBOwlViTImageGuidedObjectDetectionOutput.to_tuple.<locals>.<genexpr>rB   r>   r%   r>   r&   r:     rE   z/OwlViTImageGuidedObjectDetectionOutput.to_tuple)rF   rG   rH   rI   r   r   r"   rJ   rK   r2   ru   rv   rw   rr   r3   r   r4   rC   r   r:   r%   r%   r%   r&   rt      s   
 rt   c                       s\   e Zd Zdef fddZdejdededejfdd	Zddej	de
dejfddZ  ZS )OwlViTVisionEmbeddingsconfigc                    s   t    |j| _|| _|j| _tt	|j| _
tj|j| j|j|jdd| _|j|j d | _| jd | _t| j| j| _| jdt| jddd d S )NF)in_channelsout_channelskernel_sizestridebiasrW   r   position_idsr   
persistent)super__init__
patch_sizery   hidden_size	embed_dimr   	Parameterr"   randnclass_embeddingConv2dnum_channelspatch_embedding
image_sizenum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr#   expandr?   ry   	__class__r%   r&   r     s    
"zOwlViTVisionEmbeddings.__init__
embeddingsheightwidthr   c                 C   s  |j d d }| jjd}|j d d }tj s(||kr(||kr(| | jS |ddddf }|ddddf }|j d }	|| j }
|| j }t	|d }|
d|||	}|dddd}tjj||
|fdd	d
}|dddddd|	}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   Nr   g      ?r	   rW   bicubicF)sizemodealign_cornersdim)shaper   weight	unsqueezer"   jit
is_tracingr   r   r   reshapepermuter   r    interpolateviewcat)r?   r   r   r   r   r   r   class_pos_embedpatch_pos_embedr   
new_height	new_widthsqrt_num_positionsr%   r%   r&   interpolate_pos_encoding%  s*   



z/OwlViTVisionEmbeddings.interpolate_pos_encodingFpixel_valuesr   c           
      C   sz   |j \}}}}| |}|ddd}| j|dd}tj||gdd}	|r3|	| |	|| }	|	S |	| 	| j
 }	|	S )NrW   r   r   r   )r   r   flatten	transposer   r   r"   r   r   r   r   )
r?   r   r   
batch_size_r   r   patch_embedsrr   r   r%   r%   r&   forwardK  s   
zOwlViTVisionEmbeddings.forwardF)rF   rG   rH   r   r   r"   r   rT   r   rJ   boolr   __classcell__r%   r%   r   r&   rx     s    $&rx   c                	       sX   e Zd Zdef fddZ			ddeej deej deej dej	fd	d
Z
  ZS )OwlViTTextEmbeddingsry   c                    sP   t    t|j|j| _t|j|j| _| j	dt
|jddd d S )Nr   r   Fr   )r   r   r   r   
vocab_sizer   token_embeddingmax_position_embeddingsr   r   r"   r#   r   r   r   r%   r&   r   Y  s   

zOwlViTTextEmbeddings.__init__N	input_idsr   inputs_embedsr   c                 C   sb   |d ur	|j d n|j d }|d u r| jd d d |f }|d u r&| |}| |}|| }|S )Nr   )r   r   r   r   )r?   r   r   r   
seq_lengthposition_embeddingsr   r%   r%   r&   r   c  s   

zOwlViTTextEmbeddings.forward)NNN)rF   rG   rH   r   r   r   r"   
LongTensorrJ   r   r   r   r%   r%   r   r&   r   X  s    r   c                       s   e Zd ZdZ fddZdejdedefddZ					
ddejde	ej de	ej de	e
 deeje	ej e	eej  f f
ddZ  ZS )OwlViTAttentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkr-td| j d| j d| jd | _	|j
| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      )r   r   ry   r   r   num_attention_heads	num_headshead_dimri   scaleattention_dropoutdropoutr   Lineark_projv_projq_projout_projr   r   r%   r&   r   z  s"   

zOwlViTAttention.__init__tensorseq_lenbszc                 C   s    | ||| j| jdd S )Nr   rW   )r   r   r   r   
contiguous)r?   r   r   r   r%   r%   r&   _shape  s    zOwlViTAttention._shapeNFhidden_statesattention_maskcausal_attention_maskoutput_attentionsr   c                 C   s  |  \}}}| || j }| | |d|}	| | |d|}
|| j d| jf}| |||j| }|	j| }	|
j| }
|	 d}t	
||	dd}|  || j ||fkrmtd|| j ||f d|   |dur|  |d||fkrtd|d||f d|   ||| j||| }||| j ||}|dur|  |d||fkrtd|d||f d|   ||| j||| }||| j ||}tjj|dd}|r||| j||}||| j ||}nd}tjj|| j| jd	}||
j}t	
||
}|  || j || jfkr*td
|| j|| jf d|   ||| j|| j}|dd}||||}| |}||fS )z#Input shape: Batch x Time x Channelr   r   rW   z$Attention weights should be of size z	, but is Nz!Attention mask should be of size r   )ptrainingz `attn_output` should be of size )r   r   r   r   r   r   r   r   r   r"   bmmr   ri   r   r    softmaxr   r   torN   r   r   )r?   r   r   r   r   r   tgt_lenr   query_states
key_statesvalue_states
proj_shapesrc_lenattn_weightsattn_weights_reshaped
attn_probsattn_outputr%   r%   r&   r     sf   	



zOwlViTAttention.forwardNNF)rF   rG   rH   rI   r   r"   r   rT   r   r   r   rC   r   r   r%   r%   r   r&   r   w  s$    r   c                       s2   e Zd Z fddZdejdejfddZ  ZS )	OwlViTMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S rL   )r   r   ry   r
   
hidden_actactivation_fnr   r   r   intermediate_sizefc1fc2r   r   r%   r&   r     s
   
zOwlViTMLP.__init__r   r   c                 C   s"   |  |}| |}| |}|S rL   )r   r   r   )r?   r   r%   r%   r&   r     s   


zOwlViTMLP.forward)rF   rG   rH   r   r"   r   r   r   r%   r%   r   r&   r     s    r   c                       sT   e Zd Zdef fddZ	ddejdejdejdee d	e	ej
 f
d
dZ  ZS )OwlViTEncoderLayerry   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S N)eps)r   r   r   r   r   	self_attnr   	LayerNormlayer_norm_epslayer_norm1r   mlplayer_norm2r   r   r%   r&   r     s   


zOwlViTEncoderLayer.__init__Fr   r   r   r   r   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r0||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r   r   )r   r   r   r   )r?   r   r   r   r   residualr   outputsr%   r%   r&   r     s"   




zOwlViTEncoderLayer.forwardr   )rF   rG   rH   r   r   r"   r   r   r   rC   rJ   r   r   r%   r%   r   r&   r     s    r   c                   @   s4   e Zd ZU eed< dZdZdgZdej	fddZ
dS )	OwlViTPreTrainedModelry   owlvitTr   modulec                 C   sZ  | j j}t|tr"|jjjjd|d d |jjjjd|d d nt|t	rSt
jj|jd|jd | d t
jj|jj|j j| d t
jj|jj|j j| d nt|tr|jd d|j j d  | }|jd | }t
jj|jj|d t
jj|jj|d t
jj|jj|d t
jj|jj|d n`t|tr|j jd d|j j d  | }d|j j d | }t
jj|jj|d t
jj|jj|d n,t|trt
jj|jj|jd | d t
jj|jj|jd | d |jj | j j! t|t
j"r|j#j$  |jj d t|t
j%r)|jjjd|d |j#dur+|j#j$  dS dS dS )	zInitialize the weights        g{Gz?)meanstdr   )r  rW         ?N)&ry   initializer_factor
isinstancer   r   r   datanormal_r   rx   r   initr   r   r   initializer_ranger   num_hidden_layersr   r   r   r   r   r   r   r   OwlViTModeltext_projectiontext_embed_dimvisual_projectionvision_embed_dimlogit_scalefill_logit_scale_init_valuer   r~   zero_r   )r?   r   factorin_proj_stdout_proj_stdfc_stdr%   r%   r&   _init_weights,  sN   



 
z#OwlViTPreTrainedModel._init_weightsN)rF   rG   rH   r   rK   base_model_prefixsupports_gradient_checkpointing_no_split_modulesr   Moduler  r%   r%   r%   r&   r   %  s   
 r   c                       st   e Zd ZdZdef fddZ					ddeej deej dee	 d	ee	 d
ee	 de
eef fddZ  ZS )OwlViTEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`OwlViTEncoderLayer`].

    Args:
        config: OwlViTConfig
    ry   c                    s4   t    t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r%   )r   )r<   r   ry   r%   r&   
<listcomp>`  s    z*OwlViTEncoder.__init__.<locals>.<listcomp>F)r   r   r   
ModuleListranger
  layersgradient_checkpointingr   r   r  r&   r   ^  s   
 
zOwlViTEncoder.__init__Nr   r   r   output_hidden_statesreturn_dictr   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|r"dnd}|r(dnd}|}	| jD ]}
|r8||	f }|
|	|||d}|d }	|rM||d f }q/|rU||	f }|sctdd |	||fD S t|	||dS )	a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`).
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr%   )r   r   r   c                 s       | ]	}|d ur|V  qd S rL   r%   )r<   vr%   r%   r&   r@         z(OwlViTEncoder.forward.<locals>.<genexpr>)last_hidden_stater   
attentions)ry   r   r$  use_return_dictr"  rC   r   )r?   r   r   r   r   r$  r%  encoder_statesall_attentionsr   encoder_layerlayer_outputsr%   r%   r&   r   c  s6   


zOwlViTEncoder.forwardNNNNN)rF   rG   rH   rI   r   r   r   r"   r   r   r   rC   r   r   r   r%   r%   r   r&   r  U  s*    
r  c                       sz   e Zd Zdef fddZe					ddejdeej deej dee	 d	ee	 d
ee	 de
eef fddZ  ZS )OwlViTTextTransformerry   c                    s@   t    || _|j}t|| _t|| _tj	||j
d| _d S r   )r   r   ry   r   r   r   r  encoderr   r   r   final_layer_norm)r?   ry   r   r   r%   r&   r     s   


zOwlViTTextTransformer.__init__Nr   r   r   r   r$  r%  r   c                 C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}| }|d|d }| j||d}t||j|j	d}	|durDt
||j}| j|||	|||d}
|
d }| |}|tj|jd |j	d|tjjdd|j	f }|s||f|
dd  S t|||
j|
jd	S )
a|  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        Nr   )r   r   r   )r   r   r   r   r$  r%  r   r   r   r)  pooler_outputr   r*  )ry   r   r$  r+  r   r   r   r   rN   r   r   r2  r3  r"   r#   r   r   rT   argmaxr   r   r*  )r?   r   r   r   r   r$  r%  input_shaper   r   encoder_outputsr)  pooled_outputr%   r%   r&   r     sF   
	
zOwlViTTextTransformer.forwardr0  )rF   rG   rH   r   r   r   r"   r   r   r   r   rC   r   r   r   r%   r%   r   r&   r1    s.    
r1  c                       s   e Zd ZU eed< def fddZdejfddZdd Z	e
								dd
ejdeej dee dee dee deeef fddZ  ZS )OwlViTTextModelry   c                    "   t  | t|| _|   d S rL   )r   r   r1  
text_model	post_initr   r   r%   r&   r        
zOwlViTTextModel.__init__r   c                 C   
   | j jjS rL   r<  r   r   r>   r%   r%   r&   get_input_embeddings     
z$OwlViTTextModel.get_input_embeddingsc                 C   s   || j j_d S rL   r@  )r?   valuer%   r%   r&   set_input_embeddings  s   z$OwlViTTextModel.set_input_embeddingsNr   r   r   r$  r%  c                 C      | j |||||dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        Examples:
        ```python
        >>> from transformers import AutoProcessor, OwlViTTextModel

        >>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
        >>> inputs = processor(
        ...     text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
        ... )
        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   r   r$  r%  )r<  )r?   r   r   r   r$  r%  r%   r%   r&   r      s   zOwlViTTextModel.forward)NNNN)rF   rG   rH   r   rK   r   r   r  rA  rD  r   r"   r   r   r   r   rC   r   r   r   r%   r%   r   r&   r:    s.   
 
r:  c                       sl   e Zd Zdef fddZe				ddejdee	 dee	 d	ee	 d
ee	 de
eef fddZ  ZS )OwlViTVisionTransformerry   c                    sP   t    || _t|| _tj|j|jd| _	t
|| _tj|j|jd| _d S r   )r   r   ry   rx   r   r   r   r   r   pre_layernormr  r2  post_layernormr   r   r%   r&   r   (  s   


z OwlViTVisionTransformer.__init__NFr   r   r$  r   r%  r   c                 C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| jjjj}||}| j||d}| 	|}| j
||||d}|d }	|	d d dd d f }
| |
}
|s^|	|
f|dd   S t|	|
|j|jdS )N)r   )r   r   r$  r%  r   r   r4  )ry   r   r$  r+  r   r   r   rN   r   rH  r2  rI  r   r   r*  )r?   r   r   r$  r   r%  expected_input_dtyper   r8  r)  r9  r%   r%   r&   r   1  s2   	


zOwlViTVisionTransformer.forward)NNFN)rF   rG   rH   r   r   r   r"   rJ   r   r   r   rC   r   r   r   r%   r%   r   r&   rG  '  s(    	
rG  c                       s   e Zd ZU eed< dZdef fddZdejfddZ	e
						ddeej d
ee dee dedee deeef fddZ  ZS )OwlViTVisionModelry   r   c                    r;  rL   )r   r   rG  vision_modelr=  r   r   r%   r&   r   b  r>  zOwlViTVisionModel.__init__r   c                 C   r?  rL   )rL  r   r   r>   r%   r%   r&   rA  h  rB  z&OwlViTVisionModel.get_input_embeddingsNFr   r$  r   r%  c                 C   rE  )a  
        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, OwlViTVisionModel

        >>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```r   r   r$  r   r%  )rL  )r?   r   r   r$  r   r%  r%   r%   r&   r   k  s   zOwlViTVisionModel.forwardNNNFN)rF   rG   rH   r   rK   main_input_namer   r   r  rA  r   r   r"   rJ   r   r   rC   r   r   r   r%   r%   r   r&   rK  ^  s0   
 
rK  c                       s   e Zd ZU eed< def fddZe e	ddej	de
ej	 dejfdd	Ze e	
ddej	dedejfddZe							
		dde
ej de
ej de
ej	 de
e de
e de
e dede
e de
e deeef fddZ  ZS )r  ry   c                    s   t  | t|jtstdt|j dt|jts(tdt|j d|j}|j}|j	| _	|j
| _|j
| _t|| _t|| _tj| j| j	dd| _tj| j| j	dd| _tt|j| _|   d S )NzMconfig.text_config is expected to be of type OwlViTTextConfig but is of type .zQconfig.vision_config is expected to be of type OwlViTVisionConfig but is of type F)r~   )r   r   r  text_configr   	TypeErrortypevision_configr   projection_dimr   r  r  r1  r<  rG  rL  r   r   r  r  r   r"   r   r  r  r=  )r?   ry   rQ  rT  r   r%   r&   r     s0   

zOwlViTModel.__init__Nr   r   r   c                 C      | j ||d}| |j}|S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`OwlViTTextModel`].

        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, OwlViTModel

        >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
        >>> inputs = processor(
        ...     text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
        ... )
        >>> with torch.inference_mode():
        ...     text_features = model.get_text_features(**inputs)
        ```)r   r   )r<  r  r5  )r?   r   r   text_outputstext_featuresr%   r%   r&   get_text_features  s   zOwlViTModel.get_text_featuresFr   r   c                 C   rV  )ai  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`OwlViTVisionModel`].

        Examples:
        ```python
        >>> import torch
        >>> from transformers.image_utils import load_image
        >>> from transformers import AutoProcessor, OwlViTModel

        >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = load_image(url)

        >>> inputs = processor(images=image, return_tensors="pt")
        >>> with torch.inference_mode():
        ...     image_features = model.get_image_features(**inputs)
        ```r   r   )rL  r  r5  )r?   r   r   vision_outputsimage_featuresr%   r%   r&   get_image_features  s   zOwlViTModel.get_image_featuresreturn_lossr   r$  return_base_image_embedsr%  c
              	   C   s:  |dur|n| j j}|dur|n| j j}|	dur|	n| j j}	| j|||||	d}
| j|||||	d}|d }| |}|
d }| |}|tj	j
|dddd }|tj	j
|dddd }| j |j}t|| | }| }d}|r{t|}|}|	s||||||
f}|dur|f| S |S t|||||||
d	S )
a&  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.
        return_base_image_embeds (`bool`, *optional*):
            Whether or not to return the base image embeddings.

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, OwlViTModel

        >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```NrM  rF  r   rW   r   T)ordr   keepdim)r.   r/   r0   r1   r2   r3   r4   )ry   r   r$  r+  rL  r<  r  r  r"   linalgnormr  expr   r   matmulr)   r,   r-   )r?   r   r   r   r^  r   r$  r   r_  r%  r[  rW  r1   r2   text_embeds_normr  r0   r/   r.   outputr%   r%   r&   r     sV   #	

zOwlViTModel.forwardrL   r   )	NNNNNNFNN)rF   rG   rH   r   rK   r   r   r   r"   r   r   rJ   rY  r   r]  r   r   rC   r-   r   r   r%   r%   r   r&   r    sl   
  ""	

r  c                       s>   e Zd Zd
dedef fddZdejdejfdd	Z	  Z
S )OwlViTBoxPredictionHead   ry   out_dimc                    sJ   t    |jj}t||| _t||| _t | _	t||| _
d S rL   )r   r   rT  r   r   r   dense0dense1GELUgeludense2)r?   ry   rj  r   r   r%   r&   r   Z  s   

z OwlViTBoxPredictionHead.__init__r\  r   c                 C   s6   |  |}| |}| |}| |}| |}|S rL   )rk  rn  rl  ro  )r?   r\  rg  r%   r%   r&   r   c  s   




zOwlViTBoxPredictionHead.forward)ri  )rF   rG   rH   r   rT   r   r"   r   rJ   r   r   r%   r%   r   r&   rh  Y  s    	rh  c                	       sP   e Zd Zdef fddZdejdeej deej de	ej fdd	Z
  ZS )
OwlViTClassPredictionHeadry   c                    sZ   t    |jj}|jj| _t| j|| _t| jd| _	t| jd| _
t | _d S )Nr   )r   r   rQ  r   rT  	query_dimr   r   rk  logit_shiftr  ELUelu)r?   ry   rj  r   r%   r&   r   m  s   

z"OwlViTClassPredictionHead.__init__r2   query_embeds
query_maskr   c                 C   s
  |  |}|d u r%|j}|jd d \}}t||| jf|}||fS |tjj|dddd  }|tjj|dddd  }t	d||}| 
|}	| |}
| |
d }
||	 |
 }|d ur|jdkrmtj|dd	}t|d
kt|jj|}|tj}||fS )NrW   r   T)r   ra  gư>z...pd,...qd->...pqr   r   r   r   )rk  r   r   r"   zerosrq  r   rb  rc  einsumrr  r  rt  ndimr   wherefinforN   rZ   rO   )r?   r2   ru  rv  image_class_embedsr   r   r   pred_logitsrr  r  r%   r%   r&   r   x  s&   



z!OwlViTClassPredictionHead.forward)rF   rG   rH   r   r   r"   rJ   r   r   rC   r   r   r%   r%   r   r&   rp  l  s    rp  c                       s  e Zd ZU eed< def fddZedededej	fddZ
ed	d
	d+dededeej dej	fddZ	d,dejdejdedejfddZ		d-dejdeej deej	 deej fddZ			d.dej	dejdej	dee dee dedeej fddZ			d.dejdee dee dedeej f
dd Z	d,d!ejd"ejdedejfd#d$Ze					d/dejd%eej dee dee ded&ee defd'd(Ze					d/dej	dejdeej	 dee dee ded&ee defd)d*Z  ZS )0OwlViTForObjectDetectionry   c                    s   t  | t|| _t|| _t|| _tj	|j
j|j
jd| _t | _|| _| jj
j| jj
j | _| jj
j| jj
j | _| | j| j| _d S r   )r   r   r  r   rp  
class_headrh  box_headr   r   rT  r   r   
layer_normSigmoidsigmoidry   r   r   num_patches_heightnum_patches_widthcompute_box_biasbox_biasr   r   r%   r&   r     s   



z!OwlViTForObjectDetection.__init__r  r  r   c                 C   s   t jd|d t jd}t jd| d t jd}t j||dd\}}t j||fdd}|d  |  < |d  |   < |dd	}|S )
Nr   )rN   xy)indexingr   r   .r   .r   rW   )r"   r#   rO   meshgridstackr   )r  r  x_coordinatesy_coordinatesxxyybox_coordinatesr%   r%   r&   !normalize_grid_corner_coordinates  s   z:OwlViTForObjectDetection.normalize_grid_corner_coordinatesrW   )maxsizeNfeature_mapc           	      C   s   |d urt d| ||}t|dd}t|d t| d  }t|d}|d  |  < |d  |  < t|d t| d  }tj||gdd}|S )	NzOfeature_map has been deprecated as an input. Please pass in num_patches insteadr   r  g-C6?r  r  r   r   )ri   r  r"   cliploglog1p	full_liker   )	r?   r  r  r  r  box_coord_biasbox_sizebox_size_biasr  r%   r%   r&   r    s   z)OwlViTForObjectDetection.compute_box_biasFimage_featsr   c           	      C   sR   |  |}|r|j\}}}}| ||}n| j}||j}||7 }| |}|S )a  
        Args:
            image_feats:
                Features extracted from the image, returned by the `image_text_embedder` method.
            feature_map:
                A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
            interpolate_pos_encoding:
                Whether to interpolate the pre-trained position encodings.
        Returns:
            pred_boxes:
                List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
        )r  r   r  r  r   r   r  )	r?   r  r  r   rq   r   r  r  r  r%   r%   r&   box_predictor  s   

z&OwlViTForObjectDetection.box_predictorru  rv  c                 C   s   |  |||\}}||fS )a8  
        Args:
            image_feats:
                Features extracted from the `image_text_embedder`.
            query_embeds:
                Text query embeddings.
            query_mask:
                Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
        )r  )r?   r  ru  rv  r}  r|  r%   r%   r&   class_predictor  s   z(OwlViTForObjectDetection.class_predictorr   r   r   r   r$  c              	   C   s   | j ||||||dd}|r$|j\}}}	}
|	| jjj }|
| jjj }n| j}| j}|jd }| j j	|}t
|d d d dd d f |d d d df j}|d d dd d d f | }| |}|jd |||jd f}||}|d }|||fS )NT)r   r   r   r   r$  r   r%  r   r   r   )r   r   ry   rT  r   r  r  r4   rL  rI  r"   broadcast_tor  r   )r?   r   r   r   r   r$  r   r   r   r   r   r  r  r)  r2   class_token_outnew_sizer1   r%   r%   r&   image_text_embedder	  s8   


4


z,OwlViTForObjectDetection.image_text_embedderc                 C   s   | j j||dd}|r!|j\}}}}|| jjj }	|| jjj }
n| j}	| j}
|d }| j j|}t	
|d d d dd d f |d d d df j}|d d dd d d f | }| |}|jd |	|
|jd f}||}||fS )NT)r   r   r%  r   r   r   )r   rL  r   ry   rT  r   r  r  rI  r"   r  r  r   )r?   r   r   r$  r   r[  r   r   r   r  r  r)  r2   r  r  r%   r%   r&   image_embedder<  s*   4

z'OwlViTForObjectDetection.image_embedderquery_image_featuresquery_feature_mapc                 C   s:  |  |\}}| |||}t|}g }g }	|j}
t|jd D ]f}tjg dg|
d}|| }t||\}}t	|d dkrEt
||}t|d }|d |k }| r|| |d }tj|| dd}td||}|t| }||| |  |	| q |rt|}t|	}nd	\}}|||fS )
Nr   )r   r   r   r   r   r   g?r   )axiszd,id->iNN)r  r  r   r   r!  r   r"   r   rg   rh   rm   r[   nonzeronumelsqueezer  rx  argminappendr  )r?   r  r  r   r   rr   rq   pred_boxes_as_cornersbest_class_embedsbest_box_indicespred_boxes_deviceieach_query_boxeach_query_pred_boxesiousiou_thresholdselected_indsselected_embeddingsmean_embedsmean_simbest_box_indru  box_indicesr%   r%   r&   embed_image_queryf  s6   



z*OwlViTForObjectDetection.embed_image_queryquery_pixel_valuesr%  c              
   C   s(  |dur|n| j j}|dur|n| j j}|dur|n| j j}| j||dd }| j||||d\}}	|j\}
}}}t||
|| |f}|j\}
}}}t||
|| |f}| |||\}}}| j	||d\}}| 
|||}|s|||||||	 f}tdd |D }|S t||||||d|	dS )	a  
        query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values of query image(s) to be detected. Pass in one query image per target image.

        Examples:
        ```python
        >>> import requests
        >>> from PIL import Image
        >>> import torch
        >>> from transformers import AutoProcessor, OwlViTForObjectDetection

        >>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch16")
        >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
        >>> query_image = Image.open(requests.get(query_url, stream=True).raw)
        >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")
        >>> with torch.no_grad():
        ...     outputs = model.image_guided_detection(**inputs)
        >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
        >>> target_sizes = torch.Tensor([image.size[::-1]])
        >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> results = processor.post_process_image_guided_detection(
        ...     outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes
        ... )
        >>> i = 0  # Retrieve predictions for the first image
        >>> boxes, scores = results[i]["boxes"], results[i]["scores"]
        >>> for box, score in zip(boxes, scores):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
        Detected similar object with confidence 0.856 at location [10.94, 50.4, 315.8, 471.39]
        Detected similar object with confidence 1.0 at location [334.84, 25.33, 636.16, 374.71]
        ```NrZ  r   )r   r   r$  r   )r  ru  c                 s   r&  rL   r%   r<   xr%   r%   r&   r@     r(  zBOwlViTForObjectDetection.image_guided_detection.<locals>.<genexpr>)r2   ru   rv   rw   r   rr   r3   r4   )ry   r   r$  r%  r  r   r"   r   r  r  r  r:   rC   rt   )r?   r   r  r   r$  r   r%  r  r  r[  r   r  r  
hidden_dimr  query_image_featsru  r  rw   r}  rr   rv   rg  r%   r%   r&   image_guided_detection  s^   ,

	z/OwlViTForObjectDetection.image_guided_detectionc              	   C   s4  |dur|n| j j}|dur|n| j j}|dur|n| j j}| j||||||d\}}	}
|
j}|
j}|	j\}}}}t	|	||| |f}|jd | }|	|||jd }|	|||jd }|d dk}| 
|||\}}| ||	|}|s||||	|| | f}tdd |D }|S t|	||||||dS )	a	  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids).
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the last hidden state. See `text_model_last_hidden_state` and
            `vision_model_last_hidden_state` under returned tensors for more detail.

        Examples:
        ```python
        >>> import requests
        >>> from PIL import Image
        >>> import torch

        >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection

        >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
        >>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text_labels = [["a photo of a cat", "a photo of a dog"]]
        >>> inputs = processor(text=text_labels, images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
        >>> target_sizes = torch.tensor([(image.height, image.width)])
        >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> results = processor.post_process_grounded_object_detection(
        ...     outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels
        ... )
        >>> # Retrieve predictions for the first image for the corresponding text queries
        >>> result = results[0]
        >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
        >>> for box, score, text_label in zip(boxes, scores, text_labels):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
        Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
        Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
        ```N)r   r   r   r   r$  r   r   r   r  c                 s   r&  rL   r%   r  r%   r%   r&   r@   \  r(  z3OwlViTForObjectDetection.forward.<locals>.<genexpr>)r2   r1   rq   r   rr   r3   r4   )ry   r   r$  r%  r  r3   r4   r   r"   r   r  r  r:   rC   ro   )r?   r   r   r   r   r$  r   r%  ru  r  r   rW  r[  r   r  r  r  r  max_text_queriesrv  r}  rr   rq   rg  r%   r%   r&   r     sT   4
	z OwlViTForObjectDetection.forwardrL   r   r  r   rN  )rF   rG   rH   r   rK   r   staticmethodrT   r"   r   r  r   r   rJ   r  r   r  rC   r  r  r  r  r   rt   r  ro   r   r   r%   r%   r   r&   r~    s   
 
$

6
.
,f	r~  )r  r   r:  rK  r~  )BrI   dataclassesr   	functoolsr   typingr   r   r   r"   r   r   activationsr
   modeling_attn_mask_utilsr   r   modeling_layersr   modeling_outputsr   r   modeling_utilsr   utilsr   r   r   r   r   r   configuration_owlvitr   r   r   transformers.image_transformsr   
get_loggerrF   loggerr'   r,   r-   rU   rX   rg   rm   ro   rt   r  rx   r   r   r   r   r   r  r1  r:  rG  rK  r  rh  rp  r~  __all__r%   r%   r%   r&   <module>   sr    
%	.-Il2/PL671 J0   Q