o
    ۷iec                     @   s   d dl Z d dlmZ d dlmZ d dlZddlmZmZ ddl	m
Z
mZ ddlmZ dd	lmZ eeZeG d
d de
ZG dd deeZdS )    N)	dataclass)Literal   )ConfigMixinregister_to_config)
BaseOutputlogging)randn_tensor   )SchedulerMixinc                   @   s.   e Zd ZU dZejed< dZejdB ed< dS )EDMEulerSchedulerOutputaq  
    Output class for the scheduler's `step` function output.

    Args:
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    prev_sampleNpred_original_sample)__name__
__module____qualname____doc__torchTensor__annotations__r    r   r   _/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/schedulers/scheduling_edm_euler.pyr      s   
 
r   c                   @   s  e Zd ZdZg ZdZe									
dUdedededed de	ded deded ddfddZ
edefddZede	fddZede	fddZdVd!e	ddfd"d#Zd$ejd%eejB dejfd&d'Zd%eejB dejfd(d)Zd$ejd*ejd%eejB dejfd+d,Zd$ejd-eejB dejfd.d/Z			dWd0e	d1eejB d2ejee B dB fd3d4Z		dXd5ejdedB dedB dejfd6d7Z		dXd5ejdedB dedB dejfd8d9Z	dYd-eejB d:ejdB de	fd;d<Zd-eejB ddfd=d>Zd?d?ed@dAddBdfd*ejd-eejB d$ejdCedDedEedFedGejdB dHedIejdB de e!B fdJdKZ"dLejdMejdNejdejfdOdPZ#d%eejB deejB fdQdRZ$de	fdSdTZ%dS )ZEDMEulerScheduleraL	  
    Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1].

    [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
    https://huggingface.co/papers/2206.00364

    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.

    Args:
        sigma_min (`float`, *optional*, defaults to `0.002`):
            Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
            range is [0, 10].
        sigma_max (`float`, *optional*, defaults to `80.0`):
            Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
            range is [0.2, 80.0].
        sigma_data (`float`, *optional*, defaults to `0.5`):
            The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
        sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`):
            Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper
            (https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model:
            https://huggingface.co/stabilityai/cosxl.
        num_train_timesteps (`int`, *optional*, defaults to `1000`):
            The number of diffusion steps to train the model.
        prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`):
            Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and
            `"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
        rho (`float`, *optional*, defaults to `7.0`):
            The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
        final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`):
            The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
            sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
    r
   Mb`?      T@      ?karras  epsilon      @zero	sigma_min	sigma_max
sigma_datasigma_scheduler   exponentialnum_train_timestepsprediction_type)r   v_predictionrhofinal_sigmas_type)r    r!   returnNc	                 C   s  |dvrt d|dd | _tjj rtjntj}	tj|d |	d| }
|dkr0| 	|
}
n	|dkr9| 
|
}
|
tj}
| |
| _| jjdkrP|
d	 }n| jjd
krYd}n	t d| jj t|
tjd||
jdg| _d| _d | _d | _| jd| _d S )Nr%   z-Wrong value for provided for `sigma_schedule=z`.`r
   dtyper   r&   r!   r    r   C`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got r
   
fill_valuedeviceFcpu)
ValueErrornum_inference_stepsr   backendsmpsis_availablefloat32float64arange_compute_karras_sigmas_compute_exponential_sigmastoprecondition_noise	timestepsconfigr+   catfullr4   sigmasis_scale_input_called_step_index_begin_index)selfr!   r"   r#   r$   r'   r(   r*   r+   sigmas_dtyperF   
sigma_lastr   r   r   __init__W   s.   

zEDMEulerScheduler.__init__c                 C   s   | j jd d d S )z
        Return the standard deviation of the initial noise distribution.

        Returns:
            `float`:
                The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`.
        r   r
   r   )rC   r"   rJ   r   r   r   init_noise_sigma   s   	z"EDMEulerScheduler.init_noise_sigmac                 C      | j S )z
        Return the index counter for the current timestep. The index will increase by 1 after each scheduler step.

        Returns:
            `int` or `None`:
                The current step index, or `None` if not yet initialized.
        )rH   rN   r   r   r   
step_index   s   	zEDMEulerScheduler.step_indexc                 C   rP   )z
        Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index`
        method.

        Returns:
            `int` or `None`:
                The begin index, or `None` if not yet set.
        rI   rN   r   r   r   begin_index   s   
zEDMEulerScheduler.begin_indexr   rS   c                 C   s
   || _ dS )z
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
            begin_index (`int`, defaults to `0`):
                The begin index for the scheduler.
        NrR   )rJ   rS   r   r   r   set_begin_index   s   
z!EDMEulerScheduler.set_begin_indexsamplesigmac                 C   s   |  |}|| }|S )a  
        Precondition the input sample by scaling it according to the EDM formulation.

        Args:
            sample (`torch.Tensor`):
                The input sample tensor to precondition.
            sigma (`float` or `torch.Tensor`):
                The current sigma (noise level) value.

        Returns:
            `torch.Tensor`:
                The scaled input sample.
        )_get_conditioning_c_in)rJ   rU   rV   c_inscaled_sampler   r   r   precondition_inputs   s   
z%EDMEulerScheduler.precondition_inputsc                 C   s*   t |tjst|g}dt| }|S )aS  
        Precondition the noise level by applying a logarithmic transformation.

        Args:
            sigma (`float` or `torch.Tensor`):
                The sigma (noise level) value to precondition.

        Returns:
            `torch.Tensor`:
                The preconditioned noise value computed as `0.25 * log(sigma)`.
        g      ?)
isinstancer   r   tensorlog)rJ   rV   c_noiser   r   r   rA      s   z$EDMEulerScheduler.precondition_noisemodel_outputc                 C   s   | j j}|d |d |d   }| j jdkr%|| |d |d  d  }n | j jdkr;| | |d |d  d  }n
td| j j d|| ||  }|S )a  
        Precondition the model outputs according to the EDM formulation.

        Args:
            sample (`torch.Tensor`):
                The input sample tensor.
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model.
            sigma (`float` or `torch.Tensor`):
                The current sigma (noise level) value.

        Returns:
            `torch.Tensor`:
                The denoised sample computed by combining the skip connection and output scaling.
        r   r   r   r)   zPrediction type z is not supported.)rC   r#   r(   r6   )rJ   rU   r_   rV   r#   c_skipc_outdenoisedr   r   r   precondition_outputs   s    z&EDMEulerScheduler.precondition_outputstimestepc                 C   s6   | j du r
| | | j| j  }| ||}d| _|S )a  
        Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
        need to scale the denoising model input depending on the current timestep.

        Args:
            sample (`torch.Tensor`):
                The input sample tensor.
            timestep (`float` or `torch.Tensor`):
                The current timestep in the diffusion chain.

        Returns:
            `torch.Tensor`:
                A scaled input sample.
        NT)rQ   _init_step_indexrF   rZ   rG   )rJ   rU   rd   rV   r   r   r   scale_model_input   s   

z#EDMEulerScheduler.scale_model_inputr7   r4   rF   c                 C   s   || _ tjj rtjntj}|du rtjdd| j |d}nt|t	r+tj
||d}n||}| jjdkr<| |}n| jjdkrG| |}|jtj|d}| || _| jjdkr`|d	 }n| jjd
krid}n	td| jj t|tjd||jdg| _d| _d| _| jd| _dS )a  
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`, *optional*):
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
            sigmas (`torch.Tensor | list[float]`, *optional*):
                Custom sigmas to use for the denoising process. If not defined, the default behavior when
                `num_inference_steps` is passed will be used.
        Nr   r
   r-   r   r&   )r.   r4   r!   r/   r    r0   r1   r2   r5   )r7   r   r8   r9   r:   r;   r<   linspacer[   floatr\   r@   rC   r$   r>   r?   rA   rB   r+   r6   rD   rE   r4   rF   rH   rI   )rJ   r7   r4   rF   rK   rL   r   r   r   set_timesteps  s0   



zEDMEulerScheduler.set_timestepsrampc                 C   sP   |p| j j}|p| j j}| j j}|d|  }|d|  }||||   | }|S )aT  
        Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).

        Args:
            ramp (`torch.Tensor`):
                A tensor of values in [0, 1] representing the interpolation positions.
            sigma_min (`float`, *optional*):
                Minimum sigma value. If `None`, uses `self.config.sigma_min`.
            sigma_max (`float`, *optional*):
                Maximum sigma value. If `None`, uses `self.config.sigma_max`.

        Returns:
            `torch.Tensor`:
                The computed Karras sigma schedule.
        r
   )rC   r!   r"   r*   )rJ   rj   r!   r"   r*   min_inv_rhomax_inv_rhorF   r   r   r   r>   C  s   z(EDMEulerScheduler._compute_karras_sigmasc                 C   sD   |p| j j}|p| j j}tt|t|t| 	d}|S )a  
        Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
        https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26

        Args:
            ramp (`torch.Tensor`):
                A tensor of values representing the interpolation positions.
            sigma_min (`float`, *optional*):
                Minimum sigma value. If `None`, uses `self.config.sigma_min`.
            sigma_max (`float`, *optional*):
                Maximum sigma value. If `None`, uses `self.config.sigma_max`.

        Returns:
            `torch.Tensor`:
                The computed exponential sigma schedule.
        r   )
rC   r!   r"   r   rg   mathr]   lenexpflip)rJ   rj   r!   r"   rF   r   r   r   r?   a  s   (z-EDMEulerScheduler._compute_exponential_sigmasschedule_timestepsc                 C   s:   |du r| j }||k }t|dkrdnd}||  S )ak  
        Find the index of a given timestep in the timestep schedule.

        Args:
            timestep (`float` or `torch.Tensor`):
                The timestep value to find in the schedule.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule. For the very first step, returns the second index if
                multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
        Nr
   r   )rB   nonzerorn   item)rJ   rd   rq   indicesposr   r   r   index_for_timestep}  s
   z$EDMEulerScheduler.index_for_timestepc                 C   s@   | j du rt|tjr|| jj}| || _dS | j	| _dS )z
        Initialize the step index for the scheduler based on the given timestep.

        Args:
            timestep (`float` or `torch.Tensor`):
                The current timestep to initialize the step index from.
        N)
rS   r[   r   r   r@   rB   r4   rv   rH   rI   )rJ   rd   r   r   r   re     s
   
z"EDMEulerScheduler._init_step_index        infg      ?Ts_churns_tmins_tmaxs_noise	generatorreturn_dictr   c                 C   sN  t |ttjtjfrtd| jstd | j	du r | 
| |tj}| j| j	 }||  kr6|krDn nt|t| jd  dnd}||d  }|dkrmt|j|j|j|d}|| }|||d	 |d	  d
   }|
du rx| |||}
||
 | }| j| j	d  | }|||  }||j}|  jd7  _|	s||
fS t||
dS )a  
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model.
            timestep (`float` or `torch.Tensor`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.
            s_churn (`float`, *optional*, defaults to `0.0`):
                The amount of stochasticity to add at each step. Higher values add more noise.
            s_tmin (`float`, *optional*, defaults to `0.0`):
                The minimum sigma threshold below which no noise is added.
            s_tmax (`float`, *optional*, defaults to `float("inf")`):
                The maximum sigma threshold above which no noise is added.
            s_noise (`float`, *optional*, defaults to `1.0`):
                Scaling factor for noise added to the sample.
            generator (`torch.Generator`, *optional*):
                A random number generator for reproducibility.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple.
            pred_original_sample (`torch.Tensor`, *optional*):
                The predicted denoised sample from a previous step. If provided, skips recomputation.

        Returns:
            [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`:
                If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the previous sample tensor and the
                second element is the predicted original sample tensor.
        zPassing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to `EDMEulerScheduler.step()` is not supported. Make sure to pass one of the `scheduler.timesteps` as a timestep.zThe `scale_model_input` function should be called before `step` to ensure correct denoising. See `StableDiffusionPipeline` for a usage example.Nr
   g4y?rw   r   )r.   r4   r}   r   r   )r   r   )r[   intr   	IntTensor
LongTensorr6   rG   loggerwarningrQ   re   r@   r;   rF   minrn   r	   shaper.   r4   rc   rH   r   )rJ   r_   rd   rU   ry   rz   r{   r|   r}   r~   r   rV   gamma	sigma_hatnoiseeps
derivativedtr   r   r   r   step  sF   .

4zEDMEulerScheduler.steporiginal_samplesr   rB   c                    s
  j j|j|jd}|jjdkr)t|r)jj|jtjd |j|jtjd}nj|j ||j}j	du rF fdd|D }nj
durUj
g|jd  }n	j	g|jd  }||  }t|jt|jk r}|d}t|jt|jk sn|||  }|S )	am  
        Add noise to the original samples according to the noise schedule at the specified timesteps.

        Args:
            original_samples (`torch.Tensor`):
                The original samples to which noise will be added.
            noise (`torch.Tensor`):
                The noise tensor to add to the original samples.
            timesteps (`torch.Tensor`):
                The timesteps at which to add noise, determining the noise level from the schedule.

        Returns:
            `torch.Tensor`:
                The noisy samples with added noise scaled according to the timestep schedule.
        )r4   r.   r9   r-   Nc                    s   g | ]} | qS r   )rv   ).0trq   rJ   r   r   
<listcomp>9  s    z/EDMEulerScheduler.add_noise.<locals>.<listcomp>r   r/   )rF   r@   r4   r.   typer   is_floating_pointrB   r;   rS   rQ   r   flattenrn   	unsqueeze)rJ   r   r   rB   rF   step_indicesrV   noisy_samplesr   r   r   	add_noise  s"   


zEDMEulerScheduler.add_noisec                 C   s    d|d | j jd  d  }|S )a4  
        Compute the input conditioning factor for the EDM formulation.

        Args:
            sigma (`float` or `torch.Tensor`):
                The current sigma (noise level) value.

        Returns:
            `float` or `torch.Tensor`:
                The input conditioning factor `c_in`.
        r
   r   r   )rC   r#   )rJ   rV   rX   r   r   r   rW   H  s   z(EDMEulerScheduler._get_conditioning_c_inc                 C   s   | j jS N)rC   r'   rN   r   r   r   __len__W  s   zEDMEulerScheduler.__len__)r   r   r   r   r   r   r   r    )r   )NNN)NNr   )&r   r   r   r   _compatiblesorderr   rh   r   r   rM   propertyrO   rQ   rS   rT   r   r   rZ   rA   rc   rf   strr4   listri   r>   r?   rv   re   	Generatorboolr   tupler   r   rW   r   r   r   r   r   r   1   s
   "	
,

 

 #
5
!

	

m
0r   )rm   dataclassesr   typingr   r   configuration_utilsr   r   utilsr   r   utils.torch_utilsr	   scheduling_utilsr   
get_loggerr   r   r   r   r   r   r   r   <module>   s   
