#!/usr/bin/env python3
"""Pyannote community-1-cloud diarization with overlap removal + tight concatenation."""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os, sys, json, subprocess, numpy as np, soundfile as sf

PYANNOTE_API_KEY = "sk_4477f5473f584d1190f2c3bdbf37445b"
MODEL = "pyannote/speaker-diarization-community-1-cloud"


def flush_print(*args, **kwargs):
    print(*args, **kwargs, flush=True)


def process(video_id):
    base = f"/home/ubuntu/Speech_maker_pipeline/pawan_kalyan/{video_id}"
    demucs_wav = os.path.join(base, "yt_downloaded_256kbps_demucs.wav")
    mono_16k = os.path.join(base, "demucs_16k.wav")
    out_dir = os.path.join(base, "diarized")
    os.makedirs(out_dir, exist_ok=True)

    # Step 1: Convert to 16kHz mono for Pyannote
    if not os.path.exists(mono_16k):
        flush_print(f"[{video_id}] Converting demucs output to 16kHz mono...")
        subprocess.run([
            "ffmpeg", "-y", "-i", demucs_wav,
            "-ar", "16000", "-ac", "1", mono_16k
        ], capture_output=True, check=True, timeout=120)
        size_mb = os.path.getsize(mono_16k) / (1024*1024)
        flush_print(f"[{video_id}] Converted: {size_mb:.1f}MB")
    else:
        flush_print(f"[{video_id}] 16kHz mono already exists")

    # Step 2: Run Pyannote cloud diarization
    flush_print(f"[{video_id}] Loading Pyannote {MODEL}...")
    from pyannote.audio import Pipeline

    pipeline = Pipeline.from_pretrained(MODEL, token=PYANNOTE_API_KEY)

    flush_print(f"[{video_id}] Running diarization...")
    result = pipeline(mono_16k)

    if hasattr(result, 'speaker_diarization'):
        diarization = result.speaker_diarization
    else:
        diarization = result

    segments = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        segments.append({
            "speaker": speaker,
            "start": round(turn.start, 3),
            "end": round(turn.end, 3),
        })
    segments.sort(key=lambda x: x["start"])

    with open(os.path.join(out_dir, "diarization.json"), "w") as f:
        json.dump(segments, f, indent=2)

    speakers = {}
    for seg in segments:
        speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

    flush_print(f"[{video_id}] Found {len(speakers)} speaker(s):")
    for spk in sorted(speakers, key=lambda s: sum(e-st for st, e in speakers[s]), reverse=True):
        total = sum(e - s for s, e in speakers[spk])
        flush_print(f"  {spk}: {total:.1f}s ({total/60:.1f}min), {len(speakers[spk])} segments")

    # Step 3: Overlap removal + tight concatenation
    flush_print(f"[{video_id}] Extracting clean speaker tracks (overlaps removed, tight concat)...")
    audio, sr = sf.read(mono_16k, dtype='float32')
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    total_samples = len(audio)

    speaker_masks = {}
    for spk, segs in speakers.items():
        mask = np.zeros(total_samples, dtype=np.bool_)
        for s, e in segs:
            si, ei = int(s * sr), min(int(e * sr), total_samples)
            mask[si:ei] = True
        speaker_masks[spk] = mask

    active_count = np.zeros(total_samples, dtype=np.int8)
    for mask in speaker_masks.values():
        active_count += mask.astype(np.int8)
    overlap_mask = active_count >= 2

    overlap_secs = np.sum(overlap_mask) / sr
    flush_print(f"[{video_id}] Overlap detected: {overlap_secs:.1f}s ({overlap_secs/60:.1f}min)")

    for spk in sorted(speakers, key=lambda s: np.sum(speaker_masks[s]), reverse=True):
        clean_mask = speaker_masks[spk] & ~overlap_mask

        diff = np.diff(clean_mask.astype(np.int8))
        starts = np.where(diff == 1)[0] + 1
        ends = np.where(diff == -1)[0] + 1
        if clean_mask[0]:
            starts = np.concatenate([[0], starts])
        if clean_mask[-1]:
            ends = np.concatenate([ends, [total_samples]])

        chunks = []
        for si, ei in zip(starts, ends):
            chunks.append(audio[si:ei])

        if chunks:
            out = np.concatenate(chunks)
            path = os.path.join(out_dir, f"{spk}.wav")
            sf.write(path, out, sr)
            flush_print(f"  {spk}: {len(out)/sr:.1f}s -> {path}")

    flush_print(f"[{video_id}] Done!")


if __name__ == "__main__":
    vid = sys.argv[1] if len(sys.argv) > 1 else "42AMTLWpZ9A"
    process(vid)
