from __future__ import annotations

import argparse
import asyncio
import sys
import time
from pathlib import Path

import duckdb

from dotenv import load_dotenv

ROOT = Path("/home/ubuntu/transcripts")
sys.path.insert(0, str(ROOT))
load_dotenv(ROOT / ".env")

from src.config import EnvConfig
from src.variant_db import VariantPostgresDB
from src.variant_r2 import VariantR2Client


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Create transcript variant shards, upload to R2, and seed queue.")
    parser.add_argument("--input", type=Path, required=True)
    parser.add_argument("--run-id", default=f"variant-run-{int(time.time())}")
    parser.add_argument("--num-shards", type=int, default=1024)
    parser.add_argument("--local-shard-dir", type=Path, default=ROOT / "data" / "transcript_variant_shards")
    parser.add_argument("--input-bucket", default="")
    parser.add_argument("--input-prefix", default="")
    parser.add_argument("--output-bucket", default="")
    parser.add_argument("--output-prefix", default="")
    parser.add_argument("--id-col", default="row_id")
    parser.add_argument("--video-id-col", default="video_id")
    parser.add_argument("--segment-id-col", default="segment_id")
    parser.add_argument("--language-col", default="language_code")
    parser.add_argument("--text-col", default="text")
    parser.add_argument("--limit-rows", type=int, default=0)
    return parser.parse_args()


def _read_sql(path: Path) -> str:
    if path.suffix.lower() == ".csv":
        return f"read_csv_auto('{path.as_posix()}', header=true)"
    return f"read_parquet('{path.as_posix()}')"


def build_shards(args: argparse.Namespace) -> list[tuple[str, Path, int]]:
    args.local_shard_dir.mkdir(parents=True, exist_ok=True)
    db_path = args.local_shard_dir.parent / f"{args.local_shard_dir.name}_variant_init.duckdb"
    con = duckdb.connect(str(db_path))
    source_sql = _read_sql(args.input)

    limit_clause = f"LIMIT {args.limit_rows}" if args.limit_rows > 0 else ""
    con.execute(
        f"""
        CREATE OR REPLACE TEMP VIEW variant_source AS
        SELECT * FROM {source_sql}
        {limit_clause}
        """
    )

    con.execute(
        f"""
        CREATE OR REPLACE TEMP VIEW variant_normalized AS
        WITH base AS (
            SELECT
                COALESCE(NULLIF(CAST({args.id_col} AS VARCHAR), ''), printf('variant_row_%08d', row_number() OVER ())) AS row_id,
                COALESCE(CAST({args.video_id_col} AS VARCHAR), '') AS video_id,
                COALESCE(CAST({args.segment_id_col} AS VARCHAR), '') AS segment_id,
                COALESCE(CAST({args.language_col} AS VARCHAR), 'en') AS language_code,
                COALESCE(CAST({args.text_col} AS VARCHAR), '') AS text
            FROM variant_source
        )
        SELECT
            *,
            ntile({args.num_shards}) OVER (ORDER BY row_id) - 1 AS shard_idx
        FROM base
        """
    )

    con.execute(
        f"""
        COPY (
            SELECT row_id, video_id, segment_id, language_code, text, shard_idx
            FROM variant_normalized
        )
        TO '{args.local_shard_dir.as_posix()}'
        (FORMAT PARQUET, PARTITION_BY (shard_idx), COMPRESSION ZSTD, OVERWRITE_OR_IGNORE 1)
        """
    )

    shard_paths: list[tuple[str, Path, int]] = []
    for shard_dir in sorted(args.local_shard_dir.glob("shard_idx=*")):
        parquet_files = sorted(shard_dir.glob("*.parquet"))
        if not parquet_files:
            continue
        shard_id = f"shard_{shard_dir.name.split('=')[-1].zfill(6)}"

        if len(parquet_files) > 1:
            # DuckDB COPY may split large partitions into multiple files;
            # merge them into a single parquet for the worker to download.
            glob_pattern = (shard_dir / "*.parquet").as_posix()
            merged_path = shard_dir / f"{shard_id}.parquet"
            con.execute(
                f"COPY (SELECT * FROM read_parquet('{glob_pattern}')) "
                f"TO '{merged_path.as_posix()}' (FORMAT PARQUET, COMPRESSION ZSTD)"
            )
            shard_path = merged_path
        else:
            shard_path = parquet_files[0]

        row_count = con.execute(
            f"SELECT count(*) FROM read_parquet('{shard_path.as_posix()}')"
        ).fetchone()[0]
        shard_paths.append((shard_id, shard_path, int(row_count)))

    return shard_paths


async def seed_jobs(args: argparse.Namespace, shard_paths: list[tuple[str, Path, int]]):
    config = EnvConfig()
    db = VariantPostgresDB(config.database_url)
    r2 = VariantR2Client(config)
    await db.connect()
    await db.init_schema()

    input_bucket = args.input_bucket or config.r2_bucket
    output_bucket = args.output_bucket or config.r2_bucket
    input_prefix = args.input_prefix or f"transcript-variants/{args.run_id}/input"
    output_prefix = args.output_prefix or f"transcript-variants/{args.run_id}/output"

    jobs = []
    for shard_id, shard_path, row_count in shard_paths:
        input_key = f"{input_prefix.rstrip('/')}/{shard_id}.parquet"
        r2.upload_file(shard_path, input_bucket, input_key)
        jobs.append(
            {
                "shard_id": shard_id,
                "status": "pending",
                "input_bucket": input_bucket,
                "input_r2_key": input_key,
                "input_format": "parquet",
                "output_bucket": output_bucket,
                "output_prefix": output_prefix,
                "total_rows": row_count,
                "rows_processed": 0,
                "rows_skipped": 0,
                "rows_gemini": 0,
                "packs_uploaded": 0,
                "last_pack_key": "",
                "claimed_by": None,
                "claimed_at": None,
                "completed_at": None,
                "error_message": None,
                "attempt_count": 0,
                "metadata_json": {
                    "run_id": args.run_id,
                    "column_map": {
                        "id": "row_id",
                        "video_id": "video_id",
                        "segment_id": "segment_id",
                        "language_code": "language_code",
                        "text": "text",
                    },
                },
            }
        )

    await db.seed_jobs(jobs)
    await db.close()
    print(
        f"Seeded {len(jobs)} shard jobs. "
        f"Input prefix=s3://{input_bucket}/{input_prefix} "
        f"Output prefix=s3://{output_bucket}/{output_prefix}"
    )


def main():
    args = parse_args()
    shard_paths = build_shards(args)
    asyncio.run(seed_jobs(args, shard_paths))


if __name__ == "__main__":
    main()
