o
    i\                     @   s  d Z ddlmZmZ ddlZddlmZ ddlm	Z	m
Z
mZ ddlmZmZmZmZmZ ddlmZ dd	lmZmZmZmZ d
dlmZ eeZdZdZg dZ dZ!dZ"G dd dej#j$Z%G dd dej#j$Z&G dd dej#j$Z'G dd dej#j$Z(G dd dej#j$Z)G dd dej#j$Z*G dd dej#j$Z+G dd deZ,d Z-d!Z.eG d"d# d#ej#j$Z/ed$e-G d%d& d&e,Z0ed'e-G d(d) d)e,eZ1g d*Z2dS )+zTensorFlow ResNet model.    )OptionalUnionN   )ACT2FN) TFBaseModelOutputWithNoAttention*TFBaseModelOutputWithPoolingAndNoAttention&TFImageClassifierOutputWithNoAttention)TFPreTrainedModelTFSequenceClassificationLosskeraskeras_serializableunpack_inputs)
shape_list)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )ResNetConfigr   zmicrosoft/resnet-50)r   i      r   z	tiger catc                       sz   e Zd Z			ddededededed	d
f fddZdejd	ejfddZddejde	d	ejfddZ
dddZ  ZS )TFResNetConvLayerr   r   reluin_channelsout_channelskernel_sizestride
activationreturnNc                    sx   t  jdi | |d | _tjj|||dddd| _tjjdddd	| _|d ur-t	| ntj
d
| _|| _|| _d S )N   validFconvolution)r   stridespaddinguse_biasnameh㈵>?normalizationepsilonmomentumr$   linear )super__init__	pad_valuer   layersConv2DconvBatchNormalizationr'   r   
Activationr   r   r   )selfr   r   r   r   r   kwargs	__class__r,   a/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/resnet/modeling_tf_resnet.pyr.   6   s   	

zTFResNetConvLayer.__init__hidden_statec                 C   s2   | j | j f }}t|d||dg}| |}|S )N)r   r   )r/   tfpadr2   )r5   r:   
height_pad	width_padr,   r,   r9   r    J   s   
zTFResNetConvLayer.convolutionFtrainingc                 C   s&   |  |}| j||d}| |}|S Nr?   )r    r'   r   )r5   r:   r?   r,   r,   r9   callQ   s   

zTFResNetConvLayer.callc                 C      | j rd S d| _ t| dd d ur2t| jj | jd d d | jg W d    n1 s-w   Y  t| dd d ur_t| jj | jd d d | j	g W d    d S 1 sXw   Y  d S d S )NTr2   r'   )
builtgetattrr;   
name_scoper2   r$   buildr   r'   r   r5   input_shaper,   r,   r9   rG   W      "zTFResNetConvLayer.build)r   r   r   FN)__name__
__module____qualname__intstrr.   r;   Tensorr    boolrB   rG   __classcell__r,   r,   r7   r9   r   5   s(    r   c                       sP   e Zd ZdZdeddf fddZddejd	edejfd
dZ	dddZ
  ZS )TFResNetEmbeddingszO
    ResNet Embeddings (stem) composed of a single aggressive convolution.
    configr   Nc                    sP   t  jd	i | t|j|jdd|jdd| _tjj	ddddd| _
|j| _d S )
Nr   r   embedder)r   r   r   r$   r   r   pooler)	pool_sizer!   r"   r$   r,   )r-   r.   r   num_channelsembedding_size
hidden_actrW   r   r0   	MaxPool2DrX   r5   rV   r6   r7   r,   r9   r.   h   s   zTFResNetEmbeddings.__init__Fpixel_valuesr?   c                 C   sj   t |\}}}}t r|| jkrtd|}| |}t|ddgddgddgddgg}| |}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   )r   r;   executing_eagerlyrZ   
ValueErrorrW   r<   rX   )r5   r_   r?   _rZ   r:   r,   r,   r9   rB   u   s   
$
zTFResNetEmbeddings.callc                 C      | j rd S d| _ t| dd d ur-t| jj | jd  W d    n1 s(w   Y  t| dd d urUt| jj | jd  W d    d S 1 sNw   Y  d S d S )NTrW   rX   )rD   rE   r;   rF   rW   r$   rG   rX   rH   r,   r,   r9   rG         "zTFResNetEmbeddings.buildrK   rL   )rM   rN   rO   __doc__r   r.   r;   rR   rS   rB   rG   rT   r,   r,   r7   r9   rU   c   s
    rU   c                	       sZ   e Zd ZdZddedededdf fdd	ZddejdedejfddZ	dddZ
  ZS )TFResNetShortCutz
    ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
    downsample the input using `stride=2`.
    r   r   r   r   r   Nc                    sN   t  jd	i | tjj|d|ddd| _tjjdddd| _|| _|| _	d S )
Nr   Fr    )r   r!   r#   r$   r%   r&   r'   r(   r,   )
r-   r.   r   r0   r1   r    r3   r'   r   r   )r5   r   r   r   r6   r7   r,   r9   r.      s   

zTFResNetShortCut.__init__Fxr?   c                 C   s    |}|  |}| j||d}|S r@   )r    r'   )r5   rg   r?   r:   r,   r,   r9   rB      s   
zTFResNetShortCut.callc                 C   rC   )NTr    r'   )
rD   rE   r;   rF   r    r$   rG   r   r'   r   rH   r,   r,   r9   rG      rJ   zTFResNetShortCut.build)r   rK   rL   )rM   rN   rO   re   rP   r.   r;   rR   rS   rB   rG   rT   r,   r,   r7   r9   rf      s
     
rf   c                       s`   e Zd ZdZ	ddededededd	f
 fd
dZddejde	dejfddZ
dddZ  ZS )TFResNetBasicLayerzO
    A classic ResNet's residual layer composed by two `3x3` convolutions.
    r   r   r   r   r   r   r   Nc                    sz   t  jd	i | ||kp|dk}t|||dd| _t||d dd| _|r-t|||ddntjjddd| _	t
| | _d S )
Nr   layer.0r   r$   layer.1r   r$   shortcutr+   r$   r,   )r-   r.   r   conv1conv2rf   r   r0   r4   rm   r   r   )r5   r   r   r   r   r6   should_apply_shortcutr7   r,   r9   r.      s   zTFResNetBasicLayer.__init__Fr:   r?   c                 C   sD   |}| j ||d}| j||d}| j||d}||7 }| |}|S r@   )ro   rp   rm   r   r5   r:   r?   residualr,   r,   r9   rB      s   
zTFResNetBasicLayer.callc                 C   s   | j rd S d| _ t| dd d ur-t| jj | jd  W d    n1 s(w   Y  t| dd d urRt| jj | jd  W d    n1 sMw   Y  t| dd d urzt| jj | jd  W d    d S 1 ssw   Y  d S d S )NTro   rp   rm   )	rD   rE   r;   rF   ro   r$   rG   rp   rm   rH   r,   r,   r9   rG      s    "zTFResNetBasicLayer.build)r   r   rK   rL   rM   rN   rO   re   rP   rQ   r.   r;   rR   rS   rB   rG   rT   r,   r,   r7   r9   rh      s     	rh   c                       sh   e Zd ZdZ			ddedededed	ed
df fddZddejde	d
ejfddZ
dddZ  ZS )TFResNetBottleNeckLayera%  
    A classic ResNet's bottleneck layer composed by three `3x3` convolutions.

    The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
    convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
    r   r      r   r   r   r   	reductionr   Nc           	         s   t  jdi | ||kp|dk}|| }t||ddd| _t|||dd| _t||dd dd| _|r;t|||ddntjj	d	dd
| _
t| | _d S )Nr   ri   )r   r$   rk   rj   zlayer.2)r   r   r$   rm   r+   rn   r,   )r-   r.   r   conv0ro   rp   rf   r   r0   r4   rm   r   r   )	r5   r   r   r   r   rw   r6   rq   reduces_channelsr7   r,   r9   r.      s   	z TFResNetBottleNeckLayer.__init__Fr:   r?   c                 C   sR   |}| j ||d}| j||d}| j||d}| j||d}||7 }| |}|S r@   )rx   ro   rp   rm   r   rr   r,   r,   r9   rB      s   
zTFResNetBottleNeckLayer.callc                 C   sB  | j rd S d| _ t| dd d ur-t| jj | jd  W d    n1 s(w   Y  t| dd d urRt| jj | jd  W d    n1 sMw   Y  t| dd d urwt| jj | jd  W d    n1 srw   Y  t| dd d urt| j	j | j	d  W d    d S 1 sw   Y  d S d S )NTrx   ro   rp   rm   )
rD   rE   r;   rF   rx   r$   rG   ro   rp   rm   rH   r,   r,   r9   rG     s(   "zTFResNetBottleNeckLayer.build)r   r   rv   rK   rL   rt   r,   r,   r7   r9   ru      s(    
ru   c                       sd   e Zd ZdZ	ddedededededd	f fd
dZddejde	dejfddZ
dddZ  ZS )TFResNetStagez4
    A ResNet stage composed of stacked layers.
    r   rV   r   r   r   depthr   Nc                    sf   t  jdi |  jdkrtnt|| jddg}| fddt|d D 7 }|| _d S )N
bottleneckzlayers.0)r   r   r$   c              	      s(   g | ]} j d |d  dqS )zlayers.r   rl   )r\   ).0irV   layerr   r,   r9   
<listcomp>!  s    z*TFResNetStage.__init__.<locals>.<listcomp>r   r,   )r-   r.   
layer_typeru   rh   r\   rangestage_layers)r5   rV   r   r   r   r{   r6   r0   r7   r   r9   r.     s   

zTFResNetStage.__init__Fr:   r?   c                 C   s   | j D ]}|||d}q|S r@   )r   )r5   r:   r?   r   r,   r,   r9   rB   '  s   
zTFResNetStage.callc              	   C   j   | j rd S d| _ t| dd d ur1| jD ]}t|j |d  W d    n1 s+w   Y  qd S d S )NTr   )rD   rE   r   r;   rF   r$   rG   r5   rI   r   r,   r,   r9   rG   ,     
zTFResNetStage.build)r   r   rK   rL   )rM   rN   rO   re   r   rP   r.   r;   rR   rS   rB   rG   rT   r,   r,   r7   r9   rz     s$    rz   c                       sX   e Zd Zdeddf fddZ			ddejd	ed
ededef
ddZ	dddZ
  ZS )TFResNetEncoderrV   r   Nc                    s   t  jdi | t||j|jd |jrdnd|jd ddg| _tt	|j|jdd  |jdd  D ]\}\}}}| j
t||||d|d  d q4d S )	Nr   r   r   zstages.0)r   r{   r$   zstages.)r{   r$   r,   )r-   r.   rz   r[   hidden_sizesdownsample_in_first_stagedepthsstages	enumeratezipappend)r5   rV   r6   r~   r   r   r{   r7   r,   r9   r.   7  s   
 &zTFResNetEncoder.__init__FTr:   output_hidden_statesreturn_dictr?   c                 C   sf   |rdnd }| j D ]}|r||f }|||d}q	|r ||f }|s-tdd ||fD S t||dS )Nr,   rA   c                 s   s    | ]	}|d ur|V  qd S rL   r,   )r}   vr,   r,   r9   	<genexpr>\      z'TFResNetEncoder.call.<locals>.<genexpr>)last_hidden_statehidden_states)r   tupler   )r5   r:   r   r   r?   r   stage_moduler,   r,   r9   rB   I  s   


zTFResNetEncoder.callc              	   C   r   )NTr   )rD   rE   r   r;   rF   r$   rG   r   r,   r,   r9   rG   `  r   zTFResNetEncoder.build)FTFrL   )rM   rN   rO   r   r.   r;   rR   rS   r   rB   rG   rT   r,   r,   r7   r9   r   6  s"    
r   c                   @   s(   e Zd ZdZeZdZdZedd Z	dS )TFResNetPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    resnetr_   c                 C   s    dt jd | jjddft jdiS )Nr_      )shapedtype)r;   
TensorSpecrV   rZ   float32)r5   r,   r,   r9   input_signaturet  s    z'TFResNetPreTrainedModel.input_signatureN)
rM   rN   rO   re   r   config_classbase_model_prefixmain_input_namepropertyr   r,   r,   r,   r9   r   j  s    r   ad  
    This model is a TensorFlow
    [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
    regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
a>  
    Args:
        pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`ConvNextImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                       sv   e Zd ZeZdeddf fddZe			ddejde	e
 d	e	e
 d
e
deeej ef f
ddZdddZ  ZS )TFResNetMainLayerrV   r   Nc                    sH   t  jdi | || _t|dd| _t|dd| _tjj	dd| _
d S )NrW   rn   encoderT)keepdimsr,   )r-   r.   rV   rU   rW   r   r   r   r0   GlobalAveragePooling2DrX   r^   r7   r,   r9   r.     s
   zTFResNetMainLayer.__init__Fr_   r   r   r?   c                 C   s   |d ur|n| j j}|d ur|n| j j}tj|g dd}| j||d}| j||||d}|d }| |}t|d}t|d}d}	|dd  D ]}
|	td	d
 |
D  }	qJ|s`||f|	 S |rd|	nd }	t	|||	dS )N)r   r   r   r   )permrA   r   r   r?   r   r   r   r   r   r,   r   c                 s   s    | ]	}t |d V  qdS )r   N)r;   	transpose)r}   hr,   r,   r9   r     r   z)TFResNetMainLayer.call.<locals>.<genexpr>)r   pooler_outputr   )
rV   r   use_return_dictr;   r   rW   r   rX   r   r   )r5   r_   r   r   r?   embedding_outputencoder_outputsr   pooled_outputr   r:   r,   r,   r9   rB     s.   	
zTFResNetMainLayer.callc                 C   rc   )NTrW   r   )rD   rE   r;   rF   rW   r$   rG   r   rH   r,   r,   r9   rG     rd   zTFResNetMainLayer.buildNNFrL   )rM   rN   rO   r   r   r.   r   r;   rR   r   rS   r   r   r   rB   rG   rT   r,   r,   r7   r9   r     s&    -r   zOThe bare ResNet model outputting raw features without any specific head on top.c                       s   e Zd Zdeddf fddZeeeee	e
dede			dd	ejd
ee dee dedeeej e	f f
ddZdddZ  ZS )TFResNetModelrV   r   Nc                    s&   t  j|fi | t|dd| _d S )Nr   )rV   r$   )r-   r.   r   r   r^   r7   r,   r9   r.     s   zTFResNetModel.__init__vision)
checkpointoutput_typer   modalityexpected_outputFr_   r   r   r?   c                 C   s>   |d ur|n| j j}|d ur|n| j j}| j||||d}|S )N)r_   r   r   r?   )rV   r   r   r   )r5   r_   r   r   r?   resnet_outputsr,   r,   r9   rB     s   zTFResNetModel.callc                 C   sd   | j rd S d| _ t| dd d ur0t| jj | jd  W d    d S 1 s)w   Y  d S d S )NTr   )rD   rE   r;   rF   r   r$   rG   rH   r,   r,   r9   rG     s   "zTFResNetModel.buildr   rL   )rM   rN   rO   r   r.   r   RESNET_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r;   rR   r   rS   r   r   rB   rG   rT   r,   r,   r7   r9   r     s4    r   z
    ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    c                       s   e Zd Zdeddf fddZdejdejfddZee	e
eeeed	e					
ddeej deej dee dee dedeeej ef fddZdddZ  ZS )TFResNetForImageClassificationrV   r   Nc                    sb   t  j|fi | |j| _t|dd| _|jdkr#tjj|jddntjjddd| _	|| _
d S )Nr   rn   r   zclassifier.1r+   )r-   r.   
num_labelsr   r   r   r0   Denser4   classifier_layerrV   r^   r7   r,   r9   r.     s   

z'TFResNetForImageClassification.__init__rg   c                 C   s   t j |}| |}|S rL   )r   r0   Flattenr   )r5   rg   logitsr,   r,   r9   
classifier  s   
z)TFResNetForImageClassification.classifier)r   r   r   r   Fr_   labelsr   r   r?   c                 C   s   |dur|n| j j}| j||||d}|r|jn|d }| |}|du r'dn| ||}	|sC|f|dd  }
|	durA|	f|
 S |
S t|	||jdS )a)  
        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   )lossr   r   )rV   r   r   r   r   hf_compute_lossr   r   )r5   r_   r   r   r   r?   outputsr   r   r   outputr,   r,   r9   rB   !  s   
z#TFResNetForImageClassification.callc                 C   s   | j rd S d| _ t| dd d ur-t| jj | jd  W d    n1 s(w   Y  t| dd d ur\t| jj | jd d | jj	d g W d    d S 1 sUw   Y  d S d S )NTr   r   )
rD   rE   r;   rF   r   r$   rG   r   rV   r   rH   r,   r,   r9   rG   H  s   "z$TFResNetForImageClassification.build)NNNNFrL   )rM   rN   rO   r   r.   r;   rR   r   r   r   r   _IMAGE_CLASS_CHECKPOINTr   r   _IMAGE_CLASS_EXPECTED_OUTPUTr   r   rS   r   r   rB   rG   rT   r,   r,   r7   r9   r     s<    r   )r   r   r   )3re   typingr   r   
tensorflowr;   activations_tfr   modeling_tf_outputsr   r   r   modeling_tf_utilsr	   r
   r   r   r   tf_utilsr   utilsr   r   r   r   configuration_resnetr   
get_loggerrM   loggerr   r   r   r   r   r0   Layerr   rU   rf   rh   ru   rz   r   r   RESNET_START_DOCSTRINGr   r   r   r   __all__r,   r,   r,   r9   <module>   sL   
.*"+:"4D+E