"""
TTS runtime wrapper (container-scoped singleton).

In Modal, each container should load the model once and then handle many requests.
This wrapper owns:
- SparkTTSModel (vLLM engine)
- IndicPromptBuilder
- BiCodecDecoder
- SparkTTSPipeline + Veena3SlidingWindowPipeline
- optional SuperResolutionService

Imports framework-agnostic inference + processing code from `veena3modal/core`,
`veena3modal/processing`, and `veena3modal/audio`.
"""

from __future__ import annotations

import os
import sys
import time
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)


@dataclass
class TTSRuntime:
    """
    Holds long-lived, per-container inference objects.
    
    Thread Safety:
        This class is designed for async concurrency within a single container.
        The vLLM engine handles internal batching and scheduling.
        Do NOT share across processes without proper synchronization.
    """
    model: Any = None
    pipeline: Any = None
    streaming_pipeline: Any = None
    long_text_processor: Any = None
    prompt_builder: Any = None
    bicodec_decoder: Any = None
    sr_service: Optional[Any] = None
    
    model_version: str = "not_loaded"
    is_loaded: bool = False
    load_time_ms: float = 0.0
    
    # Configuration
    model_path: str = ""
    bicodec_path: str = ""
    sr_checkpoint_dir: Optional[str] = None
    device: str = "cuda"
    
    # OPTIMIZATION: Pre-computed global tokens per speaker (eliminates ~110ms pre-roll in streaming)
    # Maps speaker_name -> list of 32 global token IDs captured at startup
    speaker_global_cache: Dict[str, List[int]] = field(default_factory=dict)


# Module-level singleton (per-container)
_runtime: Optional[TTSRuntime] = None


def get_runtime() -> Optional[TTSRuntime]:
    """Get the current TTS runtime singleton."""
    return _runtime


def is_initialized() -> bool:
    """Check if the runtime is initialized and ready."""
    return _runtime is not None and _runtime.is_loaded


def initialize_runtime(
    model_path: Optional[str] = None,
    bicodec_path: Optional[str] = None,
    sr_checkpoint_dir: Optional[str] = None,
    device: str = "cuda",
    hf_token: Optional[str] = None,
    gpu_memory_utilization: float = 0.25,
    enable_sr: bool = False,
    num_engines: int = 1,
    max_num_batched_tokens: Optional[int] = None,
    max_num_seqs: Optional[int] = None,
    enable_chunked_prefill: Optional[bool] = None,
    enable_prefix_caching: Optional[bool] = None,
    disable_log_stats: Optional[bool] = None,
    enforce_eager: Optional[bool] = None,
    precompute_speaker_globals: Optional[bool] = None,
) -> TTSRuntime:
    """
    Initialize the TTS runtime with all components.
    
    This should be called once per container (e.g., in Modal's @modal.enter).
    
    Args:
        model_path: Path to Spark TTS model (env: SPARK_TTS_MODEL_PATH)
        bicodec_path: Path to BiCodec model (env: BICODEC_MODEL_PATH, defaults to model_path)
        sr_checkpoint_dir: Path to super-resolution checkpoints (env: AP_BWE_CHECKPOINT_DIR)
        device: Device for inference (cuda/cpu)
        hf_token: HuggingFace token for private models (env: HF_TOKEN)
        gpu_memory_utilization: vLLM GPU memory fraction (default: 0.25)
        enable_sr: Enable super-resolution service
        num_engines: Number of vLLM engine instances (default: 1, set to 2-3 for Tier 3 optimization)
        max_num_batched_tokens: Optional vLLM scheduler cap override.
        max_num_seqs: Optional vLLM concurrent sequence cap override.
        enable_chunked_prefill: Optional vLLM chunked prefill toggle.
        enable_prefix_caching: Optional vLLM prefix caching toggle.
        disable_log_stats: Optional vLLM internal stats log toggle.
        enforce_eager: Optional vLLM eager-mode toggle (disables CUDA graphs when true).
        precompute_speaker_globals: Whether to warm speaker global tokens at startup.
            If None, reads PRECOMPUTE_SPEAKER_GLOBALS env var (default: false).
    
    Returns:
        Initialized TTSRuntime instance
    
    Raises:
        RuntimeError: If initialization fails
    """
    global _runtime
    
    start_time = time.time()
    
    # Resolve paths from env vars if not provided
    model_path = model_path or os.environ.get(
        'SPARK_TTS_MODEL_PATH',
        os.environ.get('MODEL_PATH', '/models/spark_tts_4speaker')
    )
    bicodec_path = bicodec_path or os.environ.get(
        'BICODEC_MODEL_PATH',
        model_path  # BiCodec is usually in the same directory
    )
    sr_checkpoint_dir = sr_checkpoint_dir or os.environ.get('AP_BWE_CHECKPOINT_DIR')
    hf_token = hf_token or os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
    if precompute_speaker_globals is None:
        raw_flag = os.environ.get("PRECOMPUTE_SPEAKER_GLOBALS", "false").strip().lower()
        precompute_speaker_globals = raw_flag in {"1", "true", "yes", "on"}
    
    logger.info(f"Initializing TTS runtime...")
    logger.info(f"  Model path: {model_path}")
    logger.info(f"  BiCodec path: {bicodec_path}")
    logger.info(f"  SR enabled: {enable_sr}, path: {sr_checkpoint_dir}")
    
    try:
        # Import inference components (vendored into veena3modal for Modal deployment)
        from veena3modal.core.model_loader import SparkTTSModel
        from veena3modal.core.pipeline import SparkTTSPipeline
        from veena3modal.core.bicodec_decoder import BiCodecDecoder
        from veena3modal.processing.prompt_builder import IndicPromptBuilder
        from veena3modal.processing.long_text_processor import LongTextProcessor

        engine_kwargs: Dict[str, Any] = {}
        if isinstance(max_num_batched_tokens, int) and max_num_batched_tokens > 0:
            engine_kwargs["max_num_batched_tokens"] = max_num_batched_tokens
        if isinstance(max_num_seqs, int) and max_num_seqs > 0:
            engine_kwargs["max_num_seqs"] = max_num_seqs
        if enable_chunked_prefill is not None:
            engine_kwargs["enable_chunked_prefill"] = bool(enable_chunked_prefill)
        if enable_prefix_caching is not None:
            engine_kwargs["enable_prefix_caching"] = bool(enable_prefix_caching)
        if disable_log_stats is not None:
            engine_kwargs["disable_log_stats"] = bool(disable_log_stats)
        if enforce_eager is not None:
            engine_kwargs["enforce_eager"] = bool(enforce_eager)
        
        # Load SparkTTS model with vLLM engine
        if num_engines > 1:
            # TIER 3 OPTIMIZATION: Multiple vLLM engines on same GPU
            # Each engine gets gpu_memory_utilization / num_engines to share GPU fairly.
            # E.g., 3 engines at 0.08 each = 0.24 total, ~8GB per engine, ~24GB total.
            from veena3modal.core.multi_engine import create_multi_engine_model
            per_engine_mem = gpu_memory_utilization / num_engines
            logger.info(f"Loading {num_engines} vLLM engines ({per_engine_mem:.2f} GPU mem each)...")
            model = create_multi_engine_model(
                model_path=model_path,
                num_engines=num_engines,
                hf_token=hf_token,
                gpu_memory_per_engine=per_engine_mem,
                **engine_kwargs,
            )
        else:
            logger.info("Loading SparkTTS model with vLLM...")
            model = SparkTTSModel(
                model_path=model_path,
                hf_token=hf_token,
                gpu_memory_utilization=gpu_memory_utilization,
                **engine_kwargs,
            )
        
        # Initialize prompt builder
        logger.info("Initializing prompt builder...")
        prompt_builder = IndicPromptBuilder(
            tokenizer=model.tokenizer,
            model=model,
        )
        
        # Initialize BiCodec decoder
        logger.info("Initializing BiCodec decoder...")
        bicodec_decoder = BiCodecDecoder(
            device=device,
            model_path=bicodec_path,
        )
        
        # Initialize pipeline
        logger.info("Initializing TTS pipeline...")
        pipeline = SparkTTSPipeline(
            model=model,
            prompt_builder=prompt_builder,
            bicodec_decoder=bicodec_decoder,
        )
        
        # Initialize streaming pipeline (for M4)
        streaming_pipeline = None
        try:
            from veena3modal.core.streaming_pipeline import Veena3SlidingWindowPipeline
            logger.info("Initializing streaming pipeline...")
            # NOTE: Parameter is named 'snac_decoder' for legacy reasons, but works with BiCodecDecoder
            streaming_pipeline = Veena3SlidingWindowPipeline(
                model=model,
                prompt_builder=prompt_builder,
                snac_decoder=bicodec_decoder,  # BiCodecDecoder is interface-compatible
            )
        except ImportError as e:
            logger.warning(f"Streaming pipeline not available: {e}")

        # Initialize long text processor once (avoid per-request construction/logging)
        long_text_processor = LongTextProcessor(
            pipeline=pipeline,
            streaming_pipeline=streaming_pipeline,
        )
        
        # Initialize super-resolution (optional)
        sr_service = None
        if enable_sr and sr_checkpoint_dir:
            try:
                from veena3modal.core.super_resolution import SuperResolutionService
                logger.info(f"Initializing super-resolution from {sr_checkpoint_dir}...")
                sr_service = SuperResolutionService(checkpoint_dir=sr_checkpoint_dir)
                # Load the model explicitly
                if sr_service.load_model(device=device):
                    logger.info("✅ Super-resolution model loaded successfully")
                else:
                    logger.warning("Super-resolution model failed to load")
                    sr_service = None
            except Exception as e:
                logger.warning(f"Super-resolution not available: {e}")
        
        # Determine model version
        model_version = os.path.basename(model_path.rstrip('/'))
        if not model_version:
            model_version = "spark-tts"
        
        load_time = (time.time() - start_time) * 1000
        
        # Create runtime
        _runtime = TTSRuntime(
            model=model,
            pipeline=pipeline,
            streaming_pipeline=streaming_pipeline,
            long_text_processor=long_text_processor,
            prompt_builder=prompt_builder,
            bicodec_decoder=bicodec_decoder,
            sr_service=sr_service,
            model_version=model_version,
            is_loaded=True,
            load_time_ms=load_time,
            model_path=model_path,
            bicodec_path=bicodec_path,
            sr_checkpoint_dir=sr_checkpoint_dir,
            device=device,
        )
        
        logger.info(f"✅ TTS runtime initialized in {load_time:.0f}ms")
        logger.info(f"   Model version: {model_version}")
        
        # Pre-compute speaker globals only when explicitly enabled.
        # Startup warmup runs async generation and can bind AsyncLLMEngine state
        # to a non-serving event loop if executed in a different loop context.
        if precompute_speaker_globals and num_engines <= 1:
            logger.info("Speaker globals startup precompute enabled")
            _precompute_speaker_globals(_runtime)
        elif num_engines > 1:
            logger.info("Speaker globals cache startup precompute skipped (multi-engine mode)")
        else:
            logger.info("Speaker globals cache startup precompute skipped (disabled)")
        
        # Update FastAPI app with model version
        try:
            from veena3modal.api.fastapi_app import set_model_version
            set_model_version(model_version)
        except ImportError:
            pass
        
        return _runtime
        
    except Exception as e:
        logger.error(f"❌ Failed to initialize TTS runtime: {e}")
        import traceback
        traceback.print_exc()
        raise RuntimeError(f"TTS runtime initialization failed: {e}") from e


def _precompute_speaker_globals(runtime: TTSRuntime) -> None:
    """
    Pre-compute and cache global tokens for all 12 speakers at startup.
    
    OPTIMIZATION: BiCodec streaming requires 32 "global tokens" before any audio can be emitted.
    These encode speaker identity via FSQ quantization. Since we have only 12 fixed speakers,
    we generate one short utterance per speaker at startup, capture the 32 global tokens,
    and cache them. During streaming, we inject cached globals via build_prefix_with_globals(),
    skipping the ~110ms global token pre-roll phase entirely.
    
    This runs synchronously at startup (adds ~5-10s to cold start, saves ~110ms per streaming TTFB).
    """
    import asyncio
    import re
    from vllm import SamplingParams
    from veena3modal.core.constants import (
        SPEAKER_MAP, TRAINING_STOP_TOKEN_IDS,
        DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P,
    )
    
    logger.info("Pre-computing global tokens for all speakers...")
    t_start = time.time()
    
    test_text = "Hello, this is a voice test."
    
    sampling_params = SamplingParams(
        temperature=DEFAULT_TEMPERATURE,
        top_k=DEFAULT_TOP_K,
        top_p=DEFAULT_TOP_P,
        max_tokens=128,  # Only need ~60 tokens (32 global + some semantic)
        stop=TRAINING_STOP_TOKEN_IDS,
        skip_special_tokens=False,
    )
    
    async def _capture_globals_for_speaker(speaker_name: str) -> Optional[List[int]]:
        """Generate a short utterance and capture the 32 global tokens."""
        prompt = runtime.prompt_builder.build_prefix(speaker_name, test_text)
        request_id = f"warmup-{speaker_name}-{int(time.time() * 1000)}"
        
        final_output = None
        async for request_output in runtime.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        ):
            final_output = request_output
        
        if final_output is None:
            return None
        
        generated_text = final_output.outputs[0].text
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", generated_text)
        global_ids = [int(t) for t in global_matches]
        
        if len(global_ids) >= 32:
            return global_ids[:32]
        return None
    
    async def _run_all():
        for speaker_name in SPEAKER_MAP.keys():
            try:
                global_ids = await _capture_globals_for_speaker(speaker_name)
                if global_ids and len(global_ids) == 32:
                    runtime.speaker_global_cache[speaker_name] = global_ids
                    logger.info(f"  Cached globals for {speaker_name}")
                else:
                    logger.warning(f"  Failed to capture globals for {speaker_name}")
            except Exception as e:
                logger.warning(f"  Error caching globals for {speaker_name}: {e}")
    
    # Run async warmup in event loop
    try:
        loop = asyncio.get_event_loop()
        if loop.is_running():
            # Already in async context (unlikely at startup, but handle it)
            import concurrent.futures
            with concurrent.futures.ThreadPoolExecutor() as pool:
                pool.submit(lambda: asyncio.run(_run_all())).result()
        else:
            loop.run_until_complete(_run_all())
    except RuntimeError:
        # No event loop exists yet
        asyncio.run(_run_all())
    
    elapsed = (time.time() - t_start) * 1000
    cached = len(runtime.speaker_global_cache)
    total = len(SPEAKER_MAP)
    logger.info(f"Speaker globals cached: {cached}/{total} speakers in {elapsed:.0f}ms")


def _get_long_text_processor(runtime: TTSRuntime):
    """Return cached LongTextProcessor, creating it lazily if needed."""
    if runtime.long_text_processor is None:
        from veena3modal.processing.long_text_processor import LongTextProcessor

        runtime.long_text_processor = LongTextProcessor(
            pipeline=runtime.pipeline,
            streaming_pipeline=runtime.streaming_pipeline,
        )
    return runtime.long_text_processor


async def generate_speech(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    output_sample_rate: str = "16khz",
) -> Tuple[Optional[bytes], Dict[str, Any]]:
    """
    Generate speech audio (non-streaming).
    
    Args:
        text: Text to synthesize (already normalized)
        speaker: Internal speaker name (resolved)
        temperature: Sampling temperature
        top_k: Top-k sampling
        top_p: Nucleus sampling
        max_tokens: Maximum tokens to generate
        repetition_penalty: Repetition penalty
        seed: Random seed for reproducibility
        output_sample_rate: "16khz" or "48khz" (triggers super-resolution)
    
    Returns:
        Tuple of (audio_bytes, metrics_dict)
        audio_bytes: WAV audio bytes (16kHz or 48kHz, 16-bit PCM) or None if failed
        metrics_dict: Dictionary with timing metrics
    
    Raises:
        RuntimeError: If runtime not initialized
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    metrics = {
        "ttfb_ms": 0,
        "generation_ms": 0,
        "tokens_generated": 0,
        "audio_duration_seconds": 0.0,
        "sr_applied": False,
        "output_sample_rate": 16000,
    }
    
    start_time = time.time()
    
    try:
        # Generate audio using pipeline (always at 16kHz)
        audio_bytes, perf = await runtime.pipeline.generate_speech_profiled(
            speaker=speaker,
            text=text,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        )
        if perf:
            metrics.update(perf)
        
        generation_time = time.time() - start_time
        metrics["generation_ms"] = int(generation_time * 1000)
        metrics["ttfb_ms"] = metrics["generation_ms"]  # Non-streaming: TTFB = total time
        
        if audio_bytes:
            # Apply super-resolution if requested and available
            sample_rate = 16000
            if output_sample_rate == "48khz" or logger.isEnabledFor(logging.DEBUG):
                logger.debug(
                    "SR check: output_sample_rate=%r, sr_service=%s, is_loaded=%s",
                    output_sample_rate,
                    runtime.sr_service is not None,
                    runtime.sr_service.is_loaded if runtime.sr_service else False,
                )
            
            if output_sample_rate == "48khz" and runtime.sr_service and runtime.sr_service.is_loaded:
                try:
                    sr_start = time.time()
                    logger.info(f"Applying super-resolution to {len(audio_bytes)} bytes...")
                    audio_bytes = _apply_super_resolution(audio_bytes, runtime.sr_service)
                    sr_time = (time.time() - sr_start) * 1000
                    metrics["sr_ms"] = int(sr_time)
                    metrics["sr_applied"] = True
                    sample_rate = 48000
                    logger.info(f"✅ Super-resolution applied in {sr_time:.1f}ms, output={len(audio_bytes)} bytes")
                except Exception as e:
                    logger.warning(f"Super-resolution failed, returning 16kHz: {e}")
                    import traceback
                    traceback.print_exc()
            
            metrics["output_sample_rate"] = sample_rate
            
            # Calculate audio duration
            audio_duration = (len(audio_bytes) - 44) / (sample_rate * 2)  # -44 for WAV header
            metrics["audio_duration_seconds"] = max(0.0, audio_duration)
        
        return audio_bytes, metrics
        
    except Exception as e:
        logger.error(f"Speech generation failed: {e}")
        raise


def _apply_super_resolution(audio_bytes: bytes, sr_service) -> bytes:
    """
    Apply super-resolution to audio bytes.
    
    Args:
        audio_bytes: WAV audio at 16kHz
        sr_service: SuperResolutionService instance
    
    Returns:
        WAV audio at 48kHz
    """
    import numpy as np
    import struct
    import torch
    
    # Parse WAV header and extract PCM data
    if len(audio_bytes) < 44:
        raise ValueError("Invalid WAV data")
    
    # Extract PCM data (skip 44-byte WAV header)
    pcm_data = np.frombuffer(audio_bytes[44:], dtype=np.int16)
    
    # Convert to float32 for SR model
    audio_float = pcm_data.astype(np.float32) / 32768.0
    
    # Convert numpy to torch tensor for process_chunk
    audio_tensor = torch.from_numpy(audio_float)
    
    # Apply super-resolution using process_chunk (16kHz -> 48kHz)
    # process_chunk expects [batch, samples] or [samples] and returns [batch, samples]
    upsampled_tensor = sr_service.process_chunk(audio_tensor)
    
    # Convert back to numpy, squeeze batch dim if present
    upsampled = upsampled_tensor.squeeze().cpu().numpy()
    
    # Convert back to int16
    upsampled_int16 = np.clip(upsampled * 32768.0, -32768, 32767).astype(np.int16)
    
    # Create new WAV header for 48kHz
    sample_rate = 48000
    num_channels = 1
    bits_per_sample = 16
    byte_rate = sample_rate * num_channels * bits_per_sample // 8
    block_align = num_channels * bits_per_sample // 8
    data_size = len(upsampled_int16) * 2
    
    wav_header = struct.pack(
        '<4sI4s4sIHHIIHH4sI',
        b'RIFF',
        36 + data_size,
        b'WAVE',
        b'fmt ',
        16,  # fmt chunk size
        1,   # PCM format
        num_channels,
        sample_rate,
        byte_rate,
        block_align,
        bits_per_sample,
        b'data',
        data_size,
    )
    
    return wav_header + upsampled_int16.tobytes()


async def generate_speech_chunked(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    sample_rate: int = 16000,
    output_sample_rate: str = "16khz",
) -> Tuple[Optional[bytes], Dict[str, Any]]:
    """
    Generate speech with automatic text chunking for long inputs.
    
    Uses LongTextProcessor to split text and stitch audio.
    
    Args:
        text: Text to synthesize
        speaker: Internal speaker name
        ... (same as generate_speech)
        sample_rate: Internal generation sample rate (always 16000)
        output_sample_rate: "16khz" or "48khz" (triggers super-resolution)
    
    Returns:
        Tuple of (audio_bytes, metrics_dict)
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    metrics = {
        "ttfb_ms": 0,
        "generation_ms": 0,
        "chunks_processed": 0,
        "audio_duration_seconds": 0.0,
        "text_chunked": False,
        "sr_applied": False,
        "output_sample_rate": 16000,
    }
    
    start_time = time.time()
    
    try:
        long_processor = _get_long_text_processor(runtime)
        
        # Check if chunking is needed
        if long_processor.should_chunk(text):
            metrics["text_chunked"] = True
            audio_bytes = await long_processor.generate_with_chunking(
                text=text,
                speaker=speaker,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
                sample_rate=sample_rate,
            )
        else:
            # Short text: direct generation
            audio_bytes, perf = await runtime.pipeline.generate_speech_profiled(
                speaker=speaker,
                text=text,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
            )
            if perf:
                metrics.update(perf)
        
        generation_time = time.time() - start_time
        metrics["generation_ms"] = int(generation_time * 1000)
        metrics["ttfb_ms"] = metrics["generation_ms"]
        
        if audio_bytes:
            # Apply super-resolution if requested and available
            final_sample_rate = 16000
            if output_sample_rate == "48khz" and runtime.sr_service and runtime.sr_service.is_loaded:
                try:
                    sr_start = time.time()
                    logger.info(f"Applying super-resolution to chunked audio ({len(audio_bytes)} bytes)...")
                    audio_bytes = _apply_super_resolution(audio_bytes, runtime.sr_service)
                    sr_time = (time.time() - sr_start) * 1000
                    metrics["sr_ms"] = int(sr_time)
                    metrics["sr_applied"] = True
                    final_sample_rate = 48000
                    logger.info(f"✅ Super-resolution applied in {sr_time:.1f}ms, output={len(audio_bytes)} bytes")
                except Exception as e:
                    logger.warning(f"Super-resolution failed, returning 16kHz: {e}")
                    import traceback
                    traceback.print_exc()
            
            metrics["output_sample_rate"] = final_sample_rate
            audio_duration = (len(audio_bytes) - 44) / (final_sample_rate * 2)
            metrics["audio_duration_seconds"] = max(0.0, audio_duration)
        
        return audio_bytes, metrics
        
    except Exception as e:
        logger.error(f"Chunked speech generation failed: {e}")
        raise


from typing import AsyncGenerator


async def generate_speech_streaming(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    enable_chunking: bool = True,
) -> AsyncGenerator[Tuple[bytes, Dict[str, Any]], None]:
    """
    Generate speech audio with true streaming (yields chunks as they're generated).
    
    This is the core streaming implementation for M4.
    First yield includes WAV header + first PCM chunk.
    Subsequent yields are raw PCM chunks.
    
    Args:
        text: Text to synthesize (already normalized)
        speaker: Internal speaker name (resolved)
        temperature: Sampling temperature
        top_k: Top-k sampling
        top_p: Nucleus sampling
        max_tokens: Maximum tokens to generate
        repetition_penalty: Repetition penalty
        seed: Random seed for reproducibility
        enable_chunking: Enable text chunking for long inputs with voice consistency
    
    Yields:
        Tuple of (audio_bytes, metrics_dict)
        - First yield: WAV header (44 bytes) prepended to first PCM chunk
        - Subsequent yields: Raw PCM chunks (int16, 16kHz)
        - Final yield: metrics_dict has final timing info
    
    Raises:
        RuntimeError: If runtime not initialized or streaming pipeline unavailable
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    
    if runtime.streaming_pipeline is None:
        raise RuntimeError("Streaming pipeline not available")
    
    # Import audio utils for WAV header
    from veena3modal.audio.utils import create_wav_header
    
    # Metrics tracking
    metrics = {
        "ttfb_ms": 0,
        "chunks_sent": 0,
        "total_bytes": 0,
        "audio_duration_seconds": 0.0,
        "text_chunked": False,
    }
    
    start_time = time.time()
    first_chunk_time = None
    total_pcm_bytes = 0
    sample_rate = 16000  # BiCodec sample rate
    
    # Check if we need text chunking
    long_processor = _get_long_text_processor(runtime)
    needs_chunking = enable_chunking and long_processor.should_chunk(text)
    
    if needs_chunking:
        # Chunked streaming with voice consistency (global token caching)
        metrics["text_chunked"] = True
        async for audio_chunk, chunk_metrics in _stream_chunked_text(
            runtime, long_processor, text, speaker, temperature, top_k, top_p,
            max_tokens, repetition_penalty, seed, sample_rate
        ):
            # First chunk: prepend WAV header
            if first_chunk_time is None:
                first_chunk_time = time.time()
                metrics["ttfb_ms"] = int((first_chunk_time - start_time) * 1000)
                # Create streaming WAV header (size=0 for unknown length)
                wav_header = create_wav_header(sample_rate=sample_rate, data_size=0)
                audio_chunk = wav_header + audio_chunk
            
            total_pcm_bytes += len(audio_chunk) - (44 if metrics["chunks_sent"] == 0 else 0)
            metrics["chunks_sent"] += 1
            metrics["total_bytes"] = total_pcm_bytes + 44  # Include header
            metrics["audio_duration_seconds"] = total_pcm_bytes / (sample_rate * 2)
            
            yield audio_chunk, metrics
    else:
        # Simple streaming (no text chunking)
        # OPTIMIZATION: If we have pre-cached global tokens for this speaker,
        # use continuation mode to skip the ~110ms global token pre-roll.
        # The model jumps straight to semantic token generation.
        cached_globals = runtime.speaker_global_cache.get(speaker)
        
        if cached_globals:
            # Fast path: use cached globals, skip global token generation entirely
            stream_gen = runtime.streaming_pipeline.generate_speech_stream_indic_continuation(
                speaker=speaker,
                text=text,
                global_ids=cached_globals,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
                emit_progress=True,
            )
        else:
            # Fallback: no cached globals, generate them inline (original path)
            stream_gen = runtime.streaming_pipeline.generate_speech_stream_indic(
                speaker=speaker,
                text=text,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
                emit_progress=True,
            )
        
        async for stream_item in stream_gen:
            stream_metrics: Dict[str, Any] = {}
            audio_chunk = stream_item
            if (
                isinstance(stream_item, tuple)
                and len(stream_item) == 2
                and isinstance(stream_item[1], dict)
            ):
                audio_chunk = stream_item[0]
                stream_metrics = stream_item[1]

            if stream_metrics:
                metrics.update(stream_metrics)

            if not audio_chunk:
                # Metrics-only event from streaming pipeline.
                continue

            # First chunk: prepend WAV header
            if first_chunk_time is None:
                first_chunk_time = time.time()
                metrics["ttfb_ms"] = int((first_chunk_time - start_time) * 1000)
                wav_header = create_wav_header(sample_rate=sample_rate, data_size=0)
                audio_chunk = wav_header + audio_chunk
            
            total_pcm_bytes += len(audio_chunk) - (44 if metrics["chunks_sent"] == 0 else 0)
            metrics["chunks_sent"] += 1
            metrics["total_bytes"] = total_pcm_bytes + 44
            metrics["audio_duration_seconds"] = total_pcm_bytes / (sample_rate * 2)
            
            yield audio_chunk, metrics


async def _stream_chunked_text(
    runtime: TTSRuntime,
    long_processor,
    text: str,
    speaker: str,
    temperature: float,
    top_k: int,
    top_p: float,
    max_tokens: int,
    repetition_penalty: float,
    seed: Optional[int],
    sample_rate: int,
) -> AsyncGenerator[Tuple[bytes, Dict[str, Any]], None]:
    """
    Internal helper: stream chunked text with voice consistency.
    
    Uses global token caching from first chunk to maintain voice across chunks.
    """
    from veena3modal.audio.crossfade import crossfade_bytes_int16
    
    # Chunk text using the long text processor's chunker
    chunks = long_processor.chunk_text(text)
    
    if not chunks:
        return
    
    captured_globals: Optional[List[int]] = None
    previous_chunk_tail: Optional[bytes] = None
    chunk_metrics = {"chunks_processed": 0}
    
    for i, chunk_text in enumerate(chunks):
        chunk_metrics["chunks_processed"] = i + 1
        
        if i == 0:
            # First chunk: capture global tokens for voice consistency
            async for audio_bytes, global_ids in runtime.streaming_pipeline.generate_speech_stream_indic_first_chunk(
                speaker=speaker,
                text=chunk_text,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
            ):
                if captured_globals is None and global_ids:
                    captured_globals = global_ids
                
                # Crossfade with previous tail (inter-chunk stitching)
                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                    previous_chunk_tail,
                    audio_bytes,
                    sample_rate_hz=sample_rate,
                    crossfade_ms=50,
                )
                
                if to_emit:
                    yield to_emit, chunk_metrics
        else:
            # Continuation chunks: use captured globals for voice consistency
            if captured_globals is None:
                logger.warning(f"No captured globals for chunk {i+1}, using regular streaming")
                async for audio_bytes in runtime.streaming_pipeline.generate_speech_stream_indic(
                    speaker=speaker,
                    text=chunk_text,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                ):
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        audio_bytes,
                        sample_rate_hz=sample_rate,
                        crossfade_ms=50,
                    )
                    if to_emit:
                        yield to_emit, chunk_metrics
            else:
                # Use continuation method with cached globals
                async for audio_bytes in runtime.streaming_pipeline.generate_speech_stream_indic_continuation(
                    speaker=speaker,
                    text=chunk_text,
                    global_ids=captured_globals,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                ):
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        audio_bytes,
                        sample_rate_hz=sample_rate,
                        crossfade_ms=50,
                    )
                    if to_emit:
                        yield to_emit, chunk_metrics
    
    # Flush remaining tail
    if previous_chunk_tail:
        yield previous_chunk_tail, chunk_metrics
