o
    %ݫi                     @   s   d Z ddlZddlmZ ddlmZmZ ddlmZ ee	Z
G dd dejjZG dd	 d	ejjZG d
d dejjZG dd dejjZG dd dejjZG dd dejjZdS )zbLibrary implementing complex-valued recurrent neural networks.

Authors
 * Titouan Parcollet 2020
    N)CLinear)
CBatchNorm
CLayerNorm)
get_loggerc                       sJ   e Zd ZdZ							d fdd		Zd
d ZdddZdd Z  ZS )CLSTMa  This function implements a complex-valued LSTM.

    Input format is (batch, time, fea) or (batch, time, fea, channel).
    In the latter shape, the two last dimensions will be merged:
    (batch, time, fea * channel)

    Arguments
    ---------
    hidden_size : int
        Number of output neurons (i.e, the dimensionality of the output).
        Specified value is in term of complex-valued neurons. Thus, the output
        is 2*hidden_size.
    input_shape : tuple
        The expected shape of the input.
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    bias: bool, optional
        If True, the additive bias b is adopted (default True).
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    return_hidden : bool, optional
        It True, the function returns the last hidden layer.
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the complex-valued weights (default "glorot").
    weight_init : str, optional
        (complex, unitary).
        This parameter defines the initialization procedure of the
        complex-valued weights (default "complex"). "complex" will generate random complex-valued
        weights following the init_criterion and the complex polar form.
        "unitary" will normalize the weights to lie on the unit circle.
        More details in: "Deep Complex Networks", Trabelsi C. et al.

    Example
    -------
    >>> inp_tensor = torch.rand([10, 16, 40])
    >>> rnn = CLSTM(hidden_size=16, input_shape=inp_tensor.shape)
    >>> out_tensor = rnn(inp_tensor)
    >>>
    torch.Size([10, 16, 32])
       T        Fglorotcomplexc
           
         s   t    |d | _|| _|| _|| _|| _d| _|| _|| _	|	| _
t|dkr+d| _tt|dd  | _|d | _|  | _d S N   F   Tr   )super__init__hidden_size
num_layersbiasdropoutbidirectionalreshapereturn_hiddeninit_criterionweight_initlentorchprodtensorfea_dim
batch_size_init_layersrnn)
selfr   input_shaper   r   r   r   r   r   r   	__class__ [/home/ubuntu/.local/lib/python3.10/site-packages/speechbrain/nnet/complex_networks/c_RNN.pyr   C   s   


zCLSTM.__init__c                 C   sn   t jg }| j}t| jD ]&}t|| j| j| j| j	| j
| j| jd}|| | j
r1| jd }q| j}q|S )z
        Initializes the layers of the ComplexLSTM.

        Returns
        -------
        rnn : ModuleList
            The list of CLSTM_Layers.
        )r   r   r   r   r   )r   nn
ModuleListr   ranger   CLSTM_Layerr   r   r   r   r   r   appendr!   r    current_dimirnn_layr%   r%   r&   r   c   s$   

zCLSTM._init_layersNc                 C   ^   | j r|jdkr| |jd |jd |jd |jd  }| j||d\}}| jr-||fS |S )as  Returns the output of the CLSTM.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            The hidden layer.

        Returns
        -------
        output : torch.Tensor
            The output tensor.
        hh : torch.Tensor
            If return_hidden, the second tensor is hidden states.
           r   r   r   r   hxr   ndimshape_forward_rnnr   r!   xr3   outputhhr%   r%   r&   forward      
*zCLSTM.forwardc                 C      g }|dur| j r|| j| jd | j}t| jD ]%\}}|dur+|||| d}n||dd}||dddddf  qtj	|dd}| j r^||j
d d |j
d | j}||fS |dd}||fS )aX  Returns the output of the CLSTM.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            The hidden layer.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        h : torch.Tensor
            The hidden states for each step.
        Nr   r2   r   dimr   r   r   r   r   r   	enumerater    r+   r   stackr6   	transposer!   r9   r3   hr.   r/   r%   r%   r&   r7      s"    zCLSTM._forward_rnn)r   Tr   FFr	   r
   N	__name__
__module____qualname____doc__r   r   r<   r7   __classcell__r%   r%   r#   r&   r      s    3 
!r   c                       sT   e Zd ZdZ				d fdd	Zdd	d
Zdd Zdd Zdd Zdd Z	  Z
S )r*   a  This function implements complex-valued LSTM layer.

    Arguments
    ---------
    input_size : int
        Feature dimensionality of the input tensors (in term of real values).
    hidden_size : int
        Number of output values (in term of real values).
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    batch_size : int
        Batch size of the input tensors.
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    init_criterion : str, optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the complex-valued weights (default "glorot").
    weight_init : str, optional
        (complex, unitary).
        This parameter defines the initialization procedure of the
        complex-valued weights (default "complex"). "complex" will generate random complex-valued
        weights following the init_criterion and the complex polar form.
        "unitary" will normalize the weights to lie on the unit circle.
        More details in: "Deep Complex Networks", Trabelsi C. et al.
    r   Fr	   r
   c	           	         s   t    t|d | _t|| _|| _|| _|| _|| _|| _	t
| j| jd d| j	| jd| _t
| jd | jd d| j	| jd| _| jrK| jd | _| dtd| jd  | | j tjj| jdd| _td	g | _d S )
Nr   r1   Tr"   	n_neuronsr   r   r   h_initr   Fpinplace      ?)r   r   intr   
input_sizer   r   r   r   r   r   wuregister_bufferr   zeros
_init_dropr'   Dropoutdropr   floatdrop_mask_te)	r!   rW   r   r   r   r   r   r   r   r#   r%   r&   r      s8   

zCLSTM_Layer.__init__Nc                 C   s   | j r|d}tj||gdd}| | | |}|dur&| ||}n| || j}| j rG|jddd\}}|d}tj||gdd}|S )a/  Returns the output of the CRNN_layer.

        Arguments
        ---------
        x : torch.Tensor
            Linearly transformed input.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        h : torch.Tensor
            The hidden states for each step.
        r   r   r@   Nr   )	r   flipr   cat_change_batch_sizerX   _complexlstm_cellrQ   chunkr!   r9   r3   x_fliprX   rG   h_fh_br%   r%   r&   r<     s   



zCLSTM_Layer.forwardc              	   C   s   g }| j }| |}t|jd D ]a}|dd|f | | }|dd\}}	}
}}}}}ttj||	gdd}ttj|
|gdd}ttj||gdd}|t	tj||gdd | ||  }|t	| }|
| qtj|dd}|S )7  Returns the hidden states for each time step.

        Arguments
        ---------
        w : torch.Tensor
            Linearly transformed input.
        ht : torch.Tensor
            Hidden layer.

        Returns
        -------
        h : torch.Tensor
            The hidden states for each step.
        r   N   r?   r@   )rQ   _sample_drop_maskr)   r6   rY   re   r   sigmoidrb   tanhr+   rD   )r!   rX   hthiddensct	drop_maskkgatesitritiftrftiotrotictrctiitftotrG   r%   r%   r&   rd   E  s"   
zCLSTM_Layer._complexlstm_cellc              
   C   Z   t jj| jdd| _t dg | _d| _d| _	| 
d| t | j| jd j dS 	zwInitializes the recurrent dropout operation. To speed it up,
        the dropout masks are sampled in advance.
        FrR   rU   i>  r   
drop_masksr   Nr   r'   r]   r   r^   r   r_   r`   N_drop_masksdrop_mask_cntrZ   onesr   datar!   r   r%   r%   r&   r\   p     zCLSTM_Layer._init_dropc                 C      | j r6| j| j | jkr!d| _| tj| j| jd |jdj	| _
| j
| j| j| j  }| j| j | _|S | j|j| _| j}|S z,Selects one of the pre-defined dropout masksr   r   )devicetrainingr   r   r   r^   r   r   r   r   r   r   r`   tor!   rX   rr   r%   r%   r&   rl     $   zCLSTM_Layer._sample_drop_maskc                 C   L   | j |jd kr"|jd | _ | jr$| t| j| jd j| _	dS dS dS   This function changes the batch size when it is different from
        the one detected in the initialization method. This might happen in
        the case of multi-gpu or when we have different batch sizes in train
        and test. We also update the h_int and drop masks.
        r   r   N
r   r6   r   r^   r   r   r   r   r   r   r!   r9   r%   r%   r&   rc        zCLSTM_Layer._change_batch_size)r   Fr	   r
   rH   )rJ   rK   rL   rM   r   r<   rd   r\   rl   rc   rN   r%   r%   r#   r&   r*      s    %
3'+r*   c                       sL   e Zd ZdZ								d fd	d
	Zdd ZdddZdd Z  ZS )CRNNa@  This function implements a vanilla complex-valued RNN.

    Input format is (batch, time, fea) or (batch, time, fea, channel).
    In the latter shape, the two last dimensions will be merged:
    (batch, time, fea * channel)

    Arguments
    ---------
    hidden_size : int
        Number of output neurons (i.e, the dimensionality of the output).
        Specified value is in term of complex-valued neurons. Thus, the output
        is 2*hidden_size.
    input_shape : tuple
        The expected shape of the input.
    nonlinearity : str, optional
        Type of nonlinearity (tanh, relu) (default "tanh").
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    bias : bool, optional
        If True, the additive bias b is adopted (default True).
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    return_hidden : bool, optional
        It True, the function returns the last hidden layer (default False).
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the complex-valued weights (default "glorot").
    weight_init : str, optional
        (complex, unitary).
        This parameter defines the initialization procedure of the
        complex-valued weights (default "complex"). "complex" will generate random complex-valued
        weights following the init_criterion and the complex polar form.
        "unitary" will normalize the weights to lie on the unit circle.
        More details in: "Deep Complex Networks", Trabelsi C. et al.

    Example
    -------
    >>> inp_tensor = torch.rand([10, 16, 30])
    >>> rnn = CRNN(hidden_size=16, input_shape=inp_tensor.shape)
    >>> out_tensor = rnn(inp_tensor)
    >>>
    torch.Size([10, 16, 32])
    rn   r   Tr   Fr	   r
   c                    s   t    |d | _|| _|| _|| _|| _|| _d| _|| _	|	| _
|
| _t|dkr.d| _tt|dd  | _|d | _|  | _d S r   )r   r   r   nonlinearityr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    )r!   r   r"   r   r   r   r   r   r   r   r   r#   r%   r&   r     s    


zCRNN.__init__c                 C   sr   t jg }| j}t| jD ](}t|| j| j| j| j	| j
| j| j| jd	}|| | jr3| jd }q| j}q|S )z
        Initializes the layers of the CRNN.

        Returns
        -------
        rnn : ModuleList
            The list of CRNN_Layers.
        )r   r   r   r   r   r   )r   r'   r(   r   r)   r   
CRNN_Layerr   r   r   r   r   r   r   r+   r,   r%   r%   r&   r     s&   	
zCRNN._init_layersNc                 C   r0   )a  Returns the output of the vanilla CRNN.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layers.

        Returns
        -------
        output : torch.Tensor
            The outputs of the CliGRU.
        hh : torch.Tensor
            If return_hidden, also returns the hidden states for each step.
        r1   r   r   r   r   r2   r4   r8   r%   r%   r&   r<     r=   zCRNN.forwardc                 C   r>   )a_  Returns the output of the vanilla CRNN.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            The hidden layer.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        h : torch.Tensor
            The hidden states for each step.
        Nr   r2   r?   r   r@   r   rB   rF   r%   r%   r&   r7   =  s"    zCRNN._forward_rnn)rn   r   Tr   FFr	   r
   rH   rI   r%   r%   r#   r&   r     s    5"
"r   c                       sV   e Zd ZdZ					d fdd	Zdd
dZdd Zdd Zdd Zdd Z	  Z
S )r   a  This function implements complex-valued recurrent layer.

    Arguments
    ---------
    input_size : int
        Feature dimensionality of the input tensors (in term of real values).
    hidden_size : int
        Number of output values (in term of real values).
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    batch_size : int
        Batch size of the input tensors.
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    nonlinearity : str, optional
        Type of nonlinearity (tanh, relu) (default "tanh").
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the complex-valued weights (default "glorot").
    weight_init : str, optional
        (complex, unitary).
        This parameter defines the initialization procedure of the
        complex-valued weights (default "complex"). "complex" will generate random complex-valued
        weights following the init_criterion and the complex polar form.
        "unitary" will normalize the weights to lie on the unit circle.
        More details in: "Deep Complex Networks", Trabelsi C. et al.
    r   rn   Fr	   r
   c
           
         s  t    t|d | _t|| _|| _|| _|| _|| _|	| _	t
| j| jd| j	| jd| _t
| jd | jd| j	| jd| _| jrG| jd | _| dtd| jd  | | j tjj| jdd| _tdg | _|dkrytj | _d S tj | _d S )	Nr   FrO   rQ   r   rR   rU   rn   )r   r   rV   r   rW   r   r   r   r   r   r   rX   rY   rZ   r   r[   r\   r'   r]   r^   r   r_   r`   TanhactReLU)
r!   rW   r   r   r   r   r   r   r   r   r#   r%   r&   r     s>   

zCRNN_Layer.__init__Nc                 C   s   | j r|d}tj||gdd}| |}|dur!| ||}n| || j}| j rB|jddd\}}|d}tj||gdd}|S )a%  Returns the output of the CRNN_layer.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            The hidden layer.

        Returns
        -------
        h : torch.Tensor
            The hidden states for each step.
        r   r   r@   Nr   )r   ra   r   rb   rX   _complexrnn_cellrQ   re   rf   r%   r%   r&   r<     s   


zCRNN_Layer.forwardc                 C   sf   g }|  |}t|jd D ]}|dd|f | | }| || }|| qtj|dd}|S )a;  Returns the hidden states for each time step.

        Arguments
        ---------
        w : torch.Tensor
            Linearly transformed input.
        ht : torch.Tensor
            The hidden layer.

        Returns
        -------
        h : torch.Tensor
            The hidden states for each step.
        r   Nr@   )rl   r)   r6   rY   r   r+   r   rD   )r!   rX   ro   rp   rr   rs   atrG   r%   r%   r&   r     s   
zCRNN_Layer._complexrnn_cellc              
   C   r   r   r   r   r%   r%   r&   r\   	  r   zCRNN_Layer._init_dropc                 C   r   r   r   r   r%   r%   r&   rl     r   zCRNN_Layer._sample_drop_maskc                 C   r   r   r   r   r%   r%   r&   rc   2  r   zCRNN_Layer._change_batch_size)r   rn   Fr	   r
   rH   )rJ   rK   rL   rM   r   r<   r   r\   rl   rc   rN   r%   r%   r#   r&   r   g  s    '
9(r   c                       sN   e Zd ZdZ										d fd
d	Zdd ZdddZdd Z  ZS )CLiGRUaN
  This function implements a complex-valued Light GRU (liGRU).

    Ligru is single-gate GRU model based on batch-norm + relu
    activations + recurrent dropout. For more info see:

    "M. Ravanelli, P. Brakel, M. Omologo, Y. Bengio,
    Light Gated Recurrent Units for Speech Recognition,
    in IEEE Transactions on Emerging Topics in Computational Intelligence,
    2018" (https://arxiv.org/abs/1803.10225)

    To speed it up, it is compiled with the torch just-in-time compiler (jit)
    right before using it.

    It accepts in input tensors formatted as (batch, time, fea).
    In the case of 4d inputs like (batch, time, fea, channel) the tensor is
    flattened as (batch, time, fea*channel).

    Arguments
    ---------
    hidden_size : int
        Number of output neurons (i.e, the dimensionality of the output).
        Specified value is in term of complex-valued neurons. Thus, the output
        is 2*hidden_size.
    input_shape : tuple
        The expected size of the input.
    nonlinearity : str
        Type of nonlinearity (tanh, relu).
    normalization : str
        Type of normalization for the ligru model (batchnorm, layernorm).
        Every string different from batchnorm and layernorm will result
        in no normalization.
    num_layers : int
        Number of layers to employ in the RNN architecture.
    bias : bool
        If True, the additive bias b is adopted.
    dropout : float
        It is the dropout factor (must be between 0 and 1).
    bidirectional : bool
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used.
    return_hidden : bool
        If True, the function returns the last hidden layer.
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the complex-valued weights (default "glorot").
    weight_init : str, optional
        (complex, unitary).
        This parameter defines the initialization procedure of the
        complex-valued weights (default "complex"). "complex" will generate random complex-valued
        weights following the init_criterion and the complex polar form.
        "unitary" will normalize the weights to lie on the unit circle.
        More details in: "Deep Complex Networks", Trabelsi C. et al.

    Example
    -------
    >>> inp_tensor = torch.rand([10, 16, 30])
    >>> rnn = CLiGRU(input_shape=inp_tensor.shape, hidden_size=16)
    >>> out_tensor = rnn(inp_tensor)
    >>>
    torch.Size([4, 10, 5])
    relu	batchnormr   Tr   Fr	   r
   c                    s   t    |d | _|| _|| _|| _|| _|| _|| _d| _	|	| _
|
| _|| _t|dkr1d| _	tt|dd  | _|d | _|  | _d S r   )r   r   r   r   r   normalizationr   r   r   r   r   r   r   r   r   r   r   r   r   r   r    )r!   r   r"   r   r   r   r   r   r   r   r   r   r#   r%   r&   r     s"   


zCLiGRU.__init__c                 C   sv   t jg }| j}t| jD ]*}t|| j| j| j| j	| j
| j| j| j| jd
}|| | jr5| jd }q| j}q|S )zInitializes the layers of the liGRU.

        Returns
        -------
        rnn : ModuleList
            The list of CLiGRU_Layers.
        )r   r   r   r   r   r   r   )r   r'   r(   r   r)   r   CLiGRU_Layerr   r   r   r   r   r   r   r   r+   r,   r%   r%   r&   r     s(   	
zCLiGRU._init_layersNc                 C   r0   )a  Returns the output of the CliGRU.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layers.

        Returns
        -------
        output : torch.Tensor
            The outputs of the CliGRU.
        hh : torch.Tensor
            If return_hidden, also returns the hidden states for each step.
        r1   r   r   r   r   r2   )r   r5   r6   _forward_ligrur   r8   r%   r%   r&   r<     s   
*zCLiGRU.forwardc                 C   r>   )aY  Returns the output of the CliGRU.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            The hidden layer.

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        h : torch.Tensor
            The hidden states for each step.
        Nr   r2   r?   r   r@   r   rB   )r!   r9   r3   rG   r.   	ligru_layr%   r%   r&   r     s"    zCLiGRU._forward_ligru)	r   r   r   Tr   FFr	   r
   rH   )	rJ   rK   rL   rM   r   r   r<   r   rN   r%   r%   r#   r&   r   B  s    D"
!r   c                       sX   e Zd ZdZ						d fdd		ZdddZdd Zdd Zdd Zdd Z	  Z
S )r   a  
    This function implements complex-valued Light-Gated Recurrent Unit layer.

    Arguments
    ---------
    input_size : int
        Feature dimensionality of the input tensors.
    hidden_size : int
        Number of output values.
    num_layers : int
        Number of layers to employ in the RNN architecture.
    batch_size : int
        Batch size of the input tensors.
    dropout : float
        It is the dropout factor (must be between 0 and 1).
    nonlinearity : str
        Type of nonlinearity (tanh, relu).
    normalization : str
        Type of normalization (batchnorm, layernorm).
        Every string different from batchnorm and layernorm will result
        in no normalization.
    bidirectional : bool
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used.
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the complex-valued weights (default "glorot").
    weight_init : str, optional
        (complex, unitary).
        This parameter defines the initialization procedure of the
        complex-valued weights (default "complex"). "complex" will generate random complex-valued
        weights following the init_criterion and the complex polar form.
        "unitary" will normalize the weights to lie on the unit circle.
        More details in: "Deep Complex Networks", Trabelsi C. et al.
    r   r   r   Fr	   r
   c                    s  t    t|d | _t|| _|| _|| _|| _|	| _|
| _	|| _
|| _t| j| jd d| j	| jd| _t| jd | jd d| j	| jd| _| jrQ| jd | _d| _| j
dkrgt|d ddd| _d| _n| j
d	kryt|d dd
| _d| _nt|d dd
| _d| _| dtd| jd  | | j tjj| jdd| _tdg | _| jdkrtj | _d S tj | _d S )Nr   FrO   r   r?   g?)rW   rA   momentumT	layernorm)rW   rA   rQ   r   rR   rU   rn   ) r   r   rV   r   rW   r   r   r   r   r   r   r   r   rX   rY   	normalizer   normr   rZ   r   r[   r\   r'   r]   r^   r   r_   r`   r   r   r   )r!   rW   r   r   r   r   r   r   r   r   r   r#   r%   r&   r   5  sX   





zCLiGRU_Layer.__init__Nc           	      C   s   | j r|d}tj||gdd}| | | |}| jrB| ||j	d |j	d  |j	d }||j	d |j	d |j	d }|durM| 
||}n| 
|| j}| j rn|jddd\}}|d}tj||gdd}|S )a*  Returns the output of the Complex liGRU layer.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        h : torch.Tensor
            The hidden states for each step.
        r   r   r@   r   N)r   ra   r   rb   rc   rX   r   r   r   r6   _complex_ligru_cellrQ   re   )	r!   r9   r3   rg   rX   w_bnrG   rh   ri   r%   r%   r&   r<     s    


( 
zCLiGRU_Layer.forwardc                 C   s   g }|  |}t|jd D ]F}|dd|f | | }|dd\}}}	}
tj||gdd}tj|	|
gdd}t|}| || }|| d| |  }|	| qtj
|dd}|S )rj   r   Nr1   r?   r@   )rl   r)   r6   rY   re   r   rb   rm   r   r+   rD   )r!   rX   ro   rp   rr   rs   rt   atratiztrztir   zthcandrG   r%   r%   r&   r     s   

z CLiGRU_Layer._complex_ligru_cellc              
   C   r   r   r   r   r%   r%   r&   r\     r   zCLiGRU_Layer._init_dropc                 C   r   r   r   r   r%   r%   r&   rl     r   zCLiGRU_Layer._sample_drop_maskc                 C   sH   | j |jd kr |jd | _ | jr"| t| j| jj| _	dS dS dS )r   r   Nr   r   r%   r%   r&   rc     s   zCLiGRU_Layer._change_batch_size)r   r   r   Fr	   r
   rH   )rJ   rK   rL   rM   r   r<   r   r\   rl   rc   rN   r%   r%   r#   r&   r     s    ,
N-$r   )rM   r   *speechbrain.nnet.complex_networks.c_linearr   1speechbrain.nnet.complex_networks.c_normalizationr   r   speechbrain.utils.loggerr   rJ   loggerr'   Moduler   r*   r   r   r   r   r%   r%   r%   r&   <module>   s"     9 _ ? \ M