o
    oiH                     @   s   d Z ddlZddlmZ ddlmZmZmZ ddlm	Z	m
Z
 ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZ 		ddejdedededejf
ddZG dd dejZG dd deZ G dd deeZ!dS )a  
Custom LFM2 implementation with KaniTTS-2 frame-level position encoding.

Key Innovation:
- Frame-level position encoding: All 4 tokens within an audio frame share the same position ID
  This reduces RoPE distance between tokens across frames, improving long-form generation.

Compatible with Flash Attention 2 for 10-20x training speedup.

FIXED: Proper frame-level position tracking during generation with KV-cache.
    N)OptionalUnionTuple)CausalLMOutputWithPastBaseModelOutputWithPast)TransformersKwargs)Unpack)Cache)GenerationMixin)	Lfm2ModelLfm2ForCausalLMLfm2PreTrainedModelLfm2HybridConvCache)
Lfm2Config         ?	input_idsaudio_tokens_starttokens_per_frame
audio_stepreturnc                 C   s   | j \}}| j}| |k}| }tj|d|tjd}	tj|	| gddjddddddf }
tj|	| gddjddddddf }|| }|
||  }|S )a  
    Vectorized computation of frame-level position IDs (10-50x faster than Python loops).

    Key insight: Use cumulative counts to determine positions.

    - Text tokens: sequential positions (step 1.0)
    - Audio tokens: frame-level positions (step audio_step per frame)

    Args:
        input_ids: Input token IDs [batch_size, seq_len]
        audio_tokens_start: Token ID where audio tokens begin (typically 64410)
        tokens_per_frame: Number of tokens per audio frame (typically 4)
        audio_step: Position step size per audio frame (default 1.0).
                    Set to < 1.0 (e.g., 0.5) to compress audio position space.

    Returns:
        position_ids: Position IDs [batch_size, seq_len].
                      if audio_step is float, returns FloatTensor.

    Example:
        >>> input_ids = torch.tensor([[100, 200, 64410, 68442, 72474, 76506, 300]])
        >>> # Tokens:                [text, text, aud0,  aud1,  aud2,  aud3,  text]
        >>> pos = compute_frame_level_positions(input_ids, 64410, 4, audio_step=0.5)
        >>> pos
        tensor([[0., 1., 2., 2., 2., 2., 3.]])
        # Text at 0, 1. Audio frame at 2. Next text at 3 (1+1+1?)
        # Note: Text logic accumulates 1 per text token.
        # Audio logic accumulates audio_step per frame.
       devicedtypedimN)shaper   torchzeroslongcatcumsum)r   r   r   r   
batch_sizeseq_lenr   is_audio	text_maskr    
text_countaudio_token_countaudio_frame_countposition_ids r,   B/home/ubuntu/.local/lib/python3.10/site-packages/kani_tts/model.pycompute_frame_level_positions    s   
#..r.   c                       sH   e Zd ZdZ			d fdd	Zedd Zed	d
 Zdd Z  Z	S )LearnableRotaryEmbeddingu  
    Learnable RoPE with layer-wise frequency scaling.

    Each layer has a learnable alpha parameter that scales the RoPE frequencies:
        θᵢ^(l) = α^(l) · base^(-2i/d)

    where α^(l) is constrained to [alpha_min, alpha_max] via sigmoid reparameterization:
        α^(l) = alpha_min + (alpha_max - alpha_min) · sigmoid(w^(l))

    This allows the model to learn optimal positional encoding frequencies per layer.
    皙?       @Nc                    s   t    || _|| _|| _|| _|j|j }|j}|j	}	|| _
|| _|	| _	d|tjd|dtjdj|tjd|   }
| jd|
dd ttd	| _d S )
Nr   r      r   r   inv_freq_baseF)
persistentg        )super__init__	layer_idxtotal_attention_layers	alpha_min	alpha_maxhidden_sizenum_attention_heads
rope_thetamax_position_embeddingsr   baser   arangeint64tofloatregister_buffernn	Parametertensoralpha_weight)selfconfigr8   r9   r:   r;   r   r   r@   r?   r4   	__class__r,   r-   r7   l   s    
	&z!LearnableRotaryEmbedding.__init__c                 C   s   | j | j| j  t| j  S )z
        Compute constrained alpha via sigmoid reparameterization.

        Returns:
            Scalar alpha value in range [alpha_min, alpha_max]
        )r:   r;   r   sigmoidrI   rJ   r,   r,   r-   alpha   s   zLearnableRotaryEmbedding.alphac                 C   s   | j | j S )u   
        Compute scaled inverse frequencies: α^(l) · θᵢ

        Returns:
            Tensor of shape [d/2] with scaled frequencies
        )r4   rP   rO   r,   r,   r-   inv_freq   s   z!LearnableRotaryEmbedding.inv_freqc           
      C   s   | j ddddf  |jd dd}|dddddf  }|jj}t|tr0|dkr0|nd}tj	|dd% | |  
dd	}tj||fdd
}| }| }	W d   n1 saw   Y  |j|jd|	j|jdfS )aX  
        Compute rotary position embeddings for the input.

        Args:
            x: Input tensor of shape [batch_size, num_heads, seq_len, head_dim]
            position_ids: Position indices of shape [batch_size, seq_len]

        Returns:
            Tuple of (cos, sin) embeddings, each of shape [batch_size, seq_len, head_dim]
        Nr   r   r   mpscpuF)device_typeenabledr2   r   r3   )rQ   rD   expandr   r   type
isinstancestrr   autocast	transposer"   cossinrC   r   )
rJ   xr+   inv_freq_expandedposition_ids_expandedrT   freqsembr\   r]   r,   r,   r-   forward   s   (
z LearnableRotaryEmbedding.forward)r0   r1   N)
__name__
__module____qualname____doc__r7   propertyrP   rQ   rc   __classcell__r,   r,   rL   r-   r/   _   s    !
	
	r/   c                       s   e Zd ZdZ						dded	ed
edededededef fddZ								d de	e
j de	e
j de	e
j de	e de	e
j de	e
j de	e de	e
j dee def fddZ  ZS )!Lfm2ForKaniModelz
    Custom LFM2 model with KaniTTS-2 frame-level position encoding.

    This version only overrides position ID computation - everything else
    uses the standard Lfm2Model implementation.
    r   r   Fr0   r1      rK   r   r   r   use_learnable_roper:   r;   speaker_emb_dimc	              
      s  t  | || _|| _|| _|| _|| _|| _|| _t	j
||jdd| _|rg }	t|drBt|jD ]\}
}|dkr@|	|
 q3ntt|j}	t|	}t	 | _t|jD ]%}
|
|	v rvt||
|||t|drk|jnd d}| j| qW| jd  qWtd td|  td	|  td
| d|j  td td| d td| d| d d S d | _td td|  td	|  td|  td
| d|j  td td d S )NFbiaslayer_typesfull_attentionr   )rK   r8   r9   r:   r;   r   u!   ✅ Lfm2ForKaniModel initialized:z   - Audio tokens start: z   - Tokens per frame: z   - Speaker embedding: z -> z4   - Using frame-level position encoding (KaniTTS-2)z    - Learnable RoPE ENABLED for z attention layersz   - Alpha range: [z, ]z   - Audio step: z,   - Learnable RoPE DISABLED (standard RoPE))r6   r7   r   r   r   rl   r:   r;   rm   rF   Linearr<   speaker_emb_projectionhasattr	enumeraterp   appendlistrangenum_hidden_layerslen
ModuleListlearnable_rope_layersr/   r   print)rJ   rK   r   r   r   rl   r:   r;   rm   attention_layer_indicesidx
layer_typer9   learnable_roperL   r,   r-   r7      s^   


zLfm2ForKaniModel.__init__Nr   attention_maskr+   past_key_valuesinputs_embedsspeaker_emb	use_cachecache_positionkwargsr   c	              
      s  |du r|durt || j| j| jd}| js't jd|||||||d|	S ddlm}
m	} |du r;|du r;t
d|du rD| |}|rZ|du rZ|jd }|
| j|| j| jd}|du rx|durf| nd}|jd }tj||| |jd	}|du r|d}|| j|||||d
}|}d}t| jd| jj D ],\}}| j| dur| j| ||}n
|du r| ||}||f|||||d|	}q| |}t||dS )a^  
        Forward pass with custom frame-level position IDs and speaker embeddings.

        Speaker embeddings are inserted at position 1 (after the first token).
        All subsequent positions are shifted by +1 to maintain sequential ordering.

        If learnable RoPE is disabled:
            Delegates to parent class after computing frame-level position IDs.

        If learnable RoPE is enabled:
            Overrides position embedding computation to use per-layer learnable RoPE.

        Args:
            speaker_emb: Speaker embeddings [batch_size, speaker_emb_dim] (e.g., [1, 128])
        Nr   r   r   r   )r   r   r+   r   r   r   r   r   )r   create_causal_maskz;You must specify at least one of input_ids or inputs_embeds)rK   max_batch_sizer   r   r   r   )rK   input_embedsr   r   r   r+   )r   r+   r   r   position_embeddings)last_hidden_stater   r,   )r.   r   r   r   rl   r6   rc   &transformers.models.lfm2.modeling_lfm2r   r   
ValueErrorembed_tokensr   rK   r   r   get_seq_lengthr   rA   	unsqueezerv   layersrz   r}   pos_embembedding_normr   )rJ   r   r   r+   r   r   r   r   r   r   r   r   r$   past_seen_tokens
seq_lengthcausal_maskhidden_statesr   r8   decoder_layerrL   r,   r-   rc     s   



	

zLfm2ForKaniModel.forwardr   r   Fr0   r1   rk   )NNNNNNNN)rd   re   rf   rg   r   intrD   boolr7   r   r   
LongTensorTensorr   FloatTensorr   r   r   rc   ri   r,   r,   rL   r-   rj      sl    	I	
rj   c                       s  e Zd ZdZdgZedefddZ						
	d?dede	de	de
dede
de
de	f fddZd@dee	 fddZde	fddZ										dAdeej deej d eej d!ee d"eej d#eej d$eej d%ee d&eej d'ee	ejf d(ee defd)d*Z						+dBd,d-Z fd.d/Ze							dCd0ede	de	de
dede
de
de	fd1d2Zd3d4 Zd5d6 Zd7d8 Zd9d: Z d;d< Z!d=d> Z"  Z#S )DKaniTTS2ForCausalLMa  
    Flash Attention compatible LFM2 for causal language modeling with KaniTTS-2 frame-level positions.

    Features:
    - Frame-level position encoding for audio tokens (KaniTTS-2 innovation)
    - Optional learnable RoPE with per-layer frequency scaling (alpha parameters)
    - Proper position tracking during generation with KV-cache
    - Flash Attention 2 optimized
    - Standard causal masking
    - Compatible with existing KaniTTS inference pipeline
    - Includes GenerationMixin for text generation capabilities
    lm_head.weightr   c                 C   s   dS )NFr,   )clsr,   r,   r-   _supports_default_dynamic_cache  s   z3KaniTTS2ForCausalLM._supports_default_dynamic_cacher   r   Fr0   r1   rk   rK   r   r   r   rl   r:   r;   rm   c	           	   
      s   t  | t||||||||d| _|j| _tj|j|jdd| _|| _	|| _
|| _|| _|| _|| _|| _d | _d | _t|drE|jnd | _d| _|   d S )N)r   rl   r:   r;   rm   Frn   generation_configr   )r6   r7   rj   model
vocab_sizerF   rs   r<   lm_headr   r   r   rl   r:   r;   rm   _generation_state_current_speaker_embru   r   main_input_name	post_init)	rJ   rK   r   r   r   rl   r:   r;   rm   rL   r,   r-   r7     s2   
zKaniTTS2ForCausalLM.__init__Nstarting_frame_positionc                 C   s    d|dur	t |ndd| _dS )a  
        Reset generation state before starting new generation.

        This tracks:
        - The position where the first audio frame should start
        - How many audio tokens we've generated
        - What the current frame position should be

        Args:
            starting_frame_position: The position ID where the first audio frame begins.
                                   If None, will be determined when first audio token is generated.
        r   N)audio_tokens_generatedcurrent_frame_position)rD   r   )rJ   r   r,   r,   r-   _reset_generation_state  s   z+KaniTTS2ForCausalLM._reset_generation_statenew_token_idc                 C   s2   | j du rdS || jkr| j d  d7  < dS dS )z
        Update generation state after generating a token.

        Args:
            new_token_id: The token ID that was just generated
        Nr   r   )r   r   )rJ   r   r,   r,   r-   _update_generation_state  s
   

z,KaniTTS2ForCausalLM._update_generation_stater   r   r   r+   r   r   r   labelsr   r   logits_to_keepr   c                 K   s   | j d||||||||	d|}|j}t|
trt|
 dn|
}| |dd|ddf }d}|durC| jd||| jjd|}t	|||j
|j|jdS )z
        Forward pass - delegates to flash-compatible model with frame-level position encoding.

        Args:
            speaker_emb: Speaker embeddings [batch_size, speaker_emb_dim] (e.g., [1, 128])
        )r   r   r+   r   r   r   r   r   N)logitsr   r   )lossr   r   r   
attentionsr,   )r   r   rX   r   slicer   loss_functionrK   r   r   r   r   r   )rJ   r   r   r+   r   r   r   r   r   r   r   r   outputsr   slice_indicesr   r   r,   r,   r-   rc     s2   	zKaniTTS2ForCausalLM.forwardTc                 K   sR  |du r| j dur| j|}| j| j }	|	d}	tj|ddddddf |	|ddddddf gdd}|duretj|ddddf tj|jd d|j	|j
d|ddddf gdd}|durtj|dd |dd d |dd d gdd}d}|dur	t|ttfr| }
|
}nt|dkr|d d jd nd }
}|dur|dur|jd |jd kr|dd|jd |  df }n;||jd k r|dd|df }n)||jd kr|ddddf }n|dur	||jd k r	|dd|df }|du r=|dur| nd}|dur#|jd n|jd }|dur0|j	n|j	}tj||| |d}|du r|dur| jdur|durU|j	n|j	}|d	  }|| jk ri| }n5| jd
 du rz| }|| jd
< | jd | j }|dkr| jd dkr| jd
  | j7  < | jd
 }t|trtj|gg|tjd}ntj|gg|tjd}| | n<|dur| j dur|jd }|j	}tj||dd}n|durt|| j| j| jd}|du r|r| jdd |||||d}|du r|dur| j dus||d< |dur'|du r'||d< |S )z
        Prepare inputs for generation with proper frame-level position encoding.

        CRITICAL FIX: Maintains frame-level positions during generation with KV-cache.
        Nr   r   r   r   r2   r   r   )r   r   r   r   r   )r   )r   r+   r   r   r   r   r   )r   r   r   rt   r   r   r"   onesr   r   r   rX   r	   r   r   r{   rA   r   itemr   r   r   rD   rH   r!   r   r.   r   )rJ   r   r   r   r   r   r+   r   r   speaker_emb_projectedcache_lengthpast_lengthr   r   current_tokenposfirst_frame_postoken_in_framer%   model_inputsr,   r,   r-   prepare_inputs_for_generation  s   


&"








 
z1KaniTTS2ForCausalLM.prepare_inputs_for_generationc                    sL   | dd}d| _|| _zt j|i |}W d| _d| _|S d| _d| _w )a6  
        Override generate to reset state before generation.

        This ensures frame-level position tracking starts fresh for each generation call.
        Also handles speaker embeddings if provided.

        Args:
            speaker_emb: Optional speaker embedding [batch_size, speaker_emb_dim]
        r   N)popr   r   r6   generate)rJ   argsr   r   resultrL   r,   r-   r     s   zKaniTTS2ForCausalLM.generatepretrained_model_name_or_pathc	              
   O   s  dd |
  D }ddlm} |j|fi |}|du r*t|dd}|du r*td|du r4t|dd	}|du r>t|d
d}|du rHt|dd}|du rRt|dd}|du r\t|dd}|du rft|dd}| ||||||||d}|rddlm} ddlm	} ddl
}|j|r|j|d}n||dd}||}|j|dd\}}d|v rd|v r|jjj|j_dd |D }|rtdt|  t|dkr|D ]	}td |  q|rtd!t|  dd"lm} z
||}||_W n ty   | |_Y nw |d#d$}|d$krtj rd%nd&}||}n-tj|fi |}|jj|j dd |j|j  t|d'rE|j|_||j }td(|  |S ))a  
        Load a pretrained LFM2 model with KaniTTS-2 flash-compatible implementation.

        Args:
            pretrained_model_name_or_path: HuggingFace model ID or local path
            audio_tokens_start: Token ID where audio tokens begin (typically 64410).
                              If None, reads from model config.
            tokens_per_frame: Number of tokens per audio frame (default: 4).
                            If None, reads from model config.
            audio_step: Step size per audio frame (default: 1.0). Use 0.5 for new models.
                       If None, reads from model config.
            use_learnable_rope: Enable learnable RoPE with per-layer alpha (default: False).
                              If None, reads from model config.
            alpha_min: Minimum alpha value for learnable RoPE (default: 0.1).
                      If None, reads from model config.
            alpha_max: Maximum alpha value for learnable RoPE (default: 2.0).
                      If None, reads from model config.
            speaker_emb_dim: Dimension of speaker embeddings (default: 128).
                           If None, reads from model config.
        c                 S   s   i | ]\}}|d vr||qS ))rl   r:   r;   rm   r,   ).0kvr,   r,   r-   
<dictcomp>  s    z7KaniTTS2ForCausalLM.from_pretrained.<locals>.<dictcomp>r   )
AutoConfigNr   zaudio_tokens_start not provided and not found in model config. Please specify audio_tokens_start explicitly or add it to the model's config.jsonr   r   r   r   rl   Fr:   r0   r;   r1   rm   rk   )rK   r   r   r   rl   r:   r;   rm   )	load_file)hf_hub_downloadzmodel.safetensors)repo_idfilename)strictr   zmodel.embed_tokens.weightc                 S   s   g | ]}|d kr|qS )r   r,   )r   r   r,   r,   r-   
<listcomp>L  s    z7KaniTTS2ForCausalLM.from_pretrained.<locals>.<listcomp>u:      ⚠️  Missing keys (will use random initialization):    z      - u&      ⚠️  Unexpected keys (ignored): )GenerationConfig
device_mapautocudarS   r   u   ✅ Model loaded from )!itemstransformersr   from_pretrainedgetattrr   safetensors.torchr   huggingface_hubr   ospathisdirjoinload_state_dictr   r   weightr   r~   r{   r   r   	Exceptiongetr   r   is_availablerC   r   
state_dictru   r   )r   r   r   r   r   rl   r:   r;   rm   
model_argsr   base_kwargsr   rK   r   r   r   r   safetensors_pathr   missing_keysunexpected_keyskeyr   r   r   r   
base_modelr,   r,   r-   r     s   #



z#KaniTTS2ForCausalLM.from_pretrainedc                 C   s   | j jS Required by GenerationMixin.r   r   rO   r,   r,   r-   get_input_embeddingsy  s   z(KaniTTS2ForCausalLM.get_input_embeddingsc                 C   s   || j _dS r   Nr   )rJ   valuer,   r,   r-   set_input_embeddings}  s   z(KaniTTS2ForCausalLM.set_input_embeddingsc                 C      | j S r   r   rO   r,   r,   r-   get_output_embeddings     z)KaniTTS2ForCausalLM.get_output_embeddingsc                 C   
   || _ dS r   r   )rJ   new_embeddingsr,   r,   r-   set_output_embeddings     
z)KaniTTS2ForCausalLM.set_output_embeddingsc                 C   r  r   r   )rJ   decoderr,   r,   r-   set_decoder  r  zKaniTTS2ForCausalLM.set_decoderc                 C   r   r   r  rO   r,   r,   r-   get_decoder  r   zKaniTTS2ForCausalLM.get_decoderr   )N)
NNNNNNNNNr   )NNNNNT)NNNNNNN)$rd   re   rf   rg   _tied_weights_keysclassmethodr   r   r   r   rD   r7   r   r   r   r   r   r   r	   r   r   r   r   r   rc   r   r   rY   r   r   r   r   r  r  r  ri   r,   r,   rL   r-   r     s    		0	

4
 *	 r   )r   r   )"rg   r   torch.nnrF   typingr   r   r   transformers.modeling_outputsr   r   transformers.utilsr   transformers.processing_utilsr   transformers.cache_utilsr	   transformers.generation.utilsr
   r   r   r   r   r   +transformers.models.lfm2.configuration_lfm2r   r   r   rD   r.   Moduler/   rj   r   r,   r,   r,   r-   <module>   s8    
?^ I