o
    iJ                    @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZmZ ddlmZmZmZ ddlmZ ee Z!			dtde	j"dej#dej#dej#deej# dee$ de$deej# fddZ%G dd de	j"Z&G dd de	j"Z'			dud ej#d!e$d"ee( d#e)d$e*f
d%d&Z+		dvd ej#d'ee(e*f d"ee( d$e*fd(d)Z,G d*d+ d+e	j"Z-G d,d- d-e	j"Z.G d.d/ d/e	j"Z/eG d0d1 d1eZ0G d2d3 d3e	j"Z1G d4d5 d5e	j"Z2G d6d7 d7e0Z3eed8d9G d:d; d;eZ4eed<d9G d=d> d>eZ5eed?d9G d@dA dAeZ6eedBd9G dCdD dDeZ7eedEd9G dFdG dGeZ8eedHd9G dIdJ dJeZ9dKej:j;dLej#dMej#fdNdOZ<dwdPej#dQeej# dMej#fdRdSZ=G dTdU dUe	j"Z>G dVdW dWe	j"Z?G dXdY dYe	j"Z@G dZd[ d[e	j"ZAeG d\d] d]e0ZBG d^d_ d_e	j"ZCed`d9G dadb dbe0ZDG dcdd dde	j"ZEeded9G dfdg dge0ZFedhd9G didj dje	j"ZGedkd9G dldm dme0ZHG dndo doe	j"ZIedpd9G dqdr dre0ZJg dsZKdS )xzPyTorch PatchTST model.    N)	dataclass)CallableOptionalUnion)nn   )ACT2CLS)FlashAttentionKwargs)BaseModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)NegativeBinomialOutputNormalOutputStudentTOutput)ModelOutputauto_docstringlogging   )PatchTSTConfig        modulequerykeyvalueattention_maskscalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur|	| }	tjj|	dd}	|d ur5|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )N         r   dimr   )ptraining)sizetorchmatmul	transposer   
functionalsoftmaxviewr   r%   
contiguous)r   r   r   r   r   r   r   r   kwargsattn_weightsattn_output r1   k/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/transformers/models/patchtst/modeling_patchtst.pyeager_attention_forward&   s   r3   c                       s   e Zd ZdZ					ddededed	ed
ededee f fddZ					dde
jdee
j dee
j dee
j dee dee dee
jee
j eee
j  f fddZ  ZS )PatchTSTAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN	embed_dim	num_headsr   
is_decoderbias	is_causalconfigc                    s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).r    r8   )super__init__r5   r6   r   head_dimr:   
ValueErrorr   r7   r9   r   Lineark_projv_projq_projout_proj)selfr5   r6   r   r7   r8   r9   r:   	__class__r1   r2   r>   H   s&   



zPatchTSTAttention.__init__hidden_stateskey_value_statesr   layer_head_maskoutput_attentionsr.   returnc                 K   s  |du}|j dd \}}	|r|j d n|	}
||	d| jf}||
d| jf}| |j| dd}|r4|n|}| |j| dd}| |j| dd}t}| jj	dkr\t
| jj	 }|| ||||f| jshdn| j| j||d|\}}|||	d }| |}||dfS )z#Input shape: Batch x Time x ChannelNr   r   r!   eagerr   )r   r   rL   r   )shaper?   rD   r,   r)   rB   rC   r3   r:   _attn_implementationr   r%   r   r   reshaper-   rE   )rF   rI   rJ   r   rK   rL   r.   is_cross_attentionbsztgt_lensrc_lenq_input_shapekv_input_shapequery_statescurrent_states
key_statesvalue_statesattention_interfacer0   r/   r1   r1   r2   forwardg   s:   



zPatchTSTAttention.forward)r   FTFN)NNNF)__name__
__module____qualname____doc__intfloatboolr   r   r>   r'   Tensorr   r	   tupler]   __classcell__r1   r1   rG   r2   r4   E   sR    "	
r4   c                       6   e Zd ZdZdef fddZdejfddZ  Z	S )PatchTSTBatchNormzP
    Compute batch normalization over the sequence length (time) dimension.
    r:   c                    s"   t    tj|j|jd| _d S )Neps)r=   r>   r   BatchNorm1dd_modelnorm_eps	batchnormrF   r:   rG   r1   r2   r>      s   
zPatchTSTBatchNorm.__init__inputsc                 C   s"   | dd}| |}| ddS )a  
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        r   r!   )r)   ro   )rF   rq   outputr1   r1   r2   r]      s   
zPatchTSTBatchNorm.forward
r^   r_   r`   ra   r   r>   r'   re   r]   rg   r1   r1   rG   r2   ri      s    ri   Frq   
mask_ratiounmasked_channel_indiceschannel_consistent_masking
mask_valuec                 C   s*  |dk s|dkrt d| d| j\}}}}| j}	t|d|  }
|r5tj|d||	d}|d|d}n	tj||||	d}tj||||	d}d|ddddd|
f< tj|dd}tj|dd}tj	|d|d	}|
dddd|}|durd|dd|ddddf< | | |}||d
 fS )a  random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    r   r   zMask ratio z has to be between 0 and 1.deviceNr   r"   )r#   index.r   )r@   rO   ry   rb   r'   randrepeatonesargsortgather	unsqueezemasked_fillrd   )rq   rt   ru   rv   rw   
batch_sizenum_channelssequence_lengthnum_featuresry   len_keepnoisemaskids_shuffleids_restoreinputs_maskr1   r1   r2   random_masking   s&   r   num_forecast_mask_patchesc                 C   s  t |tr|g}dd |D }| j\}}}}tj|||| jd}	g }
d}t|}t||D ](\}}|dks9||krAtd| dt|| | }|
	|||g ||7 }q-t
|
dd d	}
||k rq|
d d
 ||  |
d d
< n||kr|
d d
 ||  |
d d
< d}|
D ]\}}}|| }d|	||dd| df< |}qt|	jd }|	| }	|	dddd|}	|durd|	dd|ddddf< | |	 |}||	d fS )a  Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    c                 S   s   g | ]}d qS )r   r1   .0_r1   r1   r2   
<listcomp>  s    z$forecast_masking.<locals>.<listcomp>rx   r   znum_forecast_mask_patches z6 should be greater than 0 and less than total patches.c                 S   s   | d S )Nr!   r1   )xr1   r1   r2   <lambda>  s    z"forecast_masking.<locals>.<lambda>)r   r!   r   r   Nr{   )
isinstancerb   rO   r'   zerosry   sumzipr@   appendsortedrandpermr   r}   r   rd   )rq   r   ru   rw   forecast_mask_ratiosr   r   r   r   r   t_listtotal_lengthtotal_ratiopatch_lengthratiotemp_lenbatch1	patch_lenr   batch2permr   r1   r1   r2   forecast_masking   sB   


r   c                       rh   )PatchTSTPatchifyz
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    r:   c                    s   t    |j| _|j| _|j| _| j| jkr$td| j d| j dt| j| j| j | j d | _| j| j| jd   }| j| | _	d S )NzSequence length (z+) has to be greater than the patch length ()r   )
r=   r>   context_lengthr   r   patch_strider@   maxnum_patchessequence_start)rF   r:   new_sequence_lengthrG   r1   r2   r>   9  s   
 zPatchTSTPatchify.__init__past_valuesc                 C   sp   |j d }|| jkrtd| d| j d|dd| jdddf }|jd| j| jd}|dd }|S )a!  
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        zInput sequence length (z%) doesn't match model configuration (r;   N)	dimensionr&   step)	rO   r   r@   r   unfoldr   r   r)   r-   )rF   r   r   rr   r1   r1   r2   r]   J  s   
	
zPatchTSTPatchify.forwardrs   r1   r1   rG   r2   r   1  s    r   c                       rh   )PatchTSTMaskinga  
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSTConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    r:   c                    sX   t    |j| _|j| _|j| _|j| _|j| _|j| _| jd ur*t| j| _d S d S N)	r=   r>   random_mask_ratiorv   	mask_typer   ru   rw   r   rp   rG   r1   r2   r>   n  s   

zPatchTSTMasking.__init__patch_inputc                 C   sr   | j dkrt|| j| j| j| jd\}}n| j dkr(t|| j| j| jd\}}n	td| j  d|	 }||fS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        random)rq   rt   ru   rv   rw   forecast)rq   r   ru   rw   zInvalid mask type .)
r   r   r   ru   rv   rw   r   r   r@   rd   )rF   r   masked_inputr   r1   r1   r2   r]   y  s$   

zPatchTSTMasking.forwardrs   r1   r1   rG   r2   r   a  s    r   c                       s@   e Zd ZdZdef fddZd
dejdee	 fdd	Z
  ZS )PatchTSTEncoderLayerz 
    PatchTST encoder layer
    r:   c              
      s  t    |j| _t|j|j|j|d| _|jdkr t	
|jnt	 | _|jdkr0t|| _n|jdkr@t	j|j|jd| _nt|j d| jr~|jdkrVt	
|jnt	 | _|jdkrft|| _n|jdkrvt	j|j|jd| _nt|j dt	t	j|j|j|jdt|j  |jdkrt	
|jnt	 t	j|j|j|jd| _|jdkrt	
|jnt	 | _|jdkrt|| _n|jdkrt	j|j|jd| _nt|j d|j| _d S )N)r5   r6   r   r:   r   ro   	layernormrj   z$ is not a supported norm layer type.r<   ) r=   r>   channel_attentionr4   rm   num_attention_headsattention_dropout	self_attnpath_dropoutr   DropoutIdentitydropout_path1	norm_typeri   norm_sublayer1	LayerNormrn   r@   dropout_path2norm_sublayer2
SequentialrA   ffn_dimr8   r   activation_function
ff_dropoutffdropout_path3norm_sublayer3pre_normrp   rG   r1   r2   r>     sD   
 

 


 

zPatchTSTEncoderLayer.__init__Nhidden_staterL   c                 C   s  |j \}}}}||| ||}| jr(| j| ||d\}}}	|| | }n| j||d\}}}	| || | }|||||}| jr|dd	 }||| ||}| jrp| j| 
||d\}}
}	|| | }n| j||d\}}
}	| 
|| | }|||||}|dd	 }||| ||}| jr|| | | | }n| || | | }|||||}|f}|r|| jr||
fn|f7 }|S )a  
        Parameters:
            hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*):
                Past values of the time series
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
        Return:
            `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`

        )rI   rL   r!   r   )rO   r,   r   r   r   r   rQ   r   r)   r-   r   r   r   r   r   )rF   r   rL   r   num_input_channelsr   rm   r0   r/   r   channel_attn_weightsoutputsr1   r1   r2   r]     sF   

zPatchTSTEncoderLayer.forwardr   )r^   r_   r`   ra   r   r>   r'   re   r   rd   r]   rg   r1   r1   rG   r2   r     s    "2r   c                   @   s<   e Zd ZU eed< dZdZdZdej	fddZ
ddd	Zd
S )PatchTSTPreTrainedModelr:   modelr   Fr   c                 C   s   t |tr3t| jj| jj| jj | jj d }| jjr)tj	j
|jdd |d7 }|| j||_dS t |tjrH|jj  |jjd dS t |tr^|jjj  |jjjd dS t |tjr||jjj
d| jjd |jdur~|jj  dS dS dS )z$
        Initialize weights
        r   g{Gz?)std      ?r   )meanr   N)r   PatchTSTPositionalEncodingr   r:   r   r   r   use_cls_tokenr   initnormal_	cls_token_init_peposition_encr   r8   datazero_weightfill_ri   ro   rA   init_std)rF   r   r   r1   r1   r2   _init_weights/  s,   


z%PatchTSTPreTrainedModel._init_weightsc                 C   s   t |tr
||_d S d S r   )r   PatchTSTEncodergradient_checkpointing)rF   r   r   r1   r1   r2   _set_gradient_checkpointingI  s   

z3PatchTSTPreTrainedModel._set_gradient_checkpointingN)F)r^   r_   r`   r   __annotations__base_model_prefixmain_input_namesupports_gradient_checkpointingr   Moduler   r   r1   r1   r1   r2   r   (  s   
 r   c                       2   e Zd Zdef fddZdejfddZ  ZS )PatchTSTEmbeddingr:   c                    sl   t    |j| _|j| _| jrt|j|j| _d S t	 | _t
|jD ]}| jt|j|j q%d S r   )r=   r>   r   share_embeddingr   rA   r   rm   input_embedding
ModuleListranger   )rF   r:   r   rG   r1   r2   r>   O  s   

zPatchTSTEmbedding.__init__r   c                    sj    j d }|jkrtdj d| djr  }|S  fddt|D }tj|dd}|S )a%  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input for embedding
        return:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
        r   z&The defined number of input channels (zQ) in the config has to be the same as the number of channels in the batch input (r   c              	      s2   g | ]}j |  d d |d d d d f qS r   )r   r   ir   rF   r1   r2   r   m  s   2 z-PatchTSTEmbedding.forward.<locals>.<listcomp>r"   )rO   r   r@   r   r   r   r'   stack)rF   r   r   
embeddingsr1   r   r2   r]   [  s   
	


zPatchTSTEmbedding.forward	r^   r_   r`   r   r>   r'   re   r]   rg   r1   r1   rG   r2   r   N  s    r   c                       sV   e Zd ZdZdedef fddZedededej	fddZ
d	ejfd
dZ  ZS )r   z'
    Class for positional encoding
    r:   r   c                    sz   t    |j| _|j| _|jr!ttddd|j| _	|d7 }| 
||| _|jdkr6t|j| _d S t | _d S )Nr   r   )r=   r>   r   r   r   	Parameterr'   r   rm   r   r   r   positional_dropoutr   r   rF   r:   r   rG   r1   r2   r>   w  s   
z#PatchTSTPositionalEncoding.__init__rM   c                 C   s   | j dkrtjt|| jdd}|S | j dkrst|| j}td|d}t	td| jdt
d| j   }t|| |d d dd df< t|| |d d dd df< ||  }|| d	  }tj|d
d}|S t| j  d)Nr   Trequires_gradsincosr   r   r!   g     @
   FzN is not a valid positional encoder. Available types are 'random' and 'sincos'.)positional_encoding_typer   r   r'   randnrm   r   aranger   expmathlogsincosr   r   r@   )r:   r   r   positiondiv_termr1   r1   r2   r     s    

(  
z#PatchTSTPositionalEncoding._init_per   c                 C   s   | j r8| || jdd d d f  }| j| jd dd d f  }||jd | jdd}tj||fdd}|S | || j }|S )Nr   r   r   r!   r"   )	r   r  r   r   expandrO   r   r'   cat)rF   r   r   
cls_tokensr   r1   r1   r2   r]     s    z"PatchTSTPositionalEncoding.forward)r^   r_   r`   ra   r   rb   r>   staticmethodr   r   r   r'   re   r]   rg   r1   r1   rG   r2   r   r  s    r   c                	       sT   e Zd ZdZdedef fddZ		ddejde	e
 d	e	e
 d
efddZ  ZS )r   z
    PatchTST Encoder
    r:   r   c                    sT   t    d| _t | _t || _t fddt	 j
D | _|   d S )NFc                    s   g | ]}t  qS r1   )r   r   r:   r1   r2   r         z,PatchTSTEncoder.__init__.<locals>.<listcomp>)r=   r>   r   r   embedderr   positional_encoderr   r   r   num_hidden_layerslayers	post_initr  rG   r  r2   r>     s   
 zPatchTSTEncoder.__init__Nr   output_hidden_statesrL   rM   c           	      C   s   |dur|n| j j}|dur|n| j j}| |}| |}|r"dnd}|r(dnd}| jD ]}|r6||f }|||d}|d }|rI||d f }q-t|||dS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Past values of the time series
            output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
            output_attentions (bool, optional): Indicates if attentions should be outputted.

        return:
            `BaseModelOutput`
        Nr1   )r   rL   r   r   )last_hidden_staterI   
attentions)r:   rL   r  r  r  r  r
   )	rF   r   r  rL   r   encoder_statesall_attentionsencoder_layerlayer_outputsr1   r1   r2   r]     s    



zPatchTSTEncoder.forwardNN)r^   r_   r`   ra   r   rb   r>   r'   re   r   rd   r
   r]   rg   r1   r1   rG   r2   r     s    r   zG
    Base class for model's outputs, with potential hidden states.
    )custom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dS )
PatchTSTModelOutputa>  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Sequence of hidden-states at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
        one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
        the model at the output of each layer plus the optional initial embedding outputs.
    mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
        Bool masked tensor indicating which patches are masked
    loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Patched input to the Transformer
    Nr  rI   r  r   locscaler   )r^   r_   r`   ra   r  r   r'   FloatTensorr   rI   rf   r  r   r&  r'  r   r1   r1   r1   r2   r%    s   
 r%  z4
    Output type of [`PatchTSTForPretraining`].
    c                   @   b   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )PatchTSTForPretrainingOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    prediction_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction outputs of the time series modeling heads.
    Nlossprediction_outputrI   r  )r^   r_   r`   ra   r+  r   r'   r(  r   r,  rI   rf   r  r1   r1   r1   r2   r*  	     
 r*  z3
    Output type of [`PatchTSTForRegression`].
    c                   @   r)  )PatchTSTForRegressionOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Regression outputs of the time series modeling heads.
    Nr+  regression_outputsrI   r  )r^   r_   r`   ra   r+  r   r'   r(  r   r/  rI   rf   r  r1   r1   r1   r2   r.    r-  r.  z3
    Output type of [`PatchTSTForPrediction`].
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeej ed< dZeej ed< dS )	PatchTSTForPredictionOutputa!  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
        Prediction outputs of the time series modeling heads.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
        heads.
    loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
        Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
        Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    Nr+  prediction_outputsrI   r  r&  r'  )r^   r_   r`   ra   r+  r   r'   r(  r   r1  rI   rf   r  r&  r'  r1   r1   r1   r2   r0  1  s   
 r0  z7
    Output type of [`PatchTSTForClassification`].
    c                   @   r)  )PatchTSTForClassificationOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        Total loss as the sum of the masked language modeling loss and the next sequence prediction
        (classification) loss.
    prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Prediction scores of the PatchTST modeling head (scores before SoftMax).
    Nr+  prediction_logitsrI   r  )r^   r_   r`   ra   r+  r   r'   r(  r   r3  rI   rf   r  r1   r1   r1   r2   r2  Q  s   
 r2  z
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.
    c                   @   s$   e Zd ZU dZdZeej ed< dS )SamplePatchTSTOutputz
    sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, num_targets)`):
        Sampled values from the chosen distribution.
    N	sequences)	r^   r_   r`   ra   r5  r   r'   r(  r   r1   r1   r1   r2   r4  f  s   
 r4  inputtargetrM   c                 C   s   |  | S )zc
    Computes the negative log likelihood loss from input distribution with respect to target.
    )log_prob)r6  r7  r1   r1   r2   nllw  s   r9  input_tensorweightsc                 C   sr   |dur3t |dk| | t | }t j|r|j|dn| dd}|r-|j|d| S | | S | j|dS )aj  
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.

    Args:
        input_tensor (`torch.FloatTensor`):
            Input tensor, of which the average must be computed.
        weights (`torch.FloatTensor`, *optional*):
            Weights tensor, of the same shape as `input_tensor`.
        dim (`int`, *optional*):
            The dim along which to average `input_tensor`.

    Returns:
        `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
    Nr   r"   r   min)r'   where
zeros_likeclampr   r   )r:  r;  r#   weighted_tensorsum_weightsr1   r1   r2   weighted_average  s
   " rC  c                	       P   e Zd ZdZdef fddZdejdejdeejejejf fdd	Z	  Z
S )
PatchTSTStdScalerz
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    r:   c                    sV   t    t|dr|jnd| _t|dr|jnd| _t|dr&|j| _d S d| _d S )Nscaling_dimr   keepdimTminimum_scalegh㈵>)r=   r>   hasattrrF  r#   rG  rH  rp   rG   r1   r2   r>     s   
 zPatchTSTStdScaler.__init__r   observed_indicatorrM   c                 C   sz   |j | j| jd}|d}|| j | j| jd| }|| | d j | j| jd| }t|| j }|| | ||fS )C  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        rG  r   r!   )r   r#   rG  	clamp_minr'   sqrtrH  )rF   r   rJ  denominatorr&  variancer'  r1   r1   r2   r]     s   
"zPatchTSTStdScaler.forwardr^   r_   r`   ra   r   r>   r'   re   rf   r]   rg   r1   r1   rG   r2   rE    s    rE  c                	       rD  )
PatchTSTMeanScalerz
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    r:   c                    sl   t    t|dr|jnd| _t|dr|jnd| _t|dr#|jnd| _t|dr1|j| _d S d | _d S )NrF  r   rG  TrH  绽|=default_scale)r=   r>   rI  rF  r#   rG  rH  rT  rp   rG   r1   r2   r>     s
   
 zPatchTSTMeanScaler.__init__r   rJ  rM   c           
      C   s   ||   j| jdd}|j| jdd}|tj|dd }| jdu r:|jdd}tj|ddd}t|| }n| jt| }t|dk||}tj|| j	d}|| }	| j
sa|j| jd}|	t||fS )rK  TrL  r   r<  Nr   r"   )absr   r#   r'   r@  rT  squeeze	ones_liker>  rH  rG  r?  )
rF   r   rJ  ts_sumnum_observedr'  	batch_sumbatch_observationsrT  scaled_datar1   r1   r2   r]     s   
zPatchTSTMeanScaler.forwardrQ  r1   r1   rG   r2   rR    s    rR  c                
       sX   e Zd ZdZdef fddZ	ddejdeej de	ejejejf fd	d
Z
  ZS )PatchTSTNOPScalerz|
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    r:   c                    s@   t    t|dr|jnd| _t|dr|j| _d S d| _d S )NrF  r   rG  T)r=   r>   rI  rF  r#   rG  rp   rG   r1   r2   r>     s   
 zPatchTSTNOPScaler.__init__Nr   rJ  rM   c                 C   sB   t j|ddj| j| jd}t j|ddj| j| jd}|||fS )a  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        Fr  )r#   rG  )r'   rW  r   r#   rG  r?  )rF   r   rJ  r'  r&  r1   r1   r2   r]     s   
zPatchTSTNOPScaler.forwardr   )r^   r_   r`   ra   r   r>   r'   re   r   rf   r]   rg   r1   r1   rG   r2   r]    s    r]  c                	       sL   e Zd Zdef fddZdejdejdeejejejf fddZ  Z	S )	PatchTSTScalerr:   c                    sR   t    |jdks|jdu rt|| _d S |jdkr"t|| _d S t|| _d S )Nr   Tr   )r=   r>   r   rR  scalerrE  r]  rp   rG   r1   r2   r>     s   

zPatchTSTScaler.__init__r   rJ  rM   c                 C   s   |  ||\}}}|||fS )a>  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Input for scaler calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, um_input_channels)`)
        )r_  )rF   r   rJ  r&  r'  r1   r1   r2   r]     s   
zPatchTSTScaler.forward)
r^   r_   r`   r   r>   r'   re   rf   r]   rg   r1   r1   rG   r2   r^    s    	r^  c                       sv   e Zd Zdef fddZ					ddejdeej deej dee d	ee d
ee de	e
ef fddZ  ZS )PatchTSTModelr:   c                    sf   t  | t|| _t|| _|j| _| jj}| jr!t|| _	nt
 | _	t||d| _|   d S )N)r   )r=   r>   r^  r_  r   
patchifierdo_mask_inputr   r   maskingr   r   r   encoderr  r  rG   r1   r2   r>   ,  s   


zPatchTSTModel.__init__Nr   past_observed_maskfuture_valuesr  rL   return_dictrM   c              	   C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r't|}| ||\}}}	| |}
| jr@| 	|
\}}n| 	|
d}}| j
|||d}|sk|j|j|jf}||||	|
f }tdd |D S t|j|j|j|||	|
dS )a  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*):
                Future target values associated with the `past_values`
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*):
                Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTModel

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain")

        >>> # during training, one provides both past and future values
        >>> outputs = model(
        ...     past_values=batch["past_values"],
        ...     future_values=batch["future_values"],
        ... )

        >>> last_hidden_state = outputs.last_hidden_state
        ```N)r   r  rL   c                 s   s    | ]	}|d ur|V  qd S r   r1   )r   vr1   r1   r2   	<genexpr>      z(PatchTSTModel.forward.<locals>.<genexpr>)r  rI   r  r   r&  r'  r   )r:   use_return_dictrL   r  r'   rW  r_  ra  rb  rc  rd  r  rI   r  rf   r%  )rF   r   re  rf  r  rL   rg  scaled_past_valuesr&  r'  patched_valuesmasked_valuesr   encoder_outputr   r1   r1   r2   r]   >  s6   6

zPatchTSTModel.forwardNNNNN)r^   r_   r`   r   r>   r'   re   r   rd   r   rf   r%  r]   rg   r1   r1   rG   r2   r`  *  s,    
r`  c                       s<   e Zd ZdZdef fddZdejdejfddZ  Z	S )	PatchTSTMaskPretrainHeadz-
    Pretraining head for mask modelling
    r:   c                    sH   t    |jdkrt|jnt | _t|j|j	| _
|j| _d S Nr   )r=   r>   head_dropoutr   r   r   r   rA   rm   r   linearr   rp   rG   r1   r2   r>     s   
 z!PatchTSTMaskPretrainHead.__init__	embeddingrM   c                 C   s:   |  | |}| jr|ddddddddf }|S )a  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                            `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True

        Nr   )rt  r   r   )rF   ru  r1   r1   r2   r]     s    z PatchTSTMaskPretrainHead.forwardrs   r1   r1   rG   r2   rq    s    rq  z*
    The PatchTST for pretrain model.
    c                       sj   e Zd Zdef fddZ				ddejdeej dee dee d	ee d
e	e
ef fddZ  ZS )PatchTSTForPretrainingr:   c                    s4   t  | d|_t|d| _t|| _|   d S )NTr  )r=   r>   rb  r`  r   rq  headr  rp   rG   r1   r2   r>     s
   
zPatchTSTForPretraining.__init__Nr   re  r  rL   rg  rM   c                 C   s   |dur|n| j j}| j||||dd}| |j}tjdd}|||j}	|	jdd|j	 
 |j	
 d  }
|j}|sU|f|d	d
  }|
durQ|
f| }|S |}|S t|
|||jdS )a	  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
            `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTConfig, PatchTSTForPretraining

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> # Config for random mask pretraining
        >>> config = PatchTSTConfig(
        ...     num_input_channels=7,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     mask_type='random',
        ...     random_mask_ratio=0.4,
        ...     use_cls_token=True,
        ... )
        >>> # Config for forecast mask pretraining
        >>> config = PatchTSTConfig(
        ...     num_input_channels=7,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     mask_type='forecast',
        ...     num_forecast_mask_patches=5,
        ...     use_cls_token=True,
        ... )
        >>> model = PatchTSTForPretraining(config)

        >>> # during training, one provides both past and future values
        >>> outputs = model(past_values=batch["past_values"])

        >>> loss = outputs.loss
        >>> loss.backward()
        ```NTr   re  r  rL   rg  none	reductionr   r"   rS  r   )r+  r,  rI   r  )r:   rk  r   rw  r  r   MSELossr   r   r   r   rI   r*  r  )rF   r   re  r  rL   rg  model_outputx_hatr+  loss_valmasked_lossr  r   r1   r1   r2   r]     s,   E
$
zPatchTSTForPretraining.forward)NNNN)r^   r_   r`   r   r>   r'   re   r   rd   r   rf   r*  r]   rg   r1   r1   rG   r2   rv    s&    
rv  c                       r   )PatchTSTClassificationHeadr:   c                    sd   t    |j| _|j| _tjdd| _|jdkrt|jnt	 | _
t|j|j |j| _d S Nr   	start_dimr   )r=   r>   r   pooling_typer   Flattenflattenrs  r   r   r   rA   r   rm   num_targetsrt  rp   rG   r1   r2   r>   ,  s   
 z#PatchTSTClassificationHead.__init__ru  c                 C   s   | j r|dddddddf }n"| jdkr|jdd}n| jdkr+|jddj}n	td| j d| |}| | |}|S )	a[  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_targets)`

        Nr   r   r!   r"   r   pooling operator  is not implemented yet)	r   r  r   r   valuesr@   r  rt  r   rF   ru  pooled_embeddingrr   r1   r1   r2   r]   4  s   



z"PatchTSTClassificationHead.forwardr   r1   r1   rG   r2   r  +  s    r  z0
    The PatchTST for classification model.
    c                       sx   e Zd Zdef fddZe					ddejdeej dee	 dee	 d	ee	 d
ee	 de
eef fddZ  ZS )PatchTSTForClassificationr:   c                    sB   t  | |jrtd d|_t|| _t|| _| 	  d S )N+Setting `do_mask_input` parameter to False.F)
r=   r>   rb  loggerwarningr`  r   r  rw  r  rp   rG   r1   r2   r>   V  s   


z"PatchTSTForClassification.__init__Nr   target_valuesre  r  rL   rg  rM   c                 C   s   |dur|n| j j}| j||||dd}| |j}d}	|dur)t }
|
||}	|sC|f|dd  }|	dur?|	f| }|S |}|S t|	||j|j	dS )ac  
        past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
            Input sequence to the model
        target_values (`torch.Tensor`, *optional*):
            Labels associates with the `past_values`
        past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:

            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Examples:

        ```python
        >>> from transformers import PatchTSTConfig, PatchTSTForClassification

        >>> # classification task with two input channel2 and 3 classes
        >>> config = PatchTSTConfig(
        ...     num_input_channels=2,
        ...     num_targets=3,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     use_cls_token=True,
        ... )
        >>> model = PatchTSTForClassification(config=config)

        >>> # during inference, one only provides past values
        >>> past_values = torch.randn(20, 512, 2)
        >>> outputs = model(past_values=past_values)
        >>> labels = outputs.prediction_logits
        ```NTrx  r   r   )r+  r3  rI   r  )
r:   rk  r   rw  r  r   CrossEntropyLossr2  rI   r  )rF   r   r  re  r  rL   rg  r~  y_hatr  r+  r   r1   r1   r2   r]   d  s2   ,
z!PatchTSTForClassification.forwardrp  )r^   r_   r`   r   r>   r   r'   re   r   rd   r   rf   r2  r]   rg   r1   r1   rG   r2   r  P  s.    
r  z,
    The PatchTST for regression Model.
    c                       s8   e Zd Zd	dedef fddZdejfddZ  Z	S )
PatchTSTPredictionHeadNr:   r   c                    sD  t    |j| _|j| _|j| _|j| _| js| jr|j}n|j| }| jsvt | _	t | _
t | _t| jD ]8}| jtjdd |du rW| j	t||j n	| j	|| | j
|jdkrnt|jnt  q;dS tjdd| _|du rt||j| _n||| _|jdkrt|jnt | _dS )a  
        num_patches (`int`):
            The number of patches in the input sequence.
        distribution_output (`DistributionOutput`, *optional*):
            The distribution output layer for probabilistic forecasting. If None, a linear output layer is used.
        r!   r  Nr   )r=   r>   share_projectionr   r   r  rm   r   r   projectionsdropoutsflattensr   r   r  rA   prediction_lengthget_parameter_projectionrs  r   r   r  
projectionr   )rF   r:   r   distribution_outputr?   r   rG   r1   r2   r>     s0   




($zPatchTSTPredictionHead.__init__ru  c                 C   s  | j r|dddddddf }n| jdkr|jdd}n| jdkr+|jddj}n|}| jseg }t| jD ]%}| j| |dd|ddf }| j	| |}| j
| |}|| q7tj|dd}n| |}| |}| |}t|trtdd	 |D }|S |dd}|S )
aj  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, forecast_len, num_channels)`

        Nr   r   r!   r"   r   r   c                 s   s    | ]	}| d dV  qdS )r!   r   N)r)   )r   zr1   r1   r2   ri    rj  z1PatchTSTPredictionHead.forward.<locals>.<genexpr>)r   r  r   r   r  r  r   r   r  r  r  r   r'   r   r  r   r  r   rf   r)   )rF   ru  r  rr   r   r1   r1   r2   r]     s.   


 



zPatchTSTPredictionHead.forwardr   )
r^   r_   r`   r   rb   r>   r'   re   r]   rg   r1   r1   rG   r2   r    s    +r  z,
    The PatchTST for prediction model.
    c                       s   e Zd Zdef fddZ					ddejdeej deej dee d	ee d
ee de	e
ef fddZe 	ddejdeej defddZ  ZS )PatchTSTForPredictionr:   c                    s   t  | |jrtd d|_t|| _|jdkrd | _n/|jdkr,t	|j
d| _n"|jdkr9t|j
d| _n|jdkrFt|j
d| _ntd|j t|| jjj| jd	| _|   d S )
Nr  Fmse	student_tr"   normalnegative_binomialUnknown distribution output )r  )r=   r>   rb  r  r  r`  r   r+  r  r   r  r   r   r@   r  ra  r   rw  r  rp   rG   r1   r2   r>     s$   





zPatchTSTForPrediction.__init__Nr   re  rf  r  rL   rg  rM   c                 C   s   |dur|n| j j}| j||||dd}| |j}d}	| jr"|}
n||j |j }
|durQ| jrF| jj||j|jd}t	||}	t
|	}	ntjdd}||
|}	|j}|j}|sq|
f|dd  }|	durm|	f| }|S |}|S t|	|
|j|j||d	S )
aV	  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*):
                Future target values associated with the `past_values`
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*):
                Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
            `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTConfig, PatchTSTForPrediction

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> # Prediction task with 7 input channels and prediction length is 96
        >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast")

        >>> # during training, one provides both past and future values
        >>> outputs = model(
        ...     past_values=batch["past_values"],
        ...     future_values=batch["future_values"],
        ... )

        >>> loss = outputs.loss
        >>> loss.backward()

        >>> # during inference, one only provides past values, the model outputs future values
        >>> outputs = model(past_values=batch["past_values"])
        >>> prediction_outputs = outputs.prediction_outputs
        ```NTrx  r&  r'  r   rz  r   r   )r+  r1  rI   r  r&  r'  )r:   rk  r   rw  r  r  r'  r&  distributionr9  rC  r   r}  r0  rI   r  )rF   r   re  rf  r  rL   rg  r~  r  r  	y_hat_outr  r+  r&  r'  r   r1   r1   r2   r]   4  sL   =



zPatchTSTForPrediction.forwardc                    sr   | j j}| |d|dd}| jr.| jj|j|j|jd  fddt|D }tj	|dd}n|j
d}t|d	S )
a   
        Generate sequences of sample predictions from a model with a probability distribution head.

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
            samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
            for multivariate predictions.
        NF)r   rf  re  r  r  c                       g | ]}   qS r1   sampler   r  r1   r2   r     r  z2PatchTSTForPrediction.generate.<locals>.<listcomp>r   r"   r5  )r:   num_parallel_samplesr  r  r1  r&  r'  r   r'   r   r   r4  rF   r   re  r  r   samplesr1   r  r2   generate  s   
zPatchTSTForPrediction.generaterp  r   )r^   r_   r`   r   r>   r'   re   r   rd   r   rf   r0  r]   no_gradr4  r  rg   r1   r1   rG   r2   r    s>     

mr  c                       s8   e Zd ZdZd	def fddZdejfddZ  Z	S )
PatchTSTRegressionHeadz
    Regression head
    Nr:   c                    s   t    |j| _|j| _|j| _|| _|j|j }t	j
dd| _|jdkr,t	|jnt	 | _|d u r?t	||j| _d S ||| _d S r  )r=   r>   output_rangey_ranger   r  r  r   rm   r   r  r  rs  r   r   r   rA   r  r  r  )rF   r:   r  r?   rG   r1   r2   r>     s   
 zPatchTSTRegressionHead.__init__ru  c                 C   s   | j r|dddddddf }n"| jdkr|jdd}n| jdkr+|jddj}n	td| j d| | |}| |}| j	du | j
du@ r_t|| j
d	 | j
d   | j
d  }|S )
aY  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, output_dim)`

        Nr   r   r!   r"   r   r  r  r   )r   r  r   r   r  r@   r   r  r  r  r  r'   sigmoidr  r1   r1   r2   r]     s   



(zPatchTSTRegressionHead.forwardr   rs   r1   r1   rG   r2   r    s    r  z,
    The PatchTST for regression model.
    c                       s   e Zd Zdef fddZe					ddejdeej deej dee	 d	ee	 d
ee	 de
eef fddZe 	ddejdeej defddZ  ZS )PatchTSTForRegressionr:   c                    s   t  | |jrtd d|_t|| _|jdkrd | _n/|jdkr,t	|j
d| _n"|jdkr9t|j
d| _n|jdkrFt|j
d| _ntd|j t|| j| _|   d S )	Nr  Fr  r  r"   r  r  r  )r=   r>   rb  r  r  r`  r   r+  r  r   r  r   r   r@   r  rw  r  rp   rG   r1   r2   r>     s    





zPatchTSTForRegression.__init__Nr   r  re  r  rL   rg  rM   c                    s   |dur|n j j} j||||dd} |j}d}	|durI jr> j|}
t fdd|D }t|
|}	t	|	}	nt
jdd}	|	||}	|sc|f|dd	  }|	dur_|	f| }|S |}|S t|	||j|jd
S )a#  
        past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
            Input sequence to the model
        target_values (`torch.Tensor` of shape `(bs, num_input_channels)`):
            Target values associates with the `past_values`
        past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:

            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            Whether or not to return a `ModelOutput` instead of a plain tuple.

        Examples:

        ```python
        >>> from transformers import PatchTSTConfig, PatchTSTForRegression

        >>> # Regression task with 6 input channels and regress 2 targets
        >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression")

        >>> # during inference, one only provides past values, the model outputs future values
        >>> past_values = torch.randn(20, 512, 6)
        >>> outputs = model(past_values=past_values)
        >>> regression_outputs = outputs.regression_outputs
        ```NTrx  c                 3   s     | ]}| d  jjV  qdS )r   N)r,   r:   r  )r   itemrF   r1   r2   ri  _  s    z0PatchTSTForRegression.forward.<locals>.<genexpr>r   rz  r   r   )r+  r/  rI   r  )r:   rk  r   rw  r  r  r  rf   r9  rC  r   r}  r.  rI   r  )rF   r   r  re  r  rL   rg  r~  r  r+  r  r   r1   r  r2   r]   )  s<   %


zPatchTSTForRegression.forwardc                    sb   | j j}| |d|dd}| j|j  fddt|D }tj|ddd|| j j	}t
|d	S )
a  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
            samples, num_targets)`.
        NF)r   r  re  r  c                    r  r1   r  r   r  r1   r2   r     r  z2PatchTSTForRegression.generate.<locals>.<listcomp>r   r"   r   r  )r:   r  r  r  r/  r   r'   r   r,   r  r4  r  r1   r  r2   r  s  s   
zPatchTSTForRegression.generaterp  r   )r^   r_   r`   r   r>   r   r'   re   r   rd   r   rf   r.  r]   r  r4  r  rg   r1   r1   rG   r2   r  	  s@    
Ir  )r`  r   r  rv  r  r  )Nr   N)NFr   rr  r#  )Lra   r  dataclassesr   typingr   r   r   r'   r   activationsr   modeling_flash_attention_utilsr	   modeling_outputsr
   modeling_utilsr   r   processing_utilsr   time_series_utilsr   r   r   utilsr   r   r   configuration_patchtstr   
get_loggerr^   r  r   re   rc   r3   r4   ri   listrd   rb   r   r   r   r   r   r   r   r   r   r%  r*  r.  r0  r2  r4  distributionsDistributionr9  rC  rE  rR  r]  r^  r`  rq  rv  r  r  r  r  r  r  __all__r1   r1   r1   r2   <module>   s  


X
=

D0< %$8>
"$7po%W` =7 