o
    Gi0                  	   @   s   d dl Z d dlmZmZ d dlZd dlZddlmZm	Z	 ddl
mZmZ ddlmZ ddlmZmZmZ e r;d dlZ			
ddededed dejfddZG dd deeZdS )    N)CallableLiteral   )ConfigMixinregister_to_config)	deprecateis_scipy_available)randn_tensor   )KarrasDiffusionSchedulersSchedulerMixinSchedulerOutput+?cosinenum_diffusion_timestepsmax_betaalpha_transform_type)r   explaplacereturnc                 C   s   |dkr	dd }n|dkrdd }n|dkrdd }nt d| g }t| D ]}||  }|d	 |  }|td	||||  | q(tj|tjd
S )a>  
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`str`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.

    Returns:
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
    r   c                 S   s    t | d d t j d d S )NgMb?gT㥛 ?r   )mathcospit r   \/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/schedulers/scheduling_sasolver.pyalpha_bar_fn=   s    z)betas_for_alpha_bar.<locals>.alpha_bar_fnr   c              	   S   sP   dt dd|   t ddt d|    d  }t |}t |d|  S )Ng      r
         ?r   gư>)r   copysignlogfabsr   sqrt)r   lmbsnrr   r   r   r   B   s   4
r   c                 S   s   t | d S )Ng      ()r   r   r   r   r   r   r   I   s   z"Unsupported alpha_transform_type: r
   dtype)
ValueErrorrangeappendmintorchtensorfloat32)r   r   r   r   betasit1t2r   r   r   betas_for_alpha_bar#   s   


"r2   c                0   @   s  e Zd ZdZdd eD ZdZeddddd	d
d
dd	dddddddddded d	ddfde	dedede
dejee B d	B de	de	de
ded	B dededed e
d!ed"ed#ed$ed%ed&ed'ed(e
d	B d)e
d*e	f.d+d,Zed-d. Zed/d0 Zdrd1e	fd2d3Zdsd4e	d5e
ejB fd6d7Zd8ejd9ejfd:d;Zd<d= Zd>d? Zd@ejd9ejfdAdBZd@ejd4e	d9ejfdCdDZ	Edtd@ejd4e	dFedGed9ejf
dHdIZd	dJdKejd8ejd9ejfdLdMZdNdO Z dPdQ Z!dRdS Z"dTdU Z#dKejd8ejdVejdWe	dXejd9ejfdYdZZ$d[ejd\ejd]ejd^ejdWe	dXejd9ejfd_d`Z%		dudae	ejB dbejd	B d9e	fdcddZ&dedf Z'			dvdKejdae	d8ejdged9e(e)B f
dhdiZ*d8ejd9ejfdjdkZ+dlejdVejdmej,d9ejfdndoZ-dpdq Z.d	S )wSASolverScheduleru  
    `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs.

    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.

    Args:
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        predictor_order (`int`, defaults to 2):
            The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for
            guided sampling, and `predictor_order=3` for unconditional sampling.
        corrector_order (`int`, defaults to 2):
            The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for
            guided sampling, and `corrector_order=3` for unconditional sampling.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://huggingface.co/papers/2210.02303) paper).
        tau_func (`Callable`, *optional*):
            Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`.
            SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample
            from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check
            https://huggingface.co/papers/2309.05019
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
            `algorithm_type="dpmsolver++"`.
        algorithm_type (`str`, defaults to `data_prediction`):
            Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use
            `data_prediction` with `solver_order=2` for guided sampling like in Stable Diffusion.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Default = True.
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
        lambda_min_clipped (`float`, defaults to `-inf`):
            Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
            cosine (`squaredcos_cap_v2`) noise schedule.
        variance_type (`str`, *optional*):
            Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
            contains the predicted Gaussian variance.
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
            An offset added to the inference steps, as required by some model families.
    c                 C   s   g | ]}|j qS r   )name).0er   r   r   
<listcomp>   s    zSASolverScheduler.<listcomp>r
   i  g-C6?g{Gz?linearNr   epsilonFgףp=
?      ?data_predictionTinflinspacer   num_train_timesteps
beta_startbeta_endbeta_scheduletrained_betaspredictor_ordercorrector_orderprediction_typetau_functhresholdingdynamic_thresholding_ratiosample_max_valuealgorithm_typelower_order_finaluse_karras_sigmasuse_exponential_sigmasuse_beta_sigmasuse_flow_sigmas
flow_shiftlambda_min_clippedvariance_typetimestep_spacingsteps_offsetc                 C   s  | j jrt stdt| j j| j j| j jgdkrtd|d ur,tj	|tj
d| _n:|dkr<tj|||tj
d| _n*|dkrRtj|d |d |tj
dd | _n|d	kr\t|| _n
t| d
| j d| j | _tj| jdd| _t| j| _td| j | _t| jt| j | _d| j | j d | _d| _|dvrt| d
| j d | _tjd|d |tj
dd d d  }t|| _d gt||d  | _ d gt||d  | _!|	d u rdd | _"n|	| _"|dk| _#d| _$d | _%d | _&d | _'| j(d| _d S )Nz:Make sure to install scipy if you want to use beta sigmas.r
   znOnly one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.r%   r8   scaled_linearr   r   squaredcos_cap_v2z is not implemented for r:   r   dim)r;   noise_predictionc                 S   s   | dkr
| dkr
dS dS )N   i   r
   r   r   r   r   r   r   <lambda>       z,SASolverScheduler.__init__.<locals>.<lambda>r;   cpu))configrN   r   ImportErrorsumrM   rL   r'   r+   r,   r-   r.   r=   r2   NotImplementedError	__class__alphascumprodalphas_cumprodr"   alpha_tsigma_tr    lambda_tsigmasinit_noise_sigmanum_inference_stepsnpcopy
from_numpy	timestepsmaxtimestep_listmodel_outputsrF   
predict_x0lower_order_numslast_sample_step_index_begin_indexto)selfr>   r?   r@   rA   rB   rC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rS   rT   rp   r   r   r   __init__   s\   	&
zSASolverScheduler.__init__c                 C      | j S )zg
        The index counter for current timestep. It will increase 1 after each scheduler step.
        )rw   rz   r   r   r   
step_index      zSASolverScheduler.step_indexc                 C   r|   )zq
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        rx   r}   r   r   r   begin_index   r   zSASolverScheduler.begin_indexr   c                 C   s
   || _ dS )z
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`, defaults to `0`):
                The begin index for the scheduler.
        Nr   )rz   r   r   r   r   set_begin_index  s   
z!SASolverScheduler.set_begin_indexrl   devicec           
         s  t t jdgjj}jj|   }jj	dkr;t
d|d |d  ddd dd  t
j}nXjj	dkri||d  }t
d|d |  ddd dd  t
j}|jj7 }n*jj	dkrjj| }t
|d|   t
j}|d8 }n	tjj	 dt
dj j d	 }t
| jjrt
| }j||d
}t
 fdd|D  }t
||dd gt
j}nΈjjrt
| }j||d
}t
 fdd|D }t
||dd gt
j}njjr6t
| }j||d
}t
 fdd|D }t
||dd gt
j}nnjjr}t
ddjj |d }d| }t
jj| djjd |   dd  }|jj  }t
||dd gt
j}n't
 |t
dt!||}djd  jd  d	 }	t
||	ggt
j}t "|_#t "|j$|t jd_%t!|_&dgt'jj(jj)d  _*d_+d_,d_-d_.j#$d_#dS )a  
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`):
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        r   r=   r
   NrZ   leadingtrailingzY is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'.r   )	in_sigmasrl   c                       g | ]} | qS r   _sigma_to_tr5   sigma
log_sigmasrz   r   r   r7   8  r]   z3SASolverScheduler.set_timesteps.<locals>.<listcomp>c                    r   r   r   r   r   r   r   r7   =  r]   c                    r   r   r   r   r   r   r   r7   B  r]   r:   )r   r&   r^   )/r+   searchsortedflipri   r_   rQ   r>   numpyitemrS   rm   r=   roundrn   astypeint64arangerT   r'   arrayrf   r    rL   _convert_to_karrasconcatenater-   rM   _convert_to_exponentialrN   _convert_to_betarO   rP   interplenro   rj   ry   rp   rl   rq   rC   rD   rs   ru   rv   rw   rx   )
rz   rl   r   clipped_idxlast_timesteprp   
step_ratiorj   rd   
sigma_lastr   r   r   set_timesteps  sl   66 

 
 
 
2 
zSASolverScheduler.set_timestepssampler   c                 C   s   |j }|j^}}}|tjtjfvr| }|||t| }|	 }tj
|| jjdd}tj|d| jjd}|d}t|| || }|j||g|R  }||}|S )az  
        Apply dynamic thresholding to the predicted sample.

        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://huggingface.co/papers/2205.11487

        Args:
            sample (`torch.Tensor`):
                The predicted sample to be thresholded.

        Returns:
            `torch.Tensor`:
                The thresholded sample.
        r
   rW   )r*   rq   )r&   shaper+   r-   float64floatreshaperm   prodabsquantiler_   rH   clamprI   	unsqueezery   )rz   r   r&   
batch_sizechannelsremaining_dims
abs_samplesr   r   r   _threshold_sample_  s   


z#SASolverScheduler._threshold_samplec                 C   s   t t |d}||ddt jf  }t j|dkddjddj|jd d d}|d }|| }|| }|| ||  }	t |	dd}	d|	 | |	|  }
|
|j}
|
S )a  
        Convert sigma values to corresponding timestep values through interpolation.

        Args:
            sigma (`np.ndarray`):
                The sigma value(s) to convert to timestep(s).
            log_sigmas (`np.ndarray`):
                The logarithm of the sigma schedule used for interpolation.

        Returns:
            `np.ndarray`:
                The interpolated timestep value(s) corresponding to the input sigma(s).
        g|=Nr   )axisr   )rq   r
   )	rm   r    maximumnewaxiscumsumargmaxclipr   r   )rz   r   r   	log_sigmadistslow_idxhigh_idxlowhighwr   r   r   r   r     s   ,zSASolverScheduler._sigma_to_tc                 C   s@   | j jrd| }|}||fS d|d d d  }|| }||fS )a(  
        Convert sigma values to alpha_t and sigma_t values.

        Args:
            sigma (`torch.Tensor`):
                The sigma value(s) to convert.

        Returns:
            `tuple[torch.Tensor, torch.Tensor]`:
                A tuple containing (alpha_t, sigma_t) values.
        r
   r   r   )r_   rO   )rz   r   rg   rh   r   r   r   _sigma_to_alpha_sigma_t  s   z)SASolverScheduler._sigma_to_alpha_sigma_tr   c           
      C   s   t | jdr| jj}nd}t | jdr| jj}nd}|dur |n|d  }|dur,|n|d  }d}tdd|}|d|  }|d|  }||||   | }	|	S )a  
        Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
        Models](https://huggingface.co/papers/2206.00364).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following the Karras noise schedule.
        	sigma_minN	sigma_maxrZ   r   g      @r
   )hasattrr_   r   r   r   rm   r=   )
rz   r   rl   r   r   rhorampmin_inv_rhomax_inv_rhorj   r   r   r   r     s   

z$SASolverScheduler._convert_to_karrasc                 C   s   t | jdr| jj}nd}t | jdr| jj}nd}|dur |n|d  }|dur,|n|d  }ttt	|t	||}|S )a  
        Construct an exponential noise schedule.

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.

        Returns:
            `torch.Tensor`:
                The converted sigma values following an exponential schedule.
        r   Nr   rZ   r   )
r   r_   r   r   r   rm   r   r=   r   r    )rz   r   rl   r   r   rj   r   r   r   r     s   

 z)SASolverScheduler._convert_to_exponential333333?alphabetac              
      s   t | jdr| jjndt | jdr| jjnddur n|d  dur,n|d  tfdd fddd	tdd	| D D }|S )
a  
        Construct a beta noise schedule as proposed in [Beta Sampling is All You
        Need](https://huggingface.co/papers/2407.12173).

        Args:
            in_sigmas (`torch.Tensor`):
                The input sigma values to be converted.
            num_inference_steps (`int`):
                The number of inference steps to generate the noise schedule for.
            alpha (`float`, *optional*, defaults to `0.6`):
                The alpha parameter for the beta distribution.
            beta (`float`, *optional*, defaults to `0.6`):
                The beta parameter for the beta distribution.

        Returns:
            `torch.Tensor`:
                The converted sigma values following a beta distribution schedule.
        r   Nr   rZ   r   c                    s   g | ]
}|    qS r   r   )r5   ppf)r   r   r   r   r7   5  s    z6SASolverScheduler._convert_to_beta.<locals>.<listcomp>c                    s   g | ]}t jj| qS r   )scipystatsr   r   )r5   timestep)r   r   r   r   r7   7  s    r
   )r   r_   r   r   r   rm   r   r=   )rz   r   rl   r   r   rj   r   )r   r   r   r   r   r     s    

	z"SASolverScheduler._convert_to_betar   model_outputc                O   s"  t |dkr
|d n|dd}|du r#t |dkr|d }ntd|dur-tddd | j| j }| |\}}| jjd	v r| jj	d
kr_| jj
dv rV|ddddf }|||  | }	n5| jj	dkrh|}	n,| jj	dkrw|| ||  }	n| jj	dkr| j| j }|||  }	n
td| jj	 d| jjr| |	}	|	S | jjdv r| jj	d
kr| jj
dv r|ddddf }
n+|}
n(| jj	dkr|||  | }
n| jj	dkr|| ||  }
n
td| jj	 d| jjr| j| | j| }}|||
  | }	| |	}	|||	  | }
|
S dS )a=  
        Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs.
        Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
        designed to discretize an integral of the data prediction model.

        > [!TIP] > The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction
        for both > noise prediction and data prediction models.

        Args:
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `torch.Tensor`:
                The converted model output.
        r   r   Nr
   /missing `sample` as a required keyword argumentrp   1.0.0zPassing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`)r;   r9   )learnedlearned_range   r   v_predictionflow_predictionzprediction_type given as zd must be one of `epsilon`, `sample`, `v_prediction`, or `flow_prediction` for the SASolverScheduler.)rY   zQ must be one of `epsilon`, `sample`, or `v_prediction` for the SASolverScheduler.)r   popr'   r   rj   r~   r   r_   rJ   rE   rR   rG   r   rg   rh   )rz   r   r   argskwargsr   r   rg   rh   x0_predr9   r   r   r   convert_model_output?  sd    



z&SASolverScheduler.convert_model_outputc                 C   s  |dv sJ d|dkrt | t || d  S |dkr4t | |d t ||  |d   S |dkrYt | |d d|  d t ||  |d d|  d   S |dkrt | |d d|d   d|  d t ||  |d d|d   d|  d   S dS )	zd
        Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end
        r   r
   r   r   )order is only supported for 0, 1, 2 and 3r   r
   r   r      Nr+   r   )rz   orderinterval_startinterval_endr   r   r   %get_coefficients_exponential_negative  s,   

 
z7SASolverScheduler.get_coefficients_exponential_negativec                 C   st  |dv sJ dd|d  | }d|d  | }|dkr1t |dt ||    d|d   S |dkrRt ||d |d t ||     d|d  d  S |dkrt ||d d|  d |d d|  d t ||     d|d  d  S |dkrt ||d d|d   d|  d |d d|d   d|  d t ||     d|d  d  S d	S )
zl
        Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
        r   r   r
   r   r   r   r      Nr   )rz   r   r   r   tauinterval_end_covinterval_start_covr   r   r   %get_coefficients_exponential_positive  sL   (	z7SASolverScheduler.get_coefficients_exponential_positivec              	   C   s$  |dv sJ |t |d ksJ |dkrdggS |dkrJd|d |d   |d  |d |d   gd|d |d   |d  |d |d   ggS |dkr|d |d  |d |d   }|d |d  |d |d   }|d |d  |d |d   }d| |d  |d  | |d |d  | gd| |d  |d  | |d |d  | gd| |d  |d  | |d |d  | ggS |dkr|d |d  |d |d   |d |d   }|d |d  |d |d   |d |d   }|d |d  |d |d   |d |d   }|d |d  |d |d   |d |d   }d| |d  |d  |d  | |d |d  |d |d   |d |d   | |d  |d  |d  | gd| |d  |d  |d  | |d |d  |d |d   |d |d   | |d  |d  |d  | gd| |d  |d  |d  | |d |d  |d |d   |d |d   | |d  |d  |d  | gd| |d  |d  |d  | |d |d  |d |d   |d |d   | |d  |d  |d  | ggS dS )zB
        Calculate the coefficient of lagrange polynomial
        r   r
   r   r   r   N)r   )rz   r   lambda_listdenominator1denominator2denominator3denominator4r   r   r   lagrange_polynomial_coefficient  s   
   
z1SASolverScheduler.lagrange_polynomial_coefficientc              
   C   s   |dv sJ |t |ksJ dg }| |d |}t|D ];}d}	t|D ]-}
| jr@|	|| |
 | |d |
 ||| 7 }	q&|	|| |
 | |d |
 || 7 }	q&||	 qt ||ksdJ d|S )N)r
   r   r   r   z4the length of lambda list must be equal to the orderr
   r   z3the length of coefficients does not match the order)r   r   r(   rt   r   r   r)   )rz   r   r   r   r   r   coefficientslagrange_coefficientr/   coefficientjr   r   r   get_coefficients_fnK  s"   

z%SASolverScheduler.get_coefficients_fnnoiser   r   c                 O   s  t |dkr
|d n|dd}|du r#t |dkr|d }ntd|du r6t |dkr2|d }ntd|du rIt |dkrE|d }ntd	|du r\t |d
krX|d
 }ntd|durftddd | j}	| j| jd  | j| j }
}| |
\}}
| |\}}t	|t	|
 }t	|t	| }t
|}|| }g }t|D ] }| j| }| | j| \}}t	|t	| }|| q| |||||}|}| jrm|dkrm| j| jd  }| |\}}t	|t	| }|d  dtd|d  |  |d d |d|d   d td|d  |   d|d  d    ||  7  < |d  dtd|d  |  |d d |d|d   d td|d  |   d|d  d    ||  8  < t|D ]>}| jr|d|d  |
 t|d  |  ||  |	|d    7 }qq|d|d   | ||  |	|d    7 }qq| jr|
tdtd|d  |   | }n||
 ttd| d  | }| jrt|d  | |
|  | | | }n
|| | | | }||j}|S )ag  
        One step for the SA-Predictor.

        Args:
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model at the current timestep.
            prev_timestep (`int`):
                The previous discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.
            order (`int`):
                The order of SA-Predictor at this timestep.

        Returns:
            `torch.Tensor`:
                The sample tensor at the previous timestep.
        r   prev_timestepNr
   r   r   z.missing `noise` as a required keyword argumentr   .missing `order` as a required keyword argumentr   ,missing `tau` as a required keyword argumentr   zPassing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`r:   r   r   r'   r   rs   rj   r~   r   r+   r    
zeros_liker(   r)   r   rt   r   r"   ry   r&   ) rz   r   r   r   r   r   r   r   r   model_output_listrh   sigma_s0rg   alpha_s0ri   	lambda_s0gradient_parthr   r/   sialpha_sisigma_si	lambda_sigradient_coefficientsx
temp_sigmatemp_alpha_stemp_sigma_stemp_lambda_s
noise_partx_tr   r   r   !stochastic_adams_bashforth_update_  s    







FF
0*$*z3SASolverScheduler.stochastic_adams_bashforth_updatethis_model_outputrv   
last_noisethis_samplec                O   s  t |dkr
|d n|dd}	|du r#t |dkr|d }ntd|du r6t |dkr2|d }ntd|du rIt |dkrE|d }ntd	|du r\t |d
krX|d
 }ntd|du rot |dkrk|d }ntd|	durytddd | j}
| j| j | j| jd  }}| |\}}| |\}}t	|t	| }t	|t	| }t
|}|| }g }t|D ] }| j| }| | j| \}}t	|t	| }|| q|
|g }| |||||}|}| jrd|dkrd|d  dtd|d  |  |d |d|d   d td|d  |   d|d  d |    7  < |d  dtd|d  |  |d |d|d   d td|d  |   d|d  d |    8  < t|D ]>}| jr|d|d  | t|d  |  ||  ||d    7 }qh|d|d   | ||  ||d    7 }qh| jr|tdtd|d  |   | }n|| ttd| d  | }| jrt|d  | ||  | | | }n
|| | | | }||j}|S )a  
        One step for the SA-Corrector.

        Args:
            this_model_output (`torch.Tensor`):
                The model outputs at `x_t`.
            this_timestep (`int`):
                The current timestep `t`.
            last_sample (`torch.Tensor`):
                The generated sample before the last predictor `x_{t-1}`.
            this_sample (`torch.Tensor`):
                The generated sample after the last predictor `x_{t}`.
            order (`int`):
                The order of SA-Corrector at this step.

        Returns:
            `torch.Tensor`:
                The corrected sample tensor at the current timestep.
        r   this_timestepNr
   z4missing `last_sample` as a required keyword argumentr   z3missing `last_noise` as a required keyword argumentr   z4missing `this_sample` as a required keyword argumentr   r      r   r   zPassing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`r:   r   r   )rz   r  rv   r  r  r   r   r   r   r  r   rh   r   rg   r   ri   r   r   r   r   r/   r  r  r  r  model_prev_listr  r  r  r  r   r   r   stochastic_adams_moulton_update  s    









FF
0*$*z1SASolverScheduler.stochastic_adams_moulton_updater   schedule_timestepsc                 C   sd   |du r| j }||k }t|dkrt| j d }|S t|dkr*|d  }|S |d  }|S )a  
        Find the index for a given timestep in the schedule.

        Args:
            timestep (`int` or `torch.Tensor`):
                The timestep for which to find the index.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule.
        Nr   r
   )rp   nonzeror   r   )rz   r   r  index_candidatesr~   r   r   r   index_for_timestepd  s   
z$SASolverScheduler.index_for_timestepc                 C   s@   | j du rt|tjr|| jj}| || _dS | j	| _dS )z
        Initialize the step_index counter for the scheduler.

        Args:
            timestep (`int` or `torch.Tensor`):
                The current timestep for which to initialize the step index.
        N)
r   
isinstancer+   Tensorry   rp   r   r  rw   rx   )rz   r   r   r   r   _init_step_index  s
   
	z"SASolverScheduler._init_step_indexreturn_dictc                 C   s  | j du r	td| jdu r| | | jdko| jdu}| j||d}|r<| | jd }| j|| j| j	|| j
|d}tt| jj| jjd d D ]}	| j|	d  | j|	< | j|	d  | j|	< qK|| jd< || jd< t|j||j|jd}
| jjrt| jjt| j| j }t| jjt| j| j d }n| jj}| jj}t|| jd | _t|| jd	 | _
| jdksJ | j
dksJ || _|
| _	| | jd }| j|||
| j|d
}| jt| jj| jjd k r|  jd7  _|  jd7  _|s|fS t|dS )a  
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the SA-Solver.

        Args:
            model_output (`torch.Tensor`):
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.

        Returns:
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.

        NzaNumber of inference steps is 'None', you need to run 'set_timesteps' after creating the schedulerr   r   rZ   )r  rv   r  r  r   r   r
   )	generatorr   r&   r   )r   r   r   r   r   )prev_sample)rl   r'   r~   r  rv   r   rF   rr   r  r  this_corrector_orderr(   rq   r_   rC   rD   rs   r	   r   r   r&   rK   r*   r   rp   ru   this_predictor_orderr  rw   r   )rz   r   r   r   r  r  use_correctormodel_output_convertcurrent_taur/   r   r   r  r  r   r   r   step  sl   


"	

 
zSASolverScheduler.stepc                 O   s   |S )a?  
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.Tensor`):
                The input sample.

        Returns:
            `torch.Tensor`:
                A scaled input sample.
        r   )rz   r   r   r   r   r   r   scale_model_input  s   z#SASolverScheduler.scale_model_inputoriginal_samplesrp   c                 C   s   | j j|jd| _ | j j|jd}||j}|| d }| }t|jt|jk r:|d}t|jt|jk s+d||  d }| }t|jt|jk r_|d}t|jt|jk sP|| ||  }|S )a2  
        Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
        diffusion process).

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise to add to the samples.
            timesteps (`torch.IntTensor`):
                The timesteps indicating the noise level for each sample.

        Returns:
            `torch.Tensor`:
                The noisy samples.
        )r   r%   r   rZ   r
   )rf   ry   r   r&   flattenr   r   r   )rz   r&  r   rp   rf   sqrt_alpha_prodsqrt_one_minus_alpha_prodnoisy_samplesr   r   r   	add_noise  s   

zSASolverScheduler.add_noisec                 C   s   | j jS N)r_   r>   r}   r   r   r   __len__9  s   zSASolverScheduler.__len__)r   )NN)r   r   r,  )NT)/__name__
__module____qualname____doc__r   _compatiblesr   r   r   intstrrm   ndarraylistr   boolr{   propertyr~   r   r   r+   r   r   r  r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r   tupler$  r%  	IntTensorr+  r-  r   r   r   r   r3   W   sX   C	
U


R,%'#
4
[,m	


 

%
f
*r3   )r   r   )r   typingr   r   r   rm   r+   configuration_utilsr   r   utilsr   r   utils.torch_utilsr	   scheduling_utilsr   r   r   scipy.statsr   r3  r   r  r2   r3   r   r   r   r   <module>   s,   
4