"""
Sooktam-2 batching + TTFB + concurrency benchmark.
Tests batch 1 -> N, measures throughput, TTFB, GPU usage.
"""
import sys, os, time, threading, subprocess, torch
sys.path.insert(0, '/home/ubuntu/sooktam2/src')
sys.path.insert(0, '/home/ubuntu/sooktam2')

os.environ['PYTHONPATH'] = '/home/ubuntu/sooktam2/src:' + os.environ.get('PYTHONPATH', '')

from transformers import AutoModel

REF = '/home/ubuntu/vibevoice/demo/voices/modi.wav'
REF_TEXT = 'मेरे प्यारे देशवासियों, मुझे सीतापुर के ओजस्वी ने लिखा है कि अमृत महोत्सव से जुड़ी चर्चाएं उन्हें खूब पसंद आ रही हैं।'
TEXT = 'मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.'

print('Loading Sooktam-2...', flush=True)
model = AutoModel.from_pretrained('/home/ubuntu/sooktam2', trust_remote_code=True)
mem = torch.cuda.memory_allocated() / 1e9
total = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f'Loaded. VRAM: {mem:.1f}GB / {total:.0f}GB', flush=True)

# Warmup
model.infer(ref_file=REF, ref_text=REF_TEXT, gen_text='warmup test.', tokenizer='cls', cls_language='hindi')
print('Warmed up.', flush=True)

# Sooktam-2 doesn't support native batching in infer(), so concurrency = sequential requests
# But we can measure how fast sequential requests are processed (throughput)
# and simulate concurrent load via threading

def run_single():
    t0 = time.perf_counter()
    wav, sr, _ = model.infer(ref_file=REF, ref_text=REF_TEXT, gen_text=TEXT,
                              tokenizer='cls', cls_language='hindi', nfe_step=32)
    gen = time.perf_counter() - t0
    dur = len(wav) / sr
    return gen, dur

# Sequential throughput test
print(f'\n{"="*80}', flush=True)
print(f'{"Reqs":>6} {"Mode":>12} {"Wall":>8} {"Audio":>8} {"Thru":>10} {"RTF":>7} {"GPU%":>6} {"VRAM":>7}', flush=True)
print(f'{"-"*80}', flush=True)

for n_reqs in [1, 2, 4, 8]:
    gpu_samples = []
    stop_mon = [False]
    def mon():
        while not stop_mon[0]:
            r = subprocess.run(['nvidia-smi','--query-gpu=utilization.gpu,memory.used,power.draw',
                               '--format=csv,noheader,nounits'], capture_output=True, text=True)
            parts = r.stdout.strip().split(', ')
            if len(parts) >= 3:
                gpu_samples.append((float(parts[0]), float(parts[1])/1024, float(parts[2])))
            time.sleep(0.15)

    gpu_samples.clear()
    stop_mon[0] = False
    mt = threading.Thread(target=mon, daemon=True)
    mt.start()

    t0 = time.perf_counter()
    total_audio = 0
    for i in range(n_reqs):
        gen, dur = run_single()
        total_audio += dur
    wall = time.perf_counter() - t0

    stop_mon[0] = True
    time.sleep(0.3)

    thru = total_audio / wall
    rtf = wall / (total_audio / n_reqs)
    avg_gpu = sum(g[0] for g in gpu_samples) / len(gpu_samples) if gpu_samples else 0
    avg_vram = sum(g[1] for g in gpu_samples) / len(gpu_samples) if gpu_samples else 0

    print(f'{n_reqs:>6} {"sequential":>12} {wall:>7.1f}s {total_audio:>7.1f}s {thru:>7.1f} s/s {rtf:>6.2f}x {avg_gpu:>5.0f}% {avg_vram:>5.1f}GB', flush=True)

# Concurrent via threading (simulate multiple users)
print(f'\n--- Concurrent (threaded) ---', flush=True)
for n_concurrent in [2, 4, 8]:
    gpu_samples.clear()
    stop_mon[0] = False
    mt = threading.Thread(target=mon, daemon=True)
    mt.start()

    results = []
    def worker():
        r = run_single()
        results.append(r)

    t0 = time.perf_counter()
    threads = [threading.Thread(target=worker) for _ in range(n_concurrent)]
    for t in threads: t.start()
    for t in threads: t.join()
    wall = time.perf_counter() - t0

    stop_mon[0] = True
    time.sleep(0.3)

    total_audio = sum(r[1] for r in results)
    thru = total_audio / wall
    avg_gen = sum(r[0] for r in results) / len(results)
    avg_dur = total_audio / n_concurrent
    rtf = avg_gen / avg_dur
    avg_gpu = sum(g[0] for g in gpu_samples) / len(gpu_samples) if gpu_samples else 0
    avg_vram = sum(g[1] for g in gpu_samples) / len(gpu_samples) if gpu_samples else 0

    print(f'{n_concurrent:>6} {"concurrent":>12} {wall:>7.1f}s {total_audio:>7.1f}s {thru:>7.1f} s/s {rtf:>6.2f}x {avg_gpu:>5.0f}% {avg_vram:>5.1f}GB', flush=True)

print(f'{"="*80}', flush=True)
print('Done.', flush=True)
