o
    Ei@                     @   s   d dl Z d dlZd dlmZ d dlmZmZmZmZm	Z	 d dl
Z
d dlZd dlmZmZmZ d dlmZmZmZ ddlmZmZmZ G dd	 d	Zed
dG dd dZed
dG dd dZG dd dZdS )    N)	dataclass)AnyDictListOptionalUnion)Qwen3ASRConfig Qwen3ASRForConditionalGenerationQwen3ASRProcessor)
AutoConfig	AutoModelAutoProcessor   )	AudioLikeensure_listnormalize_audiosc                   @   s   e Zd Zdd ZdedefddZdedefdd	Zdedefd
dZdede	e fddZ
dede	e fddZdede	e fddZdede	e fddZdede	e fddZde	e fddZdedede	e fddZdd ZdS ) Qwen3ForceAlignProcessorc                 C   s   t jt jtdd}i }t|ddd}|D ]}| }|s!q| d }d||< qW d    n1 s6w   Y  || _d | _	d S )Nassetszkorean_dict_jieba.dictrzutf-8)encodingr   g      ?)
ospathjoindirname__file__openstripsplitko_scoreko_tokenizer)selfko_dict_path	ko_scoresflineword r&   Z/home/ubuntu/training/qwen3-asr-1.7b-phase2-sft/qwen_asr/inference/qwen3_forced_aligner.py__init__&   s   

z!Qwen3ForceAlignProcessor.__init__chreturnc                 C   s2   |dkrdS t |}|ds|drdS dS )N'TLNF)unicodedatacategory
startswith)r    r)   catr&   r&   r'   is_kept_char3   s   
z%Qwen3ForceAlignProcessor.is_kept_chartokenc                    s   d  fdd|D S )N c                 3   s    | ]
}  |r|V  qd S N)r2   ).0r)   r    r&   r'   	<genexpr><   s    z7Qwen3ForceAlignProcessor.clean_token.<locals>.<genexpr>)r   )r    r3   r&   r7   r'   clean_token;   s   z$Qwen3ForceAlignProcessor.clean_tokenc                 C   s   t |}d|  kodkn  pWd|  kodkn  pWd|  ko%dkn  pWd|  ko1dkn  pWd	|  ko=d
kn  pWd|  koIdkn  pWd|  koUdkS   S )Ni N  i  i 4  iM  i   iߦ i  i? i@ i i  i i   i  )ord)r    r)   coder&   r&   r'   is_cjk_char>   s    z$Qwen3ForceAlignProcessor.is_cjk_chartextc                    sd   g g   fdd}|D ]} |r|  | q|r) | q|  q|  S )Nc                     s2    rd  } | }|r| g  d S d S Nr4   )r   r9   append)r3   cleanedcurrent_latinr    tokensr&   r'   flush_latinN   s   


zDQwen3ForceAlignProcessor.tokenize_chinese_mixed.<locals>.flush_latin)r<   r?   r2   )r    r=   rD   r)   r&   rA   r'   tokenize_chinese_mixedJ   s   	

z/Qwen3ForceAlignProcessor.tokenize_chinese_mixedc                 C   s6   t |j}g }|D ]}| |}|r|| q
|S r5   )nagisataggingwordsr9   r?   )r    r=   rH   rC   wr@   r&   r&   r'   tokenize_japanesee   s   

z*Qwen3ForceAlignProcessor.tokenize_japanesec                 C   s4   | |}g }|D ]}| |}|r|| q	|S r5   )tokenizer9   r?   )r    r   r=   
raw_tokensrC   rI   w_cleanr&   r&   r'   tokenize_koreann   s   


z(Qwen3ForceAlignProcessor.tokenize_koreansegc                    sP   g g   fdd}|D ]}|  |r|  | q | q|  S )Nc                      s     r d  g  d S d S r>   )r?   r   r&   bufrC   r&   r'   	flush_buf{   s   zFQwen3ForceAlignProcessor.split_segment_with_chinese.<locals>.flush_buf)r<   r?   )r    rO   rR   r)   r&   rP   r'   split_segment_with_chinesew   s   
z3Qwen3ForceAlignProcessor.split_segment_with_chinesec                 C   s4   g }|  D ]}| |}|r|| | q|S r5   )r   r9   extendrS   )r    r=   rC   rO   r@   r&   r&   r'   tokenize_space_lang   s   
z,Qwen3ForceAlignProcessor.tokenize_space_langc                 C   s   |  }t|}dg| }dg| }td|D ]'}t|D ] }|| || kr=|| d || kr=|| d ||< |||< qqt|}||}g }	|}
|
dkr]|	|
 ||
 }
|
dksP|	  dg| }|	D ]}
d||
< qh| }d}||k ry|| sq|}||k r|| s|d7 }||k r|| r|| }|dkrd }t|d ddD ]}|| r|| } nqd }t||D ]}|| r|| } nqt||D ]$}|d u r|||< q|d u r|||< q||d  || kr|n|||< qn{d }t|d ddD ]}|| r
|| } nqd }t||D ]}|| r|| } nq|d urI|d urI|| |d  }t||D ]}|||| d   ||< q8n%|d ur\t||D ]}|||< qSn|d urnt||D ]}|||< qf|}n|d7 }||k szdd |D S )	Nr   FTr      c                 S   s   g | ]}t |qS r&   )int)r6   resr&   r&   r'   
<listcomp>   s    z:Qwen3ForceAlignProcessor.fix_timestamp.<locals>.<listcomp>)tolistlenrangemaxindexr?   reversecopy)r    datandpparentij
max_lengthmax_idxlis_indicesidx	is_normalresultanomaly_countleft_valk	right_valstepr&   r&   r'   fix_timestamp   s   

$







"	



9z&Qwen3ForceAlignProcessor.fix_timestamplanguagec                 C   s   |  }|  dkr| |}n%|  dkr0| jd u r(ddlm} || jd| _| | j|}n| |}d|d }d| }||fS )Njapanesekoreanr   )
LTokenizer)scoresz<timestamp><timestamp>z)<|audio_start|><|audio_pad|><|audio_end|>)	lowerrJ   r   soynlp.tokenizerrw   r   rN   rU   r   )r    r=   rt   	word_listrw   
input_textr&   r&   r'   encode_timestamp   s   

z)Qwen3ForceAlignProcessor.encode_timestampc           	      C   sR   g }|  |}t|D ]\}}||d  }||d d  }||||d q|S )NrW   r   r=   
start_timeend_time)rs   	enumerater?   )	r    r{   	timestamptimestamp_outputtimestamp_fixedrf   r%   r   r   r&   r&   r'   parse_timestamp   s   

z(Qwen3ForceAlignProcessor.parse_timestampN)__name__
__module____qualname__r(   strboolr2   r9   r<   r   rE   rJ   rN   rS   rU   rX   rs   r}   r   r&   r&   r&   r'   r   %   s    		Yr   T)frozenc                   @   s*   e Zd ZU dZeed< eed< eed< dS )ForcedAlignItema  
    One aligned item span.

    Attributes:
        text (str):
            The aligned unit (cjk character or word) produced by the forced aligner processor.
        start_time (float):
            Start time in seconds.
        end_time (float):
            End time in seconds.
    r=   r   r   N)r   r   r   __doc__r   __annotations__rX   r&   r&   r&   r'   r     s
   
 r   c                   @   s@   e Zd ZU dZee ed< dd Zdd Zde	defd	d
Z
dS )ForcedAlignResultz
    Forced alignment output for one sample.

    Attributes:
        items (List[ForcedAlignItem]):
            Aligned token spans.
    itemsc                 C   
   t | jS r5   )iterr   r7   r&   r&   r'   __iter__+     
zForcedAlignResult.__iter__c                 C   r   r5   )r\   r   r7   r&   r&   r'   __len__.  r   zForcedAlignResult.__len__rk   r*   c                 C   s
   | j | S r5   r   )r    rk   r&   r&   r'   __getitem__1  r   zForcedAlignResult.__getitem__N)r   r   r   r   r   r   r   r   r   rX   r   r&   r&   r&   r'   r      s   
 r   c                
   @   s   e Zd ZdZdededefddZede	dd fd	d
Z
deee	ef  defddZe deeee f dee	ee	 f dee	ee	 f dee fddZdeee	  fddZdS )Qwen3ForcedAlignera  
    A HuggingFace-style wrapper for Qwen3-ForcedAligner model inference.

    This wrapper provides:
      - `from_pretrained()` initialization via HuggingFace AutoModel/AutoProcessor
      - audio input normalization (path/URL/base64/(np.ndarray, sr))
      - batch and single-sample forced alignment
      - structured output with attribute access (`.text`, `.start_time`, `.end_time`)
    model	processoraligner_processorc                 C   s~   || _ || _|| _t|dd | _| jd u r/z
t| j| _W n ty.   td| _Y nw t	|j
j| _t|j
j| _d S )Ndevicecpu)r   r   r   getattrr   next
parametersStopIterationtorchrX   configtimestamp_token_idfloattimestamp_segment_time)r    r   r   r   r&   r&   r'   r(   @  s   
zQwen3ForcedAligner.__init__pretrained_model_name_or_pathr*   c                 K   sv   t dt ttt ttt tj|fi |}t|ts*t	dt
| dtj|dd}t }| |||dS )a  
        Load Qwen3-ForcedAligner model and initialize processors.

        This method:
          1) Registers config/model/processor for HF auto classes.
          2) Loads the model using `AutoModel.from_pretrained(...)`.
          3) Initializes:
             - HF processor (`AutoProcessor.from_pretrained(...)`)
             - forced alignment text processor (`Qwen3ForceAlignProcessor()`)

        Args:
            pretrained_model_name_or_path (str):
                HuggingFace repo id or local directory.
            **kwargs:
                Forwarded to `AutoModel.from_pretrained(...)`.
                Typical examples: device_map="cuda:0", dtype=torch.bfloat16.

        Returns:
            Qwen3ForcedAligner:
                Initialized wrapper instance.
        	qwen3_asrzAutoModel returned z,, expected Qwen3ASRForConditionalGeneration.T)fix_mistral_regex)r   r   r   )r   registerr   r   r	   r   r
   from_pretrained
isinstance	TypeErrortyper   )clsr   kwargsr   r   r   r&   r&   r'   r   T  s   
z"Qwen3ForcedAligner.from_pretrainedr   c                 C   sP   g }|D ]}| tt|ddt|ddt|ddd qt|dS )Nr=   r4   r   r   r   r~   r   )r?   r   r   getr   r   )r    r   r   itr&   r&   r'   _to_structured_items~  s   
z'Qwen3ForcedAligner._to_structured_itemsaudior=   rt   c                 C   s  t |}t |}t|}t|dkrt|dkr|t| }t|t|  kr.t|ksBn tdt| dt| dt| g }g }t||D ]\}	}
| j|	|
\}}|| || qK| j||ddd}|	| j
j	| j
j}| j
jdi |j}|jdd	}g }t|d
 ||D ]@\}}}||| jk }|| j 	d }| j||}|D ]}t|d d d|d< t|d d d|d< q|| | q|S )a  
        Run forced alignment for batch or single sample.

        Args:
            audio:
                Audio input(s). Each item supports:
                  - local path / https URL / base64 string
                  - (np.ndarray, sr)
                All audios will be converted into mono 16k float32 arrays in [-1, 1].
            text:
                Transcript(s) for alignment.
            language:
                Language(s) for each sample (e.g., "Chinese", "English").

        Returns:
            List[ForcedAlignResult]:
                One result per sample. Each result contains `items`, and each token can be accessed via
                `.text`, `.start_time`, `.end_time`.
        r   zBatch size mismatch: audio=z, text=z, language=ptT)r=   r   return_tensorspaddingrV   )dim	input_idsr   r   g     @@   r   Nr&   )r   r   r\   
ValueErrorzipr   r}   r?   r   tor   r   dtypethinkerlogitsargmaxr   r   numpyr   roundr   )r    r   r=   rt   texts	languagesaudios
word_listsaligner_input_textstlangr{   aligner_input_textinputsr   
output_idsresultsinput_id	output_idmasked_output_idtimestamp_msr   r   r&   r&   r'   align  sD   " 
zQwen3ForcedAligner.alignc                 C   s>   t | jdd}t|sdS | }|du rdS tdd |D S )a  
        List supported language names for the current model.

        This is a thin wrapper around `self.model.get_support_languages()`.
        If the underlying model does not expose language constraints (returns None),
        this method also returns None.

        Returns:
            Optional[List[str]]:
                - A sorted list of supported language names (lowercased), if available.
                - None if the model does not provide supported languages.
        get_support_languagesNc                 S   s   h | ]}t | qS r&   )r   ry   )r6   xr&   r&   r'   	<setcomp>  s    z=Qwen3ForcedAligner.get_supported_languages.<locals>.<setcomp>)r   r   callablesorted)r    fnlangsr&   r&   r'   get_supported_languages  s   z*Qwen3ForcedAligner.get_supported_languagesN)r   r   r   r   r	   r
   r   r(   classmethodr   r   r   r   r   r   r   r   inference_moder   r   r   r   r   r&   r&   r&   r'   r   5  s6    

)Cr   )r   r.   dataclassesr   typingr   r   r   r   r   rF   r   "qwen_asr.core.transformers_backendr   r	   r
   transformersr   r   r   utilsr   r   r   r   r   r   r   r&   r&   r&   r'   <module>   s     j