o
    	Ti;p                     @   s   d dl Z d dlZd dlZd dlmZ d dlmZ d dlmZ d dl	m
Z
mZmZmZ d dlZd dlmZ d dlmZ d dlmZmZ d d	lmZ d d
lmZ ddlmZ ddlmZ ddlmZm Z m!Z! e rkd dl"Z"ee#Z$G dd deZ%dS )    N)defaultdict)futures)Path)AnyCallableOptionalUnion)Accelerator)
get_logger)ProjectConfigurationset_seed)PyTorchModelHubMixin)is_wandb_available   )DDPOStableDiffusionPipeline   )
DDPOConfig)PerPromptStatTrackergenerate_model_cardget_comet_experiment_urlc                       sR  e Zd ZdZddgZ	d3dedeeje	e
 e	e gejf deg e	e
ef f ded	eeeeegef  f
d
dZd4ddZdedefddZdd ZdejdedejfddZdd Zdd Zdd Zd d! Zd"d# Zd$e	ee
f fd%d&Zd3d'ee fd(d)Zd*d+ Z fd,d-Z			d5d.ee
 d/ee
 d0ee
e e
 df fd1d2Z!  Z"S )6DDPOTrainerah  
    The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily
    inspired by the work here: https://github.com/kvablack/ddpo-pytorch As of now only Stable Diffusion based pipelines
    are supported

    Attributes:
        **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more:
         details.
        **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used:
        **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
        **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
        **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
    trlddpoNconfigreward_functionprompt_functionsd_pipelineimage_samples_hookc              	   C   s  t dt |d u rt d || _|| _|| _|| _td"i | jj}| jj	r}t
jt
j| jj	| j_	dt
j| jj	vr}ttdd t
| jj	}t|dkr]td| jj	 tdd	 |D }t
j| jj	d|d
  | j_	|d
 d |_t| jj| jj | _td"| jj| jj|| jj| j d| jj| _ | ! \}	}
|	st|
|jd uo|jdk}| j j"r| j j#| jj$|st%|& dn|& | jj'd t()d|  t*| jj+dd || _,| j,j-d| j j. dddd | j jdkrt/j0}n| j jdkrt/j1}nt/j2}| j,j3j4| j j5|d | j,j6j4| j j5|d | j,j7j4| j j5|d | j,8 }| j 9| j: | j ;| j< | jj=rJdt/j>j?j@_=| AtB|tsV|C n|| _D| j,6| j,jE| jjFd u rjdgn| jjFddd| j,jEjGdjH4| j j5d | _I|jJrtK|jL|jM| _N| j,jOp| j jO| _OtP| j,dr| j,jQr| j R|| jD\}| _Dttdd |C | _Sn| j R|| jD\| _S| _D| jjTrtUjV|jWd| _X|j	rt()d |j	  | j Y|j	 t|j	Zd!d
 d | _[d S d| _[d S )#Nz@DDPOTrainer is deprecated and will be removed in version 0.23.0.z8No image_samples_hook provided; no images will be loggedcheckpoint_c                 S   s   d| v S )Nr    xr   r   L/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/ddpo_trainer.py<lambda>W   s    z&DDPOTrainer.__init__.<locals>.<lambda>r   zNo checkpoints found in c                 S   s   g | ]}t |d d qS )_)intsplit).0r!   r   r   r"   
<listcomp>]       z(DDPOTrainer.__init__.<locals>.<listcomp>r%   r   )log_withmixed_precisionproject_configgradient_accumulation_stepstensorboard)ddpo_trainer_config)r   init_kwargs
T)device_specificFTimestep)positiondisableleavedescdynamic_ncolsfp16bf16)dtype pt
max_lengthreturn_tensorspadding
truncationr?   use_lorac                 S   s   | j S N)requires_grad)pr   r   r"   r#      s    )max_workerszResuming from r$   r   )\warningswarnDeprecationWarning	prompt_fn	reward_fnr   image_samples_callbackr   project_kwargsresume_fromospathnormpath
expanduserbasenamelistfilterlistdirlen
ValueErrorsortedjoin	iterationr&   sample_num_stepstrain_timestep_fractionnum_train_timestepsr	   r+   r,   !train_gradient_accumulation_stepsaccelerator_kwargsaccelerator_config_checkis_main_processinit_trackerstracker_project_namedictto_dicttracker_kwargsloggerinfor   seedr   set_progress_bar_configis_local_main_processtorchfloat16bfloat16float32vaetodevicetext_encoderunetget_trainable_layersregister_save_state_pre_hook_save_model_hookregister_load_state_pre_hook_load_model_hook
allow_tf32backendscudamatmul_setup_optimizer
isinstance
parameters	optimizer	tokenizernegative_promptsmodel_max_length	input_idsneg_prompt_embedper_prompt_stat_trackingr   $per_prompt_stat_tracking_buffer_size"per_prompt_stat_tracking_min_countstat_trackerautocasthasattrrD   preparetrainable_layersasync_reward_computationr   ThreadPoolExecutorrH   executor
load_stater'   first_epoch)selfr   r   r   r   r   accelerator_project_configcheckpointscheckpoint_numbersis_okaymessageis_using_tensorboardinference_dtyper   rx   r   r   r"   __init__;   s   






zDDPOTrainer.__init__Fc           	         s~   |s'g }|D ]\}}}  |||\}}|tj| jjd|f qt| S  j fdd|} fdd|D }t| S )Nrv   c                    s
    j |  S rE   )rM   r    r   r   r"   r#      s   
 z-DDPOTrainer.compute_rewards.<locals>.<lambda>c                    s.   g | ]\}}t j|  jjd | fqS r   )rp   	as_tensorresultrc   rv   )r(   rewardreward_metadatar   r   r"   r)      s    z/DDPOTrainer.compute_rewards.<locals>.<listcomp>)	rM   appendrp   r   rc   rv   r   mapzip)	r   prompt_image_pairsis_asyncrewardsimagespromptsprompt_metadatar   r   r   r   r"   compute_rewards   s   
zDDPOTrainer.compute_rewardsepochglobal_stepc                    s  j jjjjd\}fddd  D j|jjd\}}t|D ]\}}||| || g q)j	durI	||j
jd  t|}j
|  }j
j||| | d|d jjrj
d	   }jjj|d
d}	j|	|}
n||  | d  }
t|
j
jdj
j j
jd< d	= d j \} t!jj"D ]v}tj#|j
jdfdd$ D t% fddt!|D }dD ]}| tj&|j
jddddf |f |< q ' }fdd|D }t(| }fdd|D }jj)*  +||||}j
j,s2t-dq|dkrK|jj. dkrKj
j/rKj
0  |S )a  
        Perform a single step of training.

        Args:
            epoch (int): The current epoch.
            global_step (int): The current global step.

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.
            - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
              and the accelerator tracker.

        Returns:
            global_step (int): The updated global step.

        )
iterations
batch_sizec                    s&   i | ]  t  fd dD qS )c                    s   g | ]}|  qS r   r   )r(   skr   r"   r)          z/DDPOTrainer.step.<locals>.<dictcomp>.<listcomp>)rp   cat)r(   )samplesr   r"   
<dictcomp>   s   & z$DDPOTrainer.step.<locals>.<dictcomp>r   )r   N)r   r   reward_mean
reward_stdstep
prompt_idsT)skip_special_tokensg:0yE>r%   
advantages	timestepsr   c                    s   i | ]	\}}||  qS r   r   r(   r   v)permr   r"   r   '      c                    s   g | ]}t j jjd qS r   )rp   randpermrc   rv   r(   r$   )num_timestepsr   r   r"   r)   ,  r*   z$DDPOTrainer.step.<locals>.<listcomp>)r   latentsnext_latents	log_probsc                    s.   g | ]}|j d  jjg|jdd R  qS )r%   r   N)reshaper   train_batch_sizeshape)r(   r   r   r   r"   r)   8  s   . c                    s   g | ]	}t t |qS r   )rh   r   )r(   
row_values)original_keysr   r"   r)   =  r   zsOptimization step should have been performed by this point. Please check calculated gradient accumulation settings.)1_generate_samplesr   sample_num_batches_per_epochsample_batch_sizekeysr   r   	enumerateextendrN   rc   trackersrp   r   gathercpunumpylogmeanstdr   r   r   batch_decoder   updater   r   num_processesprocess_indexru   rv   r   rangetrain_num_inner_epochsr   itemsstackarangevaluesr   rx   train_train_batched_samplessync_gradientsrZ   	save_freqre   
save_state)r   r   r   prompt_image_datar   rewards_metadatai
image_datar   r   r   total_batch_sizeinner_epochpermskeyoriginal_valuesreshaped_valuestransposed_valuessamples_batchedr   )r   r   r   r   r   r"   r      sz   







&
zDDPOTrainer.stepc                 C   s(  |   L | jjr0| jt|gd t|gd |j}|d\}}	|| jj	|	|   }n	| j|||j}| jj
|||| jj|d}
|
j}W d   n1 sSw   Y  t|| jj | jj}t|| }| || jj|}dt|| d  }tt|d | jjk }|||fS )a  
        Calculate the loss for a batch of an unpacked sample

        Args:
            latents (torch.Tensor):
                The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
            timesteps (torch.Tensor):
                The timesteps sampled from the diffusion model, shape: [batch_size]
            next_latents (torch.Tensor):
                The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height,
                width]
            log_probs (torch.Tensor):
                The log probabilities of the latents, shape: [batch_size]
            advantages (torch.Tensor):
                The advantages of the latents, shape: [batch_size]
            embeds (torch.Tensor):
                The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] Note: the "or" is because if
                train_cfg is True, the expectation is that negative prompts are concatenated to the embeds

        Returns:
            loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) (all of these are of shape (1,))
        r   )etaprev_sampleNg      ?      ?)r   r   	train_cfgr   rx   rp   r   samplechunksample_guidance_scalescheduler_step
sample_etar   clamptrain_adv_clip_maxexplosstrain_clip_ranger   absfloat)r   r   r   r   r   r   embeds
noise_prednoise_pred_uncondnoise_pred_textscheduler_step_outputlog_probratior  	approx_klclipfracr   r   r"   calculate_lossL  sN   
 
zDDPOTrainer.calculate_lossr   
clip_ranger  c                 C   s8   | | }| t |d| d|  }t t ||S )Nr   )rp   r  r   maximum)r   r   r  r  unclipped_lossclipped_lossr   r   r"   r    s   
zDDPOTrainer.lossc                 C   sL   | j jrdd l}|jj}ntjj}||| j j| j j| j j	f| j j
| j jdS )Nr   )lrbetasweight_decayeps)r   train_use_8bit_adambitsandbytesoptim	AdamW8bitrp   AdamWtrain_learning_ratetrain_adam_beta1train_adam_beta2train_adam_weight_decaytrain_adam_epsilon)r   trainable_layers_parametersr  optimizer_clsr   r   r"   r     s   
zDDPOTrainer._setup_optimizerc                 C   s   | j ||| |  d S rE   )r   save_checkpointpop)r   modelsweights
output_dirr   r   r"   r{     s   zDDPOTrainer._save_model_hookc                 C   s   | j || |  d S rE   )r   load_checkpointr+  )r   r,  	input_dirr   r   r"   r}     s   zDDPOTrainer._load_model_hookc                    sd  g }g } j j   j|dd}t|D ]}t fddt|D  \}} j j|ddd j jjdj	
 jj}	 j |	d }
  "  j |
| jj jj jjdd	}|j}|j}|j}W d
   n1 slw   Y  tj|dd}tj|dd} j jj|d}||	|
||d
d
d
df |d
d
dd
f ||d ||||g q||fS )a4  
        Generate samples from the model

        Args:
            iterations (int): Number of iterations to generate samples for
            batch_size (int): Batch size to use for sampling

        Returns:
            samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
        r   c                    s   g | ]}   qS r   )rL   r   r   r   r"   r)     r   z1DDPOTrainer._generate_samples.<locals>.<listcomp>r>   r?   Tr@   r   )prompt_embedsnegative_prompt_embedsnum_inference_stepsguidance_scaler   output_typeN)dimr%   )r   r1  r   r   r   r   r2  )r   rx   evalr   repeatr   r   r   r   r   ru   rc   rv   rw   r   r   r^   r  r  r   r   r   rp   r   	schedulerr   r   )r   r   r   r   r   sample_neg_prompt_embedsr$   r   r   r   r1  	sd_outputr   r   r   r   r   r   r"   r     sX   
	zDDPOTrainer._generate_samplesc                 C   s  t t}t|D ]\}}| jjrt|d |d g}n|d }t| jD ]}	| j	
| jju | |d dd|	f |d dd|	f |d dd|	f |d dd|	f |d |\}
}}|d	 | |d
 | |d |
 | j	|
 | j	jr| j	t| jts| j n| j| jj | j  | j  W d   n1 sw   Y  | j	jrdd | D }| j	j|dd}|||d | j	j||d |d7 }t t}q%q|S )a  
        Train on a batch of samples. Main training segment

        Args:
            inner_epoch (int): The current inner epoch
            epoch (int): The current epoch
            global_step (int): The current global step
            batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.

        Returns:
            global_step (int): The updated global step
        r2  r1  r   Nr   r   r   r   r  r  r  c                 S   s"   i | ]\}}|t t |qS r   )rp   r   r   r   r   r   r"   r   )  s   " z6DDPOTrainer._train_batched_samples.<locals>.<dictcomp>r   )	reduction)r   r   r   r   )r   rV   r   r   r   rp   r   r   r`   rc   
accumulater   rx   r  r   backwardr   clip_grad_norm_r   r   r   train_max_grad_normr   r   	zero_gradr   reducer   r   )r   r   r   r   batched_samplesrl   _ir   r  jr  r  r  r   r   r"   r     sN   


"z"DDPOTrainer._train_batched_samplesreturnc                 C   s   | j j| jj | j j }| j j| jj | j j }| j j| j jks/dd| j j d| j j dfS | j j| j j dksHdd| j j d| j j dfS || dksYdd| d| dfS d	S )
NFzSample batch size (z9) must be greater than or equal to the train batch size ()r   z-) must be divisible by the train batch size (zNumber of samples per epoch (z3) must be divisible by the total train batch size ()Tr=   )r   r   rc   r   r   r   ra   )r   samples_per_epochtotal_train_batch_sizer   r   r"   rd   1  s*   zDDPOTrainer._config_checkepochsc                 C   s6   d}|du r
| j j}t| j|D ]}| ||}qdS )z>
        Train the model for a given number of epochs
        r   N)r   
num_epochsr   r   r   )r   rJ  r   r   r   r   r"   r   L  s   zDDPOTrainer.trainc                 C   s   | j | |   d S rE   )r   save_pretrainedcreate_model_card)r   save_directoryr   r   r"   _save_pretrainedV  s   zDDPOTrainer._save_pretrainedc                    sL   | j jd u rt| j jj}n	| j jdd }| j|d t || d S )N/r%   )
model_name)	argshub_model_idr   r.  namer'   rM  super_save_checkpoint)r   modeltrialrQ  	__class__r   r"   rV  [  s
   zDDPOTrainer._save_checkpointrQ  dataset_nametagsc                 C   s   |   sdS t| jjdrtj| jjjs| jjj}nd}|du r&t }nt	|t
r/|h}nt|}t| jjdr?|d || j td}t||| j||t r]tjdur]tjjndt d|ddd	}|tj| jjd
 dS )a  
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        N_name_or_pathunsloth_versionunslotha          @inproceedings{black2024training,
            title        = {{Training Diffusion Models with Reinforcement Learning}},
            author       = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
            year         = 2024,
            booktitle    = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
            publisher    = {OpenReview.net},
            url          = {https://openreview.net/forum?id=YCWjhGrJFD},
        }DDPOz5Training Diffusion Models with Reinforcement Learningz
2305.13301)
base_modelrQ  rS  r[  r\  	wandb_url	comet_urltrainer_nametrainer_citationpaper_titlepaper_idz	README.md)is_world_process_zeror   rW  r   rQ   rR   isdirr]  setr   straddr   
_tag_namestextwrapdedentr   rS  r   wandbrunurlr   saver\   rR  r.  )r   rQ  r[  r\  ra  citation
model_cardr   r   r"   rM  c  s8    



zDDPOTrainer.create_model_cardrE   )F)NNN)#__name__
__module____qualname____doc__rm  r   r   rp   Tensortuplerk  r   r   r   r   r   r&   r   r  r  r  r   r{   r}   r   r   boolrd   r   rO  rV  r   rV   rM  __classcell__r   r   rY  r"   r   *   sZ    
 
kD
>=

r   )&rQ   rn  rI   collectionsr   
concurrentr   pathlibr   typingr   r   r   r   rp   
accelerater	   accelerate.loggingr
   accelerate.utilsr   r   huggingface_hubr   transformersr   r,  r   ddpo_configr   utilsr   r   r   rp  rv  rk   r   r   r   r   r"   <module>   s(   