"""
Veena3 Streaming Pipeline - Sliding Window Approach

Implements streaming for both SNAC and BiCodec tokens.

For BiCodec (Spark TTS):
1. Generate tokens from vLLM
2. Parse NEW token IDs incrementally (O(1) per token via cache)
3. Buffer semantic + global tokens
4. Apply sliding window (every N token pairs)
5. Decode and stream audio chunks

OPTIMIZATION (Dec 2025):
- Replaced O(n²) pattern (decode-all + regex-all per iteration)
- Now uses incremental token parsing with BiCodecTokenParser
- ~10x CPU reduction in streaming hot loop

For SNAC (legacy):
1. Filter SNAC token IDs directly
2. Apply sliding window
3. Decode and stream
"""

import asyncio
import logging
import os
import re
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from vllm import SamplingParams
from vllm.sampling_params import RequestOutputKind
from veena3modal.audio.crossfade import crossfade_bytes_int16
from veena3modal.core.token_utils import BiCodecTokenParser

from veena3modal.core.constants import (
    CODE_END_TOKEN_ID,
    CODE_START_TOKEN_ID,
    SNAC_MIN_ID,
    SNAC_MAX_ID,
    TRAINING_STOP_TOKEN_IDS,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    DEFAULT_MAX_TOKENS,
    DEFAULT_MIN_TOKENS,
    DEFAULT_REPETITION_PENALTY,
    DEFAULT_SEED,
)

logger = logging.getLogger(__name__)
_STREAM_OUTPUT_KIND_RAW = os.environ.get("VEENA3_STREAM_OUTPUT_KIND", "delta").strip().lower()
_STREAM_USE_DELTA_OUTPUT = _STREAM_OUTPUT_KIND_RAW in {"delta", "d", "1", "true", "yes", "on"}


def _env_int(name: str, default: int, minimum: int = 0) -> int:
    raw = os.environ.get(name, "")
    if not raw:
        return default
    try:
        value = int(raw)
    except ValueError:
        return default
    return max(minimum, value)


def _env_flag(name: str, default: bool) -> bool:
    raw = os.environ.get(name, "")
    if not raw:
        return default
    return raw.strip().lower() in {"1", "true", "yes", "on"}


_STREAM_EXPECTED_GLOBAL_COUNT = 32
_STREAM_MIN_SEMANTIC_FIRST_CHUNK = _env_int("VEENA3_STREAM_MIN_SEMANTIC_FIRST", 10, minimum=1)
_STREAM_DECODE_INTERVAL = _env_int("VEENA3_STREAM_DECODE_INTERVAL", 48, minimum=1)
_STREAM_CROSSFADE_MS = _env_int("VEENA3_STREAM_CROSSFADE_MS", 50, minimum=0)
_STREAM_WINDOW_SIZE = _env_int("VEENA3_STREAM_WINDOW_SIZE", 128, minimum=16)
_STREAM_ADAPTIVE_DECODE = _env_flag("VEENA3_STREAM_ADAPTIVE_DECODE", False)
_STREAM_DECODE_INTERVAL_FIRST = _env_int("VEENA3_STREAM_DECODE_INTERVAL_FIRST", _STREAM_DECODE_INTERVAL, minimum=1)
_STREAM_DECODE_INTERVAL_BUSY = _env_int(
    "VEENA3_STREAM_DECODE_INTERVAL_BUSY",
    max(_STREAM_DECODE_INTERVAL, 64),
    minimum=1,
)
_STREAM_DECODE_BUSY_PENDING = _env_int("VEENA3_STREAM_DECODE_BUSY_PENDING", 32, minimum=1)
_STREAM_WINDOWED_DECODE = os.environ.get("VEENA3_STREAM_WINDOWED_DECODE", "true").strip().lower() in {
    "1",
    "true",
    "yes",
    "on",
}


class Veena3SlidingWindowPipeline:
    """
    Streaming TTS pipeline using sliding window approach.
    
    This eliminates choppy audio and popping artifacts by:
    - Decoding overlapping 28-token windows (4 frames)
    - Keeping only the middle 2048 samples from each decode
    - Creating natural continuity between chunks
    
    Based on the official Canopy Labs implementation.
    """
    
    def __init__(
        self,
        model,
        prompt_builder,
        snac_decoder,
    ):
        """
        Initialize sliding window streaming pipeline.
        
        Args:
            model: Veena3Model instance
            prompt_builder: Veena3PromptBuilder instance
            snac_decoder: SNACDecoder instance (with batching enabled)
        """
        self.model = model
        self.prompt_builder = prompt_builder
        self.snac_decoder = snac_decoder
        
        # OPTIMIZATION: Pre-warm BiCodecTokenParser once at init, not per-request
        # Saves ~123ms per streaming TTFB by avoiding repeated 166K-entry vocab iteration
        tokenizer = getattr(model, "tokenizer", None)
        if tokenizer is None:
            tokenizer = getattr(model.engine, "tokenizer", None)
        self.token_parser = BiCodecTokenParser(tokenizer) if tokenizer else None
        
        print(f"🌊 Veena3SlidingWindowPipeline initialized (sliding window: 28 tokens)")

    @staticmethod
    def _p50(values: List[float]) -> float:
        if not values:
            return 0.0
        ordered = sorted(values)
        return ordered[(len(ordered) - 1) // 2]

    @staticmethod
    def _to_number(value: Any) -> Optional[float]:
        if value is None:
            return None
        if isinstance(value, (int, float)):
            return float(value)
        try:
            return float(value)
        except (TypeError, ValueError):
            return None

    def _extract_request_metrics(self, raw_metrics: Any) -> Dict[str, float]:
        if raw_metrics is None:
            return {}

        numeric: Dict[str, float] = {}
        items: List[Tuple[str, Any]] = []

        if isinstance(raw_metrics, dict):
            items.extend(raw_metrics.items())
        else:
            if hasattr(raw_metrics, "__dict__"):
                try:
                    items.extend(dict(raw_metrics.__dict__).items())
                except Exception:
                    pass
            for name in dir(raw_metrics):
                if name.startswith("_"):
                    continue
                try:
                    value = getattr(raw_metrics, name)
                except Exception:
                    continue
                if callable(value):
                    continue
                items.append((name, value))

        seen: set[str] = set()
        for key, value in items:
            key_s = str(key)
            if key_s in seen:
                continue
            seen.add(key_s)
            parsed = self._to_number(value)
            if parsed is not None:
                numeric[key_s] = parsed
        return numeric

    def _apply_stats(self, perf: Dict[str, Any], prefix: str, values: List[float]) -> None:
        if not values:
            return
        perf[f"{prefix}_total"] = float(sum(values))
        perf[f"{prefix}_min"] = float(min(values))
        perf[f"{prefix}_max"] = float(max(values))
        perf[f"{prefix}_p50"] = float(self._p50(values))

    def _get_decoder_pending_requests(self) -> int:
        getter = getattr(self.snac_decoder, "get_pending_requests", None)
        if not callable(getter):
            return 0
        try:
            return max(0, int(getter()))
        except Exception:
            return 0

    def _resolve_decode_interval(self, first_chunk_emitted: bool) -> Tuple[int, int]:
        interval = int(_STREAM_DECODE_INTERVAL)
        pending = 0
        if not _STREAM_ADAPTIVE_DECODE:
            return interval, pending

        if not first_chunk_emitted:
            interval = int(_STREAM_DECODE_INTERVAL_FIRST)

        pending = self._get_decoder_pending_requests()
        if pending >= _STREAM_DECODE_BUSY_PENDING:
            interval = max(interval, int(_STREAM_DECODE_INTERVAL_BUSY))
        return max(1, interval), pending
    
    async def generate_speech_stream(
        self,
        description: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[bytes, None]:
        """
        Generate speech audio with sliding window streaming.
        
        Yields audio chunks using overlapping windows for smooth playback.
        
        Args:
            description: Character/voice description
            text: Text to synthesize (with optional <emotion> tags)
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max SNAC tokens to generate
            repetition_penalty: Prevent loops
        
        Yields:
            Audio bytes (int16 PCM, 24kHz mono)
        """
        logger.debug("sliding-window streaming generation started")
        logger.debug("description_prefix=%s", description[:80])
        logger.debug("text_len=%d", len(text))
        
        # Build prompt
        prompt = self.prompt_builder.build_prefix(description, text)
        logger.debug("prompt built chars=%d", len(prompt))
        
        # Configure sampling
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop_token_ids=[CODE_END_TOKEN_ID],  # Only audio EOS
            seed=seed,  # None = random, int = reproducible
        )
        
        logger.debug("sampling temp=%s top_p=%s sliding_window=28", temperature, top_p)
        
        # Token buffer - keeps ALL tokens (not chunked)
        token_buffer = []
        total_tokens_generated = 0
        total_audio_chunks = 0
        
        # Generate tokens with vLLM (streaming)
        print(f"🔮 Starting token generation...")
        
        # Generate unique request ID for concurrent streaming support
        import uuid
        import time
        request_id = f"slide-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Stream tokens with sliding window
        async for request_output in results_generator:
            # Extract generated token IDs
            generated_ids = request_output.outputs[0].token_ids
            
            # Process only new tokens
            new_tokens = generated_ids[total_tokens_generated:]
            total_tokens_generated = len(generated_ids)
            
            # Filter and buffer SNAC tokens
            for token_id in new_tokens:
                if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
                    token_buffer.append(token_id)
                    
                    # Process every 7 tokens (1 frame) when we have enough for sliding window
                    # Official approach: once count > 27, take last 28 tokens
                    if len(token_buffer) % 7 == 0 and len(token_buffer) > 27:
                        # Sliding window: take last 28 tokens (4 frames)
                        window_tokens = token_buffer[-28:]
                        
                        # Decode with sliding window mode (returns middle 2048 samples only)
                        if self.snac_decoder.enable_batching:
                            audio_bytes = await self.snac_decoder.decode_single_async(
                                window_tokens, 
                                trim_warmup=False,  # Sliding window handles trimming
                                use_sliding_window=True  # CRITICAL: Use sliding window mode
                            )
                        else:
                            audio_bytes = self.snac_decoder.decode_to_bytes(
                                window_tokens, 
                                trim_warmup=False,
                                use_sliding_window=True
                            )
                        
                        if audio_bytes:
                            total_audio_chunks += 1
                            if total_audio_chunks == 1:
                                logger.debug("first sliding-window chunk decoded bytes=%d", len(audio_bytes))
                            yield audio_bytes
        
        # Note: No final chunk processing needed - sliding window handles all tokens
        # as they come in (every 7 tokens after the first 28)
        
        logger.debug(
            "sliding-window complete tokens=%d chunks=%d",
            total_tokens_generated,
            total_audio_chunks,
        )
    
    def _extract_bicodec_tokens_from_text(self, text: str) -> tuple[List[int], List[int]]:
        """
        Extract BiCodec semantic and global tokens from generated text.
        
        Args:
            text: Generated text containing BiCodec token markers
        
        Returns:
            Tuple of (semantic_ids, global_ids)
        """
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", text)
        semantic_ids = [int(t) for t in semantic_matches]
        
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", text)
        global_ids = [int(t) for t in global_matches]
        
        return semantic_ids, global_ids
    
    async def generate_speech_stream_indic(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,  # Added for Spark TTS parity with non-streaming
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
        emit_progress: bool = False,
    ) -> AsyncGenerator[Union[bytes, Tuple[bytes, Dict[str, Any]]], None]:
        """
        Generate speech audio with sliding window streaming for Indic model.
        
        Yields audio chunks using overlapping windows for smooth playback.
        
        Args:
            speaker: Speaker name (one of 12 predefined speakers)
            text: Text to synthesize with inline emotion tags
                Examples:
                - "Hello! Welcome."
                - "<laugh> Hello there!"
                - "नमस्ते! <excited> आज का दिन बहुत अच्छा है।"
            temperature: Sampling temperature
            top_p: Nucleus sampling
            max_tokens: Max SNAC tokens to generate
            repetition_penalty: Prevent loops
            seed: Random seed for reproducibility
        
        Yields:
            Audio bytes (int16 PCM, 16kHz mono - BiCodec)
        """
        import time
        import uuid

        t_request_start = time.perf_counter()
        timeline_marks: List[Dict[str, float]] = [{"stage": "request_start", "t_ms": 0.0, "dt_ms": 0.0}]
        last_mark_ms = 0.0
        mark_stage_set: set[str] = {"request_start"}

        def _mark(stage: str) -> float:
            nonlocal last_mark_ms
            now_ms = (time.perf_counter() - t_request_start) * 1000.0
            dt_ms = now_ms - last_mark_ms
            last_mark_ms = now_ms
            timeline_marks.append({"stage": stage, "t_ms": float(now_ms), "dt_ms": float(dt_ms)})
            mark_stage_set.add(stage)
            return float(now_ms)

        prompt_t0 = time.perf_counter()
        prompt = self.prompt_builder.build_prefix(speaker, text)
        prompt_build_ms = (time.perf_counter() - prompt_t0) * 1000.0
        _mark("prompt_ready")

        sampling_t0 = time.perf_counter()
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop=TRAINING_STOP_TOKEN_IDS,
            skip_special_tokens=False,
            seed=seed,
            output_kind=(RequestOutputKind.DELTA if _STREAM_USE_DELTA_OUTPUT else RequestOutputKind.CUMULATIVE),
        )
        sampling_params_ms = (time.perf_counter() - sampling_t0) * 1000.0
        _mark("sampling_ready")

        request_id = f"bicodec-stream-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        token_parser = self.token_parser
        if token_parser is None:
            tokenizer = getattr(self.model, "tokenizer", None) or getattr(self.model.engine, "tokenizer", None)
            token_parser = BiCodecTokenParser(tokenizer) if tokenizer else None
        if token_parser is None:
            raise RuntimeError("BiCodecTokenParser unavailable for streaming")

        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )

        semantic_buffer: List[int] = []
        global_buffer: List[int] = []
        processed_token_count = 0
        total_generated_tokens = 0
        total_audio_chunks = 0
        total_samples_emitted_to_user = 0
        previous_chunk_tail: Optional[bytes] = None
        last_decode_count = 0
        last_decode_end_ts: Optional[float] = None
        t_first_chunk_decode_done: Optional[float] = None
        t_first_chunk_emitted: Optional[float] = None
        final_request_metrics: Dict[str, float] = {}
        prev_gpu_cumulative_s: Optional[float] = None

        llm_batch_wall_ms: List[float] = []
        llm_batch_gpu_ms: List[float] = []
        llm_decode_wall_ms: List[float] = []
        llm_decode_gpu_ms: List[float] = []
        llm_parse_ms_samples: List[float] = []
        tokens_per_batch: List[float] = []
        bicodec_decode_wall_ms: List[float] = []
        bicodec_decode_tokens: List[float] = []
        bicodec_decode_intervals_ms: List[float] = []
        decode_interval_applied_tokens: List[float] = []
        decode_pending_requests: List[float] = []

        t_llm_start = time.perf_counter()
        t_last_batch = t_llm_start

        async for request_output in results_generator:
            now = time.perf_counter()
            batch_wall = (now - t_last_batch) * 1000.0
            t_last_batch = now

            outputs = request_output.outputs or []
            if not outputs:
                continue

            generated_ids = outputs[0].token_ids
            if _STREAM_USE_DELTA_OUTPUT:
                new_token_ids = generated_ids
                processed_token_count += len(new_token_ids)
            else:
                new_token_ids = generated_ids[processed_token_count:]
                processed_token_count = len(generated_ids)

            new_token_count = len(new_token_ids)
            if new_token_count <= 0:
                continue

            total_generated_tokens += new_token_count
            llm_batch_wall_ms.append(float(batch_wall))
            tokens_per_batch.append(float(new_token_count))

            if "llm_first_batch" not in mark_stage_set:
                _mark("llm_first_batch")
            elif llm_batch_wall_ms:
                llm_decode_wall_ms.append(float(batch_wall))

            raw_req_metrics = self._extract_request_metrics(getattr(request_output, "metrics", None))
            if raw_req_metrics:
                final_request_metrics = raw_req_metrics
                cumulative_gpu_s: Optional[float] = None
                if "model_execute_time" in raw_req_metrics:
                    cumulative_gpu_s = raw_req_metrics["model_execute_time"]
                elif "model_forward_time" in raw_req_metrics:
                    cumulative_gpu_s = raw_req_metrics["model_forward_time"]

                if cumulative_gpu_s is not None:
                    if prev_gpu_cumulative_s is None:
                        gpu_delta_ms = cumulative_gpu_s * 1000.0
                    elif cumulative_gpu_s >= prev_gpu_cumulative_s:
                        gpu_delta_ms = (cumulative_gpu_s - prev_gpu_cumulative_s) * 1000.0
                    else:
                        gpu_delta_ms = cumulative_gpu_s * 1000.0
                    prev_gpu_cumulative_s = cumulative_gpu_s
                    if gpu_delta_ms >= 0.0:
                        llm_batch_gpu_ms.append(float(gpu_delta_ms))
                        if len(llm_batch_wall_ms) > 1:
                            llm_decode_gpu_ms.append(float(gpu_delta_ms))

            parse_t0 = time.perf_counter()
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer)
            llm_parse_ms_samples.append((time.perf_counter() - parse_t0) * 1000.0)

            if len(global_buffer) >= _STREAM_EXPECTED_GLOBAL_COUNT:
                semantic_count = len(semantic_buffer)
                if semantic_count >= _STREAM_MIN_SEMANTIC_FIRST_CHUNK:
                    if "first_chunk_ready" not in mark_stage_set:
                        _mark("first_chunk_ready")

                    decode_interval, pending_requests = self._resolve_decode_interval(
                        first_chunk_emitted=t_first_chunk_emitted is not None
                    )
                    semantic_delta = semantic_count - last_decode_count
                    if semantic_delta >= decode_interval:
                        decode_step_tokens = semantic_delta
                        last_decode_count = semantic_count
                        decode_interval_applied_tokens.append(float(decode_interval))
                        decode_pending_requests.append(float(pending_requests))
                        decode_global = global_buffer[:_STREAM_EXPECTED_GLOBAL_COUNT]

                        if (
                            _STREAM_WINDOWED_DECODE
                            and semantic_count > _STREAM_WINDOW_SIZE
                        ):
                            decode_semantic = semantic_buffer[-_STREAM_WINDOW_SIZE:]
                        else:
                            decode_semantic = semantic_buffer

                        decode_t0 = time.perf_counter()
                        audio_bytes = await self.snac_decoder.decode_single_async(
                            semantic_ids=decode_semantic,
                            global_ids=decode_global,
                            trim_warmup=False,
                            use_sliding_window=False,
                        )
                        decode_t1 = time.perf_counter()

                        decode_wall_ms = (decode_t1 - decode_t0) * 1000.0
                        bicodec_decode_wall_ms.append(float(decode_wall_ms))
                        bicodec_decode_tokens.append(float(len(decode_semantic)))
                        if last_decode_end_ts is not None:
                            bicodec_decode_intervals_ms.append((decode_t1 - last_decode_end_ts) * 1000.0)
                        last_decode_end_ts = decode_t1

                        if audio_bytes:
                            total_samples_decoded = len(audio_bytes) // 2
                            if (
                                _STREAM_WINDOWED_DECODE
                                and semantic_count > _STREAM_WINDOW_SIZE
                            ):
                                new_samples_approx = decode_step_tokens * 320
                                if total_samples_decoded > new_samples_approx:
                                    new_bytes_start = (total_samples_decoded - new_samples_approx) * 2
                                else:
                                    new_bytes_start = 0
                                new_audio_bytes = audio_bytes[new_bytes_start:]
                            else:
                                if total_samples_decoded > total_samples_emitted_to_user:
                                    new_bytes_start = total_samples_emitted_to_user * 2
                                    new_audio_bytes = audio_bytes[new_bytes_start:]
                                else:
                                    new_audio_bytes = b""

                            if new_audio_bytes:
                                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                    previous_chunk_tail,
                                    new_audio_bytes,
                                    sample_rate_hz=16000,
                                    crossfade_ms=_STREAM_CROSSFADE_MS,
                                )
                                if to_emit:
                                    total_audio_chunks += 1
                                    total_samples_emitted_to_user += len(to_emit) // 2
                                    if t_first_chunk_decode_done is None:
                                        t_first_chunk_decode_done = decode_t1
                                        _mark("first_chunk_decoded")
                                    if t_first_chunk_emitted is None:
                                        t_first_chunk_emitted = time.perf_counter()
                                        _mark("first_chunk_emitted")
                                    if emit_progress:
                                        yield (to_emit, {})
                                    else:
                                        yield to_emit

        llm_generation_wall_ms = (time.perf_counter() - t_llm_start) * 1000.0
        _mark("llm_done")

        if len(semantic_buffer) > last_decode_count and len(global_buffer) >= _STREAM_EXPECTED_GLOBAL_COUNT:
            decode_semantic = semantic_buffer
            decode_t0 = time.perf_counter()
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=decode_semantic,
                global_ids=global_buffer[:_STREAM_EXPECTED_GLOBAL_COUNT],
                use_sliding_window=False,
                trim_warmup=False,
            )
            decode_t1 = time.perf_counter()
            decode_wall_ms = (decode_t1 - decode_t0) * 1000.0
            bicodec_decode_wall_ms.append(float(decode_wall_ms))
            bicodec_decode_tokens.append(float(len(decode_semantic)))
            if last_decode_end_ts is not None:
                bicodec_decode_intervals_ms.append((decode_t1 - last_decode_end_ts) * 1000.0)
            last_decode_end_ts = decode_t1

            if audio_bytes:
                total_samples_decoded = len(audio_bytes) // 2
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=_STREAM_CROSSFADE_MS,
                    )
                    if to_emit:
                        total_audio_chunks += 1
                        total_samples_emitted_to_user += len(to_emit) // 2
                        if t_first_chunk_decode_done is None:
                            t_first_chunk_decode_done = decode_t1
                            _mark("first_chunk_decoded")
                        if t_first_chunk_emitted is None:
                            t_first_chunk_emitted = time.perf_counter()
                            _mark("first_chunk_emitted")
                        if emit_progress:
                            yield (to_emit, {})
                        else:
                            yield to_emit

        if previous_chunk_tail:
            total_audio_chunks += 1
            total_samples_emitted_to_user += len(previous_chunk_tail) // 2
            if t_first_chunk_emitted is None:
                t_first_chunk_emitted = time.perf_counter()
                _mark("first_chunk_emitted")
            if emit_progress:
                yield (previous_chunk_tail, {})
            else:
                yield previous_chunk_tail

        _mark("stream_done")
        _mark("request_done")

        perf: Dict[str, Any] = {
            "prompt_build_ms": float(prompt_build_ms),
            "sampling_params_ms": float(sampling_params_ms),
            "llm_generation_wall_ms": float(llm_generation_wall_ms),
            "llm_batch_count": int(len(llm_batch_wall_ms)),
            "llm_token_total": int(total_generated_tokens),
            "semantic_token_total": int(len(semantic_buffer)),
            "global_token_total": int(len(global_buffer)),
            "llm_parse_ms": float(sum(llm_parse_ms_samples)),
            "bicodec_decode_calls": int(len(bicodec_decode_wall_ms)),
            "llm_decode_calls": int(len(bicodec_decode_wall_ms)),
            "chunks_sent": int(total_audio_chunks),
            "audio_duration_seconds": float(total_samples_emitted_to_user / 16000.0),
            "ttfb_ms": float((t_first_chunk_emitted - t_request_start) * 1000.0 if t_first_chunk_emitted else 0.0),
            "stream_output_kind": "delta" if _STREAM_USE_DELTA_OUTPUT else "cumulative",
            "stream_decode_interval_tokens": int(_STREAM_DECODE_INTERVAL),
            "stream_decode_interval_first_tokens": int(_STREAM_DECODE_INTERVAL_FIRST),
            "stream_decode_interval_busy_tokens": int(_STREAM_DECODE_INTERVAL_BUSY),
            "stream_decode_busy_pending_threshold": int(_STREAM_DECODE_BUSY_PENDING),
            "stream_adaptive_decode": bool(_STREAM_ADAPTIVE_DECODE),
            "stream_window_size_tokens": int(_STREAM_WINDOW_SIZE),
            "stream_windowed_decode": bool(_STREAM_WINDOWED_DECODE),
            "timeline_markers": list(timeline_marks),
            "timeline_total_ms": float((time.perf_counter() - t_request_start) * 1000.0),
        }

        decoder_batch_stats = {}
        if hasattr(self.snac_decoder, "get_batching_stats"):
            try:
                decoder_batch_stats = self.snac_decoder.get_batching_stats() or {}
            except Exception:
                decoder_batch_stats = {}
        if decoder_batch_stats:
            perf.update(decoder_batch_stats)

        self._apply_stats(perf, "llm_batch_wall_ms", llm_batch_wall_ms)
        self._apply_stats(perf, "llm_batch_gpu_ms", llm_batch_gpu_ms)
        self._apply_stats(perf, "llm_decode_wall_ms", llm_decode_wall_ms)
        self._apply_stats(perf, "llm_decode_gpu_ms", llm_decode_gpu_ms)
        self._apply_stats(perf, "tokens_per_batch", tokens_per_batch)
        self._apply_stats(perf, "llm_parse_step_ms", llm_parse_ms_samples)
        self._apply_stats(perf, "bicodec_decode_wall_ms", bicodec_decode_wall_ms)
        self._apply_stats(perf, "bicodec_decode_tokens", bicodec_decode_tokens)
        self._apply_stats(perf, "bicodec_decode_interval_ms", bicodec_decode_intervals_ms)
        self._apply_stats(perf, "decode_interval_applied_tokens", decode_interval_applied_tokens)
        self._apply_stats(perf, "decode_pending_requests", decode_pending_requests)

        perf["bicodec_decode_wall_ms"] = float(perf.get("bicodec_decode_wall_ms_total", 0.0))
        perf["bicodec_decode_gpu_ms"] = 0.0
        perf["bicodec_decode_cpu_ms"] = float(perf.get("bicodec_decode_wall_ms_total", 0.0))

        if total_generated_tokens > 0:
            perf["llm_time_per_token_ms"] = float(llm_generation_wall_ms) / float(total_generated_tokens)
        if llm_batch_wall_ms:
            perf["llm_time_per_batch_wall_ms"] = float(sum(llm_batch_wall_ms)) / float(len(llm_batch_wall_ms))
        if llm_batch_gpu_ms:
            perf["llm_time_per_batch_gpu_ms"] = float(sum(llm_batch_gpu_ms)) / float(len(llm_batch_gpu_ms))

        for mark in timeline_marks:
            stage = mark.get("stage")
            if not stage:
                continue
            perf[f"timeline_{stage}_ms"] = float(mark.get("t_ms", 0.0))
            perf[f"timeline_{stage}_delta_ms"] = float(mark.get("dt_ms", 0.0))

        tl_first_batch = perf.get("timeline_llm_first_batch_ms")
        tl_first_emit = perf.get("timeline_first_chunk_emitted_ms")
        tl_done = perf.get("timeline_request_done_ms")
        if isinstance(tl_first_batch, (int, float)):
            perf["timeline_to_first_batch_ms"] = float(tl_first_batch)
        if isinstance(tl_first_emit, (int, float)):
            perf["timeline_to_first_chunk_emitted_ms"] = float(tl_first_emit)
        if isinstance(tl_first_batch, (int, float)) and isinstance(tl_first_emit, (int, float)) and tl_first_emit >= tl_first_batch:
            perf["timeline_first_batch_to_first_chunk_emitted_ms"] = float(tl_first_emit - tl_first_batch)
        if isinstance(tl_first_emit, (int, float)) and isinstance(tl_done, (int, float)) and tl_done >= tl_first_emit:
            perf["timeline_first_chunk_to_request_done_ms"] = float(tl_done - tl_first_emit)

        sec_fields = {
            "time_in_queue": "llm_time_in_queue_ms",
            "scheduler_time": "llm_scheduler_ms",
            "model_forward_time": "llm_model_forward_ms",
            "model_execute_time": "llm_model_execute_ms",
            "first_token_latency": "llm_first_token_ms",
        }
        for src, dst in sec_fields.items():
            if src in final_request_metrics:
                perf[dst] = float(final_request_metrics[src]) * 1000.0

        def _delta_ms(start_s: Optional[float], end_s: Optional[float]) -> Optional[float]:
            if start_s is None or end_s is None:
                return None
            if end_s < start_s:
                return None
            return (end_s - start_s) * 1000.0

        queued_ts = final_request_metrics.get("queued_ts")
        scheduled_ts = final_request_metrics.get("scheduled_ts")
        first_token_ts = final_request_metrics.get("first_token_ts")
        last_token_ts = final_request_metrics.get("last_token_ts")
        q_to_sched = _delta_ms(queued_ts, scheduled_ts)
        sched_to_first = _delta_ms(scheduled_ts, first_token_ts)
        first_to_last = _delta_ms(first_token_ts, last_token_ts)
        queued_to_last = _delta_ms(queued_ts, last_token_ts)
        if q_to_sched is not None:
            perf["llm_queued_to_scheduled_ms"] = q_to_sched
            perf["llm_time_in_queue_ms"] = q_to_sched
        if sched_to_first is not None:
            perf["llm_scheduled_to_first_token_ms"] = sched_to_first
            perf["llm_scheduler_ms"] = sched_to_first
        if first_to_last is not None:
            perf["llm_first_to_last_token_ms"] = first_to_last
            perf["llm_model_execute_ms"] = first_to_last
        if queued_to_last is not None:
            perf["llm_queued_to_last_token_ms"] = queued_to_last
            perf["llm_request_lifecycle_ms"] = queued_to_last

        logger.debug(
            "streaming complete audio=%.2fs ttfb=%.0fms chunks=%d llm_tokens=%d",
            perf["audio_duration_seconds"],
            perf["ttfb_ms"],
            perf["chunks_sent"],
            perf["llm_token_total"],
        )

        if emit_progress:
            yield (b"", perf)
    
    async def generate_speech_stream_indic_first_chunk(
        self,
        speaker: str,
        text: str,
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
    ) -> AsyncGenerator[tuple, None]:
        """
        Generate speech for the FIRST chunk of multi-chunk text, capturing global tokens.
        
        This method is specifically for chunked generation: it yields (audio_bytes, global_ids)
        tuples instead of just audio_bytes. The caller captures global_ids from the first
        yield and passes them to generate_speech_stream_indic_continuation() for subsequent chunks.
        
        Use Case (chunked streaming with voice consistency):
            globals_captured = None
            async for audio_bytes, global_ids in pipeline.generate_speech_stream_indic_first_chunk(...):
                if globals_captured is None and global_ids:
                    globals_captured = global_ids  # Capture once
                yield audio_bytes
            # Now use globals_captured for subsequent chunks
        
        Args:
            speaker: Speaker name
            text: First text chunk to synthesize
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max tokens to generate
            repetition_penalty: Prevent repetition
            seed: Random seed for reproducibility
        
        Yields:
            Tuple of (audio_bytes, global_ids)
            - audio_bytes: Raw PCM audio (int16, 16kHz)
            - global_ids: List of 32 global token IDs (populated after first decode, else empty)
        
        Thread Safety:
            This method is stateless and thread-safe. Each call creates its own
            request-scoped state. No global state is modified or shared.
        """
        import time
        t_start = time.time()
        
        # Build prompt using Indic prompt builder
        prompt = self.prompt_builder.build_prefix(speaker, text)
        
        # Configure sampling
        # OPTIMIZATION: Use Spark TTS stop token, not legacy SNAC CODE_END_TOKEN_ID
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop=TRAINING_STOP_TOKEN_IDS,  # "<|im_end|>" - matches non-streaming pipeline
            skip_special_tokens=False,
            seed=seed,
            output_kind=(RequestOutputKind.DELTA if _STREAM_USE_DELTA_OUTPUT else RequestOutputKind.CUMULATIVE),
        )
        
        # BiCodec token buffers
        semantic_buffer = []
        global_buffer = []
        processed_token_count = 0  # Used for cumulative mode; ignored for delta mode.
        total_audio_chunks = 0
        
        # BiCodec streaming configuration
        EXPECTED_GLOBAL_COUNT = 32
        MIN_SEMANTIC_FOR_FIRST_CHUNK = 10  # OPTIMIZATION: Lowered from 16 (decoder min is 8)
        DECODE_INTERVAL = 48  # OPTIMIZATION: Increased from 24 to reduce O(n) re-decode calls
        CROSSFADE_MS = 50
        
        # Generate unique request ID
        import uuid
        request_id = f"bicodec-first-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Reuse pre-warmed token parser (singleton, created once in __init__)
        token_parser = self.token_parser
        
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track state
        last_decode_count = 0
        total_samples_emitted_to_user = 0
        previous_chunk_tail: Optional[bytes] = None
        captured_globals: List[int] = []  # Will be populated once we have 32 globals
        
        # Stream tokens with BiCodec extraction
        async for request_output in results_generator:
            generated_ids = request_output.outputs[0].token_ids
            
            # In DELTA mode token_ids are already incremental.
            if _STREAM_USE_DELTA_OUTPUT:
                new_token_ids = generated_ids
            else:
                new_token_ids = generated_ids[processed_token_count:]
                processed_token_count = len(generated_ids)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer)
            
            # Capture global tokens once we have all 32
            if not captured_globals and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                captured_globals = global_buffer[:EXPECTED_GLOBAL_COUNT]
            
            # Process if we have all global tokens
            if len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
                semantic_count = len(semantic_buffer)
                if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                    if semantic_count - last_decode_count >= DECODE_INTERVAL:
                        last_decode_count = semantic_count
                        
                        all_semantic = semantic_buffer
                        window_global = global_buffer[:EXPECTED_GLOBAL_COUNT]
                        
                        if self.snac_decoder.enable_batching:
                            audio_bytes = await self.snac_decoder.decode_single_async(
                                semantic_ids=all_semantic,
                                global_ids=window_global,
                                trim_warmup=False,
                                use_sliding_window=False
                            )
                        else:
                            audio_bytes = self.snac_decoder.decode_streaming(
                                semantic_ids=all_semantic,
                                global_ids=window_global,
                                use_sliding_window=False,
                                trim_warmup=False
                            )
                        
                        if audio_bytes:
                            total_samples_decoded = len(audio_bytes) // 2
                            
                            if total_samples_decoded > total_samples_emitted_to_user:
                                new_bytes_start = total_samples_emitted_to_user * 2
                                new_audio_bytes = audio_bytes[new_bytes_start:]
                                
                                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                    previous_chunk_tail,
                                    new_audio_bytes,
                                    sample_rate_hz=16000,
                                    crossfade_ms=CROSSFADE_MS,
                                )
                                
                                samples_emitted_in_this_chunk = len(to_emit) // 2
                                total_samples_emitted_to_user += samples_emitted_in_this_chunk
                                
                                if to_emit:
                                    total_audio_chunks += 1
                                    # Yield tuple: (audio_bytes, global_ids)
                                    yield (to_emit, captured_globals)
        
        # Final chunk processing
        if len(semantic_buffer) > last_decode_count and len(global_buffer) >= EXPECTED_GLOBAL_COUNT:
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,
                global_ids=global_buffer[:EXPECTED_GLOBAL_COUNT],
                use_sliding_window=False,
                trim_warmup=False
            )
            if audio_bytes:
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]
                    
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )
                    
                    samples_emitted_in_final = len(to_emit) // 2
                    total_samples_emitted_to_user += samples_emitted_in_final
                    
                    if to_emit:
                        total_audio_chunks += 1
                        yield (to_emit, captured_globals)
        
        # Flush remaining tail
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            yield (previous_chunk_tail, captured_globals)
        
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000
        logger.debug(
            "first-chunk complete audio=%.2fs chunks=%d captured_globals=%d",
            audio_duration_s,
            total_audio_chunks,
            len(captured_globals),
        )
    
    async def generate_speech_stream_indic_continuation(
        self,
        speaker: str,
        text: str,
        global_ids: List[int],
        temperature: float = DEFAULT_TEMPERATURE,
        top_k: int = DEFAULT_TOP_K,
        top_p: float = DEFAULT_TOP_P,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
        seed: Optional[int] = None,
        emit_progress: bool = False,
    ) -> AsyncGenerator[Union[bytes, Tuple[bytes, Dict[str, Any]]], None]:
        """
        Generate speech for CONTINUATION chunks using pre-captured global tokens.
        
        This is the key method for voice consistency in chunked generation:
        - Uses global_ids from the first chunk to maintain identical voice
        - Model generates only semantic tokens (skips global token generation)
        - Ensures no voice drift across chunks
        
        CRITICAL for production:
        - global_ids is request-scoped, passed explicitly (no shared state)
        - Thread-safe: each request has its own state
        - No global caching that could cause cross-request contamination
        
        Args:
            speaker: Speaker name (must match first chunk)
            text: Continuation text chunk to synthesize
            global_ids: 32 global token IDs captured from first chunk
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
            max_tokens: Max tokens to generate
            repetition_penalty: Prevent repetition
            seed: Random seed (use same as first chunk for consistency)
        
        Yields:
            Audio bytes (int16 PCM, 16kHz mono)
        
        Raises:
            ValueError: If global_ids doesn't contain exactly 32 tokens
        
        Thread Safety:
            Fully thread-safe. All state is request-scoped and passed explicitly.
        """
        import time
        t_start = time.time()
        
        # Validate global tokens
        if len(global_ids) != _STREAM_EXPECTED_GLOBAL_COUNT:
            raise ValueError(
                f"Expected exactly {_STREAM_EXPECTED_GLOBAL_COUNT} global tokens, got {len(global_ids)}. "
                f"global_ids must be captured from first chunk generation."
            )
        
        # Build prompt WITH pre-filled global tokens
        # This tells the model to skip global generation and go straight to semantic tokens
        prompt = self.prompt_builder.build_prefix_with_globals(speaker, text, global_ids)
        
        # Configure sampling
        # OPTIMIZATION: Use Spark TTS stop token, not legacy SNAC CODE_END_TOKEN_ID
        sampling_params = SamplingParams(
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            min_tokens=DEFAULT_MIN_TOKENS,
            repetition_penalty=repetition_penalty,
            stop=TRAINING_STOP_TOKEN_IDS,  # "<|im_end|>" - matches non-streaming pipeline
            skip_special_tokens=False,
            seed=seed,
            output_kind=(RequestOutputKind.DELTA if _STREAM_USE_DELTA_OUTPUT else RequestOutputKind.CUMULATIVE),
        )
        
        # BiCodec token buffers
        semantic_buffer = []
        global_buffer_unused = []  # Not used in continuation, but needed for parser API
        processed_token_count = 0  # Used for cumulative mode; ignored for delta mode.
        total_audio_chunks = 0
        total_generated_tokens = 0
        
        # BiCodec streaming configuration
        MIN_SEMANTIC_FOR_FIRST_CHUNK = _STREAM_MIN_SEMANTIC_FIRST_CHUNK
        DECODE_INTERVAL = _STREAM_DECODE_INTERVAL
        CROSSFADE_MS = _STREAM_CROSSFADE_MS
        WINDOW_SIZE = _STREAM_WINDOW_SIZE
        
        # Generate unique request ID
        import uuid
        request_id = f"bicodec-cont-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
        
        # Reuse pre-warmed token parser (singleton, created once in __init__)
        token_parser = self.token_parser
        
        results_generator = self.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        )
        
        # Track state
        last_decode_count = 0
        total_samples_emitted_to_user = 0
        previous_chunk_tail: Optional[bytes] = None
        t_first_chunk_emitted: Optional[float] = None
        decode_calls = 0
        decode_wall_ms_total = 0.0
        decode_interval_applied_tokens: List[float] = []
        decode_pending_requests: List[float] = []
        
        # Stream tokens - model will generate ONLY semantic tokens (globals are pre-filled)
        async for request_output in results_generator:
            generated_ids = request_output.outputs[0].token_ids
            
            # In DELTA mode token_ids are already incremental.
            if _STREAM_USE_DELTA_OUTPUT:
                new_token_ids = generated_ids
                processed_token_count += len(new_token_ids)
            else:
                new_token_ids = generated_ids[processed_token_count:]
                processed_token_count = len(generated_ids)
            total_generated_tokens += len(new_token_ids)
            token_parser.parse_incremental(new_token_ids, semantic_buffer, global_buffer_unused)
            
            semantic_count = len(semantic_buffer)
            if semantic_count >= MIN_SEMANTIC_FOR_FIRST_CHUNK:
                decode_interval, pending_requests = self._resolve_decode_interval(
                    first_chunk_emitted=t_first_chunk_emitted is not None
                )
                semantic_delta = semantic_count - last_decode_count
                if semantic_delta >= decode_interval:
                    decode_step_tokens = semantic_delta
                    last_decode_count = semantic_count
                    decode_interval_applied_tokens.append(float(decode_interval))
                    decode_pending_requests.append(float(pending_requests))
                    if _STREAM_WINDOWED_DECODE and semantic_count > WINDOW_SIZE:
                        decode_semantic = semantic_buffer[-WINDOW_SIZE:]
                    else:
                        decode_semantic = semantic_buffer
                    
                    # Use pre-captured global tokens for decoding
                    decode_t0 = time.perf_counter()
                    if self.snac_decoder.enable_batching:
                        audio_bytes = await self.snac_decoder.decode_single_async(
                            semantic_ids=decode_semantic,
                            global_ids=global_ids,  # Use captured globals
                            trim_warmup=False,
                            use_sliding_window=False
                        )
                    else:
                        audio_bytes = self.snac_decoder.decode_streaming(
                            semantic_ids=decode_semantic,
                            global_ids=global_ids,  # Use captured globals
                            use_sliding_window=False,
                            trim_warmup=False
                        )
                    decode_calls += 1
                    decode_wall_ms_total += (time.perf_counter() - decode_t0) * 1000.0
                    
                    if audio_bytes:
                        total_samples_decoded = len(audio_bytes) // 2
                        if _STREAM_WINDOWED_DECODE and semantic_count > WINDOW_SIZE:
                            new_samples_approx = decode_step_tokens * 320
                            if total_samples_decoded > new_samples_approx:
                                new_bytes_start = (total_samples_decoded - new_samples_approx) * 2
                            else:
                                new_bytes_start = 0
                            new_audio_bytes = audio_bytes[new_bytes_start:]
                        elif total_samples_decoded > total_samples_emitted_to_user:
                            new_bytes_start = total_samples_emitted_to_user * 2
                            new_audio_bytes = audio_bytes[new_bytes_start:]
                        else:
                            new_audio_bytes = b""
                            
                        if new_audio_bytes:
                            to_emit, previous_chunk_tail = crossfade_bytes_int16(
                                previous_chunk_tail,
                                new_audio_bytes,
                                sample_rate_hz=16000,
                                crossfade_ms=CROSSFADE_MS,
                            )
                            
                            if to_emit:
                                total_audio_chunks += 1
                                total_samples_emitted_to_user += len(to_emit) // 2
                                if t_first_chunk_emitted is None:
                                    t_first_chunk_emitted = time.perf_counter()
                                if emit_progress:
                                    yield (to_emit, {})
                                else:
                                    yield to_emit
        
        # Final chunk processing
        if len(semantic_buffer) > last_decode_count:
            decode_t0 = time.perf_counter()
            audio_bytes = self.snac_decoder.decode_streaming(
                semantic_ids=semantic_buffer,
                global_ids=global_ids,  # Use captured globals
                use_sliding_window=False,
                trim_warmup=False
            )
            decode_calls += 1
            decode_wall_ms_total += (time.perf_counter() - decode_t0) * 1000.0
            if audio_bytes:
                total_samples_decoded = len(audio_bytes) // 2
                
                if total_samples_decoded > total_samples_emitted_to_user:
                    new_bytes_start = total_samples_emitted_to_user * 2
                    new_audio_bytes = audio_bytes[new_bytes_start:]
                    
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        new_audio_bytes,
                        sample_rate_hz=16000,
                        crossfade_ms=CROSSFADE_MS,
                    )
                    
                    if to_emit:
                        total_audio_chunks += 1
                        total_samples_emitted_to_user += len(to_emit) // 2
                        if t_first_chunk_emitted is None:
                            t_first_chunk_emitted = time.perf_counter()
                        if emit_progress:
                            yield (to_emit, {})
                        else:
                            yield to_emit
        
        # Flush remaining tail
        if previous_chunk_tail:
            total_audio_chunks += 1
            tail_samples = len(previous_chunk_tail) // 2
            total_samples_emitted_to_user += tail_samples
            if t_first_chunk_emitted is None:
                t_first_chunk_emitted = time.perf_counter()
            if emit_progress:
                yield (previous_chunk_tail, {})
            else:
                yield previous_chunk_tail
        
        t_end = time.time()
        audio_duration_s = total_samples_emitted_to_user / 16000
        logger.debug(
            "continuation complete audio=%.2fs chunks=%d",
            audio_duration_s,
            total_audio_chunks,
        )

        if emit_progress:
            perf: Dict[str, Any] = {
                "llm_token_total": int(total_generated_tokens),
                "semantic_token_total": int(len(semantic_buffer)),
                "global_token_total": int(len(global_ids)),
                "bicodec_decode_calls": int(decode_calls),
                "llm_decode_calls": int(decode_calls),
                "bicodec_decode_wall_ms": float(decode_wall_ms_total),
                "bicodec_decode_gpu_ms": 0.0,
                "bicodec_decode_cpu_ms": float(decode_wall_ms_total),
                "chunks_sent": int(total_audio_chunks),
                "audio_duration_seconds": float(audio_duration_s),
                "ttfb_ms": float((t_first_chunk_emitted - t_start) * 1000.0 if t_first_chunk_emitted else 0.0),
                "stream_output_kind": "delta" if _STREAM_USE_DELTA_OUTPUT else "cumulative",
                "stream_decode_interval_tokens": int(DECODE_INTERVAL),
                "stream_decode_interval_first_tokens": int(_STREAM_DECODE_INTERVAL_FIRST),
                "stream_decode_interval_busy_tokens": int(_STREAM_DECODE_INTERVAL_BUSY),
                "stream_decode_busy_pending_threshold": int(_STREAM_DECODE_BUSY_PENDING),
                "stream_adaptive_decode": bool(_STREAM_ADAPTIVE_DECODE),
                "stream_window_size_tokens": int(WINDOW_SIZE),
                "stream_windowed_decode": bool(_STREAM_WINDOWED_DECODE),
            }
            self._apply_stats(perf, "decode_interval_applied_tokens", decode_interval_applied_tokens)
            self._apply_stats(perf, "decode_pending_requests", decode_pending_requests)
            decoder_batch_stats = {}
            if hasattr(self.snac_decoder, "get_batching_stats"):
                try:
                    decoder_batch_stats = self.snac_decoder.get_batching_stats() or {}
                except Exception:
                    decoder_batch_stats = {}
            if decoder_batch_stats:
                perf.update(decoder_batch_stats)
            yield (b"", perf)