o
    	Ti:                  0   @   sZ  d dl Z d dlZd dlZd dlZd dlmZ d dlmZmZm	Z	m
Z
 d dlZd dlZd dlm  mZ d dlmZmZmZ d dlmZ d dlmZ ddlmZ d	d
lmZ e rdd dlmZ d dlm Z  eG dd dZ!eG dd dZ"G dd dZ#dd Z$dd Z%				dDdej&de'dej&de(de)de	ej& de"fd d!Z*e+ 				"	#									$	%					dEd&e	e
e,e-e, f  d'e	e' d(e	e' d)e'd*e(d+e	e
e,e-e, f  d,e	e' de(d-e	e
ej.e-ej. f  d.e	ej& d/e	ej& d0e	ej& d1e	e, d2e)d3e	ee'e'ej&gdf  d4e'd5e	e/e,ef  d6e(f$d7d8Z0				"	#	%	%	%	9	:									$	%					dFd&e	e
e,e-e, f  d'e	e' d(e	e' d)e'd*e(d;e)d<e)d=e)d>e'd?e1d+e	e
e,e-e, f  d,e	e' de(d-e	e
ej.e-ej. f  d.e	ej& d/e	ej& d0e	ej& d1e	e, d2e)d3e	ee'e'ej&gdf  d4e'd5e	e/e,ef  d6e(f.d@dAZ2G dBdC dCe#Z3dS )G    N)	dataclass)AnyCallableOptionalUnion)DDIMSchedulerStableDiffusionPipelineUNet2DConditionModel)rescale_noise_cfg)is_peft_available   )randn_tensor   )convert_state_dict_to_diffusers)
LoraConfig)get_peft_model_state_dictc                   @   s0   e Zd ZU dZejed< ejed< ejed< dS )DDPOPipelineOutputa_  
    Output class for the diffusers pipeline to be finetuned with the DDPO trainer

    Args:
        images (`torch.Tensor`):
            The generated images.
        latents (`list[torch.Tensor]`):
            The latents used to generate the images.
        log_probs (`list[torch.Tensor]`):
            The log probabilities of the latents.

    imageslatents	log_probsN__name__
__module____qualname____doc__torchTensor__annotations__ r   r   O/home/ubuntu/.local/lib/python3.10/site-packages/trl/models/modeling_sd_base.pyr   &   s
   
 

r   c                   @   s&   e Zd ZU dZejed< ejed< dS )DDPOSchedulerOutputad  
    Output class for the diffusers scheduler to be finetuned with the DDPO trainer

    Args:
        latents (`torch.Tensor`):
            Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)`
        log_probs (`torch.Tensor`):
            Log probability of the above mentioned sample. Shape: `(batch_size)`
    r   r   Nr   r   r   r   r   r    :   s   
 

r    c                   @   s   e Zd ZdZdefddZdefddZedd Z	ed	d
 Z
edd Zedd Zedd Zedd Zdd Zdd Zdd Zdd Zdd ZdS )DDPOStableDiffusionPipelinezU
    Main class for the diffusers pipeline to be finetuned with the DDPO trainer
    returnc                 O      t NNotImplementedErrorselfargskwargsr   r   r   __call__O      z$DDPOStableDiffusionPipeline.__call__c                 O   r#   r$   r%   r'   r   r   r   scheduler_stepR   r,   z*DDPOStableDiffusionPipeline.scheduler_stepc                 C   r#   )z@
        Returns the 2d U-Net model used for diffusion.
        r%   r(   r   r   r   unetU      z DDPOStableDiffusionPipeline.unetc                 C   r#   )zq
        Returns the Variational Autoencoder model used from mapping images to and from the latent space
        r%   r.   r   r   r   vae\   r0   zDDPOStableDiffusionPipeline.vaec                 C   r#   )zG
        Returns the tokenizer used for tokenizing text inputs
        r%   r.   r   r   r   	tokenizerc   r0   z%DDPOStableDiffusionPipeline.tokenizerc                 C   r#   )zc
        Returns the scheduler associated with the pipeline used for the diffusion process
        r%   r.   r   r   r   	schedulerj   r0   z%DDPOStableDiffusionPipeline.schedulerc                 C   r#   )zH
        Returns the text encoder used for encoding text inputs
        r%   r.   r   r   r   text_encoderq   r0   z(DDPOStableDiffusionPipeline.text_encoderc                 C   r#   )z6
        Returns the autocast context manager
        r%   r.   r   r   r   autocastx   r0   z$DDPOStableDiffusionPipeline.autocastc                 O   r#   )z?
        Sets the progress bar config for the pipeline
        r%   r'   r   r   r   set_progress_bar_config      z3DDPOStableDiffusionPipeline.set_progress_bar_configc                 O   r#   )z0
        Saves all of the model weights
        r%   r'   r   r   r   save_pretrained   r7   z+DDPOStableDiffusionPipeline.save_pretrainedc                 O   r#   )zB
        Returns the trainable parameters of the pipeline
        r%   r'   r   r   r   get_trainable_layers   r7   z0DDPOStableDiffusionPipeline.get_trainable_layersc                 O   r#   )zq
        Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state
        r%   r'   r   r   r   save_checkpoint   r7   z+DDPOStableDiffusionPipeline.save_checkpointc                 O   r#   )zq
        Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state
        r%   r'   r   r   r   load_checkpoint   r7   z+DDPOStableDiffusionPipeline.load_checkpointN)r   r   r   r   r   r+   r    r-   propertyr/   r1   r2   r3   r4   r5   r6   r8   r9   r:   r;   r   r   r   r   r!   J   s*    





r!   c                 C   s<   | j }|t|krtd| | jdt||   |S )a  
    As opposed to the default direction of broadcasting (right to left), this function broadcasts from left to right

        Args:
            input_tensor (`torch.FloatTensor`): is the tensor to broadcast
            shape (`tuple[int]`): is the shape to broadcast to
    zrThe number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to)r   )ndimlen
ValueErrorreshapeshapebroadcast_to)input_tensorrA   
input_ndimr   r   r   _left_broadcast   s   "rE   c                 C   sr   t | jd| |j}t | dk| jd| | j|j}d| }d| }|| d||   }|S )Nr   r   )r   gatheralphas_cumprodcputodevicewherefinal_alpha_cumprod)r(   timestepprev_timestepalpha_prod_talpha_prod_t_prevbeta_prod_tbeta_prod_t_prevvariancer   r   r   _get_variance   s   
rT           Fmodel_outputrM   sampleetause_clipped_model_outputprev_sampler"   c              	   C   s  | j du r	td|| jj| j   }t|d| jjd }| jd| }	t	| dk| jd| | j
}
t|	|j|j}	t|
|j|j}
d|	 }| jjdkrf||d |  |	d  }|}n>| jjdkr{|}||	d |  |d  }n)| jjdkr|	d | |d |  }|	d | |d |  }n
td	| jj d
| jjr| |}n| jjr|| jj | jj}t| ||}||d  }t||j|j}|r||	d |  |d  }d|
 |d  d | }|
d | | }|dur|durtd|du rt|j||j|jd}|||  }| | d  d|d   t| ttdttj  }|jttd|j d}t!|"|j|S )a  

    Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process
    from the learned model outputs (most often the predicted noise).

    Args:
        model_output (`torch.FloatTensor`): direct output from learned diffusion model.
        timestep (`int`): current discrete timestep in the diffusion chain.
        sample (`torch.FloatTensor`):
            current instance of sample being created by diffusion process.
        eta (`float`): weight of noise for added noise in diffusion step.
        use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
            predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
            `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide
            with the one provided as input and `use_clipped_model_output` will have not effect.
        generator: random number generator.
        variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
            can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion.
            (https://huggingface.co/papers/2210.05559)

    Returns:
        `DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample
    NzaNumber of inference steps is 'None', you need to run 'set_timesteps' after creating the schedulerr   r   epsilong      ?rW   v_predictionzprediction_type given as z6 must be one of `epsilon`, `sample`, or `v_prediction`r   zsCannot pass both generator and prev_sample. Please make sure that either `generator` or `prev_sample` stays `None`.)	generatorrJ   dtype)dim)#num_inference_stepsr?   confignum_train_timestepsr   clamprG   rF   rH   rK   rL   rE   rA   rI   rJ   prediction_typethresholding_threshold_sampleclip_sampleclip_sample_rangerT   r   r^   detachlogsqrt	as_tensornppimeantupleranger=   r    type)r(   rV   rM   rW   rX   rY   r]   rZ   rN   rO   rP   rQ   pred_original_samplepred_epsilonrS   	std_dev_tpred_sample_directionprev_sample_meanvariance_noiselog_probr   r   r   r-      sx   
"

r-   2         @pilTpromptheightwidthr`   guidance_scalenegative_promptnum_images_per_promptr]   r   prompt_embedsnegative_prompt_embedsoutput_typereturn_dictcallbackcallback_stepscross_attention_kwargsguidance_rescalec           (   
   C   s  |p	| j jj| j }|p| j jj| j }| ||||||| |dur+t|tr+d}n|dur9t|tr9t|}n|j	d }| j
}|dk}|durO|ddnd}| j||||||||d}| jj||d | jj}| j jj}| || ||||j||	|
}
t||| jj  }|
g}g }| j|d}t|D ]\}}|rt|
gd	 n|
}| j||}| j ||||d
dd } |r| d	\}!}"|!||"|!   } |r|dkrt| |"|d} t| j| ||
|}#|#j}
|#j}$||
 ||$ |t|d ks|d |kr#|d | jj dkr#|  |dur#|| dkr#||||
 qW d   n	1 s/w   Y  |dksS| j j!|
| j jj" d
dd }%| #|%||j\}%}&n|
}%d}&|&du redg|%j	d  }'ndd |&D }'| j$j%|%||'d}%t&| dr| j'dur| j'(  t)|%||S )u  
    Function invoked when calling the pipeline for generation. Args: prompt (`str` or `list[str]`, *optional*): The
    prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height
    (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the
    generated image.
        width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
            The width in pixels of the generated image.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense
            of slower inference.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            Guidance scale as defined in [Classifier-Free Diffusion
            Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of
            [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
            `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the
            text `prompt`, usually at the expense of lower image quality.
        negative_prompt (`str` or `list[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
            less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies
            to [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
            make generation deterministic.
        latents (`torch.FloatTensor`, *optional*):
            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor will ge generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
            If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/):
            `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be called
            with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `self.processor` in
            [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
        guidance_rescale (`float`, *optional*, defaults to 0.7):
            Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
            Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
            [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).
            Guidance rescale factor should fix overexposure when using zero terminal SNR.

    Examples:

    Returns:
        `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated
        log probabilities
    Nr   r         ?scaler   r   
lora_scalerJ   totalr   Fencoder_hidden_statesr   r   rU   r   latentr   Tc                 S      g | ]}| qS r   r   .0has_nsfwr   r   r   
<listcomp>      z!pipeline_step.<locals>.<listcomp>r   do_denormalizefinal_offload_hook)*r/   ra   sample_sizevae_scale_factorcheck_inputs
isinstancestrlistr>   rA   _execution_deviceget_encode_promptr3   set_timesteps	timestepsin_channelsprepare_latentsr^   orderprogress_bar	enumerater   catscale_model_inputchunkr
   r-   r   r   appendupdater1   decodescaling_factorrun_safety_checkerimage_processorpostprocesshasattrr   offloadr   )(r(   r}   r~   r   r`   r   r   r   rX   r]   r   r   r   r   r   r   r   r   r   
batch_sizerJ   do_classifier_free_guidancetext_encoder_lora_scaler   num_channels_latentsnum_warmup_stepsall_latentsall_log_probsr   itlatent_model_input
noise_prednoise_pred_uncondnoise_pred_textscheduler_outputry   imagehas_nsfw_conceptr   r   r   r   pipeline_stepE  s   X


	

6
&

r   1   r   rz   truncated_backproptruncated_backprop_randgradient_checkpointtruncated_backprop_timesteptruncated_rand_backprop_minmaxc           .      C   s  |p	| j jj| j }|p| j jj| j }t r | ||||||| |dur0t|tr0d}n|dur>t|t	r>t
|}n|jd }| j}|dk}|durT|ddnd}| j||||||||d}| jj||d | jj}| j jj}| || ||||j|||}W d   n1 sw   Y  t
||| jj  }|g}g } | j|d}!t|D ]\}"}#|rt|gd	 n|}$| j|$|#}$|rtj| j |$|#||d
dd }%n| j |$|#||d
dd }%|r|rt|
d |
d }&|"|&k r|% }%n|"|	k r|% }%|r|%d	\}'}(|'||(|'   }%|r |dkr t|%|(|d}%t | j|%|#||})|)j!}|)j"}*|#| | #|* |"t
|d ksT|"d |krj|"d | jj dkrj|!$  |durj|"| dkrj||"|#| qW d   n	1 svw   Y  |dks| j%j&|| j%jj' d
dd }+| (|+||j\}+},n|}+d},|,du rdg|+jd  }-ndd |,D }-| j)j*|+||-d}+t+| dr| j,dur| j,-  t.|+|| S )u  
    Function to get RGB image with gradients attached to the model weights.

    Args:
        prompt (`str` or `list[str]`, *optional*, defaults to `None`):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
            instead.
        height (`int`, *optional*, defaults to `pipeline.unet.config.sample_size * pipeline.vae_scale_factor`):
            The height in pixels of the generated image.
        width (`int`, *optional*, defaults to `pipeline.unet.config.sample_size * pipeline.vae_scale_factor`):
            The width in pixels of the generated image.
        num_inference_steps (`int`, *optional*, defaults to `50`):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense
            of slower inference.
        guidance_scale (`float`, *optional*, defaults to `7.5`):
            Guidance scale as defined in [Classifier-Free Diffusion
            Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. of
            [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
            `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the
            text `prompt`, usually at the expense of lower image quality.
        truncated_backprop (`bool`, *optional*, defaults to True):
            Truncated Backpropation to fixed timesteps, helps prevent collapse during diffusion reward training as
            shown in AlignProp (https://huggingface.co/papers/2310.03739).
        truncated_backprop_rand (`bool`, *optional*, defaults to True):
            Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps, this helps
            prevent collapse during diffusion reward training as shown in AlignProp
            (https://huggingface.co/papers/2310.03739). Enabling truncated_backprop_rand allows adapting earlier
            timesteps in diffusion while not resulting in a collapse.
        gradient_checkpoint (`bool`, *optional*, defaults to True):
            Adds gradient checkpointing to Unet forward pass. Reduces GPU memory consumption while slightly increasing
            the training time.
        truncated_backprop_timestep (`int`, *optional*, defaults to 49):
            Absolute timestep to which the gradients are being backpropagated. Higher number reduces the memory usage
            and reduces the chances of collapse. While a lower value, allows more semantic changes in the diffusion
            generations, as the earlier diffusion timesteps are getting updated. However it also increases the chances
            of collapse.
        truncated_rand_backprop_minmax (`Tuple`, *optional*, defaults to (0,50)):
            Range for randomized backprop. Here the value at 0 index indicates the earlier diffusion timestep to update
            (closer to noise), while the value at index 1 indicates the later diffusion timestep to update.
        negative_prompt (`str` or `list[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
            less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies
            to [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
            make generation deterministic.
        latents (`torch.FloatTensor`, *optional*):
            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor will ge generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
            If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/):
            `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be called
            with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
            `pipeline.processor` in
            [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
        guidance_rescale (`float`, *optional*, defaults to 0.7):
            Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
            Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
            [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).
            Guidance rescale factor should fix overexposure when using zero terminal SNR.

    Examples:

    Returns:
        `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated
        log probabilities
    Nr   r   r   r   r   r   r   r   F)r   use_reentrantr   rU   r   r   r   Tc                 S   r   r   r   r   r   r   r   r     r   z+pipeline_step_with_grad.<locals>.<listcomp>r   r   )/r/   ra   r   r   r   no_gradr   r   r   r   r>   rA   r   r   r   r3   r   r   r   r   r^   r   r   r   r   r   
checkpointrandomrandintri   r   r
   r-   r   r   r   r   r1   r   r   r   r   r   r   r   r   r   ).pipeliner}   r~   r   r`   r   r   r   r   r   r   r   r   rX   r]   r   r   r   r   r   r   r   r   r   r   rJ   r   r   r   r   r   r   r   r   r   r   r   r   rand_timestepr   r   r   ry   r   r   r   r   r   r   pipeline_step_with_grad  s   s



:		

6
@

r   c                   @   s   e Zd ZddddededefddZd	efd
dZd	efddZd	e	fddZ
edd Zedd Zedd Zedd Zedd Zedd Zdd Zdd Zd d! Zd"d# Zd$d% Zd&S )'"DefaultDDPOStableDiffusionPipelinemainT)pretrained_model_revisionuse_lorapretrained_model_namer   r   c                C   s   t j||d| _|| _|| _|| _z| jj|d|d d| _W n ty0   |r.t	dt
 Y nw t| jjj| j_d | j_| jjd | jjd | jj| j  d S )N)revision pytorch_lora_weights.safetensors)weight_namer   TzTrying to load LoRA weights but no LoRA weights found. Set `use_lora=False` or check that `pytorch_lora_weights.safetensors` exists in the model folder.F)r   from_pretrainedsd_pipeliner   pretrained_modelpretrained_revisionload_lora_weightsOSErrorwarningswarnUserWarningr   from_configr3   ra   safety_checkerr1   requires_grad_r4   r/   )r(   r   r   r   r   r   r   __init__  s4   
z+DefaultDDPOStableDiffusionPipeline.__init__r"   c                 O      t | jg|R i |S r$   )r   r   r'   r   r   r   r+   ?     z+DefaultDDPOStableDiffusionPipeline.__call__c                 O   r   r$   )r   r   r'   r   r   r   rgb_with_gradB  r   z0DefaultDDPOStableDiffusionPipeline.rgb_with_gradc                 O   s   t | jjg|R i |S r$   )r-   r   r3   r'   r   r   r   r-   E  s   z1DefaultDDPOStableDiffusionPipeline.scheduler_stepc                 C      | j jS r$   )r   r/   r.   r   r   r   r/   H     z'DefaultDDPOStableDiffusionPipeline.unetc                 C   r   r$   )r   r1   r.   r   r   r   r1   L  r   z&DefaultDDPOStableDiffusionPipeline.vaec                 C   r   r$   )r   r2   r.   r   r   r   r2   P  r   z,DefaultDDPOStableDiffusionPipeline.tokenizerc                 C   r   r$   )r   r3   r.   r   r   r   r3   T  r   z,DefaultDDPOStableDiffusionPipeline.schedulerc                 C   r   r$   )r   r4   r.   r   r   r   r4   X  r   z/DefaultDDPOStableDiffusionPipeline.text_encoderc                 C   s   | j rtjS d S r$   )r   
contextlibnullcontextr.   r   r   r   r5   \  s   z+DefaultDDPOStableDiffusionPipeline.autocastc                 C   s6   | j rtt| jj}| jj||d | j| d S )Nsave_directoryunet_lora_layers)r   r   r   r   r/   save_lora_weightsr8   )r(   
output_dir
state_dictr   r   r   r8   `  s   z2DefaultDDPOStableDiffusionPipeline.save_pretrainedc                 O   s   | j j|i | d S r$   )r   r6   r'   r   r   r   r6   f  s   z:DefaultDDPOStableDiffusionPipeline.set_progress_bar_configc                 C   s^   | j r+tdddg dd}| jj| | jj D ]}|jr&|tj	|_
q| jjS | jjS )N   gaussian)to_kto_qto_vzto_out.0)r
lora_alphainit_lora_weightstarget_modules)r   r   r   r/   add_adapter
parametersrequires_gradrI   r   float32data)r(   lora_configparamr   r   r   r9   i  s   z7DefaultDDPOStableDiffusionPipeline.get_trainable_layersc                 C   s   t |dkr
td| jr0t|d dr0t|d dd d ur0tt|d }| jj||d d S | jsHt	|d t
rH|d tj|d d S tdt|d  )Nr   CGiven how the trainable params were set, this should be of length 1r   peft_configr   r/   Unknown model type )r>   r?   r   r   getattrr   r   r   r   r   r	   r8   ospathjoinrr   )r(   modelsweightsr   r   r   r   r   r:   |  s   (z2DefaultDDPOStableDiffusionPipeline.save_checkpointc                 C   s   t |dkr
td| jr$| jj|dd\}}| jj|||d d d S | jsLt|d trLtj|dd}|d j	d
i |j
 |d |  ~d S td	t|d  )Nr   r  r   )r   r   )network_alphasr/   r/   )	subfolderr  r   )r>   r?   r   r   lora_state_dictload_lora_into_unetr   r	   r   register_to_configra   load_state_dictr   rr   )r(   r  	input_dirr  r  
load_modelr   r   r   r;     s   
z2DefaultDDPOStableDiffusionPipeline.load_checkpointN)r   r   r   r   boolr   r   r+   r   r    r-   r<   r/   r1   r2   r3   r4   r5   r8   r6   r9   r:   r;   r   r   r   r   r     s,     





r   )rU   FNN)NNNrz   r{   Nr   rU   NNNNr|   TNr   NrU   )NNNrz   r{   TTTr   r   Nr   rU   NNNNr|   TNr   NrU   )4r   r  r   r   dataclassesr   typingr   r   r   r   numpyrm   r   torch.utils.checkpointutilsr   	diffusersr   r   r	   >diffusers.pipelines.stable_diffusion.pipeline_stable_diffusionr
   transformers.utilsr   corer   sd_utilsr   peftr   
peft.utilsr   r   r    r!   rE   rT   FloatTensorintfloatr  r-   r   r   r   	Generatordictr   rp   r   r   r   r   r   r   <module>   s\  T	
 		
 S	

  
