"""
Compare VibeVoice vs Fish S2 Pro vs Sooktam-2 on Modi Hindi voice.
Same text, same reference audio, measure quality + speed.
"""

import os
import time
import torch
import soundfile as sf
import numpy as np

os.makedirs("/home/ubuntu/comparison_samples", exist_ok=True)

REF_AUDIO = "/home/ubuntu/vibevoice/demo/voices/modi.wav"
REF_TEXT = "मेरे प्यारे देशवासियों, मुझे सीतापुर के ओजस्वी ने लिखा है कि अमृत महोत्सव से जुड़ी चर्चाएं उन्हें खूब पसंद आ रही हैं।"

TEXTS = {
    "medium": "मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.",
    "speech": "भारत आज दुनिया की पाँचवीं सबसे बड़ी अर्थव्यवस्था है. हमारे युवाओं की ऊर्जा, हमारे वैज्ञानिकों की प्रतिभा, और हमारे किसानों की मेहनत, यही हमारी असली ताकत है.",
}


def test_fish_s2():
    print("\n" + "=" * 60)
    print("FISH AUDIO S2 PRO")
    print("=" * 60)
    try:
        from fish_speech.inference import load_model, inference
        # Try the HF approach
        from huggingface_hub import hf_hub_download
        import subprocess

        print("Downloading Fish S2 Pro model...")
        # Use fish_speech CLI approach
        for label, text in TEXTS.items():
            t0 = time.perf_counter()
            result = subprocess.run([
                "python3", "-m", "fish_speech.inference",
                "--model", "fishaudio/s2-pro",
                "--text", text,
                "--reference", REF_AUDIO,
                "--output", f"/home/ubuntu/comparison_samples/fish_{label}.wav",
            ], capture_output=True, text=True, timeout=120)
            gen = time.perf_counter() - t0

            if os.path.exists(f"/home/ubuntu/comparison_samples/fish_{label}.wav"):
                data, sr = sf.read(f"/home/ubuntu/comparison_samples/fish_{label}.wav")
                dur = len(data) / sr
                print(f"  {label}: {dur:.2f}s audio | {gen:.2f}s gen | RTF={gen/dur:.3f}x")
            else:
                print(f"  {label}: FAILED - {result.stderr[:200]}")
    except Exception as e:
        print(f"  Fish S2 Pro failed: {e}")
        print("  Trying alternative approach...")
        try:
            from fish_speech.models.text2semantic.inference import launch as fish_launch
            print("  fish_speech models available, but need setup")
        except Exception as e2:
            print(f"  Alternative also failed: {e2}")


def test_sooktam2():
    print("\n" + "=" * 60)
    print("SOOKTAM-2 (BharatGen)")
    print("=" * 60)
    try:
        import sys
        sys.path.insert(0, "/home/ubuntu/sooktam2")

        # Check if setup-cls.sh needs to run
        if os.path.exists("/home/ubuntu/sooktam2/setup-cls.sh"):
            print("Running setup-cls.sh...")
            import subprocess
            result = subprocess.run(["bash", "setup-cls.sh"], cwd="/home/ubuntu/sooktam2",
                                    capture_output=True, text=True, timeout=120)
            if result.returncode != 0:
                print(f"  setup-cls.sh stderr: {result.stderr[:300]}")

        from transformers import AutoModel
        print("Loading Sooktam-2 model...")
        model = AutoModel.from_pretrained(
            "/home/ubuntu/sooktam2",
            trust_remote_code=True,
        )
        print(f"  Loaded. VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")

        for label, text in TEXTS.items():
            t0 = time.perf_counter()
            wav, sr, _ = model.infer(
                ref_file=REF_AUDIO,
                ref_text=REF_TEXT,
                gen_text=text,
                tokenizer="cls",
                cls_language="hindi",
                file_wave=f"/home/ubuntu/comparison_samples/sooktam_{label}.wav",
            )
            gen = time.perf_counter() - t0
            dur = len(wav) / sr
            rtf = gen / dur if dur > 0 else 999
            print(f"  {label}: {dur:.2f}s audio | {gen:.2f}s gen | RTF={rtf:.3f}x | VRAM={torch.cuda.memory_allocated()/1e9:.1f}GB")

        del model
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"  Sooktam-2 failed: {e}")
        import traceback
        traceback.print_exc()


def test_vibevoice():
    print("\n" + "=" * 60)
    print("VIBEVOICE 1.5B (our current best)")
    print("=" * 60)
    try:
        from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
        from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

        processor = VibeVoiceProcessor.from_pretrained("microsoft/VibeVoice-1.5B")
        try:
            model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                "microsoft/VibeVoice-1.5B", torch_dtype=torch.bfloat16, device_map="cuda",
                attn_implementation="flash_attention_2")
        except:
            model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                "microsoft/VibeVoice-1.5B", torch_dtype=torch.bfloat16, device_map="cuda",
                attn_implementation="sdpa")
        model.eval()
        model.set_ddpm_inference_steps(num_steps=10)

        # Warmup
        inp = processor(text=["Speaker 1: test."], voice_samples=[[REF_AUDIO]],
                        padding=True, return_tensors="pt", return_attention_mask=True)
        for k, v in inp.items():
            if torch.is_tensor(v): inp[k] = v.to("cuda")
        _ = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
                           generation_config={"do_sample": False}, verbose=False, is_prefill=True, show_progress_bar=False)

        for label, text in TEXTS.items():
            full_text = f"Speaker 1: {text}"
            inp = processor(text=[full_text], voice_samples=[[REF_AUDIO]],
                            padding=True, return_tensors="pt", return_attention_mask=True)
            for k, v in inp.items():
                if torch.is_tensor(v): inp[k] = v.to("cuda")

            torch.manual_seed(42)
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            out = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
                                 generation_config={"do_sample": False}, verbose=False, is_prefill=True, show_progress_bar=False)
            torch.cuda.synchronize()
            gen = time.perf_counter() - t0

            if out.speech_outputs[0] is not None:
                path = f"/home/ubuntu/comparison_samples/vibevoice_{label}.wav"
                processor.save_audio(out.speech_outputs[0], output_path=path)
                dur = out.speech_outputs[0].shape[-1] / 24000
                print(f"  {label}: {dur:.2f}s audio | {gen:.2f}s gen | RTF={gen/dur:.3f}x | VRAM={torch.cuda.memory_allocated()/1e9:.1f}GB")

        del model
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"  VibeVoice failed: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    print("COMPARISON: VibeVoice vs Fish S2 Pro vs Sooktam-2")
    print(f"Reference audio: {REF_AUDIO}")
    print(f"Language: Hindi")

    test_sooktam2()
    torch.cuda.empty_cache()

    test_vibevoice()
    torch.cuda.empty_cache()

    test_fish_s2()

    print("\n" + "=" * 60)
    print("All samples saved in /home/ubuntu/comparison_samples/")
    print("=" * 60)
