"""
Veena3 Streaming Pipeline - Sliding Window Approach

Implements streaming for both SNAC and BiCodec tokens.

For BiCodec (Spark TTS):
1. Generate tokens from vLLM
2. Parse NEW token IDs incrementally (O(1) per token via cache)
3. Buffer semantic + global tokens
4. Apply sliding window (every N token pairs)
5. Decode and stream audio chunks

OPTIMIZATION (Dec 2025):
- Replaced O(n²) pattern (decode-all + regex-all per iteration)
- Now uses incremental token parsing with BiCodecTokenParser
- ~10x CPU reduction in streaming hot loop

For SNAC (legacy):
1. Filter SNAC token IDs directly
2. Apply sliding window
3. Decode and stream
"""

import asyncio
import re
from typing import AsyncGenerator, Optional, List
from vllm import SamplingParams
from veena3modal.audio.crossfade import crossfade_bytes_int16
from veena3modal.core.token_utils import BiCodecTokenParser

from veena3modal.core.constants import (
    CODE_END_TOKEN_ID,
    CODE_START_TOKEN_ID,
    SNAC_MIN_ID,
    SNAC_MAX_ID,
    TRAINING_STOP_TOKEN_IDS,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    DEFAULT_MAX_TOKENS,
    DEFAULT_MIN_TOKENS,
    DEFAULT_REPETITION_PENALTY,
    DEFAULT_SEED,
)


class Veena3SlidingWindowPipeline:
    """
    Streaming TTS pipeline using sliding window approach.
    
    This eliminates choppy audio and popping artifacts by:
    - Decoding overlapping 28-token windows (4 frames)
    - Keeping only the middle 2048 samples from each decode
    - Creating natural continuity between chunks
    
    Based on the official Canopy Labs implementation.
    """
    
    def __init__(
        self,
        model,
        prompt_builder,
        snac_decoder,
    ):
        """
        Initialize sliding window streaming pipeline.
        
        Args:
            model: Veena3Model instance
            prompt_builder: Veena3PromptBuilder instance
            snac_decoder: SNACDecoder instance (with batching enabled)
        """
        self.model = model
        self.prompt_builder = prompt_builder
        self.snac_decoder = snac_decoder
        
        # OPTIMIZATION: Pre-warm BiCodecTokenParser once at init, not per-request
        # Saves ~123ms per streaming TTFB by avoiding repeated 166K-entry vocab iteration
        tokenizer = getattr(model, "tokenizer", None)
        if tokenizer is None:
            tokenizer = getattr(model.engine, "tokenizer", None)
        self.token_parser = BiCodecTokenParser(tokenizer) if tokenizer else None
        
        print(f"🌊 Veena3SlidingWindowPipeline initialized (sliding window: 28 tokens)")
    
    async def generate_speech_stream(
        self,
        description: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech audio with sliding window streaming.
        
        Yields audio chunks using overlapping windows for smooth playback.
        
        Args:
            description: Character/voice description
            text: Text to synthesize (with optional <emotion> tags)
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max SNAC tokens to generate
            repetition_penalty: Prevent loops
        
        Yields:
            Audio bytes (int16 PCM, 24kHz mono)
        """
        print(f"\n🌊 Sliding window streaming generation")
        print(f"📝 Description: {description[:80]}...")
        print(f"💬 Text: {text}")
        
        # Build prompt
        prompt = self.prompt_builder.build_prefix(description, text)
        print(f"✅ Prompt built ({len(prompt)} chars)")
        
        # Configure sampling
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop_token_ids=[CODE_END_TOKEN_ID],  # Only audio EOS
            seed=seed,  # None = random, int = reproducible
        )
        
        print(f"🎲 Sampling: temp={temperature}, top_p={top_p}, sliding_window=28 tokens")
        
        # Token buffer - keeps ALL tokens (not chunked)
        token_buffer = []
        total_tokens_generated = 0
        total_audio_chunks = 0
        
        # Generate tokens with vLLM (streaming)
        print(f"🔮 Starting token generation...")
        
        # Generate unique request ID for concurrent streaming support
        import uuid
        import time
        request_id = f"slide-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Stream tokens with sliding window
        async for request_output in results_generator:
            # Extract generated token IDs
            generated_ids = request_output.outputs[0].token_ids
            
            # Process only new tokens
            new_tokens = generated_ids[total_tokens_generated:]
            total_tokens_generated = len(generated_ids)
            
            # Filter and buffer SNAC tokens
            for token_id in new_tokens:
                if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
                    token_buffer.append(token_id)
                    
                    # Process every 7 tokens (1 frame) when we have enough for sliding window
                    # Official approach: once count > 27, take last 28 tokens
                    if len(token_buffer) % 7 == 0 and len(token_buffer) > 27:
                        # Sliding window: take last 28 tokens (4 frames)
                        window_tokens = token_buffer[-28:]
                        
                        # Decode with sliding window mode (returns middle 2048 samples only)
                        if self.snac_decoder.enable_batching:
                            audio_bytes = await self.snac_decoder.decode_single_async(
                                window_tokens, 
                                trim_warmup=False,  # Sliding window handles trimming
                                use_sliding_window=True  # CRITICAL: Use sliding window mode
                            )
                        else:
                            audio_bytes = self.snac_decoder.decode_to_bytes(
                                window_tokens, 
                                trim_warmup=False,
                                use_sliding_window=True
                            )
                        
                        if audio_bytes:
                            total_audio_chunks += 1
                            if total_audio_chunks == 1:
                                print(f"🎵 First chunk decoded ({len(audio_bytes)} bytes, sliding window)")
                            yield audio_bytes
        
        # Note: No final chunk processing needed - sliding window handles all tokens
        # as they come in (every 7 tokens after the first 28)
        
        print(f"✅ Sliding window streaming complete: {total_tokens_generated} tokens → {total_audio_chunks} audio chunks")
    
    def _extract_bicodec_tokens_from_text(self, text: str) -> tuple[List[int], List[int]]:
        """
        Extract BiCodec semantic and global tokens from generated text.
        
        Args:
            text: Generated text containing BiCodec token markers
        
        Returns:
            Tuple of (semantic_ids, global_ids)
        """
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", text)
        semantic_ids = [int(t) for t in semantic_matches]
        
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", text)
        global_ids = [int(t) for t in global_matches]
        
        return semantic_ids, global_ids
    
    async def generate_speech_stream_indic(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,  # Added for Spark TTS parity with non-streaming
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech audio with sliding window streaming for Indic model.
        
        Yields audio chunks using overlapping windows for smooth playback.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers)
            text: Text to synthesize with inline emotion tags
                Examples:
                - "Hello! Welcome."
                - "<laugh> Hello there!"
                - "नमस्ते! <excited> आज का दिन बहुत अच्छा है।"
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max SNAC tokens to generate
            repetition_penalty: Prevent loops
            seed: Random seed for reproducibility
        
        Yields:
            Audio bytes (int16 PCM, 16kHz mono - BiCodec)
        """
        import time
        t_start = time.time()
        
        # Build prompt using Indic prompt builder
        prompt = self.prompt_builder.build_prefix(speaker, text)
        
        # Configure sampling (matching non-streaming pipeline parameters)
        # OPTIMIZATION: Use Spark TTS stop token, not legacy SNAC CODE_END_TOKEN_ID (128258)
        # which doesn't exist in the Spark vocab and would cause generation to run to max_tokens
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop=TRAINING_STOP_TOKEN_IDS,  # "<|im_end|>" - matches non-streaming pipeline
            skip_special_tokens=False,  # Keep BiCodec tokens in output
            seed=seed,
        )
        
        # BiCodec token buffers (separate for semantic and global)
        semantic_buffer = []
        global_buffer = []
        processed_token_count = 0  # Track how many vLLM tokens we've parsed
        total_audio_chunks = 0
        
        # BiCodec streaming configuration
        EXPECTED_GLOBAL_COUNT = 32  # BiCodec uses 32 fixed global tokens (per paper)
        MIN_SEMANTIC_FOR_FIRST_CHUNK = 10  # OPTIMIZATION: Lowered from 16 (decoder min is 8, -120ms TTFB)
        # OPTIMIZATION: Increased from 24 to 48 to reduce O(n) re-decode calls by 2x.
        # At 50 TPS, 48 tokens = ~960ms audio per chunk. Cuts total decode calls roughly in half.
        DECODE_INTERVAL = 48
        CROSSFADE_MS = 50
        # OPTIMIZATION: Windowed decode - only decode last WINDOW_SIZE tokens instead of ALL.
        # BiCodec WaveGenerator has finite receptive field (~64 tokens based on kernel sizes).
        # Window must be >= receptive_field + DECODE_INTERVAL to avoid boundary artifacts.
        WINDOW_SIZE = 128  # 128 tokens covers receptive field + margin for crossfade
        
        # Generate tokens with vLLM (streaming)
        import uuid
        request_id = f"bicodec-stream-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Reuse pre-warmed token parser (singleton, created once in __init__)
        token_parser = self.token_parser
        
        t_gen_start = time.time()
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track timing
        t_first_token = None
        t_first_chunk_ready = None
        t_first_chunk_decoded = None
        iteration_count = 0
        last_decode_count = 0
        
        # Track audio samples EMITTED to the user (not just decoded)
        # BiCodec produces ~320 samples per semantic token at 16kHz
        # CRITICAL: This tracks samples that have been YIELDED, not held in tail
        total_samples_emitted_to_user = 0
        # Keep a tail to enable crossfading between consecutive emitted chunks
        previous_chunk_tail: Optional[bytes] = None
        
        # Stream tokens with BiCodec extraction
        async for request_output in results_generator:
            iteration_count += 1
            t_iter = time.time()
            
            # Track first token received
            if t_first_token is None:
                t_first_token = t_iter
            
            # Extract generated token IDs
            generated_ids = request_output.outputs[0].token_ids
            
            # OPTIMIZATION: Only process NEW tokens (incremental parsing)
            # Old code: tokenizer.decode(ALL) + re.findall(ALL) = O(n²) per stream
            # New code: parse only new tokens = O(1) per token
            new_token_ids = generated_ids[processed_token_count:]
            processed_token_count = len(generated_ids)
            
            # Parse new tokens and append to buffers (in-place modification)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer)
            
            # TRUE STREAMING with BiCodec two-phase generation:
            # Phase 1: Generate 32 global tokens (pre-roll)
            # Phase 2: Stream semantic tokens incrementally
            
            # Check if we have the 32 global tokens (pre-roll complete)
            if len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                # Pre-roll complete! Now we can stream semantic tokens
                if t_first_chunk_ready is None and len(semantic_buffer) >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                    t_first_chunk_ready = time.time()
                
                # Stream semantic tokens incrementally
                # Only decode if we have enough NEW semantic tokens since last decode
                semantic_count = len(semantic_buffer)
                if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                    # Check if we should decode (every DECODE_INTERVAL new semantic tokens)
                    if semantic_count - last_decode_count >= DECODE_INTERVAL:
                        last_decode_count = semantic_count
                        
                        # OPTIMIZATION: Windowed decode instead of full re-decode
                        # Only decode the last WINDOW_SIZE tokens (O(1) per chunk instead of O(n))
                        # For first few chunks when total tokens < WINDOW_SIZE, decode all
                        window_global = global_buffer[:EXPECTED_GLOBAL_COUNT]
                        
                        if len(semantic_buffer) <= WINDOW_SIZE:
                            # Early phase: decode all tokens (window not yet relevant)
                            decode_semantic = semantic_buffer
                        else:
                            # Windowed: only decode last WINDOW_SIZE tokens
                            decode_semantic = semantic_buffer[-WINDOW_SIZE:]
                        
                        t_decode_start = time.time()
                        audio_bytes = await self.snac_decoder.decode_single_async(
                            semantic_ids=decode_semantic,
                            global_ids=window_global,
                            trim_warmup=False,
                            use_sliding_window=False,
                        )
                        t_decode_end = time.time()
                        
                        if audio_bytes:
                            total_samples_decoded = len(audio_bytes) // 2
                            
                            if len(semantic_buffer) <= WINDOW_SIZE:
                                # Early phase: extract new samples from full decode
                                if total_samples_decoded > total_samples_emitted_to_user:
                                    new_bytes_start = total_samples_emitted_to_user * 2
                                    new_audio_bytes = audio_bytes[new_bytes_start:]
                                else:
                                    new_audio_bytes = b""
                            else:
                                # Windowed phase: take only the tail portion of the window decode
                                # The last DECODE_INTERVAL tokens' worth of audio is the new content
                                # Each token produces ~320 samples at 16kHz/50TPS
                                new_samples_approx = DECODE_INTERVAL * 320
                                if total_samples_decoded > new_samples_approx:
                                    new_bytes_start = (total_samples_decoded - new_samples_approx) * 2
                                else:
                                    new_bytes_start = 0
                                new_audio_bytes = audio_bytes[new_bytes_start:]
                            
                            if new_audio_bytes:
                                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                    previous_chunk_tail,
                                    new_audio_bytes,
                                    sample_rate_hz=16000,
                                    crossfade_ms=CROSSFADE_MS,
                                )
                                
                                samples_emitted_in_this_chunk = len(to_emit) // 2
                                total_samples_emitted_to_user += samples_emitted_in_this_chunk
                                
                                if to_emit:
                                    total_audio_chunks += 1
                                    if t_first_chunk_decoded is None:
                                        t_first_chunk_decoded = t_decode_end
                                    
                                    yield to_emit
        
        # Final chunk: decode any remaining semantic tokens
        # Use the same 32 global tokens (they get pooled)
        if len(semantic_buffer) > last_decode_count and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,  # All semantic tokens
                global_ids=global_buffer[:EXPECTED_GLOBAL_COUNT],  # Use first 32 global tokens
                use_sliding_window=False,  # Get all samples
                trim_warmup=False
            )
            if audio_bytes:
                # Only yield NEW samples
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]

                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )

                    samples_emitted_in_final = len(to_emit) // 2
                    total_samples_emitted_to_user += samples_emitted_in_final
                    
                    if to_emit:
                        total_audio_chunks += 1
                        yield to_emit
        
        # Flush any remaining tail that was held back for crossfade
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            yield previous_chunk_tail
        
        # Log completion summary
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000  # 16kHz sample rate
        ttfb_ms = (t_first_chunk_decoded - t_start) * 1000 if t_first_chunk_decoded else 0
        print(f"🎵 Streaming complete: {audio_duration_s:.2f}s audio, TTFB: {ttfb_ms:.0f}ms, {total_audio_chunks} chunks")
    
    async def generate_speech_stream_indic_first_chunk(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[tuple, None]:
        """
        Generate speech for the FIRST chunk of multi-chunk text, capturing global tokens.
        
        This method is specifically for chunked generation: it yields (audio_bytes, global_ids)
        tuples instead of just audio_bytes. The caller captures global_ids from the first
        yield and passes them to generate_speech_stream_indic_continuation() for subsequent chunks.
        
        Use Case (chunked streaming with voice consistency):
            globals_captured = None
            async for audio_bytes, global_ids in pipeline.generate_speech_stream_indic_first_chunk(...):
                if globals_captured is None and global_ids:
                    globals_captured = global_ids  # Capture once
                yield audio_bytes
            # Now use globals_captured for subsequent chunks
        
        Args:
            speaker: Speaker name
            text: First text chunk to synthesize
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max tokens to generate
            repetition_penalty: Prevent repetition
            seed: Random seed for reproducibility
        
        Yields:
            Tuple of (audio_bytes, global_ids)
            - audio_bytes: Raw PCM audio (int16, 16kHz)
            - global_ids: List of 32 global token IDs (populated after first decode, else empty)
        
        Thread Safety:
            This method is stateless and thread-safe. Each call creates its own
            request-scoped state. No global state is modified or shared.
        """
        import time
        t_start = time.time()
        
        # Build prompt using Indic prompt builder
        prompt = self.prompt_builder.build_prefix(speaker, text)
        
        # Configure sampling
        # OPTIMIZATION: Use Spark TTS stop token, not legacy SNAC CODE_END_TOKEN_ID
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop=TRAINING_STOP_TOKEN_IDS,  # "<|im_end|>" - matches non-streaming pipeline
            skip_special_tokens=False,
            seed=seed,
        )
        
        # BiCodec token buffers
        semantic_buffer = []
        global_buffer = []
        processed_token_count = 0  # Track how many vLLM tokens we've parsed
        total_audio_chunks = 0
        
        # BiCodec streaming configuration
        EXPECTED_GLOBAL_COUNT = 32
        MIN_SEMANTIC_FOR_FIRST_CHUNK = 10  # OPTIMIZATION: Lowered from 16 (decoder min is 8)
        DECODE_INTERVAL = 48  # OPTIMIZATION: Increased from 24 to reduce O(n) re-decode calls
        CROSSFADE_MS = 50
        
        # Generate unique request ID
        import uuid
        request_id = f"bicodec-first-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Reuse pre-warmed token parser (singleton, created once in __init__)
        token_parser = self.token_parser
        
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track state
        last_decode_count = 0
        total_samples_emitted_to_user = 0
        previous_chunk_tail: Optional[bytes] = None
        captured_globals: List[int] = []  # Will be populated once we have 32 globals
        
        # Stream tokens with BiCodec extraction
        async for request_output in results_generator:
            generated_ids = request_output.outputs[0].token_ids
            
            # OPTIMIZATION: Incremental parsing (O(1) per new token)
            new_token_ids = generated_ids[processed_token_count:]
            processed_token_count = len(generated_ids)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer)
            
            # Capture global tokens once we have all 32
            if not captured_globals and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                captured_globals = global_buffer[:EXPECTED_GLOBAL_COUNT]
            
            # Process if we have all global tokens
            if len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                semantic_count = len(semantic_buffer)
                if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                    if semantic_count - last_decode_count >= DECODE_INTERVAL:
                        last_decode_count = semantic_count
                        
                        all_semantic = semantic_buffer
                        window_global = global_buffer[:EXPECTED_GLOBAL_COUNT]
                        
                        if self.snac_decoder.enable_batching:
                            audio_bytes = await self.snac_decoder.decode_single_async(
                                semantic_ids=all_semantic,
                                global_ids=window_global,
                                trim_warmup=False,
                                use_sliding_window=False
                            )
                        else:
                            audio_bytes = self.snac_decoder.decode_streaming(
                                semantic_ids=all_semantic,
                                global_ids=window_global,
                                use_sliding_window=False,
                                trim_warmup=False
                            )
                        
                        if audio_bytes:
                            total_samples_decoded = len(audio_bytes) // 2
                            
                            if total_samples_decoded > total_samples_emitted_to_user:
                                new_bytes_start = total_samples_emitted_to_user * 2
                                new_audio_bytes = audio_bytes[new_bytes_start:]
                                
                                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                    previous_chunk_tail,
                                    new_audio_bytes,
                                    sample_rate_hz=16000,
                                    crossfade_ms=CROSSFADE_MS,
                                )
                                
                                samples_emitted_in_this_chunk = len(to_emit) // 2
                                total_samples_emitted_to_user += samples_emitted_in_this_chunk
                                
                                if to_emit:
                                    total_audio_chunks += 1
                                    # Yield tuple: (audio_bytes, global_ids)
                                    yield (to_emit, captured_globals)
        
        # Final chunk processing
        if len(semantic_buffer) > last_decode_count and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,
                global_ids=global_buffer[:EXPECTED_GLOBAL_COUNT],
                use_sliding_window=False,
                trim_warmup=False
            )
            if audio_bytes:
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]
                    
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )
                    
                    samples_emitted_in_final = len(to_emit) // 2
                    total_samples_emitted_to_user += samples_emitted_in_final
                    
                    if to_emit:
                        total_audio_chunks += 1
                        yield (to_emit, captured_globals)
        
        # Flush remaining tail
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            yield (previous_chunk_tail, captured_globals)
        
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000
        print(f"🎵 First chunk complete: {audio_duration_s:.2f}s audio, {total_audio_chunks} chunks, captured {len(captured_globals)} globals")
    
    async def generate_speech_stream_indic_continuation(
        self,
        speaker: str,
        text: str,
        global_ids: List[int],
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech for CONTINUATION chunks using pre-captured global tokens.
        
        This is the key method for voice consistency in chunked generation:
        - Uses global_ids from the first chunk to maintain identical voice
        - Model generates only semantic tokens (skips global token generation)
        - Ensures no voice drift across chunks
        
        CRITICAL for production:
        - global_ids is request-scoped, passed explicitly (no shared state)
        - Thread-safe: each request has its own state
        - No global caching that could cause cross-request contamination
        
        Args:
            speaker: Speaker name (must match first chunk)
            text: Continuation text chunk to synthesize
            global_ids: 32 global token IDs captured from first chunk
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max tokens to generate
            repetition_penalty: Prevent repetition
            seed: Random seed (use same as first chunk for consistency)
        
        Yields:
            Audio bytes (int16 PCM, 16kHz mono)
        
        Raises:
            ValueError: If global_ids doesn't contain exactly 32 tokens
        
        Thread Safety:
            Fully thread-safe. All state is request-scoped and passed explicitly.
        """
        import time
        t_start = time.time()
        
        # Validate global tokens
        EXPECTED_GLOBAL_COUNT = 32
        if len(global_ids) != EXPECTED_GLOBAL_COUNT:
            raise ValueError(
                f"Expected exactly {EXPECTED_GLOBAL_COUNT} global tokens, got {len(global_ids)}. "
                f"global_ids must be captured from first chunk generation."
            )
        
        # Build prompt WITH pre-filled global tokens
        # This tells the model to skip global generation and go straight to semantic tokens
        prompt = self.prompt_builder.build_prefix_with_globals(speaker, text, global_ids)
        
        # Configure sampling
        # OPTIMIZATION: Use Spark TTS stop token, not legacy SNAC CODE_END_TOKEN_ID
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop=TRAINING_STOP_TOKEN_IDS,  # "<|im_end|>" - matches non-streaming pipeline
            skip_special_tokens=False,
            seed=seed,
        )
        
        # BiCodec token buffers
        semantic_buffer = []
        global_buffer_unused = []  # Not used in continuation, but needed for parser API
        processed_token_count = 0  # Track how many vLLM tokens we've parsed
        total_audio_chunks = 0
        
        # BiCodec streaming configuration
        MIN_SEMANTIC_FOR_FIRST_CHUNK = 10  # OPTIMIZATION: Lowered from 16 (decoder min is 8)
        DECODE_INTERVAL = 48  # OPTIMIZATION: Increased from 24 to reduce O(n) re-decode calls
        CROSSFADE_MS = 50
        
        # Generate unique request ID
        import uuid
        request_id = f"bicodec-cont-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Reuse pre-warmed token parser (singleton, created once in __init__)
        token_parser = self.token_parser
        
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track state
        last_decode_count = 0
        total_samples_emitted_to_user = 0
        previous_chunk_tail: Optional[bytes] = None
        
        # Stream tokens - model will generate ONLY semantic tokens (globals are pre-filled)
        async for request_output in results_generator:
            generated_ids = request_output.outputs[0].token_ids
            
            # OPTIMIZATION: Incremental parsing (O(1) per new token)
            # Note: In continuation, model only generates semantic tokens (globals pre-filled in prompt)
            new_token_ids = generated_ids[processed_token_count:]
            processed_token_count = len(generated_ids)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer_unused)
            
            semantic_count = len(semantic_buffer)
            if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                if semantic_count - last_decode_count >= DECODE_INTERVAL:
                    last_decode_count = semantic_count
                    
                    # Use pre-captured global tokens for decoding
                    if self.snac_decoder.enable_batching:
                        audio_bytes = await self.snac_decoder.decode_single_async(
                            semantic_ids=semantic_buffer,
                            global_ids=global_ids,  # Use captured globals
                            trim_warmup=False,
                            use_sliding_window=False
                        )
                    else:
                        audio_bytes = self.snac_decoder.decode_streaming(
                            semantic_ids=semantic_buffer,
                            global_ids=global_ids,  # Use captured globals
                            use_sliding_window=False,
                            trim_warmup=False
                        )
                    
                    if audio_bytes:
                        total_samples_decoded = len(audio_bytes) // 2
                        
                        if total_samples_decoded > total_samples_emitted_to_user:
                            new_bytes_start = total_samples_emitted_to_user * 2
                            new_audio_bytes = audio_bytes[new_bytes_start:]
                            
                            to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                previous_chunk_tail,
                                new_audio_bytes,
                                sample_rate_hz=16000,
                                crossfade_ms=CROSSFADE_MS,
                            )
                            
                            samples_emitted_in_this_chunk = len(to_emit) // 2
                            total_samples_emitted_to_user += samples_emitted_in_this_chunk
                            
                            if to_emit:
                                total_audio_chunks += 1
                                yield to_emit
        
        # Final chunk processing
        if len(semantic_buffer) > last_decode_count:
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,
                global_ids=global_ids,  # Use captured globals
                use_sliding_window=False,
                trim_warmup=False
            )
            if audio_bytes:
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]
                    
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )
                    
                    samples_emitted_in_final = len(to_emit) // 2
                    total_samples_emitted_to_user += samples_emitted_in_final
                    
                    if to_emit:
                        total_audio_chunks += 1
                        yield to_emit
        
        # Flush remaining tail
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            yield previous_chunk_tail
        
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000
        print(f"🎵 Continuation chunk complete: {audio_duration_s:.2f}s audio, {total_audio_chunks} chunks (using cached globals)")

