"""
Multi-process parallel benchmark with TTFB + concurrency measurement.
Each process loads its own model, handles batch=1 requests.
Measures: TTFB per user, total throughput, GPU util, VRAM.
"""

import os
import sys
import time
import subprocess
import threading
import multiprocessing as mp
from pathlib import Path


def worker_main(worker_id, ready_barrier, go_event, num_requests, result_file):
    import torch
    from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
    from vibevoice.modular.streamer import AudioStreamer
    from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

    processor = VibeVoiceProcessor.from_pretrained('microsoft/VibeVoice-1.5B')
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            'microsoft/VibeVoice-1.5B', torch_dtype=torch.bfloat16, device_map='cuda',
            attn_implementation='flash_attention_2')
    except Exception:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            'microsoft/VibeVoice-1.5B', torch_dtype=torch.bfloat16, device_map='cuda',
            attn_implementation='sdpa')
    model.eval()
    model.set_ddpm_inference_steps(num_steps=10)

    voice = 'demo/voices/modi.wav'
    text = 'Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.'

    inp = processor(text=[text], voice_samples=[[voice]], padding=True, return_tensors='pt', return_attention_mask=True)
    for k, v in inp.items():
        if torch.is_tensor(v):
            inp[k] = v.to('cuda')

    # Warmup
    _ = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
                        generation_config={'do_sample': False}, verbose=False, is_prefill=True, show_progress_bar=False)

    ready_barrier.wait()
    go_event.wait()

    results = []
    for req_i in range(num_requests):
        streamer = AudioStreamer(batch_size=1, stop_signal=None)
        first_chunk = [None]
        total_samp = [0]

        def consumer():
            for chunk in streamer.get_stream(0):
                t = time.perf_counter()
                if first_chunk[0] is None:
                    first_chunk[0] = t
                total_samp[0] += chunk.shape[-1]

        th = threading.Thread(target=consumer, daemon=True)
        th.start()

        torch.manual_seed(42 + req_i)
        torch.cuda.manual_seed_all(42 + req_i)

        torch.cuda.synchronize()
        t0 = time.perf_counter()
        out = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
                              generation_config={'do_sample': False}, verbose=False, is_prefill=True,
                              audio_streamer=streamer, show_progress_bar=False)
        torch.cuda.synchronize()
        gen = time.perf_counter() - t0
        th.join(timeout=5)

        ttfb = (first_chunk[0] - t0) * 1000 if first_chunk[0] else -1
        dur = total_samp[0] / 24000.0
        rtf = gen / dur if dur > 0 else 999
        results.append(f'{ttfb:.1f},{gen:.3f},{dur:.3f},{rtf:.3f}')

    with open(result_file, 'w') as f:
        f.write('\n'.join(results))


def monitor_gpu(stop_event, samples_list):
    while not stop_event.is_set():
        r = subprocess.run(
            ['nvidia-smi', '--query-gpu=utilization.gpu,memory.used,power.draw', '--format=csv,noheader,nounits'],
            capture_output=True, text=True)
        parts = r.stdout.strip().split(', ')
        if len(parts) >= 3:
            samples_list.append((float(parts[0]), float(parts[1]) / 1024, float(parts[2])))
        time.sleep(0.2)


def run_test(num_procs, requests_per_proc=2):
    result_dir = Path('/tmp/parallel_results')
    result_dir.mkdir(exist_ok=True)

    ctx = mp.get_context('spawn')
    ready_barrier = ctx.Barrier(num_procs)
    go_event = ctx.Event()

    procs = []
    for i in range(num_procs):
        rf = str(result_dir / f'worker_{i}.txt')
        p = ctx.Process(target=worker_main, args=(i, ready_barrier, go_event, requests_per_proc, rf))
        p.start()
        procs.append(p)

    # Wait for all to load and warmup (barrier handles this)
    # Give time for barrier sync
    time.sleep(2)
    while not all(p.is_alive() for p in procs):
        time.sleep(0.5)
    # Extra wait for warmup to complete
    time.sleep(max(15 * num_procs, 30))

    gpu_samples = []
    stop_mon = threading.Event()
    mon = threading.Thread(target=monitor_gpu, args=(stop_mon, gpu_samples), daemon=True)
    mon.start()

    t0 = time.perf_counter()
    go_event.set()

    for p in procs:
        p.join(timeout=300)
    wall = time.perf_counter() - t0

    stop_mon.set()
    time.sleep(0.5)

    # Collect results
    all_ttfbs = []
    all_rtfs = []
    total_audio = 0
    total_reqs = 0

    for i in range(num_procs):
        rf = result_dir / f'worker_{i}.txt'
        if rf.exists():
            for line in rf.read_text().strip().split('\n'):
                if line:
                    parts = line.split(',')
                    ttfb, gen, dur, rtf = float(parts[0]), float(parts[1]), float(parts[2]), float(parts[3])
                    all_ttfbs.append(ttfb)
                    all_rtfs.append(rtf)
                    total_audio += dur
                    total_reqs += 1

    thru = total_audio / wall if wall > 0 else 0
    avg_gpu = sum(g[0] for g in gpu_samples) / len(gpu_samples) if gpu_samples else 0
    avg_vram = sum(g[1] for g in gpu_samples) / len(gpu_samples) if gpu_samples else 0
    avg_pow = sum(g[2] for g in gpu_samples) / len(gpu_samples) if gpu_samples else 0

    ttfbs_sorted = sorted(all_ttfbs) if all_ttfbs else [0]
    rtfs_sorted = sorted(all_rtfs) if all_rtfs else [0]

    return {
        'procs': num_procs,
        'reqs': total_reqs,
        'wall': wall,
        'total_audio': total_audio,
        'thru': thru,
        'ttfb_min': ttfbs_sorted[0],
        'ttfb_avg': sum(ttfbs_sorted) / len(ttfbs_sorted),
        'ttfb_p50': ttfbs_sorted[len(ttfbs_sorted) // 2],
        'ttfb_p95': ttfbs_sorted[min(int(len(ttfbs_sorted) * 0.95), len(ttfbs_sorted) - 1)],
        'ttfb_max': ttfbs_sorted[-1],
        'rtf_avg': sum(rtfs_sorted) / len(rtfs_sorted),
        'avg_gpu': avg_gpu,
        'avg_vram': avg_vram,
        'avg_pow': avg_pow,
    }


def main():
    print('Multi-process parallel TTS benchmark')
    print('Each process: own model copy, batch=1, 10 steps, cfg=1.3, streaming')
    print()
    print(f'{"Procs":>6} {"Reqs":>5} {"Wall":>7} {"Thru":>9} {"TTFB_avg":>9} {"TTFB_p95":>9} {"RTF_avg":>8} {"GPU%":>6} {"VRAM":>7} {"Power":>7}')
    print('-' * 85)

    for n in [1, 2, 4, 6, 8]:
        r = run_test(n, requests_per_proc=3)
        print(
            f'{r["procs"]:>6} '
            f'{r["reqs"]:>5} '
            f'{r["wall"]:>6.1f}s '
            f'{r["thru"]:>6.1f} s/s '
            f'{r["ttfb_avg"]:>8.0f}ms '
            f'{r["ttfb_p95"]:>8.0f}ms '
            f'{r["rtf_avg"]:>7.3f}x '
            f'{r["avg_gpu"]:>5.0f}% '
            f'{r["avg_vram"]:>5.1f}GB '
            f'{r["avg_pow"]:>6.0f}W'
        )

    print('-' * 85)


if __name__ == '__main__':
    main()
