"""
Local per-chunk latency test using the repo's direct inference approach.
Tests raw GPU inference time per 80ms chunk — NO network, NO batching overhead.
This is the REAL latency baseline the Modal deployment should approach.

Uses DirectStreamingASR pattern from nemotron/webapp/app.py
"""

import os
import sys
import time
import json
import numpy as np
import torch
import soundfile as sf
from scipy.signal import resample_poly
from math import gcd
from datasets import load_dataset
from jiwer import wer, cer
from dotenv import load_dotenv

load_dotenv()

# ─── Config ──────────────────────────────────────────────────────────────────
MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--BayAreaBoys--nemotron-hindi/snapshots/f9625adc4d8151e6ed5ae59e8f23fa96adc345c1/final_model.nemo"
SAMPLE_RATE = 16000
CHUNK_MS = 80          # 80ms = minimum for 8x subsampling (1 encoder frame)
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000)  # 1280 samples
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results")


def normalize_hindi(text: str) -> str:
    import re
    if not text:
        return ""
    text = re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}<>]', ' ', text)
    return re.sub(r'\s+', ' ', text).strip().lower()


# ─── Direct Streaming ASR (from repo's webapp/app.py) ────────────────────────
class DirectStreamingASR:
    """
    Direct chunk-by-chunk inference — measures raw GPU time per chunk.
    Each 80ms chunk: preprocessor → encoder → RNNT decoder.
    No batching, no WebSocket, no network — pure GPU inference latency.
    """

    def __init__(self, model_path: str):
        import nemo.collections.asr as nemo_asr
        from omegaconf import open_dict

        print(f"Loading model: {model_path}")
        self.model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(model_path)

        # WHY: Disable CUDA graph decoder — cu_call unpacking bug on this CUDA version
        with open_dict(self.model.cfg):
            if hasattr(self.model.cfg, 'decoding'):
                self.model.cfg.decoding.greedy.loop_labels = False
                self.model.cfg.decoding.greedy.use_cuda_graph_decoder = False
        self.model.change_decoding_strategy(self.model.cfg.decoding)

        self.model.eval()
        self.model = self.model.to(DEVICE)
        print(f"Model loaded on {DEVICE}")

    @torch.inference_mode()
    def transcribe_chunked(self, audio: np.ndarray) -> dict:
        """
        Process audio chunk-by-chunk, measure latency per chunk.
        Returns transcript + latency stats.
        """
        num_chunks = len(audio) // CHUNK_SAMPLES
        latencies = []
        all_text = []

        for i in range(num_chunks):
            chunk = audio[i * CHUNK_SAMPLES:(i + 1) * CHUNK_SAMPLES]

            t0 = time.perf_counter()

            audio_tensor = torch.tensor(chunk, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            audio_length = torch.tensor([audio_tensor.shape[1]], device=DEVICE)

            # Preprocess (mel spectrogram)
            processed, processed_len = self.model.preprocessor(
                input_signal=audio_tensor, length=audio_length,
            )
            # Encode
            encoded, encoded_len = self.model.encoder(
                audio_signal=processed, length=processed_len,
            )
            # Decode (greedy RNNT)
            hypotheses = self.model.decoding.rnnt_decoder_predictions_tensor(
                encoder_output=encoded, encoded_lengths=encoded_len,
            )

            latency_ms = (time.perf_counter() - t0) * 1000
            latencies.append(latency_ms)

            if hypotheses and len(hypotheses) > 0:
                if hasattr(hypotheses[0], 'text'):
                    text = hypotheses[0].text
                else:
                    text = str(hypotheses[0])
                if text:
                    all_text.append(text)

        # Handle remaining samples
        remaining = len(audio) % CHUNK_SAMPLES
        if remaining > 0:
            chunk = np.pad(audio[-remaining:], (0, CHUNK_SAMPLES - remaining))
            t0 = time.perf_counter()
            audio_tensor = torch.tensor(chunk, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            audio_length = torch.tensor([audio_tensor.shape[1]], device=DEVICE)
            processed, processed_len = self.model.preprocessor(input_signal=audio_tensor, length=audio_length)
            encoded, encoded_len = self.model.encoder(audio_signal=processed, length=processed_len)
            hypotheses = self.model.decoding.rnnt_decoder_predictions_tensor(
                encoder_output=encoded, encoded_lengths=encoded_len)
            latency_ms = (time.perf_counter() - t0) * 1000
            latencies.append(latency_ms)
            if hypotheses and hasattr(hypotheses[0], 'text') and hypotheses[0].text:
                all_text.append(hypotheses[0].text)

        return {
            "transcript": " ".join(all_text),
            "latencies": latencies,
            "num_chunks": len(latencies),
        }

    def transcribe_full(self, audio_paths: list) -> list:
        """Full-file transcription for WER comparison."""
        # WHY: Monkey-patch to bypass lhotse DynamicCutSampler compat bug
        import types
        from omegaconf import DictConfig

        def patched_setup(self_model, config):
            if 'manifest_filepath' in config:
                manifest_filepath = config['manifest_filepath']
                batch_size = config['batch_size']
            else:
                manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
                batch_size = min(config['batch_size'], len(config['paths2audio_files']))
            dl_config = {
                'manifest_filepath': manifest_filepath,
                'sample_rate': self_model.preprocessor._sample_rate,
                'batch_size': batch_size,
                'shuffle': False,
                'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
                'pin_memory': True,
                'channel_selector': config.get('channel_selector', None),
                'use_start_end_token': self_model.cfg.validation_ds.get('use_start_end_token', False),
            }
            return self_model._setup_dataloader_from_config(config=DictConfig(dl_config))

        self.model._setup_transcribe_dataloader = types.MethodType(patched_setup, self.model)
        results = self.model.transcribe(audio_paths, batch_size=4, verbose=False)
        texts = []
        for r in results:
            if isinstance(r, str):
                texts.append(r)
            elif hasattr(r, 'text'):
                texts.append(r.text)
            else:
                texts.append(str(r))
        return texts


# ─── Main ─────────────────────────────────────────────────────────────────────
def main():
    import tempfile

    print("=" * 70)
    print("  LOCAL Per-Chunk Latency Test (Direct GPU Inference)")
    print(f"  Model: BayAreaBoys/nemotron-hindi")
    print(f"  Chunk: {CHUNK_MS}ms ({CHUNK_SAMPLES} samples)")
    print(f"  GPU:   {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
    print("=" * 70)

    asr = DirectStreamingASR(MODEL_PATH)

    # Warmup — 3 dummy chunks to warm CUDA kernels
    print("\nWarming up GPU...")
    for _ in range(3):
        dummy = np.random.randn(CHUNK_SAMPLES).astype(np.float32) * 0.01
        asr.transcribe_chunked(dummy)
    print("  GPU warm.")

    # Load FLEURS Hindi test
    print("\nLoading FLEURS Hindi test set...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    N = min(50, len(ds))

    all_latencies = []
    references = []
    hypotheses_chunked = []

    tmp_dir = tempfile.mkdtemp()
    audio_paths = []

    for i in range(N):
        item = ds[i]
        audio = np.array(item["audio"]["array"], dtype=np.float32)
        sr = item["audio"]["sampling_rate"]

        # Resample to 16kHz
        if sr != SAMPLE_RATE:
            g = gcd(sr, SAMPLE_RATE)
            audio = resample_poly(audio, SAMPLE_RATE // g, sr // g).astype(np.float32)

        # Per-chunk latency test
        result = asr.transcribe_chunked(audio)
        all_latencies.extend(result["latencies"])
        hypotheses_chunked.append(result["transcript"])
        references.append(item["transcription"])

        # Save for full-file comparison
        path = os.path.join(tmp_dir, f"{i}.wav")
        sf.write(path, audio, SAMPLE_RATE)
        audio_paths.append(path)

        if (i + 1) % 10 == 0:
            med = np.median(all_latencies)
            print(f"  [{i+1}/{N}] chunks={len(all_latencies)}, median={med:.1f}ms")

    # Full-file transcription for accurate WER
    print(f"\nRunning full-file transcription for WER ({N} files)...")
    t0 = time.time()
    hypotheses_full = asr.transcribe_full(audio_paths)
    full_time = time.time() - t0

    # Cleanup temp files
    for p in audio_paths:
        os.unlink(p)
    os.rmdir(tmp_dir)

    # ─── Results ──────────────────────────────────────────────────────────
    lats = np.array(all_latencies)

    # WER on full-file transcriptions
    refs_n = [normalize_hindi(r) for r in references]
    hyps_n = [normalize_hindi(h) for h in hypotheses_full]
    pairs = [(r, h) for r, h in zip(refs_n, hyps_n) if r.strip()]
    refs_f, hyps_f = zip(*pairs) if pairs else ([], [])
    corpus_wer = wer(list(refs_f), list(hyps_f)) if refs_f else 0
    corpus_cer = cer(list(refs_f), list(hyps_f)) if refs_f else 0

    total_audio = sum(len(np.array(ds[i]["audio"]["array"])) / ds[i]["audio"]["sampling_rate"] for i in range(N))

    print(f"\n{'='*70}")
    print(f"  RESULTS — Local Direct Inference (A100)")
    print(f"{'='*70}")
    print(f"  Samples:        {N}")
    print(f"  Total audio:    {total_audio:.0f}s")
    print(f"  Total chunks:   {len(lats)}")
    print(f"  Chunk size:     {CHUNK_MS}ms")
    print(f"")
    print(f"  === PER-CHUNK LATENCY (raw GPU inference) ===")
    print(f"  Median:         {np.median(lats):.1f}ms")
    print(f"  Mean:           {np.mean(lats):.1f}ms")
    print(f"  P95:            {np.percentile(lats, 95):.1f}ms")
    print(f"  P99:            {np.percentile(lats, 99):.1f}ms")
    print(f"  Min:            {np.min(lats):.1f}ms")
    print(f"  Max:            {np.max(lats):.1f}ms")
    print(f"")
    print(f"  === ACCURACY (full-file transcription) ===")
    print(f"  WER:            {corpus_wer*100:.2f}%")
    print(f"  CER:            {corpus_cer*100:.2f}%")
    print(f"  Full-file time: {full_time:.1f}s ({total_audio/full_time:.0f}x real-time)")
    print(f"{'='*70}")

    # Save results
    os.makedirs(RESULTS_DIR, exist_ok=True)
    summary = {
        "test": "local_direct_inference",
        "model": "BayAreaBoys/nemotron-hindi",
        "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
        "chunk_ms": CHUNK_MS,
        "samples": N,
        "total_chunks": len(lats),
        "total_audio_sec": round(total_audio, 1),
        "chunk_latency_median_ms": round(float(np.median(lats)), 1),
        "chunk_latency_mean_ms": round(float(np.mean(lats)), 1),
        "chunk_latency_p95_ms": round(float(np.percentile(lats, 95)), 1),
        "chunk_latency_p99_ms": round(float(np.percentile(lats, 99)), 1),
        "chunk_latency_min_ms": round(float(np.min(lats)), 1),
        "chunk_latency_max_ms": round(float(np.max(lats)), 1),
        "wer": round(corpus_wer * 100, 2),
        "cer": round(corpus_cer * 100, 2),
        "throughput_rtx": round(total_audio / full_time, 1),
    }
    with open(os.path.join(RESULTS_DIR, "local_direct_latency.json"), "w") as f:
        json.dump(summary, f, indent=2)
    print(f"\n  Saved: benchmark_results/local_direct_latency.json")


if __name__ == "__main__":
    main()
