"""
Multi-process parallel launcher for video TTS classification.

Splits pending video IDs into N shards, launches N independent Python processes
(each with its own asyncio event loop + connection pool on a separate core),
then merges shard outputs into the final CSV.

Usage:
  python scripts/classify_parallel_launcher.py --workers 16 --concurrency-per-worker 500
"""
from __future__ import annotations

import argparse
import csv
import os
import subprocess
import sys
import time
from pathlib import Path

from dotenv import load_dotenv

SHARD_DIR = Path("data/classification_shards")
SCRIPT = Path(__file__).resolve().parent / "classify_video_tts_metadata.py"


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--input-csv", default="data/youtube_video_metadata_all.csv")
    p.add_argument("--output-csv", default="data/video_tts_classification_all.csv")
    p.add_argument("--workers", type=int, default=16)
    p.add_argument("--concurrency-per-worker", type=int, default=500)
    p.add_argument("--model", default="google/gemini-3-flash-preview")
    p.add_argument("--reasoning-effort", default="low")
    p.add_argument("--progress-every", type=int, default=500)
    p.add_argument("--overwrite", action="store_true")
    return p.parse_args()


def read_done_ids(path: Path) -> set[str]:
    if not path.exists():
        return set()
    ids: set[str] = set()
    with path.open("r", encoding="utf-8", newline="") as f:
        for row in csv.DictReader(f):
            vid = row.get("video_id", "")
            if vid:
                ids.add(vid)
    return ids


def read_pending_ids(input_csv: Path, skip: set[str]) -> list[str]:
    ids: list[str] = []
    with input_csv.open("r", encoding="utf-8", newline="") as f:
        for row in csv.DictReader(f):
            vid = row.get("video_id", "")
            if vid and vid not in skip:
                ids.append(vid)
    return ids


def write_id_file(path: Path, ids: list[str]) -> None:
    with path.open("w") as f:
        f.write("video_id\n")
        for vid in ids:
            f.write(vid + "\n")


def merge_shards(shard_outputs: list[Path], final_output: Path, append: bool) -> int:
    mode = "a" if append else "w"
    total = 0
    header_written = append and final_output.exists() and final_output.stat().st_size > 0

    with final_output.open(mode, encoding="utf-8", newline="") as out_f:
        writer = None
        for shard_path in shard_outputs:
            if not shard_path.exists() or shard_path.stat().st_size == 0:
                continue
            with shard_path.open("r", encoding="utf-8", newline="") as in_f:
                reader = csv.DictReader(in_f)
                if writer is None:
                    writer = csv.DictWriter(out_f, fieldnames=reader.fieldnames)
                    if not header_written:
                        writer.writeheader()
                        header_written = True
                for row in reader:
                    writer.writerow(row)
                    total += 1
    return total


def main() -> None:
    args = parse_args()
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")

    input_csv = Path(args.input_csv)
    output_csv = Path(args.output_csv)

    skip: set[str] = set()
    append = False
    if output_csv.exists():
        if args.overwrite:
            output_csv.unlink()
        else:
            skip = read_done_ids(output_csv)
            append = True
            print(f"Resuming: {len(skip):,} already classified")

    pending = read_pending_ids(input_csv, skip)
    total = len(pending)
    print(f"Pending videos: {total:,}")
    if total == 0:
        print("Nothing to do.")
        return

    n_workers = min(args.workers, total)
    shard_size = (total + n_workers - 1) // n_workers

    SHARD_DIR.mkdir(parents=True, exist_ok=True)

    shard_id_files: list[Path] = []
    shard_out_files: list[Path] = []
    for i in range(n_workers):
        chunk = pending[i * shard_size : (i + 1) * shard_size]
        if not chunk:
            break
        id_file = SHARD_DIR / f"shard_{i:03d}_ids.csv"
        out_file = SHARD_DIR / f"shard_{i:03d}_output.csv"
        write_id_file(id_file, chunk)
        shard_id_files.append(id_file)
        shard_out_files.append(out_file)

    total_concurrency = len(shard_id_files) * args.concurrency_per_worker
    print(f"Launching {len(shard_id_files)} workers x {args.concurrency_per_worker} concurrency = {total_concurrency:,} total")

    env = os.environ.copy()
    env["PYTHONUNBUFFERED"] = "1"

    t0 = time.monotonic()
    processes: list[tuple[int, subprocess.Popen, Path]] = []

    for i, (id_file, out_file) in enumerate(zip(shard_id_files, shard_out_files)):
        cmd = [
            sys.executable, "-u", str(SCRIPT),
            "--input-csv", str(input_csv),
            "--output-csv", str(out_file),
            "--video-id-file", str(id_file),
            "--concurrency", str(args.concurrency_per_worker),
            "--model", args.model,
            "--reasoning-effort", args.reasoning_effort,
            "--progress-every", str(args.progress_every),
            "--overwrite",
            "--no-raw-response",
        ]
        log_file = SHARD_DIR / f"shard_{i:03d}.log"
        log_fh = log_file.open("w")
        proc = subprocess.Popen(cmd, env=env, stdout=log_fh, stderr=subprocess.STDOUT)
        processes.append((i, proc, log_file))
        n_ids = len(pending[i * shard_size : (i + 1) * shard_size])
        print(f"  Worker {i}: pid={proc.pid}, {n_ids:,} videos")

    print(f"\nAll workers launched. Monitoring...")

    alive = set(range(len(processes)))
    last_report = t0
    while alive:
        time.sleep(5)
        for idx, proc, log_path in processes:
            if idx not in alive:
                continue
            ret = proc.poll()
            if ret is not None:
                alive.discard(idx)
                elapsed = time.monotonic() - t0
                print(f"  Worker {idx} finished (exit={ret}) at {elapsed:.0f}s")

        now = time.monotonic()
        if now - last_report >= 30:
            last_report = now
            elapsed = now - t0
            done_rows = 0
            for out_file in shard_out_files:
                if out_file.exists():
                    done_rows += max(0, sum(1 for _ in out_file.open("r")) - 1)
            rate = done_rows / elapsed if elapsed > 0 else 0
            eta = (total - done_rows) / rate / 60 if rate > 0 else 0
            print(
                f"  [{elapsed:.0f}s] {done_rows:,}/{total:,} done "
                f"({rate:.1f}/s, ETA {eta:.1f}m, {len(alive)} workers alive)",
                flush=True,
            )

    elapsed = time.monotonic() - t0
    print(f"\nAll workers done in {elapsed:.0f}s ({elapsed/60:.1f}m)")

    print("Merging shard outputs...")
    merged = merge_shards(shard_out_files, output_csv, append=append)
    print(f"Merged {merged:,} new rows into {output_csv}")

    total_done = len(read_done_ids(output_csv))
    print(f"Total classified: {total_done:,}")


if __name__ == "__main__":
    main()
