o
    wiQ                    @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZmZ ddlmZmZmZ ddlmZ ee Z!			dtde	j"dej#dej#dej#deej# dee$ de$deej# fddZ%G dd de	j"Z&G dd de	j"Z'			dud ej#d!e$d"ee( d#e)d$e*f
d%d&Z+		dvd ej#d'ee(e*f d"ee( d$e*fd(d)Z,G d*d+ d+e	j"Z-G d,d- d-e	j"Z.G d.d/ d/e	j"Z/eG d0d1 d1eZ0G d2d3 d3e	j"Z1G d4d5 d5e	j"Z2G d6d7 d7e0Z3eed8d9G d:d; d;eZ4eed<d9G d=d> d>eZ5eed?d9G d@dA dAeZ6eedBd9G dCdD dDeZ7eedEd9G dFdG dGeZ8eedHd9G dIdJ dJeZ9dKej:j;dLej#dMej#fdNdOZ<dwdPej#dQeej# dMej#fdRdSZ=G dTdU dUe	j"Z>G dVdW dWe	j"Z?G dXdY dYe	j"Z@G dZd[ d[e	j"ZAeG d\d] d]e0ZBG d^d_ d_e	j"ZCed`d9G dadb dbe0ZDG dcdd dde	j"ZEeded9G dfdg dge0ZFedhd9G didj dje	j"ZGedkd9G dldm dme0ZHG dndo doe	j"ZIedpd9G dqdr dre0ZJg dsZKdS )xzPyTorch PatchTST model.    N)	dataclass)CallableOptionalUnion)nn   )ACT2CLS)FlashAttentionKwargs)BaseModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)NegativeBinomialOutputNormalOutputStudentTOutput)ModelOutputauto_docstringlogging   )PatchTSTConfig        modulequerykeyvalueattention_maskscalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur|	| }	tjj|	dd}	|d ur5|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )N         r   dimr   )ptraining)sizetorchmatmul	transposer   
functionalsoftmaxviewr   r%   
contiguous)r   r   r   r   r   r   r   r   kwargsattn_weightsattn_output r1   k/home/ubuntu/sommelier/.venv/lib/python3.10/site-packages/transformers/models/patchtst/modeling_patchtst.pyeager_attention_forward&   s   r3   c                       s   e Zd ZdZ					ddededed	ed
ededee f fddZ						dde
jdee
j deee
j  dee
j dee
j dedee dee
jee
j eee
j  f fddZ  ZS )PatchTSTAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN	embed_dim	num_headsr   
is_decoderbias	is_causalconfigc                    s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).r    r8   )super__init__r5   r6   r   head_dimr:   
ValueErrorr   r7   r9   r   Lineark_projv_projq_projout_proj)selfr5   r6   r   r7   r8   r9   r:   	__class__r1   r2   r>   H   s&   



zPatchTSTAttention.__init__hidden_stateskey_value_statespast_key_valuer   layer_head_maskoutput_attentionsr.   returnc                 K   s  |du}|j dd \}	}
|r|j d n|
}|	|
d| jf}|	|d| jf}| |j| dd}|rK|durK|d j d |j d krK|d }|d }nf|rf| |j| dd}| |j| dd}nK|dur| |j| dd}| |j| dd}tj|d |gdd}tj|d |gdd}n| |j| dd}| |j| dd}| j	r||f}t
}| jjdkrt| jj }|| ||||f| jsdn| j| j||d	|\}}||	|
d }| |}|||fS )
z#Input shape: Batch x Time x ChannelNr   r   r!   r   r"   eagerr   )r   r   rM   r   )shaper?   rD   r,   r)   rB   rC   r'   catr7   r3   r:   _attn_implementationr   r%   r   r   reshaper-   rE   )rF   rI   rJ   rK   r   rL   rM   r.   is_cross_attentionbsztgt_lensrc_lenq_input_shapekv_input_shapequery_states
key_statesvalue_statesattention_interfacer0   r/   r1   r1   r2   forwardg   sX   




zPatchTSTAttention.forward)r   FTFN)NNNNF)__name__
__module____qualname____doc__intfloatboolr   r   r>   r'   Tensortupler   r	   r^   __classcell__r1   r1   rG   r2   r4   E   sX    "
r4   c                       6   e Zd ZdZdef fddZdejfddZ  Z	S )PatchTSTBatchNormzP
    Compute batch normalization over the sequence length (time) dimension.
    r:   c                    s"   t    tj|j|jd| _d S )Neps)r=   r>   r   BatchNorm1dd_modelnorm_eps	batchnormrF   r:   rG   r1   r2   r>      s   
zPatchTSTBatchNorm.__init__inputsc                 C   s"   | dd}| |}| ddS )a  
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        r   r!   )r)   rp   )rF   rr   outputr1   r1   r2   r^      s   
zPatchTSTBatchNorm.forward
r_   r`   ra   rb   r   r>   r'   rf   r^   rh   r1   r1   rG   r2   rj      s    rj   Frr   
mask_ratiounmasked_channel_indiceschannel_consistent_masking
mask_valuec                 C   s*  |dk s|dkrt d| d| j\}}}}| j}	t|d|  }
|r5tj|d||	d}|d|d}n	tj||||	d}tj||||	d}d|ddddd|
f< tj|dd}tj|dd}tj	|d|d	}|
dddd|}|durd|dd|ddddf< | | |}||d
 fS )a  random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    r   r   zMask ratio z has to be between 0 and 1.deviceNr   r"   )r#   index.r   )r@   rP   rz   rc   r'   randrepeatonesargsortgather	unsqueezemasked_fillre   )rr   ru   rv   rw   rx   
batch_sizenum_channelssequence_lengthnum_featuresrz   len_keepnoisemaskids_shuffleids_restoreinputs_maskr1   r1   r2   random_masking   s&   r   num_forecast_mask_patchesc                 C   s  t |tr|g}dd |D }| j\}}}}tj|||| jd}	g }
d}t|}t||D ](\}}|dks9||krAtd| dt|| | }|
	|||g ||7 }q-t
|
dd d	}
||k rq|
d d
 ||  |
d d
< n||kr|
d d
 ||  |
d d
< d}|
D ]\}}}|| }d|	||dd| df< |}qt|	jd }|	| }	|	dddd|}	|durd|	dd|ddddf< | |	 |}||	d fS )a  Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    c                 S   s   g | ]}d qS )r   r1   .0_r1   r1   r2   
<listcomp>)  s    z$forecast_masking.<locals>.<listcomp>ry   r   znum_forecast_mask_patches z6 should be greater than 0 and less than total patches.c                 S   s   | d S )Nr!   r1   )xr1   r1   r2   <lambda>;  s    z"forecast_masking.<locals>.<lambda>)r   r!   r   r   Nr|   )
isinstancerc   rP   r'   zerosrz   sumzipr@   appendsortedrandpermr   r~   r   re   )rr   r   rv   rx   forecast_mask_ratiosr   r   r   r   r   t_listtotal_lengthtotal_ratiopatch_lengthratiotemp_lenbatch1	patch_lenr   batch2permr   r1   r1   r2   forecast_masking  sB   


r   c                       ri   )PatchTSTPatchifyz
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    r:   c                    s   t    |j| _|j| _|j| _| j| jkr$td| j d| j dt| j| j| j | j d | _| j| j| jd   }| j| | _	d S )NzSequence length (z+) has to be greater than the patch length ()r   )
r=   r>   context_lengthr   r   patch_strider@   maxnum_patchessequence_start)rF   r:   new_sequence_lengthrG   r1   r2   r>   [  s   
 zPatchTSTPatchify.__init__past_valuesc                 C   sp   |j d }|| jkrtd| d| j d|dd| jdddf }|jd| j| jd}|dd }|S )a!  
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        zInput sequence length (z%) doesn't match model configuration (r;   N)	dimensionr&   step)	rP   r   r@   r   unfoldr   r   r)   r-   )rF   r   r   rs   r1   r1   r2   r^   l  s   
	
zPatchTSTPatchify.forwardrt   r1   r1   rG   r2   r   S  s    r   c                       ri   )PatchTSTMaskinga  
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSTConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    r:   c                    sX   t    |j| _|j| _|j| _|j| _|j| _|j| _| jd ur*t| j| _d S d S N)	r=   r>   random_mask_ratiorw   	mask_typer   rv   rx   r   rq   rG   r1   r2   r>     s   

zPatchTSTMasking.__init__patch_inputc                 C   sr   | j dkrt|| j| j| j| jd\}}n| j dkr(t|| j| j| jd\}}n	td| j  d|	 }||fS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        random)rr   ru   rv   rw   rx   forecast)rr   r   rv   rx   zInvalid mask type .)
r   r   r   rv   rw   rx   r   r   r@   re   )rF   r   masked_inputr   r1   r1   r2   r^     s$   

zPatchTSTMasking.forwardrt   r1   r1   rG   r2   r     s    r   c                       s@   e Zd ZdZdef fddZd
dejdee	 fdd	Z
  ZS )PatchTSTEncoderLayerz 
    PatchTST encoder layer
    r:   c              
      s  t    |j| _t|j|j|j|d| _|jdkr t	
|jnt	 | _|jdkr0t|| _n|jdkr@t	j|j|jd| _nt|j d| jr~|jdkrVt	
|jnt	 | _|jdkrft|| _n|jdkrvt	j|j|jd| _nt|j dt	t	j|j|j|jdt|j  |jdkrt	
|jnt	 t	j|j|j|jd| _|jdkrt	
|jnt	 | _|jdkrt|| _n|jdkrt	j|j|jd| _nt|j d|j| _d S )N)r5   r6   r   r:   r   rp   	layernormrk   z$ is not a supported norm layer type.r<   ) r=   r>   channel_attentionr4   rn   num_attention_headsattention_dropout	self_attnpath_dropoutr   DropoutIdentitydropout_path1	norm_typerj   norm_sublayer1	LayerNormro   r@   dropout_path2norm_sublayer2
SequentialrA   ffn_dimr8   r   activation_function
ff_dropoutffdropout_path3norm_sublayer3pre_normrq   rG   r1   r2   r>     sD   
 

 


 

zPatchTSTEncoderLayer.__init__Nhidden_staterM   c                 C   s  |j \}}}}||| ||}| jr(| j| ||d\}}}	|| | }n| j||d\}}}	| || | }|||||}| jr|dd	 }||| ||}| jrp| j| 
||d\}}
}	|| | }n| j||d\}}
}	| 
|| | }|||||}|dd	 }||| ||}| jr|| | | | }n| || | | }|||||}|f}|r|| jr||
fn|f7 }|S )a  
        Parameters:
            hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*):
                Past values of the time series
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
        Return:
            `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`

        )rI   rM   r!   r   )rP   r,   r   r   r   r   rS   r   r)   r-   r   r   r   r   r   )rF   r   rM   r   num_input_channelsr   rn   r0   r/   r   channel_attn_weightsoutputsr1   r1   r2   r^     sF   

zPatchTSTEncoderLayer.forwardr   )r_   r`   ra   rb   r   r>   r'   rf   r   re   r^   rh   r1   r1   rG   r2   r     s    "2r   c                   @   s.   e Zd ZeZdZdZdZdd Zd	ddZ	dS )
PatchTSTPreTrainedModelmodelr   Fc                 C   s   t |tr&| jjrtjj|jdd | jjdkr$tjj|j	ddd dS dS t |tj
r;|jj  |jjd dS t |trQ|jjj  |jjjd dS t |tjtjfrr|jjjd| jjd |jdurt|jj  dS dS dS )	z$
        Initialize weights
        g{Gz?)stdr   r   g?)meanr         ?N)r   PatchTSTPositionalEncodingr:   use_cls_tokenr   initnormal_	cls_tokenpositional_encoding_typeposition_encr   r8   datazero_weightfill_rj   rp   rA   Conv1dinit_std)rF   r   r1   r1   r2   _init_weightsQ  s$   


z%PatchTSTPreTrainedModel._init_weightsc                 C   s   t |tr
||_d S d S r   )r   PatchTSTEncodergradient_checkpointing)rF   r   r   r1   r1   r2   _set_gradient_checkpointingg  s   

z3PatchTSTPreTrainedModel._set_gradient_checkpointingN)F)
r_   r`   ra   r   config_classbase_model_prefixmain_input_namesupports_gradient_checkpointingr   r   r1   r1   r1   r2   r   J  s    r   c                       2   e Zd Zdef fddZdejfddZ  ZS )PatchTSTEmbeddingr:   c                    sl   t    |j| _|j| _| jrt|j|j| _d S t	 | _t
|jD ]}| jt|j|j q%d S r   )r=   r>   r   share_embeddingr   rA   r   rn   input_embedding
ModuleListranger   )rF   r:   r   rG   r1   r2   r>   m  s   

zPatchTSTEmbedding.__init__r   c                    sj    j d }|jkrtdj d| djr  }|S  fddt|D }tj|dd}|S )a%  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input for embedding
        return:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
        r   z&The defined number of input channels (zQ) in the config has to be the same as the number of channels in the batch input (r   c              	      s2   g | ]}j |  d d |d d d d f qS r   )r   r   ir   rF   r1   r2   r     s   2 z-PatchTSTEmbedding.forward.<locals>.<listcomp>r"   )rP   r   r@   r   r   r   r'   stack)rF   r   r   
embeddingsr1   r   r2   r^   y  s   
	


zPatchTSTEmbedding.forward	r_   r`   ra   r   r>   r'   rf   r^   rh   r1   r1   rG   r2   r   l  s    r   c                       sV   e Zd ZdZdedef fddZedededej	fddZ
d	ejfd
dZ  ZS )r   z'
    Class for positional encoding
    r:   r   c                    sz   t    |j| _|j| _|jr!ttddd|j| _	|d7 }| 
||| _|jdkr6t|j| _d S t | _d S )Nr   r   )r=   r>   r   r   r   	Parameterr'   r   rn   r   _init_per   positional_dropoutr   r   rF   r:   r   rG   r1   r2   r>     s   
z#PatchTSTPositionalEncoding.__init__rN   c                 C   s   | j dkrtjt|| jdd}|S | j dkrst|| j}td|d}t	td| jdt
d| j   }t|| |d d dd df< t|| |d d dd df< ||  }|| d	  }tj|d
d}|S t| j  d)Nr   Trequires_gradsincosr   r   r!   g     @
   FzN is not a valid positional encoder. Available types are 'random' and 'sincos'.)r   r   r  r'   randnrn   r   aranger   expmathlogsincosr   r   r@   )r:   r   r   positiondiv_termr1   r1   r2   r    s    

(  
z#PatchTSTPositionalEncoding._init_per   c                 C   s   | j r8| || jdd d d f  }| j| jd dd d f  }||jd | jdd}tj||fdd}|S | || j }|S )Nr   r   r   r!   r"   )	r   r  r   r   expandrP   r   r'   rQ   )rF   r   r   
cls_tokensr   r1   r1   r2   r^     s    z"PatchTSTPositionalEncoding.forward)r_   r`   ra   rb   r   rc   r>   staticmethodr   r  r  r'   rf   r^   rh   r1   r1   rG   r2   r     s    r   c                	       sT   e Zd ZdZdedef fddZ		ddejde	e
 d	e	e
 d
efddZ  ZS )r   z
    PatchTST Encoder
    r:   r   c                    sT   t    d| _t | _t || _t fddt	 j
D | _|   d S )NFc                    s   g | ]}t  qS r1   )r   r   r:   r1   r2   r         z,PatchTSTEncoder.__init__.<locals>.<listcomp>)r=   r>   r   r   embedderr   positional_encoderr   r   r   num_hidden_layerslayers	post_initr  rG   r  r2   r>     s   
 zPatchTSTEncoder.__init__Nr   output_hidden_statesrM   rN   c           	      C   s   |dur|n| j j}|dur|n| j j}| |}| |}|r"dnd}|r(dnd}| jD ]}|r6||f }|||d}|d }|rI||d f }q-t|||dS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Past values of the time series
            output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
            output_attentions (bool, optional): Indicates if attentions should be outputted.

        return:
            `BaseModelOutput`
        Nr1   )r   rM   r   r   )last_hidden_staterI   
attentions)r:   rM   r  r  r  r  r
   )	rF   r   r  rM   r   encoder_statesall_attentionsencoder_layerlayer_outputsr1   r1   r2   r^     s    



zPatchTSTEncoder.forwardNN)r_   r`   ra   rb   r   rc   r>   r'   rf   r   re   r
   r^   rh   r1   r1   rG   r2   r     s    r   zG
    Base class for model's outputs, with potential hidden states.
    )custom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dS )
PatchTSTModelOutputa>  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Sequence of hidden-states at the output of the last layer of the model.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
        one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
        the model at the output of each layer plus the optional initial embedding outputs.
    mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
        Bool masked tensor indicating which patches are masked
    loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
        Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
        Patched input to the Transformer
    Nr  rI   r  r   locscaler   )r_   r`   ra   rb   r  r   r'   FloatTensor__annotations__rI   rg   r  r   r&  r'  r   r1   r1   r1   r2   r%    s   
 r%  z4
    Output type of [`PatchTSTForPretraining`].
    c                   @   b   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )PatchTSTForPretrainingOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    prediction_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction outputs of the time series modeling heads.
    Nlossprediction_outputrI   r  )r_   r`   ra   rb   r,  r   r'   r(  r)  r-  rI   rg   r  r1   r1   r1   r2   r+  '     
 r+  z3
    Output type of [`PatchTSTForRegression`].
    c                   @   r*  )PatchTSTForRegressionOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Regression outputs of the time series modeling heads.
    Nr,  regression_outputsrI   r  )r_   r`   ra   rb   r,  r   r'   r(  r)  r0  rI   rg   r  r1   r1   r1   r2   r/  ;  r.  r/  z3
    Output type of [`PatchTSTForPrediction`].
    c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeej ed< dZeej ed< dS )	PatchTSTForPredictionOutputa!  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        MSE loss.
    prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
        Prediction outputs of the time series modeling heads.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
        heads.
    loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
        Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
        Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
    Nr,  prediction_outputsrI   r  r&  r'  )r_   r`   ra   rb   r,  r   r'   r(  r)  r2  rI   rg   r  r&  r'  r1   r1   r1   r2   r1  O  s   
 r1  z7
    Output type of [`PatchTSTForClassification`].
    c                   @   r*  )PatchTSTForClassificationOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        Total loss as the sum of the masked language modeling loss and the next sequence prediction
        (classification) loss.
    prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
        Prediction scores of the PatchTST modeling head (scores before SoftMax).
    Nr,  prediction_logitsrI   r  )r_   r`   ra   rb   r,  r   r'   r(  r)  r4  rI   rg   r  r1   r1   r1   r2   r3  o  s   
 r3  z
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.
    c                   @   s$   e Zd ZU dZdZeej ed< dS )SamplePatchTSTOutputz
    sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, num_targets)`):
        Sampled values from the chosen distribution.
    N	sequences)	r_   r`   ra   rb   r6  r   r'   r(  r)  r1   r1   r1   r2   r5    s   
 r5  inputtargetrN   c                 C   s   |  | S )zc
    Computes the negative log likelihood loss from input distribution with respect to target.
    )log_prob)r7  r8  r1   r1   r2   nll  s   r:  input_tensorweightsc                 C   sr   |dur3t |dk| | t | }t j|r|j|dn| dd}|r-|j|d| S | | S | j|dS )aj  
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.

    Args:
        input_tensor (`torch.FloatTensor`):
            Input tensor, of which the average must be computed.
        weights (`torch.FloatTensor`, *optional*):
            Weights tensor, of the same shape as `input_tensor`.
        dim (`int`, *optional*):
            The dim along which to average `input_tensor`.

    Returns:
        `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
    Nr   r"   r   min)r'   where
zeros_likeclampr   r   )r;  r<  r#   weighted_tensorsum_weightsr1   r1   r2   weighted_average  s
   " rD  c                	       P   e Zd ZdZdef fddZdejdejdeejejejf fdd	Z	  Z
S )
PatchTSTStdScalerz
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    r:   c                    sV   t    t|dr|jnd| _t|dr|jnd| _t|dr&|j| _d S d| _d S )Nscaling_dimr   keepdimTminimum_scalegh㈵>)r=   r>   hasattrrG  r#   rH  rI  rq   rG   r1   r2   r>     s   
 zPatchTSTStdScaler.__init__r   observed_indicatorrN   c                 C   sz   |j | j| jd}|d}|| j | j| jd| }|| | d j | j| jd| }t|| j }|| | ||fS )C  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        rH  r   r!   )r   r#   rH  	clamp_minr'   sqrtrI  )rF   r   rK  denominatorr&  variancer'  r1   r1   r2   r^     s   
"zPatchTSTStdScaler.forwardr_   r`   ra   rb   r   r>   r'   rf   rg   r^   rh   r1   r1   rG   r2   rF    s    rF  c                	       rE  )
PatchTSTMeanScalerz
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    r:   c                    sl   t    t|dr|jnd| _t|dr|jnd| _t|dr#|jnd| _t|dr1|j| _d S d | _d S )NrG  r   rH  TrI  绽|=default_scale)r=   r>   rJ  rG  r#   rH  rI  rU  rq   rG   r1   r2   r>     s
   
 zPatchTSTMeanScaler.__init__r   rK  rN   c           
      C   s   ||   j| jdd}|j| jdd}|tj|dd }| jdu r:|jdd}tj|ddd}t|| }n| jt| }t|dk||}tj|| j	d}|| }	| j
sa|j| jd}|	t||fS )rL  TrM  r   r=  Nr   r"   )absr   r#   r'   rA  rU  squeeze	ones_liker?  rI  rH  r@  )
rF   r   rK  ts_sumnum_observedr'  	batch_sumbatch_observationsrU  scaled_datar1   r1   r2   r^     s   
zPatchTSTMeanScaler.forwardrR  r1   r1   rG   r2   rS    s    rS  c                
       sX   e Zd ZdZdef fddZ	ddejdeej de	ejejejf fd	d
Z
  ZS )PatchTSTNOPScalerz|
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    r:   c                    s@   t    t|dr|jnd| _t|dr|j| _d S d| _d S )NrG  r   rH  T)r=   r>   rJ  rG  r#   rH  rq   rG   r1   r2   r>     s   
 zPatchTSTNOPScaler.__init__Nr   rK  rN   c                 C   sB   t j|ddj| j| jd}t j|ddj| j| jd}|||fS )a  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        Fr  )r#   rH  )r'   rX  r   r#   rH  r@  )rF   r   rK  r'  r&  r1   r1   r2   r^     s   
zPatchTSTNOPScaler.forwardr   )r_   r`   ra   rb   r   r>   r'   rf   r   rg   r^   rh   r1   r1   rG   r2   r^    s    r^  c                	       sL   e Zd Zdef fddZdejdejdeejejejf fddZ  Z	S )	PatchTSTScalerr:   c                    sR   t    |jdks|jdu rt|| _d S |jdkr"t|| _d S t|| _d S )Nr   Tr   )r=   r>   r   rS  scalerrF  r^  rq   rG   r1   r2   r>   -  s   

zPatchTSTScaler.__init__r   rK  rN   c                 C   s   |  ||\}}}|||fS )a>  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Input for scaler calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, um_input_channels)`)
        )r`  )rF   r   rK  r&  r'  r1   r1   r2   r^   6  s   
zPatchTSTScaler.forward)
r_   r`   ra   r   r>   r'   rf   rg   r^   rh   r1   r1   rG   r2   r_  ,  s    	r_  c                       sv   e Zd Zdef fddZ					ddejdeej deej dee d	ee d
ee de	e
ef fddZ  ZS )PatchTSTModelr:   c                    sf   t  | t|| _t|| _|j| _| jj}| jr!t|| _	nt
 | _	t||d| _|   d S )N)r   )r=   r>   r_  r`  r   
patchifierdo_mask_inputr   r   maskingr   r   r   encoderr  r  rG   r1   r2   r>   J  s   


zPatchTSTModel.__init__Nr   past_observed_maskfuture_valuesr  rM   return_dictrN   c              	   C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r't|}| ||\}}}	| |}
| jr@| 	|
\}}n| 	|
d}}| j
|||d}|sk|j|j|jf}||||	|
f }tdd |D S t|j|j|j|||	|
dS )a  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*):
                Future target values associated with the `past_values`
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*):
                Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTModel

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain")

        >>> # during training, one provides both past and future values
        >>> outputs = model(
        ...     past_values=batch["past_values"],
        ...     future_values=batch["future_values"],
        ... )

        >>> last_hidden_state = outputs.last_hidden_state
        ```N)r   r  rM   c                 s   s    | ]	}|d ur|V  qd S r   r1   )r   vr1   r1   r2   	<genexpr>      z(PatchTSTModel.forward.<locals>.<genexpr>)r  rI   r  r   r&  r'  r   )r:   use_return_dictrM   r  r'   rX  r`  rb  rc  rd  re  r  rI   r  rg   r%  )rF   r   rf  rg  r  rM   rh  scaled_past_valuesr&  r'  patched_valuesmasked_valuesr   encoder_outputr   r1   r1   r2   r^   \  s6   6

zPatchTSTModel.forwardNNNNN)r_   r`   ra   r   r>   r'   rf   r   re   r   rg   r%  r^   rh   r1   r1   rG   r2   ra  H  s,    
ra  c                       s<   e Zd ZdZdef fddZdejdejfddZ  Z	S )	PatchTSTMaskPretrainHeadz-
    Pretraining head for mask modelling
    r:   c                    sH   t    |jdkrt|jnt | _t|j|j	| _
|j| _d S Nr   )r=   r>   head_dropoutr   r   r   r   rA   rn   r   linearr   rq   rG   r1   r2   r>     s   
 z!PatchTSTMaskPretrainHead.__init__	embeddingrN   c                 C   s:   |  | |}| jr|ddddddddf }|S )a  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                            `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True

        Nr   )ru  r   r   )rF   rv  r1   r1   r2   r^     s    z PatchTSTMaskPretrainHead.forwardrt   r1   r1   rG   r2   rr    s    rr  z*
    The PatchTST for pretrain model.
    c                       sj   e Zd Zdef fddZ				ddejdeej dee dee d	ee d
e	e
ef fddZ  ZS )PatchTSTForPretrainingr:   c                    s4   t  | d|_t|d| _t|| _|   d S )NTr  )r=   r>   rc  ra  r   rr  headr  rq   rG   r1   r2   r>     s
   
zPatchTSTForPretraining.__init__Nr   rf  r  rM   rh  rN   c                 C   s   |dur|n| j j}| j||||dd}| |j}tjdd}|||j}	|	jdd|j	 
 |j	
 d  }
|j}|sU|f|d	d
  }|
durQ|
f| }|S |}|S t|
|||jdS )a	  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
            `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTConfig, PatchTSTForPretraining

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> # Config for random mask pretraining
        >>> config = PatchTSTConfig(
        ...     num_input_channels=7,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     mask_type='random',
        ...     random_mask_ratio=0.4,
        ...     use_cls_token=True,
        ... )
        >>> # Config for forecast mask pretraining
        >>> config = PatchTSTConfig(
        ...     num_input_channels=7,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     mask_type='forecast',
        ...     num_forecast_mask_patches=5,
        ...     use_cls_token=True,
        ... )
        >>> model = PatchTSTForPretraining(config)

        >>> # during training, one provides both past and future values
        >>> outputs = model(past_values=batch["past_values"])

        >>> loss = outputs.loss
        >>> loss.backward()
        ```NTr   rf  r  rM   rh  none	reductionr   r"   rT  r   )r,  r-  rI   r  )r:   rl  r   rx  r  r   MSELossr   r   r   r   rI   r+  r  )rF   r   rf  r  rM   rh  model_outputx_hatr,  loss_valmasked_lossr  r   r1   r1   r2   r^     s,   E
$
zPatchTSTForPretraining.forward)NNNN)r_   r`   ra   r   r>   r'   rf   r   re   r   rg   r+  r^   rh   r1   r1   rG   r2   rw    s&    
rw  c                       r   )PatchTSTClassificationHeadr:   c                    sd   t    |j| _|j| _tjdd| _|jdkrt|jnt	 | _
t|j|j |j| _d S Nr   	start_dimr   )r=   r>   r   pooling_typer   Flattenflattenrt  r   r   r   rA   r   rn   num_targetsru  rq   rG   r1   r2   r>   J  s   
 z#PatchTSTClassificationHead.__init__rv  c                 C   s   | j r|dddddddf }n"| jdkr|jdd}n| jdkr+|jddj}n	td| j d| |}| | |}|S )	a[  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_targets)`

        Nr   r   r!   r"   r   pooling operator  is not implemented yet)	r   r  r   r   valuesr@   r  ru  r   rF   rv  pooled_embeddingrs   r1   r1   r2   r^   R  s   



z"PatchTSTClassificationHead.forwardr   r1   r1   rG   r2   r  I  s    r  z0
    The PatchTST for classification model.
    c                       sx   e Zd Zdef fddZe					ddejdeej dee	 dee	 d	ee	 d
ee	 de
eef fddZ  ZS )PatchTSTForClassificationr:   c                    sB   t  | |jrtd d|_t|| _t|| _| 	  d S )N+Setting `do_mask_input` parameter to False.F)
r=   r>   rc  loggerwarningra  r   r  rx  r  rq   rG   r1   r2   r>   t  s   


z"PatchTSTForClassification.__init__Nr   target_valuesrf  r  rM   rh  rN   c                 C   s   |dur|n| j j}| j||||dd}| |j}d}	|dur)t }
|
||}	|sC|f|dd  }|	dur?|	f| }|S |}|S t|	||j|j	dS )ac  
        past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
            Input sequence to the model
        target_values (`torch.Tensor`, *optional*):
            Labels associates with the `past_values`
        past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:

            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Examples:

        ```python
        >>> from transformers import PatchTSTConfig, PatchTSTForClassification

        >>> # classification task with two input channel2 and 3 classes
        >>> config = PatchTSTConfig(
        ...     num_input_channels=2,
        ...     num_targets=3,
        ...     context_length=512,
        ...     patch_length=12,
        ...     stride=12,
        ...     use_cls_token=True,
        ... )
        >>> model = PatchTSTForClassification(config=config)

        >>> # during inference, one only provides past values
        >>> past_values = torch.randn(20, 512, 2)
        >>> outputs = model(past_values=past_values)
        >>> labels = outputs.prediction_logits
        ```NTry  r   r   )r,  r4  rI   r  )
r:   rl  r   rx  r  r   CrossEntropyLossr3  rI   r  )rF   r   r  rf  r  rM   rh  r  y_hatr  r,  r   r1   r1   r2   r^     s2   ,
z!PatchTSTForClassification.forwardrq  )r_   r`   ra   r   r>   r   r'   rf   r   re   r   rg   r3  r^   rh   r1   r1   rG   r2   r  n  s.    
r  z,
    The PatchTST for regression Model.
    c                       s8   e Zd Zd	dedef fddZdejfddZ  Z	S )
PatchTSTPredictionHeadNr:   r   c                    sD  t    |j| _|j| _|j| _|j| _| js| jr|j}n|j| }| jsvt | _	t | _
t | _t| jD ]8}| jtjdd |du rW| j	t||j n	| j	|| | j
|jdkrnt|jnt  q;dS tjdd| _|du rt||j| _n||| _|jdkrt|jnt | _dS )a  
        num_patches (`int`):
            The number of patches in the input sequence.
        distribution_output (`DistributionOutput`, *optional*):
            The distribution output layer for probabilistic forecasting. If None, a linear output layer is used.
        r!   r  Nr   )r=   r>   share_projectionr   r   r  rn   r   r   projectionsdropoutsflattensr   r   r  rA   prediction_lengthget_parameter_projectionrt  r   r   r  
projectionr   )rF   r:   r   distribution_outputr?   r   rG   r1   r2   r>     s0   




($zPatchTSTPredictionHead.__init__rv  c                 C   s  | j r|dddddddf }n| jdkr|jdd}n| jdkr+|jddj}n|}| jseg }t| jD ]%}| j| |dd|ddf }| j	| |}| j
| |}|| q7tj|dd}n| |}| |}| |}t|trtdd	 |D }|S |dd}|S )
aj  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, forecast_len, num_channels)`

        Nr   r   r!   r"   r   r   c                 s   s    | ]	}| d dV  qdS )r!   r   N)r)   )r   zr1   r1   r2   rj  )  rk  z1PatchTSTPredictionHead.forward.<locals>.<genexpr>)r   r  r   r   r  r  r   r   r  r  r  r   r'   r   r  r   r  r   rg   r)   )rF   rv  r  rs   r   r1   r1   r2   r^     s.   


 



zPatchTSTPredictionHead.forwardr   )
r_   r`   ra   r   rc   r>   r'   rf   r^   rh   r1   r1   rG   r2   r    s    +r  z,
    The PatchTST for prediction model.
    c                       s   e Zd Zdef fddZ					ddejdeej deej dee d	ee d
ee de	e
ef fddZe 	ddejdeej defddZ  ZS )PatchTSTForPredictionr:   c                    s   t  | |jrtd d|_t|| _|jdkrd | _n/|jdkr,t	|j
d| _n"|jdkr9t|j
d| _n|jdkrFt|j
d| _ntd|j t|| jjj| jd	| _|   d S )
Nr  Fmse	student_tr"   normalnegative_binomialUnknown distribution output )r  )r=   r>   rc  r  r  ra  r   r,  r  r   r  r   r   r@   r  rb  r   rx  r  rq   rG   r1   r2   r>   5  s$   





zPatchTSTForPrediction.__init__Nr   rf  rg  r  rM   rh  rN   c                 C   s   |dur|n| j j}| j||||dd}| |j}d}	| jr"|}
n||j |j }
|durQ| jrF| jj||j|jd}t	||}	t
|	}	ntjdd}||
|}	|j}|j}|sq|
f|dd  }|	durm|	f| }|S |}|S t|	|
|j|j||d	S )
aV	  
        Parameters:
            past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
                Input sequence to the model
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*):
                Future target values associated with the `past_values`
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
            return_dict (`bool`, *optional*):
                Whether or not to return a `ModelOutput` instead of a plain tuple.

        Returns:
            `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
            `config.return_dict`=False)

        Examples:

        ```python
        >>> from huggingface_hub import hf_hub_download
        >>> import torch
        >>> from transformers import PatchTSTConfig, PatchTSTForPrediction

        >>> file = hf_hub_download(
        ...     repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
        ... )
        >>> batch = torch.load(file)

        >>> # Prediction task with 7 input channels and prediction length is 96
        >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast")

        >>> # during training, one provides both past and future values
        >>> outputs = model(
        ...     past_values=batch["past_values"],
        ...     future_values=batch["future_values"],
        ... )

        >>> loss = outputs.loss
        >>> loss.backward()

        >>> # during inference, one only provides past values, the model outputs future values
        >>> outputs = model(past_values=batch["past_values"])
        >>> prediction_outputs = outputs.prediction_outputs
        ```NTry  r&  r'  r   r{  r   r   )r,  r2  rI   r  r&  r'  )r:   rl  r   rx  r  r  r'  r&  distributionr:  rD  r   r~  r1  rI   r  )rF   r   rf  rg  r  rM   rh  r  r  r  	y_hat_outr  r,  r&  r'  r   r1   r1   r2   r^   R  sL   =



zPatchTSTForPrediction.forwardc                    sr   | j j}| |d|dd}| jr.| jj|j|j|jd  fddt|D }tj	|dd}n|j
d}t|d	S )
a   
        Generate sequences of sample predictions from a model with a probability distribution head.

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
            samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
            for multivariate predictions.
        NF)r   rg  rf  r  r  c                       g | ]}   qS r1   sampler   r  r1   r2   r     r  z2PatchTSTForPrediction.generate.<locals>.<listcomp>r   r"   r6  )r:   num_parallel_samplesr  r  r2  r&  r'  r   r'   r   r   r5  rF   r   rf  r  r   samplesr1   r  r2   generate  s   
zPatchTSTForPrediction.generaterq  r   )r_   r`   ra   r   r>   r'   rf   r   re   r   rg   r1  r^   no_gradr5  r  rh   r1   r1   rG   r2   r  /  s>     

mr  c                       s8   e Zd ZdZd	def fddZdejfddZ  Z	S )
PatchTSTRegressionHeadz
    Regression head
    Nr:   c                    s   t    |j| _|j| _|j| _|| _|j|j }t	j
dd| _|jdkr,t	|jnt	 | _|d u r?t	||j| _d S ||| _d S r  )r=   r>   output_rangey_ranger   r  r  r   rn   r   r  r  rt  r   r   r   rA   r  r  r  )rF   r:   r  r?   rG   r1   r2   r>     s   
 zPatchTSTRegressionHead.__init__rv  c                 C   s   | j r|dddddddf }n"| jdkr|jdd}n| jdkr+|jddj}n	td| j d| | |}| |}| j	du | j
du@ r_t|| j
d	 | j
d   | j
d  }|S )
aY  
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, output_dim)`

        Nr   r   r!   r"   r   r  r  r   )r   r  r   r   r  r@   r   r  r  r  r  r'   sigmoidr  r1   r1   r2   r^     s   



(zPatchTSTRegressionHead.forwardr   rt   r1   r1   rG   r2   r    s    r  z,
    The PatchTST for regression model.
    c                       s   e Zd Zdef fddZe					ddejdeej deej dee	 d	ee	 d
ee	 de
eef fddZe 	ddejdeej defddZ  ZS )PatchTSTForRegressionr:   c                    s   t  | |jrtd d|_t|| _|jdkrd | _n/|jdkr,t	|j
d| _n"|jdkr9t|j
d| _n|jdkrFt|j
d| _ntd|j t|| j| _|   d S )	Nr  Fr  r  r"   r  r  r  )r=   r>   rc  r  r  ra  r   r,  r  r   r  r   r   r@   r  rx  r  rq   rG   r1   r2   r>   -  s    





zPatchTSTForRegression.__init__Nr   r  rf  r  rM   rh  rN   c                    s   |dur|n j j} j||||dd} |j}d}	|durI jr> j|}
t fdd|D }t|
|}	t	|	}	nt
jdd}	|	||}	|sc|f|dd	  }|	dur_|	f| }|S |}|S t|	||j|jd
S )a#  
        past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
            Input sequence to the model
        target_values (`torch.Tensor` of shape `(bs, num_input_channels)`):
            Target values associates with the `past_values`
        past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:

            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
            Whether or not to return a `ModelOutput` instead of a plain tuple.

        Examples:

        ```python
        >>> from transformers import PatchTSTConfig, PatchTSTForRegression

        >>> # Regression task with 6 input channels and regress 2 targets
        >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression")

        >>> # during inference, one only provides past values, the model outputs future values
        >>> past_values = torch.randn(20, 512, 6)
        >>> outputs = model(past_values=past_values)
        >>> regression_outputs = outputs.regression_outputs
        ```NTry  c                    s   g | ]
}| d  jjqS )r   )r,   r:   r  )r   itemrF   r1   r2   r   }  s    z1PatchTSTForRegression.forward.<locals>.<listcomp>r   r{  r   r   )r,  r0  rI   r  )r:   rl  r   rx  r  r  r  rg   r:  rD  r   r~  r/  rI   r  )rF   r   r  rf  r  rM   rh  r  r  r,  r  r   r1   r  r2   r^   G  s<   %


zPatchTSTForRegression.forwardc                    sb   | j j}| |d|dd}| j|j  fddt|D }tj|ddd|| j j	}t
|d	S )
a  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
            samples, num_targets)`.
        NF)r   r  rf  r  c                    r  r1   r  r   r  r1   r2   r     r  z2PatchTSTForRegression.generate.<locals>.<listcomp>r   r"   r   r  )r:   r  r  r  r0  r   r'   r   r,   r  r5  r  r1   r  r2   r    s   
zPatchTSTForRegression.generaterq  r   )r_   r`   ra   r   r>   r   r'   rf   r   re   r   rg   r/  r^   r  r5  r  rh   r1   r1   rG   r2   r  '  s@    
Ir  )ra  r   r  rw  r  r  )Nr   N)NFr   rs  r#  )Lrb   r  dataclassesr   typingr   r   r   r'   r   activationsr   modeling_flash_attention_utilsr	   modeling_outputsr
   modeling_utilsr   r   processing_utilsr   time_series_utilsr   r   r   utilsr   r   r   configuration_patchtstr   
get_loggerr_   r  Modulerf   rd   r3   r4   rj   listre   rc   r   r   r   r   r   r   r   r   r   r%  r+  r/  r1  r3  r5  distributionsDistributionr:  rD  rF  rS  r^  r_  ra  rr  rw  r  r  r  r  r  r  __all__r1   r1   r1   r2   <module>   s  


z
=

D0< !$8>
"$7po%W` =7 