o
    ϯiX                     @   s   d dl Z d dlZd dlZd dlZd dlZd dlmZ d dlZd dl	Z	d dl
m  mZ zd dlZW n ey<   dZY nw d dlmZmZ ddlmZ G dd deZdd	 Z	dd
dZdddZ				dddZ	dddZdd Zdd ZdS )    N)suppress)ClipLossgather_features   )	is_masterc                   @   s*   e Zd ZdZdd Zdd Zd
ddZd	S )AverageMeterz1Computes and stores the average and current valuec                 C   s   |    d S N)resetself r   M/home/ubuntu/.local/lib/python3.10/site-packages/laion_clap/training/train.py__init__   s   zAverageMeter.__init__c                 C   s   d| _ d| _d| _d| _d S )Nr   )valavgsumcountr
   r   r   r   r	      s   
zAverageMeter.resetr   c                 C   s8   || _ |  j|| 7  _|  j|7  _| j| j | _d S r   )r   r   r   r   )r   r   nr   r   r   update!   s   zAverageMeter.updateN)r   )__name__
__module____qualname____doc__r   r	   r   r   r   r   r   r      s
    r   c                 C   s   t | dr| jS | S )Nmodule)hasattrr   )modelr   r   r   unwrap_model(   s   
r   c           +      C   s   t |j}|jdkrt jjjnt}	|   t|j	|j
d|j|j|j|j|jd}
|d j|d j}}|jr@|d ur@|| |j}tt|jd d}|jdkrY|j  t }t }t }t }t|D ]\}}|| | }t|t r|! D ]}|| q~n|| |}|d }|"t |  t|t r|! D ]}|#  qn|#  |	 ) | |||\}}}}}}|jr|
||||||d	} n|
|||d
} W d    n1 sw   Y  t|t r>|d ur-|$| %  |! D ]/}|jr"|&  |'| |(  |)| W d    n	1 sw   Y  q|)| q|"  nW| %  |! D ]}|)  q5nF|d ur||$| %  |jrr|&  |'| |(  |)| W d    n	1 slw   Y  n|)| |"  n| %  |)  t * $ t+| j,-dtd |jrt+| j.-dtd W d    n	1 sw   Y  |"t |  t }|d }!t/|r}|d dks|!|kr}t|t rt0|d }"nt0|}"|!|" |j }#|j}$d|! | }%|"| 1 |" |1 }&|1 }'t|t r|jrbt23d| d|#d| d|$ d|%dd|j4dd|j5dd|j5dd|j5dddd |! D  d|&dd|'d |j4|j4|j4|&|'d d |! D d!}(nt23d| d|#d| d|$ d|%dd|j4dd|j5dd|j5dd|j5ddd"d |! D  d|&d |j4|j4|j4|&d#d |! D d$}(n|jrt23d| d|#d| d|$ d|%dd|j4dd|j5dd|j5dd|j5dd|j6d d% d&d|&dd|'d |j4|j4|j4|&|'|j6d d% d!}(nGt23d| d|#d| d|$ d|%dd|j4dd|j5dd|j5dd|j5dd|j6d d% d&d|&d |j4|j4|j4|&|j6d d% d$}(|(7 D ]+\})}*d'|) })|d ur]|8|)|*| |j9rst9d usjJ d(t9|)|*d)|i qI|:  |:  qjd S )*NampT)
local_lossgather_with_gradcache_labelsrank
world_sizeuse_horovodmlp_lossweight_loss_kappatrainr   
   toytext)audio_featurestext_featureslogit_scale_alogit_scale_taudio_features_mlptext_features_mlp)r*   r+   r,   r   d   waveformg      Y@zTrain Epoch:  [>/z (z.0fz
%)] Loss: z#.5gz#.4gz) Data (t): z.3fz Batch (t): z LR: c                 S      g | ]	}|j d  d qS r   lrparam_groups.0o_r   r   r   
<listcomp>       z#train_one_epoch.<locals>.<listcomp>z Logit Scale Audio: zLogit Scale Text: c                 S   r5   r6   r8   r:   r   r   r   r=      r>   )loss	data_time
batch_timescale_audio
scale_textr7   c                 S   r5   r6   r8   r:   r   r   r   r=      r>   c                 S   r5   r6   r8   r:   r   r   r   r=      r>   )r?   r@   rA   rB   r7   r7   5fztrain/Please install wandb.step);torchdevice	precisioncudar   autocastr   r&   r   r   r   r!   r"   horovodclap_mlplosskappa
dataloadersamplerdistributed	set_epochnum_batchesmathceillognum_samplesdataset_typedatasetgenerate_queuer   time	enumerate
isinstancedictvaluesr   	zero_gradscalebackwardsynchronizeunscale_skip_synchronizerF   no_gradr   r,   clamp_r-   r   lenitemlogginginfor   r   r9   items
add_scalarwandbr	   )+r   dataepoch	optimizerscaler	schedulerargs	tb_writerrH   rK   r?   rO   rP   num_batches_per_epochsample_digitsloss_mbatch_time_mdata_time_mendibatchrF   saudiostextsr<   r*   r+   r.   r/   r,   r-   
total_lossbatch_count
batch_sizerW   samples_per_epochpercent_completelogit_scale_scalar_alogit_scale_scalar_tlog_datanamer   r   r   r   train_one_epoch/   s  







	









"$
	$		$	$
 r   c                    s2  i }|j st|s|S t|j}|   t|rtd |jdkr'tjjj	nt
}|jddgkrb|j r7tdt| ||||||}| D ]}	||	 qEd| vrZ|d|i t|||}n+d|v r|jr||j dksx||jkr|d j}
d}|
j}i }|jrd	dg g g g d
|d< n	d	dg g d|d< t  t|
D ]Z\}}|}|d }ttdd |d D }|D ] }|| vr|jrd	dg g g g d
||< qd	dg g d||< q|  | |||\}}}}}}|j r"|jrt||||dd|j|j|j|jd
\}}}}nt||dd|j|j|j|jd\}}t|r||jd 7 }g |dD ]  dkrm|  d  |!  |  d  |!  |jrk|  d  |!  |  d  |!  q4t"#t"$dd |d D  kd }|  d  |! %dt&|'  |  d  |! %dt&|'  |jr|  d  |! %dt&|'  |  d  |! %dt&|'  q4W d    n	1 sw   Y  t|r|d dkrt()d| d| d| d qt|r}i }| D ]o |jr@t*t+|  d t+|  d |! t+|  d t+|  d |! |jd}nt*t+|  d t+|  d |! |jd} fd d!|, D | < ||   d| vr{|d|i qW d    n	1 sw   Y  t|r|s|S t()d| d"d#-d$d | D   |j.r|, D ]\}}|d ur|/d%| || qt0t1j2-|j3d&d'}|4t56| |4d# W d    n	1 sw   Y  |j7rt7d usJ d(|, D ]\}}t78d%| |d|i q|S |S ))NzEvaluating...r   Clotho	audiocapszEParallel evaluation not supported for eval only Clotho and audiocaps.rp   r   r           )cumulative_lossrW   all_audio_featuresall_text_featuresall_audio_features_mlpall_text_features_mlpallr   rW   r   r   r)   c                 S   $   g | ]}d  |ddd qS -r4   joinsplitr;   br   r   r   r=   F     $ zevaluate.<locals>.<listcomp>__url__F)
r*   r+   r.   r/   r   r   r!   r"   r#   r$   )r*   r+   r   r   r!   r"   r#   r$   r   r   r   r   c                 S   r   r   r   r   r   r   r   r=     r   r0   zEval Epoch: r2   z / ])r*   r+   r,   r.   r/   r-   r$   )r*   r+   r,   r$   c                       i | ]\}} d  | |qS r4   r   r;   kvr   r   r   
<dictcomp>      zevaluate.<locals>.<dictcomp> 
c                 S   s$   g | ]}d  dd | D qS )	c                 S   s&   g | ]\}}| d t |ddqS )z:    z.4f)roundr   r   r   r   r=        & z'evaluate.<locals>.<listcomp>.<listcomp>)r   rl   )r;   mr   r   r   r=     s    zval/zresults.jsonlza+rE   )9parallel_evalr   rG   rH   evalprintrI   rJ   r   rK   r   val_dataset_namesNotImplementedErrorevaluate_clotho_audiocapsr_   r   keys"select_top_metric_clotho_audiocapsval_frequencyepochsrO   rW   rM   rf   r\   listsetr   r!   r"   rL   shapeappendcpunpwherearrayindex_selecttensorlongrj   rk   get_metricscatrl   r   	save_logsrm   openospathcheckpoint_pathwritejsondumpsrn   rV   )r   ro   rp   rt   ru   metricsrH   rK   val_metrics_per_datasetr   rO   rW   samples_per_val	eval_infor|   r}   r   r   	all_namesr   r*   r+   r.   r/   r,   r-   idxmetrics_single_datasetr   fr   r   r   evaluate
  s  








^



  



r   Fc              	   C   sH  i }|rv||  |     }|    }	|| |     }
|
    }t| jd  }t||t|	| t|
| t|| d }|	 |d< | jd |d< ||
 d |	| d d}tt
|dd}nI||  |     }|    }t| jd  }t||t|| d }|	 |d< | jd |d< ||d}tt
|dd}| D ]^\}}tj|d	d
}t||kd }|   }| d || d< tt|d || d< dD ]}t||k || d| < qtt|dk d|d  d|| d< q|S )Nr   r   r   rW      )audio_to_texttext_to_audior   r   T
descending
_mean_rank_median_rankr      r'   z_R@r'   r   z_mAP@10)tdetachr   rG   aranger   r   Fcross_entropyri   rh   viewrl   argsortr   numpymeanr   floormedian)r*   r+   r,   r.   r/   r-   r$   r   a_logits_per_audioa_logits_per_textt_logits_per_audiot_logits_per_textlabelsr   logitsground_truthlogits_per_audiologits_per_textr   logitrankingpredsr   r   r   r   r     s\   	








,r   c           '         sx  |d j }t  i }t|D ]\}	}
|
} jdkr3ddlm fdd|
d D tndd	lm	  fd
d|
d D fddd 
 D ttdd |
d D }|D ]}||
 vrrddg g d||< qa|  | |d|}| d|}tj|dd}tj|dd}ttdd |
d D }|D ]Nttdd |
d D kd }| d | dt|  | d | dd|jd gdt| d|jd g qW d   n1 sw   Y  qi }|
 D ]| dd|\}}| }tj| d dd}tj| d dd}|| |       td dj dj  i }|jd |d< t|jd  fddtdD }fddtdD }t|t| d }|  |d < g }tdD ]9}ddd|ddf }tt!|"dd}tj#|d!d"}t||kd }||  $  qtj%|dd#}| d |d$< t&t'|d |d%< d&D ]}t||k |d'| < qtt|d(k d|d  d|d)< g } g }!tD ]g}|ddf }"tj#|"d!d"}t|d |d d d }tt(|gd |"ddkd }#t)|#}$|!|$  $  |#|#d(k    $ }%t*tdt!|%d |%d  d }&| |& qt| |d*< d&D ]}tt|!|k |d+| < qfd,d|+ D |< qW d   |S 1 sw   Y  |S )-a  
    Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py.
    1. for text-to-audio retrieval, do 5 times and average the results
    2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text
    3. for map@10 in audio-to-text retrieval:
        3.1: sort the rank of 5 text
        3.2: exclude the rank >=10 (0-index)
        3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks).
        (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth.
        (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc.
    r   transformerr   tokenizec                    s   g | ]} |qS r   r   r;   r   r   r   r   r=   Y      z-evaluate_clotho_audiocaps.<locals>.<listcomp>	full_textr   )	tokenizerc                    s   g | ]	}| j d qS )tmodelr   r   )rt   r   r   r   r=   ]  r>   c                    s&   i | ]  t  fd dD qS )c                    s   g | ]}|  qS r   r   r   r   r   r   r=   ^  r   z8evaluate_clotho_audiocaps.<locals>.<dictcomp>.<listcomp>)rG   r   )r;   )r   r  r   r   ^  r   z-evaluate_clotho_audiocaps.<locals>.<dictcomp>c                 S   r   r   r   r   r   r   r   r=   d  r   r   r   r   Nr   )dimc                 S   r   r   r   r   r   r   r   r=   t  r   c                 S   r   r   r   r   r   r   r   r=   x  r   r   r   r   zdataset z, logits_per_audio shape: z, logits_per_text shape: rW   c              	      s4   g | ]}t d dddd|f  qS r   Nr   r   reshaper;   d)r   r   rW   r   r   r=          c              	      s4   g | ]}t d dd|ddf  qS r  r  r  )r   r   rW   r   r   r=     r	  r   r   Tr   )axistext_to_audio_mean_ranktext_to_audio_median_rankr   ztext_to_audio_R@r'   ztext_to_audio_mAP@10zaudio_to_text_mAP@10zaudio_to_text_R@c                    r   r   r   r   r   r   r   r     r   ),rO   rG   rf   r\   r  clap_moduler   r   ro   r   r   r   r   r   	normalizer   r   r   r   r   r   r   r   r  r   r   r   rj   rk   r   ranger   ri   rh   r   r   r   concatenater   r   stackminr   rl   )'r   ro   rp   rt   rK   rH   ru   rO   r   r|   r}   r   r   r   r*   r+   r   val_metrics_allr,   r-   r   audio_to_text_losstext_to_audio_lossr   	pred_textr  r   r   r   r   pred_text_concatr   map_allpred_audio_alllogit_singleall_predmin_predall_pred_filter
map_singler   )	rt   r   r   r   r   rW   r   r   r   r   r   A  s   


	
 $&
&"

   r   c                 C   sN   g }|   D ]}| | | d | | | d  d }|| qt|S )zI
    Calculate performance for Clotho+AudioCaps for model selection.
    z/audio_to_text_mAP@10z/text_to_audio_mAP@10r   )r   r   r   r   )r   selection_performance_allr   selection_performancer   r   r   0calculate_selection_performance_clotho_audiocaps  s   
r!  c           	      C   s<  t |dsJt|}i }| D ]#}||  D ]}|| | ||dd d d |dd  < qq||d< | d |d< | | ||_||_| S t|}|j}||kri }| D ]#}||  D ]}|| | ||dd d d |dd  < qcq[||d< | d |d< | | ||_||_| S | |j | S )Ntop_selection_performancer4   r   z-topr   rp   ztop-selection-epoch)r   r!  r   r   r   
top_metricr"  )	r   r   rt   r   metric_updater   r   selection_performance_newselection_performance_oldr   r   r   r     s:   
2
2
r   r   )NNNF)r   rj   rT   r   r[   
contextlibr   r   r   rG   torch.nn.functionalnn
functionalr   rn   ImportErrorr  r   r   rQ   r   objectr   r   r   r   r   r   r!  r   r   r   r   r   <module>   sB    
 
\ s
J
 "