"""
Local concurrent per-chunk latency test.
Simulates N parallel streams processing through the GPU simultaneously.

WHY: Tests whether per-chunk latency holds under concurrent load on A100,
without any network/WebSocket/Modal overhead. Pure GPU contention test.

Uses ThreadPoolExecutor to simulate concurrent streams hitting the model.
"""

import os
import sys
import time
import json
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import List, Dict
from math import gcd

import numpy as np
import torch
import soundfile as sf
from scipy.signal import resample_poly
from datasets import load_dataset
from jiwer import wer, cer
from dotenv import load_dotenv

load_dotenv()

MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--BayAreaBoys--nemotron-hindi/snapshots/f9625adc4d8151e6ed5ae59e8f23fa96adc345c1/final_model.nemo"
SAMPLE_RATE = 16000
CHUNK_MS = 80
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000)
DEVICE = "cuda"
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results")


def normalize_hindi(text: str) -> str:
    import re
    if not text:
        return ""
    text = re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}<>]', ' ', text)
    return re.sub(r'\s+', ' ', text).strip().lower()


def load_model():
    import nemo.collections.asr as nemo_asr
    from omegaconf import open_dict

    print(f"Loading model...")
    model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(MODEL_PATH)
    with open_dict(model.cfg):
        model.cfg.decoding.greedy.loop_labels = False
        model.cfg.decoding.greedy.use_cuda_graph_decoder = False
    model.change_decoding_strategy(model.cfg.decoding)
    model.eval()
    model = model.to(DEVICE)
    print(f"Model on {DEVICE}")
    return model


def load_fleurs(n: int) -> List[Dict]:
    print(f"Loading FLEURS Hindi ({n} samples)...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    samples = []
    for i in range(min(n, len(ds))):
        item = ds[i]
        arr = np.array(item["audio"]["array"], dtype=np.float32)
        sr = item["audio"]["sampling_rate"]
        if sr != SAMPLE_RATE:
            g = gcd(sr, SAMPLE_RATE)
            arr = resample_poly(arr, SAMPLE_RATE // g, sr // g).astype(np.float32)
        samples.append({
            "id": item.get("id", i),
            "audio": arr,
            "reference": item["transcription"],
            "duration": len(arr) / SAMPLE_RATE,
        })
    print(f"  Loaded {len(samples)} samples ({sum(s['duration'] for s in samples):.0f}s)")
    return samples


@dataclass
class StreamResult:
    sample_id: int = 0
    reference: str = ""
    hypothesis: str = ""
    duration: float = 0.0
    chunk_latencies: list = field(default_factory=list)
    wall_time: float = 0.0


@torch.inference_mode()
def process_one_stream(model, sample: Dict, stream_id: int) -> StreamResult:
    """Process one audio stream chunk-by-chunk. Runs in a thread."""
    audio = sample["audio"]
    result = StreamResult(
        sample_id=sample["id"],
        reference=sample["reference"],
        duration=sample["duration"],
    )

    t_wall = time.perf_counter()
    all_text = []
    latencies = []
    num_chunks = len(audio) // CHUNK_SAMPLES

    for i in range(num_chunks):
        chunk = audio[i * CHUNK_SAMPLES:(i + 1) * CHUNK_SAMPLES]

        t0 = time.perf_counter()
        audio_tensor = torch.tensor(chunk, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        audio_length = torch.tensor([audio_tensor.shape[1]], device=DEVICE)

        processed, processed_len = model.preprocessor(input_signal=audio_tensor, length=audio_length)
        encoded, encoded_len = model.encoder(audio_signal=processed, length=processed_len)
        hypotheses = model.decoding.rnnt_decoder_predictions_tensor(
            encoder_output=encoded, encoded_lengths=encoded_len)

        lat = (time.perf_counter() - t0) * 1000
        latencies.append(lat)

        if hypotheses and hasattr(hypotheses[0], 'text') and hypotheses[0].text:
            all_text.append(hypotheses[0].text)

    # Remaining samples
    remaining = len(audio) % CHUNK_SAMPLES
    if remaining > 0:
        chunk = np.pad(audio[-remaining:], (0, CHUNK_SAMPLES - remaining))
        t0 = time.perf_counter()
        audio_tensor = torch.tensor(chunk, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        audio_length = torch.tensor([audio_tensor.shape[1]], device=DEVICE)
        processed, processed_len = model.preprocessor(input_signal=audio_tensor, length=audio_length)
        encoded, encoded_len = model.encoder(audio_signal=processed, length=processed_len)
        hypotheses = model.decoding.rnnt_decoder_predictions_tensor(
            encoder_output=encoded, encoded_lengths=encoded_len)
        lat = (time.perf_counter() - t0) * 1000
        latencies.append(lat)
        if hypotheses and hasattr(hypotheses[0], 'text') and hypotheses[0].text:
            all_text.append(hypotheses[0].text)

    result.hypothesis = " ".join(all_text)
    result.chunk_latencies = latencies
    result.wall_time = time.perf_counter() - t_wall
    return result


def run_concurrent(model, samples: List[Dict], concurrency: int) -> List[StreamResult]:
    """Run N streams concurrently via ThreadPoolExecutor."""
    to_run = samples[:concurrency]
    print(f"\n  Running {len(to_run)} concurrent streams...")

    t0 = time.perf_counter()
    results = []

    with ThreadPoolExecutor(max_workers=concurrency) as pool:
        futures = {
            pool.submit(process_one_stream, model, s, i): i
            for i, s in enumerate(to_run)
        }
        for f in as_completed(futures):
            results.append(f.result())

    wall = time.perf_counter() - t0
    total_audio = sum(r.duration for r in results)
    print(f"  Done in {wall:.1f}s | {total_audio:.0f}s audio | {total_audio/wall:.1f}x RT")
    return results


def compute_metrics(results: List[StreamResult]) -> Dict:
    all_lats = []
    for r in results:
        all_lats.extend(r.chunk_latencies)
    lats = np.array(all_lats)

    refs = [normalize_hindi(r.reference) for r in results]
    hyps = [normalize_hindi(r.hypothesis) for r in results]
    pairs = [(r, h) for r, h in zip(refs, hyps) if r.strip()]
    if pairs:
        rf, hf = zip(*pairs)
        w = wer(list(rf), list(hf))
        c = cer(list(rf), list(hf))
    else:
        w, c = 1.0, 1.0

    total_audio = sum(r.duration for r in results)
    total_wall = max(r.wall_time for r in results)

    return {
        "streams": len(results),
        "total_chunks": len(lats),
        "total_audio_sec": round(total_audio, 1),
        "wall_sec": round(total_wall, 1),
        "throughput_rtx": round(total_audio / total_wall, 1) if total_wall > 0 else 0,
        "chunk_median_ms": round(float(np.median(lats)), 1),
        "chunk_mean_ms": round(float(np.mean(lats)), 1),
        "chunk_p95_ms": round(float(np.percentile(lats, 95)), 1),
        "chunk_p99_ms": round(float(np.percentile(lats, 99)), 1),
        "chunk_max_ms": round(float(np.max(lats)), 1),
        "wer": round(w * 100, 2),
        "cer": round(c * 100, 2),
    }


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--max-samples", type=int, default=100)
    parser.add_argument("--sweep", action="store_true")
    parser.add_argument("--concurrency", type=int, default=1)
    args = parser.parse_args()

    print("=" * 70)
    print("  LOCAL Concurrent Per-Chunk Latency Test (A100)")
    print(f"  Chunk: {CHUNK_MS}ms | GPU: {torch.cuda.get_device_name(0)}")
    print("=" * 70)

    model = load_model()

    # Warmup
    print("Warming up...")
    for _ in range(5):
        dummy = np.random.randn(CHUNK_SAMPLES).astype(np.float32) * 0.01
        process_one_stream(model, {"id": 0, "audio": dummy, "reference": "", "duration": 0.08}, 0)
    print("  Warm.\n")

    samples = load_fleurs(args.max_samples)

    if args.sweep:
        levels = [1, 5, 10, 25, 50]
    else:
        levels = [args.concurrency]

    all_metrics = []
    for c in levels:
        results = run_concurrent(model, samples, c)
        m = compute_metrics(results)
        m["concurrency"] = c
        all_metrics.append(m)

        print(f"  {'─'*60}")
        print(f"  Concurrency={c}: median={m['chunk_median_ms']}ms  p95={m['chunk_p95_ms']}ms  "
              f"p99={m['chunk_p99_ms']}ms  WER={m['wer']}%  {m['throughput_rtx']}x RT")
        print(f"  {'─'*60}")

    if len(all_metrics) > 1:
        print(f"\n{'='*80}")
        print(f"  CONCURRENT SWEEP — Local A100 Direct Inference")
        print(f"{'='*80}")
        print(f"  {'Conc':>6} {'Strms':>6} {'Chunks':>7} {'Med':>7} {'P95':>7} "
              f"{'P99':>7} {'Max':>7} {'WER%':>7} {'Thru':>7}")
        print(f"  {'─'*6} {'─'*6} {'─'*7} {'─'*7} {'─'*7} {'─'*7} {'─'*7} {'─'*7} {'─'*7}")
        for m in all_metrics:
            print(f"  {m['concurrency']:>6} {m['streams']:>6} {m['total_chunks']:>7} "
                  f"{m['chunk_median_ms']:>6.1f}ms {m['chunk_p95_ms']:>6.1f}ms "
                  f"{m['chunk_p99_ms']:>6.1f}ms {m['chunk_max_ms']:>6.1f}ms "
                  f"{m['wer']:>6.1f}% {m['throughput_rtx']:>6.1f}x")
        print(f"{'='*80}")

    # Save
    os.makedirs(RESULTS_DIR, exist_ok=True)
    with open(os.path.join(RESULTS_DIR, "local_concurrent_sweep.json"), "w") as f:
        json.dump(all_metrics, f, indent=2)
    print(f"\n  Saved: benchmark_results/local_concurrent_sweep.json")


if __name__ == "__main__":
    main()
