o
    }oi                     @   s   d dl Z d dlZd dlZd dlZd dlmZmZmZm	Z	m
Z
mZmZmZmZ zd dlZed dZW n ey>   dZY nw d dlmZ d dlmZmZmZ G dd	 d	ZdS )
    N)	MAELossMSELossSDRLosscalculate_mae_batchcalculate_meancalculate_mse_batchcalculate_sdr_batchconvolution_invariant_targetscale_invariant_target
torchaudioTF)CombinedLoss)calculate_sdr_numpy"convolution_invariant_target_numpyscale_invariant_target_numpyc                   @   s  e Zd Zejjejdddgejdddgejdddgdededefdd	Z	ejjd
d Z
ejjdd Zejjdd Zejjdd Zejjdd ZejjejdddgdefddZejjejdddgdefddZejjejdddgdd ZejjejdddgdefddZejjejdddgdd Zejjejddgejddd gdedefd!d"Zejjejd#dd$gejdddgejdddgejdddgded#ededefd%d&Zejjejd#dd$gejdddgded#efd'd(Zejjd)d* Zejjd+d, Zejjd-d. Zejjd/d0 Zejjejdddgejd1d2dgded1efd3d4Zejjejdddgejd1d2dgded1efd5d6Zejjejdddgejd1d2dgded1efd7d8Zejjd9d: Zejjd;d< Z ejjd=d> Z!ejjejdddgejd1d2dgded1efd?d@Z"ejjejdddgejd1d2dgded1efdAdBZ#ejjejdddgejd1d2dgded1efdCdDZ$ejjdEdF Z%ejjdGdH Z&ejjdIdJ Z'ejjejj(e) dKdLdMdN Z*dOS )PTestAudioLossesnum_channels      use_maskTFuse_input_lengthc                 C   s  d}d}d}d}d}d}	t  }
|
| t|D ]}t j|||f|
d}t jd||f|
d	}t j|d|fd
}t|D ]}d||ddd|| f< q<|rn|rntt	 t
||||d W d   n1 shw   Y  q|rxt
|||d}n|rt
|||d}nt
||d}t|D ]3}t|D ],}t |||d|s|r|| n|f }t j|||f ||	dsJ d| d| qqqdS )zTest mean calculation   2   
   *   g|=ư>)size	generatorr   )lowhighr   r   r         ?N)maskinput_lengtheps)r!   r#   )r"   r#   )r#   atolzMean not matching for example 
, channel )torch	Generatormanual_seedrangerandnrandintzerospytestraisesRuntimeErrorr   meanallclose)selfr   r   r   
batch_sizenum_samplesnum_batchesrandom_seedr#   r%   rngninput_signalr"   r!   iuutbcgolden r@   ]/home/ubuntu/.local/lib/python3.10/site-packages/tests/collections/audio/test_audio_losses.pytest_calculate_mean3   sJ   
 z#TestAudioLosses.test_calculate_meanc                 C   sX   t jdd}t jdd}tt t||ddd W d   dS 1 s%w   Y  dS )CTest SDR calculation with scale and conovolution invariant options.r   r   d   r   T)estimatetargetscale_invariantconvolution_invariantN)r'   r+   r.   r/   
ValueErrorr   r3   rF   rG   r@   r@   rA   2test_calculate_sdr_scale_and_convolution_invariantd   s
   "zBTestAudioLosses.test_calculate_sdr_scale_and_convolution_invariantc                 C   p   t jdd}t jdd}t dg}t jdd}tt t||||d W d   dS 1 s1w   Y  dS )z=Test MSE calculation with simultaneous input length and mask.rD   r   rE   rF   rG   r"   r!   N)r'   r+   tensoronesr.   r/   r0   r   r3   rF   rG   r"   r!   r@   r@   rA   !test_calculate_mse_input_and_maskn      "z1TestAudioLosses.test_calculate_mse_input_and_maskc                 C      t jdd}t jdd}tt t||d W d   n1 s"w   Y  t jdd}t j|jd}tt t||d W d   dS 1 sKw   Y  dS )z1Test MSE calculation with unsupported dimensions.r   r   rE   r   r   rD   rF   rG   Nr   r   rE   r      )r'   r+   r.   r/   AssertionErrorr   shaper0   rK   r@   r@   rA   %test_calculate_mse_invalid_dimensionsz      "z5TestAudioLosses.test_calculate_mse_invalid_dimensionsc                 C   rM   )z=Test MAE calculation with simultaneous input length and mask.rD   r   rE   rN   N)r'   r+   rO   rP   r.   r/   r0   r   rQ   r@   r@   rA   !test_calculate_mae_input_and_mask   rS   z1TestAudioLosses.test_calculate_mae_input_and_maskc                 C   rT   )z1Test MAE calculation with unsupported dimensions.rU   r   rD   rV   NrW   )r'   r+   r.   r/   rY   r   rZ   r0   rK   r@   r@   rA   %test_calculate_mae_invalid_dimensions   r\   z5TestAudioLosses.test_calculate_mae_invalid_dimensionsc                 C   s  g d}d}d}d}d}d}t jj|d}dD ]}	|D ]}
t|
|	d	}t|D ]}|j|||fd
}|jddd|j|||fd
 }|| }||jddd7 }||jddd7 }t|}t|}t 	||f}t|D ]#}t|D ]}t
|||ddf |||ddf |	|
d|||f< qsqmt|||	|
d}|||d}t j|   ||dsJ d| d|
 d|	 t j||  |dsJ d| d|
 d|	 q'qqdS )zTest SDR calculation)r   gؗҜ<皙?r   r   r   r   r   seed)TF)r#   remove_meanr   {Gz?r   r   r   N)rF   rG   rb   r#   rV   r$   zSDR not matching for example z, eps=z, remove_mean=!SDRLoss not matching for example )nprandomdefault_rngr   r*   normaluniformr'   rO   r-   r   r   r2   cpudetachnumpyiscloser1   )r3   r   test_epsr4   r5   r6   r7   r%   _rngrb   r#   sdr_lossr9   rG   noiserF   tensor_estimatetensor_target
golden_sdrr=   muut_sdruut_sdr_lossr@   r@   rA   test_sdr   s`    

	zTestAudioLosses.test_sdrc                    s.  d}d}d}d}d}t jj|d}|jdd|d	}|t | }t|d
}	t|D ]k}
|j|||fd|jddd|jjd }| t	
}t	
}d}t|D ]  fddt|D }t t || }||7 }qV|| }|	||d}t j|   | |dsJ d|
 q)dS )z0Test SDR calculation with weighting for channelsr   r   r   r   r   r`   rc   r    r   r   r   weightr   MbP?rd   r   c              	      s4   g | ]}t  |d d f  |d d f dqS NrV   r   .0rw   r=   rF   rG   r@   rA   
<listcomp>      &z5TestAudioLosses.test_sdr_weighted.<locals>.<listcomp>rV   r$   rf   N)rg   rh   ri   rk   sumr   r*   rj   rZ   r'   rO   arrayr2   rl   rm   rn   )r3   r   r4   r5   r6   r7   r%   rq   channel_weightrr   r9   rs   rt   ru   rv   sdrry   r@   r   rA   test_sdr_weighted   s<   



z!TestAudioLosses.test_sdr_weightedc                    s*  d}d}d}d}d}t jj|d}t }t|D ]z}	|j|||fd|jddd	|jjd }
|
 |jd
||d}t	
}t	
}t	
|}d}t|D ]\  fddt|D }t t |}||7 }qR|| }||||d}t j|   | |dsJ d|	 qdS )z'Test SDR calculation with input length.r   r   r   r   r   r`   r   r~   rd   r   r{   r   c              	      s4   g | ]}t  |d f  |d f dqS r   r   r   r=   b_lenrF   rG   r@   rA   r   >  r   z9TestAudioLosses.test_sdr_input_length.<locals>.<listcomp>rF   rG   r"   r$   rf   Nrg   rh   ri   r   r*   rj   rk   rZ   integersr'   rO   	enumerater1   r   r2   rl   rm   rn   r3   r   r4   max_num_samplesr6   r7   r%   rq   rr   r9   rs   r"   rt   ru   tensor_input_lengthrv   r   ry   r@   r   rA   test_sdr_input_length  s<   



z%TestAudioLosses.test_sdr_input_lengthc                    s.  d}d}d}d}d}t jj|d}tdd}t|D ]z}	|j|||fd	|jd
dd|jjd	 }
|
 |jd||d}t	
}t	
}t	
|}d}t|D ]\  fddt|D }t t |}||7 }qT|| }||||d}t j|   | |dsJ d|	 qdS )z1Test SDR calculation with scale invariant option.r   r   r   r   r   r`   T)rH   r   r~   rd   r   r{   r   c              	      s6   g | ]}t  |d f  |d f ddqS )NT)rF   rG   rH   r   r   r   r@   rA   r   p  s    "z<TestAudioLosses.test_sdr_scale_invariant.<locals>.<listcomp>r   r$   rf   Nr   r   r@   r   rA   test_sdr_scale_invariantN  s<   




z(TestAudioLosses.test_sdr_scale_invariantc                    s,  d}d}d}d}d}t jj|d}t }t|D ]{}	|j|||fd|jddd	|jjd }
|
 |jd
d|||fdt	
}t	
}t	
}d
}t|D ]  fddt|D }t t |}||7 }qU|| }||||d}t j|   | |dsJ d|	 qdS )z(Test SDR calculation with temporal mask.r   r   r   r   r   r`   r   r~   rd   r      r{   c                    sP   g | ]$}t  | |d d f dkf  | |d d f dkf dqS )Nr   rV   r   r   r=   rF   r!   rG   r@   rA   r     s    <z8TestAudioLosses.test_sdr_binary_mask.<locals>.<listcomp>)rF   rG   r!   r$   rf   N)rg   rh   ri   r   r*   rj   rk   rZ   r   r'   rO   r1   r   r2   rl   rm   rn   )r3   r   r4   r   r6   r7   r%   rq   rr   r9   rs   rt   ru   tensor_maskrv   r   ry   r@   r   rA   test_sdr_binary_mask  s<   



z$TestAudioLosses.test_sdr_binary_masksdr_maxr   r   c                    s0  d}d}d}d}d}t jj|d}td}	t|D ]{}
|j|||fd|jd	dd
|jjd }| |jd||d}t	
}t	
}t	
|}d}t|D ]\  fddt|D }t t |}||7 }qT|| }|	|||d}t j|   | |dsJ d|
 qdS )z-Test SDR calculation with soft max threshold.r   r   r   r   r   r`   )r   r   r~   rd   r   r{   r   c              	      s6   g | ]}t  |d f  |d f dqS )N)rF   rG   r   r   r   r=   r   rF   r   rG   r@   rA   r     s    (z0TestAudioLosses.test_sdr_max.<locals>.<listcomp>r   r$   rf   Nr   )r3   r   r   r4   r   r6   r7   r%   rq   rr   r9   rs   r"   rt   ru   r   rv   r   ry   r@   r   rA   test_sdr_max  s<   




zTestAudioLosses.test_sdr_maxfilter_length    c              
   C   s  d}d}d}d}d}	t jj|d}
t|D ]V}|
j|||fd}|
jddd	|
j|jd }|| }|
j|||d
}tj	|d|fd}t|D ]}d||ddd|| f< qH|r|rt
t tt|t|t||d W d   n1 s}w   Y  t
t tt|t|t|||d W d   n1 sw   Y  qtt|t||rt|nd|r|ndd}tt|t||rt|nd|r|nd|d}|dkrt||sJ dt|D ]|}|s|r|| n|}t|D ]k}t|||d|f |||d|f d}t j|||d|f    ||	ds5J d| d| t|||d|f |||d|f |d}t j|||d|f    ||	dsjJ d| d| qqqdS )z>Test target calculation with scale and convolution invariance.r   r   r   r   r   r`   r   r~   rd   r{   r   r    NrN   )rF   rG   r"   r!   r   z*SI and CI should match for filter_length=1rV   r$   zSI not matching for example r&   )rF   rG   r   zCI not matching for example )rg   rh   ri   r*   rj   rk   rZ   r   r'   r-   r.   r/   r0   r
   rO   r	   r2   r   rl   rm   rn   r   )r3   r   r   r   r   r4   r   r6   r7   r%   rq   r9   rG   rs   rF   r"   r!   r;   	si_target	ci_targetr=   r   rw   si_target_refci_target_refr@   r@   rA   test_target_calculation  s   


  
" 
z'TestAudioLosses.test_target_calculationc                    s&  d}d}d}d}d}t jj|d}tdd}	t|D ]u}
|j|||fd	|jd
dd|jjd	 }| |j||d}|	t	
t	
t	
|d}d}t|D ]\  fddt|D }t t |}||7 }qV|| }t j|   | |dsJ d|
 qdS )z7Test SDR calculation with convolution invariant option.r   r   r   r   r   r`   T)rI   convolution_filter_lengthr   r~   rd   r{   r   r   c              	      s8   g | ]}t  |d f  |d f ddqS )NT)rF   rG   rI   r   r   r   r=   r   rF   r   rG   r@   rA   r   l  s    zBTestAudioLosses.test_sdr_convolution_invariant.<locals>.<listcomp>r$   rf   Nr   )r3   r   r   r4   r   r6   r7   r%   rq   rr   r9   rs   r"   ry   rv   r   r@   r   rA   test_sdr_convolution_invariantI  s:   	
z.TestAudioLosses.test_sdr_convolution_invariantc                 C   s<   t t tddd W d   dS 1 sw   Y  dS )rC   TrH   rI   Nr.   r/   rJ   r   r3   r@   r@   rA   (test_sdr_scale_and_convolution_invariant~  s   "z8TestAudioLosses.test_sdr_scale_and_convolution_invariantc                 C   s|   t jdd}t jdd}t dg}t jdd}tddd}tt |||||d W d   dS 1 s7w   Y  dS )z=Test SDR calculation with simultaneous input length and mask.rD   r   rE   Fr   rN   N)r'   r+   rO   rP   r   r.   r/   r0   )r3   rF   rG   r"   r!   rr   r@   r@   rA   test_sdr_length_and_mask  s   "z(TestAudioLosses.test_sdr_length_and_maskc                 C   v   t t tddgd W d   n1 sw   Y  t t tddgd W d   dS 1 s4w   Y  dS )zTest SDR with invalid weights.re   r   r|   Nr_   r   r   r@   r@   rA   test_sdr_invalid_weight     "z'TestAudioLosses.test_sdr_invalid_weightc                 C   :   t t tdd W d   dS 1 sw   Y  dS )z Test SDR with invalid reduction.not-mean	reductionNr   r   r@   r@   rA   test_sdr_invalid_reduction  s   "z*TestAudioLosses.test_sdr_invalid_reductionndim   c              	   C   s  d}d}d}d}d}d}|dkr||||fn|||f}	|dkr!dnd	}
t |d
}tjj|d}t|D ]}|j|	d}|jddd|j|	d }|| }||jd	dd7 }||jd	dd7 }t|}t|}t	||f}t|D ]+}t|D ]$}|||ddf |||ddf  }tj
t|d |
d|||f< qyqst||d}|||d}tj|   ||dsJ d| tj||
 |dsJ d| q3dS )zTest MSE calculationr   r   {   r   r   r   r   re   re   r   r`   r   rc   r   rd   Nr   axisrV   r$   zMSE not matching for example !MSELoss not matching for example )r   rg   rh   ri   r*   rj   rk   r'   rO   r-   r1   absr   r2   rl   rm   rn   ro   )r3   r   r   r4   r5   num_featuresr6   r7   r%   signal_shapereduction_dimmse_lossrq   r9   rG   rs   rF   rt   ru   
golden_mser=   rw   erruut_mseuut_mse_lossr@   r@   rA   test_mse  sH   


$"$zTestAudioLosses.test_msec                    \  d}d}d}d}d}d}|dkr||||fn|||f}	|dkr!dnd	t jj|d
}
|
jdd|d}|t | }t||d}t|D ]h}|
j|	d|
jddd|
jjd }| t	
}t	
}d}t|D ]  fddt|D }t t || }||7 }qm|| }|||d}t j|   ||dsJ d| qCdS )z0Test MSE calculation with weighting for channelsr   r   r   r   r   r   r   r   re   r`   rc   r    r{   r}   r   r   r~   rd   r   c                    sD   g | ]}t jt  |d d f  |d d f  d dqS )Nr   r   rg   r1   r   r   r=   rF   r   rG   r@   rA   r         6z5TestAudioLosses.test_mse_weighted.<locals>.<listcomp>rV   r$   r   N)rg   rh   ri   rk   r   r   r*   rj   rZ   r'   rO   r   r2   rl   rm   rn   )r3   r   r   r4   r5   r   r6   r7   r%   r   rq   r   r   r9   rs   rt   ru   r   mser   r@   r   rA   test_mse_weighted  H   


z!TestAudioLosses.test_mse_weightedc                    Z  d}d}d}d}d}d}|dkr||||fn|||f}	|dkr!dnd	t jj|d
}
t|d}t|D ]w}|
j|	d|
jddd|
jjd }| |
jd||d}t	
}t	
}t	
|}d}t|D ]\  fddt|D }t t |}||7 }qj|| }||||d}t j|   ||dsJ d| q3dS )z'Test MSE calculation with input length.r   r   r   r   r   r   r   r   re   r`   r   r   r~   rd   r   r{   r   c                    sH   g | ] }t jt  |d df  |d df  d dqS ).Nr   r   r   r   r=   r   rF   r   rG   r@   rA   r   J  s    :z9TestAudioLosses.test_mse_input_length.<locals>.<listcomp>r   r$   r   N)rg   rh   ri   r   r*   rj   rk   rZ   r   r'   rO   r   r1   r   r2   rl   rm   rn   )r3   r   r   r4   r   r   r6   r7   r%   r   rq   r   r9   rs   r"   rt   ru   r   r   r   r   r@   r   rA   test_mse_input_length  H   




z%TestAudioLosses.test_mse_input_lengthc                 C   r   )z"Test MSE with unsupported weights.re   r   r|   Nr_   r.   r/   rJ   r   r   r@   r@   rA   test_mse_invalid_weightZ  r   z'TestAudioLosses.test_mse_invalid_weightc                 C   r   )z$Test MSE with unsupported reduction.r   r   Nr   r   r@   r@   rA   test_mse_invalid_reductione     "z*TestAudioLosses.test_mse_invalid_reductionc                 C   n   t t tdd W d   n1 sw   Y  t t tdd W d   dS 1 s0w   Y  dS )z%Test MSE with unsupported dimensions.r   r   N   r   r   r@   r@   rA   test_mse_invalid_ndiml     "z%TestAudioLosses.test_mse_invalid_ndimc              	   C   sz  d}d}d}d}d}d}|dkr||||fn|||f}	|dkr!dnd	}
t |d
}tjj|d}t|D ]}|j|	d}|jddd|j|	d }|| }||jd	dd7 }||jd	dd7 }t|}t|}t	||f}t|D ])}t|D ]"}|||ddf |||ddf  }tj
t||
d|||f< qyqs|||d}tj|   |
 |dsJ d| q3dS )zTest MAE calculationr   r   r   r   r   r   r   r   re   r   r`   r   rc   r   rd   Nr   rV   r$   zMAE not matching for example )r   rg   rh   ri   r*   rj   rk   r'   rO   r-   r1   r   r2   rl   rm   rn   )r3   r   r   r4   r5   r   r6   r7   r%   r   r   mae_lossrq   r9   rG   rs   rF   rt   ru   
golden_maer=   rw   r   uut_mae_lossr@   r@   rA   test_maew  sD   


$zTestAudioLosses.test_maec                    r   )z0Test MAE calculation with weighting for channelsr   r   r   r   r   r   r   r   re   r`   rc   r    r{   r   r   r~   rd   r   c                    s@   g | ]}t jt  |d d f  |d d f  dqS )Nr   r   r   r   r@   rA   r     s    2z5TestAudioLosses.test_mae_weighted.<locals>.<listcomp>rV   r$   !MAELoss not matching for example N)rg   rh   ri   rk   r   r   r*   rj   rZ   r'   rO   r   r2   rl   rm   rn   )r3   r   r   r4   r5   r   r6   r7   r%   r   rq   r   r   r9   rs   rt   ru   r   maer   r@   r   rA   test_mae_weighted  r   z!TestAudioLosses.test_mae_weightedc                    r   )z'Test MAE calculation with input length.r   r   r   r   r   r   r   r   re   r`   r   r   r~   rd   r   r{   r   c                    sD   g | ]}t jt  |d df  |d df  dqS ).Nr   r   r   r   r@   rA   r     r   z9TestAudioLosses.test_mae_input_length.<locals>.<listcomp>r   r$   r   N)rg   rh   ri   r   r*   rj   rk   rZ   r   r'   rO   r   r1   r   r2   rl   rm   rn   )r3   r   r   r4   r   r   r6   r7   r%   r   rq   r   r9   rs   r"   rt   ru   r   r   r   r   r@   r   rA   test_mae_input_length  r   z%TestAudioLosses.test_mae_input_lengthc                 C   r   )zTest MAE with invalid weights.re   r   r|   Nr_   r.   r/   rJ   r   r   r@   r@   rA   test_mae_invalid_weight&  r   z'TestAudioLosses.test_mae_invalid_weightc                 C   r   )z Test MAE with invalid reduction.r   r   Nr   r   r@   r@   rA   test_mae_invalid_reduction1  r   z*TestAudioLosses.test_mae_invalid_reductionc                 C   r   )z!Test MAE with invalid dimensions.r   r   Nr   r   r   r@   r@   rA   test_mae_invalid_ndim8  r   z%TestAudioLosses.test_mae_invalid_ndimz'Modules in this test require torchaudio)reasonc                 C   s   t j|ddd}d}g d}d}|D ]j}|\}}|\}	}
}}}ttj r)dnd}td	d
dd|	|
|||d	|}t	t
|t
j}|||ddf}||}t	t
|jt
j|ddf|}|j||d }t
j|||ds}J qd S )Naudiomaxinez	input.binrc   )))r   r   r   TTg      T@))r   r   r   TTg)0?))r   r   r   TTgH}83@))r   r   r   TTg	h"lLY@   zcuda:0rl   i>  i  i  i@  )	sample_rate
fft_length
hop_lengthnum_melssisnr_loss_weightasr_loss_weightspectral_loss_weightuse_asr_lossuse_mel_specr   re   rV   r$   )ospathjoinr'   devicecudais_availabler   torO   rg   fromfilefloat32repeatreshaper-   rZ   forwardrl   r2   )r3   test_data_dirINPUT_LOCATIONATOLGOLDEN_VALUESr4   valueconfigr?   sisnr_wtasr_wtspec_wtuse_asruse_specr   loss_instance
input_datarF   lossr@   r@   rA   test_maxine_combined_lossC  s8   

(z)TestAudioLosses.test_maxine_combined_lossN)+__name__
__module____qualname__r.   markunitparametrizeintboolrB   rL   rR   r[   r]   r^   rz   r   r   r   r   floatr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   skipifHAVE_TORCHAUDIOr  r@   r@   r@   rA   r   2   s    -
	



A/0220$[2




;89




589




r   )r   rn   rg   r.   r'   #nemo.collections.audio.losses.audior   r   r   r   r   r   r   r	   r
   	importlibimport_moduler  ModuleNotFoundError$nemo.collections.audio.losses.maxiner   (nemo.collections.audio.parts.utils.audior   r   r   r   r@   r@   r@   rA   <module>   s   ,
