# # Nemotron ASR Streaming on Modal

import time
import json
from dataclasses import dataclass, field
from typing import Optional, List
from pathlib import Path

import modal

app = modal.App("hindi-nemotron-asr")

model_cache = modal.Volume.from_name("hindi-nemotron-model", create_if_missing=True)
CACHE_PATH = "/model"

# WHY: Use named Modal secret for production (never hardcode tokens in source)
# Create with: modal secret create hf-token HF_TOKEN=hf_xxxxx
# Falls back to inline dict for dev — but MUST use named secret in production
try:
    hf_secret = modal.Secret.from_name("hf-token")
except Exception:
    # Fallback for development only — remove before production
    import warnings
    warnings.warn("Using fallback HF token. Create a Modal secret 'hf-token' for production.")
    hf_secret = modal.Secret.from_dict({"HF_TOKEN": "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"})

image = (
    modal.Image.from_registry(
        "nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04", add_python="3.12"
    )
    .env(
        {
            "HF_HUB_ENABLE_HF_TRANSFER": "1",
            "HF_HOME": CACHE_PATH,
            "CXX": "g++",
            "CC": "g++",
            "TORCH_HOME": CACHE_PATH,
            "PYTHONPATH": "/root/stt-services:$PYTHONPATH",
        }
    )
    .apt_install("git", "libsndfile1", "ffmpeg")
    .pip_install(
        "hf_transfer",
        "huggingface_hub[hf-xet]",
        "numpy<2",
        "fastapi[standard]",
        "orjson",
        "msgpack",
        # WHY: Pin NeMo to a specific commit — @main breaks randomly on upstream changes.
        # Update this hash deliberately after testing, not by accident.
        # To find latest stable: git ls-remote https://github.com/NVIDIA/NeMo.git main
        # Pinned 2026-02-08: same commit that was previously on @main and working
        "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@557177a18d080867eb24c800c2a334002f9f9fc3",
        "nemo_text_processing",
    )
    # WHY: Entry file is /root/nemotron_asr.py which shadows the package name.
    # So we place helper modules at /root/asr_lib/ to avoid the name conflict.
    .add_local_dir(
        str(Path(__file__).parent),
        remote_path="/root/asr_lib",
        copy=True,
    )
    .run_commands("ls -la /root/asr_lib/ && cat /root/asr_lib/vad.py")  # verify files + VAD content
)


with image.imports():
    import asyncio
    import logging
    import numpy as np
    import msgpack
    import torch
    from omegaconf import OmegaConf
    from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
    from nemo.collections.asr.inference.utils.progressbar import TQDMProgressBar
    from fastapi import FastAPI, WebSocket, WebSocketDisconnect
    from starlette.websockets import WebSocketState
    # Original: import uvicorn, threading — removed: tunnel + uvicorn server replaced by @modal.asgi_app()
    
    # WHY: Helper modules are at /root/asr_lib/ (via add_local_dir) to avoid
    # name shadowing with the entry file /root/nemotron_asr.py
    import sys
    sys.path.insert(0, "/root/asr_lib")
    from asr_utils import preprocess_audio
    from pipelines import StreamingPipelineBuilder
    from vad import SileroVAD


class SileroVAD:
    """
    CPU-based Voice Activity Detection using Silero VAD.
    WHY: Filters silence BEFORE GPU pipeline. With 1000 connected users,
    only ~300 are speaking at any moment. VAD effectively 3x GPU capacity.
    Runs on CPU, ~1ms per chunk — negligible overhead.
    """
    def __init__(self, threshold=0.5, sample_rate=16000):
        self.threshold = threshold
        self.sample_rate = sample_rate
        self._model = None

    def _ensure_loaded(self):
        if self._model is None:
            self._model, _ = torch.hub.load(
                repo_or_dir='snakers4/silero-vad',
                model='silero_vad', force_reload=False, onnx=False,
            )
            self._model.eval()

    def is_speech(self, audio_bytes):
        """Check if audio chunk contains speech. Returns True on any error (safe fallback)."""
        self._ensure_loaded()
        try:
            audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
            audio_tensor = torch.from_numpy(audio_np)
            window_size = 512  # 32ms at 16kHz
            if len(audio_tensor) < window_size:
                return True  # Too short for VAD, pass through
            chunk = audio_tensor[-window_size:]
            prob = self._model(chunk, self.sample_rate).item()
            return prob >= self.threshold
        except Exception:
            return True  # On error, assume speech (safe)


# Configuration dataclasses

@dataclass
class BoostingTreeConfig:
    """Boosting tree configuration for phrase boosting"""
    model_path: Optional[str] = None
    key_phrases_file: Optional[str] = None
    key_phrases_list: Optional[List[str]] = None
    source_lang: str = "en"


@dataclass
class GreedyDecodingConfig:
    """Greedy decoding configuration"""
    use_cuda_graph_decoder: bool = False
    max_symbols: int = 10
    ngram_lm_model: Optional[str] = None
    ngram_lm_alpha: float = 0.0
    boosting_tree: BoostingTreeConfig = field(default_factory=BoostingTreeConfig)
    boosting_tree_alpha: float = 0.0


@dataclass
class DecodingConfig:
    """Decoding configuration"""
    strategy: str = "greedy_batch"
    preserve_alignments: bool = False
    fused_batch_size: int = -1
    greedy: GreedyDecodingConfig = field(default_factory=GreedyDecodingConfig)


HF_MODEL_REPO = "BayAreaBoys/nemotron-hindi"
HF_MODEL_FILE = "final_model.nemo"

@dataclass
class ASRConfig:
    """ASR model configuration"""
    # WHY: Point to local path after hf_hub_download — set dynamically in load()
    model_name: str = "nvidia/nemotron-speech-streaming-en-0.6b"
    device: str = "cuda"
    device_id: int = 0
    compute_dtype: str = "bfloat16"
    use_amp: bool = True
    decoding: DecodingConfig = field(default_factory=DecodingConfig)


@dataclass
class ITNConfig:
    """Inverse Text Normalization configuration"""
    input_case: str = "lower_cased"
    whitelist: Optional[str] = None
    overwrite_cache: bool = False
    max_number_of_permutations_per_split: int = 729
    left_padding_size: int = 4
    batch_size: int = 32
    n_jobs: int = 16


@dataclass
class ConfidenceConfig:
    """Confidence estimation configuration"""
    exclude_blank: bool = True
    aggregation: str = "mean"
    method_cfg: dict = field(default_factory=lambda: {
        "name": "entropy",
        "entropy_type": "tsallis",
        "alpha": 0.5,
        "entropy_norm": "exp",
    })


@dataclass
class EndpointingConfig:
    """Endpointing configuration"""
    stop_history_eou: int = 800
    residue_tokens_at_end: int = 2


@dataclass
class StreamingConfig:
    """Streaming configuration"""
    sample_rate: int = 16000
    batch_size: int = 512
    word_boundary_tolerance: int = 4
    att_context_size: List[int] = field(default_factory=lambda: [70, 6])
    use_cache: bool = True
    use_feat_cache: bool = True
    chunk_size_in_secs: Optional[float] = None
    request_type: str = "frame"
    num_slots: int = 1024
    exhaustive_batching: bool = False  # Process ALL ready frames in one cycle vs single batch
    batching_delay_secs: float = 0.300  # Legacy: only used as fallback
    # === Production batching config ===
    # WHY: Event-driven batching replaces the fixed 300ms sleep.
    # The batcher fires when EITHER max_batch_wait_ms expires OR enough chunks arrive.
    # Under low load: fires after max_batch_wait_ms (bounded latency).
    # Under high load: fires immediately when batch is full (max throughput).
    max_batch_wait_ms: float = 50.0     # Max ms to wait for a batch to fill (50ms = "feels instant")
    min_batch_fire_size: int = 1        # Fire immediately if this many chunks are ready
    # === Connection limits ===
    max_active_streams: int = 400       # Per-container hard cap (matches max_inputs=400)
    # === VAD config ===
    # WHY: Disabled for initial stress test. Enable after VAD Silero compat fix.
    enable_vad: bool = False            # Gate audio through CPU VAD before GPU pipeline
    vad_threshold: float = 0.5          # Silero VAD speech probability threshold

@dataclass
class MetricsASRConfig:
    """ASR metrics configuration"""
    gt_text_attr_name: str = "text"
    clean_groundtruth_text: bool = False
    langid: str = "hi"
    use_cer: bool = True  # CER more meaningful for Hindi/Devanagari
    ignore_capitalization: bool = True
    ignore_punctuation: bool = True
    strip_punc_space: bool = False


@dataclass
class MetricsConfig:
    """Metrics configuration"""
    asr: MetricsASRConfig = field(default_factory=MetricsASRConfig)


@dataclass
class CacheAwarePipelineConfig:
    """Main configuration for cache-aware RNNT pipeline"""
    # ASR configuration
    asr: ASRConfig = field(default_factory=ASRConfig)
    
    # ITN configuration
    itn: ITNConfig = field(default_factory=ITNConfig)
    
    # NMT configuration (set to None to disable)
    nmt: Optional[dict] = None
    
    # Confidence configuration
    confidence: ConfidenceConfig = field(default_factory=ConfidenceConfig)
    
    # Endpointing configuration
    endpointing: EndpointingConfig = field(default_factory=EndpointingConfig)
    
    # Streaming configuration
    streaming: StreamingConfig = field(default_factory=StreamingConfig)
    
    # Pipeline settings
    matmul_precision: str = "high"
    log_level: int = 20
    pipeline_type: str = "cache_aware"
    asr_decoding_type: str = "rnnt"
    
    # Runtime arguments
    audio_file: Optional[str] = None
    output_filename: Optional[str] = None
    output_dir: Optional[str] = None
    enable_pnc: bool = False
    enable_itn: bool = True
    enable_nmt: bool = False
    asr_output_granularity: str = "segment"
    cache_dir: Optional[str] = None
    lang: Optional[str] = 'hi'
    return_tail_result: bool = False
    calculate_wer: bool = True
    calculate_bleu: bool = True
    
    # Metrics
    metrics: MetricsConfig = field(default_factory=MetricsConfig)


@app.cls(
    volumes={CACHE_PATH: model_cache},
    gpu="H100",
    # WHY: 4 physical CPU cores for async event loop handling 100s of WS connections,
    # audio preprocessing, queue routing. Default 0.125 cores was starving the event loop.
    cpu=4.0,
    image=image,
    secrets=[hf_secret],
    timeout=3600,
    scaledown_window=300,
    min_containers=2,
    buffer_containers=1,
    max_containers=5,
)
# WHY: target_inputs=200 means Modal scales up a NEW container when one has 200
# connections. max_inputs=400 allows burst headroom while new container starts.
# 200 per container = well within the "stable zone" (WER 31%, ChkP99 <500ms).
# 5 containers x 200 target = 1000 users. 5 x 400 burst = 2000 users peak.
@modal.concurrent(max_inputs=400, target_inputs=200)
class NemotronASR:
        
    @modal.enter()
    async def load(self):
        # Silence chatty logs from nemo
        logging.getLogger("nemo_logger").setLevel(logging.WARNING)
        
        self.client_queues: dict[int, asyncio.Queue] = {}
        self.client_queues_lock = None  # Will be initialized in async context
        self.inference_task = None
        self.inference_running = False
        
        # Per-stream audio timestamp tracking (cumulative duration in seconds)
        self.stream_audio_timestamps: dict[int, float] = {}
        self.stream_timestamps_lock = None  # Will be initialized in async context
        
        # Track message format per stream (True = msgpack, False = raw bytes/JSON)
        self.stream_uses_msgpack: dict[int, bool] = {}
        self.stream_format_lock = None  # Will be initialized in async context
        
        # Track streams pending cleanup (WebSocket closed but may still have buffered frames)
        self.streams_pending_cleanup: set[int] = set()
        self.streams_cleanup_lock = None  # Will be initialized in async context
        
        # === EVENT-DRIVEN BATCHING ===
        # WHY: Instead of sleeping 300ms, the inference loop waits on this event.
        # recv_loop signals it when audio data arrives. The loop also has a timeout
        # (max_batch_wait_ms) so it fires periodically even without new data.
        self._batch_ready_event = None  # Will be asyncio.Event() in async context
        
        # Timing reference (set after warmup)
        self.start_time = None

        print("Initializing pipeline configuration...")
        
        # === Download our fine-tuned Hindi model from HuggingFace ===
        import os
        from huggingface_hub import hf_hub_download
        hf_token = os.environ.get("HF_TOKEN")
        print(f"Downloading {HF_MODEL_REPO}/{HF_MODEL_FILE} from HuggingFace...")
        model_path = hf_hub_download(
            repo_id=HF_MODEL_REPO,
            filename=HF_MODEL_FILE,
            token=hf_token,
            cache_dir=CACHE_PATH,
        )
        print(f"  Model cached at: {model_path}")
        
        # Create config as dataclass, then convert to OmegaConf
        config = CacheAwarePipelineConfig()
        
        # WHY: Point model_name to the downloaded .nemo file
        config.asr.model_name = model_path
        
        # Disable ITN and NMT — not needed for Hindi transcription
        config.enable_itn = False
        config.enable_nmt = False
        config.nmt = None
        # Hindi-specific: disable PnC since our model already produces punctuated output
        config.enable_pnc = False
        
        # Convert to OmegaConf for NeMo
        self.cfg = OmegaConf.structured(config)
        
        print(f"Building pipeline with config:")
        print(f"  Model: {HF_MODEL_REPO} (local: {model_path})")
        print(f"  Pipeline type: {self.cfg.pipeline_type}")
        print(f"  ASR decoding type: {self.cfg.asr_decoding_type}")
        print(f"  Attention context size: {self.cfg.streaming.att_context_size}")
        print(f"  Batch size: {self.cfg.streaming.batch_size}")
        
        # Build the pipeline using PipelineBuilder
        self.pipeline = StreamingPipelineBuilder.build_pipeline(self.cfg)
        
        print("Pipeline loaded successfully!")
        
        # Warm up with test audio using streaming (3x for thorough warmup)
        print("Warming up GPU with streaming inference (3 iterations)...")
        AUDIO_URL = "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/mono_44100/156550__acclivity__a-dream-within-a-dream.wav"
        audio_bytes = preprocess_audio(AUDIO_URL, target_sample_rate=16000)
        
        # Initialize streaming support
        print("Initializing streaming request generator...")
        self.pipeline.init_streaming_request_generator()
        print("Streaming initialized!")
        
        # Run 3 streaming warmup iterations
        for warmup_iter in range(1, 4):
            print(f"\n🔥 Warmup iteration {warmup_iter}/3...")
            options = ASRRequestOptions()
            options.asr_output_granularity = "word"
            stream_id = self.pipeline.open_streaming_session(options=options)
            print(f"   Opened stream {stream_id}")
            
            # Calculate chunk size based on streaming config
            # Use 80ms chunks (1280 samples at 16kHz)
            chunk_duration = 0.080  # seconds
            chunk_size = int(chunk_duration * 16000 * 2)  # 16kHz, 16-bit (2 bytes) = 2560 bytes
            
            step_num = 0
            full_streaming_transcript = ""
            
            # Stream the audio in chunks
            for i in range(0, len(audio_bytes), chunk_size):
                chunk = audio_bytes[i:i + chunk_size]
                
                # Append to stream
                self.pipeline.append_streaming_audio(stream_id, chunk)
                
                # Process any ready frames
                import asyncio
                requests, outputs = await self.pipeline.process_streaming_batch()
                
                if outputs: 
                    for output in outputs:
                        if output.stream_id == stream_id:
                            if output.final_transcript:
                                full_streaming_transcript += output.final_transcript
                
                step_num += 1
            
            # Close the stream and process final frames
            self.pipeline.close_streaming_session(stream_id)
            
            # Try to get any remaining outputs
            try:
                remaining_requests, remaining_outputs = await self.pipeline.process_streaming_batch()
                for output in remaining_outputs:
                    if output.stream_id == stream_id:
                        if output.final_transcript:
                            full_streaming_transcript += output.final_transcript
            except:
                pass
            
            print(f"   ✅ Iteration {warmup_iter} complete: '{full_streaming_transcript[:50]}...' ({step_num} chunks)")
            
            # Clean up this warmup iteration
            if stream_id in self.pipeline._state_pool:
                self.pipeline.delete_state(stream_id)
            if stream_id in self.pipeline._streaming_request_generator.streams:
                self.pipeline._streaming_request_generator.streams.pop(stream_id, None)
        
        # Final verification after all warmup iterations
        print(f"\n🎉 All warmup iterations complete!")
        
        # Set timing reference point (after warmup)
        self.start_time = time.perf_counter()
        
        # Store exhaustive batching config
        self.exhaustive_batching = self.cfg.streaming.exhaustive_batching
        print(f"Exhaustive batching: {'ENABLED' if self.exhaustive_batching else 'DISABLED'}")
        
        # === VAD INITIALIZATION ===
        # WHY: CPU-based voice activity detection filters silence before GPU.
        # With 1000 users in voice-agent mode, ~70% are listening (not speaking).
        # VAD skips those chunks entirely, effectively 3x GPU capacity.
        if self.cfg.streaming.enable_vad:
            self.vad = SileroVAD(
                threshold=self.cfg.streaming.vad_threshold,
                sample_rate=self.cfg.streaming.sample_rate,
            )
            print(f"VAD: ENABLED (threshold={self.cfg.streaming.vad_threshold})")
        else:
            self.vad = None
            print("VAD: DISABLED")
        
        # Initialize async primitives for multi-client support
        self.client_queues_lock = asyncio.Lock()
        self.stream_timestamps_lock = asyncio.Lock()
        self.stream_format_lock = asyncio.Lock()
        self.streams_cleanup_lock = asyncio.Lock()
        self._batch_ready_event = asyncio.Event()
        
        # Setup FastAPI WebSocket server
        self.web_app = FastAPI()
        
        # === HEALTH CHECK + METRICS ENDPOINT ===
        # WHY: Load balancers need a fast probe. Returns 503 when overloaded.
        @self.web_app.get("/health")
        async def health_check():
            try:
                vram_used = torch.cuda.memory_allocated() / 1e9
                vram_total = torch.cuda.get_device_properties(0).total_memory / 1e9
                gpu_name = torch.cuda.get_device_name(0)
            except Exception:
                vram_used = vram_total = 0
                gpu_name = "unknown"
            active = len(self.client_queues)
            max_streams = self.cfg.streaming.max_active_streams
            status = "healthy" if active < max_streams else "overloaded"
            from fastapi.responses import JSONResponse
            return JSONResponse(
                status_code=200 if status == "healthy" else 503,
                content={
                    "status": status,
                    "active_streams": active,
                    "max_streams": max_streams,
                    "gpu": gpu_name,
                    "vram_gb": round(vram_used, 1),
                    "vram_total_gb": round(vram_total, 0),
                    "inference_running": self.inference_running,
                    "uptime_s": round(time.perf_counter() - self.start_time, 0) if self.start_time else 0,
                }
            )
        
        # === HTTP BATCH TRANSCRIPTION ===
        # WHY: Benchmarking needs HTTP mode alongside WebSocket.
        # Routes through same inference pipeline as WS — single-writer safety.
        from fastapi import Request as FastAPIRequest
        import base64
        
        async def _transcribe_via_scheduler(pcm_bytes: bytes) -> str:
            """
            Feed audio through streaming pipeline and collect full transcription.
            Creates a virtual stream, feeds all audio, waits for final result.
            """
            self.start_inference_loop_if_needed()
            
            stream_id = self.pipeline.open_streaming_session()
            transcription_queue = await self.register_client(stream_id)
            
            try:
                # Feed all audio in chunks
                chunk_size_bytes = (self.cfg.streaming.att_context_size[1] + 1) * 1280 * 2
                for i in range(0, len(pcm_bytes), chunk_size_bytes):
                    chunk = pcm_bytes[i:i + chunk_size_bytes]
                    self.pipeline.append_streaming_audio(stream_id, chunk)
                    if self._batch_ready_event is not None:
                        self._batch_ready_event.set()
                    await asyncio.sleep(0)
                
                # Mark stream ended
                self.pipeline._streaming_request_generator.streams[stream_id].mark_end()
                if self._batch_ready_event is not None:
                    self._batch_ready_event.set()
                
                # Collect transcription (timeout 30s)
                full_text = ""
                deadline = time.perf_counter() + 30.0
                while time.perf_counter() < deadline:
                    try:
                        output_dict = await asyncio.wait_for(
                            transcription_queue.get(), timeout=2.0
                        )
                        if output_dict is None:
                            break
                        if output_dict.get('is_marker'):
                            break
                        json_str = output_dict.get('json_str', '{}')
                        data = json.loads(json_str)
                        if data.get('is_final') and data.get('text'):
                            full_text += data['text'] + " "
                        elif data.get('text') and not data.get('is_final'):
                            # Keep latest partial as fallback
                            full_text = data['text']
                    except asyncio.TimeoutError:
                        # Check if stream is still in generator
                        if stream_id not in self.pipeline._streaming_request_generator.streams:
                            break
                
                return full_text.strip()
            finally:
                await self.unregister_client(stream_id)
        
        @self.web_app.post("/transcribe")
        async def http_transcribe(request: FastAPIRequest):
            """HTTP batch transcription. Accepts multipart file or JSON base64."""
            t_start = time.perf_counter()
            content_type = request.headers.get("content-type", "")
            
            if content_type.startswith("application/json"):
                body = await request.json()
                audio_b64 = body.get("audio_base64") or body.get("audio")
                if not audio_b64:
                    from fastapi.responses import JSONResponse
                    return JSONResponse(status_code=400, content={"error": "audio_base64 required"})
                audio_data = base64.b64decode(audio_b64)
            else:
                form = await request.form()
                file_obj = form.get("file")
                if file_obj is None:
                    from fastapi.responses import JSONResponse
                    return JSONResponse(status_code=400, content={"error": "file required"})
                audio_data = await file_obj.read() if hasattr(file_obj, "read") else file_obj
            
            pcm_bytes = preprocess_audio(audio_data, target_sample_rate=16000)
            duration = len(pcm_bytes) / (16000 * 2)
            
            t_infer = time.perf_counter()
            text = await _transcribe_via_scheduler(pcm_bytes)
            t_done = time.perf_counter()
            
            try:
                vram_used = torch.cuda.memory_allocated() / 1e9
                vram_total = torch.cuda.get_device_properties(0).total_memory / 1e9
                gpu_name = torch.cuda.get_device_name(0)
            except Exception:
                vram_used = vram_total = 0
                gpu_name = "unknown"
            
            return {
                "text": text,
                "duration": round(duration, 2),
                "processing_time": round(t_done - t_infer, 3),
                "total_time": round(t_done - t_start, 3),
                "gpu": gpu_name,
                "vram_gb": round(vram_used, 1),
                "vram_total_gb": round(vram_total, 0),
            }
        
        # Register WebSocket handler
        @self.web_app.websocket("/ws")
        async def run_with_websocket(ws: WebSocket):
            """
            Production WebSocket endpoint with backpressure.
            Rejects connections when container is at capacity.
            """
            
            # === BACKPRESSURE: Fast-fail when overloaded ===
            # WHY: Better to reject immediately than silently degrade everyone.
            # Client gets clear signal to retry / route to another container.
            active = len(self.client_queues)
            max_streams = self.cfg.streaming.max_active_streams
            if active >= max_streams:
                await ws.accept()
                await ws.close(code=1013, reason=f"Overloaded: {active}/{max_streams}")
                return
            
            await ws.accept()
            
            stream_id = None  # Track for cleanup
            try:
                # Start centralized inference loop if not already running
                self.start_inference_loop_if_needed()
                
                # Open streaming session for this client
                stream_id = self.pipeline.open_streaming_session()
                print(f"✅ Opened stream_id={stream_id}")
                
                transcription_queue = await self.register_client(stream_id)
                elapsed = time.perf_counter() - self.start_time
                print(f"[+{elapsed:7.3f}s] Client {stream_id} connected ({len(self.client_queues)} total)")
            except Exception as e:
                print(f"❌ Error during connection setup: {e}")
                import traceback
                traceback.print_exc()
                if stream_id is not None:
                    await self.unregister_client(stream_id)
                if ws and ws.application_state is WebSocketState.CONNECTED:
                    await ws.close()
                raise
            
            async def recv_loop(ws, stream_id, transcription_queue):
                """
                Receive audio chunks and append to stream buffer
                """

                chunk_size = self.cfg.streaming.att_context_size[1] + 1
                audio_buffer = b""  # Accumulate chunks here as bytes
                num_buffer_samples = 0
                
                elapsed = time.perf_counter() - self.start_time
                print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Micro-batching enabled (batch_size={chunk_size})")
                
                try:
                    while True:
                        # EAGER MESSAGE DRAINING: Collect all available messages before processing
                        # This reduces per-message overhead by batching WebSocket receives
                        messages = []
                        
                        # [1] First message (blocking) - always wait for at least one
                        first_message = await ws.receive()
                        messages.append(first_message)
                        
                        # [2] Drain any additional messages with timeout (non-blocking)
                        # This grabs messages that arrived during processing of previous batch
                        DRAIN_TIMEOUT_MS = 1  # 1ms timeout
                        while True:
                            try:
                                # Try to get more messages with timeout
                                next_message = await asyncio.wait_for(
                                    ws.receive(), 
                                    timeout=DRAIN_TIMEOUT_MS / 1000
                                )
                                messages.append(next_message)
                            except asyncio.TimeoutError:
                                # No more messages available, proceed with what we have
                                break
                        
                        stream_ended = False
                        
                        # Process all messages - support both msgpack and raw bytes
                        for message in messages:
                            # Check for text messages (marker signal from raw bytes client)
                            if "text" in message:
                                text_msg = message["text"]
                                if text_msg == "END" or text_msg == "MARKER":
                                    # Track that this stream uses raw bytes (not msgpack)
                                    async with self.stream_format_lock:
                                        if stream_id not in self.stream_uses_msgpack:
                                            self.stream_uses_msgpack[stream_id] = False
                                            elapsed = time.perf_counter() - self.start_time
                                            print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Detected format = RAW BYTES (text marker)")
                                    
                                    elapsed = time.perf_counter() - self.start_time
                                    print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Received text Marker '{text_msg}'")
                                    await transcription_queue.put({
                                        'output': None,
                                        'audio_timestamp': 0.0,
                                        'is_marker': True
                                    })
                                    stream_ended = True
                                    break
                            
                            # Check for binary messages (audio data or msgpack)
                            if "bytes" in message:
                                raw_bytes = message["bytes"]
                                is_msgpack = False
                                
                                # Ensure raw_bytes is actually bytes, not str
                                if isinstance(raw_bytes, str):
                                    # Convert string to bytes if needed (shouldn't happen with binary WebSocket)
                                    raw_bytes = raw_bytes.encode('latin-1')
                                
                                # Try to decode as msgpack first (for backwards compatibility)
                                is_msgpack_data = False
                                try:
                                    # Try msgpack decode
                                    data = msgpack.unpackb(raw_bytes, raw=False)
                                    msg_type = data.get("type")
                                    
                                    # Successfully decoded as msgpack
                                    is_msgpack_data = True
                                    
                                    # Track that this stream uses msgpack (on first message)
                                    async with self.stream_format_lock:
                                        if stream_id not in self.stream_uses_msgpack:
                                            self.stream_uses_msgpack[stream_id] = True
                                            elapsed = time.perf_counter() - self.start_time
                                            print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Detected format = MSGPACK")
                                    
                                    # Handle msgpack Marker
                                    if msg_type == "Marker":
                                        elapsed = time.perf_counter() - self.start_time
                                        print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Received msgpack Marker")
                                        await transcription_queue.put({
                                            'output': None,
                                            'audio_timestamp': 0.0,
                                            'is_marker': True
                                        })
                                        stream_ended = True
                                        break
                                    
                                    # Handle msgpack Audio
                                    if msg_type == "Audio":
                                        pcm_bytes = data["pcm_bytes"]
                                        # Accumulate in buffer
                                        audio_buffer += pcm_bytes
                                        num_buffer_samples += len(pcm_bytes)
                                
                                except:
                                    # Silently ignore msgpack decode errors - just means it's raw bytes
                                    pass
                                
                                # If not msgpack, treat as raw PCM bytes
                                if not is_msgpack_data:
                                    # Track that this stream uses raw bytes (on first message)
                                    async with self.stream_format_lock:
                                        if stream_id not in self.stream_uses_msgpack:
                                            self.stream_uses_msgpack[stream_id] = False
                                            elapsed = time.perf_counter() - self.start_time
                                            print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Detected format = RAW BYTES")
                                    
                                    # This is the simple path for frontends without msgpack
                                    # raw_bytes should already be bytes type
                                    audio_buffer += raw_bytes
                                    num_buffer_samples += len(raw_bytes)
                        
                        # After processing all messages in batch, check if buffer should be flushed
                        if num_buffer_samples >= chunk_size * 1280 * 2:  # samples per frame
                            
                            # === VAD GATE: Skip GPU pipeline for silence ===
                            # WHY: CPU VAD runs in ~1ms. If chunk is silence, skip the
                            # entire GPU forward pass. 70% of connected users are
                            # listening (not speaking), so this 3x effective GPU capacity.
                            if self.vad is not None and not self.vad.is_speech(audio_buffer, stream_id):
                                # Silence detected — skip GPU, clear buffer
                                audio_buffer = b""
                                num_buffer_samples = 0
                                if hasattr(self, '_metrics'):
                                    self._metrics['vad_filtered_chunks'] = self._metrics.get('vad_filtered_chunks', 0) + 1
                                await asyncio.sleep(0)
                                continue
                            
                            # Pipeline append (now with batched data)  
                            self.pipeline.append_streaming_audio(stream_id, audio_buffer)
                            
                            # Clear buffer and reset sample counter
                            audio_buffer = b""
                            num_buffer_samples = 0
                            
                            # === SIGNAL THE INFERENCE LOOP ===
                            # WHY: This is the key event-driven batching trigger.
                            # Instead of the inference loop sleeping 300ms, it waits on this event.
                            # Setting it wakes the loop immediately to process the new data.
                            if self._batch_ready_event is not None:
                                self._batch_ready_event.set()
                            
                            # Yield to event loop after batch append
                            await asyncio.sleep(0)
                        
                        # Check if stream ended and break outer loop
                        if stream_ended:
                            break
                        
                except WebSocketDisconnect:
                    pass  # Normal disconnection
                except Exception as e:
                    # Safely print error without trying to decode binary data
                    try:
                        error_msg = str(e)
                    except:
                        error_msg = repr(e)
                    print(f"❌ recv_loop error stream {stream_id}: {error_msg}")
                    import traceback
                    traceback.print_exc()
                finally:
                    # FLUSH REMAINING BUFFERED CHUNKS
                    if audio_buffer:
                        elapsed = time.perf_counter() - self.start_time
                        print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Flushing {len(audio_buffer)} remaining chunks from buffer")
                        try:
                            np_data = np.frombuffer(audio_buffer, dtype=np.int16)
                            np_data = np_data.astype(np.float32) / 32768.0
                            torch_data = torch.from_numpy(np_data)
                            self.pipeline.append_streaming_audio(stream_id, torch_data)
                            audio_buffer = b""
                        except Exception as e:
                            print(f"[+{elapsed:7.3f}s] ⚠️  Error flushing buffer for stream {stream_id}: {e}")
                    
                    # Mark stream as ended so final frames can be processed
                    try:
                        self.pipeline._streaming_request_generator.streams[stream_id].mark_end()
                    except Exception as e:
                        print(f"[WARN] Could not mark stream {stream_id} as ended: {e}")
                    
                    # WHY: Signal inference loop to drain remaining frames NOW.
                    # Old: await asyncio.sleep(1.0) — blocked event loop for ALL streams.
                    # New: signal event + brief yield for one inference cycle.
                    if self._batch_ready_event is not None:
                        self._batch_ready_event.set()
                    await asyncio.sleep(0.1)  # 100ms — 2x max_batch_wait_ms for safety
                    
                    # Signal send_loop to finish if no marker was sent
                    # (This handles backwards compatibility with clients that don't send markers)
                    await transcription_queue.put(None)
            
            async def send_loop(ws, transcription_queue, stream_id):
                """
                Send transcription results from centralized inference loop to client
                """
                step_num = 0
                
                try:
                    while True:
                        output_dict = await transcription_queue.get()
                        
                        if output_dict is None:  # Shutdown signal
                            break
                        
                        # Check client's message format for EACH message (format may be detected after first audio)
                        async with self.stream_format_lock:
                            uses_msgpack = self.stream_uses_msgpack.get(stream_id, True)
                        
                        # Check for Marker echo
                        if output_dict.get('is_marker'):
                            elapsed = time.perf_counter() - self.start_time
                            print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Echoing Marker back to client (format={'msgpack' if uses_msgpack else 'text'})")
                            
                            # Send marker in client's format
                            if uses_msgpack:
                                # Msgpack client - send msgpack Marker
                                marker_msg = {"type": "Marker", "id": -1}
                                marker_bytes = msgpack.packb(marker_msg, use_bin_type=True)
                                await ws.send_bytes(marker_bytes)
                            else:
                                # Raw bytes client - send text marker
                                await ws.send_text("END")
                            
                            # Yield to event loop
                            await asyncio.sleep(0)
                            break  # Client will close connection after receiving Marker
                        
                        # Send transcription in client's preferred format
                        if uses_msgpack:
                            # Msgpack client - send pre-encoded msgpack bytes
                            msgpack_bytes = output_dict.get('msgpack_bytes')
                            if msgpack_bytes:
                                await ws.send_bytes(msgpack_bytes)
                        else:
                            # Raw bytes client - send JSON text
                            json_str = output_dict.get('json_str')
                            if json_str:
                                if step_num == 0:
                                    elapsed = time.perf_counter() - self.start_time
                                    print(f"[+{elapsed:7.3f}s] Stream {stream_id}: Sending first JSON text message: {json_str[:100]}")
                                await ws.send_text(json_str)
                        
                        step_num += 1
                        
                except Exception as e:
                    elapsed = time.perf_counter() - self.start_time
                    print(f"[+{elapsed:7.3f}s] ❌ send_loop error stream {stream_id}: {e}")
                finally:
                    # Connection will be closed by client after Marker echo
                    # No need to send additional signals
                    pass
            
            # WebSocket already accepted at the top
            tasks = []
            try:
                tasks = [
                    asyncio.create_task(recv_loop(ws, stream_id, transcription_queue)),
                    asyncio.create_task(send_loop(ws, transcription_queue, stream_id)),
                ]

                # Send ready message as JSON text (works for all clients)
                # We don't know the client's format yet, so use JSON which is universally supported
                ready_msg = {"type": "Ready", "id": stream_id}
                await ws.send_text(json.dumps(ready_msg))
                
                # Wait for both tasks to complete
                await asyncio.gather(*tasks, return_exceptions=True)
                
            except Exception as e:
                print(f"❌ WebSocket error for stream {stream_id}: {e}")
                import traceback
                traceback.print_exc()
            finally:
                # Cleanup
                await self.unregister_client(stream_id)
                if ws and ws.application_state is WebSocketState.CONNECTED:
                    await ws.close()
                # Cancel any remaining tasks
                for task in tasks:
                    if not task.done():
                        try:
                            task.cancel()
                            await task
                        except asyncio.CancelledError:
                            pass

        # WHY: Removed uvicorn thread + modal.forward() tunnel.
        # Everything is served through @modal.asgi_app() natively — no tunnel overhead.
        print("ASGI app ready (served natively by Modal, no tunnel)")
    
    async def register_client(self, stream_id: int) -> asyncio.Queue:
        """
        Register a new client and return their transcription queue
        Args:
            stream_id (int): The stream ID for this client
        Returns:
            asyncio.Queue: The transcription queue for this client
        """
        queue = asyncio.Queue()
        async with self.client_queues_lock:
            self.client_queues[stream_id] = queue
        async with self.stream_timestamps_lock:
            self.stream_audio_timestamps[stream_id] = 0.0  # Initialize audio timestamp
        return queue
    
    async def unregister_client(self, stream_id: int):
        """
        Remove client from tracking and mark stream for cleanup.
        Actual state cleanup happens in inference loop after all buffered frames are processed.
        Args:
            stream_id (int): The stream ID to remove
        """
        try:
            async with self.client_queues_lock:
                self.client_queues.pop(stream_id, None)
            async with self.stream_timestamps_lock:
                self.stream_audio_timestamps.pop(stream_id, None)  # Clean up timestamp tracking
            async with self.stream_format_lock:
                self.stream_uses_msgpack.pop(stream_id, None)  # Clean up format tracking
            
            # Clean up VAD state for this stream
            if self.vad is not None:
                self.vad.reset_stream(stream_id)
            
            # Mark the stream as ended (stops accepting new audio)
            if hasattr(self.pipeline, '_streaming_request_generator'):
                self.pipeline._streaming_request_generator.close_stream(stream_id)
            
            # Mark stream for cleanup (actual cleanup happens after frames are drained)
            async with self.streams_cleanup_lock:
                self.streams_pending_cleanup.add(stream_id)
            
            elapsed = time.perf_counter() - self.start_time
            print(f"[+{elapsed:7.3f}s] Client {stream_id} disconnected, marked for cleanup ({len(self.client_queues)} remaining)")
        except Exception as e:
            elapsed = time.perf_counter() - self.start_time
            print(f"[+{elapsed:7.3f}s] ❌ Unregister error stream {stream_id}: {e}")
    
    async def route_outputs(self, outputs_with_timestamps):
        """
        Route outputs to client-specific queues in parallel without holding lock during I/O
        
        Args:
            outputs_with_timestamps: List of dicts with 'output' and 'audio_timestamp' keys
        """
        # BATCH ENCODING: Encode each unique output in BOTH formats
        # Then send the appropriate format to each client
        current_time = time.time()
        import json
        
        for output_dict in outputs_with_timestamps:
            output = output_dict['output']
            audio_timestamp = output_dict['audio_timestamp']
            
            # Check if output has transcription text to send
            if output.partial_transcript:
                # Build result dict
                result = {
                    "text": output.current_step_transcript,
                    "timestamp": current_time,
                    "audio_timestamp": audio_timestamp,
                    "is_final": False,
                }
                if len(output.final_segments) > 0:
                    result["segment_text"] = output.final_segments[0].text
                    result["segment_start_time"] = output.final_segments[0].start
                    result["segment_end_time"] = output.final_segments[0].end
                
                # Encode in BOTH formats (msgpack for old clients, JSON for new)
                output_dict['msgpack_bytes'] = msgpack.packb(result, use_bin_type=True)
                output_dict['json_str'] = json.dumps(result)
            else:
                result = {
                    "text": output.final_transcript,
                    "timestamp": current_time,
                    "audio_timestamp": audio_timestamp,
                    "is_final": True,
                }
                if len(output.final_segments) > 0:
                    result["segment_text"] = output.final_segments[0].text
                    result["segment_start_time"] = output.final_segments[0].start
                    result["segment_end_time"] = output.final_segments[0].end
                
                # Encode in BOTH formats
                output_dict['msgpack_bytes'] = msgpack.packb(result, use_bin_type=True)
                output_dict['json_str'] = json.dumps(result)
        
        # Quick snapshot under lock
        async with self.client_queues_lock:
            queues_snapshot = dict(self.client_queues)
        
        # Build queue_map from snapshot WITHOUT holding lock
        queue_map = {}  # stream_id -> (list of output_dicts, queue)
        for output_dict in outputs_with_timestamps:
            output = output_dict['output']
            queue = queues_snapshot.get(output.stream_id)
            if queue:
                if output.stream_id not in queue_map:
                    queue_map[output.stream_id] = ([], queue)
                queue_map[output.stream_id][0].append(output_dict)
            else:
                elapsed = time.perf_counter() - self.start_time
                print(f"[+{elapsed:7.3f}s] ⚠️ No queue for stream {output.stream_id}")
        
        # Now route outputs in parallel WITHOUT holding the lock
        async def send_to_queue(stream_id, output_dicts, queue):
            """
            Send all outputs for this stream to its queue
            """
            for output_dict in output_dicts:
                try:
                    await queue.put(output_dict)
                except Exception as e:
                    elapsed = time.perf_counter() - self.start_time
                    print(f"[+{elapsed:7.3f}s] ❌ Route error stream {stream_id}: {e}")
        
        # Send all outputs for each stream in parallel
        if queue_map:
            await asyncio.gather(*[
                send_to_queue(stream_id, output_dicts, queue) 
                for stream_id, (output_dicts, queue) in queue_map.items()
            ], return_exceptions=True)
    
    async def centralized_inference_loop(self):
        """
        Production inference loop with event-driven batching.
        
        WHY event-driven instead of fixed sleep:
        - Fixed 300ms sleep adds 300ms latency ALWAYS, even when data is ready
        - Event-driven fires IMMEDIATELY when enough data arrives (max throughput)
        - Falls back to max_batch_wait_ms timeout (bounded latency, default 50ms)
        - Under high load: batches fill instantly, near-zero wait
        - Under low load: waits up to 50ms, still 6x faster than 300ms
        
        This is the "smart taxi" pattern described by big ASR labs:
        dispatch when full OR after short timeout, whichever comes first.
        """
        self.inference_running = True
        
        # Production metrics (structured, not ad-hoc)
        self._metrics = {
            'total_batches': 0,
            'total_outputs': 0,
            'total_gpu_ms': 0.0,
            'total_route_ms': 0.0,
            'batch_sizes': [],       # Rolling window for reporting
            'gpu_times_ms': [],      # Rolling window
            'cycle_times_ms': [],    # Rolling window
            'empty_cycles': 0,       # How often we woke up with nothing to do
            'vad_filtered_chunks': 0,  # Chunks skipped by VAD
        }
        METRICS_REPORT_INTERVAL = 100  # Report every N productive cycles
        METRICS_WINDOW = 200           # Rolling window size for percentiles
        
        max_wait = self.cfg.streaming.max_batch_wait_ms / 1000.0  # Convert to seconds
        
        while self.inference_running:
            try:
                active_clients = len(self.client_queues)
                
                # Idle if no clients — wait on event instead of polling
                if active_clients == 0:
                    # WHY: Use event.wait() instead of sleep() so we wake immediately
                    # when the first client connects and signals new data
                    try:
                        await asyncio.wait_for(self._batch_ready_event.wait(), timeout=0.1)
                        self._batch_ready_event.clear()
                    except asyncio.TimeoutError:
                        pass
                    continue
                
                # === EVENT-DRIVEN WAIT ===
                # Wait for EITHER: data signal OR timeout, whichever comes first
                # This replaces the fixed 300ms sleep
                cycle_start = time.perf_counter()
                try:
                    await asyncio.wait_for(self._batch_ready_event.wait(), timeout=max_wait)
                except asyncio.TimeoutError:
                    pass  # Timeout is normal — just means we fire with whatever we have
                self._batch_ready_event.clear()
                
                # === PROCESS ALL READY FRAMES (exhaustive) ===
                # WHY: Always exhaustive in production. Processing one batch then sleeping
                # wastes time when multiple batches are ready. Drain everything, THEN wait.
                try:
                    all_outputs_with_timestamps = []
                    
                    # Snapshot timestamps
                    async with self.stream_timestamps_lock:
                        cycle_timestamps = dict(self.stream_audio_timestamps)
                    
                    batches_this_cycle = 0
                    while True:
                        try:
                            batch_start = time.perf_counter()
                            requests, outputs = await self.pipeline.process_streaming_batch()
                            batch_end = time.perf_counter()
                            gpu_ms = (batch_end - batch_start) * 1000
                            
                            if not outputs:
                                break  # No more frames ready
                            
                            batches_this_cycle += 1
                            self._metrics['total_batches'] += 1
                            self._metrics['total_outputs'] += len(outputs)
                            self._metrics['total_gpu_ms'] += gpu_ms
                            self._metrics['batch_sizes'].append(len(outputs))
                            self._metrics['gpu_times_ms'].append(gpu_ms)
                            
                            # Trim rolling windows
                            if len(self._metrics['batch_sizes']) > METRICS_WINDOW:
                                self._metrics['batch_sizes'] = self._metrics['batch_sizes'][-METRICS_WINDOW:]
                            if len(self._metrics['gpu_times_ms']) > METRICS_WINDOW:
                                self._metrics['gpu_times_ms'] = self._metrics['gpu_times_ms'][-METRICS_WINDOW:]
                            
                            # Update timestamps
                            for request in requests:
                                if hasattr(request, 'length'):
                                    sid = request.stream_id
                                    duration_secs = request.length / 16000.0
                                    cycle_timestamps[sid] = cycle_timestamps.get(sid, 0.0) + duration_secs
                            
                            for output in outputs:
                                all_outputs_with_timestamps.append({
                                    'output': output,
                                    'audio_timestamp': cycle_timestamps.get(output.stream_id, 0.0),
                                })
                            
                            # Yield to event loop between sub-batches
                            await asyncio.sleep(0)
                            
                        except Exception as e:
                            if type(e).__name__ == 'NotEnoughDataException':
                                if batches_this_cycle == 0:
                                    self._metrics['empty_cycles'] += 1
                                break
                            raise
                    
                    # Route outputs and update state
                    if all_outputs_with_timestamps:
                        async with self.stream_timestamps_lock:
                            for sid in cycle_timestamps:
                                if sid in self.stream_audio_timestamps:
                                    self.stream_audio_timestamps[sid] = cycle_timestamps[sid]
                        
                        route_start = time.perf_counter()
                        await self.route_outputs(all_outputs_with_timestamps)
                        route_end = time.perf_counter()
                        self._metrics['total_route_ms'] += (route_end - route_start) * 1000
                        
                        # Deferred stream cleanup
                        await self._cleanup_finished_streams()
                        
                        cycle_end = time.perf_counter()
                        self._metrics['cycle_times_ms'].append((cycle_end - cycle_start) * 1000)
                        if len(self._metrics['cycle_times_ms']) > METRICS_WINDOW:
                            self._metrics['cycle_times_ms'] = self._metrics['cycle_times_ms'][-METRICS_WINDOW:]
                        
                        # === PERIODIC METRICS REPORT ===
                        if self._metrics['total_batches'] % METRICS_REPORT_INTERVAL == 0:
                            self._log_metrics(active_clients)
                
                except Exception as e:
                    if type(e).__name__ != 'NotEnoughDataException':
                        print(f"[ERROR] Inference cycle: {e}")
                        import traceback
                        traceback.print_exc()
                
            except Exception as e:
                print(f"[ERROR] Inference loop outer: {e}")
                import traceback
                traceback.print_exc()
                await asyncio.sleep(0.1)  # Back off on errors
        
        print("[WARN] Centralized inference loop stopped!")
    
    def _log_metrics(self, active_clients: int):
        """Structured metrics report for monitoring/alerting"""
        m = self._metrics
        elapsed = time.perf_counter() - self.start_time
        
        # Calculate percentiles from rolling windows
        def pct(arr, p):
            if not arr:
                return 0
            s = sorted(arr)
            idx = min(int(len(s) * p / 100), len(s) - 1)
            return s[idx]
        
        gpu_p50 = pct(m['gpu_times_ms'], 50)
        gpu_p99 = pct(m['gpu_times_ms'], 99)
        cycle_p50 = pct(m['cycle_times_ms'], 50)
        cycle_p99 = pct(m['cycle_times_ms'], 99)
        batch_avg = sum(m['batch_sizes']) / len(m['batch_sizes']) if m['batch_sizes'] else 0
        batch_max = max(m['batch_sizes']) if m['batch_sizes'] else 0
        
        # VRAM usage
        try:
            vram_used = torch.cuda.memory_allocated() / 1e9
            vram_total = torch.cuda.get_device_properties(0).total_memory / 1e9
            vram_pct = (vram_used / vram_total) * 100
        except Exception:
            vram_used = vram_total = vram_pct = 0
        
        print(f"[+{elapsed:7.1f}s] METRICS | streams={active_clients} "
              f"batches={m['total_batches']} outputs={m['total_outputs']} "
              f"empty={m['empty_cycles']} vad_skip={m['vad_filtered_chunks']}")
        print(f"         GPU: p50={gpu_p50:.1f}ms p99={gpu_p99:.1f}ms | "
              f"Cycle: p50={cycle_p50:.1f}ms p99={cycle_p99:.1f}ms | "
              f"Batch: avg={batch_avg:.0f} max={batch_max}")
        print(f"         VRAM: {vram_used:.1f}/{vram_total:.0f}GB ({vram_pct:.0f}%)")
    
    async def _cleanup_finished_streams(self):
        """
        Clean up streams that are marked for cleanup AND fully drained from the request generator.
        This ensures we don't delete state while frames are still being processed.
        """
        async with self.streams_cleanup_lock:
            if not self.streams_pending_cleanup:
                return  # Nothing to clean up
            
            # Get list of streams still in the request generator
            active_stream_ids = set(self.pipeline._streaming_request_generator.streams.keys())
            
            # Find streams that are pending cleanup AND no longer in request generator
            streams_to_cleanup = self.streams_pending_cleanup - active_stream_ids
            
            if streams_to_cleanup:
                elapsed = time.perf_counter() - self.start_time
                for stream_id in streams_to_cleanup:
                    try:
                        # Now safe to delete state - no more frames will arrive
                        if stream_id in self.pipeline._state_pool:
                            del self.pipeline._state_pool[stream_id]
                            print(f"[+{elapsed:7.3f}s]    🧹 Cleaned up state_pool for stream {stream_id}")
                        
                        # Reset context manager cache slot
                        if hasattr(self.pipeline, 'context_manager') and self.pipeline.context_manager is not None:
                            try:
                                self.pipeline.context_manager.reset_slots([stream_id], [True])
                                print(f"[+{elapsed:7.3f}s]    🧹 Reset context_manager cache slot for stream {stream_id}")
                            except Exception as e:
                                print(f"[+{elapsed:7.3f}s]    ⚠️  Could not reset context_manager for stream {stream_id}: {e}")
                        
                        # Remove from pending cleanup
                        self.streams_pending_cleanup.discard(stream_id)
                        
                    except Exception as e:
                        print(f"[+{elapsed:7.3f}s] ❌ Error cleaning up stream {stream_id}: {e}")
    
    def start_inference_loop_if_needed(self):
        """
        Start the centralized inference loop if it's not already running.
        Called on first client connection.
        """
        if self.inference_task is None:
            print("🔄 Starting centralized inference loop (first client)...")
            self.inference_task = asyncio.create_task(self.centralized_inference_loop())
            print("✅ Centralized inference loop started!")
        elif self.inference_task.done():
            print("⚠️  Inference task was done! Restarting...")
            # Check if it had an exception
            try:
                exc = self.inference_task.exception()
                if exc:
                    print(f"❌ Previous task failed with: {exc}")
            except:
                pass
            self.inference_task = asyncio.create_task(self.centralized_inference_loop())
            print("✅ Centralized inference loop restarted!")
        else:
            pass  # Already running

    
    @modal.asgi_app()
    def webapp(self):
        """Expose the FastAPI WebSocket app"""
        return self.web_app

    @modal.method()
    def get_config_dict(self) -> dict:
        """
        Get a serializable dictionary representation of the pipeline configuration.
        
        Returns:
            Dictionary containing all configuration values
        """
        # Convert OmegaConf to regular dict for serialization
        config_dict = OmegaConf.to_container(self.cfg, resolve=True)
        return config_dict
    
    @modal.method()
    def transcribe_file(self, audio_url: str) -> dict:
        """
        Transcribe an audio file from a URL or local path
        
        Args:
            audio_url: URL or path to audio file
            
        Returns:
            Dictionary with transcription results
        """
        print(f"Transcribing audio from: {audio_url}")
        
        # Preprocess audio to ensure it's 16kHz mono
        start_time = time.perf_counter()
        audio_bytes = preprocess_audio(audio_url, target_sample_rate=16000)
        preprocess_time = time.perf_counter() - start_time
        print(f"Audio preprocessing took {preprocess_time:.2f} seconds")
        
        # Write to temp file for pipeline
        import tempfile
        import wave
        import os
        
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            with wave.open(tmp, "wb") as wav_file:
                wav_file.setnchannels(1)
                wav_file.setsampwidth(2)
                wav_file.setframerate(16000)
                wav_file.writeframes(audio_bytes)
            tmp_path = tmp.name
        
        # Calculate audio duration before running pipeline
        data_dur = len(audio_bytes) / (16000 * 2)  # 16kHz, 16-bit (2 bytes per sample)
        
        # Run pipeline
        start_time = time.perf_counter()
        progress_bar = TQDMProgressBar()
        output = self.pipeline.run([tmp_path], progress_bar=progress_bar)
        inference_time = time.perf_counter() - start_time
        
        # Clean up temp file
        os.unlink(tmp_path)
        
        result = {
            "text": output[0]["text"],
            "inference_time": inference_time,
            "audio_duration": data_dur,
        }
        
        print(f"Transcription complete!")
        print(f"  Text: {result['text']}")
        print(f"  Audio duration: {data_dur:.2f}s")
        print(f"  Inference time: {inference_time:.2f}s")
        
        return result


# ## Frontend Service
#
# We serve a simple HTML/JS frontend to interact with the transcriber.
# The frontend captures microphone input and streams it to the WebSocket endpoint.

web_image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install("fastapi")
    .add_local_dir(Path(__file__).parent.parent  / "nemotron-asr-frontend", "/root/frontend")
)

with web_image.imports():
    from fastapi import FastAPI
    from fastapi.responses import HTMLResponse, Response
    from fastapi.staticfiles import StaticFiles


@app.cls(image=web_image)
class WebServer:
    @modal.asgi_app()
    def web(self):
        web_app = FastAPI()
        web_app.mount("/static", StaticFiles(directory="/root/frontend"), name="static")

        @web_app.get("/status")
        async def status():
            return Response(status_code=200)

        # Serve frontend
        @web_app.get("/")
        async def index():
            html_content = open("/root/frontend/index.html").read()

            # Get the WebSocket URL from the NemotronASR
            cls_instance = NemotronASR()
            ws_base_url = (
                cls_instance.webapp.web_url.replace("http", "ws") + "/ws"
            )
            script_tag = f'<script>window.WS_BASE_URL = "{ws_base_url}";</script>'
            html_content = html_content.replace(
                '<script src="/static/cache-aware-stt.js"></script>',
                f'{script_tag}\n    <script src="/static/cache-aware-stt.js"></script>',
            )
            return HTMLResponse(content=html_content)

        return web_app


@app.local_entrypoint()
def main():
    """Quick smoke test — deploy and check Hindi transcription works"""
    print("Starting Hindi Nemotron ASR deployment test...")
    runner = NemotronASR()
    config = runner.get_config_dict.remote()
    print(f"\nDeployed config: model={config['asr']['model_name'][:60]}...")
    print(f"  streaming.att_context_size={config['streaming']['att_context_size']}")
    print(f"  streaming.batch_size={config['streaming']['batch_size']}")
    print("✅ Hindi ASR service is ready!")

print("✅ Hindi ASR service is ready!")

