o
    ۷i                  	   @   s   d dl Z d dlmZ d dlZd dlZddlmZmZ ddl	m
Z
mZ ddlmZmZmZ e r3d dlZ			dd
ededed dejfddZG dd deeZdS )    N)Literal   )ConfigMixinregister_to_config)	deprecateis_scipy_available   )KarrasDiffusionSchedulersSchedulerMixinSchedulerOutput+?cosinenum_diffusion_timestepsmax_betaalpha_transform_type)r   explaplacereturnc                 C   s   |dkr	dd }n|dkrdd }n|dkrdd }nt d| g }t| D ]}||  }|d	 |  }|td	||||  | q(tj|tjd
S )a>  
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`str`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.

    Returns:
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
    r   c                 S   s    t | d d t j d d S )NgMb?gT㥛 ?r   )mathcospit r   d/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/schedulers/scheduling_deis_multistep.pyalpha_bar_fn<   s    z)betas_for_alpha_bar.<locals>.alpha_bar_fnr   c              	   S   sP   dt dd|   t ddt d|    d  }t |}t |d|  S )Ng      r         ?r   gư>)r   copysignlogfabsr   sqrt)r   lmbsnrr   r   r   r   A   s   4
r   c                 S   s   t | d S )Ng      ()r   r   r   r   r   r   r   H   s   z"Unsupported alpha_transform_type: r   dtype)
ValueErrorrangeappendmintorchtensorfloat32)r   r   r   r   betasit1t2r   r   r   betas_for_alpha_bar"   s   


"r0   c                0   @   s*  e Zd ZdZdd eD ZdZe							
																dmdede	de	de
d dejee	 B d	B dede
d dede	d e	d!e
d d"e
d d#ed$ed%ed&ed'ed(e	d)e
d* d+ed,ed-e
d d.d	f.d/d0Zed.efd1d2Zed.efd3d4Zdnd5ed.d	fd6d7Zdod8ed9eejB d:e	d	B fd;d<Zd=ejd.ejfd>d?Zd@ejdAejd.ejfdBdCZd@ejd.eejejf fdDdEZdFejd8ed.ejfdGdHZdFejd8ed.ejfdIdJZ	KdpdFejd8edLe	dMe	d.ejf
dNdOZd	dPdQejd=ejd.ejfdRdSZ d	dPdQejd=ejd.ejfdTdUZ!d	dPdVeej d=ejd.ejfdWdXZ"d	dPdVeej d=ejd.ejfdYdZZ#		dqd[eejB d\ejd	B d.efd]d^Z$d[eejB d.d	fd_d`Z%	drdQejd[eejB d=ejdaed.e&eB f
dbdcZ'd=ejd.ejfdddeZ(dfejdgejdhej)d.ejfdidjZ*d.efdkdlZ+d	S )sDEISMultistepScheduleru  
    `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs).

    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.

    Args:
        num_train_timesteps (`int`, defaults to `1000`):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to `0.0001`):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to `0.02`):
            The final `beta` value.
        beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`np.ndarray` or `list[float]`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        solver_order (`int`, defaults to `2`):
            The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
            sampling, and `solver_order=3` for unconditional sampling.
        prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
            Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to `0.995`):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to `1.0`):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
        algorithm_type (`"deis"`, defaults to `"deis"`):
            The algorithm type for the solver.
        solver_type (`"logrho"`, defaults to `"logrho"`):
            Solver type for DEIS.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
             Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
             the sigmas are determined according to a sequence of noise levels {σi}.
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
        use_flow_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
        flow_shift (`float`, *optional*, defaults to `1.0`):
            The flow shift parameter for flow-based models.
        timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to `0`):
            An offset added to the inference steps, as required by some model families.
        use_dynamic_shifting (`bool`, defaults to `False`):
            Whether to use dynamic shifting for the noise schedule.
        time_shift_type (`"exponential"`, defaults to `"exponential"`):
            The type of time shifting to apply.
    c                 C   s   g | ]}|j qS r   )name).0er   r   r   
<listcomp>   s    z!DEISMultistepScheduler.<listcomp>r     -C6?{Gz?linearNr   epsilonFףp=
?      ?deislogrhoTlinspacer   exponentialnum_train_timesteps
beta_startbeta_endbeta_schedule)r9   scaled_linearsquaredcos_cap_v2trained_betassolver_orderprediction_type)r:   samplev_predictionflow_predictionthresholdingdynamic_thresholding_ratiosample_max_valuealgorithm_typesolver_typelower_order_finaluse_karras_sigmasuse_exponential_sigmasuse_beta_sigmasuse_flow_sigmas
flow_shifttimestep_spacing)r?   leadingtrailingsteps_offsetuse_dynamic_shiftingtime_shift_typer   c                 C   s  | j jrt stdt| j j| j j| j jgdkrtd|d ur,tj	|tj
d| _n:|dkr<tj|||tj
d| _n*|dkrRtj|d |d |tj
dd | _n|d	kr\t|| _n
t| d
| j d| j | _tj| jdd| _t| j| _td| j | _t| jt| j | _d| j | j d | _d| _|dvr|dv r| jdd n
t| d
| j |dvr|dv r| jdd ntd| d
| j d | _tjd|d |tj
dd d d  }t|| _d g| | _ d| _!d | _"d | _#| j$d| _d S )Nz:Make sure to install scipy if you want to use beta sigmas.r   znOnly one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.r#   r9   rE   r   r   rF   z is not implemented for r<   r   dim)r=   )	dpmsolverzdpmsolver++r=   )rP   )r>   )midpointheunbh1bh2r>   )rQ   zsolver type cpu)%configrU   r   ImportErrorsumrT   rS   r%   r)   r*   r+   r,   r?   r0   NotImplementedError	__class__alphascumprodalphas_cumprodr    alpha_tsigma_tr   lambda_tsigmasinit_noise_sigmar   num_inference_stepsnpcopy
from_numpy	timestepsmodel_outputslower_order_nums_step_index_begin_indexto)selfrA   rB   rC   rD   rG   rH   rI   rM   rN   rO   rP   rQ   rR   rS   rT   rU   rV   rW   rX   r[   r\   r]   rx   r   r   r   __init__   sj   		&zDEISMultistepScheduler.__init__c                 C      | j S )zg
        The index counter for current timestep. It will increase 1 after each scheduler step.
        )r{   r~   r   r   r   
step_index      z!DEISMultistepScheduler.step_indexc                 C   r   )zq
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        r|   r   r   r   r   begin_index   r   z"DEISMultistepScheduler.begin_indexr   c                 C   s
   || _ dS )z
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`, defaults to `0`):
                The begin index for the scheduler.
        Nr   )r~   r   r   r   r   set_begin_index  s   
z&DEISMultistepScheduler.set_begin_indexrt   devicemuc           	         s  |durj jrj jdksJ t|j _j jdkr<tdj jd |d 	 ddd dd 
 tj}n\j jdkrlj j|d  }td|d | 	 ddd dd 
 tj}|j j7 }n,j jdkrj j| }tj jd| 	 
 tj}|d8 }n	tj j d	tdj j d
 }t| j jrt|
 }j||d}t fdd|D 	 }t||dd gtj}nΈj jrt|
 }j||d}t fdd|D }t||dd gtj}nj jr;t|
 }j||d}t fdd|D }t||dd gtj}nnj jrtddj j |d }d| }tj j| dj jd |   dd 
 }|j j 
 }t||dd gtj}n't|tdt||}djd  jd  d
 }t||ggtj}t |_!t |j"|tjd_#t|_$dgj j% _&d_'d_(d_)j!"d_!dS )aj  
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`):
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
            mu (`float`, *optional*):
                The mu parameter for dynamic shifting. Only used when `use_dynamic_shifting=True` and
                `time_shift_type="exponential"`.
        Nr@   r?   r   r   re   rY   rZ   zY is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'.r   )	in_sigmasrt   c                       g | ]} | qS r   _sigma_to_tr3   sigma
log_sigmasr~   r   r   r5   >      z8DEISMultistepScheduler.set_timesteps.<locals>.<listcomp>c                    r   r   r   r   r   r   r   r5   C  r   c                    r   r   r   r   r   r   r   r5   H  r   r<   r   r$   rf   )*rg   r\   r]   ru   r   rW   rX   r?   rA   roundrv   astypeint64aranger[   r%   arrayrn   r   rS   flip_convert_to_karrasconcatenater+   rT   _convert_to_exponentialrU   _convert_to_betarV   interplenr)   rw   rr   r}   rx   rt   rH   ry   rz   r{   r|   )	r~   rt   r   r   rx   
step_ratiorr   rl   
sigma_lastr   r   r   set_timesteps  sx   6$

 
 
 
2 
z$DEISMultistepScheduler.set_timestepsrJ   c                 C   s   |j }|j^}}}|tjtjfvr| }|||t| }|	 }tj
|| jjdd}tj|d| jjd}|d}t|| || }|j||g|R  }||}|S )az  
        Apply dynamic thresholding to the predicted sample.

        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://huggingface.co/papers/2205.11487

        Args:
            sample (`torch.Tensor`):
                The predicted sample to be thresholded.

        Returns:
            `torch.Tensor`:
                The thresholded sample.
        r   r^   )r(   max)r$   shaper)   r+   float64floatreshaperu   prodabsquantilerg   rN   clamprO   	unsqueezer}   )r~   rJ   r$   
batch_sizechannelsremaining_dims
abs_samplesr   r   r   _threshold_samplee  s   


z(DEISMultistepScheduler._threshold_sampler   r   c                 C   s   t t |d}||ddt jf  }t j|dkddjddj|jd d d}|d }|| }|| }|| ||  }	t |	dd}	d|	 | |	|  }
|
|j}
|
S )a  
        Convert sigma values to corresponding timestep values through interpolation.

        Args:
            sigma (`np.ndarray`):
                The sigma value(s) to convert to timestep(s).
            log_sigmas (`np.ndarray`):
                The logarithm of the sigma schedule used for interpolation.

        Returns:
            `np.ndarray`:
                The interpolated timestep value(s) corresponding to the input sigma(s).
        g|=Nr   )axisr   )r   r   )	ru   r   maximumnewaxiscumsumargmaxclipr   r   )r~   r   r   	log_sigmadistslow_idxhigh_idxlowhighwr   r   r   r   r     s   ,z"DEISMultistepScheduler._sigma_to_tc                 C   s@   | j jrd| }|}||fS d|d d d  }|| }||fS )a(  
        Convert sigma values to alpha_t and sigma_t values.

        Args:
            sigma (`torch.Tensor`):
                The sigma value(s) to convert.

        Returns:
            `tuple[torch.Tensor, torch.Tensor]`:
                A tuple containing (alpha_t, sigma_t) values.
        r   r   r   )rg   rV   )r~   r   ro   rp   r   r   r   _sigma_to_alpha_sigma_t  s   z.DEISMultistepScheduler._sigma_to_alpha_sigma_tr   c           
      C   s   t | jdr| jj}nd}t | jdr| jj}nd}|dur |n|d  }|dur,|n|d  }d}tdd|}|d|  }|d|  }||||   | }	|	S )a  
        Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
        Models](https://huggingface.co/papers/2206.00364).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following the Karras noise schedule.
        	sigma_minN	sigma_maxre   r   g      @r   )hasattrrg   r   r   itemru   r?   )
r~   r   rt   r   r   rhorampmin_inv_rhomax_inv_rhorr   r   r   r   r     s   

z)DEISMultistepScheduler._convert_to_karrasc                 C   s   t | jdr| jj}nd}t | jdr| jj}nd}|dur |n|d  }|dur,|n|d  }ttt	|t	||}|S )a  
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        r   Nr   re   r   )
r   rg   r   r   r   ru   r   r?   r   r   )r~   r   rt   r   r   rr   r   r   r   r     s   

 z.DEISMultistepScheduler._convert_to_exponential333333?alphabetac              
      s   t | jdr| jjndt | jdr| jjnddur n|d  dur,n|d  tfdd fddd	tdd	| D D }|S )
a  
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        r   Nr   re   r   c                    s   g | ]
}|    qS r   r   )r3   ppf)r   r   r   r   r5   ;  s    z;DEISMultistepScheduler._convert_to_beta.<locals>.<listcomp>c                    s   g | ]}t jj| qS r   )scipystatsr   r   )r3   timestep)r   r   r   r   r5   =  s    r   )r   rg   r   r   r   ru   r   r?   )r~   r   rt   r   r   rr   r   )r   r   r   r   r   r     s    

	z'DEISMultistepScheduler._convert_to_betarJ   model_outputc          
      O   s2  t |dkr
|d n|dd}|du r#t |dkr|d }ntd|dur-tddd | j| j }| |\}}| jjd	krI|||  | }	n5| jjd
krR|}	n,| jjdkra|| ||  }	n| jjdkrt| j| j }|||  }	n
td| jj d| jj	r| 
|	}	| jjdkr|||	  | S td)a  
        Convert the model output to the corresponding type the DEIS algorithm needs.

        Args:
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `torch.Tensor`:
                The converted model output.
        r   r   Nr   /missing `sample` as a required keyword argumentrx   1.0.0Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`r:   rJ   rK   rL   zprediction_type given as zi must be one of `epsilon`, `sample`, `v_prediction`, or `flow_prediction` for the DEISMultistepScheduler.r=   'only support log-rho multistep deis now)r   popr%   r   rr   r   r   rg   rI   rM   r   rP   rj   )
r~   r   rJ   argskwargsr   r   ro   rp   x0_predr   r   r   convert_model_outputE  s<    

z+DEISMultistepScheduler.convert_model_outputc                O   s2  t |dkr
|d n|dd}t |dkr|d n|dd}|du r3t |dkr/|d }ntd|dur=tdd	d
 |durGtdd	d | j| jd  | j| j }}| |\}	}| |\}
}t|	t| }t|
t| }|| }| j	j
dkr|	|
 | |t|d  |  }|S td)au  
        One step for the first-order DEIS (equivalent to DDIM).

        Args:
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `torch.Tensor`:
                The sample tensor at the previous timestep.
        r   r   Nr   prev_timestepr   r   rx   r   r   Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`r=   r<   r   )r   r   r%   r   rr   r   r   r)   r   rg   rP   r   rj   )r~   r   rJ   r   r   r   r   rp   sigma_sro   alpha_srq   lambda_shx_tr   r   r   deis_first_order_update  s<     

"z.DEISMultistepScheduler.deis_first_order_updatemodel_output_listc                O   s  t |dkr
|d n|dd}t |dkr|d n|dd}|du r3t |dkr/|d }ntd|dur=tddd	 |durGtddd
 | j| jd  | j| j | j| jd  }}}	| |\}
}| |\}}| |	\}}	|d |d }}||
 || |	| }}}| jjdkrdd }|||||||| }|||||||| }|
|| ||  ||   }|S t	d)a  
        One step for the second-order multistep DEIS.

        Args:
            model_output_list (`list[torch.Tensor]`):
                The direct outputs from learned diffusion model at current and latter timesteps.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `torch.Tensor`:
                The sample tensor at the previous timestep.
        r   timestep_listNr   r   r   r   r   Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`r   re   r=   c                 S   s2   | t | t |  d  t |t |  S )Nr   ru   r   )r   bcr   r   r   ind_fn  s   2zIDEISMultistepScheduler.multistep_deis_second_order_update.<locals>.ind_fnr   
r   r   r%   r   rr   r   r   rg   rP   rj   )r~   r   rJ   r   r   r   r   rp   sigma_s0sigma_s1ro   alpha_s0alpha_s1m0m1rho_trho_s0rho_s1r   coef1coef2r   r   r   r   "multistep_deis_second_order_update  sJ     



z9DEISMultistepScheduler.multistep_deis_second_order_updatec                O   s  t |dkr
|d n|dd}t |dkr|d n|dd}|du r3t |dkr/|d }ntd|dur=tddd	 |durGtddd
 | j| jd  | j| j | j| jd  | j| jd  f\}}}	}
| |\}}| |\}}| |	\}}	| |
\}}
|d |d |d }}}|| || |	| |
| f\}}}}| jjdkrdd }|||||||||| }|||||||||| }|||||||||| }||| ||  ||  ||   }|S t	d)a  
        One step for the third-order multistep DEIS.

        Args:
            model_output_list (`list[torch.Tensor]`):
                The direct outputs from learned diffusion model at current and latter timesteps.
            sample (`torch.Tensor`):
                A current instance of a sample created by diffusion process.

        Returns:
            `torch.Tensor`:
                The sample tensor at the previous timestep.
        r   r   Nr   r   r   r   r   r   r   re   r   r=   c                 S   s   | t |t |t |  d  t |t |   t | t | d  dt |   d  }t |t | t |t |  }|| S )Nr   r   r   )r   r   r   d	numeratordenominatorr   r   r   r   H  s    (zHDEISMultistepScheduler.multistep_deis_third_order_update.<locals>.ind_fnr   r   )r~   r   rJ   r   r   r   r   rp   r   r   sigma_s2ro   r   r   alpha_s2r   r   m2r   r   r   rho_s2r   r   r   coef3r   r   r   r   !multistep_deis_third_order_update  sR     

$z8DEISMultistepScheduler.multistep_deis_third_order_updater   schedule_timestepsc                 C   sd   |du r| j }||k }t|dkrt| j d }|S t|dkr*|d  }|S |d  }|S )a  
        Find the index for a given timestep in the schedule.

        Args:
            timestep (`int` or `torch.Tensor`):
                The timestep for which to find the index.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule.
        Nr   r   )rx   nonzeror   r   )r~   r   r  index_candidatesr   r   r   r   index_for_timestep`  s   
z)DEISMultistepScheduler.index_for_timestepc                 C   s@   | j du rt|tjr|| jj}| || _dS | j	| _dS )z
        Initialize the step_index counter for the scheduler.

        Args:
            timestep (`int` or `torch.Tensor`):
                The current timestep for which to initialize the step index.
        N)
r   
isinstancer)   Tensorr}   rx   r   r  r{   r|   )r~   r   r   r   r   _init_step_index  s
   
	z'DEISMultistepScheduler._init_step_indexreturn_dictc           	      C   s`  | j du r	td| jdu r| | | jt| jd ko'| jjo't| jdk }| jt| jd ko<| jjo<t| jdk }| j||d}t	| jj
d D ]}| j|d  | j|< qL|| jd< | jj
dksk| jdk sk|rs| j||d}n| jj
dks| jdk s|r| j| j|d}n| j| j|d}| j| jj
k r|  jd7  _|  jd7  _|s|fS t|dS )	a  
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep DEIS.

        Args:
            model_output (`torch.Tensor`):
                The direct output from learned diffusion model.
            timestep (`int` or `torch.Tensor`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.

        Returns:
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
        NzaNumber of inference steps is 'None', you need to run 'set_timesteps' after creating the schedulerr      r   r   re   )prev_sample)rt   r%   r   r  r   rx   rg   rR   r   r&   rH   ry   rz   r   r   r   r{   r   )	r~   r   r   rJ   r  rR   lower_order_secondr-   r
  r   r   r   step  s2   


((

zDEISMultistepScheduler.stepc                 O   s   |S )a?  
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.Tensor`):
                The input sample.

        Returns:
            `torch.Tensor`:
                A scaled input sample.
        r   )r~   rJ   r   r   r   r   r   scale_model_input  s   z(DEISMultistepScheduler.scale_model_inputoriginal_samplesnoiserx   c           
         s  j j|j|jd}|jjdkr)t|r)jj|jtjd |j|jtjd}nj|j ||j}j	du rF fdd|D }nj
durUj
g|jd  }n	j	g|jd  }||  }t|jt|jk r}|d}t|jt|jk sn|\}}|| ||  }	|	S )	a  
        Add noise to the original samples according to the noise schedule at the specified timesteps.

        Args:
            original_samples (`torch.Tensor`):
                The original samples without noise.
            noise (`torch.Tensor`):
                The noise to add to the samples.
            timesteps (`torch.IntTensor`):
                The timesteps at which to add noise to the samples.

        Returns:
            `torch.Tensor`:
                The noisy samples.
        r   mpsr#   Nc                    r   r   )r  )r3   r   r  r~   r   r   r5     r   z4DEISMultistepScheduler.add_noise.<locals>.<listcomp>r   re   )rr   r}   r   r$   typer)   is_floating_pointrx   r+   r   r   r   flattenr   r   r   )
r~   r  r  rx   rr   step_indicesr   ro   rp   noisy_samplesr   r  r   	add_noise  s$   


z DEISMultistepScheduler.add_noisec                 C   s   | j jS N)rg   rA   r   r   r   r   __len__  s   zDEISMultistepScheduler.__len__)r6   r7   r8   r9   Nr   r:   Fr;   r<   r=   r>   TFFFFr<   r?   r   Fr@   )r   )NN)r   r   r  )T),__name__
__module____qualname____doc__r	   _compatiblesorderr   intr   r   ru   ndarraylistboolr   propertyr   r   r   strr)   r   r   r  r   r   tupler   r   r   r   r   r   r   r   r  r  r   r  r  	IntTensorr  r  r   r   r   r   r1   V   sB   =	
_"
U, %'#
4
@
A
M
\
%
@
1r1   )r   r   )r   typingr   numpyru   r)   configuration_utilsr   r   utilsr   r   scheduling_utilsr	   r
   r   scipy.statsr   r   r   r  r0   r1   r   r   r   r   <module>   s*   
4