o
    eiA                  	   @   s|  d Z ddlZddlZddlmZ ddlZddlmZ ddlm	Z
 ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZmZ ddlmZ eeZeeddG dd deZdd Zdd ZdJdej de!de"dej fddZ#G dd dej$Z%G d d! d!ej$Z&G d"d# d#ej$Z'G d$d% d%ej$Z(G d&d' d'ej$Z)G d(d) d)ej$Z*G d*d+ d+ej$Z+G d,d- d-ej$Z,G d.d/ d/ej$Z-G d0d1 d1ej$Z.G d2d3 d3ej$Z/G d4d5 d5eZ0G d6d7 d7ej$Z1eG d8d9 d9eZ2eG d:d; d;e2Z3G d<d= d=ej$Z4G d>d? d?ej$Z5G d@dA dAej$Z6G dBdC dCej$Z7G dDdE dEej$Z8edFdG dGdH dHe2Z9g dIZ:dS )Kz"PyTorch Swin2SR Transformer model.    N)	dataclass)nn   )initialization)ACT2FN)GradientCheckpointingLayer)BaseModelOutputImageSuperResolutionOutput)PreTrainedModel)ModelOutputauto_docstringlogging   )Swin2SRConfigzQ
    Swin2SR encoder's outputs, with potential hidden states and attentions.
    )custom_introc                   @   sL   e Zd ZU dZejdB ed< dZeej dB ed< dZ	eej dB ed< dS )Swin2SREncoderOutputNlast_hidden_statehidden_states
attentions)
__name__
__module____qualname__r   torchFloatTensor__annotations__r   tupler    r   r   j/home/ubuntu/transcripts/venv/lib/python3.10/site-packages/transformers/models/swin2sr/modeling_swin2sr.pyr   #   s   
 r   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowsr   r   r   window_partition0   s   $r.   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )z?
    Merges windows to produce higher resolution features.
    r!   r   r   r   r   r   r    r"   )r-   r(   r*   r+   r,   r   r   r   window_reverse=   s   
$r/           Finput	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )zc
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    r0   r   r   )r   )dtypedevice)r#   ndimr   randr5   r6   floor_div)r1   r2   r3   	keep_probr#   random_tensoroutputr   r   r   	drop_pathH   s   r>   c                       sT   e Zd ZdZddedB ddf fddZdejdejfdd	Zde	fd
dZ
  ZS )Swin2SRDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr2   r4   c                    s   t    || _d S N)super__init__r2   )selfr2   	__class__r   r   rB   [   s   

zSwin2SRDropPath.__init__r   c                 C   s   t || j| jS r@   )r>   r2   r3   rC   r   r   r   r   forward_   s   zSwin2SRDropPath.forwardc                 C   s   d| j  S )Nzp=)r2   rC   r   r   r   
extra_reprb   s   zSwin2SRDropPath.extra_reprr@   )r   r   r   __doc__floatrB   r   TensorrG   strrI   __classcell__r   r   rD   r   r?   X   s
    r?   c                       s>   e Zd ZdZ fddZdejdB deej fddZ	  Z
S )	Swin2SREmbeddingsz?
    Construct the patch and optional position embeddings.
    c                    s`   t    t|| _| jj}|jr tt	d|d |j
| _nd | _t|j| _|j| _d S )Nr   )rA   rB   Swin2SRPatchEmbeddingspatch_embeddingsnum_patchesuse_absolute_embeddingsr   	Parameterr   zeros	embed_dimposition_embeddingsDropouthidden_dropout_probdropoutr(   )rC   configrR   rD   r   r   rB   k   s   

zSwin2SREmbeddings.__init__pixel_valuesNr4   c                 C   s4   |  |\}}| jd ur|| j }| |}||fS r@   )rQ   rW   rZ   )rC   r\   
embeddingsoutput_dimensionsr   r   r   rG   y   s
   


zSwin2SREmbeddings.forward)r   r   r   rJ   rB   r   r   r   rL   rG   rN   r   r   rD   r   rO   f   s    &rO   c                       sD   e Zd Zd	 fdd	ZdejdB deejee f fddZ	  Z
S )
rP   Tc                    s   t    |j}|j|j}}t|tjjr|n||f}t|tjjr%|n||f}|d |d  |d |d  g}|| _	|d |d  | _
tj||j||d| _|r[t|j| _d S d | _d S )Nr   r   )kernel_sizestride)rA   rB   rV   
image_size
patch_size
isinstancecollectionsabcIterablepatches_resolutionrR   r   Conv2d
projection	LayerNorm	layernorm)rC   r[   normalize_patchesr,   ra   rb   rg   rD   r   r   rB      s   
  zSwin2SRPatchEmbeddings.__init__r]   Nr4   c                 C   sN   |  |}|j\}}}}||f}|ddd}| jd ur#| |}||fS )Nr   r   )ri   r#   flatten	transposerk   )rC   r]   _r*   r+   r^   r   r   r   rG      s   


zSwin2SRPatchEmbeddings.forward)T)r   r   r   rB   r   r   r   rL   intrG   rN   r   r   rD   r   rP      s    .rP   c                       (   e Zd ZdZ fddZdd Z  ZS )Swin2SRPatchUnEmbeddingszImage to Patch Unembeddingc                    s   t    |j| _d S r@   )rA   rB   rV   )rC   r[   rD   r   r   rB      s   
z!Swin2SRPatchUnEmbeddings.__init__c                 C   s2   |j \}}}|dd|| j|d |d }|S )Nr   r   r   )r#   rn   r$   rV   )rC   r]   x_sizer)   height_widthr,   r   r   r   rG      s   "z Swin2SRPatchUnEmbeddings.forwardr   r   r   rJ   rB   rG   rN   r   r   rD   r   rr      s    rr   c                	       sh   e Zd ZdZejfdee dedejddf fddZ	d	d
 Z
dejdeeef dejfddZ  ZS )Swin2SRPatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    input_resolutiondim
norm_layerr4   Nc                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr   r   Fbias)rA   rB   rw   rx   r   Linear	reductionnorm)rC   rw   rx   ry   rD   r   r   rB      s
   
zSwin2SRPatchMerging.__init__c                 C   sF   |d dkp|d dk}|r!ddd|d d|d f}t j||}|S )Nr   r   r   )r   
functionalpad)rC   r'   r*   r+   
should_pad
pad_valuesr   r   r   	maybe_pad   s
   zSwin2SRPatchMerging.maybe_padr'   input_dimensionsc                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r   r   r!   r   )r#   r$   r   r   catr}   r~   )rC   r'   r   r*   r+   r)   rx   r,   input_feature_0input_feature_1input_feature_2input_feature_3r   r   r   rG      s   $$$$

zSwin2SRPatchMerging.forward)r   r   r   rJ   r   rj   r   rp   ModulerB   r   r   rL   rG   rN   r   r   rD   r   rv      s
    **rv   c                
       s^   e Zd Zddgf fdd	Z		ddejdejdB dedB d	eej fd
dZ	dd Z
  ZS )Swin2SRSelfAttentionr   c              
      sF  t    || dkrtd| d| d|| _t|| | _| j| j | _t|tj	j
r0|n||f| _|| _ttdt|ddf | _ttjddd	d
tjd	dtjd|dd
| _|  \}}| jd|dd | jd|dd tj| j| j|jd
| _tj| j| jdd
| _tj| j| j|jd
| _t|j| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()
   r   r   i   Trz   inplaceFrelative_coords_table
persistentrelative_position_index) rA   rB   
ValueErrornum_attention_headsrp   attention_head_sizeall_head_sizerc   rd   re   rf   r(   pretrained_window_sizer   rT   r   logoneslogit_scale
Sequentialr|   ReLUcontinuous_position_bias_mlpcreate_coords_table_and_indexregister_bufferqkv_biasquerykeyvaluerX   attention_probs_dropout_probrZ   )rC   r[   rx   	num_headsr(   r   r   r   rD   r   r   rB      s,   
"&zSwin2SRSelfAttention.__init__NFr   attention_maskoutput_attentionsr4   c                 C   s  |j \}}}| ||d| j| jdd}| ||d| j| jdd}| ||d| j| jdd}	tj	j
|ddtj	j
|dddd }
tj| jtdd }|
| }
| | jd| j}|| jd | jd | jd  | jd | jd  d}|ddd }d	t| }|
|d }
|d ur|j d }|
|| || j|||dd }
|
|dd }
|
d| j||}
tj	j|
dd}| |}t||	}|dddd
 }| d d | jf }||}|r||f}|S |f}|S )Nr!   r   r   )rx   g      Y@)maxr      r   )r#   r   r$   r   r   rn   r   r   r   r   	normalizer   clampr   mathr   expr   r   r   r(   r%   r&   sigmoid	unsqueezesoftmaxrZ   matmulsizer   )rC   r   r   r   r)   rx   r,   query_layer	key_layervalue_layerattention_scoresr   relative_position_bias_tablerelative_position_bias
mask_shapeattention_probscontext_layernew_context_layer_shapeoutputsr   r   r   rG     s`   &


zSwin2SRSelfAttention.forwardc           
      C   s  t j| jd d  | jd t jd }t j| jd d  | jd t jd }t t j||gddddd 	d}| j
d dkrt|d d d d d d df  | j
d d   < |d d d d d d df  | j
d d   < n5| jd dkr|d d d d d d df  | jd d   < |d d d d d d df  | jd d   < |d9 }t |t t |d  td }|t| j j}t | jd }t | jd }t t j||gdd}t |d}|d d d d d f |d d d d d f  }|ddd }|d d d d df  | jd d 7  < |d d d d df  | jd d 7  < |d d d d df  d| jd  d 9  < |d	}	||	fS )
Nr   r   r5   ij)indexingr      g      ?r!   )r   aranger(   int64rK   stackmeshgridr%   r&   r   r   signlog2absr   tonextr   
parametersr5   rm   sum)
rC   relative_coords_hrelative_coords_wr   coords_hcoords_wcoordscoords_flattenrelative_coordsr   r   r   r   r   G  s8   ((
.0..&,((,
z2Swin2SRSelfAttention.create_coords_table_and_indexNF)r   r   r   rB   r   rL   r   boolr   rG   r   rN   r   r   rD   r   r      s     
Dr   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )Swin2SRSelfOutputc                    s*   t    t||| _t|j| _d S r@   )rA   rB   r   r|   denserX   r   rZ   rC   r[   rx   rD   r   r   rB   o  s   
zSwin2SRSelfOutput.__init__r   input_tensorr4   c                 C      |  |}| |}|S r@   r   rZ   )rC   r   r   r   r   r   rG   t     

zSwin2SRSelfOutput.forwardr   r   r   rB   r   rL   rG   rN   r   r   rD   r   r   n  s    $r   c                
       sP   e Zd Zd fdd	Z		ddejdejdB dedB d	eej fd
dZ	  Z
S )Swin2SRAttentionr   c                    sD   t    t||||t|tjjr|n||fd| _t||| _	d S )Nr[   rx   r   r(   r   )
rA   rB   r   rc   rd   re   rf   rC   r   r=   )rC   r[   rx   r   r(   r   rD   r   r   rB   }  s   
	zSwin2SRAttention.__init__NFr   r   r   r4   c                 C   s4   |  |||}| |d |}|f|dd   }|S Nr   r   )rC   r=   )rC   r   r   r   self_outputsattention_outputr   r   r   r   rG     s   zSwin2SRAttention.forwardr   r   )r   r   r   rB   r   rL   r   r   r   rG   rN   r   r   rD   r   r   |  s    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )Swin2SRIntermediatec                    sJ   t    t|t|j| | _t|jt	rt
|j | _d S |j| _d S r@   )rA   rB   r   r|   rp   	mlp_ratior   rc   
hidden_actrM   r   intermediate_act_fnr   rD   r   r   rB     s
   
zSwin2SRIntermediate.__init__r   r4   c                 C   r   r@   )r   r   rF   r   r   r   rG        

zSwin2SRIntermediate.forwardr   r   r   rD   r   r     s    r   c                       r   )Swin2SROutputc                    s4   t    tt|j| || _t|j| _	d S r@   )
rA   rB   r   r|   rp   r   r   rX   rY   rZ   r   rD   r   r   rB     s   
zSwin2SROutput.__init__r   r4   c                 C   r   r@   r   rF   r   r   r   rG     r   zSwin2SROutput.forwardr   r   r   rD   r   r     s    r   c                       s   e Zd Z	d fdd	Zdeeeef eeef f fddZdd	 Zd
d Z	dde	j
deeef dedB dee	j
e	j
f fddZ  ZS )Swin2SRLayerr0   r   c           	         s   t    || _| |j|jf||f\}}|d | _|d | _t|||| jt|tj	j
r/|n||fd| _tj||jd| _|dkrGt|nt | _t||| _t||| _tj||jd| _d S )Nr   r   epsr0   )rA   rB   rw   _compute_window_shiftr(   
shift_sizer   rc   rd   re   rf   	attentionr   rj   layer_norm_epslayernorm_beforer?   Identityr>   r   intermediater   r=   layernorm_after)	rC   r[   rx   rw   r   drop_path_rater   r   r(   rD   r   r   rB     s*   


	zSwin2SRLayer.__init__r4   c                 C   s6   dd t | j|D }dd t | j||D }||fS )Nc                 S   s   g | ]	\}}t ||qS r   )min).0rwr   r   r   
<listcomp>  s    z6Swin2SRLayer._compute_window_shift.<locals>.<listcomp>c                 S   s"   g | ]\}}}||krd n|qS r   r   )r   r   r   sr   r   r   r     s   " )ziprw   )rC   target_window_sizetarget_shift_sizer(   r   r   r   r   r     s   z"Swin2SRLayer._compute_window_shiftc              	   C   s  | j dkrtjd||df|d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ]}|D ]}	||d d ||	d d f< |d7 }qDq@t|| j}
|
d| j| j }
|
d|
d }||dkd|dkd}|S d }|S )Nr   r   r   r!   r   g      Yr0   )	r   r   rU   slicer(   r.   r$   r   masked_fill)rC   r*   r+   r5   img_maskheight_sliceswidth_slicescountheight_slicewidth_slicemask_windows	attn_maskr   r   r   get_attn_mask  s.   

zSwin2SRLayer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS )Nr   )r(   r   r   r   )rC   r   r*   r+   	pad_right
pad_bottomr   r   r   r   r     s
   zSwin2SRLayer.maybe_padFr   r   r   Nc                 C   s  |\}}|  \}}}|}	|||||}| |||\}}
|j\}}}}| jdkr9tj|| j | j fdd}n|}t|| j}|d| j| j |}| j	|||j
d}|d ur_||j}| j|||d}|d }|d| j| j|}t|| j||}| jdkrtj|| j| jfdd}n|}|
d dkp|
d dk}|r|d d d |d |d d f  }|||| |}| |}|	| | }| |}| |}|| | | }|r||d	 f}|S |f}|S )
Nr   )r   r   )shiftsdimsr!   r   )r   r   r    r   )r   r$   r   r#   r   r   rollr.   r(   r  r5   r   r6   r   r/   r&   r   r>   r   r=   r   )rC   r   r   r   r*   r+   r)   ro   channelsshortcutr   
height_pad	width_padshifted_hidden_stateshidden_states_windowsr  attention_outputsr   attention_windowsshifted_windows
was_paddedlayer_outputlayer_outputsr   r   r   rG     sD   

$


zSwin2SRLayer.forward)r0   r   r   F)r   r   r   rB   r   rp   r   r  r   r   rL   r   rG   rN   r   r   rD   r   r     s     &
r   c                
       sT   e Zd ZdZd fdd	Z	ddejdeeef de	d	B d
eej fddZ
  ZS )Swin2SRStagezh
    This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation.
    r   c                    s   t     | _| _t fddt|D | _ jdkr.t	ddd| _
n6 jdkrdtt	d dddtjdd	d
t	d d dddtjdd	d
t	d ddd| _
t dd| _t | _d S )Nc              
      s6   g | ]}t  |d  dkrdn jd  dqS )r   r   )r[   rx   rw   r   r   r   )r   r(   )r   ir[   rx   rw   r   r   r   r   r   8  s    	z)Swin2SRStage.__init__.<locals>.<listcomp>1convr   r   3convr   皙?Tnegative_sloper   r   F)rl   )rA   rB   r[   rx   r   
ModuleListrangelayersresi_connectionrh   convr   	LeakyReLUrP   patch_embedrr   patch_unembed)rC   r[   rx   rw   depthr   r>   r   rD   r!  r   rB   3  s(   
	

zSwin2SRStage.__init__Fr   r   r   Nr4   c                 C   s   |}|\}}t | jD ]\}}||||}	|	d }q||||f}
| ||}| |}| |\}}|| }||
f}|rD||	dd  7 }|S r   )	enumerater)  r.  r+  r-  )rC   r   r   r   residualr*   r+   r   layer_moduler  r^   ro   stage_outputsr   r   r   rG   U  s   

zSwin2SRStage.forwardr   r  )r   r   r   rJ   rB   r   rL   r   rp   r   rG   rN   r   r   rD   r   r  .  s    &
r  c                       s`   e Zd Z fddZ			ddejdeeef dedB d	edB d
edB dee	B fddZ
  ZS )Swin2SREncoderc                    sn   t    t j| _ | _dd tjd jt	 jddD t
 fddt| jD | _d| _d S )Nc                 S   s   g | ]}|  qS r   )item)r   xr   r   r   r   w  s    z+Swin2SREncoder.__init__.<locals>.<listcomp>r   cpu)r6   c                    sd   g | ].}t   jd  d f j|  j| t jd| t jd|d   d dqS )r   r   N)r[   rx   rw   r/  r   r>   r   )r  rV   depthsr   r   )r   	stage_idxr[   dpr	grid_sizer   r   r   y  s    
*F)rA   rB   lenr8  
num_stagesr[   r   linspacer   r   r   r'  r(  stagesgradient_checkpointing)rC   r[   r<  rD   r:  r   rB   s  s   
$

zSwin2SREncoder.__init__FTr   r   r   Noutput_hidden_statesreturn_dictr4   c                 C   s   d}|rdnd }|rdnd }|r||f7 }t | jD ]0\}	}
|
|||}|d }|d }|d |d f}||f7 }|r@||f7 }|rJ||dd  7 }q|sYtdd |||fD S t|||d	S )
Nr   r   r   r   r!   r   c                 s   s    | ]	}|d ur|V  qd S r@   r   )r   vr   r   r   	<genexpr>  s    z)Swin2SREncoder.forward.<locals>.<genexpr>r   r   r   )r0  r@  r   r   )rC   r   r   r   rB  rC  all_input_dimensionsall_hidden_statesall_self_attentionsr   stage_moduler  r^   r   r   r   rG     s.   


zSwin2SREncoder.forward)FFT)r   r   r   rB   r   rL   r   rp   r   r   rG   rN   r   r   rD   r   r4  r  s$    
r4  c                   @   s6   e Zd ZU eed< dZdZdZdZe	
 dd ZdS )	Swin2SRPreTrainedModelr[   swin2srr\   )imageTc                 C   s  t |tjtjfr"tj|j| jjd |j	dur t
|j	 dS dS t |tjr6t
|j	 t|j dS t |tr[t|jtd | \}}t|j| t|j| dS t |tr|jjdkrz|jjdkrztg ddddd}ntdddd}t|j| dS dS )zInitialize the weights)stdNr   r   gw#?g8EGr?gB`"?r   )rc   r   r|   rh   inittrunc_normal_weightr[   initializer_ranger{   zeros_rj   ones_r   	constant_r   r   r   r   copy_r   r   Swin2SRModelr,   num_channels_outr   tensorr$   rU   mean)rC   moduler   r   r[  r   r   r   _init_weights  s&   


z$Swin2SRPreTrainedModel._init_weightsN)r   r   r   r   r   base_model_prefixmain_input_nameinput_modalitiessupports_gradient_checkpointingr   no_gradr]  r   r   r   r   rK    s   
 rK  c                       sh   e Zd Z fddZdd Zdd Ze			ddejd	e	dB d
e	dB de	dB de
eB f
ddZ  ZS )rX  c                    s   t  | || _|jdkr!|jdkr!tg ddddd}ntdddd}| j	d|dd |j
| _
t|j|jddd| _t|| _t|| jjjd| _tj|j|jd| _t|| _t|j|jddd| _|   d S )	Nr   rO  r   r[  Fr   )r<  r   )rA   rB   r[   r,   rY  r   rZ  r$   rU   r   	img_ranger   rh   rV   first_convolutionrO   r]   r4  rQ   rg   encoderrj   r   rk   rr   r.  conv_after_body	post_init)rC   r[   r[  rD   r   r   rB     s   

zSwin2SRModel.__init__c                 C   s   | j jS r@   )r]   rQ   rH   r   r   r   get_input_embeddings  s   z!Swin2SRModel.get_input_embeddingsc           	      C   sn   |  \}}}}| jj}|||  | }|||  | }tj|d|d|fd}| j|}|| | j }|S )Nr   reflect)	r   r[   r(   r   r   r   r[  type_asrc  )	rC   r\   ro   r*   r+   r(   modulo_pad_heightmodulo_pad_widthr[  r   r   r   pad_and_normalize  s   zSwin2SRModel.pad_and_normalizeNr\   r   rB  rC  r4   c                 K   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|j\}}}}| |}| |}	| |	\}
}| j|
||||d}|d }| 	|}| 
|||f}| ||	 }|se|f|dd   }|S t||j|jdS )Nr   rB  rC  r   r   rF  )r[   r   rB  use_return_dictr#   rm  rd  r]   re  rk   r.  rf  r   r   r   )rC   r\   r   rB  rC  kwargsro   r*   r+   r]   embedding_outputr   encoder_outputssequence_outputr=   r   r   r   rG     s6   	


zSwin2SRModel.forward)NNN)r   r   r   rB   rh  rm  r   r   r   r   r   r   rG   rN   r   r   rD   r   rX    s&    rX  c                       rq   )UpsamplezUpsample module.

    Args:
        scale (`int`):
            Scale factor. Supported scales: 2^n and 3.
        num_features (`int`):
            Channel number of intermediate features.
    c                    s   t    || _||d @ dkr<ttt|D ] }| d| t	|d| ddd | d| t
d qd S |dkrTt	|d| ddd| _t
d| _d S td	| d
)Nr   r   convolution_r   r   pixelshuffle_r   	   zScale z/ is not supported. Supported scales: 2^n and 3.)rA   rB   scaler(  rp   r   r   
add_moduler   rh   PixelShuffleconvolutionpixelshuffler   )rC   rx  num_featuresr   rD   r   r   rB   6  s   
$zUpsample.__init__c                 C   s|   | j | j d @ dkr-ttt| j D ]}| d| |}| d| |}q|S | j dkr<| |}| |}|S )Nr   r   ru  rv  r   )rx  r(  rp   r   r   __getattr__r{  r|  )rC   hidden_stater   r   r   r   rG   E  s   


zUpsample.forwardru   r   r   rD   r   rt  ,  s    	rt  c                       rq   )UpsampleOneStepa  UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)

    Used in lightweight SR to save parameters.

    Args:
        scale (int):
            Scale factor. Supported scales: 2^n and 3.
        in_channels (int):
            Channel number of intermediate features.
        out_channels (int):
            Channel number of output features.
    c                    s6   t    t||d | ddd| _t|| _d S )Nr   r   r   )rA   rB   r   rh   r+  rz  pixel_shuffle)rC   rx  in_channelsout_channelsrD   r   r   rB   `  s   
zUpsampleOneStep.__init__c                 C   r   r@   )r+  r  )rC   r6  r   r   r   rG   f  r   zUpsampleOneStep.forwardru   r   r   rD   r   r  R  s    r  c                       $   e Zd Z fddZdd Z  ZS )PixelShuffleUpsamplerc                    sV   t    t|j|ddd| _tjdd| _t|j	|| _
t||jddd| _d S Nr   r   Tr   )rA   rB   r   rh   rV   conv_before_upsampler,  
activationrt  upscaleupsamplerY  final_convolutionrC   r[   r}  rD   r   r   rB   n  s
   
zPixelShuffleUpsampler.__init__c                 C   s,   |  |}| |}| |}| |}|S r@   )r  r  r  r  )rC   rs  r6  r   r   r   rG   u  s
   



zPixelShuffleUpsampler.forwardr   r   r   rB   rG   rN   r   r   rD   r   r  m  s    r  c                       r  )NearestConvUpsamplerc                    s   t    |jdkrtdt|j|ddd| _tjdd| _	t||ddd| _
t||ddd| _t||ddd| _t||jddd| _tjddd| _d S )	Nr   zNThe nearest+conv upsampler only supports an upscale factor of 4 at the moment.r   r   Tr   r$  r%  )rA   rB   r  r   r   rh   rV   r  r,  r  conv_up1conv_up2conv_hrrY  r  lrelur  rD   r   r   rB     s   

zNearestConvUpsampler.__init__c              	   C   sn   |  |}| |}| | tjjj|ddd}| | tjjj|ddd}| 	| | 
|}|S )Nr   nearest)scale_factormode)r  r  r  r  r   r   r   interpolater  r  r  )rC   rs  reconstructionr   r   r   rG     s   

zNearestConvUpsampler.forwardr  r   r   rD   r   r  ~  s    r  c                       r  )PixelShuffleAuxUpsamplerc              	      s   t    |j| _t|j|ddd| _t|j|ddd| _tj	dd| _
t||jddd| _ttd|dddtj	dd| _t|j|| _t||jddd| _d S r  )rA   rB   r  r   rh   r,   conv_bicubicrV   r  r,  r  conv_auxr   conv_after_auxrt  r  rY  r  r  rD   r   r   rB     s   
$z!PixelShuffleAuxUpsampler.__init__c                 C   s   |  |}| |}| |}| |}| |}| |d d d d d || j d || j f |d d d d d || j d || j f  }| |}||fS r@   )r  r  r  r  r  r  r  r  )rC   rs  bicubicr*   r+   auxr  r   r   r   rG     s   




0*
z PixelShuffleAuxUpsampler.forwardr  r   r   rD   r   r    s    r  zm
    Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration.
    c                       sj   e Zd Z fddZe					ddejdB dejdB dedB dedB dedB d	e	e
B fd
dZ  ZS )Swin2SRForImageSuperResolutionc                    s   t  | t|| _|j| _|j| _d}| jdkr!t||| _n4| jdkr-t||| _n(| jdkr=t	|j|j
|j| _n| jdkrIt||| _nt|j
|jddd| _|   d S )N@   r|  pixelshuffle_auxpixelshuffledirectnearest+convr   r   )rA   rB   rX  rL  	upsamplerr  r  r  r  r  rV   rY  r  r   rh   r  rg  r  rD   r   r   rB     s   




z'Swin2SRForImageSuperResolution.__init__Nr\   labelsr   rB  rC  r4   c                 K   s\  |dur|n| j j}d}|durtd|jdd \}}	| j jdkr5tjj||| j |	| j fddd}
| j	||||d}|d	 }| jd
v rM| 
|}n!| jdkrg| 
||
||	\}}|| j	j | j	j }n|| | }|| j	j | j	j }|ddddd|| j d|	| j f }|s|f|dd  }|dur|f| S |S t|||j|jdS )a  
        Example:
         ```python
         >>> import torch
         >>> import numpy as np
         >>> from PIL import Image
         >>> import httpx
        >>> from io import BytesIO

         >>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution

         >>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
         >>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")

         >>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg"
         >>> with httpx.stream("GET", url) as response:
         ...     image = Image.open(BytesIO(response.read()))
         >>> # prepare image for the model
         >>> inputs = processor(image, return_tensors="pt")

         >>> # forward pass
         >>> with torch.no_grad():
         ...     outputs = model(**inputs)

         >>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
         >>> output = np.moveaxis(output, source=0, destination=-1)
         >>> output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
         >>> # you can visualize `output` with `Image.fromarray`
         ```Nz'Training is not supported at the momentr   r  r  F)r   r  align_cornersrn  r   )r|  r  r  r   )lossr  r   r   )r[   ro  NotImplementedErrorr#   r  r   r   r  r  rL  r  rc  r[  r  r	   r   r   )rC   r\   r  r   rB  rC  rp  r  r*   r+   r  r   rs  r  r  r=   r   r   r   rG     sH   '

,z&Swin2SRForImageSuperResolution.forward)NNNNN)r   r   r   rB   r   r   r   
LongTensorr   r   r	   rG   rN   r   r   rD   r   r    s*    r  )r  rX  rK  )r0   F);rJ   collections.abcrd   r   dataclassesr   r   r    r   rP  activationsr   modeling_layersr   modeling_outputsr   r	   modeling_utilsr
   utilsr   r   r   configuration_swin2srr   
get_loggerr   loggerr   r.   r/   rL   rK   r   r>   r   r?   rO   rP   rr   rv   r   r   r   r   r   r   r  r4  rK  rX  rt  r  r  r  r  r  __all__r   r   r   r   <module>   sd   
 7 
zD?[&r