"""
TTS runtime wrapper (container-scoped singleton).

In Modal, each container should load the model once and then handle many requests.
This wrapper owns:
- SparkTTSModel (vLLM engine)
- IndicPromptBuilder
- BiCodecDecoder
- SparkTTSPipeline + Veena3SlidingWindowPipeline
- optional SuperResolutionService

Imports framework-agnostic inference + processing code from `veena3modal/core`,
`veena3modal/processing`, and `veena3modal/audio`.
"""

from __future__ import annotations

import os
import sys
import time
import logging
import threading
import asyncio
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)


def _env_int(name: str, default: int, minimum: int = 0) -> int:
    raw = os.environ.get(name, "")
    if not raw:
        return default
    try:
        value = int(raw)
    except ValueError:
        return default
    return max(minimum, value)


def _env_float(name: str, default: float, minimum: float = 0.0) -> float:
    raw = os.environ.get(name, "")
    if not raw:
        return default
    try:
        value = float(raw)
    except ValueError:
        return default
    return max(minimum, value)


@dataclass
class StreamingAdmissionLease:
    """Admission ticket for one active streaming request."""

    wait_ms: float = 0.0
    queued: bool = False
    inflight_on_grant: int = 0
    waiters_on_grant: int = 0
    queue_depth_on_entry: int = 0
    released: bool = False


class StreamingAdmissionError(RuntimeError):
    """Raised when streaming admission control cannot accept a request."""

    def __init__(
        self,
        message: str,
        *,
        reason: str,
        retry_after_ms: float = 0.0,
        snapshot: Optional[Dict[str, float]] = None,
    ):
        super().__init__(message)
        self.reason = reason
        self.retry_after_ms = float(max(0.0, retry_after_ms))
        self.snapshot = snapshot or {}


class StreamingAdmissionController:
    """
    Lightweight admission controller for streaming requests.

    Keeps a hard cap on active streams (`max_inflight`) and an optional bounded
    wait queue (`max_queue`) with per-request wait timeout (`max_wait_ms`).
    """

    def __init__(
        self,
        max_inflight: int,
        max_queue: int,
        max_wait_ms: float,
        poll_ms: float,
    ):
        self.max_inflight = int(max(0, max_inflight))
        self.max_queue = int(max(0, max_queue))
        self.max_wait_ms = float(max(0.0, max_wait_ms))
        self.poll_ms = float(max(0.5, poll_ms))
        self.enabled = self.max_inflight > 0

        self._lock = threading.Lock()
        self._inflight = 0
        self._waiters = 0
        self._total_admitted = 0
        self._total_queue_entries = 0
        self._total_queue_wait_ms = 0.0
        self._total_rejected_queue_full = 0
        self._total_rejected_timeout = 0

    @classmethod
    def from_env(cls) -> "StreamingAdmissionController":
        return cls(
            max_inflight=_env_int("VEENA3_STREAM_ADMISSION_MAX_INFLIGHT", 0, minimum=0),
            max_queue=_env_int("VEENA3_STREAM_ADMISSION_MAX_QUEUE", 0, minimum=0),
            max_wait_ms=_env_float("VEENA3_STREAM_ADMISSION_MAX_WAIT_MS", 0.0, minimum=0.0),
            poll_ms=_env_float("VEENA3_STREAM_ADMISSION_POLL_MS", 2.0, minimum=0.5),
        )

    def _snapshot_locked(self) -> Dict[str, float]:
        queue_avg = (
            self._total_queue_wait_ms / self._total_queue_entries
            if self._total_queue_entries > 0
            else 0.0
        )
        return {
            "enabled": 1.0 if self.enabled else 0.0,
            "max_inflight": float(self.max_inflight),
            "max_queue": float(self.max_queue),
            "max_wait_ms": float(self.max_wait_ms),
            "poll_ms": float(self.poll_ms),
            "inflight_now": float(self._inflight),
            "waiters_now": float(self._waiters),
            "admitted_total": float(self._total_admitted),
            "queue_entries_total": float(self._total_queue_entries),
            "queue_wait_ms_avg": float(queue_avg),
            "rejected_queue_full_total": float(self._total_rejected_queue_full),
            "rejected_timeout_total": float(self._total_rejected_timeout),
        }

    def snapshot(self) -> Dict[str, float]:
        with self._lock:
            return self._snapshot_locked()

    async def acquire(self) -> StreamingAdmissionLease:
        start = time.perf_counter()
        queued = False
        queue_depth_on_entry = 0
        deadline = (
            start + (self.max_wait_ms / 1000.0)
            if self.max_wait_ms > 0.0
            else None
        )

        while True:
            now = time.perf_counter()
            with self._lock:
                if (not self.enabled) or (self._inflight < self.max_inflight):
                    wait_ms = (now - start) * 1000.0
                    self._inflight += 1
                    if queued:
                        self._waiters = max(0, self._waiters - 1)
                        self._total_queue_wait_ms += max(0.0, wait_ms)
                    self._total_admitted += 1
                    return StreamingAdmissionLease(
                        wait_ms=float(wait_ms),
                        queued=bool(queued),
                        inflight_on_grant=int(self._inflight),
                        waiters_on_grant=int(self._waiters),
                        queue_depth_on_entry=int(queue_depth_on_entry),
                    )

                if not queued:
                    if self._waiters >= self.max_queue:
                        self._total_rejected_queue_full += 1
                        snapshot = self._snapshot_locked()
                        raise StreamingAdmissionError(
                            "streaming admission queue full",
                            reason="queue_full",
                            retry_after_ms=self.poll_ms,
                            snapshot=snapshot,
                        )
                    queued = True
                    self._waiters += 1
                    queue_depth_on_entry = self._waiters
                    self._total_queue_entries += 1

            if deadline is not None and now >= deadline:
                with self._lock:
                    if queued:
                        self._waiters = max(0, self._waiters - 1)
                    self._total_rejected_timeout += 1
                    snapshot = self._snapshot_locked()
                raise StreamingAdmissionError(
                    "streaming admission wait timeout",
                    reason="wait_timeout",
                    retry_after_ms=self.poll_ms,
                    snapshot=snapshot,
                )

            await asyncio.sleep(self.poll_ms / 1000.0)

    def release(self, lease: Optional[StreamingAdmissionLease]) -> None:
        if lease is None or lease.released:
            return
        lease.released = True
        with self._lock:
            self._inflight = max(0, self._inflight - 1)


@dataclass
class TTSRuntime:
    """
    Holds long-lived, per-container inference objects.
    
    Thread Safety:
        This class is designed for async concurrency within a single container.
        The vLLM engine handles internal batching and scheduling.
        Do NOT share across processes without proper synchronization.
    """
    model: Any = None
    pipeline: Any = None
    streaming_pipeline: Any = None
    long_text_processor: Any = None
    prompt_builder: Any = None
    bicodec_decoder: Any = None
    sr_service: Optional[Any] = None
    
    model_version: str = "not_loaded"
    is_loaded: bool = False
    load_time_ms: float = 0.0
    
    # Configuration
    model_path: str = ""
    bicodec_path: str = ""
    sr_checkpoint_dir: Optional[str] = None
    device: str = "cuda"
    
    # OPTIMIZATION: Pre-computed global tokens per speaker (eliminates ~110ms pre-roll in streaming)
    # Maps speaker_name -> list of 32 global token IDs captured at startup
    speaker_global_cache: Dict[str, List[int]] = field(default_factory=dict)
    stream_admission: Optional[StreamingAdmissionController] = None


# Module-level singleton (per-container)
_runtime: Optional[TTSRuntime] = None


def get_runtime() -> Optional[TTSRuntime]:
    """Get the current TTS runtime singleton."""
    return _runtime


def is_initialized() -> bool:
    """Check if the runtime is initialized and ready."""
    return _runtime is not None and _runtime.is_loaded


def _get_stream_admission_controller(runtime: Optional[TTSRuntime]) -> StreamingAdmissionController:
    """Return initialized streaming admission controller for this runtime."""
    if runtime is None:
        raise RuntimeError("TTS runtime not initialized")
    controller = runtime.stream_admission
    if controller is None:
        controller = StreamingAdmissionController.from_env()
        runtime.stream_admission = controller
    return controller


def get_streaming_admission_snapshot() -> Dict[str, float]:
    """
    Return current admission state + cumulative counters.

    Safe to call anytime; when runtime is not initialized returns an empty dict.
    """
    runtime = get_runtime()
    if runtime is None:
        return {}
    return _get_stream_admission_controller(runtime).snapshot()


async def acquire_streaming_slot() -> StreamingAdmissionLease:
    """Acquire one streaming admission slot or raise StreamingAdmissionError."""
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    runtime = get_runtime()
    controller = _get_stream_admission_controller(runtime)
    return await controller.acquire()


def release_streaming_slot(lease: Optional[StreamingAdmissionLease]) -> None:
    """Release a previously acquired streaming slot (idempotent)."""
    runtime = get_runtime()
    if runtime is None:
        return
    controller = _get_stream_admission_controller(runtime)
    controller.release(lease)


def initialize_runtime(
    model_path: Optional[str] = None,
    bicodec_path: Optional[str] = None,
    sr_checkpoint_dir: Optional[str] = None,
    device: str = "cuda",
    hf_token: Optional[str] = None,
    gpu_memory_utilization: float = 0.25,
    enable_sr: bool = False,
    num_engines: int = 1,
    max_num_batched_tokens: Optional[int] = None,
    max_num_seqs: Optional[int] = None,
    enable_chunked_prefill: Optional[bool] = None,
    enable_prefix_caching: Optional[bool] = None,
    disable_log_stats: Optional[bool] = None,
    enforce_eager: Optional[bool] = None,
    precompute_speaker_globals: Optional[bool] = None,
) -> TTSRuntime:
    """
    Initialize the TTS runtime with all components.
    
    This should be called once per container (e.g., in Modal's @modal.enter).
    
    Args:
        model_path: Path to Spark TTS model (env: SPARK_TTS_MODEL_PATH)
        bicodec_path: Path to BiCodec model (env: BICODEC_MODEL_PATH, defaults to model_path)
        sr_checkpoint_dir: Path to super-resolution checkpoints (env: AP_BWE_CHECKPOINT_DIR)
        device: Device for inference (cuda/cpu)
        hf_token: HuggingFace token for private models (env: HF_TOKEN)
        gpu_memory_utilization: vLLM GPU memory fraction (default: 0.25)
        enable_sr: Enable super-resolution service
        num_engines: Number of vLLM engine instances (default: 1, set to 2-3 for Tier 3 optimization)
        max_num_batched_tokens: Optional vLLM scheduler cap override.
        max_num_seqs: Optional vLLM concurrent sequence cap override.
        enable_chunked_prefill: Optional vLLM chunked prefill toggle.
        enable_prefix_caching: Optional vLLM prefix caching toggle.
        disable_log_stats: Optional vLLM internal stats log toggle.
        enforce_eager: Optional vLLM eager-mode toggle (disables CUDA graphs when true).
        precompute_speaker_globals: Whether to warm speaker global tokens at startup.
            If None, reads PRECOMPUTE_SPEAKER_GLOBALS env var (default: false).
    
    Returns:
        Initialized TTSRuntime instance
    
    Raises:
        RuntimeError: If initialization fails
    """
    global _runtime
    
    start_time = time.time()
    
    # Resolve paths from env vars if not provided
    model_path = model_path or os.environ.get(
        'SPARK_TTS_MODEL_PATH',
        os.environ.get('MODEL_PATH', '/models/spark_tts_4speaker')
    )
    bicodec_path = bicodec_path or os.environ.get(
        'BICODEC_MODEL_PATH',
        model_path  # BiCodec is usually in the same directory
    )
    sr_checkpoint_dir = sr_checkpoint_dir or os.environ.get('AP_BWE_CHECKPOINT_DIR')
    hf_token = hf_token or os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
    if precompute_speaker_globals is None:
        raw_flag = os.environ.get("PRECOMPUTE_SPEAKER_GLOBALS", "false").strip().lower()
        precompute_speaker_globals = raw_flag in {"1", "true", "yes", "on"}
    
    logger.info(f"Initializing TTS runtime...")
    logger.info(f"  Model path: {model_path}")
    logger.info(f"  BiCodec path: {bicodec_path}")
    logger.info(f"  SR enabled: {enable_sr}, path: {sr_checkpoint_dir}")
    
    try:
        # Import inference components (vendored into veena3modal for Modal deployment)
        from veena3modal.core.model_loader import SparkTTSModel
        from veena3modal.core.pipeline import SparkTTSPipeline
        from veena3modal.core.bicodec_decoder import BiCodecDecoder
        from veena3modal.processing.prompt_builder import IndicPromptBuilder
        from veena3modal.processing.long_text_processor import LongTextProcessor

        engine_kwargs: Dict[str, Any] = {}
        if isinstance(max_num_batched_tokens, int) and max_num_batched_tokens > 0:
            engine_kwargs["max_num_batched_tokens"] = max_num_batched_tokens
        if isinstance(max_num_seqs, int) and max_num_seqs > 0:
            engine_kwargs["max_num_seqs"] = max_num_seqs
        if enable_chunked_prefill is not None:
            engine_kwargs["enable_chunked_prefill"] = bool(enable_chunked_prefill)
        if enable_prefix_caching is not None:
            engine_kwargs["enable_prefix_caching"] = bool(enable_prefix_caching)
        if disable_log_stats is not None:
            engine_kwargs["disable_log_stats"] = bool(disable_log_stats)
        if enforce_eager is not None:
            engine_kwargs["enforce_eager"] = bool(enforce_eager)
        
        # Load SparkTTS model with vLLM engine
        if num_engines > 1:
            # TIER 3 OPTIMIZATION: Multiple vLLM engines on same GPU
            # Each engine gets gpu_memory_utilization / num_engines to share GPU fairly.
            # E.g., 3 engines at 0.08 each = 0.24 total, ~8GB per engine, ~24GB total.
            from veena3modal.core.multi_engine import create_multi_engine_model
            per_engine_mem = gpu_memory_utilization / num_engines
            logger.info(f"Loading {num_engines} vLLM engines ({per_engine_mem:.2f} GPU mem each)...")
            model = create_multi_engine_model(
                model_path=model_path,
                num_engines=num_engines,
                hf_token=hf_token,
                gpu_memory_per_engine=per_engine_mem,
                **engine_kwargs,
            )
        else:
            logger.info("Loading SparkTTS model with vLLM...")
            model = SparkTTSModel(
                model_path=model_path,
                hf_token=hf_token,
                gpu_memory_utilization=gpu_memory_utilization,
                **engine_kwargs,
            )
        
        # Initialize prompt builder
        logger.info("Initializing prompt builder...")
        prompt_builder = IndicPromptBuilder(
            tokenizer=model.tokenizer,
            model=model,
        )
        
        # Initialize BiCodec decoder
        logger.info("Initializing BiCodec decoder...")
        batch_raw = os.environ.get("VEENA3_BICODEC_BATCHING", "true").strip().lower()
        enable_bicodec_batching = batch_raw in {"1", "true", "yes", "on"}
        bicodec_decoder = BiCodecDecoder(
            device=device,
            model_path=bicodec_path,
            enable_batching=enable_bicodec_batching,
        )
        
        # Initialize pipeline
        logger.info("Initializing TTS pipeline...")
        pipeline = SparkTTSPipeline(
            model=model,
            prompt_builder=prompt_builder,
            bicodec_decoder=bicodec_decoder,
        )
        
        # Initialize streaming pipeline (for M4)
        streaming_pipeline = None
        try:
            from veena3modal.core.streaming_pipeline import Veena3SlidingWindowPipeline
            logger.info("Initializing streaming pipeline...")
            # NOTE: Parameter is named 'snac_decoder' for legacy reasons, but works with BiCodecDecoder
            streaming_pipeline = Veena3SlidingWindowPipeline(
                model=model,
                prompt_builder=prompt_builder,
                snac_decoder=bicodec_decoder,  # BiCodecDecoder is interface-compatible
            )
        except ImportError as e:
            logger.warning(f"Streaming pipeline not available: {e}")

        # Initialize long text processor once (avoid per-request construction/logging)
        long_text_processor = LongTextProcessor(
            pipeline=pipeline,
            streaming_pipeline=streaming_pipeline,
        )
        
        # Initialize super-resolution (optional)
        sr_service = None
        if enable_sr and sr_checkpoint_dir:
            try:
                from veena3modal.core.super_resolution import SuperResolutionService
                logger.info(f"Initializing super-resolution from {sr_checkpoint_dir}...")
                sr_service = SuperResolutionService(checkpoint_dir=sr_checkpoint_dir)
                # Load the model explicitly
                if sr_service.load_model(device=device):
                    logger.info("✅ Super-resolution model loaded successfully")
                else:
                    logger.warning("Super-resolution model failed to load")
                    sr_service = None
            except Exception as e:
                logger.warning(f"Super-resolution not available: {e}")
        
        # Determine model version
        model_version = os.path.basename(model_path.rstrip('/'))
        if not model_version:
            model_version = "spark-tts"
        
        load_time = (time.time() - start_time) * 1000
        stream_admission = StreamingAdmissionController.from_env()
        
        # Create runtime
        _runtime = TTSRuntime(
            model=model,
            pipeline=pipeline,
            streaming_pipeline=streaming_pipeline,
            long_text_processor=long_text_processor,
            prompt_builder=prompt_builder,
            bicodec_decoder=bicodec_decoder,
            sr_service=sr_service,
            model_version=model_version,
            is_loaded=True,
            load_time_ms=load_time,
            model_path=model_path,
            bicodec_path=bicodec_path,
            sr_checkpoint_dir=sr_checkpoint_dir,
            device=device,
            stream_admission=stream_admission,
        )
        
        logger.info(f"✅ TTS runtime initialized in {load_time:.0f}ms")
        logger.info(f"   Model version: {model_version}")
        if stream_admission.enabled:
            logger.info(
                "   Streaming admission: inflight<=%d queue<=%d wait<=%.1fms poll=%.1fms",
                stream_admission.max_inflight,
                stream_admission.max_queue,
                stream_admission.max_wait_ms,
                stream_admission.poll_ms,
            )
        else:
            logger.info("   Streaming admission: disabled")
        
        # Pre-compute speaker globals only when explicitly enabled.
        # Startup warmup runs async generation and can bind AsyncLLMEngine state
        # to a non-serving event loop if executed in a different loop context.
        if precompute_speaker_globals and num_engines <= 1:
            logger.info("Speaker globals startup precompute enabled")
            _precompute_speaker_globals(_runtime)
        elif num_engines > 1:
            logger.info("Speaker globals cache startup precompute skipped (multi-engine mode)")
        else:
            logger.info("Speaker globals cache startup precompute skipped (disabled)")
        
        # Update FastAPI app with model version
        try:
            from veena3modal.api.fastapi_app import set_model_version
            set_model_version(model_version)
        except ImportError:
            pass
        
        return _runtime
        
    except Exception as e:
        logger.error(f"❌ Failed to initialize TTS runtime: {e}")
        import traceback
        traceback.print_exc()
        raise RuntimeError(f"TTS runtime initialization failed: {e}") from e


def _precompute_speaker_globals(runtime: TTSRuntime) -> None:
    """
    Pre-compute and cache global tokens for all 12 speakers at startup.
    
    OPTIMIZATION: BiCodec streaming requires 32 "global tokens" before any audio can be emitted.
    These encode speaker identity via FSQ quantization. Since we have only 12 fixed speakers,
    we generate one short utterance per speaker at startup, capture the 32 global tokens,
    and cache them. During streaming, we inject cached globals via build_prefix_with_globals(),
    skipping the ~110ms global token pre-roll phase entirely.
    
    This runs synchronously at startup (adds ~5-10s to cold start, saves ~110ms per streaming TTFB).
    """
    import asyncio
    import re
    from vllm import SamplingParams
    from veena3modal.core.constants import (
        SPEAKER_MAP, TRAINING_STOP_TOKEN_IDS,
        DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P,
    )
    
    logger.info("Pre-computing global tokens for all speakers...")
    t_start = time.time()
    
    test_text = "Hello, this is a voice test."
    
    sampling_params = SamplingParams(
        temperature=DEFAULT_TEMPERATURE,
        top_k=DEFAULT_TOP_K,
        top_p=DEFAULT_TOP_P,
        max_tokens=128,  # Only need ~60 tokens (32 global + some semantic)
        stop=TRAINING_STOP_TOKEN_IDS,
        skip_special_tokens=False,
    )
    
    async def _capture_globals_for_speaker(speaker_name: str) -> Optional[List[int]]:
        """Generate a short utterance and capture the 32 global tokens."""
        prompt = runtime.prompt_builder.build_prefix(speaker_name, test_text)
        request_id = f"warmup-{speaker_name}-{int(time.time() * 1000)}"
        
        final_output = None
        async for request_output in runtime.model.engine.generate(
            prompt=prompt,
            sampling_params=sampling_params,
            request_id=request_id,
        ):
            final_output = request_output
        
        if final_output is None:
            return None
        
        generated_text = final_output.outputs[0].text
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", generated_text)
        global_ids = [int(t) for t in global_matches]
        
        if len(global_ids) >= 32:
            return global_ids[:32]
        return None
    
    async def _run_all():
        for speaker_name in SPEAKER_MAP.keys():
            try:
                global_ids = await _capture_globals_for_speaker(speaker_name)
                if global_ids and len(global_ids) == 32:
                    runtime.speaker_global_cache[speaker_name] = global_ids
                    logger.info(f"  Cached globals for {speaker_name}")
                else:
                    logger.warning(f"  Failed to capture globals for {speaker_name}")
            except Exception as e:
                logger.warning(f"  Error caching globals for {speaker_name}: {e}")
    
    # Run async warmup in event loop
    try:
        loop = asyncio.get_event_loop()
        if loop.is_running():
            # Already in async context (unlikely at startup, but handle it)
            import concurrent.futures
            with concurrent.futures.ThreadPoolExecutor() as pool:
                pool.submit(lambda: asyncio.run(_run_all())).result()
        else:
            loop.run_until_complete(_run_all())
    except RuntimeError:
        # No event loop exists yet
        asyncio.run(_run_all())
    
    elapsed = (time.time() - t_start) * 1000
    cached = len(runtime.speaker_global_cache)
    total = len(SPEAKER_MAP)
    logger.info(f"Speaker globals cached: {cached}/{total} speakers in {elapsed:.0f}ms")


def _get_long_text_processor(runtime: TTSRuntime):
    """Return cached LongTextProcessor, creating it lazily if needed."""
    if runtime.long_text_processor is None:
        from veena3modal.processing.long_text_processor import LongTextProcessor

        runtime.long_text_processor = LongTextProcessor(
            pipeline=runtime.pipeline,
            streaming_pipeline=runtime.streaming_pipeline,
        )
    return runtime.long_text_processor


async def generate_speech(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    output_sample_rate: str = "16khz",
) -> Tuple[Optional[bytes], Dict[str, Any]]:
    """
    Generate speech audio (non-streaming).
    
    Args:
        text: Text to synthesize (already normalized)
        speaker: Internal speaker name (resolved)
        temperature: Sampling temperature
        top_k: Top-k sampling
        top_p: Nucleus sampling
        max_tokens: Maximum tokens to generate
        repetition_penalty: Repetition penalty
        seed: Random seed for reproducibility
        output_sample_rate: "16khz" or "48khz" (triggers super-resolution)
    
    Returns:
        Tuple of (audio_bytes, metrics_dict)
        audio_bytes: WAV audio bytes (16kHz or 48kHz, 16-bit PCM) or None if failed
        metrics_dict: Dictionary with timing metrics
    
    Raises:
        RuntimeError: If runtime not initialized
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    metrics = {
        "ttfb_ms": 0,
        "generation_ms": 0,
        "tokens_generated": 0,
        "audio_duration_seconds": 0.0,
        "sr_applied": False,
        "output_sample_rate": 16000,
    }
    
    start_time = time.time()
    
    try:
        # Generate audio using pipeline (always at 16kHz)
        audio_bytes, perf = await runtime.pipeline.generate_speech_profiled(
            speaker=speaker,
            text=text,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
            seed=seed,
        )
        if perf:
            metrics.update(perf)
        
        generation_time = time.time() - start_time
        metrics["generation_ms"] = int(generation_time * 1000)
        metrics["ttfb_ms"] = metrics["generation_ms"]  # Non-streaming: TTFB = total time
        
        if audio_bytes:
            # Apply super-resolution if requested and available
            sample_rate = 16000
            if output_sample_rate == "48khz" or logger.isEnabledFor(logging.DEBUG):
                logger.debug(
                    "SR check: output_sample_rate=%r, sr_service=%s, is_loaded=%s",
                    output_sample_rate,
                    runtime.sr_service is not None,
                    runtime.sr_service.is_loaded if runtime.sr_service else False,
                )
            
            if output_sample_rate == "48khz" and runtime.sr_service and runtime.sr_service.is_loaded:
                try:
                    sr_start = time.time()
                    logger.info(f"Applying super-resolution to {len(audio_bytes)} bytes...")
                    audio_bytes = _apply_super_resolution(audio_bytes, runtime.sr_service)
                    sr_time = (time.time() - sr_start) * 1000
                    metrics["sr_ms"] = int(sr_time)
                    metrics["sr_applied"] = True
                    sample_rate = 48000
                    logger.info(f"✅ Super-resolution applied in {sr_time:.1f}ms, output={len(audio_bytes)} bytes")
                except Exception as e:
                    logger.warning(f"Super-resolution failed, returning 16kHz: {e}")
                    import traceback
                    traceback.print_exc()
            
            metrics["output_sample_rate"] = sample_rate
            
            # Calculate audio duration
            audio_duration = (len(audio_bytes) - 44) / (sample_rate * 2)  # -44 for WAV header
            metrics["audio_duration_seconds"] = max(0.0, audio_duration)
        
        return audio_bytes, metrics
        
    except Exception as e:
        logger.error(f"Speech generation failed: {e}")
        raise


def _apply_super_resolution(audio_bytes: bytes, sr_service) -> bytes:
    """
    Apply super-resolution to audio bytes.
    
    Args:
        audio_bytes: WAV audio at 16kHz
        sr_service: SuperResolutionService instance
    
    Returns:
        WAV audio at 48kHz
    """
    import numpy as np
    import struct
    import torch
    
    # Parse WAV header and extract PCM data
    if len(audio_bytes) < 44:
        raise ValueError("Invalid WAV data")
    
    # Extract PCM data (skip 44-byte WAV header)
    pcm_data = np.frombuffer(audio_bytes[44:], dtype=np.int16)
    
    # Convert to float32 for SR model
    audio_float = pcm_data.astype(np.float32) / 32768.0
    
    # Convert numpy to torch tensor for process_chunk
    audio_tensor = torch.from_numpy(audio_float)
    
    # Apply super-resolution using process_chunk (16kHz -> 48kHz)
    # process_chunk expects [batch, samples] or [samples] and returns [batch, samples]
    upsampled_tensor = sr_service.process_chunk(audio_tensor)
    
    # Convert back to numpy, squeeze batch dim if present
    upsampled = upsampled_tensor.squeeze().cpu().numpy()
    
    # Convert back to int16
    upsampled_int16 = np.clip(upsampled * 32768.0, -32768, 32767).astype(np.int16)
    
    # Create new WAV header for 48kHz
    sample_rate = 48000
    num_channels = 1
    bits_per_sample = 16
    byte_rate = sample_rate * num_channels * bits_per_sample // 8
    block_align = num_channels * bits_per_sample // 8
    data_size = len(upsampled_int16) * 2
    
    wav_header = struct.pack(
        '<4sI4s4sIHHIIHH4sI',
        b'RIFF',
        36 + data_size,
        b'WAVE',
        b'fmt ',
        16,  # fmt chunk size
        1,   # PCM format
        num_channels,
        sample_rate,
        byte_rate,
        block_align,
        bits_per_sample,
        b'data',
        data_size,
    )
    
    return wav_header + upsampled_int16.tobytes()


async def generate_speech_chunked(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    sample_rate: int = 16000,
    output_sample_rate: str = "16khz",
) -> Tuple[Optional[bytes], Dict[str, Any]]:
    """
    Generate speech with automatic text chunking for long inputs.
    
    Uses LongTextProcessor to split text and stitch audio.
    
    Args:
        text: Text to synthesize
        speaker: Internal speaker name
        ... (same as generate_speech)
        sample_rate: Internal generation sample rate (always 16000)
        output_sample_rate: "16khz" or "48khz" (triggers super-resolution)
    
    Returns:
        Tuple of (audio_bytes, metrics_dict)
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    metrics = {
        "ttfb_ms": 0,
        "generation_ms": 0,
        "chunks_processed": 0,
        "audio_duration_seconds": 0.0,
        "text_chunked": False,
        "sr_applied": False,
        "output_sample_rate": 16000,
    }
    
    start_time = time.time()
    
    try:
        long_processor = _get_long_text_processor(runtime)
        
        # Check if chunking is needed
        if long_processor.should_chunk(text):
            metrics["text_chunked"] = True
            audio_bytes = await long_processor.generate_with_chunking(
                text=text,
                speaker=speaker,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
                sample_rate=sample_rate,
            )
        else:
            # Short text: direct generation
            audio_bytes, perf = await runtime.pipeline.generate_speech_profiled(
                speaker=speaker,
                text=text,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
            )
            if perf:
                metrics.update(perf)
        
        generation_time = time.time() - start_time
        metrics["generation_ms"] = int(generation_time * 1000)
        metrics["ttfb_ms"] = metrics["generation_ms"]
        
        if audio_bytes:
            # Apply super-resolution if requested and available
            final_sample_rate = 16000
            if output_sample_rate == "48khz" and runtime.sr_service and runtime.sr_service.is_loaded:
                try:
                    sr_start = time.time()
                    logger.info(f"Applying super-resolution to chunked audio ({len(audio_bytes)} bytes)...")
                    audio_bytes = _apply_super_resolution(audio_bytes, runtime.sr_service)
                    sr_time = (time.time() - sr_start) * 1000
                    metrics["sr_ms"] = int(sr_time)
                    metrics["sr_applied"] = True
                    final_sample_rate = 48000
                    logger.info(f"✅ Super-resolution applied in {sr_time:.1f}ms, output={len(audio_bytes)} bytes")
                except Exception as e:
                    logger.warning(f"Super-resolution failed, returning 16kHz: {e}")
                    import traceback
                    traceback.print_exc()
            
            metrics["output_sample_rate"] = final_sample_rate
            audio_duration = (len(audio_bytes) - 44) / (final_sample_rate * 2)
            metrics["audio_duration_seconds"] = max(0.0, audio_duration)
        
        return audio_bytes, metrics
        
    except Exception as e:
        logger.error(f"Chunked speech generation failed: {e}")
        raise


async def generate_speech_streaming(
    text: str,
    speaker: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_tokens: int = 4096,
    repetition_penalty: float = 1.05,
    seed: Optional[int] = None,
    enable_chunking: bool = True,
    admission_lease: Optional[StreamingAdmissionLease] = None,
    release_admission_lease: bool = True,
) -> AsyncGenerator[Tuple[bytes, Dict[str, Any]], None]:
    """
    Generate speech audio with true streaming (yields chunks as they're generated).
    
    This is the core streaming implementation for M4.
    First yield includes WAV header + first PCM chunk.
    Subsequent yields are raw PCM chunks.
    
    Args:
        text: Text to synthesize (already normalized)
        speaker: Internal speaker name (resolved)
        temperature: Sampling temperature
        top_k: Top-k sampling
        top_p: Nucleus sampling
        max_tokens: Maximum tokens to generate
        repetition_penalty: Repetition penalty
        seed: Random seed for reproducibility
        enable_chunking: Enable text chunking for long inputs with voice consistency
        admission_lease: Optional pre-acquired admission lease.
        release_admission_lease: Whether to release lease before generator returns.
    
    Yields:
        Tuple of (audio_bytes, metrics_dict)
        - First yield: WAV header (44 bytes) prepended to first PCM chunk
        - Subsequent yields: Raw PCM chunks (int16, 16kHz)
        - Final yield: metrics_dict has final timing info
    
    Raises:
        RuntimeError: If runtime not initialized or streaming pipeline unavailable
    """
    if not is_initialized():
        raise RuntimeError("TTS runtime not initialized")
    
    runtime = get_runtime()
    
    if runtime.streaming_pipeline is None:
        raise RuntimeError("Streaming pipeline not available")

    lease = admission_lease
    if lease is None:
        lease = await acquire_streaming_slot()

    controller = _get_stream_admission_controller(runtime)
    admission_wait_ms = float(getattr(lease, "wait_ms", 0.0) or 0.0)
    
    # Import audio utils for WAV header
    from veena3modal.audio.utils import create_wav_header
    
    # Metrics tracking
    metrics = {
        "ttfb_ms": 0,
        "chunks_sent": 0,
        "total_bytes": 0,
        "audio_duration_seconds": 0.0,
        "text_chunked": False,
        "admission_wait_ms": float(admission_wait_ms),
        "admission_queued": bool(getattr(lease, "queued", False)),
        "admission_inflight_on_grant": int(getattr(lease, "inflight_on_grant", 0) or 0),
        "admission_waiters_on_grant": int(getattr(lease, "waiters_on_grant", 0) or 0),
        "admission_queue_depth_on_entry": int(getattr(lease, "queue_depth_on_entry", 0) or 0),
        "admission_max_inflight": int(controller.max_inflight),
        "admission_max_queue": int(controller.max_queue),
        "admission_max_wait_ms": float(controller.max_wait_ms),
    }
    
    # Include admission wait in end-to-end TTFB accounting.
    start_time = time.time() - (admission_wait_ms / 1000.0)
    first_chunk_time = None
    total_pcm_bytes = 0
    sample_rate = 16000  # BiCodec sample rate
    
    try:
        # Check if we need text chunking
        long_processor = _get_long_text_processor(runtime)
        needs_chunking = enable_chunking and long_processor.should_chunk(text)
        
        if needs_chunking:
            # Chunked streaming with voice consistency (global token caching)
            metrics["text_chunked"] = True
            async for audio_chunk, chunk_metrics in _stream_chunked_text(
                runtime, long_processor, text, speaker, temperature, top_k, top_p,
                max_tokens, repetition_penalty, seed, sample_rate
            ):
                # First chunk: prepend WAV header
                if first_chunk_time is None:
                    first_chunk_time = time.time()
                    metrics["ttfb_ms"] = int((first_chunk_time - start_time) * 1000)
                    # Create streaming WAV header (size=0 for unknown length)
                    wav_header = create_wav_header(sample_rate=sample_rate, data_size=0)
                    audio_chunk = wav_header + audio_chunk
                
                total_pcm_bytes += len(audio_chunk) - (44 if metrics["chunks_sent"] == 0 else 0)
                metrics["chunks_sent"] += 1
                metrics["total_bytes"] = total_pcm_bytes + 44  # Include header
                metrics["audio_duration_seconds"] = total_pcm_bytes / (sample_rate * 2)

                if chunk_metrics:
                    metrics.update(chunk_metrics)
                
                yield audio_chunk, metrics
        else:
            # Simple streaming (no text chunking)
            # OPTIMIZATION: If we have pre-cached global tokens for this speaker,
            # use continuation mode to skip the ~110ms global token pre-roll.
            # The model jumps straight to semantic token generation.
            cached_globals = runtime.speaker_global_cache.get(speaker)
            
            if cached_globals:
                # Fast path: use cached globals, skip global token generation entirely
                stream_gen = runtime.streaming_pipeline.generate_speech_stream_indic_continuation(
                    speaker=speaker,
                    text=text,
                    global_ids=cached_globals,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                    emit_progress=True,
                )
            else:
                # Fallback: no cached globals, generate them inline (original path)
                stream_gen = runtime.streaming_pipeline.generate_speech_stream_indic(
                    speaker=speaker,
                    text=text,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                    emit_progress=True,
                )
            
            async for stream_item in stream_gen:
                stream_metrics: Dict[str, Any] = {}
                audio_chunk = stream_item
                if (
                    isinstance(stream_item, tuple)
                    and len(stream_item) == 2
                    and isinstance(stream_item[1], dict)
                ):
                    audio_chunk = stream_item[0]
                    stream_metrics = stream_item[1]

                if stream_metrics:
                    metrics.update(stream_metrics)

                if not audio_chunk:
                    # Metrics-only event from streaming pipeline.
                    continue

                # First chunk: prepend WAV header
                if first_chunk_time is None:
                    first_chunk_time = time.time()
                    metrics["ttfb_ms"] = int((first_chunk_time - start_time) * 1000)
                    wav_header = create_wav_header(sample_rate=sample_rate, data_size=0)
                    audio_chunk = wav_header + audio_chunk
                
                total_pcm_bytes += len(audio_chunk) - (44 if metrics["chunks_sent"] == 0 else 0)
                metrics["chunks_sent"] += 1
                metrics["total_bytes"] = total_pcm_bytes + 44
                metrics["audio_duration_seconds"] = total_pcm_bytes / (sample_rate * 2)
                
                yield audio_chunk, metrics
    finally:
        if release_admission_lease:
            release_streaming_slot(lease)


async def _stream_chunked_text(
    runtime: TTSRuntime,
    long_processor,
    text: str,
    speaker: str,
    temperature: float,
    top_k: int,
    top_p: float,
    max_tokens: int,
    repetition_penalty: float,
    seed: Optional[int],
    sample_rate: int,
) -> AsyncGenerator[Tuple[bytes, Dict[str, Any]], None]:
    """
    Internal helper: stream chunked text with voice consistency.
    
    Uses global token caching from first chunk to maintain voice across chunks.
    """
    from veena3modal.audio.crossfade import crossfade_bytes_int16
    
    # Chunk text using the long text processor's chunker
    chunks = long_processor.chunk_text(text)
    
    if not chunks:
        return
    
    captured_globals: Optional[List[int]] = None
    previous_chunk_tail: Optional[bytes] = None
    chunk_metrics = {"chunks_processed": 0}
    
    for i, chunk_text in enumerate(chunks):
        chunk_metrics["chunks_processed"] = i + 1
        
        if i == 0:
            # First chunk: capture global tokens for voice consistency
            async for audio_bytes, global_ids in runtime.streaming_pipeline.generate_speech_stream_indic_first_chunk(
                speaker=speaker,
                text=chunk_text,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_tokens=max_tokens,
                repetition_penalty=repetition_penalty,
                seed=seed,
            ):
                if captured_globals is None and global_ids:
                    captured_globals = global_ids
                
                # Crossfade with previous tail (inter-chunk stitching)
                to_emit, previous_chunk_tail = crossfade_bytes_int16(
                    previous_chunk_tail,
                    audio_bytes,
                    sample_rate_hz=sample_rate,
                    crossfade_ms=50,
                )
                
                if to_emit:
                    yield to_emit, chunk_metrics
        else:
            # Continuation chunks: use captured globals for voice consistency
            if captured_globals is None:
                logger.warning(f"No captured globals for chunk {i+1}, using regular streaming")
                async for audio_bytes in runtime.streaming_pipeline.generate_speech_stream_indic(
                    speaker=speaker,
                    text=chunk_text,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                ):
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        audio_bytes,
                        sample_rate_hz=sample_rate,
                        crossfade_ms=50,
                    )
                    if to_emit:
                        yield to_emit, chunk_metrics
            else:
                # Use continuation method with cached globals
                async for audio_bytes in runtime.streaming_pipeline.generate_speech_stream_indic_continuation(
                    speaker=speaker,
                    text=chunk_text,
                    global_ids=captured_globals,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_tokens=max_tokens,
                    repetition_penalty=repetition_penalty,
                    seed=seed,
                ):
                    to_emit, previous_chunk_tail = crossfade_bytes_int16(
                        previous_chunk_tail,
                        audio_bytes,
                        sample_rate_hz=sample_rate,
                        crossfade_ms=50,
                    )
                    if to_emit:
                        yield to_emit, chunk_metrics
    
    # Flush remaining tail
    if previous_chunk_tail:
        yield previous_chunk_tail, chunk_metrics
