from __future__ import annotations

import argparse
import asyncio
import json
import os
import shutil
import sys
import tarfile
from pathlib import Path

import asyncpg
import pyarrow.parquet as pq
from dotenv import load_dotenv

ROOT = Path("/home/ubuntu/transcripts")
sys.path.insert(0, str(ROOT))
load_dotenv(ROOT / ".env")

from src.audio_polish import polish_all_segments
from src.final_export_common import build_export_segment_payload, replay_segment_id, sha256_bytes
from src.final_export_config import FinalExportConfig
from src.final_export_r2 import FinalExportR2Client
from src.final_export_reference_store import FinalExportReferenceStore
from src.r2_client import R2Client


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Strict validator for final export runs")
    parser.add_argument("--run-id", required=True)
    parser.add_argument("--video-id", action="append", default=[])
    parser.add_argument("--skip-microshards", action="store_true")
    parser.add_argument("--skip-final-shards", action="store_true")
    return parser.parse_args()


async def fetch_target_video_ids(conn: asyncpg.Connection, run_id: str, explicit_video_ids: list[str]) -> list[str]:
    if explicit_video_ids:
        return explicit_video_ids
    rows = await conn.fetch(
        """
        SELECT video_id
        FROM final_export_video_outputs
        WHERE run_id = $1
        ORDER BY video_id
        """,
        run_id,
    )
    return [str(row["video_id"]) for row in rows]


async def fetch_pack_rows(
    conn: asyncpg.Connection,
    *,
    table: str,
    run_id: str,
    video_ids: list[str],
) -> list[dict]:
    if table == "final_export_microshards":
        rows = await conn.fetch(
            """
            SELECT video_id, language, output_bucket, metadata_key, audio_key, audio_index_key, manifest_key
            FROM final_export_microshards
            WHERE run_id = $1
              AND video_id = ANY($2::text[])
            ORDER BY language, video_id, metadata_key
            """,
            run_id,
            video_ids,
        )
    else:
        rows = await conn.fetch(
            """
            SELECT output_bucket, metadata_key, audio_key, audio_index_key, manifest_key
            FROM final_export_shards
            WHERE run_id = $1
            ORDER BY language, shard_id
            """,
            run_id,
        )
    return [dict(row) for row in rows]


def normalize_meta_information(raw: str) -> dict:
    payload = json.loads(raw)
    export = payload.get("export_provenance", {})
    normalized = dict(payload)
    normalized["export_provenance"] = {
        "run_id": export.get("run_id"),
        "audio_sha256_type": export.get("audio_sha256_type"),
        "worker_id_present": bool(export.get("worker_id")),
        "exported_at_present": bool(export.get("exported_at")),
    }
    return normalized


def compare_rows(label: str, expected: dict, actual: dict, mismatches: list[str], key: tuple[str, str]):
    for field, expected_value in expected.items():
        actual_value = actual.get(field)
        if field == "meta_information":
            if normalize_meta_information(str(expected_value)) != normalize_meta_information(str(actual_value)):
                mismatches.append(f"{label}:{key}:meta_information_mismatch")
            continue
        if actual_value != expected_value:
            mismatches.append(f"{label}:{key}:{field}:{actual_value!r}!={expected_value!r}")


def load_pack_records(
    *,
    r2: FinalExportR2Client,
    rows: list[dict],
    work_root: Path,
    filter_video_ids: set[str],
) -> dict[tuple[str, str], dict]:
    records: dict[tuple[str, str], dict] = {}
    for idx, row in enumerate(rows):
        local_dir = work_root / f"pack_{idx:06d}"
        local_dir.mkdir(parents=True, exist_ok=True)
        metadata_path = local_dir / "metadata.parquet"
        audio_index_path = local_dir / "audio_index.parquet"
        audio_tar_path = local_dir / "audio.tar"
        r2.download_file(row["output_bucket"], row["metadata_key"], metadata_path)
        r2.download_file(row["output_bucket"], row["audio_index_key"], audio_index_path)
        r2.download_file(row["output_bucket"], row["audio_key"], audio_tar_path)
        metadata_rows = pq.read_table(metadata_path).to_pylist()
        audio_index_rows = pq.read_table(audio_index_path).to_pylist()
        audio_index_map = {
            (str(item["video_id"]), str(item["segment_id"])): item
            for item in audio_index_rows
        }
        with tarfile.open(audio_tar_path, "r") as tf:
            for item in metadata_rows:
                key = (str(item["video_id"]), str(item["segment_id"]))
                if key[0] not in filter_video_ids:
                    continue
                audio_index = audio_index_map.get(key)
                if audio_index is None:
                    raise RuntimeError(f"Missing audio_index row for {key}")
                member_name = str(audio_index["tar_member_name"])
                handle = tf.extractfile(member_name)
                if handle is None:
                    raise RuntimeError(f"Missing tar member {member_name} for {key}")
                flac_bytes = handle.read()
                records[key] = {
                    "metadata_row": item,
                    "audio_index_row": audio_index,
                    "audio_bytes": flac_bytes,
                }
    return records


async def build_expected_records(
    *,
    config: FinalExportConfig,
    video_ids: list[str],
) -> dict[tuple[str, str], dict]:
    work_root = config.local_work_root / "validator_expected" / config.run_id
    work_root.mkdir(parents=True, exist_ok=True)
    store = FinalExportReferenceStore(config, work_root)
    raw_r2 = R2Client(config.base)
    expected: dict[tuple[str, str], dict] = {}
    try:
        await store.start()
        loop = asyncio.get_running_loop()
        for video_id in video_ids:
            video_dir = work_root / video_id
            video_dir.mkdir(parents=True, exist_ok=True)
            tar_path = await loop.run_in_executor(None, raw_r2.download_tar, video_id, video_dir)
            extracted = await loop.run_in_executor(None, raw_r2.extract_tar, tar_path, video_id)
            reference_rows = store.get_video_reference_rows(video_id)
            polished_segments = await loop.run_in_executor(
                None,
                lambda: polish_all_segments(sorted(extracted.segment_paths), max_workers=config.polish_threads),
            )
            for polished in polished_segments:
                if polished.trim_meta.discarded:
                    continue
                segment_id = replay_segment_id(
                    polished.trim_meta.original_file,
                    polished.trim_meta.was_split,
                    polished.trim_meta.split_index,
                )
                row = reference_rows.get(segment_id)
                if row is None:
                    continue
                segment_language = str(row.get("segment_language") or "").strip().lower()
                if not segment_language or segment_language not in config.supported_languages:
                    continue
                if config.require_variants:
                    native = str(row.get("native_script_text") or "")
                    romanized = str(row.get("romanized_text") or "")
                    if not native.strip() and not romanized.strip():
                        continue
                if config.require_validation and not bool(row.get("final_has_validation")):
                    continue

                payload = build_export_segment_payload(
                    video_id=video_id,
                    canonical_row=row,
                    polished_segment=polished,
                    run_id=config.run_id,
                    worker_id="validator",
                    exported_at="validator",
                )
                key = (video_id, segment_id)
                expected[key] = {
                    "metadata_row": payload["metadata_row"],
                    "audio_index_row": {
                        "video_id": video_id,
                        "segment_id": segment_id,
                        "tar_member_name": payload["audio_row"]["tar_member_name"],
                        "flac_size_bytes": len(payload["audio_row"]["flac_bytes"]),
                        "flac_sha256": payload["audio_row"]["flac_sha256"],
                        "audio_duration_s": payload["audio_row"]["audio_duration_s"],
                    },
                    "audio_bytes": payload["audio_row"]["flac_bytes"],
                }
        return expected
    finally:
        await store.close()


def validate_records(
    *,
    label: str,
    expected: dict[tuple[str, str], dict],
    actual: dict[tuple[str, str], dict],
) -> list[str]:
    mismatches: list[str] = []
    expected_keys = set(expected)
    actual_keys = set(actual)
    for missing in sorted(expected_keys - actual_keys):
        mismatches.append(f"{label}:{missing}:missing_actual_row")
    for extra in sorted(actual_keys - expected_keys):
        mismatches.append(f"{label}:{extra}:unexpected_extra_row")
    for key in sorted(expected_keys & actual_keys):
        compare_rows(label, expected[key]["metadata_row"], actual[key]["metadata_row"], mismatches, key)
        compare_rows(label, expected[key]["audio_index_row"], actual[key]["audio_index_row"], mismatches, key)
        if sha256_bytes(actual[key]["audio_bytes"]) != expected[key]["audio_index_row"]["flac_sha256"]:
            mismatches.append(f"{label}:{key}:actual_tar_bytes_hash_mismatch")
        if actual[key]["audio_bytes"] != expected[key]["audio_bytes"]:
            mismatches.append(f"{label}:{key}:audio_bytes_mismatch")
    return mismatches


async def main_async(args: argparse.Namespace):
    config = FinalExportConfig.from_env()
    config.run_id = args.run_id
    errors = config.validate_for_video_stage()
    if errors:
        raise SystemExit("\n".join(errors))

    conn = await asyncpg.connect(dsn=config.database_url, ssl="require", statement_cache_size=0)
    try:
        video_ids = await fetch_target_video_ids(conn, args.run_id, args.video_id)
        if not video_ids:
            raise SystemExit("No videos found for validation")
        expected = await build_expected_records(config=config, video_ids=video_ids)
        r2 = FinalExportR2Client(config)
        work_root = config.local_work_root / "validator_actual" / args.run_id
        work_root.mkdir(parents=True, exist_ok=True)

        all_mismatches: list[str] = []
        if not args.skip_microshards:
            micro_rows = await fetch_pack_rows(
                conn,
                table="final_export_microshards",
                run_id=args.run_id,
                video_ids=video_ids,
            )
            actual_micro = load_pack_records(
                r2=r2,
                rows=micro_rows,
                work_root=work_root / "microshards",
                filter_video_ids=set(video_ids),
            )
            all_mismatches.extend(validate_records(label="microshards", expected=expected, actual=actual_micro))

        if not args.skip_final_shards:
            shard_rows = await fetch_pack_rows(
                conn,
                table="final_export_shards",
                run_id=args.run_id,
                video_ids=video_ids,
            )
            actual_shards = load_pack_records(
                r2=r2,
                rows=shard_rows,
                work_root=work_root / "shards",
                filter_video_ids=set(video_ids),
            )
            all_mismatches.extend(validate_records(label="final_shards", expected=expected, actual=actual_shards))

        if all_mismatches:
            print(json.dumps({"ok": False, "mismatches": all_mismatches[:200]}, indent=2))
            raise SystemExit(1)

        print(
            json.dumps(
                {
                    "ok": True,
                    "run_id": args.run_id,
                    "video_count": len(video_ids),
                    "segment_count": len(expected),
                    "validated_microshards": not args.skip_microshards,
                    "validated_final_shards": not args.skip_final_shards,
                },
                indent=2,
            )
        )
    finally:
        await conn.close()


def main():
    args = parse_args()
    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()
