#!/usr/bin/env python3
"""
Run inference with trained Hindi Soprano checkpoint and save sample audios.
Usage:
  python inference.py --checkpoint-dir /path/to/checkpoints --out-dir /path/to/wavs [--num-samples 5]
"""
import argparse
import json
import os

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint-dir", type=str, default="/home/ubuntu/soprano_data/hindi_unified/checkpoints",
                        help="Path to trained checkpoint (model + tokenizer + decoder.pth)")
    parser.add_argument("--out-dir", type=str, default="/home/ubuntu/soprano_data/hindi_unified/inference_output",
                        help="Directory to save generated WAVs")
    parser.add_argument("--val-json", type=str, default="/home/ubuntu/soprano_data/hindi_unified/val.json",
                        help="Val JSON to sample prompts from")
    parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to generate")
    args = parser.parse_args()

    with open(args.val_json, encoding="utf-8") as f:
        val_data = json.load(f)

    # Sample diverse prompts (text only)
    indices = [0, len(val_data)//4, len(val_data)//2, 3*len(val_data)//4, len(val_data)-1][:args.num_samples]
    prompts = [val_data[i][0] for i in indices]

    os.makedirs(args.out_dir, exist_ok=True)
    print(f"Loading model from {args.checkpoint_dir} ...")
    from soprano import SopranoTTS
    model = SopranoTTS(backend="auto", device="auto", model_path=args.checkpoint_dir)

    print(f"Generating {len(prompts)} samples ...")
    for i, text in enumerate(prompts):
        out_path = os.path.join(args.out_dir, f"sample_{i}.wav")
        # Truncate very long text for quicker generation
        if len(text) > 200:
            text = text[:200].rsplit(" ", 1)[0] if " " in text[:200] else text[:200]
        print(f"  [{i+1}/{len(prompts)}] {text[:60]}... -> {out_path}")
        model.infer(text, out_path=out_path)

    print(f"\nDone. Audios saved to {args.out_dir}")
    for i in range(len(prompts)):
        p = os.path.join(args.out_dir, f"sample_{i}.wav")
        if os.path.isfile(p):
            print(f"  {p} ({os.path.getsize(p)//1024} KB)")

if __name__ == "__main__":
    main()
