o
    }oiGE                     @   s
  d dl Z d dlZd dlZd dlmZ d dlmZmZmZ d dl	Z	d dl
Z
d dlmZ d dlmZmZ d dlmZ d dlmZmZ d dlmZ d d	lmZ G d
d de
jjZe	j e	 dd ZG dd deeZ G dd deZ!e	 dd Z"G dd dZ#dS )    N)	dataclass)AnyDictList)	open_dict)
DataLoaderDataset_speech_collate_fn)TranscribeConfigTranscriptionMixin)GenericTranscriptionType)
Hypothesisc                       s$   e Zd Z fddZdd Z  ZS )
DummyModelc                    s0   t    tjdd| _d| _d| _d| _d S )N   r   F)	super__init__torchnnLinearencoderexecution_count
flag_beginflag_endself	__class__ c/home/ubuntu/.local/lib/python3.10/site-packages/tests/collections/asr/mixins/test_transcription.pyr   !   s
   

zDummyModel.__init__c                 C   s   |  |}|S N)r   )r   xoutr   r   r   forward)   s   
zDummyModel.forward)__name__
__module____qualname__r   r#   __classcell__r   r   r   r   r       s    r   c                 C   s`   ddl }tj| ddddd}tj| ddddd}|j|d	d
\}}|j|d	d
\}}||fS )z4
    Returns a list of audio files for testing.
    r   Nasrtrainan4wavan46-mmap-b.wavzan104-mrcb-b.wavfloat32dtype)	soundfileospathjoinread)test_data_dirsfaudio_file1audio_file2audio1_audio2r   r   r   audio_files/   s   r<   c                       s   e Zd Zdef fddZdee dedefddZded	e	fd
dZ
dedefddZded	efddZdef fddZ  ZS )TranscribableDummytrcfgc                    s   t  || d| _d S NT)r   _transcribe_on_beginr   )r   audior>   r   r   r   r@   A   s   
z'TranscribableDummy._transcribe_on_beginr<   temp_dirc           	      C   s   t j|d}t|ddd}|D ]}|ddd}|t|d  qW d    n1 s.w   Y  ||j||j|j	d	}|S )
Nzdummy_manifest.jsonwzutf-8)encodingi  )audio_filepathdurationtext
)paths2audio_files
batch_sizerB   num_workerschannel_selector)
r1   r2   r3   openwritejsondumpsrK   rL   rM   )	r   r<   rB   r>   manifest_pathfp
audio_fileentry	ds_configr   r   r   %_transcribe_input_manifest_processingE   s   z8TranscribableDummy._transcribe_input_manifest_processingconfigreturnc                 C   s8   G dd dt }||d |}t||d |d dddS )Nc                   @   s2   e Zd Zdee defddZdd Zdd Zd	S )
zETranscribableDummy._setup_transcribe_dataloader.<locals>.DummyDatasetr<   rX   c                 S      || _ || _d S r    )r<   rX   )r   r<   rX   r   r   r   r   Y      
zNTranscribableDummy._setup_transcribe_dataloader.<locals>.DummyDataset.__init__c                 S   s$   | j | }tt|gd}|S )Nr   )r<   r   tensorfloatview)r   indexdatar   r   r   __getitem__]   s   
zQTranscribableDummy._setup_transcribe_dataloader.<locals>.DummyDataset.__getitem__c                 S   
   t | jS r    )lenr<   r   r   r   r   __len__b      
zMTranscribableDummy._setup_transcribe_dataloader.<locals>.DummyDataset.__len__N	r$   r%   r&   r   strr   r   ra   rd   r   r   r   r   DummyDatasetX   s    rh   rJ   rK   rL   F)datasetrK   rL   
pin_memory	drop_last)r   r   )r   rX   rh   ri   r   r   r   _setup_transcribe_dataloaderW   s   z/TranscribableDummy._setup_transcribe_dataloaderbatchc                 C   s   | |}|S r    r   )r   rm   r>   outputr   r   r   _transcribe_forwardo   s   z&TranscribableDummy._transcribe_forwardc                 C   s   |  j d7  _ g }|D ]}|t|  qt|dr'|jdkr'd|i}|S t|dr:|jdkr:dd |D }|S t|drJ|jdkrJt|}|S |S )	Nr   output_typedictrn   dict2c                 S   s   g | ]}d |iqS )rn   r   ).0resr   r   r   
<listcomp>   s    zDTranscribableDummy._transcribe_output_processing.<locals>.<listcomp>tuple)r   appendr]   itemhasattrrp   rv   )r   outputsr>   resultrn   resultsr   r   r   _transcribe_output_processings   s   z0TranscribableDummy._transcribe_output_processingc                    s   t  | d| _d S r?   )r   _transcribe_on_endr   )r   r>   r   r   r   r~      s   
z%TranscribableDummy._transcribe_on_end)r$   r%   r&   r   r@   r   rg   rW   r   r   rl   r   ro   r   r}   r~   r'   r   r   r   r   r=   @   s    r=   c                   @   s4   e Zd Zd
dee defddZdd Zdd	 ZdS )rh   Naudio_tensorsrX   c                 C   rZ   r    )r   rX   )r   r   rX   r   r   r   r      r[   zDummyDataset.__init__c                 C   sX   | j | }t|}tj|jd tjd}tjdgtjd}tjdtjd}||||fS )Nr   r.   r   )r   r   r\   shapelong)r   r_   r`   samplesseq_lentext_tokenstext_tokens_lenr   r   r   ra      s   

zDummyDataset.__getitem__c                 C   rb   r    )rc   r   r   r   r   r   rd      re   zDummyDataset.__len__r    rf   r   r   r   r   rh      s    rh   c                   C   s   t  S r    )r=   r   r   r   r   dummy_model   s   r   c                   @   sl  e Zd Zejjdd Zejjdd Zejjdd Zejjdd Z	ejjd	d
 Z
ejjdd Zejjdd Zej  ejjdd Zej ejjdd Zej ejjdd Zej ejjdd Zejjdd Zejjdd Zej ejjdd Zej ejjdd Zej ejjdd  Zej ejjd!d" Zd#S )$TestTranscriptionMixinc                 C   s&   t  }t|tr
J t|drJ d S )N
transcribe)r   
isinstancer   ry   )r   modelr   r   r   test_constructor_non_instance   s   z4TestTranscriptionMixin.test_constructor_non_instancec                 C   s   |  }|jjjd |jjjd g d}|j|dd}t|dks'J |d dks/J |d dks7J |d	 d
ks?J d S )N      ?        z1.0z2.0z3.0r   rK      r          @         @)evalr   weightr`   fill_biasr   rc   )r   r   rA   rz   r   r   r   test_transcribe   s   z&TestTranscriptionMixin.test_transcribec                 C   s   |  }|jjjd |jjjd g d}tdd}|j||d}g }d}|D ]}|| t	|dks9J t	||ksAJ |d7 }q*t	|dksNJ |d dksVJ |d d	ks^J |d
 dksfJ d S )Nr   r   r   r   r   override_configr   r   r   r   r   )
r   r   r   r`   r   r   r   transcribe_generatorextendrc   r   r   rA   transribe_config	generatorrz   r_   r{   r   r   r   test_transcribe_generator   s"   


z0TestTranscriptionMixin.test_transcribe_generatorc                 C   s   |  }|jjjd |jjjd g d}tdd}|j||d}g }d}	 zt|}W n	 t	y8   Y nw |
| t|dksFJ t||ksNJ |d7 }q)t|dks[J |d	 dkscJ |d d
kskJ |d dkssJ d S )Nr   r   r   r   r   r   Tr   r   r   r   r   )r   r   r   r`   r   r   r   r   nextStopIterationr   rc   r   r   r   r   -test_transcribe_generator_explicit_stop_check   s.   


zDTestTranscriptionMixin.test_transcribe_generator_explicit_stop_checkc                 C   s6   |  }g d}|j|dd |jsJ |jsJ d S )Nr   r   r   )r   r   r   r   )r   r   rA   r   r   r   test_transcribe_check_flags   s
   
z2TestTranscriptionMixin.test_transcribe_check_flagsc                 C   sl   t G dd d}| }g d}|ddd}tt |j||d}W d    d S 1 s/w   Y  d S )Nc                   @   &   e Zd ZU dZeed< dZeed< dS )zWTestTranscriptionMixin.test_transribe_override_config_incorrect.<locals>.OverrideConfigr   rK   rq   rp   N)r$   r%   r&   rK   int__annotations__rp   rg   r   r   r   r   OverrideConfig      
 r   )r   r   r   r   rq   rK   rp   r   )r   r   pytestraises
ValueErrorr   )r   r   r   rA   override_cfgr:   r   r   r   (test_transribe_override_config_incorrect   s   "z?TestTranscriptionMixin.test_transribe_override_config_incorrectc                 C   s  t G dd dt}| }|jjjd |jjjd g d}|ddd}|j||d	}t	|t
s6J t|dks>J |jd
ksEJ |d d dksOJ |d d dksYJ |d d dkscJ d|_|ddd}|j||d	}t	|tszJ t|d
ksJ |jd
ksJ |d d dksJ |d d dksJ |d d dksJ d|_|ddd}|j||d	}t	|tsJ t|dksJ |jd
ksJ |d d dksJ |d d dksJ |d d dksJ d S )Nc                   @   r   )zUTestTranscriptionMixin.test_transribe_override_config_correct.<locals>.OverrideConfigrq   rp   FverboseN)r$   r%   r&   rp   rg   r   r   boolr   r   r   r   r   
  r   r   r   r   r   r   rq   r   r   r   rn   r   r   r   r   rr   rv   )r   r   r   r   r   r`   r   r   r   r   rq   rc   r   listrv   )r   r   r   rA   r   rz   r   r   r   &test_transribe_override_config_correct  s@   z=TestTranscriptionMixin.test_transribe_override_config_correctc                 C   s   t j|ddddd}|j|ddd}t|dksJ t|d	 ts$J |d	 }t|jts0J t|j	t
js9J t|jt
jsBJ d S )
Nr(   r)   r*   r+   r,   r   T)rK   return_hypothesesr   )r1   r2   r3   r   rc   r   r   rH   rg   
y_sequencer   Tensor
alignments)r   r5   fast_conformer_ctc_modelrT   rz   hypr   r   r   !test_transcribe_return_hypothesis<  s   z8TestTranscriptionMixin.test_transcribe_return_hypothesisc                 C   s<   |\}}|j |dd}t|dksJ t|d tsJ d S )Nr   r   r   )r   rc   r   r   )r   r<   r   rA   r:   rz   r   r   r   test_transcribe_tensorJ  s   z-TestTranscriptionMixin.test_transcribe_tensorc                 C   s\   |\}}t |}|j||gdd}t|dksJ t|d ts#J t|d ts,J d S )Nr   r   r   r   )r   r\   r   rc   r   r   )r   r<   r   rA   audio_2rz   r   r   r   test_transcribe_multiple_tensorT  s   
z6TestTranscriptionMixin.test_transcribe_multiple_tensorc           	      C   st   |\}}t ||g}dd }t|ddd|d}|j|dd}t|dks&J t|d ts/J t|d ts8J d S )	Nc                 S   s   t | ddS )Nr   )pad_idr	   )r!   r   r   r   <lambda>i  s    zCTestTranscriptionMixin.test_transcribe_dataloader.<locals>.<lambda>r   Fr   )rK   shufflerL   
collate_fnr   r   )rh   r   r   rc   r   r   )	r   r<   r   rA   r;   ri   r   
dataloaderrz   r   r   r   test_transcribe_dataloaderb  s   z1TestTranscriptionMixin.test_transcribe_dataloaderc                 C   s   |   |\}}t|jj}t|jj}t| d|d< d|d d< d|d d< d|d d< W d    n1 s;w   Y  || |j||gd	dd
}t|dksWJ t	dd |D sbJ t	dd |D smJ t	dd |D sxJ || d S )Nmalsd_batchstrategy   beam	beam_sizeFreturn_best_hypothesisallow_cuda_graphsr   rK   
timestampsr   c                 s       | ]	}t |d kV  qdS r   Nrc   rs   rn   r   r   r   	<genexpr>      zKTestTranscriptionMixin.test_transcribe_return_nbest_rnnt.<locals>.<genexpr>c                 s       | ]}t |tV  qd S r    r   r   r   r   r   r   r         c                 s   $    | ]}|D ]}t |tV  qqd S r    r   r   rs   rn   r   r   r   r   r        " 
r   copydeepcopycfgdecodingr   change_decoding_strategyr   rc   all)r   r<   fast_conformer_transducer_modelr9   r;   orig_decoding_configdecoding_configrz   r   r   r   !test_transcribe_return_nbest_rnntr  s"   

z8TestTranscriptionMixin.test_transcribe_return_nbest_rnntc                 C   s   |   |\}}t|jj}t|jj}t| d|d d< d|d d< W d    n1 s1w   Y  || |j||gddd}t|dksMJ t	d	d
 |D sXJ t	dd
 |D scJ t	dd
 |D snJ || d S )Nr   r   r   Fr   r   r   r   c                 s   r   r   r   r   r   r   r   r     r   zMTestTranscriptionMixin.test_transcribe_return_nbest_canary.<locals>.<genexpr>c                 s   r   r    r   r   r   r   r   r     r   c                 s   r   r    r   r   r   r   r   r     r   r   )r   r<   canary_1b_flashr9   r;   r   r   rz   r   r   r   #test_transcribe_return_nbest_canary  s   

z:TestTranscriptionMixin.test_transcribe_return_nbest_canaryc                 C      |\}}|j ||gdd}t|dksJ t|d tsJ |d jdks'J |d jdks0J |d jd d d td	ksBJ |d jd d d
 tdksTJ d S )NTr   r   r   stopr   startsegment皙?endQ?r   rc   r   r   rH   	timestampr   approx)r   r<   r   r9   r;   rn   r   r   r   test_timestamps_with_transcribe     $(z6TestTranscriptionMixin.test_timestamps_with_transcribec                 C   s   |\}}|j ||gdd}t|dksJ t|d tsJ |d jdks'J |d jdks0J |d jd d d	 td
ksBJ |d jd d d tdksTJ d S )NTr   r   r   zStop?r   Start.r   r   r   r   
ףp=
?r   r   r<   fast_conformer_hybrid_modelr9   r;   rn   r   r   r   &test_timestamps_with_transcribe_hybrid  r   z=TestTranscriptionMixin.test_timestamps_with_transcribe_hybridc                 C   s   |\}}|j dd |j||gdd}t|dksJ t|d ts$J |d jdks-J |d jd	ks6J |d jd
 d d tdksHJ |d jd
 d d tdksZJ d S )Nctc)decoder_typeTr   r   r   Stopr   r   r   r   r   r   r   )	r   r   rc   r   r   rH   r   r   r   r   r   r   r   /test_timestamps_with_transcribe_hybrid_ctc_head  s   $(zFTestTranscriptionMixin.test_timestamps_with_transcribe_hybrid_ctc_headc                 C   r   )NTr   r   r   r   r   r   r   g{Gz?r   r   r   )r   r<   r   r9   r;   rn   r   r   r   ,test_timestamps_with_transcribe_canary_flash  r   zCTestTranscriptionMixin.test_timestamps_with_transcribe_canary_flashN)r$   r%   r&   r   markunitr   r   r   r   r   r   r   with_downloadsr   r   r   r   r   r   r   r   r   r   r   r   r   r   r      sV    







1


r   )$r   rP   r1   dataclassesr   typingr   r   r   r   r   	omegaconfr   torch.utils.datar   r   'nemo.collections.asr.data.audio_to_textr
   !nemo.collections.asr.parts.mixinsr   r   /nemo.collections.asr.parts.mixins.transcriptionr    nemo.collections.asr.parts.utilsr   r   Moduler   r   r   fixturer<   r=   rh   r   r   r   r   r   r   <module>   s,   N
