"""
Profile exactly where time is spent in LuxTTS inference.
Break down every stage to find the bottleneck.
"""
import torch
import torch.nn as nn
import time
import numpy as np

# Patch swoosh
import zipvoice.models.modules.scaling as scaling
class FastSwooshL(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 4.0) - 0.08 * x - 0.035
class FastSwooshR(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 1.0) - 0.08 * x - 0.313261687
scaling.SwooshL = FastSwooshL
scaling.SwooshR = FastSwooshR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import generate as orig_generate
from zipvoice.utils.common import make_pad_mask
from zipvoice.models.modules.solver import get_time_steps

REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"
TEXT = "Ladies and gentlemen, welcome to the annual technology conference. Today we will discuss the latest advances in artificial intelligence, machine learning, and natural language processing."

def profile():
    print("=" * 70)
    print("LuxTTS Detailed Profiling — A100 80GB")
    print("=" * 70)

    tts = LuxTTS(device='cuda')
    enc = tts.encode_prompt(REF_AUDIO, duration=5)
    model = tts.model
    vocoder = tts.vocos
    tokenizer = tts.tokenizer
    device = next(model.parameters()).device

    prompt_tokens = enc['prompt_tokens']
    prompt_features = enc['prompt_features']
    prompt_features_lens = enc['prompt_features_lens']

    # Warmup
    for _ in range(3):
        orig_generate(prompt_tokens, prompt_features_lens, prompt_features,
                      enc['prompt_rms'], TEXT, model, vocoder, tokenizer, num_step=4)

    # --- Stage-by-stage profiling (B=1) ---
    print("\n--- Stage Breakdown (B=1) ---")
    N = 10

    # 1. Tokenization
    times = []
    for _ in range(N):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        tokens = tokenizer.texts_to_token_ids([TEXT])
        torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    print(f"  1. Tokenization:      {np.mean(times[1:])*1000:.2f}ms")

    # 2. Text embed + condition (text_encoder + duration)
    speed = 1.0 * 1.3
    times = []
    for _ in range(N):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.inference_mode():
            text_condition, padding_mask = model.forward_text_inference_ratio_duration(
                tokens=tokens,
                prompt_tokens=prompt_tokens,
                prompt_features_lens=prompt_features_lens,
                speed=speed,
            )
        torch.cuda.synchronize()
        times.append(time.perf_counter() - t0)
    print(f"  2. Text encoder+cond: {np.mean(times[1:])*1000:.2f}ms")
    print(f"     text_condition shape: {text_condition.shape}")

    # 3. Setup (noise, speech condition)
    batch_size, num_frames, _ = text_condition.shape
    with torch.inference_mode():
        speech_condition = torch.nn.functional.pad(
            prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
        )
        speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
        speech_condition = torch.where(
            speech_condition_mask.unsqueeze(-1),
            torch.zeros_like(speech_condition),
            speech_condition,
        )
        x0 = torch.randn(batch_size, num_frames, prompt_features.size(-1), device=device)

    # 4. Flow matching steps (fm_decoder) — THE CORE
    num_steps = 4
    timesteps = get_time_steps(t_start=0.0, t_end=1.0, num_step=num_steps, t_shift=0.5, device=device)

    # Time a single fm_decoder call
    times_single = []
    for _ in range(N):
        t_cur = timesteps[0]
        guidance_scale = torch.tensor(3.0, dtype=t_cur.dtype, device=device)
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.inference_mode():
            vt = model.fm_decoder(
                x=x0, t=t_cur.unsqueeze(0), padding_mask=padding_mask,
            )
        torch.cuda.synchronize()
        times_single.append(time.perf_counter() - t0)
    print(f"  3. Single fm_decoder:  {np.mean(times_single[1:])*1000:.2f}ms")

    # Time full solver (4 steps with CFG)
    times_solver = []
    for _ in range(N):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.inference_mode():
            x1 = model.solver.sample(
                x=x0, text_condition=text_condition,
                speech_condition=speech_condition,
                padding_mask=padding_mask,
                num_step=num_steps, guidance_scale=3.0, t_shift=0.5,
            )
        torch.cuda.synchronize()
        times_solver.append(time.perf_counter() - t0)
    print(f"  4. Solver (4 steps):   {np.mean(times_solver[1:])*1000:.2f}ms")
    print(f"     (per step avg:      {np.mean(times_solver[1:])*1000/num_steps:.2f}ms)")

    # 5. Vocoder
    with torch.inference_mode():
        pred_features = x1.permute(0, 2, 1) / 0.1
    times_voc = []
    for _ in range(N):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.inference_mode():
            wav = vocoder.decode(pred_features).clamp(-1, 1)
        torch.cuda.synchronize()
        times_voc.append(time.perf_counter() - t0)
    audio_dur = wav.shape[-1] / 48000
    print(f"  5. Vocoder decode:     {np.mean(times_voc[1:])*1000:.2f}ms")
    print(f"     audio: {audio_dur:.2f}s")

    # Total
    total = (np.mean(times[1:]) + np.mean(times_solver[1:]) + np.mean(times_voc[1:])) * 1000
    print(f"\n  TOTAL:                {total:.1f}ms for {audio_dur:.2f}s audio = {audio_dur/(total/1000):.0f}x realtime")

    # --- Now check the DiffusionModel wrapper (does CFG double the work?) ---
    print("\n--- Checking CFG overhead ---")

    # Check if the distill model uses CFG (doubles forward passes)
    from zipvoice.models.modules.solver import DistillDiffusionModel
    print(f"  Solver type: {type(model.solver).__name__}")
    print(f"  Model type: {type(model.solver.model).__name__}")

    # Read the forward of DistillDiffusionModel
    import inspect
    src = inspect.getsource(model.solver.model.forward)
    if 'guidance_scale' in src and ('2 *' in src or 'cat' in src or 'repeat' in src):
        print("  CFG: YES — model does 2x forward per step (conditional + unconditional)")
    else:
        print("  CFG: checking source...")
    # Print the key lines
    for line in src.split('\n'):
        line = line.strip()
        if any(k in line for k in ['guidance', 'model_func', 'return', 'cat', 'repeat', '2 *']):
            print(f"    {line}")

    # --- Batched profiling ---
    print("\n--- Batched fm_decoder timing ---")
    for bs in [1, 4, 8, 16, 32, 64]:
        x_batch = torch.randn(bs, num_frames, 100, device=device)
        pm_batch = padding_mask.expand(bs, -1)
        t_val = timesteps[0].unsqueeze(0).expand(bs)

        times_b = []
        for _ in range(6):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            with torch.inference_mode():
                vt = model.fm_decoder(x=x_batch, t=t_val, padding_mask=pm_batch)
            torch.cuda.synchronize()
            times_b.append(time.perf_counter() - t0)
        avg = np.mean(times_b[1:]) * 1000
        per_item = avg / bs
        print(f"  B={bs:3d}: {avg:.1f}ms total, {per_item:.2f}ms/item, {audio_dur/(per_item/1000):.0f}x realtime/item")

    peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"\nPeak VRAM: {peak:.2f} GB")


if __name__ == "__main__":
    profile()
