o
    ߥi                     @   s"  d dl Z d dlZd dlZd dlm  mZ d dlmZ ddlm	Z	m
Z
 G dd dejZG dd deZG d	d
 d
eZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd  d eZG d!d" d"eZG d#d$ d$eZG d%d& d&eZG d'd( d(eZG d)d* d*eZG d+d, d,eZG d-d. d.ej j!Z"G d/d0 d0eZ#i ded
edededededededededededed&ed,ed0e#Z$d1e%fd2d3Z&dS )4    N)nn   )#create_netblock_list_from_str_innerget_right_parentheses_indexc                       sb   e Zd Z					d fdd	Zdd Zdd	 Zd
d Zdd ZedddZ	edd Z
  ZS )PlainNetBasicBlockClassNr   Fc                    sX   t t| jdi | || _|| _|| _|| _|| _| jd u r*dt	
 j| _d S d S )Nuuid{} )superr   __init__in_channelsout_channelsstride	no_create
block_nameformatuuiduuid4hex)selfr   r   r   r   r   kwargs	__class__r   k/home/ubuntu/.local/lib/python3.10/site-packages/modelscope/models/cv/tinynas_classfication/basic_blocks.pyr
      s   
z PlainNetBasicBlockClass.__init__c                 C      t dNNot implementedRuntimeErrorr   xr   r   r   forward!      zPlainNetBasicBlockClass.forwardc                 C   s   t | jd| j| j| j S )Nz
({},{},{}))type__name__r   r   r   r   r   r   r   r   __str__$   s   zPlainNetBasicBlockClass.__str__c                 C   "   t | jd| j| j| j| j S )Nz({}|{},{},{}))r"   r#   r   r   r   r   r   r$   r   r   r   __repr__(      z PlainNetBasicBlockClass.__repr__c                 C   r   r   r   r   input_resolutionr   r   r   get_output_resolution,   r!   z-PlainNetBasicBlockClass.get_output_resolutionc                 K   s   t |sJ t|}|d usJ |t| jd | }|d}|dk r.dt j	}n|d| }||d d  }|
d}t|d }	t|d }
t|d }| |	|
|||d||d d  fS )	N(|r   r   r   ,   )r   r   r   r   r   )r   is_instance_from_strr   lenr#   findr   r   r   r   splitint)clssr   r   idx	param_strtmp_idxtmp_block_nameparam_str_splitr   r   r   r   r   r   create_from_str/   s,   

z'PlainNetBasicBlockClass.create_from_strc                 C   s$   | | jd r|d dkrdS dS )Nr,   )TF)
startswithr#   )r5   r6   r   r   r   r0   H   s   z,PlainNetBasicBlockClass.is_instance_from_str)NNr   FNF)r#   
__module____qualname__r
   r    r%   r'   r+   classmethodr<   r0   __classcell__r   r   r   r   r      s    r   c                       L   e Zd Zd fdd	Zdd Zdd Zdd	 Zd
d ZedddZ	  Z
S )AdaptiveAvgPoolFc                    sP   t t| jdi | || _|| _|| _|| _|s&tj| j| jfd| _	d S d S )Noutput_sizer   )
r	   rF   r
   r   r   rH   r   r   AdaptiveAvgPool2dnetblock)r   r   rH   r   r   r   r   r   r
   R   s   
zAdaptiveAvgPool.__init__c                 C   
   |  |S NrJ   r   r   r   r   r    \      
zAdaptiveAvgPool.forwardc                 C   s$   t | jd| j| jd  | j S )Nz({},{})r/   )r"   r#   r   r   rH   r$   r   r   r   r%   _   s   zAdaptiveAvgPool.__str__c                 C   s(   t | jd| j| j| jd  | j S )Nz
({}|{},{})r/   )r"   r#   r   r   r   rH   r$   r   r   r   r'   c   s   zAdaptiveAvgPool.__repr__c                 C   s   | j S rL   rG   r)   r   r   r   r+   h   s   z%AdaptiveAvgPool.get_output_resolutionc                 K   s   t |sJ t|}|d usJ |td| }|d}|dk r+dt j}n|d| }||d d  }|	d}t
|d }	t
|d }
t |	|
||d||d d  fS )NzAdaptiveAvgPool(r-   r   r   r   r.   )r   rH   r   r   )rF   r0   r   r1   r2   r   r   r   r   r3   r4   )r5   r6   r   r   r7   r8   r9   r:   r;   r   rH   r   r   r   r<   k   s(   

zAdaptiveAvgPool.create_from_strr@   r#   rA   rB   r
   r    r%   r'   r+   rC   r<   rD   r   r   r   r   rF   P   s    
rF   c                       R   e Zd Z			d fdd	Zdd Zdd Zd	d
 Zdd ZedddZ	  Z
S )BNNFc                    s   t t| jdi | || _|d ur8t|tjsJ |jjd | _	|jjd | _
|d u s3|| j
ks3J || _d S || _	|| _
|rBd S tj| j
d| _d S )Nr   )num_featuresr   )r	   rQ   r
   r   
isinstancer   BatchNorm2dweightshaper   r   rJ   r   r   	copy_fromr   r   r   r   r   r
      s   
zBN.__init__c                 C   rK   rL   rM   r   r   r   r   r       rN   z
BN.forwardc                 C      d | jS )NzBN({})r   r   r$   r   r   r   r%         z
BN.__str__c                 C      d | j| jS )Nz	BN({}|{})r   r   r   r$   r   r   r   r'         zBN.__repr__c                 C      |S rL   r   r)   r   r   r   r+         zBN.get_output_resolutionc           	      K   s   t |sJ t|}|d usJ |td| }|d}|dk r+dt j}n|d| }||d d  }t	|}t |||d||d d  fS )NzBN(r-   r   r   r   )r   r   r   )
rQ   r0   r   r1   r2   r   r   r   r   r4   	r5   r6   r   r   r7   r8   r9   r:   r   r   r   r   r<      s"   
zBN.create_from_strNNFr@   rO   r   r   r   r   rQ      s    rQ   c                       sZ   e Zd Z							d fdd	Zdd Zdd	 Zd
d Zdd ZedddZ	  Z
S )ConvKXNr   Fc           	   	      sN  t t| jdi | || _|d urct|tjsJ |j| _|j| _|j	d | _	|j
d | _
|j| _|d u s=|| jks=J |d u sH|| jksHJ |d u sS|| j	ksSJ |d u s^|| j
ks^J || _d S || _|| _|| _
|| _|| _	| j	d d | _|s| jdks| jdks| j	dks| j
dkrd S tj| j| j| j	| j
| jd| jd| _d S Nr   r   r/   F)r   r   kernel_sizer   paddingbiasgroupsr   )r	   rc   r
   r   rS   r   Conv2dr   r   re   r   rh   rJ   rf   )	r   r   r   re   r   rh   rX   r   r   r   r   r   r
      s@   	
"
zConvKX.__init__c                 C   rK   rL   rM   r   r   r   r   r       rN   zConvKX.forwardc                 C   r&   )Nz({},{},{},{}))r"   r#   r   r   r   re   r   r$   r   r   r   r%      r(   zConvKX.__str__c                 C   s&   t | jd| j| j| j| j| j S )Nz({}|{},{},{},{}))r"   r#   r   r   r   r   re   r   r$   r   r   r   r'      s   zConvKX.__repr__c                 C   
   || j  S rL   r   r)   r   r   r   r+      rN   zConvKX.get_output_resolutionc                 K   s   |  |sJ t|}|d usJ |t| jd | }|d}|dk r.dt j}n|d| }||d d  }|	d}t
|d }	t
|d }
t
|d }t
|d }| |	|
||||d	||d d  fS )
Nr,   r-   r   r   r   r.   r/      )r   r   re   r   r   r   )r0   r   r1   r#   r2   r   r   r   r   r3   r4   )r5   r6   r   r   r7   r8   r9   r:   	split_strr   r   re   r   r   r   r   r<      s0   

zConvKX.create_from_str)NNNNr   NFr@   rO   r   r   r   r   rc      s    ,rc   c                       sV   e Zd Z					d fdd	Zdd Zdd Zd	d
 Zdd ZedddZ	  Z
S )ConvDWNFc              	      s:  t t| jdi | || _|d ur\t|tjsJ |j| _|j| _|j	d | _	|j
d | _
| j| jks6J |d u sA|| jksAJ |d u sL|| j	ksLJ |d u sW|| j
ksWJ || _d S || _|| _|| _
|| _	| j	d d | _|s| jdks| jdks| j	dks| j
dkrd S tj| j| j| j	| j
| jd| jd| _d S rd   )r	   rn   r
   r   rS   r   ri   r   r   re   r   rJ   rf   )r   r   re   r   rX   r   r   r   r   r   r
     s<   
"
zConvDW.__init__c                 C   rK   rL   rM   r   r   r   r   r    @  rN   zConvDW.forwardc                 C      d | j| j| jS )NzConvDW({},{},{})r   r   re   r   r$   r   r   r   r%   C     zConvDW.__str__c                 C      d | j| j| j| jS )NzConvDW({}|{},{},{})r   r   r   re   r   r$   r   r   r   r'   G     zConvDW.__repr__c                 C   rj   rL   rk   r)   r   r   r   r+   K  rN   zConvDW.get_output_resolutionc                 K      t |sJ t|}|d usJ |td| }|d}|dk r+dt j}n|d| }||d d  }|	d}t
|d }	t
|d }
t
|d }t |	|
|||d||d d  fS )	NzConvDW(r-   r   r   r   r.   r/   r   re   r   r   r   )rn   r0   r   r1   r2   r   r   r   r   r3   r4   )r5   r6   r   r   r7   r8   r9   r:   rm   r   re   r   r   r   r   r<   N  ,   

zConvDW.create_from_str)NNNNFr@   rO   r   r   r   r   rn     s    +rn   c                       *   e Zd Z						d fdd	Z  ZS )ConvKXG2NFc              
      *   t t| jd||||||dd| d S )Nr/   r   r   re   r   rX   r   rh   r   )r	   ry   r
   r   r   r   re   r   rX   r   r   r   r   r   r
   i     
zConvKXG2.__init__NNNNNFr#   rA   rB   r
   rD   r   r   r   r   ry   g      ry   c                       rx   )ConvKXG4NFc              
      rz   )N   r{   r   )r	   r   r
   r|   r   r   r   r
   ~  r}   zConvKXG4.__init__r~   r   r   r   r   r   r   |  r   r   c                       rx   )ConvKXG8NFc              
      rz   )N   r{   r   )r	   r   r
   r|   r   r   r   r
     r}   zConvKXG8.__init__r~   r   r   r   r   r   r     r   r   c                       rx   )	ConvKXG16NFc              
      rz   )N   r{   r   )r	   r   r
   r|   r   r   r   r
     r}   zConvKXG16.__init__r~   r   r   r   r   r   r     r   r   c                       rx   )	ConvKXG32NFc              
      rz   )N    r{   r   )r	   r   r
   r|   r   r   r   r
     r}   zConvKXG32.__init__r~   r   r   r   r   r   r     r   r   c                       rE   )FlattenFc                    ,   t t| jdi | || _|| _|| _d S Nr   )r	   r   r
   r   r   r   r   r   r   r   r   r   r   r
        
zFlatten.__init__c                 C   s   t |dS Nr   )torchflattenr   r   r   r   r      r[   zFlatten.forwardc                 C   rY   )NzFlatten({})rZ   r$   r   r   r   r%     r[   zFlatten.__str__c                 C   r\   )NzFlatten({}|{})r]   r$   r   r   r   r'     r^   zFlatten.__repr__c                 C   s   dS r   r   r)   r   r   r   r+     r`   zFlatten.get_output_resolutionc           	      K      t |sJ t|}|d usJ |td| }|d}|dk r+dt j}n|d| }||d d  }t	|}t |||d||d d  fS )NzFlatten(r-   r   r   r   r   r   r   )
r   r0   r   r1   r2   r   r   r   r   r4   ra   r   r   r   r<     "   
zFlatten.create_from_strr@   rO   r   r   r   r   r         r   c                       sV   e Zd Z					d fdd	Zdd Zdd	 Zd
d Zdd ZedddZ	  Z
S )LinearNTFc                    s   t t| jdi | || _|d urIt|tjsJ |jjd | _|jjd | _	|j
d u| _|d u s9|| jks9J |d u sD|| j	ksDJ || _d S || _|| _	|| _|sbtj| j| j	| jd| _d S d S )Nr   r   )rg   r   )r	   r   r
   r   rS   r   rU   rV   r   r   rg   use_biasrJ   )r   r   r   rg   rX   r   r   r   r   r   r
     s$   
zLinear.__init__c                 C   rK   rL   rM   r   r   r   r   r      rN   zLinear.forwardc                 C   s   d | j| jt| jS )NzLinear({},{},{}))r   r   r   r4   r   r$   r   r   r   r%     rt   zLinear.__str__c                 C   s   d | j| j| jt| jS )NzLinear({}|{},{},{}))r   r   r   r   r4   r   r$   r   r   r   r'     s   zLinear.__repr__c                 C   s   |dksJ dS r   r   r)   r   r   r   r+   "  s   zLinear.get_output_resolutionc                 K   s   t |sJ t|}|d usJ |td| }|d}|dk r+dt j}n|d| }||d d  }|	d}t
|d }	t
|d }
t
|d }t |	|
|dk||d||d d  fS )	NzLinear(r-   r   r   r   r.   r/   )r   r   rg   r   r   )r   r0   r   r1   r2   r   r   r   r   r3   r4   )r5   r6   r   r   r7   r8   r9   r:   rm   r   r   r   r   r   r   r<   &  s,   

zLinear.create_from_str)NNTNFr@   rO   r   r   r   r   r     s    r   c                       sN   e Zd Z	d fdd	Zdd Zdd Zdd	 Zd
d ZedddZ	  Z
S )MaxPoolFc                    sf   t t| jdi | || _|| _|| _|| _|d d | _|| _|s1t	j
| j| j| jd| _d S d S )Nr   r/   )re   r   rf   r   )r	   r   r
   r   r   re   r   rf   r   r   	MaxPool2drJ   )r   r   re   r   r   r   r   r   r   r
   B  s   zMaxPool.__init__c                 C   rK   rL   rM   r   r   r   r   r    U  rN   zMaxPool.forwardc                 C   ro   )NzMaxPool({},{},{})rp   r$   r   r   r   r%   X  rq   zMaxPool.__str__c                 C   rr   )NzMaxPool({}|{},{},{})rs   r$   r   r   r   r'   \  s   zMaxPool.__repr__c                 C   rj   rL   rk   r)   r   r   r   r+   a  rN   zMaxPool.get_output_resolutionc                 K   ru   )	NzMaxPool(r-   r   r   r   r.   r/   rv   )r   r0   r   r1   r2   r   r   r   r   r3   r4   )r5   r6   r   r   r7   r8   r9   r:   r;   r   re   r   r   r   r   r<   d  rw   zMaxPool.create_from_strr@   rO   r   r   r   r   r   @  s    r   c                       rE   )
SequentialFc                    sr   t t| jdi | || _|st|| _|d j| _|d j| _|| _	d}| jD ]}|
|}q*d| | _d S )Nr   r=      r   )r	   r   r
   
block_listr   
ModuleListmodule_listr   r   r   r+   r   )r   r   r   r   resblockr   r   r   r
     s   
zSequential.__init__c                 C   s   |}| j D ]}||}q|S rL   r   r   r   outputinner_blockr   r   r   r      s   

zSequential.forwardc                 C   s(   d}| j D ]}|t|7 }q|d7 }|S )NSequential(r>   )r   strr   r6   r   r   r   r   r%     s
   
zSequential.__str__c                 C      t | S rL   r   r$   r   r   r   r'     r!   zSequential.__repr__c                 C      |}| j D ]}||}q|S rL   r   r+   r   r*   the_res	the_blockr   r   r   r+     s   
z Sequential.get_output_resolutionc           
      K   s   t |sJ t|}|tdd | }|d}|dk r'dt j}n|d| }||d d  }t	|t
|d\}}	t|	dksFJ |d u sPt|dkrRdS t |||dd	fS )
Nr   r   r-   r   r   netblocks_dictr   )N )r   r   r   r   )r   r0   r   r1   r2   r   r   r   r   r   bottom_basic_dict)
r5   r6   r   r   the_right_paraen_idxr8   r9   r:   the_block_listremaining_sr   r   r   r<     s*   

zSequential.create_from_strr@   rO   r   r   r   r   r   }  s    r   c                       rE   )MultiSumBlockFc                    s   t t| jdi | || _|st|| _tdd |D | _	tdd |D | _
|| _d}| jd |}d| | _d S )Nc                 S      g | ]}|j qS r   r   .0r   r   r   r   
<listcomp>      z*MultiSumBlock.__init__.<locals>.<listcomp>c                 S   r   r   r   r   r   r   r   r     r   r   r   r   )r	   r   r
   r   r   r   r   npmaxr   r   r   r+   r   r   r   r   r   r   r   r   r   r
        zMultiSumBlock.__init__c                 C   s6   | j d |}| j dd  D ]
}||}|| }q|S Nr   r   r   )r   r   r   r   output2r   r   r   r      s
   
zMultiSumBlock.forwardc                 C   @   d | j}| jD ]
}|t|d 7 }q	|d d }|d7 }|S )NzMultiSumBlock({}|;r=   r>   r   r   r   r   r   r   r   r   r%     s   
zMultiSumBlock.__str__c                 C   r   rL   r   r$   r   r   r   r'     r!   zMultiSumBlock.__repr__c                 C   2   | j d |}| j D ]}|||ksJ q|S Nr   r   r   r   r   r   r+        
z#MultiSumBlock.get_output_resolutionc                 K   s   t |sJ t|}|d usJ |td| }|d}|dk r+dt j}n|d| }||d d  }|}g }	t|dkrpt	|t
|d\}
}|}|
d u rSnt|
dkra|	|
d  n	|	t|
|d t|dksC	 t|	dkrd ||d d  fS t |	||d||d d  fS )	NzMultiSumBlock(r-   r   r   r   r   r   r   r   r   r   )r   r0   r   r1   r2   r   r   r   r   r   r   appendr   r5   r6   r   r   r7   r8   r9   r:   the_sr   tmp_block_listr   r   r   r   r<     sD   


zMultiSumBlock.create_from_strr@   rO   r   r   r   r   r     s    r   c                       rE   )MultiCatBlockFc                    s   t t| jdi | || _|st|| _tdd |D | _	t
dd |D | _|| _d}| jd |}d| | _d S )Nc                 S   r   r   r   r   r   r   r   r     r   z*MultiCatBlock.__init__.<locals>.<listcomp>c                 S   r   r   r   r   r   r   r   r     r   r   r   r   )r	   r   r
   r   r   r   r   r   r   r   sumr   r   r+   r   r   r   r   r   r
   	  r   zMultiCatBlock.__init__c                 C   s0   g }| j D ]}||}|| qtj|ddS )Nr   )dim)r   r   r   cat)r   r   output_listr   r   r   r   r   r      s
   
zMultiCatBlock.forwardc                 C   r   )NzMultiCatBlock({}|r   r=   r>   r   r   r   r   r   r%     s   
zMultiCatBlock.__str__c                 C   r   rL   r   r$   r   r   r   r'   '  r!   zMultiCatBlock.__repr__c                 C   r   r   r   r   r   r   r   r+   *  r   z#MultiCatBlock.get_output_resolutionc                 K   s  t |sJ t|}|d usJ |td| }|d}|dk r+dt j}n|d| }||d d  }|}g }	t|dkrpt	|t
|d\}
}|}|
d u rSnt|
dkra|	|
d  n	|	t|
|d t|dksCt|	dkrd ||d d  fS t |	||d||d d  fS )	NzMultiCatBlock(r-   r   r   r   r   r   r   )r   r0   r   r1   r2   r   r   r   r   r   r   r   r   r   r   r   r   r<   1  sB   


zMultiCatBlock.create_from_strr@   rO   r   r   r   r   r     s    	r   c                       rE   )RELUFc                    r   r   )r	   r   r
   r   r   r   r   r   r   r   r
   X  r   zRELU.__init__c                 C   
   t |S rL   )Frelur   r   r   r   r    ^  rN   zRELU.forwardc                 C   rY   )NzRELU({})rZ   r$   r   r   r   r%   a  r[   zRELU.__str__c                 C   r\   )NzRELU({}|{})r]   r$   r   r   r   r'   d  r^   zRELU.__repr__c                 C   r_   rL   r   r)   r   r   r   r+   g  r`   zRELU.get_output_resolutionc           	      K   r   )NzRELU(r-   r   r   r   r   )
r   r0   r   r1   r2   r   r   r   r   r4   ra   r   r   r   r<   j  r   zRELU.create_from_strr@   rO   r   r   r   r   r   V  r   r   c                       V   e Zd ZdZ			d fdd	Zdd Zdd	 Zd
d Zdd Ze	dddZ
  ZS )ResBlockzz
    ResBlock(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels
    NFc                    s   t t| jdi | || _|| _|| _|st|| _|d u r'|d j	| _	n|| _	|d j
| _
| jd u rAd}| |}|| | _d | _| jdksO| j	| j
krett| j	| j
d| jt| j
| _d S d S Nr   r=   r   r   r   )r	   r   r
   r   r   r   r   r   r   r   r   r+   projr   ri   rT   r   r   r   r   r   r   tmp_input_restmp_output_resr   r   r   r
     s*   




zResBlock.__init__c                 C   sR   t | jdkr	|S |}| jD ]}||}q| jd ur#|| | }|S || }|S r   r1   r   r   r   r   r   r   r      s   


zResBlock.forwardc                 C   4   d | j| j}| jD ]}|t|7 }q|d7 }|S )NzResBlock({},{},r>   r   r   r   r   r   r   r   r   r   r%     
   
zResBlock.__str__c                 C   8   d | j| j| j}| jD ]}|t|7 }q|d7 }|S )NzResBlock({}|{},{},r>   r   r   r   r   r   r   r   r   r   r   r'        
zResBlock.__repr__c                 C   r   rL   r   r   r   r   r   r+        
zResBlock.get_output_resolutionc                 K     t |sJ t|}|d usJ d }|td| }|d}|dk r-dt j}n|d| }||d d  }|d}	|	dk sL|d|	 	 sXd }
t
|t|d\}}nEt|d|	 }
||	d d  }|d}|dk sy|d| 	 st
|t|d\}}nt|d| }||d d  }t
|t|d\}}	 	 t|dksJ |d u st|dkrd ||d d  fS t ||
|||d||d d  fS )	Nz	ResBlock(r-   r   r   r   r.   r   r   r   r   r   r   )r   r0   r   r1   r2   r   r   r   r   isdigitr   r   r4   r5   r6   r   r   r7   
the_strider8   r9   r:   first_comma_indexr   r   r   second_comma_indexr   r   r   r<     h   




zResBlock.create_from_strrb   r@   r#   rA   rB   __doc__r
   r    r%   r'   r+   rC   r<   rD   r   r   r   r   r   ~  s    	r   c                       r   )ResBlockProjz~
    ResBlockProj(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels
    NFc                    s   t t| jdi | || _|| _|| _|st|| _|d u r'|d j	| _	n|| _	|d j
| _
| jd u rAd}| |}|| | _tt| j	| j
d| jt| j
| _d S r   )r	   r   r
   r   r   r   r   r   r   r   r   r+   r   ri   rT   r   r   r   r   r   r
     s$   




zResBlockProj.__init__c                 C   s<   t | jdkr	|S |}| jD ]}||}q|| | }|S r   r   r   r   r   r   r       s   

zResBlockProj.forwardc                 C   r   )NzResBlockProj({},{},r>   r   r   r   r   r   r%   *  r   zResBlockProj.__str__c                 C   r   )NzResBlockProj({}|{},{},r>   r   r   r   r   r   r'   2  r   zResBlockProj.__repr__c                 C   r   rL   r   r   r   r   r   r+   ;  r   z"ResBlockProj.get_output_resolutionc                 K   r   )	NzResBlockProj(r-   r   r   r   r.   r   r   )r   r0   r   r1   r2   r   r   r   r   r   r   r   r4   r   r   r   r   r<   B  r   zResBlockProj.create_from_strrb   r@   r   r   r   r   r   r     s    
	r   c                       rP   )SENFc                    s   t t| jdi | || _|d urtd|| _|| _d| _tdt	t
| j| j | _|s4| jdkr6d S ttdtj| j| jdddddt| jt tj| j| jdddddt| jt | _d S )	Nr   g      ?r   r   )r   r   F)r   r   re   r   rf   rg   r   )r	   r   r
   r   r   r   r   se_ratior   r4   roundse_channelsr   r   rI   ri   rT   ReLUSigmoidrJ   rW   r   r   r   r
   y  sF   


zSE.__init__c                 C   s   |  |}|| S rL   rM   )r   r   se_xr   r   r   r      s   
z
SE.forwardc                 C   rY   )NzSE({})rZ   r$   r   r   r   r%     r[   z
SE.__str__c                 C   r\   )Nz	SE({}|{})r]   r$   r   r   r   r'     r^   zSE.__repr__c                 C   r_   rL   r   r)   r   r   r   r+     r`   zSE.get_output_resolutionc           	      K   r   )NzSE(r-   r   r   r   r   )
r   r0   r   r1   r2   r   r   r   r   r4   ra   r   r   r   r<     r   zSE.create_from_strrb   r@   rO   r   r   r   r   r   w  s    &r   c                   @   s$   e Zd Zedd Zedd ZdS )SwishImplementationc                 C   s   |t | }| | |S rL   )r   sigmoidsave_for_backward)ctxiresultr   r   r   r      s   
zSwishImplementation.forwardc                 C   s,   | j d }t|}||d|d|     S r   )saved_variablesr   r   )r   grad_outputr   	sigmoid_ir   r   r   backward  s   

zSwishImplementation.backwardN)r#   rA   rB   staticmethodr    r  r   r   r   r   r     s
    
r   c                       rP   )SwishNFc                    s<   t t| jdi | || _|d urtd|| _|| _d S )Nr   r   )r	   r  r
   r   r   r   r   rW   r   r   r   r
     s   
zSwish.__init__c                 C   r   rL   )r   applyr   r   r   r   r      rN   zSwish.forwardc                 C   rY   )Nz	Swish({})rZ   r$   r   r   r   r%     r[   zSwish.__str__c                 C   r\   )NzSwish({}|{})r]   r$   r   r   r   r'     r^   zSwish.__repr__c                 C   r_   rL   r   r)   r   r   r   r+     r`   zSwish.get_output_resolutionc           	      K   r   )NzSwish(r-   r   r   r   r   )
r  r0   r   r1   r2   r   r   r   r   r4   ra   r   r   r   r<     r   zSwish.create_from_strrb   r@   rO   r   r   r   r   r    s    r  r   c                 C   s(   t ttttd}| | | t | S )N)r   r   r   r   r   )r   r   r   r   r   updater   )r   this_py_file_netblocks_dictr   r   r   register_netblocks_dict  s   

r  )'r   numpyr   r   torch.nn.functionalr   
functionalr   global_utilsr   r   Moduler   rF   rQ   rc   rn   ry   r   r   r   r   r   r   r   r   r   r   r   r   r   r   autogradFunctionr   r  r   dictr  r   r   r   r   <module>   s~   A38XT(H=<NO( yI0	
