"""From pyannote diarization, build:
- speaker_XX_clean.wav for each speaker (their turns MINUS overlap regions)
- overlaps/overlap_NNN.wav with LOOSE cut (start of earlier turn -> end of later turn)
"""
import json
from pathlib import Path
import numpy as np
import soundfile as sf

HERE = Path(__file__).parent
JOB = json.load(open(HERE / "diarization.json"))
turns = JOB["output"]["diarization"]
audio, sr = sf.read(HERE / "clip.wav", always_2d=False)
n = len(audio)
print(f"sr={sr}, samples={n}, dur={n/sr:.1f}s")
print(f"turns: {len(turns)}")

# ---- 1. find overlap windows (loose cut) ----
turns_sorted = sorted(turns, key=lambda t: t["start"])
overlap_windows = []  # list of (loose_start, loose_end, speakers)
for i, a in enumerate(turns_sorted):
    for b in turns_sorted[i + 1:]:
        if b["start"] >= a["end"]:
            break  # sorted, no more candidates
        if a["speaker"] == b["speaker"]:
            continue
        # they overlap
        loose_s = min(a["start"], b["start"])
        loose_e = max(a["end"], b["end"])
        overlap_windows.append((loose_s, loose_e, sorted([a["speaker"], b["speaker"]])))

# merge overlapping loose windows so we don't double-extract the same region
overlap_windows.sort()
merged = []
for s, e, sp in overlap_windows:
    if merged and s <= merged[-1][1]:
        merged[-1][1] = max(merged[-1][1], e)
        merged[-1][2] = sorted(set(merged[-1][2]) | set(sp))
    else:
        merged.append([s, e, sp])
print(f"overlap windows (merged): {len(merged)}, "
      f"total {sum(e-s for s,e,_ in merged):.1f}s")

# ---- 2. build per-speaker CLEAN intervals ----
# A speaker's clean region = union of their turns, MINUS any overlap window
def subtract(intervals, holes):
    out = []
    for s, e in intervals:
        cur = [(s, e)]
        for hs, he in holes:
            new = []
            for cs, ce in cur:
                if he <= cs or hs >= ce:
                    new.append((cs, ce))
                    continue
                if hs > cs:
                    new.append((cs, hs))
                if he < ce:
                    new.append((he, ce))
            cur = new
        out.extend(cur)
    return out

speakers = sorted({t["speaker"] for t in turns})
holes = [(s, e) for s, e, _ in merged]
clean_per_speaker = {}
for sp in speakers:
    raw = sorted((t["start"], t["end"]) for t in turns if t["speaker"] == sp)
    # merge tiny adjacent
    mr = []
    for s, e in raw:
        if mr and s <= mr[-1][1]:
            mr[-1] = (mr[-1][0], max(mr[-1][1], e))
        else:
            mr.append((s, e))
    clean = subtract(mr, holes)
    clean_per_speaker[sp] = clean
    tot = sum(e - s for s, e in clean)
    print(f"  {sp}: {len(clean)} clean intervals, {tot:.1f}s "
          f"(longest={max((e-s for s,e in clean), default=0):.1f}s)")

# ---- 3. write per-speaker clean concat wavs ----
def slice_concat(intervals):
    chunks = []
    for s, e in intervals:
        i0 = max(0, int(round(s * sr)))
        i1 = min(n, int(round(e * sr)))
        if i1 > i0:
            chunks.append(audio[i0:i1])
    if not chunks:
        return np.zeros(0, dtype=audio.dtype)
    return np.concatenate(chunks)

for sp in speakers:
    out = slice_concat(clean_per_speaker[sp])
    fname = HERE / f"{sp.lower()}_clean.wav"
    sf.write(fname, out, sr, subtype="PCM_16")
    print(f"wrote {fname}  ({len(out)/sr:.1f}s)")

# ---- 4. write each loose-cut overlap clip ----
ovdir = HERE / "overlaps"
ovdir.mkdir(exist_ok=True)
for i, (s, e, sp) in enumerate(merged):
    i0 = max(0, int(round(s * sr)))
    i1 = min(n, int(round(e * sr)))
    if i1 <= i0:
        continue
    sf.write(ovdir / f"overlap_{i:04d}.wav", audio[i0:i1], sr, subtype="PCM_16")
print(f"wrote {len(merged)} overlap clips to {ovdir}")

# ---- 5. dump a manifest for the next stage ----
manifest = {
    "speakers": speakers,
    "clean_files": {sp: f"{sp.lower()}_clean.wav" for sp in speakers},
    "clean_durations_s": {
        sp: round(sum(e - s for s, e in clean_per_speaker[sp]), 2) for sp in speakers
    },
    "overlap_count": len(merged),
    "overlap_total_s": round(sum(e - s for s, e, _ in merged), 2),
    "overlaps": [
        {"index": i, "start": round(s, 3), "end": round(e, 3),
         "speakers": sp, "file": f"overlaps/overlap_{i:04d}.wav"}
        for i, (s, e, sp) in enumerate(merged)
    ],
}
(HERE / "manifest.json").write_text(json.dumps(manifest, indent=2))
print(f"wrote manifest.json")
