#!/usr/bin/env python3
"""
Build combined dataset.jsonl for VibeVoice full fine-tuning.
Merges Modi data + all soprano_data Hindi sources.

VibeVoice format: {"text": "Speaker 0: <text>", "audio": "/path/to/wav"}
"""

import csv
import json
import os
import wave
import random
from pathlib import Path

SOPRANO_ROOT = "/home/ubuntu/soprano_data"
MODI_JSONL = "/home/ubuntu/modi_processed/dataset.jsonl"
OUTPUT_JSONL = "/home/ubuntu/modi_processed/dataset_combined.jsonl"

MIN_DURATION = 1.0
MAX_DURATION = 30.0


def get_wav_duration(wav_path: str) -> float:
    try:
        with wave.open(wav_path, "rb") as wf:
            return wf.getnframes() / wf.getframerate()
    except Exception:
        return 0.0


def load_pipe_delimited_metadata(metadata_path: str, wavs_dir: str):
    """Load metadata.csv with pipe delimiter (google_tts, sarvam, polly)."""
    entries = []
    with open(metadata_path, "r") as f:
        reader = csv.DictReader(f, delimiter="|")
        for row in reader:
            filename = row.get("filename", "").strip()
            text = row.get("text", "").strip()
            duration = float(row.get("duration", 0))
            if not filename or not text:
                continue
            if duration < MIN_DURATION or duration > MAX_DURATION:
                continue
            wav_path = os.path.join(wavs_dir, filename)
            if not wav_path.endswith(".wav"):
                wav_path += ".wav"
            if not os.path.exists(wav_path):
                continue
            entries.append({"text": f"Speaker 0: {text}", "audio": wav_path})
    return entries


def load_comma_delimited_metadata(metadata_path: str, wavs_dir: str):
    """Load metadata.csv with comma delimiter (rasa_hindi)."""
    entries = []
    with open(metadata_path, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            filename = row.get("filename", "").strip()
            text = row.get("text", "").strip()
            duration = float(row.get("duration", 0))
            if not filename or not text:
                continue
            if duration < MIN_DURATION or duration > MAX_DURATION:
                continue
            wav_path = os.path.join(wavs_dir, filename)
            if not wav_path.endswith(".wav"):
                wav_path += ".wav"
            if not os.path.exists(wav_path):
                continue
            entries.append({"text": f"Speaker 0: {text}", "audio": wav_path})
    return entries


def load_iisc_data(gender: str):
    """Load IISc SYSPIN data."""
    base = f"{SOPRANO_ROOT}/IISc_SYSPIN_Data/IISc_SYSPINProject_Hindi_{gender}_Spk001_HC"
    transcript_file = f"{base}/IISc_SYSPINProject_Hindi_{gender}_Spk001_HC_Transcripts.json"
    wav_dir = f"{base}/wav"

    if not os.path.exists(transcript_file) or not os.path.isdir(wav_dir):
        return []

    with open(transcript_file) as f:
        data = json.load(f)

    transcripts = {}
    if isinstance(data, list):
        for item in data:
            if isinstance(item, dict) and "Transcripts" in item:
                transcripts = item["Transcripts"]
                break
    elif isinstance(data, dict) and "Transcripts" in data:
        transcripts = data["Transcripts"]

    if isinstance(transcripts, list):
        trans_dict = {}
        for item in transcripts:
            if isinstance(item, (list, tuple)) and len(item) == 2:
                trans_dict[item[0]] = item[1]
            elif isinstance(item, dict):
                for k, v in item.items():
                    trans_dict[k] = v
        transcripts = trans_dict

    entries = []
    for wav_name, info in transcripts.items():
        text = info.get("Transcript", "") if isinstance(info, dict) else str(info)
        text = text.strip().rstrip("|")
        if not text:
            continue
        wav_path = os.path.join(wav_dir, wav_name)
        if not wav_path.endswith(".wav"):
            wav_path += ".wav"
        if not os.path.exists(wav_path):
            continue
        dur = get_wav_duration(wav_path)
        if dur < MIN_DURATION or dur > MAX_DURATION:
            continue
        entries.append({"text": f"Speaker 0: {text}", "audio": wav_path})

    return entries


def load_modi_data():
    """Load existing Modi dataset."""
    entries = []
    with open(MODI_JSONL) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            entry = json.loads(line)
            if os.path.exists(entry["audio"]):
                entries.append(entry)
    return entries


def main():
    all_entries = []

    # 1. Modi data
    print("Loading Modi data...")
    modi = load_modi_data()
    print(f"  Modi: {len(modi)} entries")
    all_entries.extend(modi)

    # 2. Pipe-delimited soprano data (google_tts_*, polly_kajal, sarvam_data)
    pipe_dirs = []
    for d in sorted(os.listdir(SOPRANO_ROOT)):
        full_path = os.path.join(SOPRANO_ROOT, d)
        metadata = os.path.join(full_path, "metadata.csv")
        wavs_dir = os.path.join(full_path, "wavs")
        if os.path.isfile(metadata) and os.path.isdir(wavs_dir):
            if d == "rasa_hindi":
                continue
            pipe_dirs.append((d, metadata, wavs_dir))

    for name, metadata, wavs_dir in pipe_dirs:
        print(f"Loading {name}...")
        entries = load_pipe_delimited_metadata(metadata, wavs_dir)
        print(f"  {name}: {len(entries)} entries")
        all_entries.extend(entries)

    # 3. Rasa Hindi (comma-delimited)
    rasa_meta = os.path.join(SOPRANO_ROOT, "rasa_hindi", "metadata.csv")
    rasa_wavs = os.path.join(SOPRANO_ROOT, "rasa_hindi", "wavs")
    if os.path.exists(rasa_meta):
        print("Loading rasa_hindi...")
        entries = load_comma_delimited_metadata(rasa_meta, rasa_wavs)
        print(f"  rasa_hindi: {len(entries)} entries")
        all_entries.extend(entries)

    # 4. IISc SYSPIN data
    for gender in ["Female", "Male"]:
        print(f"Loading IISc {gender}...")
        entries = load_iisc_data(gender)
        print(f"  IISc {gender}: {len(entries)} entries")
        all_entries.extend(entries)

    # Shuffle
    random.seed(42)
    random.shuffle(all_entries)

    # Write
    with open(OUTPUT_JSONL, "w") as f:
        for entry in all_entries:
            f.write(json.dumps(entry, ensure_ascii=False) + "\n")

    print(f"\nTotal: {len(all_entries)} entries written to {OUTPUT_JSONL}")


if __name__ == "__main__":
    main()
