"""
FastAPI ASGI app factory for Modal deployment.

Phase 1: expose `/v1/tts/generate` and `/v1/tts/health` while reusing existing
Veena3 inference/streaming code from the Django project.

Phase 2+: extract shared code into `veena3modal/shared/` and remove Django coupling.
"""

# NOTE: Intentionally NOT using `from __future__ import annotations` here
# to ensure FastAPI recognizes Request type annotations properly

import time
import uuid
import json
import threading
from typing import Optional, Any, Dict

# Structured logging (JSON formatted, request-scoped)
from veena3modal.shared.logging import (
    get_logger,
    set_request_context,
    clear_request_context,
    log_request_received,
    log_first_audio_emitted,
    log_request_completed,
    log_request_failed,
)

# Prometheus metrics
from veena3modal.shared.metrics import (
    record_request_received as metrics_request_received,
    record_request_completed as metrics_request_completed,
    record_request_failed as metrics_request_failed,
    record_ttfb,
    record_rtf,
    record_audio_duration,
    record_chunks_sent,
    get_metrics_text,
    get_content_type as get_metrics_content_type,
)

# Auth and rate limiting
from veena3modal.api.auth import (
    get_api_validator,
    extract_api_key,
    hash_api_key,
)
from veena3modal.api.rate_limiter import get_rate_limiter
from veena3modal.api.error_handlers import (
    ErrorCode,
    create_error_response,
    get_error_status,
    is_gpu_fault,
    handle_gpu_fault,
    FeatureFlags,
)

logger = get_logger(__name__)

# Sentence storage (fire-and-forget, non-blocking)
# Imported at module level for singleton access
_sentence_store = None


def _get_sentence_store():
    """Lazy load sentence store singleton."""
    global _sentence_store
    if _sentence_store is None:
        from veena3modal.services.sentence_store import get_sentence_store
        _sentence_store = get_sentence_store()
    return _sentence_store

# Version info (will be set dynamically after model load)
_APP_VERSION = "0.1.0"
_MODEL_VERSION: Optional[str] = None
_startup_time: Optional[float] = None

# Request-level in-flight tracking (single-process scope)
_inflight_lock = threading.Lock()
_inflight_requests = 0


def _inflight_enter() -> int:
    """Increment in-flight counter and return current value."""
    global _inflight_requests
    with _inflight_lock:
        _inflight_requests += 1
        return _inflight_requests


def _inflight_exit() -> int:
    """Decrement in-flight counter and return current value."""
    global _inflight_requests
    with _inflight_lock:
        _inflight_requests = max(0, _inflight_requests - 1)
        return _inflight_requests


def _serialize_header_value(value: Any) -> Optional[str]:
    """Convert metric values to compact header-safe strings."""
    if value is None:
        return None
    if isinstance(value, bool):
        return "true" if value else "false"
    if isinstance(value, int):
        return str(value)
    if isinstance(value, float):
        if value.is_integer():
            return str(int(value))
        return f"{value:.3f}"
    return str(value)


def _build_perf_headers(metrics: Dict[str, Any], request_inflight: int) -> Dict[str, str]:
    """
    Build detailed timing headers consumed by local benchmark scripts.

    Includes a compact JSON payload (`X-Perf-Details`) and selected explicit headers
    for easy parsing in shell scripts/tools.
    """
    perf_headers: Dict[str, str] = {}
    if not isinstance(metrics, dict):
        metrics = {}

    detail_payload: Dict[str, Any] = {}
    if request_inflight > 0:
        detail_payload["request_inflight"] = request_inflight

    # Expose selected metrics as explicit headers.
    header_map = {
        "x-generation-ms": "generation_ms",
        "x-request-inflight": "request_inflight",
        "x-api-preprocess-ms": "api_preprocess_ms",
        "x-api-generate-await-ms": "api_generate_await_ms",
        "x-api-postprocess-ms": "api_postprocess_ms",
        "x-api-total-ms": "api_total_ms",
        "x-api-overhead-vs-pipeline-ms": "api_overhead_vs_pipeline_ms",
        "x-api-overhead-vs-generation-ms": "api_overhead_vs_generation_ms",
        "x-timeline-first-batch-ms": "timeline_llm_first_batch_ms",
        "x-timeline-llm-done-ms": "timeline_llm_done_ms",
        "x-timeline-bicodec-done-ms": "timeline_bicodec_done_ms",
        "x-timeline-total-ms": "timeline_total_ms",
        "x-llm-token-total": "llm_token_total",
        "x-llm-batch-count": "llm_batch_count",
        "x-llm-batch-wall-ms": "llm_batch_wall_ms_total",
        "x-llm-batch-gpu-ms": "llm_batch_gpu_ms_total",
        "x-llm-batch-wall-ms-min": "llm_batch_wall_ms_min",
        "x-llm-batch-wall-ms-max": "llm_batch_wall_ms_max",
        "x-llm-batch-wall-ms-p50": "llm_batch_wall_ms_p50",
        "x-llm-batch-gpu-ms-min": "llm_batch_gpu_ms_min",
        "x-llm-batch-gpu-ms-max": "llm_batch_gpu_ms_max",
        "x-llm-batch-gpu-ms-p50": "llm_batch_gpu_ms_p50",
        "x-llm-decode-wall-ms": "llm_decode_wall_ms_total",
        "x-llm-decode-gpu-ms": "llm_decode_gpu_ms_total",
        "x-llm-decode-wall-ms-min": "llm_decode_wall_ms_min",
        "x-llm-decode-wall-ms-max": "llm_decode_wall_ms_max",
        "x-llm-decode-wall-ms-p50": "llm_decode_wall_ms_p50",
        "x-llm-decode-gpu-ms-min": "llm_decode_gpu_ms_min",
        "x-llm-decode-gpu-ms-max": "llm_decode_gpu_ms_max",
        "x-llm-decode-gpu-ms-p50": "llm_decode_gpu_ms_p50",
        "x-llm-decode-calls": "llm_decode_calls",
        "x-llm-decode-cpu-ms": "bicodec_decode_cpu_ms",
        "x-llm-parse-ms": "llm_parse_ms",
        "x-llm-parse-avg-ms": "llm_parse_ms",
        "x-llm-tokens-per-batch-min": "tokens_per_batch_min",
        "x-llm-tokens-per-batch-max": "tokens_per_batch_max",
        "x-llm-tokens-per-batch-p50": "tokens_per_batch_p50",
        "x-llm-time-per-token-ms": "llm_time_per_token_ms",
        "x-llm-time-per-batch-wall-ms": "llm_time_per_batch_wall_ms",
        "x-llm-time-per-batch-gpu-ms": "llm_time_per_batch_gpu_ms",
        "x-llm-time-in-queue-ms": "llm_time_in_queue_ms",
        "x-llm-scheduler-ms": "llm_scheduler_ms",
        "x-llm-first-token-ms": "llm_first_token_ms",
        "x-llm-lifecycle-ms": "llm_request_lifecycle_ms",
        "x-llm-text-chunked": "text_chunked",
        "x-llm-chunks-processed": "chunks_processed",
        "x-bicodec-decode-wall-ms": "bicodec_decode_wall_ms",
        "x-bicodec-decode-gpu-ms": "bicodec_decode_gpu_ms",
        "x-bicodec-decode-cpu-ms": "bicodec_decode_cpu_ms",
    }

    if request_inflight > 0:
        metrics = dict(metrics)
        metrics["request_inflight"] = request_inflight

    for header_key, metric_key in header_map.items():
        if metric_key not in metrics:
            continue
        value = metrics.get(metric_key)
        serialized = _serialize_header_value(value)
        if serialized is None:
            continue
        perf_headers[header_key] = serialized
        detail_payload[metric_key] = value

    # Include additional timing fields (bounded set) in JSON payload only.
    extra_detail_keys = (
        "prompt_build_ms",
        "sampling_params_ms",
        "api_preprocess_ms",
        "api_generate_await_ms",
        "api_postprocess_ms",
        "api_total_ms",
        "api_overhead_vs_pipeline_ms",
        "api_overhead_vs_generation_ms",
        "timeline_markers",
        "timeline_prompt_ready_ms",
        "timeline_sampling_ready_ms",
        "timeline_llm_first_batch_ms",
        "timeline_llm_done_ms",
        "timeline_parse_done_ms",
        "timeline_validation_done_ms",
        "timeline_bicodec_done_ms",
        "timeline_wav_done_ms",
        "timeline_request_done_ms",
        "timeline_to_first_batch_ms",
        "timeline_first_batch_to_llm_done_ms",
        "timeline_post_llm_ms",
        "timeline_parse_to_bicodec_done_ms",
        "timeline_bicodec_to_request_done_ms",
        "timeline_total_ms",
        "llm_prefill_wall_ms",
        "llm_prefill_gpu_ms",
        "llm_generation_wall_ms",
        "llm_time_in_queue_ms",
        "llm_scheduler_ms",
        "llm_model_forward_ms",
        "llm_model_execute_ms",
        "llm_first_token_ms",
        "llm_request_lifecycle_ms",
        "llm_queued_to_scheduled_ms",
        "llm_scheduled_to_first_token_ms",
        "llm_first_to_last_token_ms",
        "llm_queued_to_last_token_ms",
        "llm_prompt_token_total",
        "semantic_token_total",
        "global_token_total",
        "wav_pack_ms",
        "pipeline_total_ms",
        "token_validation_ms",
    )
    for key in extra_detail_keys:
        if key in metrics:
            detail_payload[key] = metrics[key]

    perf_headers["x-perf-details"] = json.dumps(
        detail_payload,
        separators=(",", ":"),
        ensure_ascii=True,
    )
    return perf_headers


def get_model_version() -> str:
    """Return the current model version, or 'not_loaded' if not initialized."""
    return _MODEL_VERSION or "not_loaded"


def set_model_version(version: str) -> None:
    """Set the model version after loading."""
    global _MODEL_VERSION
    _MODEL_VERSION = version


def create_app():
    """
    Factory function to create FastAPI app.
    
    Import inside factory so local tooling can import this module without FastAPI installed.
    """
    from fastapi import FastAPI, HTTPException, Request
    from fastapi.responses import JSONResponse, Response

    global _startup_time
    _startup_time = time.time()

    app = FastAPI(
        title="Veena3 TTS (Modal)",
        version=_APP_VERSION,
        description="High-quality multilingual Text-to-Speech API with true streaming support.",
    )

    @app.get("/v1/tts/health")
    async def tts_health():
        """
        Health check endpoint.
        
        Returns:
            - status: "healthy" | "degraded" | "unhealthy"
            - model_loaded: whether TTS model is initialized
            - model_version: version string of loaded model
            - uptime_seconds: time since app startup
            - gpu_available: whether GPU is detected (best-effort check)
        """
        # Check model status via runtime
        from veena3modal.services import tts_runtime
        model_loaded = tts_runtime.is_initialized()
        
        # Use runtime's model version if available
        runtime = tts_runtime.get_runtime()
        model_version = runtime.model_version if runtime else get_model_version()
        
        # Check GPU availability (best-effort, non-blocking)
        gpu_available = False
        try:
            import torch
            gpu_available = torch.cuda.is_available()
        except ImportError:
            pass  # torch not installed, skip GPU check
        
        # Compute uptime
        uptime = time.time() - _startup_time if _startup_time else 0
        
        # Determine overall status
        if model_loaded and gpu_available:
            status = "healthy"
        elif model_loaded or gpu_available:
            status = "degraded"
        else:
            status = "degraded"
        
        return JSONResponse(
            content={
                "status": status,
                "model_loaded": model_loaded,
                "model_version": model_version,
                "uptime_seconds": round(uptime, 2),
                "gpu_available": gpu_available,
                "app_version": _APP_VERSION,
            },
            headers={
                "X-Model-Version": model_version,
                "X-App-Version": _APP_VERSION,
            },
        )

    @app.get("/v1/tts/metrics")
    async def tts_metrics():
        """
        Prometheus metrics endpoint.
        
        Returns metrics in Prometheus text format for scraping.
        """
        return Response(
            content=get_metrics_text(),
            media_type=get_metrics_content_type(),
        )

    @app.post("/v1/tts/generate")
    async def tts_generate(request: Request):
        """
        Generate speech from text.
        
        Supports streaming (M4) and non-streaming (M3) modes.
        Currently implements non-streaming WAV generation only.
        """
        from veena3modal.api.schemas import TTSGenerateRequest, AudioFormat
        from veena3modal.services import tts_runtime
        
        request_id = str(uuid.uuid4())
        start_time = time.time()
        request_start_perf = time.perf_counter()
        
        # === AUTH CHECK ===
        if FeatureFlags.is_auth_enabled():
            # Extract API key from headers
            headers_dict = dict(request.headers)
            api_key = extract_api_key(headers_dict)
            
            if not api_key:
                return JSONResponse(
                    status_code=401,
                    content=create_error_response(
                        code=ErrorCode.INVALID_API_KEY,
                        message="API key required. Use 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header.",
                        request_id=request_id,
                    ),
                    headers={"X-Request-ID": request_id},
                )
            
            # Validate API key
            validator = get_api_validator()
            key_hash = hash_api_key(api_key)
            auth_result = validator.validate(key_hash)
            
            if not auth_result.is_valid:
                status_code = 401 if auth_result.error_code == "INVALID_API_KEY" else 403
                return JSONResponse(
                    status_code=status_code,
                    content=create_error_response(
                        code=ErrorCode(auth_result.error_code),
                        message=auth_result.error_message or "Authentication failed",
                        request_id=request_id,
                    ),
                    headers={"X-Request-ID": request_id},
                )
            
            # === RATE LIMIT CHECK ===
            if FeatureFlags.is_rate_limiting_enabled():
                rate_limiter = get_rate_limiter()
                allowed, remaining, reset_after = rate_limiter.check(key_hash)
                
                if not allowed:
                    headers = rate_limiter.get_headers(allowed, remaining, reset_after)
                    headers["X-Request-ID"] = request_id
                    return JSONResponse(
                        status_code=429,
                        content=create_error_response(
                            code=ErrorCode.RATE_LIMIT_EXCEEDED,
                            message=f"Rate limit exceeded. Retry after {int(reset_after) + 1} seconds.",
                            request_id=request_id,
                        ),
                        headers=headers,
                    )
        
        # Parse and validate request body
        try:
            body = await request.json()
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
        
        try:
            req = TTSGenerateRequest(**body)
        except Exception as e:
            error_msg = str(e)
            # Extract validation error details
            return JSONResponse(
                status_code=400,
                content={
                    "error": {
                        "code": "VALIDATION_ERROR",
                        "message": error_msg,
                        "request_id": request_id,
                    }
                },
                headers={"X-Request-ID": request_id},
            )
        
        # Check if runtime is initialized
        if not tts_runtime.is_initialized():
            return JSONResponse(
                status_code=503,
                content={
                    "error": {
                        "code": "MODEL_NOT_LOADED",
                        "message": "TTS model not initialized. Please wait for warmup.",
                        "request_id": request_id,
                    }
                },
                headers={"X-Request-ID": request_id},
            )
        
        # Get resolved speaker
        speaker = req.get_resolved_speaker()
        
        # Get normalized text (apply text normalization if enabled)
        normalizer_func = None
        if req.normalize:
            try:
                from veena3modal.processing.text_normalizer import normalize_text
                normalizer_func = lambda t: normalize_text(t, verbose=req.normalize_verbose)
            except ImportError:
                logger.warning("TextNormalizer not available, skipping text normalization")
        
        text = req.get_normalized_text(normalizer_func=normalizer_func)
        
        # Set request context for structured logging
        set_request_context(request_id=request_id)
        
        # Log request received (lifecycle event + metrics)
        log_request_received(
            request_id=request_id,
            text_length=len(text),
            speaker=speaker,
            stream=req.stream,
            format=req.format,
        )
        metrics_request_received(speaker=speaker, stream=req.stream, format=req.format)
        
        # Streaming mode (M4)
        if req.stream:
            return await _handle_streaming_request(req, request_id, text, speaker, start_time)
        
        # Format encoding will be applied after generation (non-streaming only)
        # For streaming, only WAV is supported
        target_format = req.format
        inflight_current = _inflight_enter()
        try:
            try:
                # Get output sample rate - use_enum_values=True means req.output is already a string
                output_sr = req.output if hasattr(req, 'output') else "16khz"
                generate_call_start = time.perf_counter()
                
                # Generate speech (non-streaming)
                if req.chunking:
                    audio_bytes, metrics = await tts_runtime.generate_speech_chunked(
                        text=text,
                        speaker=speaker,
                        temperature=req.temperature,
                        top_k=req.top_k,
                        top_p=req.top_p,
                        max_tokens=req.max_tokens,
                        repetition_penalty=req.repetition_penalty,
                        seed=req.seed,
                        sample_rate=req.sample_rate,
                        output_sample_rate=output_sr,
                    )
                else:
                    audio_bytes, metrics = await tts_runtime.generate_speech(
                        text=text,
                        speaker=speaker,
                        temperature=req.temperature,
                        top_k=req.top_k,
                        top_p=req.top_p,
                        max_tokens=req.max_tokens,
                        repetition_penalty=req.repetition_penalty,
                        seed=req.seed,
                        output_sample_rate=output_sr,
                    )
                generate_call_end = time.perf_counter()
                
                if metrics is None:
                    metrics = {}
                metrics["api_preprocess_ms"] = (generate_call_start - request_start_perf) * 1000
                metrics["api_generate_await_ms"] = (generate_call_end - generate_call_start) * 1000
                metrics["request_inflight"] = max(
                    int(metrics.get("request_inflight", 0) or 0),
                    inflight_current,
                )

                if audio_bytes is None:
                    return JSONResponse(
                        status_code=500,
                        content={
                            "error": {
                                "code": "GENERATION_FAILED",
                                "message": "Audio generation failed. Check logs for details.",
                                "request_id": request_id,
                            }
                        },
                        headers={"X-Request-ID": request_id},
                    )
                
                # Calculate RTF
                total_time = time.time() - start_time
                audio_duration = metrics.get("audio_duration_seconds", 0)
                rtf = total_time / audio_duration if audio_duration > 0 else 0
                
                # Build response headers
                runtime = tts_runtime.get_runtime()
                model_version = runtime.model_version if runtime else get_model_version()
                
                # Use actual output sample rate from metrics
                output_sample_rate = metrics.get("output_sample_rate", req.sample_rate)
                
                headers = {
                    "X-Request-ID": request_id,
                    "X-Model-Version": model_version,
                    "X-Format": "wav",
                    "X-Sample-Rate": str(output_sample_rate),
                    "X-Stream": "false",
                    "X-Audio-Bytes": str(len(audio_bytes)),
                    "X-Audio-Seconds": f"{audio_duration:.2f}",
                    "X-TTFB-ms": str(metrics.get("ttfb_ms", 0)),
                    "X-RTF": f"{rtf:.3f}",
                    "X-Text-Chunked": str(metrics.get("text_chunked", False)).lower(),
                    "X-SR-Applied": str(metrics.get("sr_applied", False)).lower(),
                }
                
                if req.seed is not None:
                    headers["X-Seed"] = str(req.seed)
                
                # Apply format encoding if not WAV
                media_type = "audio/wav"
                if target_format != AudioFormat.WAV.value:
                    try:
                        from veena3modal.audio.encoder import (
                            OpusEncoder, MP3Encoder, MuLawEncoder, FLACEncoder
                        )
                        
                        # Get sample rate from output (after potential SR)
                        sample_rate = metrics.get("output_sample_rate", 16000)
                        
                        # Strip WAV header to get PCM data
                        pcm_data = audio_bytes[44:] if len(audio_bytes) > 44 else audio_bytes
                        
                        if target_format == AudioFormat.OPUS.value:
                            encoder = OpusEncoder(sample_rate=sample_rate)
                            audio_bytes = encoder.encode(pcm_data, sample_rate)
                            media_type = "audio/opus"
                        elif target_format == AudioFormat.MP3.value:
                            encoder = MP3Encoder(sample_rate=sample_rate)
                            audio_bytes = encoder.encode(pcm_data, sample_rate)
                            media_type = "audio/mpeg"
                        elif target_format == AudioFormat.MULAW.value:
                            encoder = MuLawEncoder()
                            audio_bytes = encoder.encode(pcm_data, sample_rate)
                            media_type = "audio/x-wav"
                        elif target_format == AudioFormat.FLAC.value:
                            encoder = FLACEncoder(sample_rate=sample_rate)
                            audio_bytes = encoder.encode(pcm_data, sample_rate)
                            media_type = "audio/flac"
                        
                        headers["X-Audio-Bytes"] = str(len(audio_bytes))
                        logger.info(f"Encoded audio to {target_format}: {len(audio_bytes)} bytes")
                    except ImportError as e:
                        logger.warning(f"Audio encoder not available: {e}")
                        # Fall back to WAV
                    except Exception as e:
                        logger.error(f"Format encoding failed: {e}")
                        # Fall back to WAV

                api_total_ms = (time.perf_counter() - request_start_perf) * 1000
                api_preprocess_ms = float(metrics.get("api_preprocess_ms", 0.0) or 0.0)
                api_generate_await_ms = float(metrics.get("api_generate_await_ms", 0.0) or 0.0)
                metrics["api_total_ms"] = api_total_ms
                metrics["api_postprocess_ms"] = max(
                    0.0, api_total_ms - api_preprocess_ms - api_generate_await_ms
                )
                timeline_total_ms = metrics.get("timeline_total_ms")
                if isinstance(timeline_total_ms, (int, float)):
                    metrics["api_overhead_vs_pipeline_ms"] = max(
                        0.0, api_total_ms - float(timeline_total_ms)
                    )
                generation_ms = metrics.get("generation_ms")
                if isinstance(generation_ms, (int, float)):
                    metrics["api_overhead_vs_generation_ms"] = max(
                        0.0, api_total_ms - float(generation_ms)
                    )
                headers.update(_build_perf_headers(metrics, request_inflight=inflight_current))

                # Fire-and-forget sentence storage (non-blocking, after generation)
                sentence_store = _get_sentence_store()
                sentence_store.store_fire_and_forget(
                    request_id=request_id,
                    text=text,
                    speaker=speaker,
                    stream=False,
                    format=req.format,
                    temperature=req.temperature,
                    top_k=req.top_k,
                    top_p=req.top_p,
                    max_tokens=req.max_tokens,
                    repetition_penalty=req.repetition_penalty,
                    seed=req.seed,
                    text_chunked=metrics.get("text_chunked", False),
                    ttfb_ms=metrics.get("ttfb_ms"),
                    audio_duration_seconds=audio_duration,
                )

                # Log successful completion (lifecycle event + metrics)
                total_time = time.time() - start_time
                rtf = total_time / audio_duration if audio_duration > 0 else 0
                headers["X-RTF"] = f"{rtf:.3f}"
                total_duration_ms = int(total_time * 1000)
                log_request_completed(
                    request_id=request_id,
                    status_code=200,
                    total_duration_ms=total_duration_ms,
                    audio_duration_seconds=audio_duration,
                    rtf=rtf,
                    chunks_sent=1,
                )
                metrics_request_completed(
                    status_code=200,
                    duration_seconds=total_time,
                    speaker=speaker,
                    stream=False,
                )
                record_ttfb(ttfb_seconds=metrics.get("ttfb_ms", 0) / 1000, speaker=speaker, stream=False)
                record_rtf(rtf=rtf, speaker=speaker)
                record_audio_duration(duration_seconds=audio_duration, speaker=speaker)

                # High-cardinality details are kept in logs; summarize key timings.
                logger.info(
                    "request_perf_summary",
                    extra={
                        "request_id": request_id,
                        "speaker": speaker,
                        "generation_ms": metrics.get("generation_ms"),
                        "llm_token_total": metrics.get("llm_token_total"),
                        "llm_batch_count": metrics.get("llm_batch_count"),
                        "llm_batch_wall_ms_total": metrics.get("llm_batch_wall_ms_total"),
                        "llm_batch_gpu_ms_total": metrics.get("llm_batch_gpu_ms_total"),
                        "timeline_llm_first_batch_ms": metrics.get("timeline_llm_first_batch_ms"),
                        "timeline_llm_done_ms": metrics.get("timeline_llm_done_ms"),
                        "timeline_bicodec_done_ms": metrics.get("timeline_bicodec_done_ms"),
                        "timeline_total_ms": metrics.get("timeline_total_ms"),
                        "bicodec_decode_wall_ms": metrics.get("bicodec_decode_wall_ms"),
                        "bicodec_decode_gpu_ms": metrics.get("bicodec_decode_gpu_ms"),
                        "api_preprocess_ms": metrics.get("api_preprocess_ms"),
                        "api_generate_await_ms": metrics.get("api_generate_await_ms"),
                        "api_postprocess_ms": metrics.get("api_postprocess_ms"),
                        "api_total_ms": metrics.get("api_total_ms"),
                        "api_overhead_vs_pipeline_ms": metrics.get("api_overhead_vs_pipeline_ms"),
                        "request_inflight": metrics.get("request_inflight"),
                    },
                )

                clear_request_context()
                
                return Response(
                    content=audio_bytes,
                    media_type=media_type,
                    headers=headers,
                )
                
            except Exception as e:
                logger.exception(f"TTS generation error: {e}")
                
                # Log failure (lifecycle event + metrics)
                log_request_failed(
                    request_id=request_id,
                    status_code=500,
                    error_code="INTERNAL_ERROR",
                    error_message=str(e),
                )
                metrics_request_failed(
                    status_code=500,
                    error_code="INTERNAL_ERROR",
                    speaker=speaker,
                )
                clear_request_context()
                
                return JSONResponse(
                    status_code=500,
                    content={
                        "error": {
                            "code": "INTERNAL_ERROR",
                            "message": str(e),
                            "request_id": request_id,
                        }
                    },
                    headers={"X-Request-ID": request_id},
                )
        finally:
            _inflight_exit()

    async def _handle_streaming_request(req, request_id: str, text: str, speaker: str, start_time: float):
        """
        Handle streaming TTS request.
        
        Returns a StreamingResponse that yields WAV header + PCM chunks.
        True streaming: first bytes sent ASAP before full audio is generated.
        """
        from fastapi.responses import StreamingResponse
        from starlette.background import BackgroundTask
        from veena3modal.api.schemas import AudioFormat
        from veena3modal.services import tts_runtime
        
        # Check if streaming pipeline is available
        runtime = tts_runtime.get_runtime()
        if runtime is None or runtime.streaming_pipeline is None:
            return JSONResponse(
                status_code=503,
                content={
                    "error": {
                        "code": "STREAMING_UNAVAILABLE",
                        "message": "Streaming pipeline not initialized.",
                        "request_id": request_id,
                    }
                },
                headers={"X-Request-ID": request_id},
            )
        
        # Non-WAV formats not implemented yet (M5)
        if req.format != AudioFormat.WAV.value:
            return JSONResponse(
                status_code=501,
                content={
                    "error": {
                        "code": "FORMAT_NOT_IMPLEMENTED",
                        "message": f"Streaming format '{req.format}' not yet implemented. Use format=wav.",
                        "request_id": request_id,
                    }
                },
                headers={"X-Request-ID": request_id},
            )
        
        # Set request context for structured logging
        set_request_context(request_id=request_id)

        try:
            admission_lease = await tts_runtime.acquire_streaming_slot()
        except tts_runtime.StreamingAdmissionError as exc:
            retry_after_ms = float(getattr(exc, "retry_after_ms", 0.0) or 0.0)
            retry_after_s = max(1, int((retry_after_ms + 999.0) // 1000.0))
            overload_snapshot = dict(getattr(exc, "snapshot", {}) or {})
            log_request_failed(
                request_id=request_id,
                status_code=429,
                error_code="STREAMING_OVERLOADED",
                error_message=str(exc),
            )
            metrics_request_failed(
                status_code=429,
                error_code="STREAMING_OVERLOADED",
                speaker=speaker,
            )
            clear_request_context()
            return JSONResponse(
                status_code=429,
                content={
                    "error": {
                        "code": "STREAMING_OVERLOADED",
                        "message": "Streaming capacity exhausted. Retry shortly.",
                        "request_id": request_id,
                        "reason": getattr(exc, "reason", "overloaded"),
                    },
                    "admission": overload_snapshot,
                },
                headers={
                    "X-Request-ID": request_id,
                    "Retry-After": str(retry_after_s),
                },
            )
        
        # Create async generator for streaming response
        async def audio_stream_generator():
            """Yields audio chunks as they're generated."""
            final_metrics = {}
            first_chunk_sent = False
            chunks_count = 0
            try:
                async for audio_chunk, metrics in tts_runtime.generate_speech_streaming(
                    text=text,
                    speaker=speaker,
                    temperature=req.temperature,
                    top_k=req.top_k,
                    top_p=req.top_p,
                    max_tokens=req.max_tokens,
                    repetition_penalty=req.repetition_penalty,
                    seed=req.seed,
                    enable_chunking=req.chunking,
                    admission_lease=admission_lease,
                    release_admission_lease=False,
                ):
                    final_metrics = metrics
                    chunks_count += 1
                    
                    # Fire-and-forget sentence storage after first chunk (don't block TTFB)
                    if not first_chunk_sent:
                        first_chunk_sent = True
                        ttfb_ms = metrics.get("ttfb_ms", 0)
                        
                        # Log first audio emitted (TTFB marker)
                        log_first_audio_emitted(
                            request_id=request_id,
                            ttfb_ms=ttfb_ms,
                            chunk_size_bytes=len(audio_chunk),
                        )
                        record_ttfb(ttfb_seconds=ttfb_ms / 1000, speaker=speaker, stream=True)
                        
                        sentence_store = _get_sentence_store()
                        sentence_store.store_fire_and_forget(
                            request_id=request_id,
                            text=text,
                            speaker=speaker,
                            stream=True,
                            format=req.format,
                            temperature=req.temperature,
                            top_k=req.top_k,
                            top_p=req.top_p,
                            max_tokens=req.max_tokens,
                            repetition_penalty=req.repetition_penalty,
                            seed=req.seed,
                            text_chunked=metrics.get("text_chunked", False),
                            ttfb_ms=ttfb_ms,
                            # audio_duration_seconds not known yet for streaming
                        )
                    
                    yield audio_chunk
                
                # Stream completed - log metrics
                total_time = time.time() - start_time
                audio_duration = final_metrics.get("audio_duration_seconds", 0)
                rtf = total_time / audio_duration if audio_duration > 0 else 0
                
                log_request_completed(
                    request_id=request_id,
                    status_code=200,
                    total_duration_ms=int(total_time * 1000),
                    audio_duration_seconds=audio_duration,
                    rtf=rtf,
                    chunks_sent=chunks_count,
                )
                metrics_request_completed(
                    status_code=200,
                    duration_seconds=total_time,
                    speaker=speaker,
                    stream=True,
                )
                record_rtf(rtf=rtf, speaker=speaker)
                record_audio_duration(duration_seconds=audio_duration, speaker=speaker)
                record_chunks_sent(chunks=chunks_count, speaker=speaker)
                clear_request_context()
            except Exception as e:
                logger.exception(f"Streaming error: {e}")
                # Log streaming failure
                log_request_failed(
                    request_id=request_id,
                    status_code=500,
                    error_code="STREAMING_ERROR",
                    error_message=str(e),
                )
                metrics_request_failed(
                    status_code=500,
                    error_code="STREAMING_ERROR",
                    speaker=speaker,
                )
                clear_request_context()
                # Can't return error once streaming started
                # Best we can do is stop yielding
                return
            finally:
                tts_runtime.release_streaming_slot(admission_lease)
        
        # Build response headers
        # Note: Some headers can't be set dynamically during streaming
        # TTFB and RTF will be logged but not in headers for streaming
        model_version = runtime.model_version if runtime else get_model_version()
        
        headers = {
            "X-Request-ID": request_id,
            "X-Model-Version": model_version,
            "X-Format": "wav",
            "X-Sample-Rate": str(req.sample_rate),
            "X-Stream": "true",
            "X-Chunking-Enabled": str(req.chunking).lower(),
            "X-Admission-Wait-ms": str(int(round(float(admission_lease.wait_ms)))),
            "X-Admission-Queued": str(bool(admission_lease.queued)).lower(),
            # Transfer-Encoding: chunked is set automatically by StreamingResponse
        }
        
        if req.seed is not None:
            headers["X-Seed"] = str(req.seed)
        
        return StreamingResponse(
            audio_stream_generator(),
            media_type="audio/wav",
            headers=headers,
            background=BackgroundTask(tts_runtime.release_streaming_slot, admission_lease),
        )

    # === WebSocket Support (M5b) ===
    try:
        from veena3modal.api.websocket_handler import add_websocket_routes
        add_websocket_routes(app)
        logger.info("WebSocket support enabled at /v1/tts/ws")
    except ImportError as e:
        logger.warning(f"WebSocket support not available: {e}")

    return app
