"""
Veena3 Streaming Pipeline - Sliding Window Approach

Implements streaming for both SNAC and BiCodec tokens.

For BiCodec (Spark TTS):
1. Generate tokens from vLLM
2. Parse NEW token IDs incrementally (O(1) per token via cache)
3. Buffer semantic + global tokens
4. Apply sliding window (every N token pairs)
5. Decode and stream audio chunks

OPTIMIZATION (Dec 2025):
- Replaced O(n²) pattern (decode-all + regex-all per iteration)
- Now uses incremental token parsing with BiCodecTokenParser
- ~10x CPU reduction in streaming hot loop

For SNAC (legacy):
1. Filter SNAC token IDs directly
2. Apply sliding window
3. Decode and stream
"""

import asyncio
import re
from typing import AsyncGenerator, Optional, List
from vllm import SamplingParams
from veena3modal.audio.crossfade import crossfade_bytes_int16
from veena3modal.core.token_utils import BiCodecTokenParser

from veena3modal.core.constants import (
    CODE_END_TOKEN_ID,
    CODE_START_TOKEN_ID,
    SNAC_MIN_ID,
    SNAC_MAX_ID,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    DEFAULT_MAX_TOKENS,
    DEFAULT_MIN_TOKENS,
    DEFAULT_REPETITION_PENALTY,
    DEFAULT_SEED,
)


class Veena3SlidingWindowPipeline:
    """
    Streaming TTS pipeline using sliding window approach.
    
    This eliminates choppy audio and popping artifacts by:
    - Decoding overlapping 28-token windows (4 frames)
    - Keeping only the middle 2048 samples from each decode
    - Creating natural continuity between chunks
    
    Based on the official Canopy Labs implementation.
    """
    
    def __init__(
        self,
        model,
        prompt_builder,
        snac_decoder,
    ):
        """
        Initialize sliding window streaming pipeline.
        
        Args:
            model: Veena3Model instance
            prompt_builder: Veena3PromptBuilder instance
            snac_decoder: SNACDecoder instance (with batching enabled)
        """
        self.model = model
        self.prompt_builder = prompt_builder
        self.snac_decoder = snac_decoder
        
        print(f"🌊 Veena3SlidingWindowPipeline initialized (sliding window: 28 tokens)")
    
    async def generate_speech_stream(
        self,
        description: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech audio with sliding window streaming.
        
        Yields audio chunks using overlapping windows for smooth playback.
        
        Args:
            description: Character/voice description
            text: Text to synthesize (with optional <emotion> tags)
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max SNAC tokens to generate
            repetition_penalty: Prevent loops
        
        Yields:
            Audio bytes (int16 PCM, 24kHz mono)
        """
        print(f"\n🌊 Sliding window streaming generation")
        print(f"📝 Description: {description[:80]}...")
        print(f"💬 Text: {text}")
        
        # Build prompt
        prompt = self.prompt_builder.build_prefix(description, text)
        print(f"✅ Prompt built ({len(prompt)} chars)")
        
        # Configure sampling
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop_token_ids=[CODE_END_TOKEN_ID],  # Only audio EOS
            seed=seed,  # None = random, int = reproducible
        )
        
        print(f"🎲 Sampling: temp={temperature}, top_p={top_p}, sliding_window=28 tokens")
        
        # Token buffer - keeps ALL tokens (not chunked)
        token_buffer = []
        total_tokens_generated = 0
        total_audio_chunks = 0
        
        # Generate tokens with vLLM (streaming)
        print(f"🔮 Starting token generation...")
        
        # Generate unique request ID for concurrent streaming support
        import uuid
        import time
        request_id = f"slide-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Stream tokens with sliding window
        async for request_output in results_generator:
            # Extract generated token IDs
            generated_ids = request_output.outputs[0].token_ids
            
            # Process only new tokens
            new_tokens = generated_ids[total_tokens_generated:]
            total_tokens_generated = len(generated_ids)
            
            # Filter and buffer SNAC tokens
            for token_id in new_tokens:
                if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
                    token_buffer.append(token_id)
                    
                    # Process every 7 tokens (1 frame) when we have enough for sliding window
                    # Official approach: once count > 27, take last 28 tokens
                    if len(token_buffer) % 7 == 0 and len(token_buffer) > 27:
                        # Sliding window: take last 28 tokens (4 frames)
                        window_tokens = token_buffer[-28:]
                        
                        # Decode with sliding window mode (returns middle 2048 samples only)
                        if self.snac_decoder.enable_batching:
                            audio_bytes = await self.snac_decoder.decode_single_async(
                                window_tokens, 
                                trim_warmup=False,  # Sliding window handles trimming
                                use_sliding_window=True  # CRITICAL: Use sliding window mode
                            )
                        else:
                            audio_bytes = self.snac_decoder.decode_to_bytes(
                                window_tokens, 
                                trim_warmup=False,
                                use_sliding_window=True
                            )
                        
                        if audio_bytes:
                            total_audio_chunks += 1
                            if total_audio_chunks == 1:
                                print(f"🎵 First chunk decoded ({len(audio_bytes)} bytes, sliding window)")
                            yield audio_bytes
        
        # Note: No final chunk processing needed - sliding window handles all tokens
        # as they come in (every 7 tokens after the first 28)
        
        print(f"✅ Sliding window streaming complete: {total_tokens_generated} tokens → {total_audio_chunks} audio chunks")
    
    def _extract_bicodec_tokens_from_text(self, text: str) -> tuple[List[int], List[int]]:
        """
        Extract BiCodec semantic and global tokens from generated text.
        
        Args:
            text: Generated text containing BiCodec token markers
        
        Returns:
            Tuple of (semantic_ids, global_ids)
        """
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", text)
        semantic_ids = [int(t) for t in semantic_matches]
        
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", text)
        global_ids = [int(t) for t in global_matches]
        
        return semantic_ids, global_ids
    
    async def generate_speech_stream_indic(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,  # Added for Spark TTS parity with non-streaming
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech audio with sliding window streaming for Indic model.
        
        Yields audio chunks using overlapping windows for smooth playback.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers)
            text: Text to synthesize with inline emotion tags
                Examples:
                - "Hello! Welcome."
                - "<laugh> Hello there!"
                - "नमस्ते! <excited> आज का दिन बहुत अच्छा है।"
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max SNAC tokens to generate
            repetition_penalty: Prevent loops
            seed: Random seed for reproducibility
        
        Yields:
            Audio bytes (int16 PCM, 16kHz mono - BiCodec)
        """
        import time
        t_start = time.time()
        
        # Build prompt using Indic prompt builder
        prompt = self.prompt_builder.build_prefix(speaker, text)
        
        # Configure sampling (matching non-streaming pipeline parameters)
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,  # Added for Spark TTS parity with non-streaming
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop_token_ids=[CODE_END_TOKEN_ID],  # Audio EOS
            seed=seed,
        )
        
        # BiCodec token buffers (separate for semantic and global)
        semantic_buffer = []
        global_buffer = []
        processed_token_count = 0  # Track how many vLLM tokens we've parsed
        total_audio_chunks = 0
        
        # BiCodec streaming configuration
        EXPECTED_GLOBAL_COUNT = 32  # BiCodec uses 32 fixed global tokens (per paper)
        MIN_SEMANTIC_FOR_FIRST_CHUNK = 16  # Start streaming after 16 semantic tokens (~320ms at 50 TPS)
        DECODE_INTERVAL = 24  # Decode every 24 new semantic tokens (~480ms per chunk)
        CROSSFADE_MS = 50
        
        # Generate tokens with vLLM (streaming)
        # Generate unique request ID
        import uuid
        request_id = f"bicodec-stream-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Get tokenizer from SparkTTSModel (falls back to engine tokenizer if unavailable)
        tokenizer = getattr(self.model, "tokenizer", None)
        if tokenizer is None:
            tokenizer = getattr(self.model.engine, "tokenizer", None)
        if tokenizer is None:
            raise RuntimeError("Spark TTS tokenizer unavailable on model or engine")
        
        # OPTIMIZATION: Create incremental token parser (caches token_id -> (type, value))
        # This replaces the O(n²) decode-all + regex-all pattern
        token_parser = BiCodecTokenParser(tokenizer)
        
        t_gen_start = time.time()
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track timing
        t_first_token = None
        t_first_chunk_ready = None
        t_first_chunk_decoded = None
        iteration_count = 0
        last_decode_count = 0
        
        # Track audio samples EMITTED to the user (not just decoded)
        # BiCodec produces ~320 samples per semantic token at 16kHz
        # CRITICAL: This tracks samples that have been YIELDED, not held in tail
        total_samples_emitted_to_user = 0
        # Keep a tail to enable crossfading between consecutive emitted chunks
        previous_chunk_tail: Optional[bytes] = None
        
        # Stream tokens with BiCodec extraction
        async for request_output in results_generator:
            iteration_count += 1
            t_iter = time.time()
            
            # Track first token received
            if t_first_token is None:
                t_first_token = t_iter
            
            # Extract generated token IDs
            generated_ids = request_output.outputs[0].token_ids
            
            # OPTIMIZATION: Only process NEW tokens (incremental parsing)
            # Old code: tokenizer.decode(ALL) + re.findall(ALL) = O(n²) per stream
            # New code: parse only new tokens = O(1) per token
            new_token_ids = generated_ids[processed_token_count:]
            processed_token_count = len(generated_ids)
            
            # Parse new tokens and append to buffers (in-place modification)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer)
            
            # TRUE STREAMING with BiCodec two-phase generation:
            # Phase 1: Generate 32 global tokens (pre-roll)
            # Phase 2: Stream semantic tokens incrementally
            
            # Check if we have the 32 global tokens (pre-roll complete)
            if len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                # Pre-roll complete! Now we can stream semantic tokens
                if t_first_chunk_ready is None and len(semantic_buffer) >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                    t_first_chunk_ready = time.time()
                
                # Stream semantic tokens incrementally
                # Only decode if we have enough NEW semantic tokens since last decode
                semantic_count = len(semantic_buffer)
                if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                    # Check if we should decode (every DECODE_INTERVAL new semantic tokens)
                    if semantic_count - last_decode_count >= DECODE_INTERVAL:
                        last_decode_count = semantic_count
                        
                        # Decode ALL semantic tokens accumulated so far
                        # CRITICAL: Use the SAME 32 global tokens (they get pooled)
                        # We'll track what we've sent and only yield NEW samples
                        all_semantic = semantic_buffer  # All semantic tokens so far
                        window_global = global_buffer[:EXPECTED_GLOBAL_COUNT]  # Always use first 32 global tokens
                        
                        # Decode without sliding window (get all samples)
                        t_decode_start = time.time()
                        if self.snac_decoder.enable_batching:
                            audio_bytes = await self.snac_decoder.decode_single_async(
                                semantic_ids=all_semantic,
                                global_ids=window_global,  # 32 globals, will be broadcasted inside
                                trim_warmup=False,
                                use_sliding_window=False  # Get ALL samples
                            )
                        else:
                            audio_bytes = self.snac_decoder.decode_streaming(
                                semantic_ids=all_semantic,
                                global_ids=window_global,  # 32 globals, will be broadcasted inside
                                use_sliding_window=False,  # Get ALL samples
                                trim_warmup=False
                            )
                        t_decode_end = time.time()
                        
                        if audio_bytes:
                            # Calculate how many samples this decode produced
                            # audio_bytes is int16 PCM, so 2 bytes per sample
                            total_samples_decoded = len(audio_bytes) // 2
                            
                            # Only yield NEW samples (avoid overlap/echo)
                            # CRITICAL: Extract from where we last EMITTED, not where we last DECODED
                            if total_samples_decoded > total_samples_emitted_to_user:
                                new_bytes_start = total_samples_emitted_to_user * 2  # 2 bytes per sample
                                new_audio_bytes = audio_bytes[new_bytes_start:]

                                # Crossfade with previous tail and hold back a new tail
                                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                    previous_chunk_tail,
                                    new_audio_bytes,
                                    sample_rate_hz=16000,
                                    crossfade_ms=CROSSFADE_MS,
                                )
                                
                                # CRITICAL: Update counter based on what we EMITTED, not decoded
                                # The tail is held back, so we only count what was actually yielded
                                samples_emitted_in_this_chunk = len(to_emit) // 2
                                total_samples_emitted_to_user += samples_emitted_in_this_chunk
                                
                                if to_emit:
                                    total_audio_chunks += 1
                                    if t_first_chunk_decoded is None:
                                        t_first_chunk_decoded = t_decode_end
                                    
                                    yield to_emit
        
        # Final chunk: decode any remaining semantic tokens
        # Use the same 32 global tokens (they get pooled)
        if len(semantic_buffer) > last_decode_count and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,  # All semantic tokens
                global_ids=global_buffer[:EXPECTED_GLOBAL_COUNT],  # Use first 32 global tokens
                use_sliding_window=False,  # Get all samples
                trim_warmup=False
            )
            if audio_bytes:
                # Only yield NEW samples
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]

                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )

                    samples_emitted_in_final = len(to_emit) // 2
                    total_samples_emitted_to_user += samples_emitted_in_final
                    
                    if to_emit:
                        total_audio_chunks += 1
                        yield to_emit
        
        # Flush any remaining tail that was held back for crossfade
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            yield previous_chunk_tail
        
        # Log completion summary
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000  # 16kHz sample rate
        ttfb_ms = (t_first_chunk_decoded - t_start) * 1000 if t_first_chunk_decoded else 0
        print(f"🎵 Streaming complete: {audio_duration_s:.2f}s audio, TTFB: {ttfb_ms:.0f}ms, {total_audio_chunks} chunks")
    
    async def generate_speech_stream_indic_first_chunk(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[tuple, None]:
        """
        Generate speech for the FIRST chunk of multi-chunk text, capturing global tokens.
        
        This method is specifically for chunked generation: it yields (audio_bytes, global_ids)
        tuples instead of just audio_bytes. The caller captures global_ids from the first
        yield and passes them to generate_speech_stream_indic_continuation() for subsequent chunks.
        
        Use Case (chunked streaming with voice consistency):
            globals_captured = None
            async for audio_bytes, global_ids in pipeline.generate_speech_stream_indic_first_chunk(...):
                if globals_captured is None and global_ids:
                    globals_captured = global_ids  # Capture once
                yield audio_bytes
            # Now use globals_captured for subsequent chunks
        
        Args:
            speaker: Speaker name
            text: First text chunk to synthesize
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max tokens to generate
            repetition_penalty: Prevent repetition
            seed: Random seed for reproducibility
        
        Yields:
            Tuple of (audio_bytes, global_ids)
            - audio_bytes: Raw PCM audio (int16, 16kHz)
            - global_ids: List of 32 global token IDs (populated after first decode, else empty)
        
        Thread Safety:
            This method is stateless and thread-safe. Each call creates its own
            request-scoped state. No global state is modified or shared.
        """
        import time
        t_start = time.time()
        
        # Build prompt using Indic prompt builder
        prompt = self.prompt_builder.build_prefix(speaker, text)
        
        # Configure sampling
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop_token_ids=[CODE_END_TOKEN_ID],
            seed=seed,
        )
        
        # BiCodec token buffers
        semantic_buffer = []
        global_buffer = []
        processed_token_count = 0  # Track how many vLLM tokens we've parsed
        total_audio_chunks = 0
        
        # BiCodec streaming configuration
        EXPECTED_GLOBAL_COUNT = 32
        MIN_SEMANTIC_FOR_FIRST_CHUNK = 16
        DECODE_INTERVAL = 24
        CROSSFADE_MS = 50
        
        # Generate unique request ID
        import uuid
        request_id = f"bicodec-first-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Get tokenizer
        tokenizer = getattr(self.model, "tokenizer", None)
        if tokenizer is None:
            tokenizer = getattr(self.model.engine, "tokenizer", None)
        if tokenizer is None:
            raise RuntimeError("Spark TTS tokenizer unavailable on model or engine")
        
        # OPTIMIZATION: Create incremental token parser
        token_parser = BiCodecTokenParser(tokenizer)
        
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track state
        last_decode_count = 0
        total_samples_emitted_to_user = 0
        previous_chunk_tail: Optional[bytes] = None
        captured_globals: List[int] = []  # Will be populated once we have 32 globals
        
        # Stream tokens with BiCodec extraction
        async for request_output in results_generator:
            generated_ids = request_output.outputs[0].token_ids
            
            # OPTIMIZATION: Incremental parsing (O(1) per new token)
            new_token_ids = generated_ids[processed_token_count:]
            processed_token_count = len(generated_ids)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer)
            
            # Capture global tokens once we have all 32
            if not captured_globals and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                captured_globals = global_buffer[:EXPECTED_GLOBAL_COUNT]
            
            # Process if we have all global tokens
            if len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                semantic_count = len(semantic_buffer)
                if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                    if semantic_count - last_decode_count >= DECODE_INTERVAL:
                        last_decode_count = semantic_count
                        
                        all_semantic = semantic_buffer
                        window_global = global_buffer[:EXPECTED_GLOBAL_COUNT]
                        
                        if self.snac_decoder.enable_batching:
                            audio_bytes = await self.snac_decoder.decode_single_async(
                                semantic_ids=all_semantic,
                                global_ids=window_global,
                                trim_warmup=False,
                                use_sliding_window=False
                            )
                        else:
                            audio_bytes = self.snac_decoder.decode_streaming(
                                semantic_ids=all_semantic,
                                global_ids=window_global,
                                use_sliding_window=False,
                                trim_warmup=False
                            )
                        
                        if audio_bytes:
                            total_samples_decoded = len(audio_bytes) // 2
                            
                            if total_samples_decoded > total_samples_emitted_to_user:
                                new_bytes_start = total_samples_emitted_to_user * 2
                                new_audio_bytes = audio_bytes[new_bytes_start:]
                                
                                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                    previous_chunk_tail,
                                    new_audio_bytes,
                                    sample_rate_hz=16000,
                                    crossfade_ms=CROSSFADE_MS,
                                )
                                
                                samples_emitted_in_this_chunk = len(to_emit) // 2
                                total_samples_emitted_to_user += samples_emitted_in_this_chunk
                                
                                if to_emit:
                                    total_audio_chunks += 1
                                    # Yield tuple: (audio_bytes, global_ids)
                                    yield (to_emit, captured_globals)
        
        # Final chunk processing
        if len(semantic_buffer) > last_decode_count and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,
                global_ids=global_buffer[:EXPECTED_GLOBAL_COUNT],
                use_sliding_window=False,
                trim_warmup=False
            )
            if audio_bytes:
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]
                    
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )
                    
                    samples_emitted_in_final = len(to_emit) // 2
                    total_samples_emitted_to_user += samples_emitted_in_final
                    
                    if to_emit:
                        total_audio_chunks += 1
                        yield (to_emit, captured_globals)
        
        # Flush remaining tail
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            yield (previous_chunk_tail, captured_globals)
        
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000
        print(f"🎵 First chunk complete: {audio_duration_s:.2f}s audio, {total_audio_chunks} chunks, captured {len(captured_globals)} globals")
    
    async def generate_speech_stream_indic_continuation(
        self,
        speaker: str,
        text: str,
        global_ids: List[int],
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech for CONTINUATION chunks using pre-captured global tokens.
        
        This is the key method for voice consistency in chunked generation:
        - Uses global_ids from the first chunk to maintain identical voice
        - Model generates only semantic tokens (skips global token generation)
        - Ensures no voice drift across chunks
        
        CRITICAL for production:
        - global_ids is request-scoped, passed explicitly (no shared state)
        - Thread-safe: each request has its own state
        - No global caching that could cause cross-request contamination
        
        Args:
            speaker: Speaker name (must match first chunk)
            text: Continuation text chunk to synthesize
            global_ids: 32 global token IDs captured from first chunk
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max tokens to generate
            repetition_penalty: Prevent repetition
            seed: Random seed (use same as first chunk for consistency)
        
        Yields:
            Audio bytes (int16 PCM, 16kHz mono)
        
        Raises:
            ValueError: If global_ids doesn't contain exactly 32 tokens
        
        Thread Safety:
            Fully thread-safe. All state is request-scoped and passed explicitly.
        """
        import time
        t_start = time.time()
        
        # Validate global tokens
        EXPECTED_GLOBAL_COUNT = 32
        if len(global_ids) != EXPECTED_GLOBAL_COUNT:
            raise ValueError(
                f"Expected exactly {EXPECTED_GLOBAL_COUNT} global tokens, got {len(global_ids)}. "
                f"global_ids must be captured from first chunk generation."
            )
        
        # Build prompt WITH pre-filled global tokens
        # This tells the model to skip global generation and go straight to semantic tokens
        prompt = self.prompt_builder.build_prefix_with_globals(speaker, text, global_ids)
        
        # Configure sampling
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop_token_ids=[CODE_END_TOKEN_ID],
            seed=seed,
        )
        
        # BiCodec token buffers
        semantic_buffer = []
        global_buffer_unused = []  # Not used in continuation, but needed for parser API
        processed_token_count = 0  # Track how many vLLM tokens we've parsed
        total_audio_chunks = 0
        
        # BiCodec streaming configuration
        MIN_SEMANTIC_FOR_FIRST_CHUNK = 16
        DECODE_INTERVAL = 24
        CROSSFADE_MS = 50
        
        # Generate unique request ID
        import uuid
        request_id = f"bicodec-cont-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Get tokenizer
        tokenizer = getattr(self.model, "tokenizer", None)
        if tokenizer is None:
            tokenizer = getattr(self.model.engine, "tokenizer", None)
        if tokenizer is None:
            raise RuntimeError("Spark TTS tokenizer unavailable on model or engine")
        
        # OPTIMIZATION: Create incremental token parser
        token_parser = BiCodecTokenParser(tokenizer)
        
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track state
        last_decode_count = 0
        total_samples_emitted_to_user = 0
        previous_chunk_tail: Optional[bytes] = None
        
        # Stream tokens - model will generate ONLY semantic tokens (globals are pre-filled)
        async for request_output in results_generator:
            generated_ids = request_output.outputs[0].token_ids
            
            # OPTIMIZATION: Incremental parsing (O(1) per new token)
            # Note: In continuation, model only generates semantic tokens (globals pre-filled in prompt)
            new_token_ids = generated_ids[processed_token_count:]
            processed_token_count = len(generated_ids)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer_unused)
            
            semantic_count = len(semantic_buffer)
            if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                if semantic_count - last_decode_count >= DECODE_INTERVAL:
                    last_decode_count = semantic_count
                    
                    # Use pre-captured global tokens for decoding
                    if self.snac_decoder.enable_batching:
                        audio_bytes = await self.snac_decoder.decode_single_async(
                            semantic_ids=semantic_buffer,
                            global_ids=global_ids,  # Use captured globals
                            trim_warmup=False,
                            use_sliding_window=False
                        )
                    else:
                        audio_bytes = self.snac_decoder.decode_streaming(
                            semantic_ids=semantic_buffer,
                            global_ids=global_ids,  # Use captured globals
                            use_sliding_window=False,
                            trim_warmup=False
                        )
                    
                    if audio_bytes:
                        total_samples_decoded = len(audio_bytes) // 2
                        
                        if total_samples_decoded > total_samples_emitted_to_user:
                            new_bytes_start = total_samples_emitted_to_user * 2
                            new_audio_bytes = audio_bytes[new_bytes_start:]
                            
                            to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                previous_chunk_tail,
                                new_audio_bytes,
                                sample_rate_hz=16000,
                                crossfade_ms=CROSSFADE_MS,
                            )
                            
                            samples_emitted_in_this_chunk = len(to_emit) // 2
                            total_samples_emitted_to_user += samples_emitted_in_this_chunk
                            
                            if to_emit:
                                total_audio_chunks += 1
                                yield to_emit
        
        # Final chunk processing
        if len(semantic_buffer) > last_decode_count:
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,
                global_ids=global_ids,  # Use captured globals
                use_sliding_window=False,
                trim_warmup=False
            )
            if audio_bytes:
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]
                    
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )
                    
                    samples_emitted_in_final = len(to_emit) // 2
                    total_samples_emitted_to_user += samples_emitted_in_final
                    
                    if to_emit:
                        total_audio_chunks += 1
                        yield to_emit
        
        # Flush remaining tail
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            yield previous_chunk_tail
        
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000
        print(f"🎵 Continuation chunk complete: {audio_duration_s:.2f}s audio, {total_audio_chunks} chunks (using cached globals)")

