o
    wi E                     @   s  d dl Z d dlZd dlZd dlmZ d dlmZmZmZ d dl	Z	d dl
Z
d dlmZ d dlmZmZ d dlmZ d dlmZmZ d dlmZ d d	lmZ G d
d de
jjZG dd deZG dd deZe	j e	  dd Z!G dd deeZ"e	  dd Z#G dd dZ$dS )    N)	dataclass)AnyDictList)	open_dict)
DataLoaderDataset_speech_collate_fn)TranscribeConfigTranscriptionMixin)GenericTranscriptionType)
Hypothesisc                       s$   e Zd Z fddZdd Z  ZS )
DummyModelc                    s0   t    tjdd| _d| _d| _d| _d S )N   r   F)	super__init__torchnnLinearencoderexecution_count
flag_beginflag_endself	__class__ l/home/ubuntu/sommelier/.venv/lib/python3.10/site-packages/tests/collections/asr/mixins/test_transcription.pyr   !   s
   

zDummyModel.__init__c                 C   s   |  |}|S N)r   )r   xoutr   r   r   forward)   s   
zDummyModel.forward)__name__
__module____qualname__r   r#   __classcell__r   r   r   r   r       s    r   c                   @   s2   e Zd Zdee defddZdd Zdd Zd	S )
DummyDatasetAudioOnlyaudio_filesconfigc                 C      || _ || _d S r    )r)   r*   )r   r)   r*   r   r   r   r   0      
zDummyDatasetAudioOnly.__init__c                 C   s$   | j | }tt|gd}|S )Nr   )r)   r   tensorfloatview)r   indexdatar   r   r   __getitem__4   s   
z!DummyDatasetAudioOnly.__getitem__c                 C   
   t | jS r    )lenr)   r   r   r   r   __len__9      
zDummyDatasetAudioOnly.__len__N	r$   r%   r&   r   strr   r   r2   r5   r   r   r   r   r(   /   s    r(   c                   @   s4   e Zd Zd
dee defddZdd Zdd	 ZdS )DummyDatasetNaudio_tensorsr*   c                 C   r+   r    )r:   r*   )r   r:   r*   r   r   r   r   >   r,   zDummyDataset.__init__c                 C   sX   | j | }t|}tj|jd tjd}tjdgtjd}tjdtjd}||||fS )Nr   dtyper   )r:   r   r-   shapelong)r   r0   r1   samplesseq_lentext_tokenstext_tokens_lenr   r   r   r2   B   s   

zDummyDataset.__getitem__c                 C   r3   r    )r4   r:   r   r   r   r   r5   N   r6   zDummyDataset.__len__r    r7   r   r   r   r   r9   =   s    r9   c                 C   s`   ddl }tj| ddddd}tj| ddddd}|j|d	d
\}}|j|d	d
\}}||fS )z4
    Returns a list of audio files for testing.
    r   Nasrtrainan4wavan46-mmap-b.wavzan104-mrcb-b.wavfloat32r;   )	soundfileospathjoinread)test_data_dirsfaudio_file1audio_file2audio1_audio2r   r   r   r)   R   s   r)   c                       s   e Zd Zdef fddZdee dedefddZded	e	fd
dZ
dedefddZded	efddZdef fddZ  ZS )TranscribableDummytrcfgc                    s   t  || d| _d S NT)r   _transcribe_on_beginr   )r   audiorV   r   r   r   rX   d   s   
z'TranscribableDummy._transcribe_on_beginr)   temp_dirc           	      C   s   t j|d}t|ddd}|D ]}|ddd}|t|d  qW d    n1 s.w   Y  ||j||j|j	d	}|S )
Nzdummy_manifest.jsonwzutf-8)encodingi  )audio_filepathdurationtext
)paths2audio_files
batch_sizerZ   num_workerschannel_selector)
rJ   rK   rL   openwritejsondumpsrc   rd   re   )	r   r)   rZ   rV   manifest_pathfp
audio_fileentry	ds_configr   r   r   %_transcribe_input_manifest_processingh   s   z8TranscribableDummy._transcribe_input_manifest_processingr*   returnc                 C   s(   t |d |}t||d |d dddS )Nrb   rc   rd   F)datasetrc   rd   
pin_memory	drop_last)r(   r   )r   r*   rq   r   r   r   _setup_transcribe_dataloaderz   s   z/TranscribableDummy._setup_transcribe_dataloaderbatchc                 C   s   | |}|S r    r   )r   ru   rV   outputr   r   r   _transcribe_forward   s   z&TranscribableDummy._transcribe_forwardc                 C   s   |  j d7  _ g }|D ]}|t|  qt|dr'|jdkr'd|i}|S t|dr:|jdkr:dd |D }|S t|drJ|jdkrJt|}|S |S )	Nr   output_typedictrv   dict2c                 S   s   g | ]}d |iqS )rv   r   ).0resr   r   r   
<listcomp>   s    zDTranscribableDummy._transcribe_output_processing.<locals>.<listcomp>tuple)r   appendr.   itemhasattrrx   r~   )r   outputsrV   resultrv   resultsr   r   r   _transcribe_output_processing   s   z0TranscribableDummy._transcribe_output_processingc                    s   t  | d| _d S rW   )r   _transcribe_on_endr   )r   rV   r   r   r   r      s   
z%TranscribableDummy._transcribe_on_end)r$   r%   r&   r   rX   r   r8   ro   r   r   rt   r   rw   r   r   r   r'   r   r   r   r   rU   c   s    rU   c                   C   s   t  S r    )rU   r   r   r   r   dummy_model   s   r   c                   @   sl  e Zd Zejjdd Zejjdd Zejjdd Zejjdd Z	ejjd	d
 Z
ejjdd Zejjdd Zej  ejjdd Zej ejjdd Zej ejjdd Zej ejjdd Zejjdd Zejjdd Zej ejjdd Zej ejjdd Zej ejjdd  Zej ejjd!d" Zd#S )$TestTranscriptionMixinc                 C   s&   t  }t|tr
J t|drJ d S )N
transcribe)r   
isinstancer   r   )r   modelr   r   r   test_constructor_non_instance   s   z4TestTranscriptionMixin.test_constructor_non_instancec                 C   s   |  }|jjjd |jjjd g d}|j|dd}t|dks'J |d dks/J |d dks7J |d	 d
ks?J d S )N      ?        z1.0z2.0z3.0r   rc      r          @         @)evalr   weightr1   fill_biasr   r4   )r   r   rY   r   r   r   r   test_transcribe   s   z&TestTranscriptionMixin.test_transcribec                 C   s   |  }|jjjd |jjjd g d}tdd}|j||d}g }d}|D ]}|| t	|dks9J t	||ksAJ |d7 }q*t	|dksNJ |d dksVJ |d d	ks^J |d
 dksfJ d S )Nr   r   r   r   r   override_configr   r   r   r   r   )
r   r   r   r1   r   r   r   transcribe_generatorextendr4   r   r   rY   transribe_config	generatorr   r0   r   r   r   r   test_transcribe_generator   s"   


z0TestTranscriptionMixin.test_transcribe_generatorc                 C   s   |  }|jjjd |jjjd g d}tdd}|j||d}g }d}	 zt|}W n	 t	y8   Y nw |
| t|dksFJ t||ksNJ |d7 }q)t|dks[J |d	 dkscJ |d d
kskJ |d dkssJ d S )Nr   r   r   r   r   r   Tr   r   r   r   r   )r   r   r   r1   r   r   r   r   nextStopIterationr   r4   r   r   r   r   -test_transcribe_generator_explicit_stop_check   s.   


zDTestTranscriptionMixin.test_transcribe_generator_explicit_stop_checkc                 C   s6   |  }g d}|j|dd |jsJ |jsJ d S )Nr   r   r   )r   r   r   r   )r   r   rY   r   r   r   test_transcribe_check_flags   s
   
z2TestTranscriptionMixin.test_transcribe_check_flagsc                 C   sl   t G dd d}| }g d}|ddd}tt |j||d}W d    d S 1 s/w   Y  d S )Nc                   @   &   e Zd ZU dZeed< dZeed< dS )zWTestTranscriptionMixin.test_transribe_override_config_incorrect.<locals>.OverrideConfigr   rc   ry   rx   N)r$   r%   r&   rc   int__annotations__rx   r8   r   r   r   r   OverrideConfig      
 r   )r   r   r   r   ry   rc   rx   r   )r   r   pytestraises
ValueErrorr   )r   r   r   rY   override_cfgrS   r   r   r   (test_transribe_override_config_incorrect   s   "z?TestTranscriptionMixin.test_transribe_override_config_incorrectc                 C   s  t G dd dt}| }|jjjd |jjjd g d}|ddd}|j||d	}t	|t
s6J t|dks>J |jd
ksEJ |d d dksOJ |d d dksYJ |d d dkscJ d|_|ddd}|j||d	}t	|tszJ t|d
ksJ |jd
ksJ |d d dksJ |d d dksJ |d d dksJ d|_|ddd}|j||d	}t	|tsJ t|dksJ |jd
ksJ |d d dksJ |d d dksJ |d d dksJ d S )Nc                   @   r   )zUTestTranscriptionMixin.test_transribe_override_config_correct.<locals>.OverrideConfigry   rx   FverboseN)r$   r%   r&   rx   r8   r   r   boolr   r   r   r   r     r   r   r   r   r   r   ry   r   r   r   rv   r   r   r   r   rz   r~   )r   r   r   r   r   r1   r   r   r   r   ry   r4   r   listr~   )r   r   r   rY   r   r   r   r   r   &test_transribe_override_config_correct	  s@   z=TestTranscriptionMixin.test_transribe_override_config_correctc                 C   s   t j|ddddd}|j|ddd}t|dksJ t|d	 ts$J |d	 }t|jts0J t|j	t
js9J t|jt
jsBJ d S )
NrC   rD   rE   rF   rG   r   T)rc   return_hypothesesr   )rJ   rK   rL   r   r4   r   r   r`   r8   
y_sequencer   Tensor
alignments)r   rN   fast_conformer_ctc_modelrl   r   hypr   r   r   !test_transcribe_return_hypothesis=  s   z8TestTranscriptionMixin.test_transcribe_return_hypothesisc                 C   s<   |\}}|j |dd}t|dksJ t|d tsJ d S )Nr   r   r   )r   r4   r   r   )r   r)   r   rY   rS   r   r   r   r   test_transcribe_tensorK  s   z-TestTranscriptionMixin.test_transcribe_tensorc                 C   s\   |\}}t |}|j||gdd}t|dksJ t|d ts#J t|d ts,J d S )Nr   r   r   r   )r   r-   r   r4   r   r   )r   r)   r   rY   audio_2r   r   r   r   test_transcribe_multiple_tensorU  s   
z6TestTranscriptionMixin.test_transcribe_multiple_tensorc           	      C   st   |\}}t ||g}dd }t|ddd|d}|j|dd}t|dks&J t|d ts/J t|d ts8J d S )	Nc                 S   s   t | ddS )Nr   )pad_idr	   )r!   r   r   r   <lambda>j  s    zCTestTranscriptionMixin.test_transcribe_dataloader.<locals>.<lambda>r   Fr   )rc   shufflerd   
collate_fnr   r   )r9   r   r   r4   r   r   )	r   r)   r   rY   rT   rq   r   
dataloaderr   r   r   r   test_transcribe_dataloaderc  s   z1TestTranscriptionMixin.test_transcribe_dataloaderc                 C   s   |   |\}}t|jj}t|jj}t| d|d< d|d d< d|d d< d|d d< W d    n1 s;w   Y  || |j||gd	dd
}t|dksWJ t	dd |D sbJ t	dd |D smJ t	dd |D sxJ || d S )Nmalsd_batchstrategy   beam	beam_sizeFreturn_best_hypothesisallow_cuda_graphsr   rc   
timestampsr   c                 s       | ]	}t |d kV  qdS r   Nr4   r{   rv   r   r   r   	<genexpr>      zKTestTranscriptionMixin.test_transcribe_return_nbest_rnnt.<locals>.<genexpr>c                 s       | ]}t |tV  qd S r    r   r   r   r   r   r   r         c                 s   $    | ]}|D ]}t |tV  qqd S r    r   r   r{   rv   r   r   r   r   r        " 
r   copydeepcopycfgdecodingr   change_decoding_strategyr   r4   all)r   r)   fast_conformer_transducer_modelrR   rT   orig_decoding_configdecoding_configr   r   r   r   !test_transcribe_return_nbest_rnnts  s"   

z8TestTranscriptionMixin.test_transcribe_return_nbest_rnntc                 C   s   |   |\}}t|jj}t|jj}t| d|d d< d|d d< W d    n1 s1w   Y  || |j||gddd}t|dksMJ t	d	d
 |D sXJ t	dd
 |D scJ t	dd
 |D snJ || d S )Nr   r   r   Fr   r   r   r   c                 s   r   r   r   r   r   r   r   r     r   zMTestTranscriptionMixin.test_transcribe_return_nbest_canary.<locals>.<genexpr>c                 s   r   r    r   r   r   r   r   r     r   c                 s   r   r    r   r   r   r   r   r     r   r   )r   r)   canary_1b_flashrR   rT   r   r   r   r   r   r   #test_transcribe_return_nbest_canary  s   

z:TestTranscriptionMixin.test_transcribe_return_nbest_canaryc                 C      |\}}|j ||gdd}t|dksJ t|d tsJ |d jdks'J |d jdks0J |d jd d d td	ksBJ |d jd d d
 tdksTJ d S )NTr   r   r   stopr   startsegment皙?endQ?r   r4   r   r   r`   	timestampr   approx)r   r)   r   rR   rT   rv   r   r   r   test_timestamps_with_transcribe     $(z6TestTranscriptionMixin.test_timestamps_with_transcribec                 C   s   |\}}|j ||gdd}t|dksJ t|d tsJ |d jdks'J |d jdks0J |d jd d d	 td
ksBJ |d jd d d tdksTJ d S )NTr   r   r   Stop?r   Start.r   r   r   r   
ףp=
?r   r   r)   fast_conformer_hybrid_modelrR   rT   rv   r   r   r   &test_timestamps_with_transcribe_hybrid  r   z=TestTranscriptionMixin.test_timestamps_with_transcribe_hybridc                 C   s   |\}}|j dd |j||gdd}t|dksJ t|d ts$J |d jdv s-J |d jd	v s6J |d jd
 d d tdksHJ |d jd
 d d tdksZJ d S )Nctc)decoder_typeTr   r   r   )Stopr   r   )Startr   r   r   r   r   r   )	r   r   r4   r   r   r`   r   r   r   r   r   r   r   /test_timestamps_with_transcribe_hybrid_ctc_head  s   $(zFTestTranscriptionMixin.test_timestamps_with_transcribe_hybrid_ctc_headc                 C   r   )NTr   r   r   r   r   r   r   g{Gz?r   r   r   )r   r)   r   rR   rT   rv   r   r   r   ,test_timestamps_with_transcribe_canary_flash  r   zCTestTranscriptionMixin.test_timestamps_with_transcribe_canary_flashN)r$   r%   r&   r   markunitr   r   r   r   r   r   r   with_downloadsr   r   r   r   r   r   r   r   r   r   r   r   r   r   r      sV    







1


r   )%r   rh   rJ   dataclassesr   typingr   r   r   r   r   	omegaconfr   torch.utils.datar   r   'nemo.collections.asr.data.audio_to_textr
   !nemo.collections.asr.parts.mixinsr   r   /nemo.collections.asr.parts.mixins.transcriptionr    nemo.collections.asr.parts.utilsr   r   Moduler   r(   r9   r  r  fixturer)   rU   r   r   r   r   r   r   <module>   s.   A
