o
    	۷iL                    @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	m
Z
 ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZ ddlmZ eeZ G dd de
j!Z"G dd de
j!Z#G dd de
j!Z$G dd de
j!Z%G dd de
j!Z&G dd de
j!Z'			dde
j!dej(dej(dej(d eej( d!ee) d"e)d#eej( fd$d%Z*G d&d' d'e
j!Z+G d(d) d)e
j!Z,G d*d+ d+e
j!Z-G d,d- d-e
j!Z.G d.d/ d/e
j!Z/G d0d1 d1e
j!Z0G d2d3 d3e
j!Z1eG d4d5 d5eZ2G d6d7 d7e
j!Z3		8	dd9ej(d:e)d;ee4 d<e5d=e6f
d>d?Z7		dd9ej(d@ee4e6f d;ee4 d=e6fdAdBZ8G dCdD dDe
j!Z9G dEdF dFe
j!Z:G dGdH dHe
j!Z;G dIdJ dJe
j!Z<G dKdL dLe
j!Z=eedMdNG dOdP dPeZ>G dQdR dRe2Z?eedSdNG dTdU dUeZ@edVdNG dWdX dXe2ZAeedYdNG dZd[ d[eZBed\dNG d]d^ d^e2ZCeed_dNG d`da daeZDeedbdNG dcdd ddeZEeedbdNG dedf dfeZFdgejGjHdhej(diej(fdjdkZIddlej(dmeej( diej(fdndoZJG dpdq dqe2ZKeedrdNG dsdt dteZLG dudv dve2ZMeedwdNG dxdy dyeZNG dzd{ d{e
j!ZOed|dNG d}d~ d~e2ZPg dZQdS )zPyTorch PatchTSMixer model.    N)	dataclass)CallableOptionalUnion)PreTrainedModel)ModelOutput   )FlashAttentionKwargs)ALL_ATTENTION_FUNCTIONS)Unpack)NegativeBinomialOutputNormalOutputStudentTOutput)auto_docstringlogging   )PatchTSMixerConfigc                       s2   e Zd ZdZdedef fddZdd Z  ZS )PatchTSMixerGatedAttentionz
    Module that applies gated attention to input data.

    Args:
        in_size (`int`): The input size.
        out_size (`int`): The output size.
    in_sizeout_sizec                    s*   t    t||| _tjdd| _d S )Ndim)super__init__nnLinear
attn_layerSoftmaxattn_softmax)selfr   r   	__class__ l/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/patchtsmixer/modeling_patchtsmixer.pyr   /   s   
z#PatchTSMixerGatedAttention.__init__c                 C   s   |  | |}|| }|S N)r   r   )r    inputsattn_weightr#   r#   r$   forward4   s   z"PatchTSMixerGatedAttention.forward)__name__
__module____qualname____doc__intr   r(   __classcell__r#   r#   r!   r$   r   &   s    r   c                       6   e Zd ZdZdef fddZdejfddZ  Z	S )PatchTSMixerBatchNormzP
    Compute batch normalization over the sequence length (time) dimension.
    configc                    s"   t    tj|j|jd| _d S )Neps)r   r   r   BatchNorm1dd_modelnorm_eps	batchnormr    r1   r!   r#   r$   r   @   s   
zPatchTSMixerBatchNorm.__init__r&   c                 C   s"   | dd}| |}| ddS )a  
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        r      )	transposer7   )r    r&   outputr#   r#   r$   r(   D   s   
zPatchTSMixerBatchNorm.forward
r)   r*   r+   r,   r   r   torchTensorr(   r.   r#   r#   r!   r$   r0   ;   s    r0   c                       sN   e Zd ZdZdef fddZededejfddZ	de
jfd	d
Z  ZS )PatchTSMixerPositionalEncodingz'
    Class for positional encoding
    r1   c                    s<   t    |jr| || _d S tt|j	|j
| _d S r%   )r   r   use_positional_encoding_init_peposition_encr   	Parameterr=   zerosnum_patchesr5   r8   r!   r#   r$   r   V   s   
z'PatchTSMixerPositionalEncoding.__init__returnc                 C   s   | j dkrtjt| j| jdd}|S | j dkrvt| j| j}td| j	d}t
td| jdtd| j   }t|| |d d dd df< t|| |d d dd df< ||  }|| d	  }tj|d
d}|S t| j  d)NrandomTrequires_gradsincosr   r   r9   g     @
   FzN is not a valid positional encoder. Available types are 'random' and 'sincos'.)positional_encoding_typer   rC   r=   randnrE   r5   rD   arange	unsqueezeexpmathlogsincosmeanstd
ValueError)r1   rB   positiondiv_termr#   r#   r$   rA   ^   s    

(  
z'PatchTSMixerPositionalEncoding._init_pepatch_inputc                 C   s   || j  }|S r%   )rB   )r    rZ   hidden_stater#   r#   r$   r(   r   s   
z&PatchTSMixerPositionalEncoding.forward)r)   r*   r+   r,   r   r   staticmethodr   rC   rA   r=   r>   r(   r.   r#   r#   r!   r$   r?   Q   s    r?   c                       r/   )PatchTSMixerNormLayerzeNormalization block

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    sF   t    |j| _d|j v rt|| _d S tj|j|j	d| _d S )Nbatchr2   )
r   r   norm_mlplowerr0   normr   	LayerNormr5   r6   r8   r!   r#   r$   r      s
   
zPatchTSMixerNormLayer.__init__r&   c                 C   sf   d| j  v r,t||jd |jd  |jd |jd f}| |}t||j}|S | |}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the normalization layer.
        Returns:
            `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
        r^   r   r   r9   r   )r_   r`   r=   reshapeshapera   )r    r&   inputs_reshapedr#   r#   r$   r(      s   


zPatchTSMixerNormLayer.forwardr<   r#   r#   r!   r$   r]   x   s    
r]   c                       s,   e Zd Z fddZdejfddZ  ZS )PatchTSMixerMLPc                    sP   t    ||j }t||| _t|j| _t||| _	t|j| _
d S r%   )r   r   expansion_factorr   r   fc1Dropoutdropoutdropout1fc2dropout2)r    in_featuresout_featuresr1   
num_hiddenr!   r#   r$   r      s   

zPatchTSMixerMLP.__init__r&   c                 C   s0   |  tj| |}| |}| |}|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the MLP layer.
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        )rk   r   
functionalgelurh   rl   rm   )r    r&   r#   r#   r$   r(      s   

zPatchTSMixerMLP.forward)r)   r*   r+   r   r=   r>   r(   r.   r#   r#   r!   r$   rf      s    rf   c                       r/   )$PatchTSMixerChannelFeatureMixerBlockzThis module mixes the features in the channel dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    P   t    t|| _|j| _t|j|j|d| _|jr&t|j|jd| _	d S d S Nrn   ro   r1   r   r   )
r   r   r]   ra   
gated_attnrf   num_input_channelsmlpr   gating_blockr8   r!   r#   r$   r      s   

z-PatchTSMixerChannelFeatureMixerBlock.__init__r&   c                 C   sT   |}|  |}|dddd}| jr| |}| |}|dddd}|| }|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                input to the MLP layer
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        r   r   r9   r   )ra   permuterx   r{   rz   )r    r&   residualoutr#   r#   r$   r(      s   


z,PatchTSMixerChannelFeatureMixerBlock.forwardr<   r#   r#   r!   r$   rs      s    rs           modulequerykeyvalueattention_maskscalingrj   	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur|	| }	tjj|	dd}	|d ur5|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )Nr         r9   r   r   r   )ptraining)sizer=   matmulr:   r   rq   softmaxviewrj   r   
contiguous)r   r   r   r   r   r   rj   r   kwargsattn_weightsattn_outputr#   r#   r$   eager_attention_forward   s   r   c                       s   e Zd ZdZ					ddededed	ed
ededee f fddZ					dde
jdee
j dee
j dee
j dee dee dee
jee
j eee
j  f fddZ  ZS )PatchTSMixerAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN	embed_dim	num_headsrj   
is_decoderbias	is_causalr1   c                    s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).r   )r   )r   r   r   r   rj   head_dimr1   rW   r   r   r   r   r   k_projv_projq_projout_proj)r    r   r   rj   r   r   r   r1   r!   r#   r$   r     s&   



zPatchTSMixerAttention.__init__hidden_stateskey_value_statesr   layer_head_maskoutput_attentionsr   rF   c                 K   s  |du}|j dd \}}	|r|j d n|	}
||	d| jf}||
d| jf}| |j| dd}|r4|n|}| |j| dd}| |j| dd}t}| jj	dkr\t
| jj	 }|| ||||f| jshdn| j| j||d|\}}|||	d }| |}||dfS )z#Input shape: Batch x Time x ChannelNr   r   r9   eagerr   )rj   r   r   r   )rd   r   r   r   r:   r   r   r   r1   _attn_implementationr
   r   rj   r   rc   r   r   )r    r   r   r   r   r   r   is_cross_attentionbsztgt_lensrc_lenq_input_shapekv_input_shapequery_statescurrent_states
key_statesvalue_statesattention_interfacer   r   r#   r#   r$   r(   2  s:   



zPatchTSMixerAttention.forward)r   FTFN)NNNF)r)   r*   r+   r,   r-   floatboolr   r   r   r=   r>   r   r	   tupler(   r.   r#   r#   r!   r$   r     sR    "	
r   c                       .   e Zd ZdZdef fddZdd Z  ZS )PatchMixerBlockzxThis module mixes the patch dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    s   t    t|| _|j| _|j| _t|j|j|d| _|jr(t	|j|jd| _
|jr>t|j|j|j|d| _t|| _d S d S )Nrv   rw   )r   r   rj   r1   )r   r   r]   ra   	self_attnrx   rf   rE   rz   r   r{   r   r5   self_attn_headsrj   self_attn_layer	norm_attnr8   r!   r#   r$   r   p  s(   

zPatchMixerBlock.__init__c                 C   s   |}|  |}| jr,|j\}}}}||| ||}| j|dd\}}	}	|||||}|dd}| |}| jr?| |}|dd}| jrO| 	|| }|| }
|
S )z
        Args:
            hidden_state (`torch.Tensor`): Input tensor.

        Returns:
            `torch.Tensor`: Transformed tensor.
        F)r   r9   r   )
ra   r   rd   rc   r   r:   rz   rx   r{   r   )r    r[   r}   
batch_sizen_varsrE   r5   hidden_state_reshapedx_attn_r~   r#   r#   r$   r(     s    


zPatchMixerBlock.forwardr)   r*   r+   r,   r   r   r(   r.   r#   r#   r!   r$   r   h  s    r   c                       r/   )FeatureMixerBlockzThis module mixes the hidden feature dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r1   c                    rt   ru   )
r   r   r]   ra   rx   rf   r5   rz   r   r{   r8   r!   r#   r$   r     s   

zFeatureMixerBlock.__init__hiddenc                 C   s4   |}|  |}| |}| jr| |}|| }|S )
        Args:
            hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
                Input tensor to the layer.

        Returns:
            `torch.Tensor`: Transformed tensor.
        )ra   rz   rx   r{   )r    r   r}   r~   r#   r#   r$   r(     s   	


zFeatureMixerBlock.forwardr<   r#   r#   r!   r$   r     s    r   c                       r/   )PatchTSMixerLayerz
    The `PatchTSMixer` layer that does all three kinds of mixing.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r1   c                    sH   t    t|d| _t|d| _|j| _|jdkr"t|d| _d S d S )Nr1   mix_channel)	r   r   r   patch_mixerr   feature_mixermoders   channel_feature_mixerr8   r!   r#   r$   r     s   

zPatchTSMixerLayer.__init__r   c                 C   s,   | j dkr
| |}| |}| |}|S )r   r   )r   r   r   r   )r    r   r#   r#   r$   r(     s
   
	


zPatchTSMixerLayer.forwardr<   r#   r#   r!   r$   r     s    	r   c                       s6   e Zd ZdZdef fddZd	defddZ  ZS )
PatchTSMixerBlockzThe main computing framework of the `PatchTSMixer` model.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    s2   t     j}t fddt|D | _d S )Nc                    s   g | ]}t  d qS )r   )r   .0r   r   r#   r$   
<listcomp>  s    z.PatchTSMixerBlock.__init__.<locals>.<listcomp>)r   r   
num_layersr   
ModuleListrangemixers)r    r1   r   r!   r   r$   r   	  s   
"zPatchTSMixerBlock.__init__Foutput_hidden_statesc                 C   s>   g }|}| j D ]}||}|r|| q|r||fS |dfS )as  
        Args:
            hidden_state (`torch.Tensor`): The input tensor.
            output_hidden_states (`bool`, *optional*, defaults to False.):
                Whether to output the hidden states as well.

        Returns:
            `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
            `True`.
        N)r   append)r    r[   r   all_hidden_states	embeddingmodr#   r#   r$   r(     s   

zPatchTSMixerBlock.forwardF)	r)   r*   r+   r,   r   r   r   r(   r.   r#   r#   r!   r$   r     s    r   c                       0   e Zd ZdZddef fddZdd Z  ZS )	PatchTSMixerForPredictionHeadzqPrediction Head for Forecasting

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr1   c                    s|   t    |j| _| jd ur| j  t|j| _|d u r+t|j	|j
 |j| _n
||j	|j
 | _tjdd| _d S )N	start_dim)r   r   prediction_channel_indicessortr   ri   head_dropoutdropout_layerr   rE   r5   prediction_lengthbase_forecast_blockget_parameter_projectionFlattenflatten)r    r1   distribution_outputr!   r#   r$   r   2  s   



z&PatchTSMixerForPredictionHead.__init__c                    s     |} |} |}t|trtdd |D }n|dd} jdurBt|tr;t fdd|D }|S |d jf }|S )ar  

        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
                or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.

        c                 s   s    | ]	}| d dV  qdS )r   r   N)r:   r   zr#   r#   r$   	<genexpr>U  s    z8PatchTSMixerForPredictionHead.forward.<locals>.<genexpr>r   r   Nc                 3   s    | ]
}|d  j f V  qdS ).N)r   r   r    r#   r$   r   [  s    .)r   r   r   
isinstancer   r:   r   r    hidden_featuresforecastr#   r   r$   r(   D  s   





z%PatchTSMixerForPredictionHead.forwardr%   r   r#   r#   r!   r$   r   *  s    r   c                       r   )	PatchTSMixerLinearHeadzLinear head for Classification and Regression.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr1   c                    s   t    |j| _|j| _|jd u r|j}nd}|| _|d u r.t|j|j	 | |j
| _n||j|j	 | | _|jd u rGtjdd| _ntjdd| _t|j| _d S )Nr   r   r   )r   r   head_aggregationoutput_rangerE   r   r   r   r5   ry   num_targets
projectionr   r   r   ri   r   rj   )r    r1   r   
mul_factorr!   r#   r$   r   j  s&   


zPatchTSMixerLinearHead.__init__c                 C   s   | dd}| jdkr|d }n| jdkr|jddj}n| jdkr(|jdd}| jr0| |}| |}| |}| jdu rX| j	durXt
|| j	d	 | j	d
   | j	d
  }|S )ai  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x num_targets)`.
        r   r   use_last).r   max_poolr   avg_poolNr   r   )r:   r   maxvaluesrU   r   rj   r   r   r   r=   sigmoid)r    r   r#   r#   r$   r(     s   






&zPatchTSMixerLinearHead.forwardr%   r   r#   r#   r!   r$   r   b  s    r   c                   @   s*   e Zd ZU eed< dZdZdZdd ZdS )PatchTSMixerPreTrainedModelr1   modelpast_valuesFc                 C   s   t |tr| jjdkrtjj|jddd dS dS t |tjtj	fr1|j
j  |jjd dS t |trG|jj
j  |jjjd dS t |tjre|jjjd| jjd |j
durg|j
j  dS dS dS )zInitialize weightsrG   r   g?)rU   rV         ?N)r   r?   r1   rL   r   initnormal_rB   rb   r4   r   datazero_weightfill_r0   r7   r   init_std)r    r   r#   r#   r$   _init_weights  s    


z)PatchTSMixerPreTrainedModel._init_weightsN)	r)   r*   r+   r   __annotations__base_model_prefixmain_input_namesupports_gradient_checkpointingr  r#   r#   r#   r$   r     s   
 r   c                       r   )PatchTSMixerPretrainHeadzcPretraining head.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    s.   t    t|j| _t|j|j| _	d S r%   )
r   r   r   ri   r   r   r   r5   patch_lengthbase_pt_blockr8   r!   r#   r$   r     s   
z!PatchTSMixerPretrainHead.__init__c                 C   s   |  |}| |}|S )a  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
        )r   r  r   r#   r#   r$   r(     s   

z PatchTSMixerPretrainHead.forwardr   r#   r#   r!   r$   r    s    r  Fr&   
mask_ratiounmasked_channel_indiceschannel_consistent_masking
mask_valuec                 C   s*  |dk s|dkrt d| d| j\}}}}| j}	t|d|  }
|r5tj|d||	d}|d|d}n	tj||||	d}tj||||	d}d|ddddd|
f< tj|dd}tj|dd}tj	|d|d	}|
dddd|}|durd|dd|ddddf< | | |}||d
 fS )a  random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    r   r   zMask ratio z has to be between 0 and 1.deviceNr   r   )r   index.r   )rW   rd   r  r-   r=   randrepeatonesargsortgatherrO   masked_fillr   )r&   r  r  r  r  r   num_channelssequence_lengthnum_featuresr  len_keepnoisemaskids_shuffleids_restoreinputs_maskr#   r#   r$   random_masking  s&   r%  num_forecast_mask_patchesc                 C   s  t |tr|g}dd |D }| j\}}}}tj|||| jd}	g }
d}t|}t||D ](\}}|dks9||krAtd| dt|| | }|
	|||g ||7 }q-t
|
dd d	}
||k rq|
d d
 ||  |
d d
< n||kr|
d d
 ||  |
d d
< d}|
D ]\}}}|| }d|	||dd| df< |}qt|	jd }|	| }	|	dddd|}	|durd|	dd|ddddf< | |	 |}||	d fS )a  Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    c                 S   s   g | ]}d qS )r   r#   r   r#   r#   r$   r   7  s    z$forecast_masking.<locals>.<listcomp>r  r   znum_forecast_mask_patches z6 should be greater than 0 and less than total patches.c                 S   s   | d S Nr9   r#   )xr#   r#   r$   <lambda>I  s    z"forecast_masking.<locals>.<lambda>)r   r9   r   r   Nr  )r   r-   rd   r=   rD   r  sumziprW   r   sortedrandpermrO   r  r  r   )r&   r&  r  r  forecast_mask_ratiosr   r  r  r  r!  t_listtotal_lengthtotal_ratior  ratiotemp_lenbatch1	patch_lenr   batch2permr$  r#   r#   r$   forecast_masking  sB   


r8  c                       r/   )PatchTSMixerPatchifyz
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    r1   c                    s   t    |j| _|j| _|j| _| j| jkr$td| j d| j dt| j| j| j | j d | _| j| j| jd   }| j| | _	d S )NzSequence length (z+) has to be greater than the patch length ()r   )
r   r   context_lengthr  r  patch_striderW   r   rE   sequence_start)r    r1   new_sequence_lengthr!   r#   r$   r   j  s   
 zPatchTSMixerPatchify.__init__r   c                 C   sp   |j d }|| jkrtd| d| j d|dd| jdddf }|jd| j| jd}|dd }|S )a!  
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        r   zInput sequence length (z%) doesn't match model configuration (r   N)	dimensionr   stepr   )	rd   r  rW   r=  unfoldr  r<  r:   r   )r    r   r  r;   r#   r#   r$   r(   {  s   
	
zPatchTSMixerPatchify.forwardr<   r#   r#   r!   r$   r9  b  s    r9  c                       r/   )PatchTSMixerMaskinga  
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSMixerConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    r1   c                    sX   t    |j| _|j| _|j| _|j| _|j| _|j| _| jd ur*t| j| _d S d S r%   )	r   r   random_mask_ratior  	mask_typer&  r  r  r,  r8   r!   r#   r$   r     s   

zPatchTSMixerMasking.__init__rZ   c                 C   sr   | j dkrt|| j| j| j| jd\}}n| j dkr(t|| j| j| jd\}}n	td| j  d|	 }||fS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        rG   )r&   r  r  r  r  r   )r&   r&  r  r  zInvalid mask type .)
rD  r%  rC  r  r  r  r8  r&  rW   r   )r    rZ   masked_inputr!  r#   r#   r$   r(     s$   

zPatchTSMixerMasking.forwardr<   r#   r#   r!   r$   rB    s    rB  c                	       P   e Zd ZdZdef fddZdejdejdeejejejf fdd	Z	  Z
S )
PatchTSMixerStdScalerz
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    r1   c                    sV   t    t|dr|jnd| _t|dr|jnd| _t|dr&|j| _d S d| _d S )Nscaling_dimr   keepdimTminimum_scalegh㈵>)r   r   hasattrrI  r   rJ  rK  r8   r!   r#   r$   r     s   
 zPatchTSMixerStdScaler.__init__r  observed_indicatorrF   c                 C   sz   |j | j| jd}|d}|| j | j| jd| }|| | d j | j| jd| }t|| j }|| | ||fS )C  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        rJ  r   r9   )r*  r   rJ  	clamp_minr=   sqrtrK  )r    r  rM  denominatorlocvariancescaler#   r#   r$   r(     s   
"zPatchTSMixerStdScaler.forwardr)   r*   r+   r,   r   r   r=   r>   r   r(   r.   r#   r#   r!   r$   rH    s    rH  c                	       rG  )
PatchTSMixerMeanScalerz
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    r1   c                    sl   t    t|dr|jnd| _t|dr|jnd| _t|dr#|jnd| _t|dr1|j| _d S d | _d S )NrI  r   rJ  TrK  绽|=default_scale)r   r   rL  rI  r   rJ  rK  rY  r8   r!   r#   r$   r     s
   
 zPatchTSMixerMeanScaler.__init__r  rM  rF   c           
      C   s   ||   j| jdd}|j| jdd}|tj|dd }| jdu r:|jdd}tj|ddd}t|| }n| jt| }t|dk||}tj|| j	d}|| }	| j
sa|j| jd}|	t||fS )rN  TrO  r   minNr   r   )absr*  r   r=   clamprY  squeeze	ones_likewhererK  rJ  
zeros_like)
r    r  rM  ts_sumnum_observedrU  	batch_sumbatch_observationsrY  scaled_datar#   r#   r$   r(     s   
zPatchTSMixerMeanScaler.forwardrV  r#   r#   r!   r$   rW    s    rW  c                
       sX   e Zd ZdZdef fddZ	ddejdeej de	ejejejf fd	d
Z
  ZS )PatchTSMixerNOPScalerz|
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    r1   c                    s@   t    t|dr|jnd| _t|dr|j| _d S d| _d S )NrI  r   rJ  T)r   r   rL  rI  r   rJ  r8   r!   r#   r$   r   0  s   
 zPatchTSMixerNOPScaler.__init__Nr  rM  rF   c                 C   sB   t j|ddj| j| jd}t j|ddj| j| jd}|||fS )a  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        FrH   )r   rJ  )r=   r_  rU   r   rJ  ra  )r    r  rM  rU  rS  r#   r#   r$   r(   5  s   
zPatchTSMixerNOPScaler.forwardr%   )r)   r*   r+   r,   r   r   r=   r>   r   r   r(   r.   r#   r#   r!   r$   rg  +  s    rg  zS
    Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
    )custom_introc                   @   s:   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dS )PatchTSMixerEncoderOutputa-  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
        Hidden-state at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    Nlast_hidden_stater   )r)   r*   r+   r,   rj  r   r=   FloatTensorr  r   r   r#   r#   r#   r$   ri  F  s   
 ri  c                       s\   e Zd ZdZdef fddZe		ddejde	e
 d	e	e
 d
eeef fddZ  ZS )PatchTSMixerEncoderz
    Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r1   c                    sb   t  | |j| _t|j|j| _|jrt	|d| _
nd | _
t|d| _|jr/|   d S d S )Nr   )r   r   use_return_dictr   r   r  r5   patcherr@   r?   positional_encoderr   mlp_mixer_encoder	post_initr8   r!   r#   r$   r   a  s   zPatchTSMixerEncoder.__init__FNr   r   return_dictrF   c                 C   sh   |dur|n| j }| |}| jdur| |}| j||d\}}|s.tdd ||fD S t||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to
            predict the masked portion. For a forecasting task, this denotes the history/past time series values.
            Similarly, for classification or regression tasks, it denotes the appropriate context values of the
            time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
            it is greater than 1.

        Returns:
            `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
        N)r   c                 s       | ]}|V  qd S r%   r#   r   vr#   r#   r$   r     
    
z.PatchTSMixerEncoder.forward.<locals>.<genexpr>)rj  r   )rm  rn  ro  rp  r   ri  )r    r   r   rr  patchesrj  r   r#   r#   r$   r(   q  s   


zPatchTSMixerEncoder.forward)FN)r)   r*   r+   r,   r   r   r   r=   r>   r   r   r   r   ri  r(   r.   r#   r#   r!   r$   rl  X  s    
rl  zG
    Base class for model's outputs, with potential hidden states.
    c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dS )	PatchTSMixerModelOutputa  
    last_hidden_state (`torch.FloatTensor`  of shape `(batch_size, num_channels, num_patches, d_model)`):
        Hidden-state at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Patched input data to the model.
    mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
        Bool Tensor indicating True in masked patches and False otherwise.
    loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
        enabled.
    scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
        enabled.
    Nrj  r   rZ   r!  rS  rU  )r)   r*   r+   r,   rj  r   r=   rk  r  r   r   rZ   r!  rS  rU  r#   r#   r#   r$   rx    s   
 rx  z=
    The PatchTSMixer Model for time-series forecasting.
    c                       sb   e Zd Zddedef fddZe			ddejde	ej d	e	e d
e	e de
f
ddZ  ZS )PatchTSMixerModelFr1   
mask_inputc                    s   t  | |j| _t|| _t|| _|du rt|| _nd| _|j	dkr,t
|| _n|j	dks6|j	du r<t|| _nt|| _|jrJ|   dS dS )z
        mask_input (bool, *optional*, defaults to `False`):
            Whether to mask the input using the [`PatchTSMixerMasking`] module.
        TNrU   rV   )r   r   rm  rl  encoderr9  patchingrB  maskingr   rW  scalerrH  rg  rq  )r    r1   rz  r!   r#   r$   r     s   



zPatchTSMixerModel.__init__Nr   observed_maskr   rr  rF   c                 C   s   |dur|n| j }d}|du rt|}| ||\}}}| |}	|	}
| jdur0| |	\}
}| j|
||d}t|trAt	| }|sTtdd |j
|j|	|||fD S t|j
|j|	|||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        Nr   rr  c                 s   rs  r%   r#   rt  r#   r#   r$   r     rv  z,PatchTSMixerModel.forward.<locals>.<genexpr>)rj  r   rZ   r!  rS  rU  )rm  r=   r_  r~  r|  r}  r{  r   r   ri  rj  r   rx  )r    r   r  r   rr  r!  scaled_past_valuesrS  rU  	patched_x	enc_inputencoder_outputr#   r#   r$   r(     sD   



zPatchTSMixerModel.forwardr   )NFN)r)   r*   r+   r   r   r   r   r=   r>   r   rx  r(   r.   r#   r#   r!   r$   ry    s"    ry  z>
    Output type of [`PatchTSMixerForPreTrainingOutput`].
    c                   @   ^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS ) PatchTSMixerForPreTrainingOutputa@  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
        Prediction output from the pretrain head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer.
    Nlossprediction_outputsrj  r   r)   r*   r+   r,   r  r   r=   rk  r  r  rj  r   r   r#   r#   r#   r$   r  #     
 r  z.
    `PatchTSMixer` for mask pretraining.
    c                       sb   e Zd Zdef fddZe				ddejdeej d	ee	 d
e	dee	 de
fddZ  ZS )PatchTSMixerForPretrainingr1   c                    sL   t  | t|dd| _t|d| _|j| _|j| _|jr$|   d S d S )NT)rz  r   )	r   r   ry  r   r  headmasked_lossrm  rq  r8   r!   r#   r$   r   A  s   z#PatchTSMixerForPretraining.__init__NFTr   r  r   return_lossrr  rF   c           
      C   s   |dur|n| j }| jdu rtjjdd}ntjjdd}| j||||d}t|tr/t| }| 	|j
}|du r@|||j}	nd}	| jdu r]|	dur]|	jdd|j  |j d	  }	|sntd
d |	||j
|jfD S t|	||j
|jdS )aT  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        NTnone	reductionrU   r  r   rr  r   r   rX  c                 s   rs  r%   r#   rt  r#   r#   r$   r     rv  z5PatchTSMixerForPretraining.forward.<locals>.<genexpr>r  r  rj  r   )rm  r  r=   r   MSELossr   r   r   rx  r  rj  rZ   rU   r!  r*  r   r  )
r    r   r  r   r  rr  r  model_outputx_hatloss_valr#   r#   r$   r(   L  s@   

$
z"PatchTSMixerForPretraining.forwardNFTN)r)   r*   r+   r   r   r   r=   r>   r   r   r  r(   r.   r#   r#   r!   r$   r  ;  s(    r  z=
    Output type of [`PatchTSMixerForPredictionOutput`].
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeej ed< dZeej ed< dS )	PatchTSMixerForPredictionOutputaD  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
        Prediction output from the forecast head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
        Input mean
    scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
        Input std dev
    Nr  r  rj  r   rS  rU  )r)   r*   r+   r,   r  r   r=   rk  r  r  rj  r   r   rS  rU  r#   r#   r#   r$   r    s   
 r  z
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.
    c                   @   $   e Zd ZU dZdZeej ed< dS )"SamplePatchTSMixerPredictionOutput
    sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
        Sampled values from the chosen distribution.
    N	sequences	r)   r*   r+   r,   r  r   r=   rk  r  r#   r#   r#   r$   r       
 r  c                   @   r  )"SamplePatchTSMixerRegressionOutputr  Nr  r  r#   r#   r#   r$   r    r  r  inputtargetrF   c                 C   s   |  | S )zc
    Computes the negative log likelihood loss from input distribution with respect to target.
    )log_prob)r  r  r#   r#   r$   nll  s   r  input_tensorweightsc                 C   sr   |dur3t |dk| | t | }t j|r|j|dn| dd}|r-|j|d| S | | S | j|dS )aj  
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.

    Args:
        input_tensor (`torch.FloatTensor`):
            Input tensor, of which the average must be computed.
        weights (`torch.FloatTensor`, *optional*):
            Weights tensor, of the same shape as `input_tensor`.
        dim (`int`, *optional*):
            The dim along which to average `input_tensor`.

    Returns:
        `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
    Nr   r   r   rZ  )r=   r`  ra  r]  r*  rU   )r  r  r   weighted_tensorsum_weightsr#   r#   r$   weighted_average  s
   " r  c                       s   e Zd ZdZdef fddZe					ddejd	e	ej d
e	ej de	e
 de
de	e
 defddZe 	ddejd	e	ej defddZ  ZS )PatchTSMixerForPredictionz
    `PatchTSMixer` for forecasting application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r1   c                    s   t  | |j| _|j| _|j| _|j| _|jdkrd | _n#|j}tt	t
d}||jd }|d ur:||d| _ntd|j t|| _t|| jd| _|jrX|   d S d S )Nmse	student_tnormalnegative_binomialr   Unknown distribution output r1   r   )r   r   r  rm  r   num_parallel_samplesr   r   r   r   r   getrW   ry  r   r   r  rq  )r    r1   r   distribution_output_mapoutput_classr!   r#   r$   r     s0   

z"PatchTSMixerForPrediction.__init__NFTr   r  future_valuesr   r  rr  rF   c                 C   s  | j dkrtjdd}n| j dkrt}ntd|dur|n| j}| j||||d}t|tr3t	| }| 
|j}	d}
| jdur| jro| jj|	|jd| jf |jd| jf d	}|durn|d
u rn|||d| jf }
t|
}
nZ|	|jd| jf  |jd| jf  }	|dur|d
u r||	|d| jf }
n5| jr| jj|	|j|jd	}|dur|d
u r|||}
t|
}
n|	|j |j }	|dur|d
u r||	|}
| jdur|jd| jf }|jd| jf }n|j}|j}|stdd |
|	|j|j||fD S t|
|	|j|j||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `future_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        r  rU   r  r  2Invalid loss function: Allowed values: mse and nllNr  .rS  rU  Tc                 s   rs  r%   r#   rt  r#   r#   r$   r     rv  z4PatchTSMixerForPrediction.forward.<locals>.<genexpr>)r  r  rj  r   rS  rU  )r  r   r  r  rW   rm  r   r   r   rx  r  rj  r   r   distributionrS  rU  r  r   r  )r    r   r  r  r   r  rr  r  r  y_hatr  r  rS  rU  r#   r#   r$   r(     s   
$






z!PatchTSMixerForPrediction.forwardc                    s\   | j }| |d|dd}| jj|j|j|jd  fddt|D }tj|dd}t	|d	S )
a  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.

            observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, prediction_length, num_input_channels)`.
        NF)r   r  r  r   r  c                       g | ]}   qS r#   sampler   r  r#   r$   r     s    z6PatchTSMixerForPrediction.generate.<locals>.<listcomp>r   r   r  )
r  r   r  r  rS  rU  r   r=   stackr  )r    r   r  r  outputssamplesr#   r  r$   generate  s   	
z"PatchTSMixerForPrediction.generate)NNFTNr%   )r)   r*   r+   r,   r   r   r   r=   r>   r   r   r  r(   no_gradr  r  r.   r#   r#   r!   r$   r    sB     yr  zK
    Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
    c                   @   r  )-PatchTSMixerForTimeSeriesClassificationOutputaP  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
        Prediction output from the classification head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nr  r  rj  r   r  r#   r#   r#   r$   r    r  r  c                       sf   e Zd ZdZdef fddZe				ddejd	e	ej d
e	e
 de
de	e
 defddZ  ZS )'PatchTSMixerForTimeSeriesClassificationz
    `PatchTSMixer` for classification application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r1   c                    sd   t  | t|| _t|d| _|j| _|jdv r$t|j	|j
d| _nd | _|jr0|   d S d S )Nr   rV   rU   Tr5   rE   )r   r   ry  r   r   r  rm  r   InjectScalerStatistics4Dr5   rE   inject_scalerq  r8   r!   r#   r$   r     s   

z0PatchTSMixerForTimeSeriesClassification.__init__NFTr   target_valuesr   r  rr  rF   c           
      C   s   t j }|dur|n| j}| j|||d}t|trt| }| jdur0| j|j	|j
|jd|_	| |j	}|durD|du rD|||}	nd}	|sWtdd |	||j	|jfD S t|	||j	|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target
            values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        Nr  r  Tc                 s   rs  r%   r#   rt  r#   r#   r$   r   >  rv  zBPatchTSMixerForTimeSeriesClassification.forward.<locals>.<genexpr>r  )r=   r   CrossEntropyLossrm  r   r   r   rx  r  rj  rS  rU  r  r   r  )
r    r   r  r   r  rr  r  r  r  r  r#   r#   r$   r(     sB   
$


z/PatchTSMixerForTimeSeriesClassification.forwardr  )r)   r*   r+   r,   r   r   r   r=   r>   r   r   r  r(   r.   r#   r#   r!   r$   r    s*    r  z=
    Output type of [`PatchTSMixerForRegressionOutput`].
    c                   @   r  )PatchTSMixerForRegressionOutputaM  
    loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
        Total loss.
    regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Prediction output from the regression head.
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
        Backbone embeddings before passing through the head.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*):
        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nr  regression_outputsrj  r   )r)   r*   r+   r,   r  r   r=   rk  r  r  rj  r   r   r#   r#   r#   r$   r  P  r  r  c                       sH   e Zd Zddededef fddZdejdejd	ejfd
dZ  ZS )r  r9   r5   rE   	expansionc                    s`   t    t|d || | _t|| || _tdd| | _td| d| _|| _d S r'  )	r   r   r   r   inverse_trans_expansioninverse_trans_compressionmap_scale_expansionmap_scale_compressionrE   )r    r5   rE   r  r!   r#   r$   r   i  s   

z!InjectScalerStatistics4D.__init__r&   rS  rU  c                 C   s   | dd}|d}|dd| jd}| dd}|d}|dd| jd}tj||gdd}| |}| |}tj||gdd}| |}| 	|}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
            loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
            scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
        Returns:
            `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
        r   r   r   r   )
r:   rO   r  rE   r=   catr  r  r  r  )r    r&   rS  rU  rU   stdevconcat_statsr#   r#   r$   r(   r  s   






z InjectScalerStatistics4D.forward)r9   )	r)   r*   r+   r-   r   r=   r>   r(   r.   r#   r#   r!   r$   r  h  s    $	r  z4
    `PatchTSMixer` for regression application.
    c                       s~   e Zd Zdef fddZe				ddejdeej d	ee	 d
e	dee	 de
fddZe dejdefddZ  ZS )PatchTSMixerForRegressionr1   c                    s   t  | t|| _|j| _|j| _|j| _|j| _|jdkr$d | _n tt	t
d}||j}|d ur<||jd| _ntd|j |jdv rSt|j|jd| _nd | _t|| jd| _|jrg|   d S d S )Nr  r  r   r  r  r  r  )r   r   ry  r   r  r   rm  r  r   r   r   r  r   rW   r   r  r5   rE   r  r   r  rq  )r    r1   r  r  r!   r#   r$   r     s4   


z"PatchTSMixerForRegression.__init__NFTr   r  r   r  rr  rF   c                    sD   j dkrtjdd}n j dkrt}ntd|dur|n j} j|||d}t|tr2t	| } j
durC j
|j|j|jd|_ |j}|dur|d	u r jr jd
krdt|dk rdtd j|}	t fdd|D }||	|}
t|
}
n|||}
nd}
|stdd |
||j|jfD S t|
||j|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        r  rU   r  r  r  Nr  r  Tr  r   zDtarget_values cannot be negative for negative_binomial distribution.c                 3   s     | ]}| d  jjV  qdS )r   N)r   r1   r   )r   itemr   r#   r$   r     s    z4PatchTSMixerForRegression.forward.<locals>.<genexpr>c                 s   rs  r%   r#   rt  r#   r#   r$   r   	  rv  )r  r  rj  r   )r  r   r  r  rW   rm  r   r   r   rx  r  rj  rS  rU  r  r   r=   any	Exceptionr  r  r   r  )r    r   r  r   r  rr  r  r  r  r  r  r#   r   r$   r(     sX   
#





z!PatchTSMixerForRegression.forwardc                    s^   | j }| |ddd}| j|j  fddt|D }tj|ddd|| jj	}t
|d	S )
a
  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the target values.

        Return:
            [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, num_targets)`.
        NF)r   r  r   c                    r  r#   r  r   r  r#   r$   r   8  s    z6PatchTSMixerForRegression.generate.<locals>.<listcomp>r   r   r   r  )r  r   r  r  r   r=   r  r   r1   r   r  )r    r   r  r  r  r#   r  r$   r    s   

z"PatchTSMixerForRegression.generater  )r)   r*   r+   r   r   r   r=   r>   r   r   r  r(   r  r  r  r.   r#   r#   r!   r$   r    s4    '\r  )r   ry  r  r  r  r  )Nr   N)NFr   )Nr   )NN)Rr,   rQ   dataclassesr   typingr   r   r   r=   torch.nnr   transformers.modeling_utilsr   transformers.utilsr   modeling_flash_attention_utilsr	   modeling_utilsr
   processing_utilsr   time_series_utilsr   r   r   utilsr   r   configuration_patchtsmixerr   
get_loggerr)   loggerModuler   r0   r?   r]   rf   rs   r>   r   r   r   r   r   r   r   r   r   r   r  listr   r-   r%  r8  r9  rB  rH  rW  rg  ri  rl  rx  ry  r  r  r  r  r  distributionsDistributionr  r  r  r  r  r  r  r  __all__r#   r#   r#   r$   <module>   s  
'17
XF-&)8G"
>

E1=$7EaT	
" Xn( -