"""
Quantized model benchmark: bf16 vs int8 vs int4 (NF4)
Measures RTF, TTFB, and VRAM with streaming at 20 DDPM steps.
"""

import gc
import time
import threading
import torch
from transformers import BitsAndBytesConfig
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def run_streaming(model, processor, voice_path, batch_size=1, ddpm_steps=20, cfg_scale=1.3):
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."
    texts = [text] * batch_size
    voice_batch = [[voice_path]] * batch_size

    inputs = processor(
        text=texts, voice_samples=voice_batch,
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    streamer = AudioStreamer(batch_size=batch_size, stop_signal=None)
    first_chunk_times = [None] * batch_size
    total_samples = [0] * batch_size
    start_time = [None]

    def consumer(idx):
        for chunk in streamer.get_stream(idx):
            t = time.perf_counter()
            if first_chunk_times[idx] is None:
                first_chunk_times[idx] = t
            total_samples[idx] += chunk.shape[-1]

    threads = [threading.Thread(target=consumer, args=(i,), daemon=True) for i in range(batch_size)]
    for t in threads:
        t.start()

    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    start_time[0] = time.perf_counter()

    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False,
    )

    torch.cuda.synchronize()
    end_time = time.perf_counter()
    for t in threads:
        t.join(timeout=5)

    gen_time = end_time - start_time[0]
    peak_mem = torch.cuda.max_memory_allocated() / 1e9
    ttfbs = [(first_chunk_times[i] - start_time[0]) * 1000 for i in range(batch_size) if first_chunk_times[i]]
    avg_ttfb = sum(ttfbs) / len(ttfbs) if ttfbs else -1
    total_audio = sum(s / 24000.0 for s in total_samples)
    avg_audio = total_audio / batch_size
    rtf = gen_time / avg_audio if avg_audio > 0 else float("inf")

    return {
        "gen_time_s": gen_time,
        "total_audio_s": total_audio,
        "avg_audio_s": avg_audio,
        "throughput": total_audio / gen_time if gen_time > 0 else 0,
        "peak_mem_gb": peak_mem,
        "avg_ttfb_ms": avg_ttfb,
        "rtf": rtf,
    }


def load_model(model_path, quant_mode, attn_impl="flash_attention_2"):
    if quant_mode == "bf16":
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation=attn_impl)
    elif quant_mode == "int8":
        bnb_config = BitsAndBytesConfig(load_in_8bit=True)
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, quantization_config=bnb_config, device_map="cuda",
            attn_implementation=attn_impl)
    elif quant_mode == "int4":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
        )
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, quantization_config=bnb_config, device_map="cuda",
            attn_implementation=attn_impl)
    else:
        raise ValueError(f"Unknown quant mode: {quant_mode}")
    return model


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"
    batch_sizes = [1, 4, 8, 16]
    quant_modes = ["bf16", "int8", "int4"]

    processor = VibeVoiceProcessor.from_pretrained(model_path)

    for qmode in quant_modes:
        gc.collect()
        torch.cuda.empty_cache()

        print(f"\n{'='*100}")
        print(f"QUANTIZATION: {qmode.upper()}")
        print(f"{'='*100}")

        try:
            attn = "flash_attention_2"
            model = load_model(model_path, qmode, attn)
        except Exception as e:
            print(f"flash_attention_2 failed: {e}")
            try:
                attn = "sdpa"
                model = load_model(model_path, qmode, attn)
            except Exception as e2:
                print(f"sdpa also failed: {e2}")
                continue
        model.eval()

        mem = torch.cuda.memory_allocated() / 1e9
        print(f"Model VRAM: {mem:.1f} GB | Attention: {attn}")

        # Warmup
        _ = run_streaming(model, processor, voice_path, 1, ddpm_steps=20)

        print(f"{'Batch':>6} {'GenTime':>9} {'Audio/usr':>10} {'RTF':>7} {'Thruput':>10} {'TTFB':>8} {'VRAM':>7} {'Stream?':>8}")
        print(f"{'-'*75}")

        for bs in batch_sizes:
            gc.collect()
            torch.cuda.empty_cache()
            try:
                r = run_streaming(model, processor, voice_path, bs, ddpm_steps=20)
                can_stream = "YES" if r["rtf"] < 1.0 else "NO"
                print(
                    f"{bs:>6} "
                    f"{r['gen_time_s']:>8.2f}s "
                    f"{r['avg_audio_s']:>9.2f}s "
                    f"{r['rtf']:>6.3f}x "
                    f"{r['throughput']:>7.2f} s/s "
                    f"{r['avg_ttfb_ms']:>7.0f}ms "
                    f"{r['peak_mem_gb']:>6.1f}GB "
                    f"{can_stream:>7}"
                )
            except torch.cuda.OutOfMemoryError:
                torch.cuda.empty_cache()
                gc.collect()
                print(f"{bs:>6} OOM")
                break
            except Exception as e:
                print(f"{bs:>6} ERROR: {e}")
                break

        del model
        gc.collect()
        torch.cuda.empty_cache()

    print(f"\n{'='*100}")


if __name__ == "__main__":
    main()
