"""Submit clip to pyannote.ai precision-2 diarization and save the result."""
import json
import sys
import time
from pathlib import Path
import requests

API_KEY = "sk_6dbfc3ab62f84962b1abb1c5c10ca82e"
BASE = "https://api.pyannote.ai/v1"
HEADERS = {"Authorization": f"Bearer {API_KEY}"}

HERE = Path(__file__).parent
AUDIO = HERE / "clip_4m12s_to_13m10s.wav"
OUT = HERE / "dairize.json"

OBJECT_KEY = "media://yt-clip-4m12s-13m10s.wav"

# 1. Get presigned upload URL
print("[1/4] Requesting presigned upload URL...")
r = requests.post(f"{BASE}/media/input", headers=HEADERS, json={"url": OBJECT_KEY})
r.raise_for_status()
upload_url = r.json()["url"]
print(f"      OK. Will upload {AUDIO.stat().st_size/1e6:.1f} MB")

# 2. PUT the audio file
print("[2/4] Uploading audio...")
with open(AUDIO, "rb") as f:
    up = requests.put(upload_url, data=f)
up.raise_for_status()
print(f"      Upload status: {up.status_code}")

# 3. Submit diarization job
print("[3/4] Submitting precision-2 diarization job...")
body = {
    "url": OBJECT_KEY,
    "model": "precision-2",
    "exclusive": True,
    "confidence": True,
    "turnLevelConfidence": True,
}
r = requests.post(f"{BASE}/diarize", headers=HEADERS, json=body)
r.raise_for_status()
job = r.json()
job_id = job.get("jobId") or job.get("id")
print(f"      jobId: {job_id}")

# 4. Poll
print("[4/4] Polling job status...")
start = time.time()
while True:
    r = requests.get(f"{BASE}/jobs/{job_id}", headers=HEADERS)
    r.raise_for_status()
    j = r.json()
    status = j.get("status", "?")
    elapsed = int(time.time() - start)
    print(f"      [{elapsed:3d}s] status={status}")
    if status in ("succeeded", "failed", "canceled"):
        break
    time.sleep(5)

if status != "succeeded":
    print("ERROR: job did not succeed:")
    print(json.dumps(j, indent=2))
    sys.exit(1)

# Save full job response (contains output)
OUT.write_text(json.dumps(j, indent=2))
print(f"\nDone -> {OUT}")

# Quick summary
output = j.get("output") or {}
dia = output.get("diarization") or []
ex = output.get("exclusiveDiarization") or []
print(f"  diarization turns: {len(dia)}")
print(f"  exclusiveDiarization turns: {len(ex)}")
if dia:
    speakers = sorted({s.get("speaker") for s in dia})
    print(f"  speakers: {speakers}")
    from collections import defaultdict
    tot = defaultdict(float)
    for s in dia:
        tot[s["speaker"]] += s["end"] - s["start"]
    for k, v in sorted(tot.items(), key=lambda x: -x[1]):
        print(f"    {k}: {v:.1f}s")
