o
    Gih                     @   s   d dl Z d dlmZ d dlmZmZmZ d dlZd dl	Z	ddl
mZmZ ddlmZmZmZ ddlmZ e r;d dlZeeZeG d	d
 d
eZG dd deeZdS )    N)	dataclass)LiteralOptionalUnion   )ConfigMixinregister_to_config)
BaseOutputis_scipy_availablelogging   )SchedulerMixinc                   @   s   e Zd ZU dZejed< dS )%FlowMatchEulerDiscreteSchedulerOutputaJ  
    Output class for the scheduler's `step` function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
    prev_sampleN)__name__
__module____qualname____doc__torchFloatTensor__annotations__ r   r   m/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/schedulers/scheduling_flow_match_euler_discrete.pyr   !   s   
 	r   c                   @   s  e Zd ZdZg ZdZe										
					d`dedede	ded
B ded
B dedede	dede	de	de	de
d de	fddZedd Zedd  Zed!d" Zdad$efd%d&Zdefd'd(Z	
dbd)ejd*eejB d+ejd
B d,ejfd-d.Zd,efd/d0Zd1ed2ed3ejd,ejfd4d5Zd3ejd,ejfd6d7Z	
	
	
	
	
dcd8ed
B d9eejB d:ee d
B d1ed
B d;ee d
B f
d<d=Z	
dbd*eeejf d>eej d,efd?d@Zd*eeejf d,d
fdAdBZ dCdCedDdd
d
dEfdFejd*eejB d)ejdGedHedIedJedKej!d
B dLejd
B dMe	d,e"e#B fdNdOZ$dPejd8ed,ejfdQdRZ%dPejd8ed,ejfdSdTZ&	UdddPejd8edVedWed,ejf
dXdYZ'd1ed2ed3ejd,ejfdZd[Z(d1ed2ed3ejd,ejfd\d]Z)d,efd^d_Z*d
S )eFlowMatchEulerDiscreteSchedulerab  
    Euler scheduler.

    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.

    Args:
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        shift (`float`, defaults to 1.0):
            The shift value for the timestep schedule.
        use_dynamic_shifting (`bool`, defaults to False):
            Whether to apply timestep shifting on-the-fly based on the image resolution.
        base_shift (`float`, defaults to 0.5):
            Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
            with desired output.
        max_shift (`float`, defaults to 1.15):
            Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
            more exaggerated or stylized.
        base_image_seq_len (`int`, defaults to 256):
            The base image sequence length.
        max_image_seq_len (`int`, defaults to 4096):
            The maximum image sequence length.
        invert_sigmas (`bool`, defaults to False):
            Whether to invert the sigmas.
        shift_terminal (`float`, defaults to None):
            The end value of the shifted timestep schedule.
        use_karras_sigmas (`bool`, defaults to False):
            Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
        use_exponential_sigmas (`bool`, defaults to False):
            Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
        use_beta_sigmas (`bool`, defaults to False):
            Whether to use beta sigmas for step sizes in the noise schedule during sampling.
        time_shift_type (`str`, defaults to "exponential"):
            The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
        stochastic_sampling (`bool`, defaults to False):
            Whether to use stochastic sampling.
    r           ?F      ?ffffff?      Nexponentialnum_train_timestepsshiftuse_dynamic_shifting
base_shift	max_shiftbase_image_seq_lenmax_image_seq_leninvert_sigmasshift_terminaluse_karras_sigmasuse_exponential_sigmasuse_beta_sigmastime_shift_type)r    linearstochastic_samplingc                 C   s   | j jrt stdt| j j| j j| j jgdkrtd|dvr&tdtj	d||tj
dd d d  }t|jtj
d}|| }|sS|| d|d |   }|| | _d | _d | _|| _|d| _| jd  | _| jd	  | _d S )
Nz:Make sure to install scipy if you want to use beta sigmas.r   znOnly one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.>   r.   r    z;`time_shift_type` must either be 'exponential' or 'linear'.dtypecpur   )configr,   r
   ImportErrorsumr+   r*   
ValueErrornplinspacefloat32copyr   
from_numpyto	timesteps_step_index_begin_index_shiftsigmasitem	sigma_min	sigma_max)selfr!   r"   r#   r$   r%   r&   r'   r(   r)   r*   r+   r,   r-   r/   r>   rB   r   r   r   __init__Z   s6   	"
z(FlowMatchEulerDiscreteScheduler.__init__c                 C      | j S )z.
        The value used for shifting.
        rA   rF   r   r   r   r"         z%FlowMatchEulerDiscreteScheduler.shiftc                 C   rH   )zg
        The index counter for current timestep. It will increase 1 after each scheduler step.
        )r?   rJ   r   r   r   
step_index   rK   z*FlowMatchEulerDiscreteScheduler.step_indexc                 C   rH   )zq
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        r@   rJ   r   r   r   begin_index   rK   z+FlowMatchEulerDiscreteScheduler.begin_indexr   rN   c                 C   
   || _ dS )z
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`, defaults to `0`):
                The begin index for the scheduler.
        NrM   )rF   rN   r   r   r   set_begin_index      
z/FlowMatchEulerDiscreteScheduler.set_begin_indexc                 C   rO   )z
        Sets the shift value for the scheduler.

        Args:
            shift (`float`):
                The shift value to be set.
        NrI   )rF   r"   r   r   r   	set_shift   rQ   z)FlowMatchEulerDiscreteScheduler.set_shiftsampletimestepnoisereturnc                    s  j j|j|jd}|jjdkr)t|r)jj|jtjd |j|jtjd}nj|j ||j}j	du rF fdd|D }nj
durUj
g|jd  }n	j	g|jd  }||  }t|jt|jk r}|d}t|jt|jk sn|| d	| |  }|S )
a  
        Forward process in flow-matching

        Args:
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`torch.FloatTensor`):
                The current timestep in the diffusion chain.
            noise (`torch.FloatTensor`):
                The noise tensor.

        Returns:
            `torch.FloatTensor`:
                A scaled input sample.
        )devicer1   mpsr0   Nc                    s   g | ]} | qS r   )index_for_timestep).0tschedule_timestepsrF   r   r   
<listcomp>   s    z?FlowMatchEulerDiscreteScheduler.scale_noise.<locals>.<listcomp>r   r2   r   )rB   r=   rW   r1   typer   is_floating_pointr>   r:   rN   rL   shapeflattenlen	unsqueeze)rF   rS   rT   rU   rB   step_indicessigmar   r\   r   scale_noise   s"   


z+FlowMatchEulerDiscreteScheduler.scale_noisec                 C   s   || j j S Nr4   r!   )rF   rf   r   r   r   _sigma_to_t   s   z+FlowMatchEulerDiscreteScheduler._sigma_to_tmurf   r[   c                 C   s8   | j jdkr| |||S | j jdkr| |||S dS )a  
        Apply time shifting to the sigmas.

        Args:
            mu (`float`):
                The mu parameter for the time shift.
            sigma (`float`):
                The sigma parameter for the time shift.
            t (`torch.Tensor`):
                The input timesteps.

        Returns:
            `torch.Tensor`:
                The time-shifted timesteps.
        r    r.   N)r4   r-   _time_shift_exponential_time_shift_linearrF   rk   rf   r[   r   r   r   
time_shift   s
   z*FlowMatchEulerDiscreteScheduler.time_shiftc                 C   s,   d| }|d d| j j  }d||  }|S )a,  
        Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
        value.

        Reference:
        https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51

        Args:
            t (`torch.Tensor`):
                A tensor of timesteps to be stretched and shifted.

        Returns:
            `torch.Tensor`:
                A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
        r   r2   )r4   r)   )rF   r[   one_minus_zscale_factorstretched_tr   r   r   stretch_shift_to_terminal  s   z9FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminalnum_inference_stepsrW   rB   r>   c                 C   sF  | j jr|du rtd|dur |dur t|t|kr td|dur=|dur.t||ks8|dur<t||kr<tdn|durEt|nt|}|| _|du}|r[t|tj}|du ry|du rrt	| 
| j| 
| j|}|| j j }nt|tj}t|}| j jr| |d|}n| j| d| jd |   }| j jr| |}| j jr| j||d}n| j jr| j||d}n| j jr| j||d}t|jtj|d}|s|| j j }nt|jtj|d}| j jrd| }|| j j }t|tjd|jd	g}nt|tjd|jd	g}|| _ || _!d| _"d| _#dS )
a  
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`, *optional*):
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
            sigmas (`list[float]`, *optional*):
                Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
                automatically.
            mu (`float`, *optional*):
                Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
                shifting.
            timesteps (`list[float]`, *optional*):
                Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
                automatically.
        NzC`mu` must be passed when `use_dynamic_shifting` is set to be `True`z4`sigmas` and `timesteps` should have the same lengthzq`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is providedr   r   )	in_sigmasrt   )r1   rW   )rW   )$r4   r#   r7   rc   rt   r8   arrayastyper:   r9   rj   rE   rD   r!   ro   r"   r)   rs   r*   _convert_to_karrasr+   _convert_to_exponentialr,   _convert_to_betar   r<   r=   r(   catonesrW   zerosr>   rB   r?   r@   )rF   rt   rW   rB   rk   r>   is_timesteps_providedr   r   r   set_timesteps  sf   




z-FlowMatchEulerDiscreteScheduler.set_timestepsr]   c                 C   s:   |du r| j }||k }t|dkrdnd}||  S )a  
        Get the index for the given timestep.

        Args:
            timestep (`float` or `torch.FloatTensor`):
                The timestep to find the index for.
            schedule_timesteps (`torch.FloatTensor`, *optional*):
                The schedule timesteps to validate against. If `None`, the scheduler's timesteps are used.

        Returns:
            `int`:
                The index of the timestep.
        Nr   r   )r>   nonzerorc   rC   )rF   rT   r]   indicesposr   r   r   rY     s
   z2FlowMatchEulerDiscreteScheduler.index_for_timestepc                 C   s@   | j d u rt|tjr|| jj}| || _d S | j	| _d S rh   )
rN   
isinstancer   Tensorr=   r>   rW   rY   r?   r@   )rF   rT   r   r   r   _init_step_index  s
   
z0FlowMatchEulerDiscreteScheduler._init_step_indexg        infTmodel_outputs_churns_tmins_tmaxs_noise	generatorper_token_timestepsreturn_dictc                 C   sX  t |tst |tjst |tjrtd| jdu r| | |tj	}|	durZ|	| j
j }| jddddf }||d d k }|| }|jdd\}}|d }|d }|| }n| j}| j| }| j|d  }|}|}|| }| j
jr|||  }t|}d| | ||  }n|||  }|  jd7  _|	du r||j}|
s|fS t|d	S )
a-  
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            s_churn (`float`):
            s_tmin  (`float`):
            s_tmax  (`float`):
            s_noise (`float`, defaults to 1.0):
                Scaling factor for noise added to the sample.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            per_token_timesteps (`torch.Tensor`, *optional*):
                The timesteps for each token in the sample.
            return_dict (`bool`, defaults to `True`):
                Whether or not to return a
                [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.

        Returns:
            [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
                If return_dict is `True`,
                [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
                otherwise a tuple is returned where the first element is the sample tensor.
        zPassing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass one of the `scheduler.timesteps` as a timestep.Ngư>r   )dim).Nr   r   )r   )r   intr   	IntTensor
LongTensorr7   rL   r   r=   r:   r4   r!   rB   maxr/   
randn_liker?   r1   r   )rF   r   rT   rS   r   r   r   r   r   r   r   per_token_sigmasrB   
lower_masklower_sigmas_current_sigma
next_sigmadt	sigma_idxrf   
sigma_nextx0rU   r   r   r   r   step  sL   -







z$FlowMatchEulerDiscreteScheduler.stepru   c           
      C   s   t | jdr| jj}nd}t | jdr| jj}nd}|dur |n|d  }|dur,|n|d  }d}tdd|}|d|  }|d|  }||||   | }	|	S )a  
        Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
        Models](https://huggingface.co/papers/2206.00364).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following the Karras noise schedule.
        rD   NrE   r2   r   g      @r   )hasattrr4   rD   rE   rC   r8   r9   )
rF   ru   rt   rD   rE   rhorampmin_inv_rhomax_inv_rhorB   r   r   r   rx     s   

z2FlowMatchEulerDiscreteScheduler._convert_to_karrasc                 C   s   t | jdr| jj}nd}t | jdr| jj}nd}|dur |n|d  }|dur,|n|d  }ttt	|t	||}|S )a  
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        rD   NrE   r2   r   )
r   r4   rD   rE   rC   r8   expr9   mathlog)rF   ru   rt   rD   rE   rB   r   r   r   ry   6  s   

 z7FlowMatchEulerDiscreteScheduler._convert_to_exponential333333?alphabetac              
      s   t | jdr| jjndt | jdr| jjnddur n|d  dur,n|d  tfdd fddd	tdd	| D D }|S )
a  
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        rD   NrE   r2   r   c                    s   g | ]
}|    qS r   r   )rZ   ppf)rE   rD   r   r   r^   ~  s    zDFlowMatchEulerDiscreteScheduler._convert_to_beta.<locals>.<listcomp>c                    s   g | ]}t jj| qS r   )scipystatsr   r   )rZ   rT   )r   r   r   r   r^     s    r   )r   r4   rD   rE   rC   r8   rv   r9   )rF   ru   rt   r   r   rB   r   )r   r   rE   rD   r   rz   X  s    

	z0FlowMatchEulerDiscreteScheduler._convert_to_betac                 C   s$   t |t |d| d |   S Nr   )r   r   rn   r   r   r   rl     s   $z7FlowMatchEulerDiscreteScheduler._time_shift_exponentialc                 C   s   ||d| d |   S r   r   rn   r   r   r   rm     s   z2FlowMatchEulerDiscreteScheduler._time_shift_linearc                 C   s   | j jS rh   ri   rJ   r   r   r   __len__  s   z'FlowMatchEulerDiscreteScheduler.__len__)r   r   Fr   r   r   r   FNFFFr    F)r   rh   )NNNNN)r   r   )+r   r   r   r   _compatiblesorderr   r   floatboolr   rG   propertyr"   rL   rN   rP   rR   r   r   rg   rj   r   ro   rs   strrW   listr   r   r   rY   r   	Generatorr   tupler   rx   ry   rz   rl   rm   r   r   r   r   r   r   /   s   '	
6




2


k
	

f'#
0r   )r   dataclassesr   typingr   r   r   numpyr8   r   configuration_utilsr   r   utilsr	   r
   r   scheduling_utilsr   scipy.statsr   
get_loggerr   loggerr   r   r   r   r   r   <module>   s   
