"""
CPU-based Voice Activity Detection using Silero VAD.
WHY: Filters silence BEFORE GPU pipeline. With voice-agent traffic (~30% speech
duty cycle), this effectively 3x GPU capacity for free.
Runs on CPU, ~1ms per chunk — negligible overhead.

Hysteresis: requires 3 consecutive silence frames before suppressing.
This prevents clipping speech at word boundaries.

Fix 4: Eager preload at startup — no lazy torch.hub.load in the request path.
The first-request penalty (model download + compile) is eliminated.
"""
import logging
import numpy as np
import torch

logger = logging.getLogger("triton_asr.vad")


class SileroVAD:
    def __init__(self, threshold=0.5, sample_rate=16000):
        self.threshold = threshold
        self.sample_rate = sample_rate
        self._model = None
        self._silence_counts = {}
        self.silence_threshold_frames = 3

    def preload(self):
        """
        Fix 4: Eagerly load VAD model at startup, not lazily on first request.
        WHY: torch.hub.load can trigger a network download and block the event loop
        for seconds. In production, that hits the first user with unexpected latency.
        Call this during server startup (engine.load_model) so the cost is paid once.
        """
        if self._model is not None:
            return
        logger.info("Preloading Silero VAD model...")
        try:
            self._model, _ = torch.hub.load(
                repo_or_dir="snakers4/silero-vad",
                model="silero_vad", force_reload=False, onnx=False,
            )
            self._model.eval()
            logger.info("Silero VAD model loaded successfully")
        except Exception as e:
            # WHY: If VAD fails to load, log and disable rather than crashing server.
            # is_speech() will return True (pass-through) when _model is None.
            logger.error(f"Failed to load Silero VAD: {e}. VAD will be disabled.")
            self._model = None

    def is_speech(self, audio_bytes, stream_id=0):
        """
        Check if audio chunk contains speech.
        Returns True on any error or if model not loaded (safe fallback).
        """
        if self._model is None:
            return True
        try:
            audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
            audio_tensor = torch.from_numpy(audio_np)
            window_size = 512  # 32ms at 16kHz
            if len(audio_tensor) < window_size:
                return True
            chunk = audio_tensor[-window_size:]
            prob = self._model(chunk, self.sample_rate).item()

            if prob >= self.threshold:
                self._silence_counts[stream_id] = 0
                return True
            else:
                count = self._silence_counts.get(stream_id, 0) + 1
                self._silence_counts[stream_id] = count
                return count < self.silence_threshold_frames
        except Exception:
            return True

    def reset_stream(self, stream_id):
        """Clean up state for a closed stream."""
        self._silence_counts.pop(stream_id, None)
