"""
Export recover-worker reference snapshots and optionally upload them to R2.

The recover worker only needs a narrow subset of columns from
`transcription_results` and `transcription_flags`, so this exporter writes
compact parquet files tailored for recover-mode lookups:

  - `transcription_results_recover.parquet`
  - `transcription_flags_recover.parquet`

Both outputs are sorted for better row-group pruning during per-video and
per-segment DuckDB queries inside workers.
"""
from __future__ import annotations

import argparse
import json
import os
import subprocess
import time
from pathlib import Path
from urllib.parse import quote, unquote, urlparse

import duckdb
import pyarrow.parquet as pq
from boto3.s3.transfer import TransferConfig
from dotenv import load_dotenv

TX_FILENAME = "transcription_results_recover.parquet"
FLAGS_FILENAME = "transcription_flags_recover.parquet"
VALIDATED_FILENAME = "validated_segment_ids.parquet"
MANIFEST_FILENAME = "recover_reference_manifest.json"
RAW_TX_CSV_FILENAME = "_raw_transcription_results_recover.csv"
RAW_FLAGS_CSV_FILENAME = "_raw_transcription_flags_recover.csv"
ROW_GROUP_SIZE = 100_000

TX_QUERY = """
SELECT
    video_id,
    segment_file,
    expected_language_hint,
    detected_language,
    transcription,
    tagged,
    quality_score,
    speaker_emotion,
    speaker_style,
    speaker_pace,
    speaker_accent
FROM transcription_results
"""

FLAGS_QUERY = """
SELECT
    segment_id,
    flag_type
FROM transcription_flags
"""


def parse_args():
    p = argparse.ArgumentParser(description="Export recover reference parquet snapshots")
    p.add_argument("--output-dir", default="data/recover_reference", help="Local output directory")
    p.add_argument("--validation-dir", default="data", help="Dir with golden/redo/dispose CSVs")
    p.add_argument("--bucket", default="", help="R2 bucket for upload (default: R2_VALIDATION_REFERENCE_BUCKET)")
    p.add_argument("--prefix", default="reference-data", help="R2 object prefix")
    p.add_argument("--no-upload", action="store_true", help="Skip R2 upload")
    return p.parse_args()


def export_query_to_csv(dsn: str, query: str, csv_path: Path):
    export_dsn = _direct_export_dsn(dsn)
    cmd = [
        "psql",
        export_dsn,
        "-q",
        "-v",
        "ON_ERROR_STOP=1",
        "-c",
        "SET statement_timeout=0",
        "-c",
        f"COPY ({query}) TO STDOUT WITH CSV HEADER",
    ]
    target = "direct postgres" if export_dsn != dsn else "configured dsn"
    print(f"Exporting CSV -> {csv_path.name} via {target}")
    t0 = time.time()
    with csv_path.open("wb") as fh:
        subprocess.run(cmd, stdout=fh, check=True)
    print(
        f"Finished CSV export {csv_path.name} in {time.time() - t0:.1f}s "
        f"({_format_bytes(_size_bytes(csv_path))})"
    )


def sort_csv_to_parquet(csv_path: Path, final_path: Path, *, columns_sql: str, order_by: str):
    skip_rows = _csv_skip_rows(csv_path)
    con = duckdb.connect()
    try:
        con.execute(f"SET threads = {min(os.cpu_count() or 4, 16)}")
        con.execute("SET memory_limit = '16GB'")
        csv_sql = _sql_path(csv_path)
        final_sql = _sql_path(final_path)
        con.execute(f"""
            COPY (
                SELECT *
                FROM read_csv(
                    '{csv_sql}',
                    auto_detect=false,
                    header=true,
                    columns={columns_sql},
                    skip={skip_rows},
                    delim=',',
                    quote='"',
                    escape='"'
                )
                ORDER BY {order_by}
            ) TO '{final_sql}' (
                FORMAT PARQUET,
                COMPRESSION ZSTD,
                ROW_GROUP_SIZE {ROW_GROUP_SIZE}
            )
        """)
    finally:
        con.close()


def upload_file(bucket: str, key: str, path: Path):
    import boto3

    transfer = TransferConfig(
        multipart_threshold=64 * 1024 * 1024,
        multipart_chunksize=64 * 1024 * 1024,
        max_concurrency=16,
        use_threads=True,
    )
    s3 = boto3.client(
        "s3",
        endpoint_url=os.environ["R2_ENDPOINT_URL"],
        aws_access_key_id=os.environ["R2_ACCESS_KEY_ID"],
        aws_secret_access_key=os.environ["R2_SECRET_ACCESS_KEY"],
        region_name="auto",
    )
    t0 = time.time()
    print(f"Uploading {path.name} -> s3://{bucket}/{key}")
    s3.upload_file(str(path), bucket, key, Config=transfer)
    print(f"Uploaded {path.name} in {time.time() - t0:.1f}s")


def build_validated_segment_ids(validation_dir: Path, output_path: Path) -> int:
    csv_paths = [
        validation_dir / "golden_segments.csv",
        validation_dir / "redo_segments.csv",
        validation_dir / "dispose_segments.csv",
    ]
    existing = [str(p) for p in csv_paths if p.exists()]
    if not existing:
        print("No validation CSVs found, skipping validated_segment_ids export")
        return 0

    csv_list = ", ".join([f"'{p}'" for p in existing])
    con = duckdb.connect()
    try:
        con.execute(f"SET threads = {min(os.cpu_count() or 4, 16)}")
        con.execute("SET memory_limit = '16GB'")
        out_sql = _sql_path(output_path)
        con.execute(f"""
            COPY (
                SELECT DISTINCT video_id, segment_file
                FROM read_csv_auto([{csv_list}], union_by_name=true, header=true)
                WHERE video_id IS NOT NULL AND segment_file IS NOT NULL
                ORDER BY video_id, segment_file
            ) TO '{out_sql}' (
                FORMAT PARQUET,
                COMPRESSION ZSTD,
                ROW_GROUP_SIZE {ROW_GROUP_SIZE}
            )
        """)
    finally:
        con.close()
    rows = pq.ParquetFile(output_path).metadata.num_rows
    print(f"Built validated_segment_ids: {rows:,} rows ({_format_bytes(_size_bytes(output_path))})")
    return rows


def _size_bytes(path: Path) -> int:
    return path.stat().st_size if path.exists() else 0


def _format_bytes(size_bytes: int) -> str:
    if size_bytes <= 0:
        return "0B"
    value = float(size_bytes)
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if value < 1024.0 or unit == "TB":
            return f"{value:.2f}{unit}"
        value /= 1024.0
    return f"{size_bytes}B"


def _sql_path(path: Path) -> str:
    return path.as_posix().replace("'", "''")


def _csv_skip_rows(path: Path) -> int:
    with path.open("r", encoding="utf-8", errors="replace") as fh:
        first_line = fh.readline().strip()
    return 1 if first_line == "SET" else 0


def _direct_export_dsn(dsn: str) -> str:
    parsed = urlparse(dsn)
    host = parsed.hostname or ""
    username = parsed.username or ""
    if "pooler.supabase.com" not in host or "." not in username:
        return dsn

    _, project_ref = username.split(".", 1)
    password = quote(unquote(parsed.password or ""), safe="")
    path = parsed.path or "/postgres"
    return f"postgresql://postgres:{password}@db.{project_ref}.supabase.co:5432{path}"


def main():
    args = parse_args()
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
    dsn = os.getenv("DATABASE_URL")
    if not dsn:
        raise SystemExit("DATABASE_URL missing")

    bucket = args.bucket or os.getenv("R2_VALIDATION_REFERENCE_BUCKET") or os.getenv("R2_VALIDATION_MODEL_BUCKET") or "validation-results"
    prefix = args.prefix.strip("/")
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    raw_tx_csv_path = output_dir / RAW_TX_CSV_FILENAME
    raw_flags_csv_path = output_dir / RAW_FLAGS_CSV_FILENAME
    tx_path = output_dir / TX_FILENAME
    flags_path = output_dir / FLAGS_FILENAME
    manifest_path = output_dir / MANIFEST_FILENAME

    t0 = time.time()
    print(f"Exporting recover reference data into {output_dir}")

    if raw_tx_csv_path.exists() and raw_tx_csv_path.stat().st_size > 0:
        print(f"Reusing existing CSV {raw_tx_csv_path.name} ({_format_bytes(_size_bytes(raw_tx_csv_path))})")
    else:
        export_query_to_csv(dsn, TX_QUERY, raw_tx_csv_path)
    print("Sorting transcription snapshot for row-group pruning...")
    sort_csv_to_parquet(
        raw_tx_csv_path,
        tx_path,
        columns_sql="""{
            'video_id': 'VARCHAR',
            'segment_file': 'VARCHAR',
            'expected_language_hint': 'VARCHAR',
            'detected_language': 'VARCHAR',
            'transcription': 'VARCHAR',
            'tagged': 'VARCHAR',
            'quality_score': 'FLOAT',
            'speaker_emotion': 'VARCHAR',
            'speaker_style': 'VARCHAR',
            'speaker_pace': 'VARCHAR',
            'speaker_accent': 'VARCHAR'
        }""",
        order_by="video_id, segment_file",
    )
    raw_tx_csv_path.unlink(missing_ok=True)

    if raw_flags_csv_path.exists() and raw_flags_csv_path.stat().st_size > 0:
        print(f"Reusing existing CSV {raw_flags_csv_path.name} ({_format_bytes(_size_bytes(raw_flags_csv_path))})")
    else:
        export_query_to_csv(dsn, FLAGS_QUERY, raw_flags_csv_path)
    print("Sorting flags snapshot for row-group pruning...")
    sort_csv_to_parquet(
        raw_flags_csv_path,
        flags_path,
        columns_sql="""{
            'segment_id': 'VARCHAR',
            'flag_type': 'VARCHAR'
        }""",
        order_by="segment_id, flag_type",
    )
    raw_flags_csv_path.unlink(missing_ok=True)

    validated_path = output_dir / VALIDATED_FILENAME
    validated_rows = build_validated_segment_ids(Path(args.validation_dir), validated_path)

    tx_rows = pq.ParquetFile(tx_path).metadata.num_rows
    flags_rows = pq.ParquetFile(flags_path).metadata.num_rows

    payload = {
        "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "tx_rows": tx_rows,
        "flags_rows": flags_rows,
        "validated_rows": validated_rows,
        "tx_filename": TX_FILENAME,
        "flags_filename": FLAGS_FILENAME,
        "validated_filename": VALIDATED_FILENAME,
        "tx_size_bytes": _size_bytes(tx_path),
        "flags_size_bytes": _size_bytes(flags_path),
        "validated_size_bytes": _size_bytes(validated_path),
        "tx_size_human": _format_bytes(_size_bytes(tx_path)),
        "flags_size_human": _format_bytes(_size_bytes(flags_path)),
        "validated_size_human": _format_bytes(_size_bytes(validated_path)),
        "bucket": bucket,
        "prefix": prefix,
        "elapsed_s": round(time.time() - t0, 2),
    }
    manifest_path.write_text(json.dumps(payload, indent=2))

    print(json.dumps(payload, indent=2))

    if not args.no_upload:
        upload_file(bucket, f"{prefix}/{TX_FILENAME}", tx_path)
        upload_file(bucket, f"{prefix}/{FLAGS_FILENAME}", flags_path)
        upload_file(bucket, f"{prefix}/{VALIDATED_FILENAME}", validated_path)
        upload_file(bucket, f"{prefix}/{MANIFEST_FILENAME}", manifest_path)


if __name__ == "__main__":
    main()
