"""
Test TRUE parallel execution on A100 using:
1. Multiple CUDA streams
2. Multiple processes (bypass GIL)
"""
import torch
import torch.nn as nn
import time
import numpy as np
import torch.multiprocessing as mp
from multiprocessing import Process, Queue as MPQueue
import os

import zipvoice.models.modules.scaling as scaling
class FSL(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 4.0) - 0.08 * x - 0.035
class FSR(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 1.0) - 0.08 * x - 0.313261687
scaling.SwooshL = FSL
scaling.SwooshR = FSR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import generate as orig_generate
from batched_inference import batched_generate

REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"
TEXT = "The quick brown fox jumps over the lazy dog near the riverbank."


def test_cuda_streams():
    """Test if multiple CUDA streams actually overlap execution."""
    print("=" * 70)
    print("TEST 1: CUDA Streams Overlap")
    print("=" * 70)

    tts = LuxTTS(device='cuda')
    enc = tts.encode_prompt(REF_AUDIO, duration=5)
    model = tts.model
    vocoder = tts.vocos
    tokenizer = tts.tokenizer

    # Warmup
    for _ in range(3):
        orig_generate(enc['prompt_tokens'], enc['prompt_features_lens'],
                      enc['prompt_features'], enc['prompt_rms'],
                      TEXT, model, vocoder, tokenizer, num_step=1)

    # Baseline: sequential B=1
    times_seq = []
    for _ in range(10):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        orig_generate(enc['prompt_tokens'], enc['prompt_features_lens'],
                      enc['prompt_features'], enc['prompt_rms'],
                      TEXT, model, vocoder, tokenizer, num_step=1)
        torch.cuda.synchronize()
        times_seq.append(time.perf_counter() - t0)
    seq_ms = np.mean(times_seq[2:]) * 1000
    print(f"\nSequential B=1: {seq_ms:.1f}ms per request")

    # Test: launch N requests on N separate streams, see if total < N * sequential
    for n_streams in [2, 4, 8]:
        streams = [torch.cuda.Stream() for _ in range(n_streams)]
        events_start = [torch.cuda.Event(enable_timing=True) for _ in range(n_streams)]
        events_end = [torch.cuda.Event(enable_timing=True) for _ in range(n_streams)]

        torch.cuda.synchronize()
        wall_start = time.perf_counter()

        for i in range(n_streams):
            with torch.cuda.stream(streams[i]):
                events_start[i].record()
                orig_generate(enc['prompt_tokens'], enc['prompt_features_lens'],
                              enc['prompt_features'], enc['prompt_rms'],
                              TEXT, model, vocoder, tokenizer, num_step=1)
                events_end[i].record()

        torch.cuda.synchronize()
        wall_total = (time.perf_counter() - wall_start) * 1000

        # Per-stream GPU time
        gpu_times = [events_start[i].elapsed_time(events_end[i]) for i in range(n_streams)]

        seq_expected = seq_ms * n_streams
        overlap = (seq_expected - wall_total) / seq_expected * 100

        print(f"\n  {n_streams} CUDA streams:")
        print(f"    Wall time:      {wall_total:.1f}ms")
        print(f"    Sequential:     {seq_expected:.1f}ms (if no overlap)")
        print(f"    Overlap:        {overlap:.1f}%")
        print(f"    Per-stream GPU: {[f'{t:.1f}ms' for t in gpu_times]}")
        print(f"    Effective TTFB: {wall_total:.1f}ms for all {n_streams} requests")

    # Test: streams + batching combined
    print(f"\n{'='*70}")
    print("TEST 2: Streams + Batching Combined")
    print(f"{'='*70}")

    for n_streams, batch_per_stream in [(2, 16), (4, 8), (8, 4), (2, 32), (4, 16)]:
        total_reqs = n_streams * batch_per_stream
        streams = [torch.cuda.Stream() for _ in range(n_streams)]

        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        t0 = time.perf_counter()

        for i in range(n_streams):
            texts = [TEXT] * batch_per_stream
            with torch.cuda.stream(streams[i]):
                batched_generate(texts, model, vocoder, tokenizer,
                                 enc['prompt_tokens'], enc['prompt_features'],
                                 enc['prompt_features_lens'], enc['prompt_rms'],
                                 num_steps=1)

        torch.cuda.synchronize()
        wall_ms = (time.perf_counter() - t0) * 1000
        vram = torch.cuda.max_memory_allocated() / 1024**3

        print(f"  {n_streams} streams × B={batch_per_stream} = {total_reqs} reqs: "
              f"{wall_ms:.0f}ms wall, {wall_ms/total_reqs:.1f}ms/req, "
              f"TTFB={wall_ms:.0f}ms, VRAM={vram:.1f}GB")

    # ─── Final: simulate 500 concurrent ───
    print(f"\n{'='*70}")
    print("TEST 3: 500 Concurrent — Streams + Batching")
    print(f"{'='*70}")

    target = 500
    for n_streams, batch_per_stream, steps in [
        (4, 32, 1), (8, 16, 1), (8, 32, 1), (4, 32, 2), (8, 16, 2),
    ]:
        reqs_per_round = n_streams * batch_per_stream
        rounds = int(np.ceil(target / reqs_per_round))
        streams = [torch.cuda.Stream() for _ in range(n_streams)]

        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
        all_ttfbs = []
        t_global = time.perf_counter()

        for r in range(rounds):
            t_round = time.perf_counter()
            for i in range(n_streams):
                n = min(batch_per_stream, target - r * reqs_per_round - i * batch_per_stream)
                if n <= 0:
                    continue
                texts = [TEXT] * n
                with torch.cuda.stream(streams[i]):
                    batched_generate(texts, model, vocoder, tokenizer,
                                     enc['prompt_tokens'], enc['prompt_features'],
                                     enc['prompt_features_lens'], enc['prompt_rms'],
                                     num_steps=steps)
            torch.cuda.synchronize()
            round_ms = (time.perf_counter() - t_round) * 1000
            # All requests in this round have same TTFB = time from global start
            round_ttfb = (time.perf_counter() - t_global) * 1000
            served = min(reqs_per_round, target - r * reqs_per_round)
            all_ttfbs.extend([round_ttfb] * served)

        total_ms = (time.perf_counter() - t_global) * 1000
        vram = torch.cuda.max_memory_allocated() / 1024**3
        ttfbs = np.array(all_ttfbs[:target])

        print(f"\n  {n_streams} streams × B={batch_per_stream} × {steps} steps ({reqs_per_round}/round, {rounds} rounds):")
        print(f"    Total wall:  {total_ms:.0f}ms")
        print(f"    Throughput:  {target/(total_ms/1000):.0f} req/s")
        print(f"    VRAM:        {vram:.1f} GB ({vram/80*100:.0f}%)")
        print(f"    TTFB min:    {np.min(ttfbs):.0f}ms")
        print(f"    TTFB p50:    {np.percentile(ttfbs, 50):.0f}ms")
        print(f"    TTFB p95:    {np.percentile(ttfbs, 95):.0f}ms")
        print(f"    TTFB max:    {np.max(ttfbs):.0f}ms")


if __name__ == "__main__":
    test_cuda_streams()
