"""Pick the longest single clean SPEAKER_00 turn from clip.wav as ref."""
import json
from pathlib import Path
import soundfile as sf

HERE = Path(__file__).parent
JOB = json.load(open(HERE / "diarization.json"))
turns = JOB["output"]["diarization"]
audio, sr = sf.read(HERE / "clip.wav", always_2d=False)

# Build overlap mask first (any region where 2+ speakers active)
turns_sorted = sorted(turns, key=lambda t: t["start"])
overlaps = []
for i, a in enumerate(turns_sorted):
    for b in turns_sorted[i + 1:]:
        if b["start"] >= a["end"]:
            break
        if a["speaker"] == b["speaker"]:
            continue
        overlaps.append((max(a["start"], b["start"]), min(a["end"], b["end"])))
# merge
overlaps.sort()
merged_ov = []
for s, e in overlaps:
    if merged_ov and s <= merged_ov[-1][1]:
        merged_ov[-1] = (merged_ov[-1][0], max(merged_ov[-1][1], e))
    else:
        merged_ov.append((s, e))

# A turn is "clean" if it does not intersect any overlap
def intersects(s, e, holes):
    for hs, he in holes:
        if hs < e and he > s:
            return True
        if hs >= e:
            break
    return False

clean_sp00 = [t for t in turns if t["speaker"] == "SPEAKER_00"
              and not intersects(t["start"], t["end"], merged_ov)]

# Pick longest, also factor in confidence
clean_sp00.sort(key=lambda t: -(t["end"] - t["start"]))
print("Top 5 clean SPEAKER_00 turns:")
for t in clean_sp00[:5]:
    print(f"  {t['start']:.2f}-{t['end']:.2f} ({t['end']-t['start']:.2f}s) "
          f"conf={t.get('confidence',{})}")

best = clean_sp00[0]
i0 = int(round(best["start"] * sr))
i1 = int(round(best["end"] * sr))
ref = audio[i0:i1]
ref_path = HERE / "main_speaker_ref.wav"
sf.write(ref_path, ref, sr, subtype="PCM_16")
print(f"\nSaved reference: {ref_path}")
print(f"  range: {best['start']:.2f}-{best['end']:.2f}s ({len(ref)/sr:.2f}s)")
