"""
500 concurrent requests benchmark with multi-replica to max out A100 80GB.
Each replica runs its own batch queue on a separate CUDA stream.
"""
import torch
import torch.nn as nn
import time
import numpy as np
import threading
import queue
from dataclasses import dataclass, field
from concurrent.futures import Future
from typing import List

# Patch swoosh
import zipvoice.models.modules.scaling as scaling
class FSL(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 4.0) - 0.08 * x - 0.035
class FSR(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 1.0) - 0.08 * x - 0.313261687
scaling.SwooshL = FSL
scaling.SwooshR = FSR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import load_models_gpu, process_audio
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
from zipvoice.tokenizer.tokenizer import EmiliaTokenizer
from zipvoice.utils.checkpoint import load_checkpoint
from zipvoice.utils.feature import VocosFbank
from linacodec.vocoder.vocos import Vocos
from torch.nn.utils import parametrize
from huggingface_hub import snapshot_download
import json


@dataclass
class Request:
    text: str
    submit_time: float = field(default_factory=time.perf_counter)
    future: Future = field(default_factory=Future)


@torch.inference_mode()
def batched_generate(texts, model, vocoder, tokenizer, prompt_tokens,
                     prompt_features, prompt_features_lens, prompt_rms,
                     num_steps=4, guidance_scale=3.0, speed=1.3, t_shift=0.5):
    bs = len(texts)
    all_tokens = tokenizer.texts_to_token_ids(texts)
    bp_tokens = prompt_tokens * bs
    bp_features = prompt_features.expand(bs, -1, -1)
    bp_lens = prompt_features_lens.expand(bs)

    pred_feat, pred_lens, _, _ = model.sample(
        tokens=all_tokens, prompt_tokens=bp_tokens,
        prompt_features=bp_features, prompt_features_lens=bp_lens,
        speed=speed, t_shift=t_shift, duration='predict',
        num_step=num_steps, guidance_scale=guidance_scale,
    )
    wav_batch = vocoder.decode(pred_feat.permute(0, 2, 1) / 0.1).clamp(-1, 1)
    if prompt_rms < 0.1:
        wav_batch = wav_batch * (prompt_rms / 0.1)

    results = []
    for i in range(bs):
        n_samples = min(int(pred_lens[i].item()) * 512, wav_batch.shape[1])
        results.append(wav_batch[i, :n_samples].cpu().numpy())
    return results


class Replica:
    """Single model replica with its own CUDA stream and batch queue."""

    def __init__(self, replica_id, model, vocoder, tokenizer, prompt_data,
                 max_batch=32, max_wait_ms=20, num_steps=4):
        self.id = replica_id
        self.model = model
        self.vocoder = vocoder
        self.tokenizer = tokenizer
        self.prompt = prompt_data
        self.max_batch = max_batch
        self.max_wait = max_wait_ms / 1000
        self.num_steps = num_steps
        self.stream = torch.cuda.Stream()
        self.queue = queue.Queue()
        self._running = False

    def start(self):
        self._running = True
        self._thread = threading.Thread(target=self._loop, daemon=True)
        self._thread.start()

    def stop(self):
        self._running = False

    def submit(self, req: Request):
        self.queue.put(req)

    def _loop(self):
        while self._running:
            batch = []
            try:
                batch.append(self.queue.get(timeout=0.05))
            except queue.Empty:
                continue
            deadline = time.perf_counter() + self.max_wait
            while len(batch) < self.max_batch:
                remaining = deadline - time.perf_counter()
                if remaining <= 0:
                    break
                try:
                    batch.append(self.queue.get(timeout=remaining))
                except queue.Empty:
                    break
            self._process(batch)

    def _process(self, batch):
        texts = [r.text for r in batch]
        try:
            with torch.cuda.stream(self.stream):
                results = batched_generate(
                    texts, self.model, self.vocoder, self.tokenizer,
                    self.prompt['prompt_tokens'], self.prompt['prompt_features'],
                    self.prompt['prompt_features_lens'], self.prompt['prompt_rms'],
                    num_steps=self.num_steps,
                )
            self.stream.synchronize()
            done_time = time.perf_counter()
            for req, result in zip(batch, results):
                req.future.set_result((result, done_time))
        except Exception as e:
            done_time = time.perf_counter()
            for req in batch:
                req.future.set_exception(e)


class MultiReplicaEngine:
    def __init__(self, num_replicas, max_batch=32, max_wait_ms=20, num_steps=4):
        self.num_replicas = num_replicas
        self.max_batch = max_batch
        self.num_steps = num_steps

        print(f"Loading {num_replicas} model replicas...")
        t0 = time.time()

        # Load base model once
        model_path = snapshot_download("YatharthS/LuxTTS")
        token_file = f"{model_path}/tokens.txt"
        model_ckpt = f"{model_path}/model.pt"
        model_config_path = f"{model_path}/config.json"
        with open(model_config_path, "r") as f:
            model_config = json.load(f)
        tokenizer = EmiliaTokenizer(token_file=token_file)
        tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}

        self.replicas = []
        for i in range(num_replicas):
            # Each replica gets its own model copy
            model = ZipVoiceDistill(**model_config["model"], **tokenizer_config)
            load_checkpoint(filename=model_ckpt, model=model, strict=True)
            model = model.to('cuda').eval()

            vocos = Vocos.from_hparams(f'{model_path}/vocoder/config.yaml').to('cuda')
            parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[0], "weight")
            parametrize.remove_parametrizations(vocos.upsampler.upsample_layers[1], "weight")
            vocos.load_state_dict(torch.load(f'{model_path}/vocoder/vocos.bin', map_location='cuda'))
            vocos.freq_range = 12000

            replica = Replica(i, model, vocos, tokenizer, None,
                              max_batch=max_batch, max_wait_ms=max_wait_ms, num_steps=num_steps)
            self.replicas.append(replica)

        load_time = time.time() - t0
        vram = torch.cuda.memory_allocated() / 1024**3
        print(f"  {num_replicas} replicas loaded in {load_time:.1f}s, VRAM: {vram:.1f} GB")
        self._next = 0

    def encode_prompt(self, audio_path, duration=5):
        # Use first replica's model for encoding via LuxTTS helper
        from zipvoice.modeling_utils import process_audio
        from transformers import pipeline
        transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base", device='cuda')
        feature_extractor = VocosFbank()
        tokenizer = self.replicas[0].tokenizer

        prompt_tokens, prompt_features_lens, prompt_features, prompt_rms = process_audio(
            audio_path, transcriber, tokenizer, feature_extractor, 'cuda', duration=duration)

        prompt_data = {
            'prompt_tokens': prompt_tokens,
            'prompt_features': prompt_features,
            'prompt_features_lens': prompt_features_lens,
            'prompt_rms': prompt_rms,
        }
        for r in self.replicas:
            r.prompt = prompt_data
        # Cleanup transcriber
        del transcriber
        torch.cuda.empty_cache()
        print("Prompt encoded for all replicas.")

    def start(self):
        for r in self.replicas:
            r.start()
        print(f"All {self.num_replicas} replicas started.")

    def stop(self):
        for r in self.replicas:
            r.stop()

    def submit(self, text: str) -> Request:
        req = Request(text=text)
        # Round-robin across replicas
        self.replicas[self._next % self.num_replicas].submit(req)
        self._next += 1
        return req


def run_benchmark():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--replicas", type=int, default=8)
    parser.add_argument("--max-batch", type=int, default=32)
    parser.add_argument("--max-wait-ms", type=float, default=20)
    parser.add_argument("--num-steps", type=int, default=4)
    parser.add_argument("--concurrency", type=int, default=500)
    parser.add_argument("--ref-audio", default="/home/ubuntu/LuxTTS/ref_audio.wav")
    args = parser.parse_args()

    engine = MultiReplicaEngine(
        num_replicas=args.replicas,
        max_batch=args.max_batch,
        max_wait_ms=args.max_wait_ms,
        num_steps=args.num_steps,
    )
    engine.encode_prompt(args.ref_audio)
    engine.start()

    TEXTS = [
        "Hello, this is a test of the batched inference system.",
        "The quick brown fox jumps over the lazy dog near the riverbank.",
        "Welcome to the annual technology conference on artificial intelligence.",
        "Machine learning has transformed how we approach complex problems today.",
        "Natural language processing enables computers to understand human speech.",
        "Deep learning models continue to improve in both speed and accuracy.",
        "The future of computing lies in efficient parallel processing architectures.",
        "Voice synthesis technology has made remarkable progress in recent years.",
        "Today we demonstrate the power of batched inference for text to speech.",
        "High concurrency serving requires careful optimization of GPU resources.",
    ]

    vram_before = torch.cuda.memory_allocated() / 1024**3

    # Warmup
    print("\nWarmup...")
    for i in range(args.replicas):
        req = engine.submit("Warmup.")
        req.future.result(timeout=30)
    time.sleep(0.5)
    torch.cuda.reset_peak_memory_stats()

    conc = args.concurrency
    print(f"\n{'='*70}")
    print(f"BENCHMARK: {conc} concurrent requests")
    print(f"Config: {args.replicas} replicas, max_batch={args.max_batch}, steps={args.num_steps}")
    print(f"{'='*70}")

    requests = []
    t_start = time.perf_counter()

    # Submit all at once
    for i in range(conc):
        text = TEXTS[i % len(TEXTS)]
        req = engine.submit(text)
        requests.append(req)

    # Collect results
    ttfbs = []
    audio_durs = []
    for req in requests:
        result, done_time = req.future.result(timeout=300)
        ttfb = done_time - req.submit_time
        ttfbs.append(ttfb)
        audio_durs.append(len(result) / 48000)

    t_end = time.perf_counter()
    total_wall = t_end - t_start
    peak_vram = torch.cuda.max_memory_allocated() / 1024**3

    # Stats
    ttfbs_ms = np.array(ttfbs) * 1000
    total_audio = sum(audio_durs)
    avg_audio = np.mean(audio_durs)

    print(f"\n--- Results ---")
    print(f"  Total wall time:    {total_wall:.3f}s")
    print(f"  Throughput:         {conc/total_wall:.1f} req/s")
    print(f"  Total audio gen:    {total_audio:.1f}s ({total_audio/60:.1f} min)")
    print(f"  Avg audio/req:      {avg_audio:.2f}s")
    print(f"  Effective RTF:      {total_wall/total_audio:.5f} ({total_audio/total_wall:.0f}x realtime)")
    print(f"  Peak VRAM:          {peak_vram:.1f} GB / 80 GB ({peak_vram/80*100:.0f}%)")
    print(f"")
    print(f"  --- TTFB ---")
    print(f"  Min:    {np.min(ttfbs_ms):8.0f}ms")
    print(f"  p25:    {np.percentile(ttfbs_ms, 25):8.0f}ms")
    print(f"  p50:    {np.percentile(ttfbs_ms, 50):8.0f}ms")
    print(f"  p75:    {np.percentile(ttfbs_ms, 75):8.0f}ms")
    print(f"  p90:    {np.percentile(ttfbs_ms, 90):8.0f}ms")
    print(f"  p95:    {np.percentile(ttfbs_ms, 95):8.0f}ms")
    print(f"  p99:    {np.percentile(ttfbs_ms, 99):8.0f}ms")
    print(f"  Max:    {np.max(ttfbs_ms):8.0f}ms")
    print(f"  Avg:    {np.mean(ttfbs_ms):8.0f}ms")

    engine.stop()
    print("\nDone.")


if __name__ == "__main__":
    run_benchmark()
