o
    eiL                    @   s  d Z ddlZddlmZ ddlmZ ddlZddlmZ ddl	m
Z
 ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZ ddlmZmZmZ ddlmZ e e!Z"G dd dej#Z$G dd dej#Z%G dd dej#Z&G dd dej#Z'G dd dej#Z(G dd dej#Z)		ddej#dej*dej*d ej*d!ej*dB d"e+dB d#e+d$ee fd%d&Z,G d'd( d(ej#Z-G d)d* d*ej#Z.G d+d, d,ej#Z/G d-d. d.ej#Z0G d/d0 d0ej#Z1G d1d2 d2ej#Z2G d3d4 d4ej#Z3eG d5d6 d6e
Z4G d7d8 d8ej#Z5		9	dd:ej*d;e+d<e6dB d=e7d>e8f
d?d@Z9		dd:ej*dAe6e8B d<e6dB d>e8fdBdCZ:G dDdE dEej#Z;G dFdG dGej#Z<G dHdI dIej#Z=G dJdK dKej#Z>G dLdM dMej#Z?eedNdOG dPdQ dQeZ@G dRdS dSe4ZAeedTdOG dUdV dVeZBedWdOG dXdY dYe4ZCeedZdOG d[d\ d\eZDed]dOG d^d_ d_e4ZEeed`dOG dadb dbeZFeedcdOG ddde deeZGeedcdOG dfdg dgeZHdhejIjJdiej*djej*fdkdlZKddmej*dnej*dB djej*fdodpZLG dqdr dre4ZMeedsdOG dtdu dueZNG dvdw dwe4ZOeedxdOG dydz dzeZPG d{d| d|ej#ZQed}dOG d~d de4ZRg dZSdS )zPyTorch PatchTSMixer model.    N)Callable)	dataclass)PreTrainedModel)ModelOutput   )initialization)FlashAttentionKwargs)ALL_ATTENTION_FUNCTIONS)Unpack)NegativeBinomialOutputNormalOutputStudentTOutput)TransformersKwargsauto_docstringlogging   )PatchTSMixerConfigc                       s2   e Zd ZdZdedef fddZdd Z  ZS )PatchTSMixerGatedAttentionz
    Module that applies gated attention to input data.

    Args:
        in_size (`int`): The input size.
        out_size (`int`): The output size.
    in_sizeout_sizec                    s*   t    t||| _tjdd| _d S )Ndim)super__init__nnLinear
attn_layerSoftmaxattn_softmax)selfr   r   	__class__ t/home/ubuntu/transcripts/venv/lib/python3.10/site-packages/transformers/models/patchtsmixer/modeling_patchtsmixer.pyr   /   s   
z#PatchTSMixerGatedAttention.__init__c                 C   s   |  | |}|| }|S N)r   r   )r    inputsattn_weightr#   r#   r$   forward4   s   z"PatchTSMixerGatedAttention.forward)__name__
__module____qualname____doc__intr   r(   __classcell__r#   r#   r!   r$   r   &   s    r   c                       6   e Zd ZdZdef fddZdejfddZ  Z	S )PatchTSMixerBatchNormzP
    Compute batch normalization over the sequence length (time) dimension.
    configc                    s"   t    tj|j|jd| _d S )Neps)r   r   r   BatchNorm1dd_modelnorm_eps	batchnormr    r1   r!   r#   r$   r   @   s   
zPatchTSMixerBatchNorm.__init__r&   c                 C   s"   | dd}| |}| ddS )a  
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        r      )	transposer7   )r    r&   outputr#   r#   r$   r(   D   s   
zPatchTSMixerBatchNorm.forward
r)   r*   r+   r,   r   r   torchTensorr(   r.   r#   r#   r!   r$   r0   ;   s    r0   c                       sN   e Zd ZdZdef fddZededejfddZ	de
jfd	d
Z  ZS )PatchTSMixerPositionalEncodingz'
    Class for positional encoding
    r1   c                    s<   t    |jr| || _d S tt|j	|j
| _d S r%   )r   r   use_positional_encoding_init_peposition_encr   	Parameterr=   zerosnum_patchesr5   r8   r!   r#   r$   r   V   s   
z'PatchTSMixerPositionalEncoding.__init__returnc                 C   s   | j dkrtjt| j| jdd}|S | j dkrvt| j| j}td| j	d}t
td| jdtd| j   }t|| |d d dd df< t|| |d d dd df< ||  }|| d	  }tj|d
d}|S t| j  d)NrandomTrequires_gradsincosr   r   r9   g     @
   FzN is not a valid positional encoder. Available types are 'random' and 'sincos'.)positional_encoding_typer   rC   r=   randnrE   r5   rD   arange	unsqueezeexpmathlogsincosmeanstd
ValueError)r1   rB   positiondiv_termr#   r#   r$   rA   ^   s    

(  
z'PatchTSMixerPositionalEncoding._init_pepatch_inputc                 C   s   || j  }|S r%   )rB   )r    rZ   hidden_stater#   r#   r$   r(   r   s   
z&PatchTSMixerPositionalEncoding.forward)r)   r*   r+   r,   r   r   staticmethodr   rC   rA   r=   r>   r(   r.   r#   r#   r!   r$   r?   Q   s    r?   c                       r/   )PatchTSMixerNormLayerzeNormalization block

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    sF   t    |j| _d|j v rt|| _d S tj|j|j	d| _d S )Nbatchr2   )
r   r   norm_mlplowerr0   normr   	LayerNormr5   r6   r8   r!   r#   r$   r      s
   
zPatchTSMixerNormLayer.__init__r&   c                 C   sf   d| j  v r,t||jd |jd  |jd |jd f}| |}t||j}|S | |}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the normalization layer.
        Returns:
            `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
        r^   r   r   r9   r   )r_   r`   r=   reshapeshapera   )r    r&   inputs_reshapedr#   r#   r$   r(      s   


zPatchTSMixerNormLayer.forwardr<   r#   r#   r!   r$   r]   x   s    
r]   c                       s,   e Zd Z fddZdejfddZ  ZS )PatchTSMixerMLPc                    sP   t    ||j }t||| _t|j| _t||| _	t|j| _
d S r%   )r   r   expansion_factorr   r   fc1Dropoutdropoutdropout1fc2dropout2)r    in_featuresout_featuresr1   
num_hiddenr!   r#   r$   r      s   

zPatchTSMixerMLP.__init__r&   c                 C   s0   |  tj| |}| |}| |}|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the MLP layer.
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        )rk   r   
functionalgelurh   rl   rm   )r    r&   r#   r#   r$   r(      s   

zPatchTSMixerMLP.forward)r)   r*   r+   r   r=   r>   r(   r.   r#   r#   r!   r$   rf      s    rf   c                       r/   )$PatchTSMixerChannelFeatureMixerBlockzThis module mixes the features in the channel dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    P   t    t|| _|j| _t|j|j|d| _|jr&t|j|jd| _	d S d S Nrn   ro   r1   r   r   )
r   r   r]   ra   
gated_attnrf   num_input_channelsmlpr   gating_blockr8   r!   r#   r$   r      s   

z-PatchTSMixerChannelFeatureMixerBlock.__init__r&   c                 C   sT   |}|  |}|dddd}| jr| |}| |}|dddd}|| }|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                input to the MLP layer
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        r   r   r9   r   )ra   permuterx   r{   rz   )r    r&   residualoutr#   r#   r$   r(      s   


z,PatchTSMixerChannelFeatureMixerBlock.forwardr<   r#   r#   r!   r$   rs      s    rs           modulequerykeyvalueattention_maskscalingrj   kwargsc           
      K   s   |d u r| dd }t||dd| }|d ur|| }tjj|dd}tjj||| jd}t||}	|	dd	 }	|	|fS )Nr         r9   r   r   )ptrainingr   )
sizer=   matmulr:   r   rq   softmaxrj   r   
contiguous)
r   r   r   r   r   r   rj   r   attn_weightsattn_outputr#   r#   r$   eager_attention_forward   s   
r   c                       s   e Zd ZdZ					ddededed	ed
edededB f fddZ			dde	j
de	j
dB de	j
dB dedB dee dee	j
e	j
dB ee	j
 dB f fddZ  ZS )PatchTSMixerAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN	embed_dim	num_headsrj   
is_decoderbias	is_causalr1   c                    s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).r   )r   )r   r   r   r   rj   head_dimr1   rW   r   r   r   r   r   k_projv_projq_projout_proj)r    r   r   rj   r   r   r   r1   r!   r#   r$   r     s&   



zPatchTSMixerAttention.__init__hidden_stateskey_value_statesr   output_attentionsr   rF   c                 K   s  |du}|j dd \}}|r|j d n|}	||d| jf}
||	d| jf}| |j|
 dd}|r4|n|}| |j| dd}| |j| dd}t| j	j
t}|| ||||f| jsbdn| j| j|d|\}}|||d }| |}||dfS )z#Input shape: Batch x Time x ChannelNr   r   r9   r   )rj   r   r   )rd   r   r   viewr:   r   r   r	   get_interfacer1   _attn_implementationr   r   rj   r   rc   r   r   )r    r   r   r   r   r   is_cross_attentionbsztgt_lensrc_lenq_input_shapekv_input_shapequery_statescurrent_states
key_statesvalue_statesattention_interfacer   r   r#   r#   r$   r(   0  s8   	


zPatchTSMixerAttention.forward)r   FTFN)NNF)r)   r*   r+   r,   r-   floatboolr   r   r=   r>   r
   r   tupler(   r.   r#   r#   r!   r$   r     sL    "	r   c                       .   e Zd ZdZdef fddZdd Z  ZS )PatchMixerBlockzxThis module mixes the patch dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    s   t    t|| _|j| _|j| _t|j|j|d| _|jr(t	|j|jd| _
|jr>t|j|j|j|d| _t|| _d S d S )Nrv   rw   )r   r   rj   r1   )r   r   r]   ra   	self_attnrx   rf   rE   rz   r   r{   r   r5   self_attn_headsrj   self_attn_layer	norm_attnr8   r!   r#   r$   r   l  s(   

zPatchMixerBlock.__init__c                 C   s   |}|  |}| jr,|j\}}}}||| ||}| j|dd\}}	}	|||||}|dd}| |}| jr?| |}|dd}| jrO| 	|| }|| }
|
S )z
        Args:
            hidden_state (`torch.Tensor`): Input tensor.

        Returns:
            `torch.Tensor`: Transformed tensor.
        F)r   r9   r   )
ra   r   rd   rc   r   r:   rz   rx   r{   r   )r    r[   r}   
batch_sizen_varsrE   r5   hidden_state_reshapedx_attn_r~   r#   r#   r$   r(     s    


zPatchMixerBlock.forwardr)   r*   r+   r,   r   r   r(   r.   r#   r#   r!   r$   r   d  s    r   c                       r/   )FeatureMixerBlockzThis module mixes the hidden feature dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r1   c                    rt   ru   )
r   r   r]   ra   rx   rf   r5   rz   r   r{   r8   r!   r#   r$   r     s   

zFeatureMixerBlock.__init__hiddenc                 C   s4   |}|  |}| |}| jr| |}|| }|S )
        Args:
            hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
                Input tensor to the layer.

        Returns:
            `torch.Tensor`: Transformed tensor.
        )ra   rz   rx   r{   )r    r   r}   r~   r#   r#   r$   r(     s   	


zFeatureMixerBlock.forwardr<   r#   r#   r!   r$   r     s    r   c                       r/   )PatchTSMixerLayerz
    The `PatchTSMixer` layer that does all three kinds of mixing.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r1   c                    sH   t    t|d| _t|d| _|j| _|jdkr"t|d| _d S d S )Nr1   mix_channel)	r   r   r   patch_mixerr   feature_mixermoders   channel_feature_mixerr8   r!   r#   r$   r     s   

zPatchTSMixerLayer.__init__r   c                 C   s,   | j dkr
| |}| |}| |}|S )r   r   )r   r   r   r   )r    r   r#   r#   r$   r(     s
   
	


zPatchTSMixerLayer.forwardr<   r#   r#   r!   r$   r     s    	r   c                       s6   e Zd ZdZdef fddZd	defddZ  ZS )
PatchTSMixerBlockzThe main computing framework of the `PatchTSMixer` model.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    s2   t     j}t fddt|D | _d S )Nc                    s   g | ]}t  d qS )r   )r   .0r   r   r#   r$   
<listcomp>
  s    z.PatchTSMixerBlock.__init__.<locals>.<listcomp>)r   r   
num_layersr   
ModuleListrangemixers)r    r1   r   r!   r   r$   r     s   
"zPatchTSMixerBlock.__init__Foutput_hidden_statesc                 C   s>   g }|}| j D ]}||}|r|| q|r||fS |dfS )as  
        Args:
            hidden_state (`torch.Tensor`): The input tensor.
            output_hidden_states (`bool`, *optional*, defaults to False.):
                Whether to output the hidden states as well.

        Returns:
            `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
            `True`.
        N)r   append)r    r[   r   all_hidden_states	embeddingmodr#   r#   r$   r(     s   

zPatchTSMixerBlock.forwardF)	r)   r*   r+   r,   r   r   r   r(   r.   r#   r#   r!   r$   r     s    r   c                       0   e Zd ZdZddef fddZdd Z  ZS )	PatchTSMixerForPredictionHeadzqPrediction Head for Forecasting

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr1   c                    s|   t    |j| _| jd ur| j  t|j| _|d u r+t|j	|j
 |j| _n
||j	|j
 | _tjdd| _d S )N	start_dim)r   r   prediction_channel_indicessortr   ri   head_dropoutdropout_layerr   rE   r5   prediction_lengthbase_forecast_blockget_parameter_projectionFlattenflatten)r    r1   distribution_outputr!   r#   r$   r   .  s   



z&PatchTSMixerForPredictionHead.__init__c                    s     |} |} |}t|trtdd |D }n|dd} jdurBt|tr;t fdd|D }|S |d jf }|S )ar  

        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
                or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.

        c                 s   s    | ]	}| d dV  qdS )r   r   N)r:   r   zr#   r#   r$   	<genexpr>Q  s    z8PatchTSMixerForPredictionHead.forward.<locals>.<genexpr>r   r   Nc                 3   s    | ]
}|d  j f V  qdS ).N)r   r   r    r#   r$   r   W  s    .)r   r   r   
isinstancer   r:   r   r    hidden_featuresforecastr#   r   r$   r(   @  s   





z%PatchTSMixerForPredictionHead.forwardr%   r   r#   r#   r!   r$   r   &  s    r   c                       r   )	PatchTSMixerLinearHeadzLinear head for Classification and Regression.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr1   c                    s   t    |j| _|j| _|jd u r|j}nd}|| _|d u r.t|j|j	 | |j
| _n||j|j	 | | _|jd u rGtjdd| _ntjdd| _t|j| _d S )Nr   r   r   )r   r   head_aggregationoutput_rangerE   r   r   r   r5   ry   num_targets
projectionr   r   r   ri   r   rj   )r    r1   r   
mul_factorr!   r#   r$   r   f  s&   


zPatchTSMixerLinearHead.__init__c                 C   s   | dd}| jdkr|d }n| jdkr|jddj}n| jdkr(|jdd}| jr0| |}| |}| |}| jdu rX| j	durXt
|| j	d	 | j	d
   | j	d
  }|S )ai  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x num_targets)`.
        r   r   use_last).r   max_poolr   avg_poolNr   r   )r:   r   maxvaluesrU   r   rj   r   r   r   r=   sigmoid)r    r   r#   r#   r$   r(     s   






&zPatchTSMixerLinearHead.forwardr%   r   r#   r#   r!   r$   r   ^  s    r   c                   @   s6   e Zd ZU eed< dZdZdZdZe	
 dd ZdS )	PatchTSMixerPreTrainedModelr1   modelpast_values)timeFc                 C   s  t |tr| jjdkrtj|jddd dS dS t |tjtj	frKt
|j t|j t|dddurIt
|j t|j t
|j dS dS t |tr`t
|jj t|jj dS t |tjr~tj|jd| jjd |jdurt
|j dS dS dS )zInitialize weightsrG   r   g?)rU   rV   running_meanN)r   r?   r1   rL   initnormal_rB   r   rb   r4   zeros_r   ones_weightgetattrr   running_varnum_batches_trackedr0   r7   r   init_std)r    r   r#   r#   r$   _init_weights  s*   


z)PatchTSMixerPreTrainedModel._init_weightsN)r)   r*   r+   r   __annotations__base_model_prefixmain_input_nameinput_modalitiessupports_gradient_checkpointingr=   no_gradr  r#   r#   r#   r$   r     s   
 r   c                       r   )PatchTSMixerPretrainHeadzcPretraining head.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    s.   t    t|j| _t|j|j| _	d S r%   )
r   r   r   ri   r   r   r   r5   patch_lengthbase_pt_blockr8   r!   r#   r$   r     s   
z!PatchTSMixerPretrainHead.__init__c                 C   s   |  |}| |}|S )a  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
        )r   r  r   r#   r#   r$   r(     s   

z PatchTSMixerPretrainHead.forwardr   r#   r#   r!   r$   r    s    r  Fr&   
mask_ratiounmasked_channel_indiceschannel_consistent_masking
mask_valuec                 C   s*  |dk s|dkrt d| d| j\}}}}| j}	t|d|  }
|r5tj|d||	d}|d|d}n	tj||||	d}tj||||	d}d|ddddd|
f< tj|dd}tj|dd}tj	|d|d	}|
dddd|}|durd|dd|ddddf< | | |}||d
 fS )a  random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    r   r   zMask ratio z has to be between 0 and 1.deviceNr   r   )r   index.r   )rW   rd   r  r-   r=   randrepeatonesargsortgatherrO   masked_fillr   )r&   r  r  r  r  r   num_channelssequence_lengthnum_featuresr  len_keepnoisemaskids_shuffleids_restoreinputs_maskr#   r#   r$   random_masking  s&   r(  num_forecast_mask_patchesc                 C   s  t |tr|g}dd |D }| j\}}}}tj|||| jd}	g }
d}t|}t||D ](\}}|dks9||krAtd| dt|| | }|
	|||g ||7 }q-t
|
dd d	}
||k rq|
d d
 ||  |
d d
< n||kr|
d d
 ||  |
d d
< d}|
D ]\}}}|| }d|	||dd| df< |}qt|	jd }|	| }	|	dddd|}	|durd|	dd|ddddf< | |	 |}||	d fS )a  Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    c                 S   s   g | ]}d qS )r   r#   r   r#   r#   r$   r   9  s    z$forecast_masking.<locals>.<listcomp>r  r   znum_forecast_mask_patches z6 should be greater than 0 and less than total patches.c                 S   s   | d S Nr9   r#   )xr#   r#   r$   <lambda>K  s    z"forecast_masking.<locals>.<lambda>)r   r9   r   r   Nr  )r   r-   rd   r=   rD   r  sumziprW   r   sortedrandpermrO   r  r  r   )r&   r)  r  r  forecast_mask_ratiosr   r  r   r!  r$  t_listtotal_lengthtotal_ratior  ratiotemp_lenbatch1	patch_lenr   batch2permr'  r#   r#   r$   forecast_masking  sB   


r;  c                       r/   )PatchTSMixerPatchifyz
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    r1   c                    s   t    |j| _|j| _|j| _| j| jkr$td| j d| j dt| j| j| j | j d | _| j| j| jd   }| j| | _	d S )NzSequence length (z+) has to be greater than the patch length ()r   )
r   r   context_lengthr   r  patch_striderW   r   rE   sequence_start)r    r1   new_sequence_lengthr!   r#   r$   r   l  s   
 zPatchTSMixerPatchify.__init__r   c                 C   sp   |j d }|| jkrtd| d| j d|dd| jdddf }|jd| j| jd}|dd }|S )a!  
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        r   zInput sequence length (z%) doesn't match model configuration (r   N)	dimensionr   stepr   )	rd   r   rW   r@  unfoldr  r?  r:   r   )r    r   r   r;   r#   r#   r$   r(   }  s   
	
zPatchTSMixerPatchify.forwardr<   r#   r#   r!   r$   r<  d  s    r<  c                       r/   )PatchTSMixerMaskinga  
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSMixerConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    r1   c                    sX   t    |j| _|j| _|j| _|j| _|j| _|j| _| jd ur*t| j| _d S d S r%   )	r   r   random_mask_ratior  	mask_typer)  r  r  r/  r8   r!   r#   r$   r     s   

zPatchTSMixerMasking.__init__rZ   c                 C   sr   | j dkrt|| j| j| j| jd\}}n| j dkr(t|| j| j| jd\}}n	td| j  d|	 }||fS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        rG   )r&   r  r  r  r  r   )r&   r)  r  r  zInvalid mask type .)
rG  r(  rF  r  r  r  r;  r)  rW   r   )r    rZ   masked_inputr$  r#   r#   r$   r(     s$   

zPatchTSMixerMasking.forwardr<   r#   r#   r!   r$   rE    s    rE  c                	       P   e Zd ZdZdef fddZdejdejdeejejejf fdd	Z	  Z
S )
PatchTSMixerStdScalerz
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    r1   c                    sV   t    t|dr|jnd| _t|dr|jnd| _t|dr&|j| _d S d| _d S )Nscaling_dimr   keepdimTminimum_scalegh㈵>)r   r   hasattrrL  r   rM  rN  r8   r!   r#   r$   r     s   
 zPatchTSMixerStdScaler.__init__dataobserved_indicatorrF   c                 C   sz   |j | j| jd}|d}|| j | j| jd| }|| | d j | j| jd| }t|| j }|| | ||fS )C  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        rM        ?r9   )r-  r   rM  	clamp_minr=   sqrtrN  )r    rP  rQ  denominatorlocvariancescaler#   r#   r$   r(     s   
"zPatchTSMixerStdScaler.forwardr)   r*   r+   r,   r   r   r=   r>   r   r(   r.   r#   r#   r!   r$   rK    s    rK  c                	       rJ  )
PatchTSMixerMeanScalerz
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    r1   c                    sl   t    t|dr|jnd| _t|dr|jnd| _t|dr#|jnd| _t|dr1|j| _d S d | _d S )NrL  r   rM  TrN  绽|=default_scale)r   r   rO  rL  r   rM  rN  r^  r8   r!   r#   r$   r     s
   
 zPatchTSMixerMeanScaler.__init__rP  rQ  rF   c           
      C   s   ||   j| jdd}|j| jdd}|tj|dd }| jdu r:|jdd}tj|ddd}t|| }n| jt| }t|dk||}tj|| j	d}|| }	| j
sa|j| jd}|	t||fS )rR  TrS  r   minNr   r   )absr-  r   r=   clampr^  squeeze	ones_likewhererN  rM  
zeros_like)
r    rP  rQ  ts_sumnum_observedrZ  	batch_sumbatch_observationsr^  scaled_datar#   r#   r$   r(     s   
zPatchTSMixerMeanScaler.forwardr[  r#   r#   r!   r$   r\    s    r\  c                
       sX   e Zd ZdZdef fddZ	ddejdejdB deejejejf fd	d
Z	  Z
S )PatchTSMixerNOPScalerz|
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    r1   c                    s@   t    t|dr|jnd| _t|dr|j| _d S d| _d S )NrL  r   rM  T)r   r   rO  rL  r   rM  r8   r!   r#   r$   r   2  s   
 zPatchTSMixerNOPScaler.__init__NrP  rQ  rF   c                 C   sB   t j|ddj| j| jd}t j|ddj| j| jd}|||fS )a  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        FrH   )r   rM  )r=   rd  rU   r   rM  rf  )r    rP  rQ  rZ  rX  r#   r#   r$   r(   7  s   
zPatchTSMixerNOPScaler.forwardr%   r[  r#   r#   r!   r$   rl  -  s    rl  zS
    Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
    )custom_introc                   @   s:   e Zd ZU dZdZejdB ed< dZe	ej dB ed< dS )PatchTSMixerEncoderOutputa-  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
        Hidden-state at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    Nlast_hidden_stater   )
r)   r*   r+   r,   ro  r=   FloatTensorr  r   r   r#   r#   r#   r$   rn  H  s   
 rn  c                       sX   e Zd ZdZdef fddZe		ddejde	dB d	e	dB d
e
eB fddZ  ZS )PatchTSMixerEncoderz
    Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    sb   t  | |j| _t|j|j| _|jrt	|d| _
nd | _
t|d| _|jr/|   d S d S )Nr   )r   r   use_return_dictr   r   r  r5   patcherr@   r?   positional_encoderr   mlp_mixer_encoder	post_initr8   r!   r#   r$   r   c  s   zPatchTSMixerEncoder.__init__FNr   r   return_dictrF   c                 K   sh   |dur|n| j }| |}| jdur| |}| j||d\}}|s.tdd ||fD S t||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to
            predict the masked portion. For a forecasting task, this denotes the history/past time series values.
            Similarly, for classification or regression tasks, it denotes the appropriate context values of the
            time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
            it is greater than 1.

        Returns:
            `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
        N)r   c                 s       | ]}|V  qd S r%   r#   r   vr#   r#   r$   r     
    
z.PatchTSMixerEncoder.forward.<locals>.<genexpr>)ro  r   )rr  rs  rt  ru  r   rn  )r    r   r   rw  r   patchesro  r   r#   r#   r$   r(   s  s   


zPatchTSMixerEncoder.forward)FN)r)   r*   r+   r,   r   r   r   r=   r>   r   r   rn  r(   r.   r#   r#   r!   r$   rq  Z  s    rq  zG
    Base class for model's outputs, with potential hidden states.
    c                   @   s   e Zd ZU dZdZejdB ed< dZe	ej dB ed< dZ
ejdB ed< dZejdB ed< dZejdB ed< dZejdB ed< dS )	PatchTSMixerModelOutputa  
    last_hidden_state (`torch.FloatTensor`  of shape `(batch_size, num_channels, num_patches, d_model)`):
        Hidden-state at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Patched input data to the model.
    mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
        Bool Tensor indicating True in masked patches and False otherwise.
    loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
        enabled.
    scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
        enabled.
    Nro  r   rZ   r$  rX  rZ  )r)   r*   r+   r,   ro  r=   rp  r  r   r   rZ   r$  rX  rZ  r#   r#   r#   r$   r}    s   
 r}  z=
    The PatchTSMixer Model for time-series forecasting.
    c                       sb   e Zd Zddedef fddZe			ddejdejdB d	edB d
edB de	f
ddZ
  ZS )PatchTSMixerModelFr1   
mask_inputc                    s   t  | |j| _t|| _t|| _|du rt|| _nd| _|j	dkr,t
|| _n|j	dks6|j	du r<t|| _nt|| _|jrJ|   dS dS )z
        mask_input (bool, *optional*, defaults to `False`):
            Whether to mask the input using the [`PatchTSMixerMasking`] module.
        TNrU   rV   )r   r   rr  rq  encoderr<  patchingrE  maskingr   r\  scalerrK  rl  rv  )r    r1   r  r!   r#   r$   r     s   



zPatchTSMixerModel.__init__Nr   observed_maskr   rw  rF   c                 K   s   |dur|n| j }d}|du rt|}| ||\}}}	| |}
|
}| jdur0| |
\}}| j|||d}t|trAt	| }|sTtdd |j
|j|
|||	fD S t|j
|j|
|||	dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        Nr   rw  c                 s   rx  r%   r#   ry  r#   r#   r$   r     r{  z,PatchTSMixerModel.forward.<locals>.<genexpr>)ro  r   rZ   r$  rX  rZ  )rr  r=   rd  r  r  r  r  r   r   rn  ro  r   r}  )r    r   r  r   rw  r   r$  scaled_past_valuesrX  rZ  	patched_x	enc_inputencoder_outputr#   r#   r$   r(     sD   



zPatchTSMixerModel.forwardr   )NFN)r)   r*   r+   r   r   r   r   r=   r>   r}  r(   r.   r#   r#   r!   r$   r~    s"    r~  z>
    Output type of [`PatchTSMixerForPreTrainingOutput`].
    c                   @   ^   e Zd ZU dZdZejdB ed< dZejdB ed< dZ	ejdB ed< dZ
eej dB ed< dS ) PatchTSMixerForPreTrainingOutputa@  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
        Prediction output from the pretrain head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    Nlossprediction_outputsro  r   r)   r*   r+   r,   r  r=   rp  r  r  ro  r   r   r#   r#   r#   r$   r  '     
 r  z.
    `PatchTSMixer` for mask pretraining.
    c                       sb   e Zd Zdef fddZe				ddejdejdB d	edB d
ededB de	fddZ
  ZS )PatchTSMixerForPretrainingr1   c                    sL   t  | t|dd| _t|d| _|j| _|j| _|jr$|   d S d S )NT)r  r   )	r   r   r~  r   r  headmasked_lossrr  rv  r8   r!   r#   r$   r   E  s   z#PatchTSMixerForPretraining.__init__NFTr   r  r   return_lossrw  rF   c                 K   s   |dur|n| j }| jdu rtjjdd}ntjjdd}| j||||d}t|tr/t| }| 	|j
}	|du r@||	|j}
nd}
| jdu r]|
dur]|
jdd|j  |j d	  }
|sntd
d |
|	|j
|jfD S t|
|	|j
|jdS )aT  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        NTnone	reductionrU   r  r   rw  r   r   r]  c                 s   rx  r%   r#   ry  r#   r#   r$   r     r{  z5PatchTSMixerForPretraining.forward.<locals>.<genexpr>r  r  ro  r   )rr  r  r=   r   MSELossr   r   r   r}  r  ro  rZ   rU   r$  r-  r   r  )r    r   r  r   r  rw  r   r  model_outputx_hatloss_valr#   r#   r$   r(   P  s@   

$
z"PatchTSMixerForPretraining.forwardNFTN)r)   r*   r+   r   r   r   r=   r>   r   r  r(   r.   r#   r#   r!   r$   r  ?  s(    r  z=
    Output type of [`PatchTSMixerForPredictionOutput`].
    c                   @   s   e Zd ZU dZdZejdB ed< dZejdB ed< dZ	ejdB ed< dZ
eej dB ed< dZejdB ed< dZejdB ed< dS )	PatchTSMixerForPredictionOutputaD  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
        Prediction output from the forecast head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
        Input mean
    scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
        Input std dev
    Nr  r  ro  r   rX  rZ  )r)   r*   r+   r,   r  r=   rp  r  r  ro  r   r   rX  rZ  r#   r#   r#   r$   r    s   
 r  z
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.
    c                   @   $   e Zd ZU dZdZejdB ed< dS )"SamplePatchTSMixerPredictionOutput
    sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
        Sampled values from the chosen distribution.
    N	sequencesr)   r*   r+   r,   r  r=   rp  r  r#   r#   r#   r$   r       
 r  c                   @   r  )"SamplePatchTSMixerRegressionOutputr  Nr  r  r#   r#   r#   r$   r    r  r  inputtargetrF   c                 C   s   |  | S )zc
    Computes the negative log likelihood loss from input distribution with respect to target.
    )log_prob)r  r  r#   r#   r$   nll  s   r  input_tensorweightsc                 C   sr   |dur3t |dk| | t | }t j|r|j|dn| dd}|r-|j|d| S | | S | j|dS )aj  
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.

    Args:
        input_tensor (`torch.FloatTensor`):
            Input tensor, of which the average must be computed.
        weights (`torch.FloatTensor`, *optional*):
            Weights tensor, of the same shape as `input_tensor`.
        dim (`int`, *optional*):
            The dim along which to average `input_tensor`.

    Returns:
        `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
    Nr   r   rT  r_  )r=   re  rf  rb  r-  rU   )r  r  r   weighted_tensorsum_weightsr#   r#   r$   weighted_average  s
   " r  c                       s   e Zd ZdZdef fddZe					ddejd	ejdB d
ejdB de	dB de	de	dB de
fddZe 	ddejd	ejdB defddZ  ZS )PatchTSMixerForPredictionz
    `PatchTSMixer` for forecasting application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r1   c                    s   t  | |j| _|j| _|j| _|j| _|jdkrd | _n"|j}tt	t
d}||j}|d ur9||d| _ntd|j t|| _t|| jd| _|jrW|   d S d S )Nmse	student_tnormalnegative_binomialr   Unknown distribution output r1   r   )r   r   r  rr  r   num_parallel_samplesr   r   r   r   r   getrW   r~  r   r   r  rv  )r    r1   r   distribution_output_mapoutput_classr!   r#   r$   r     s0   

z"PatchTSMixerForPrediction.__init__NFTr   r  future_valuesr   r  rw  rF   c                 K   s  | j dkrtjdd}n| j dkrt}ntd|dur|n| j}| j||||d}	t|	tr3t	|	 }	| 
|	j}
d}| jdur| jro| jj|
|	jd| jf |	jd| jf d	}|durn|d
u rn|||d| jf }t|}nZ|
|	jd| jf  |	jd| jf  }
|dur|d
u r||
|d| jf }n5| jr| jj|
|	j|	jd	}|dur|d
u r|||}t|}n|
|	j |	j }
|dur|d
u r||
|}| jdur|	jd| jf }|	jd| jf }n|	j}|	j}|stdd ||
|	j|	j||fD S t||
|	j|	j||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `future_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        r  rU   r  r  2Invalid loss function: Allowed values: mse and nllNr  .rX  rZ  Tc                 s   rx  r%   r#   ry  r#   r#   r$   r     r{  z4PatchTSMixerForPrediction.forward.<locals>.<genexpr>)r  r  ro  r   rX  rZ  )r  r   r  r  rW   rr  r   r   r   r}  r  ro  r   r   distributionrX  rZ  r  r   r  )r    r   r  r  r   r  rw  r   r  r  y_hatr  r  rX  rZ  r#   r#   r$   r(   $  s   
%






z!PatchTSMixerForPrediction.forwardc                    s\   | j }| |d|dd}| jj|j|j|jd  fddt|D }tj|dd}t	|d	S )
a  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.

            observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, prediction_length, num_input_channels)`.
        NF)r   r  r  r   r  c                       g | ]}   qS r#   sampler   r  r#   r$   r     s    z6PatchTSMixerForPrediction.generate.<locals>.<listcomp>r   r   r  )
r  r   r  r  rX  rZ  r   r=   stackr  )r    r   r  r  outputssamplesr#   r  r$   generate  s   	
z"PatchTSMixerForPrediction.generate)NNFTNr%   )r)   r*   r+   r,   r   r   r   r=   r>   r   r  r(   r  r  r  r.   r#   r#   r!   r$   r    sB     	zr  zK
    Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
    c                   @   r  )-PatchTSMixerForTimeSeriesClassificationOutputaP  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
        Prediction output from the classification head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nr  r  ro  r   r  r#   r#   r#   r$   r    r  r  c                       sf   e Zd ZdZdef fddZe				ddejd	ejdB d
e	dB de	de	dB de
fddZ  ZS )'PatchTSMixerForTimeSeriesClassificationz
    `PatchTSMixer` for classification application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r1   c                    sd   t  | t|| _t|d| _|j| _|jdv r$t|j	|j
d| _nd | _|jr0|   d S d S )Nr   rV   rU   Tr5   rE   )r   r   r~  r   r   r  rr  r   InjectScalerStatistics4Dr5   rE   inject_scalerv  r8   r!   r#   r$   r     s   

z0PatchTSMixerForTimeSeriesClassification.__init__NFTr   target_valuesr   r  rw  rF   c                 K   s   t j }|dur|n| j}| j|||d}t|trt| }| jdur0| j|j	|j
|jd|_	| |j	}	|durD|du rD||	|}
nd}
|sWtdd |
|	|j	|jfD S t|
|	|j	|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target
            values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        Nr  r  Tc                 s   rx  r%   r#   ry  r#   r#   r$   r   E  r{  zBPatchTSMixerForTimeSeriesClassification.forward.<locals>.<genexpr>r  )r=   r   CrossEntropyLossrr  r   r   r   r}  r  ro  rX  rZ  r  r   r  )r    r   r  r   r  rw  r   r  r  r  r  r#   r#   r$   r(     sB   
%


z/PatchTSMixerForTimeSeriesClassification.forwardr  )r)   r*   r+   r,   r   r   r   r=   r>   r   r  r(   r.   r#   r#   r!   r$   r    s*    r  z=
    Output type of [`PatchTSMixerForRegressionOutput`].
    c                   @   r  )PatchTSMixerForRegressionOutputaM  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Prediction output from the regression head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nr  regression_outputsro  r   )r)   r*   r+   r,   r  r=   rp  r  r  ro  r   r   r#   r#   r#   r$   r  W  r  r  c                       sH   e Zd Zddededef fddZdejdejd	ejfd
dZ  ZS )r  r9   r5   rE   	expansionc                    s`   t    t|d || | _t|| || _tdd| | _td| d| _|| _d S r*  )	r   r   r   r   inverse_trans_expansioninverse_trans_compressionmap_scale_expansionmap_scale_compressionrE   )r    r5   rE   r  r!   r#   r$   r   p  s   

z!InjectScalerStatistics4D.__init__r&   rX  rZ  c                 C   s   | dd}|d}|dd| jd}| dd}|d}|dd| jd}tj||gdd}| |}| |}tj||gdd}| |}| 	|}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
            loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
            scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
        Returns:
            `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
        r   r   r   r   )
r:   rO   r  rE   r=   catr  r  r  r  )r    r&   rX  rZ  rU   stdevconcat_statsr#   r#   r$   r(   y  s   






z InjectScalerStatistics4D.forward)r9   )	r)   r*   r+   r-   r   r=   r>   r(   r.   r#   r#   r!   r$   r  o  s    $	r  z4
    `PatchTSMixer` for regression application.
    c                       s~   e Zd Zdef fddZe				ddejdejdB d	edB d
ededB de	fddZ
e dejdefddZ  ZS )PatchTSMixerForRegressionr1   c                    s   t  | t|| _|j| _|j| _|j| _|j| _|jdkr$d | _n tt	t
d}||j}|d ur<||jd| _ntd|j |jdv rSt|j|jd| _nd | _t|| jd| _|jrg|   d S d S )Nr  r  r   r  r  r  r  )r   r   r~  r   r  r   rr  r  r   r   r   r  r   rW   r   r  r5   rE   r  r   r  rv  )r    r1   r  r  r!   r#   r$   r     s4   


z"PatchTSMixerForRegression.__init__NFTr   r  r   r  rw  rF   c                    sD   j dkrtjdd}n j dkrt}ntd|dur|n j} j|||d}t|tr2t	| } j
durC j
|j|j|jd|_ |j}	|dur|d	u r jr jd
krdt|dk rdtd j|	}
t fdd|	D }	||
|}t|}n||	|}nd}|stdd ||	|j|jfD S t||	|j|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        r  rU   r  r  r  Nr  r  Tr  r   zDtarget_values cannot be negative for negative_binomial distribution.c                 3   s     | ]}| d  jjV  qdS )r   N)r   r1   r   )r   itemr   r#   r$   r     s    z4PatchTSMixerForRegression.forward.<locals>.<genexpr>c                 s   rx  r%   r#   ry  r#   r#   r$   r     r{  )r  r  ro  r   )r  r   r  r  rW   rr  r   r   r   r}  r  ro  rX  rZ  r  r   r=   any	Exceptionr  r  r   r  )r    r   r  r   r  rw  r   r  r  r  r  r  r#   r   r$   r(     sX   
$





z!PatchTSMixerForRegression.forwardc                    s^   | j }| |ddd}| j|j  fddt|D }tj|ddd|| jj	}t
|d	S )
a
  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the target values.

        Return:
            [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, num_targets)`.
        NF)r   r  r   c                    r  r#   r  r   r  r#   r$   r   @  s    z6PatchTSMixerForRegression.generate.<locals>.<listcomp>r   r   r   r  )r  r   r  r  r   r=   r  r   r1   r   r  )r    r   r  r  r  r#   r  r$   r  "  s   

z"PatchTSMixerForRegression.generater  )r)   r*   r+   r   r   r   r=   r>   r   r  r(   r  r  r  r.   r#   r#   r!   r$   r    s4    ']r  )r   r~  r  r  r  r  )Nr   )NFr   )Nr   )NN)Tr,   rQ   collections.abcr   dataclassesr   r=   torch.nnr   transformers.modeling_utilsr   transformers.utilsr    r   r   modeling_flash_attention_utilsr   modeling_utilsr	   processing_utilsr
   time_series_utilsr   r   r   utilsr   r   r   configuration_patchtsmixerr   
get_loggerr)   loggerModuler   r0   r?   r]   rf   rs   r>   r   r   r   r   r   r   r   r   r   r   r  listr   r-   r(  r;  r<  rE  rK  r\  rl  rn  rq  r}  r~  r  r  r  r  r  distributionsDistributionr  r  r  r  r  r  r  r  __all__r#   r#   r#   r$   <module>   s  
'17
VF-&)8G"
>
E1=$7FbU	
" Yo( .