"""
Create and seed `validation_recover_queue` from historical transcriptions.

Default behavior:
  - one row per transcribed video
  - `tx_segments` = count(*) in transcription_results
  - `status` = pending

For smoke tests you can seed only a few videos:
  python scripts/init_recover_queue.py --video-id abc --video-id def --reset
"""
from __future__ import annotations

import argparse
import asyncio
import os
from pathlib import Path

import asyncpg
from dotenv import load_dotenv


def parse_args():
    p = argparse.ArgumentParser(description="Init validation_recover_queue")
    p.add_argument("--video-id", action="append", default=[], help="Only seed these video IDs")
    p.add_argument("--limit", type=int, default=0, help="Limit seeded videos (after ordering)")
    p.add_argument("--reset", action="store_true", help="Delete existing rows before seeding the selected scope")
    p.add_argument("--queue-table", default="validation_recover_queue")
    return p.parse_args()


async def main():
    args = parse_args()
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
    dsn = os.getenv("DATABASE_URL")
    if not dsn:
        raise SystemExit("DATABASE_URL missing")

    conn = await asyncpg.connect(dsn=dsn, ssl="require", statement_cache_size=0 if ":6543" in dsn else 100)
    table = args.queue_table
    try:
        await conn.execute(f"""
            CREATE TABLE IF NOT EXISTS {table} (
                video_id TEXT PRIMARY KEY,
                status TEXT NOT NULL DEFAULT 'pending',
                tx_segments INTEGER NOT NULL,
                recovered_segments INTEGER NOT NULL DEFAULT 0,
                replayed_segments INTEGER NOT NULL DEFAULT 0,
                extra_regen_segments INTEGER NOT NULL DEFAULT 0,
                missing_tx_segments INTEGER NOT NULL DEFAULT 0,
                missing_parent_files INTEGER NOT NULL DEFAULT 0,
                extra_timeout_segments INTEGER NOT NULL DEFAULT 0,
                extra_error_segments INTEGER NOT NULL DEFAULT 0,
                extra_flagged_segments INTEGER NOT NULL DEFAULT 0,
                extra_unflagged_segments INTEGER NOT NULL DEFAULT 0,
                extra_regen_ids_json TEXT,
                claimed_by TEXT,
                claimed_at TIMESTAMPTZ,
                completed_at TIMESTAMPTZ,
                error_message TEXT,
                created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
                updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
            )
        """)
        await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{table}_status ON {table}(status)")
        await conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{table}_claimed_at ON {table}(claimed_at)")

        if args.video_id and args.reset:
            await conn.execute(f"DELETE FROM {table} WHERE video_id = ANY($1::text[])", args.video_id)
        elif args.reset:
            await conn.execute(f"TRUNCATE {table}")

        where = ""
        params = []
        if args.video_id:
            where = "WHERE video_id = ANY($1::text[])"
            params.append(args.video_id)

        limit_sql = f"LIMIT {args.limit}" if args.limit > 0 else ""
        rows = await conn.fetch(
            f"""
            SELECT video_id, count(*)::int AS tx_segments
            FROM transcription_results
            {where}
            GROUP BY video_id
            ORDER BY tx_segments DESC, video_id
            {limit_sql}
            """,
            *params,
        )

        if not rows:
            print("No videos selected.")
            return

        await conn.executemany(
            f"""
            INSERT INTO {table} (video_id, tx_segments, status, updated_at)
            VALUES ($1, $2, 'pending', now())
            ON CONFLICT (video_id) DO UPDATE SET
                tx_segments = EXCLUDED.tx_segments,
                updated_at = now()
            """,
            [(row["video_id"], row["tx_segments"]) for row in rows],
        )

        total_segments = sum(row["tx_segments"] for row in rows)
        print(
            f"Seeded {len(rows):,} videos into {table} "
            f"({total_segments:,} tx segments)"
        )
        if args.video_id:
            print("Videos:")
            for row in rows:
                print(f"  {row['video_id']}\t{row['tx_segments']}")
    finally:
        await conn.close()


if __name__ == "__main__":
    asyncio.run(main())
