#!/usr/bin/env python3
"""Download 128kbps, run Demucs, run Pyannote diarization for a single video."""

import os, json, time, requests, re, sys, subprocess, glob
import numpy as np
import soundfile as sf

AUTH = "20250901majwlqo"
DOMAIN = "api-ak.vidssave.com"
BASE_DIR = "/home/ubuntu/Speech_maker_pipeline/pawan_kalyan"

def flush_print(*args, **kwargs):
    print(*args, **kwargs, flush=True)

def extract_video_id(url):
    m = re.search(r'[?&]v=([^&]+)', url)
    return m.group(1) if m else url.split('/')[-1]

# ─── STEP 1: Download 128kbps ───
def download_128kbps(url):
    vid = extract_video_id(url)
    out_dir = os.path.join(BASE_DIR, vid)
    os.makedirs(out_dir, exist_ok=True)

    existing = [f for f in os.listdir(out_dir) if f.startswith("yt_downloaded_128kbps")]
    if existing:
        flush_print(f"[{vid}] Already downloaded: {existing[0]}")
        return vid, os.path.join(out_dir, existing[0])

    flush_print(f"[{vid}] Step 1: Parsing video metadata...")
    resp = requests.post("https://api.vidssave.com/api/contentsite_api/media/parse", data={
        "auth": AUTH, "domain": DOMAIN, "origin": "source", "link": url
    }, timeout=30)
    data = resp.json()
    if data.get("status") != 1:
        raise Exception(f"Parse failed: {data}")

    title = data["data"].get("title", "unknown")
    duration = data["data"].get("duration", 0)
    resources = data["data"].get("resources", [])
    flush_print(f"[{vid}] Title: {title}")
    flush_print(f"[{vid}] Duration: {duration}s ({duration/60:.1f}min)")

    flush_print(f"[{vid}] Available audio resources:")
    for r in resources:
        if r["type"] == "audio":
            flush_print(f"  - {r['quality']} {r['format']} (direct_url: {bool(r.get('download_url'))})")

    target = None
    for r in resources:
        if r["type"] == "audio" and r["quality"] == "128KBPS":
            target = r
            break
    if not target:
        for r in resources:
            if r["type"] == "audio":
                target = r
                break
    if not target:
        raise Exception("No audio resource found")

    quality = target["quality"]
    fmt = target["format"].lower()
    flush_print(f"[{vid}] Selected: {quality} {fmt}")

    if target.get("download_url"):
        dl_url = target["download_url"]
        flush_print(f"[{vid}] Direct download URL available")
    else:
        flush_print(f"[{vid}] Requesting download task...")
        resp = requests.post("https://api.vidssave.com/api/contentsite_api/media/download", data={
            "auth": AUTH, "domain": DOMAIN,
            "request": target["resource_content"], "no_encrypt": "1"
        }, timeout=30)
        dl_data = resp.json()
        if dl_data.get("status") != 1:
            raise Exception(f"Download request failed: {dl_data}")

        task_id = dl_data["data"]["task_id"]
        flush_print(f"[{vid}] Task ID: {task_id}, polling SSE...")

        sse_url = (
            f"https://api.vidssave.com/sse/contentsite_api/media/download_query"
            f"?auth={AUTH}&domain={DOMAIN}&task_id={task_id}"
            f"&download_domain=vidssave.com&origin=content_site"
        )
        dl_url = None
        resp = requests.get(sse_url, stream=True, timeout=120)
        for line in resp.iter_lines(decode_unicode=True):
            if not line or not line.startswith("data:"):
                continue
            event_data = json.loads(line[5:].strip())
            if event_data.get("status") == "success":
                dl_url = event_data.get("download_link")
                break
            elif event_data.get("status") == "error":
                raise Exception(f"SSE error: {event_data}")
            elif "progress" in event_data:
                flush_print(f"[{vid}] Progress: {event_data.get('progress', '?')}%")

        if not dl_url:
            raise Exception("No download link from SSE")

    ext_map = {"opus": "opus", "m4a": "m4a", "webm": "webm", "mp4": "m4a"}
    ext = ext_map.get(fmt, fmt)
    out_path = os.path.join(out_dir, f"yt_downloaded_128kbps.{ext}")

    flush_print(f"[{vid}] Downloading file...")
    resp = requests.get(dl_url, stream=True, timeout=300, allow_redirects=True)
    resp.raise_for_status()
    with open(out_path, "wb") as f:
        for chunk in resp.iter_content(chunk_size=1024*1024):
            f.write(chunk)

    size_mb = os.path.getsize(out_path) / (1024*1024)
    flush_print(f"[{vid}] Downloaded: {size_mb:.1f}MB -> {out_path}")
    return vid, out_path


# ─── STEP 2: Demucs (chunked) ───
def run_demucs(vid, src_path):
    folder = os.path.join(BASE_DIR, vid)
    dst = os.path.join(folder, "yt_downloaded_128kbps_demucs.wav")

    if os.path.exists(dst) and os.path.getsize(dst) > 1000:
        flush_print(f"[{vid}] Demucs output already exists: {os.path.getsize(dst)/(1024*1024):.1f}MB")
        return dst

    tmp_dir = os.path.join(folder, "demucs_tmp")
    os.makedirs(tmp_dir, exist_ok=True)

    flush_print(f"[{vid}] Step 2: Running Demucs htdemucs on {src_path}...")
    t0 = time.time()

    cmd = [
        "python3", "-m", "demucs",
        "--two-stems", "vocals",
        "-n", "htdemucs",
        "-o", tmp_dir,
        src_path,
    ]
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
    if result.returncode != 0:
        flush_print(f"[{vid}] Demucs stderr: {result.stderr[-1000:]}")
        raise Exception(f"Demucs failed with code {result.returncode}")

    vocals_candidates = glob.glob(os.path.join(tmp_dir, "htdemucs", "*", "vocals.wav"))
    if not vocals_candidates:
        vocals_candidates = glob.glob(os.path.join(tmp_dir, "**", "vocals.wav"), recursive=True)
    if not vocals_candidates:
        raise Exception(f"vocals.wav not found in {tmp_dir}")

    vocals_path = vocals_candidates[0]
    os.rename(vocals_path, dst)
    subprocess.run(["rm", "-rf", tmp_dir], capture_output=True)

    elapsed = time.time() - t0
    size_mb = os.path.getsize(dst) / (1024*1024)
    flush_print(f"[{vid}] Demucs done: {size_mb:.1f}MB in {elapsed:.0f}s ({elapsed/60:.1f}min)")
    return dst


# ─── STEP 3: Pyannote Diarization ───
def run_pyannote(vid, demucs_wav):
    import torchaudio
    if not hasattr(torchaudio, 'list_audio_backends'):
        torchaudio.list_audio_backends = lambda: ["soundfile"]

    PYANNOTE_API_KEY = "sk_4477f5473f584d1190f2c3bdbf37445b"
    MODEL = "pyannote/speaker-diarization-community-1-cloud"

    base = os.path.join(BASE_DIR, vid)
    mono_16k = os.path.join(base, "demucs_16k.wav")
    out_dir = os.path.join(base, "diarized")
    os.makedirs(out_dir, exist_ok=True)

    if not os.path.exists(mono_16k):
        flush_print(f"[{vid}] Step 3: Converting demucs output to 16kHz mono...")
        subprocess.run([
            "ffmpeg", "-y", "-i", demucs_wav,
            "-ar", "16000", "-ac", "1", mono_16k
        ], capture_output=True, check=True, timeout=120)
        size_mb = os.path.getsize(mono_16k) / (1024*1024)
        flush_print(f"[{vid}] Converted: {size_mb:.1f}MB")
    else:
        flush_print(f"[{vid}] 16kHz mono already exists")

    flush_print(f"[{vid}] Loading Pyannote {MODEL}...")
    from pyannote.audio import Pipeline

    pipeline = Pipeline.from_pretrained(MODEL, token=PYANNOTE_API_KEY)

    flush_print(f"[{vid}] Running diarization...")
    result = pipeline(mono_16k)

    if hasattr(result, 'speaker_diarization'):
        diarization = result.speaker_diarization
    else:
        diarization = result

    segments = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        segments.append({
            "speaker": speaker,
            "start": round(turn.start, 3),
            "end": round(turn.end, 3),
        })
    segments.sort(key=lambda x: x["start"])

    with open(os.path.join(out_dir, "diarization.json"), "w") as f:
        json.dump(segments, f, indent=2)

    speakers = {}
    for seg in segments:
        speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

    flush_print(f"[{vid}] Found {len(speakers)} speaker(s):")
    for spk in sorted(speakers, key=lambda s: sum(e-st for st, e in speakers[s]), reverse=True):
        total = sum(e - s for s, e in speakers[spk])
        flush_print(f"  {spk}: {total:.1f}s ({total/60:.1f}min), {len(speakers[spk])} segments")

    flush_print(f"[{vid}] Extracting clean speaker tracks (overlaps removed, tight concat)...")
    audio, sr = sf.read(mono_16k, dtype='float32')
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    total_samples = len(audio)

    speaker_masks = {}
    for spk, segs in speakers.items():
        mask = np.zeros(total_samples, dtype=np.bool_)
        for s, e in segs:
            si, ei = int(s * sr), min(int(e * sr), total_samples)
            mask[si:ei] = True
        speaker_masks[spk] = mask

    active_count = np.zeros(total_samples, dtype=np.int8)
    for mask in speaker_masks.values():
        active_count += mask.astype(np.int8)
    overlap_mask = active_count >= 2

    overlap_secs = np.sum(overlap_mask) / sr
    flush_print(f"[{vid}] Overlap detected: {overlap_secs:.1f}s ({overlap_secs/60:.1f}min)")

    for spk in sorted(speakers, key=lambda s: np.sum(speaker_masks[s]), reverse=True):
        clean_mask = speaker_masks[spk] & ~overlap_mask
        diff = np.diff(clean_mask.astype(np.int8))
        starts = np.where(diff == 1)[0] + 1
        ends = np.where(diff == -1)[0] + 1
        if clean_mask[0]:
            starts = np.concatenate([[0], starts])
        if clean_mask[-1]:
            ends = np.concatenate([ends, [total_samples]])

        chunks = []
        for si, ei in zip(starts, ends):
            chunks.append(audio[si:ei])

        if chunks:
            out = np.concatenate(chunks)
            path = os.path.join(out_dir, f"{spk}.wav")
            sf.write(path, out, sr)
            flush_print(f"  {spk}: {len(out)/sr:.1f}s -> {path}")

    flush_print(f"[{vid}] All done!")


if __name__ == "__main__":
    url = sys.argv[1] if len(sys.argv) > 1 else "https://www.youtube.com/watch?v=49WU4Lx3Gkk"
    
    flush_print("=" * 60)
    flush_print(f"Processing: {url}")
    flush_print("=" * 60)

    vid, audio_path = download_128kbps(url)
    flush_print()
    demucs_path = run_demucs(vid, audio_path)
    flush_print()
    run_pyannote(vid, demucs_path)
