from __future__ import annotations

import argparse
import json
import shutil
import sys
import tempfile
from pathlib import Path

import pyarrow.parquet as pq
from dotenv import load_dotenv

ROOT = Path("/home/ubuntu/transcripts")
sys.path.insert(0, str(ROOT))
load_dotenv(ROOT / ".env")

from src.final_export_common import sha256_string_set
from src.final_export_config import FinalExportConfig
from src.final_export_r2 import FinalExportR2Client


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Audit final export manifests without downloading audio.tar")
    parser.add_argument("--run-id", default=None)
    parser.add_argument("--bucket", default=None)
    parser.add_argument("--prefix", default=None)
    parser.add_argument("--verify-parquet", action="store_true")
    parser.add_argument("--limit-manifests", type=int, default=0)
    return parser.parse_args()


def audit_manifest(
    *,
    r2: FinalExportR2Client,
    bucket: str,
    manifest_key: str,
    verify_parquet: bool,
    tmp_dir: Path,
) -> dict:
    manifest = r2.download_json(bucket, manifest_key)
    base_prefix = manifest_key.rsplit("/", 1)[0]
    metadata_key = f"{base_prefix}/metadata.parquet"
    audio_key = f"{base_prefix}/audio.tar"
    audio_index_key = f"{base_prefix}/audio_index.parquet"

    result = {
        "manifest_key": manifest_key,
        "segment_count": int(manifest.get("segment_count") or manifest.get("metadata_row_count") or 0),
        "size_ok": True,
        "parquet_ok": True,
        "errors": [],
    }

    expected_sizes = {
        metadata_key: int(manifest.get("metadata_size_bytes") or 0),
        audio_key: int(manifest.get("audio_tar_size_bytes") or 0),
        audio_index_key: int(manifest.get("audio_index_size_bytes") or 0),
    }
    for key, expected in expected_sizes.items():
        remote_size = r2.head_size(bucket, key)
        if expected and remote_size != expected:
            result["size_ok"] = False
            result["errors"].append(f"size_mismatch:{key}:{remote_size}!={expected}")

    if verify_parquet:
        local_metadata = tmp_dir / "metadata.parquet"
        local_audio_index = tmp_dir / "audio_index.parquet"
        r2.download_file(bucket, metadata_key, local_metadata)
        r2.download_file(bucket, audio_index_key, local_audio_index)
        metadata_rows = pq.read_table(local_metadata).to_pylist()
        audio_index_rows = pq.read_table(local_audio_index).to_pylist()
        metadata_segment_ids = [str(row["segment_id"]) for row in metadata_rows]
        audio_index_segment_ids = [str(row["segment_id"]) for row in audio_index_rows]
        if len(metadata_rows) != int(manifest.get("metadata_row_count") or 0):
            result["parquet_ok"] = False
            result["errors"].append("metadata_row_count_mismatch")
        if len(audio_index_rows) != int(manifest.get("audio_index_row_count") or 0):
            result["parquet_ok"] = False
            result["errors"].append("audio_index_row_count_mismatch")
        if sha256_string_set(metadata_segment_ids) != str(manifest.get("segment_id_set_sha256") or ""):
            result["parquet_ok"] = False
            result["errors"].append("metadata_segment_id_hash_mismatch")
        if sha256_string_set(audio_index_segment_ids) != str(manifest.get("segment_id_set_sha256") or ""):
            result["parquet_ok"] = False
            result["errors"].append("audio_index_segment_id_hash_mismatch")
        if sum(int(row.get("flac_size_bytes") or 0) for row in audio_index_rows) != int(manifest.get("sum_flac_bytes") or 0):
            result["parquet_ok"] = False
            result["errors"].append("sum_flac_bytes_mismatch")

    return result


def main():
    args = parse_args()
    if args.run_id:
        import os

        os.environ["FINAL_EXPORT_RUN_ID"] = args.run_id
    config = FinalExportConfig.from_env()
    r2 = FinalExportR2Client(config)
    bucket = args.bucket or config.output_bucket
    prefix = args.prefix or config.shard_prefix
    manifest_keys = [key for key in r2.list_keys(bucket, prefix) if key.endswith("/manifest.json")]
    if args.limit_manifests > 0:
        manifest_keys = manifest_keys[: args.limit_manifests]

    tmp_root = Path(tempfile.mkdtemp(prefix="final_export_audit_"))
    failures = 0
    try:
        for idx, manifest_key in enumerate(manifest_keys, start=1):
            tmp_dir = tmp_root / f"{idx:06d}"
            tmp_dir.mkdir(parents=True, exist_ok=True)
            result = audit_manifest(
                r2=r2,
                bucket=bucket,
                manifest_key=manifest_key,
                verify_parquet=args.verify_parquet,
                tmp_dir=tmp_dir,
            )
            if not result["size_ok"] or not result["parquet_ok"]:
                failures += 1
            print(json.dumps(result, ensure_ascii=False))
    finally:
        shutil.rmtree(tmp_root, ignore_errors=True)

    if failures:
        raise SystemExit(1)


if __name__ == "__main__":
    main()
