"""
Fast VibeVoice inference for Modi's voice.

Optimizations:
  1. flash_attention_2 for faster attention
  2. Pre-cached voice embeddings (encode reference once, reuse)
  3. Configurable DDPM steps (default 10 -> try 5)
  4. Batched inference (multiple texts in parallel)
  5. torch.compile on the LM backbone
  6. bf16 throughout
"""

import argparse
import os
import time
import json
import torch
import numpy as np
from typing import List, Optional

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


class ModiTTSEngine:
    """Pre-loaded, optimized TTS engine for a single speaker."""

    def __init__(
        self,
        model_path: str = "microsoft/VibeVoice-1.5B",
        voice_path: str = "demo/voices/modi.wav",
        device: str = "cuda",
        ddpm_steps: int = 10,
        cfg_scale: float = 1.3,
        compile_model: bool = False,
    ):
        self.device = device
        self.cfg_scale = cfg_scale
        self.ddpm_steps = ddpm_steps

        print(f"[init] Loading processor from {model_path}")
        self.processor = VibeVoiceProcessor.from_pretrained(model_path)

        print(f"[init] Loading model with flash_attention_2")
        try:
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                device_map="cuda",
                attn_implementation="flash_attention_2",
            )
        except Exception as e:
            print(f"[init] flash_attention_2 failed ({e}), falling back to sdpa")
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                device_map="cuda",
                attn_implementation="sdpa",
            )

        self.model.eval()
        self.model.set_ddpm_inference_steps(num_steps=ddpm_steps)

        if compile_model:
            print("[init] Compiling language model with torch.compile...")
            self.model.model.language_model = torch.compile(
                self.model.model.language_model, mode="reduce-overhead"
            )

        self.voice_path = os.path.abspath(voice_path)
        self._voice_cache = {}
        print(f"[init] Pre-caching voice from {voice_path}")
        self._precache_voice(voice_path)
        print(f"[init] Engine ready. DDPM steps={ddpm_steps}, cfg={cfg_scale}")

    def _precache_voice(self, voice_path: str):
        """Encode the reference voice once and cache processor outputs."""
        dummy_text = "Speaker 1: test"
        inputs = self.processor(
            text=[dummy_text],
            voice_samples=[[voice_path]],
            padding=True,
            return_tensors="pt",
            return_attention_mask=True,
        )
        self._voice_cache["speech_tensors"] = inputs["speech_tensors"].to(self.device)
        self._voice_cache["speech_masks"] = inputs["speech_masks"].to(self.device)

    def _prepare_inputs(self, texts: List[str]) -> dict:
        """Prepare batched inputs, reusing cached voice embeddings."""
        scripts = []
        for t in texts:
            if not t.startswith("Speaker"):
                t = f"Speaker 1: {t}"
            scripts.append(t)

        voice_samples_batch = [[self.voice_path]] * len(scripts)

        inputs = self.processor(
            text=scripts,
            voice_samples=voice_samples_batch,
            padding=True,
            return_tensors="pt",
            return_attention_mask=True,
        )

        for k, v in inputs.items():
            if torch.is_tensor(v):
                inputs[k] = v.to(self.device)
        return inputs

    @torch.inference_mode()
    def generate(
        self,
        texts: List[str],
        seed: Optional[int] = None,
        ddpm_steps: Optional[int] = None,
        cfg_scale: Optional[float] = None,
    ) -> List[dict]:
        """
        Generate speech for one or more texts (batched).

        Returns list of dicts with keys:
          - audio: torch.Tensor (waveform)
          - duration_s: float
          - generation_time_s: float
          - rtf: float
        """
        if seed is not None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

        if ddpm_steps and ddpm_steps != self.ddpm_steps:
            self.model.set_ddpm_inference_steps(num_steps=ddpm_steps)
            active_steps = ddpm_steps
        else:
            active_steps = self.ddpm_steps

        active_cfg = cfg_scale if cfg_scale is not None else self.cfg_scale

        inputs = self._prepare_inputs(texts)

        torch.cuda.synchronize()
        t0 = time.perf_counter()

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=None,
            cfg_scale=active_cfg,
            tokenizer=self.processor.tokenizer,
            generation_config={"do_sample": False},
            verbose=False,
            is_prefill=True,
            show_progress_bar=False,
        )

        torch.cuda.synchronize()
        gen_time = time.perf_counter() - t0

        results = []
        sample_rate = 24000
        for i, audio in enumerate(outputs.speech_outputs):
            if audio is not None:
                dur = audio.shape[-1] / sample_rate
                results.append({
                    "audio": audio.cpu(),
                    "duration_s": dur,
                    "generation_time_s": gen_time if len(texts) == 1 else gen_time,
                    "rtf": gen_time / dur if dur > 0 else float("inf"),
                })
            else:
                results.append({"audio": None, "duration_s": 0, "generation_time_s": gen_time, "rtf": float("inf")})

        if ddpm_steps and ddpm_steps != self.ddpm_steps:
            self.model.set_ddpm_inference_steps(num_steps=self.ddpm_steps)

        return results

    def save_audio(self, audio_tensor: torch.Tensor, path: str):
        self.processor.save_audio(audio_tensor, output_path=path)


def benchmark_steps(engine: ModiTTSEngine, text: str, steps_list: List[int], output_dir: str):
    """Benchmark different DDPM step counts."""
    print("\n" + "=" * 60)
    print("DDPM STEPS BENCHMARK")
    print("=" * 60)
    os.makedirs(output_dir, exist_ok=True)

    for steps in steps_list:
        results = engine.generate([text], seed=42, ddpm_steps=steps)
        r = results[0]
        out_path = os.path.join(output_dir, f"modi_steps_{steps}.wav")
        if r["audio"] is not None:
            engine.save_audio(r["audio"], out_path)
        print(f"  Steps={steps:2d} | Duration={r['duration_s']:.2f}s | GenTime={r['generation_time_s']:.2f}s | RTF={r['rtf']:.3f}x | -> {out_path}")

    print("=" * 60)


def benchmark_batch(engine: ModiTTSEngine, texts: List[str], output_dir: str):
    """Benchmark batched vs sequential inference."""
    print("\n" + "=" * 60)
    print("BATCHING BENCHMARK")
    print("=" * 60)
    os.makedirs(output_dir, exist_ok=True)

    # Sequential
    seq_total = 0
    for i, t in enumerate(texts):
        results = engine.generate([t], seed=42)
        r = results[0]
        seq_total += r["generation_time_s"]
        out_path = os.path.join(output_dir, f"modi_seq_{i}.wav")
        if r["audio"] is not None:
            engine.save_audio(r["audio"], out_path)
        print(f"  Sequential[{i}] | Duration={r['duration_s']:.2f}s | GenTime={r['generation_time_s']:.2f}s | RTF={r['rtf']:.3f}x")

    print(f"  Sequential total: {seq_total:.2f}s")

    # Batched
    results = engine.generate(texts, seed=42)
    batch_time = results[0]["generation_time_s"]
    for i, r in enumerate(results):
        out_path = os.path.join(output_dir, f"modi_batch_{i}.wav")
        if r["audio"] is not None:
            engine.save_audio(r["audio"], out_path)
        print(f"  Batched[{i}]     | Duration={r['duration_s']:.2f}s | GenTime={batch_time:.2f}s")

    print(f"  Batched total:     {batch_time:.2f}s")
    if seq_total > 0:
        print(f"  Speedup:           {seq_total / batch_time:.2f}x")
    print("=" * 60)


def parse_args():
    p = argparse.ArgumentParser(description="Fast Modi TTS Engine")
    p.add_argument("--model_path", default="microsoft/VibeVoice-1.5B")
    p.add_argument("--voice_path", default="demo/voices/modi.wav")
    p.add_argument("--ddpm_steps", type=int, default=10)
    p.add_argument("--cfg_scale", type=float, default=1.3)
    p.add_argument("--compile", action="store_true", help="torch.compile the LM backbone")
    p.add_argument("--output_dir", default="./fast_output")
    p.add_argument("--benchmark_steps", action="store_true", help="Benchmark different DDPM step counts")
    p.add_argument("--benchmark_batch", action="store_true", help="Benchmark batched vs sequential")
    p.add_argument("--text", type=str, default=None, help="Single text to generate (Hindi or English)")
    p.add_argument("--text_file", type=str, default=None, help="File with multiple texts, one per line")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()


def main():
    args = parse_args()

    engine = ModiTTSEngine(
        model_path=args.model_path,
        voice_path=args.voice_path,
        ddpm_steps=args.ddpm_steps,
        cfg_scale=args.cfg_scale,
        compile_model=args.compile,
    )

    os.makedirs(args.output_dir, exist_ok=True)

    test_texts = [
        "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ.",
        "Speaker 1: डिजिटल इंडिया ने देश की तस्वीर बदल दी है. गाँव गाँव में इंटरनेट पहुँच रहा है.",
        "Speaker 1: हमारा देश एक नये दौर में प्रवेश कर रहा है, जहाँ टेक्नोलॉजी और इनोवेशन हमारी ताकत बन रही है.",
    ]

    if args.text:
        test_texts = [args.text]
    elif args.text_file:
        with open(args.text_file, "r", encoding="utf-8") as f:
            test_texts = [line.strip() for line in f if line.strip()]

    if args.benchmark_steps:
        benchmark_steps(engine, test_texts[0], [3, 5, 8, 10, 15, 20], args.output_dir)

    if args.benchmark_batch:
        benchmark_batch(engine, test_texts, args.output_dir)

    if not args.benchmark_steps and not args.benchmark_batch:
        # Warmup run
        print("\n[warmup] Running warmup generation...")
        _ = engine.generate(["Speaker 1: test warmup."], seed=0)

        print("\n[generate] Generating speech...")
        results = engine.generate(test_texts, seed=args.seed)
        for i, r in enumerate(results):
            out_path = os.path.join(args.output_dir, f"modi_fast_{i}.wav")
            if r["audio"] is not None:
                engine.save_audio(r["audio"], out_path)
            print(f"  [{i}] Duration={r['duration_s']:.2f}s | GenTime={r['generation_time_s']:.2f}s | RTF={r['rtf']:.3f}x | -> {out_path}")


if __name__ == "__main__":
    main()
