import os, json, shutil
from pathlib import Path
import soundfile as sf
import numpy as np
from omegaconf import OmegaConf
from nemo.collections.asr.models import ClusteringDiarizer

WORK = Path("/home/ubuntu/cleanup_pipeline")
AUDIO = WORK / "vocals_mono16k.wav"
OUT   = WORK / "diarization_out"
FINAL = WORK / "final_speaker_tracks"

OUT.mkdir(exist_ok=True)
FINAL.mkdir(exist_ok=True)

manifest_path = WORK / "input_manifest.json"
with open(manifest_path, "w") as f:
    meta = {
        "audio_filepath": str(AUDIO),
        "offset": 0,
        "duration": None,
        "label": "infer",
        "text": "-",
        "num_speakers": None,
        "rttm_filepath": None,
        "uem_filepath": None,
    }
    json.dump(meta, f)
    f.write("\n")

cfg = OmegaConf.create({
    "device": "cuda",
    "verbose": True,
    "num_workers": 4,
    "sample_rate": 16000,
    "batch_size": 64,
    "diarizer": {
        "manifest_filepath": str(manifest_path),
        "out_dir": str(OUT),
        "oracle_vad": False,
        "ignore_overlap": True,
        "clustering": {
            "parameters": {
                "oracle_num_speakers": False,
                "max_num_speakers": 10,
                "enhanced_count_thres": 80,
                "max_rp_threshold": 0.25,
                "sparse_search_volume": 10,
                "maj_vote_spk_count": False,
            }
        },
        "vad": {
            "model_path": "vad_multilingual_marblenet",
            "parameters": {
                "window_length_in_sec": 0.15,
                "shift_length_in_sec": 0.01,
                "smoothing": "median",
                "overlap": 0.5,
                "onset": 0.8,
                "offset": 0.6,
                "pad_onset": 0.05,
                "pad_offset": -0.1,
                "min_duration_on": 0.2,
                "min_duration_off": 0.2,
            },
        },
        "speaker_embeddings": {
            "model_path": "titanet_large",
            "parameters": {
                "window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5],
                "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25],
                "multiscale_weights": [1, 1, 1, 1, 1],
                "save_embeddings": False,
            },
        },
    }
})

print("=== Starting NeMo ClusteringDiarizer ===")
diarizer = ClusteringDiarizer(cfg=cfg)
diarizer.diarize()

rttm_files = list(OUT.rglob("*.rttm"))
if not rttm_files:
    for p in OUT.rglob("*"):
        print(f"  found: {p}")
    raise FileNotFoundError("No RTTM output found")

rttm_path = rttm_files[0]
print(f"\n=== RTTM output: {rttm_path} ===")

segments = []
with open(rttm_path) as f:
    for line in f:
        parts = line.strip().split()
        if parts[0] == "SPEAKER":
            start = float(parts[3])
            dur   = float(parts[4])
            spk   = parts[7]
            segments.append((spk, start, start + dur))

speakers = sorted(set(s[0] for s in segments))
print(f"\nDetected {len(speakers)} speakers: {speakers}")
print(f"Total segments: {len(segments)}")

vocals_orig, sr_orig = sf.read(str(WORK / "htdemucs" / "clip_00" / "vocals.wav"))
if vocals_orig.ndim == 2:
    n_samples = vocals_orig.shape[0]
else:
    n_samples = len(vocals_orig)

for spk in speakers:
    spk_segs = [(s, e) for (sp, s, e) in segments if sp == spk]
    track = np.zeros_like(vocals_orig)
    for seg_start, seg_end in spk_segs:
        s_sample = int(seg_start * sr_orig)
        e_sample = min(int(seg_end * sr_orig), n_samples)
        track[s_sample:e_sample] = vocals_orig[s_sample:e_sample]

    out_path = FINAL / f"{spk}_clean.wav"
    sf.write(str(out_path), track, sr_orig)
    total_speech = sum(e - s for s, e in spk_segs)
    print(f"  {spk}: {len(spk_segs)} segments, {total_speech:.1f}s speech -> {out_path}")

summary = {"speakers": {}}
for spk in speakers:
    spk_segs = [(s, e) for (sp, s, e) in segments if sp == spk]
    summary["speakers"][spk] = {
        "num_segments": len(spk_segs),
        "total_speech_sec": round(sum(e - s for s, e in spk_segs), 2),
        "output_file": str(FINAL / f"{spk}_clean.wav"),
        "segments": [{"start": round(s, 3), "end": round(e, 3)} for s, e in spk_segs],
    }

with open(FINAL / "diarization_summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(f"\n=== Done! Final tracks in {FINAL} ===")
for spk, info in summary["speakers"].items():
    print(f"  {spk}: {info['total_speech_sec']}s across {info['num_segments']} segments")
