"""
Proper profiling with correct fm_decoder input format.
"""
import torch
import torch.nn as nn
import time
import numpy as np

import zipvoice.models.modules.scaling as scaling
class FSL(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 4.0) - 0.08 * x - 0.035
class FSR(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 1.0) - 0.08 * x - 0.313261687
scaling.SwooshL = FSL
scaling.SwooshR = FSR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import generate as orig_generate
from zipvoice.models.modules.solver import get_time_steps

REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"
TEXT = "Ladies and gentlemen, welcome to the annual technology conference. Today we will discuss the latest advances in artificial intelligence, machine learning, and natural language processing."

def timeit(fn, n=10, warmup=2):
    with torch.inference_mode():
        for _ in range(warmup):
            fn()
        times = []
        for _ in range(n):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            fn()
            torch.cuda.synchronize()
            times.append(time.perf_counter() - t0)
    return np.mean(times) * 1000  # ms

def main():
    print("=" * 70)
    print("LuxTTS Profiling v2 — A100 80GB")
    print("=" * 70)

    tts = LuxTTS(device='cuda')
    enc = tts.encode_prompt(REF_AUDIO, duration=5)
    model = tts.model
    vocoder = tts.vocos
    tokenizer = tts.tokenizer
    device = next(model.parameters()).device

    # Full end-to-end first
    print("\n--- End-to-end (B=1) ---")
    for steps in [1, 2, 4]:
        ms = timeit(lambda: orig_generate(
            enc['prompt_tokens'], enc['prompt_features_lens'], enc['prompt_features'],
            enc['prompt_rms'], TEXT, model, vocoder, tokenizer, num_step=steps
        ))
        # Get audio duration
        wav = orig_generate(enc['prompt_tokens'], enc['prompt_features_lens'], enc['prompt_features'],
                           enc['prompt_rms'], TEXT, model, vocoder, tokenizer, num_step=steps)
        adur = wav.shape[-1] / 48000
        print(f"  steps={steps}: {ms:.1f}ms, {adur:.2f}s audio, {adur/(ms/1000):.0f}x realtime")

    # --- Stage breakdown ---
    print("\n--- Stage Breakdown (B=1, steps=4) ---")

    # 1. Tokenize
    ms = timeit(lambda: tokenizer.texts_to_token_ids([TEXT]))
    print(f"  Tokenize:         {ms:.2f}ms")

    # 2. model.sample internals — use hooks to time
    tokens = tokenizer.texts_to_token_ids([TEXT])
    speed = 1.3

    # 2a. Text inference ratio duration (text_encoder + duration predict)
    ms = timeit(lambda: model.forward_text_inference_ratio_duration(
        tokens=tokens, prompt_tokens=enc['prompt_tokens'],
        prompt_features_lens=enc['prompt_features_lens'], speed=speed,
    ))
    print(f"  Text encode+dur:  {ms:.2f}ms")

    # Get text_condition and padding_mask for next stages
    with torch.inference_mode():
        text_condition, padding_mask = model.forward_text_inference_ratio_duration(
            tokens=tokens, prompt_tokens=enc['prompt_tokens'],
            prompt_features_lens=enc['prompt_features_lens'], speed=speed,
        )
    B, num_frames, F = text_condition.shape
    print(f"  → num_frames={num_frames}, feat_dim={F}")

    # 2b. Setup speech condition + noise
    from zipvoice.utils.common import make_pad_mask
    with torch.inference_mode():
        speech_condition = torch.nn.functional.pad(
            enc['prompt_features'], (0, 0, 0, num_frames - enc['prompt_features'].size(1)))
        speech_mask = make_pad_mask(enc['prompt_features_lens'], num_frames)
        speech_condition = torch.where(speech_mask.unsqueeze(-1), torch.zeros_like(speech_condition), speech_condition)
        x0 = torch.randn(B, num_frames, F, device=device)

    # 2c. Single forward_fm_decoder call
    timesteps = get_time_steps(0.0, 1.0, 4, 0.5, device)
    t_cur = timesteps[0]
    gs = torch.tensor(3.0, device=device)

    ms = timeit(lambda: model.forward_fm_decoder(
        t=t_cur, xt=x0, text_condition=text_condition,
        speech_condition=speech_condition, padding_mask=padding_mask,
        guidance_scale=gs,
    ))
    print(f"  fm_decoder (1x):  {ms:.2f}ms")
    print(f"  fm_decoder (4x):  {ms*4:.2f}ms (4 steps)")

    # 2d. Full solver
    ms_solver = timeit(lambda: model.solver.sample(
        x=x0, text_condition=text_condition, speech_condition=speech_condition,
        padding_mask=padding_mask, num_step=4, guidance_scale=3.0, t_shift=0.5,
    ))
    print(f"  Solver (4 steps): {ms_solver:.2f}ms")

    # 2e. Vocoder
    with torch.inference_mode():
        pred = model.solver.sample(x=x0, text_condition=text_condition,
            speech_condition=speech_condition, padding_mask=padding_mask,
            num_step=4, guidance_scale=3.0, t_shift=0.5)
        pred_perm = pred.permute(0, 2, 1) / 0.1

    ms_voc = timeit(lambda: vocoder.decode(pred_perm).clamp(-1, 1))
    wav = vocoder.decode(pred_perm).clamp(-1, 1)
    adur = wav.shape[-1] / 48000
    print(f"  Vocoder:          {ms_voc:.2f}ms → {adur:.2f}s audio")

    print(f"\n  === SUMMARY (B=1) ===")
    print(f"  fm_decoder is {ms*4/(ms*4 + ms_voc + 28):.0%} of total time")

    # --- Batched fm_decoder scaling ---
    print("\n--- fm_decoder Batch Scaling (seq_len={}) ---".format(num_frames))

    # Build proper input: cat([xt, text_cond, speech_cond], dim=2) = 300-dim
    xt_input = torch.cat([x0, text_condition, speech_condition], dim=2)  # (1, T, 300)

    for bs in [1, 2, 4, 8, 16, 32, 64, 128]:
        x_b = xt_input.expand(bs, -1, -1).contiguous()
        pm_b = padding_mask.expand(bs, -1).contiguous()
        t_b = t_cur.unsqueeze(0).expand(bs)
        gs_b = gs.unsqueeze(0).expand(bs)

        ms_b = timeit(lambda: model.fm_decoder(x=x_b, t=t_b, padding_mask=pm_b, guidance_scale=gs_b), n=8)
        per_item = ms_b / bs
        speed_x = adur / (per_item / 1000)
        vram = torch.cuda.max_memory_allocated() / 1024**3
        print(f"  B={bs:3d}: {ms_b:7.1f}ms total, {per_item:6.2f}ms/item, {speed_x:6.0f}x realtime/item, VRAM={vram:.1f}GB")

    # --- Full batched e2e ---
    print("\n--- Full Batched End-to-End ---")
    from batched_inference import batched_generate

    for bs in [1, 4, 8, 16, 32, 64]:
        texts = [TEXT] * bs
        ms_b = timeit(lambda: batched_generate(
            texts, model, vocoder, tokenizer,
            enc['prompt_tokens'], enc['prompt_features'], enc['prompt_features_lens'],
            enc['prompt_rms'], num_steps=4,
        ), n=5, warmup=1)
        per_item = ms_b / bs
        speed_x = adur / (per_item / 1000)
        print(f"  B={bs:3d}: {ms_b:7.1f}ms total, {per_item:6.2f}ms/item, {speed_x:6.0f}x realtime/item")

    peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"\nPeak VRAM: {peak:.2f} GB")

if __name__ == "__main__":
    main()
