"""
Entry point for the validation pipeline.
Usage:
  python -m validations.main [--max-videos N] [--mock] [--worker-id ID]
  python -m validations.main --test-local VIDEO_ID  # test on a single video
"""
from __future__ import annotations

import argparse
import asyncio
import logging
import sys
import time
from pathlib import Path

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("validations")


def parse_args():
    p = argparse.ArgumentParser(description="Validation pipeline: LID + CTC scoring")
    p.add_argument("--max-videos", type=int, default=0, help="Max videos to process (0=unlimited)")
    p.add_argument("--mock", action="store_true", help="Mock mode (no R2/DB)")
    p.add_argument("--recover", action="store_true", help="Run recover worker over validation_recover_queue")
    p.add_argument("--worker-id", type=str, default="", help="Worker ID")
    p.add_argument("--test-local", type=str, default="", help="Test single video from R2")
    p.add_argument(
        "--test-recover-local",
        type=str,
        default="",
        help="Recover + validate a single video from raw 1-cleaned-data using transcription_results",
    )
    p.add_argument(
        "--models", type=str, default="all",
        help="Comma-separated model list: mms,vox,conformer,wav2vec or 'all'",
    )
    p.add_argument("--shard-size", type=int, default=50, help="Videos per parquet shard")
    p.add_argument("--tx-parquet", type=str, default="data/transcription_results.parquet")
    p.add_argument("--validation-dir", type=str, default="data")
    p.add_argument("--recover-limit", type=int, default=0, help="Max missing segments to recover in local test")
    return p.parse_args()


async def run_worker(args):
    from .config import ValidationConfig
    from .worker import ValidationWorker
    from .recover_worker import RecoverValidationWorker

    config = ValidationConfig(
        mock_mode=args.mock,
        max_videos=args.max_videos,
        worker_id=args.worker_id,
    )

    _apply_model_toggles(config, args.models)

    errors = config.validate()
    if errors and not args.mock:
        for e in errors:
            logger.error(f"Config error: {e}")
        sys.exit(1)

    worker_cls = RecoverValidationWorker if args.recover else ValidationWorker
    worker = worker_cls(config)
    await worker.start()


async def test_local(video_id: str, args):
    """Test pipeline on a single video — download, process, inspect results."""
    import json
    import shutil
    import tarfile
    import tempfile

    from .config import ValidationConfig
    from .audio_loader import load_video_segments
    from .pipeline import ValidationPipeline
    from .packer import ParquetPacker

    config = ValidationConfig()
    _apply_model_toggles(config, args.models)

    logger.info(f"=== Test local: {video_id} ===")
    logger.info(
        f"Models: MMS={config.enable_mms_lid}, Vox={config.enable_voxlingua}, "
        f"Conformer={config.enable_conformer_multi}, Wav2Vec={config.enable_wav2vec_lang}"
    )

    # Download tar
    import boto3
    s3 = boto3.client(
        "s3",
        endpoint_url=config.r2_endpoint_url,
        aws_access_key_id=config.r2_access_key_id,
        aws_secret_access_key=config.r2_secret_access_key,
        region_name="auto",
    )

    work_dir = Path(tempfile.mkdtemp(prefix=f"val_test_{video_id}_"))
    tar_path = work_dir / f"{video_id}_transcribed.tar"

    key = f"{video_id}_transcribed.tar"
    t0 = time.time()
    # Try both R2 locations (tars split across buckets)
    downloaded = False
    for bucket, obj_key in [(config.r2_bucket_source, key), ("1-cleaned-data", f"transcribed/{key}")]:
        try:
            s3.head_object(Bucket=bucket, Key=obj_key)
            logger.info(f"Downloading s3://{bucket}/{obj_key}")
            s3.download_file(bucket, obj_key, str(tar_path))
            downloaded = True
            break
        except Exception:
            continue
    if not downloaded:
        logger.error(f"Tar not found in any R2 location for {video_id}")
        return
    dl_time = time.time() - t0
    size_mb = tar_path.stat().st_size / 1e6
    logger.info(f"Downloaded: {size_mb:.1f}MB in {dl_time:.1f}s")

    # Extract
    with tarfile.open(tar_path, "r:*") as tf:
        tf.extractall(work_dir, filter="data")
    tar_path.unlink()

    # Load segments
    metadata, segments = load_video_segments(work_dir, video_id)
    logger.info(f"Loaded {len(segments)} segments, language={metadata.get('language', '?')}")

    if not segments:
        logger.error("No segments found!")
        shutil.rmtree(work_dir)
        return

    # Load models and process
    pipeline = ValidationPipeline(config)
    pipeline.load_models()

    t1 = time.time()
    results = pipeline.process_video(video_id, segments)
    proc_time = time.time() - t1

    # Summary
    logger.info(f"\n{'='*60}")
    logger.info(f"Results: {len(results)} segments in {proc_time:.1f}s")
    logger.info(f"{'='*60}")

    consensus_count = sum(1 for r in results if r.lid_consensus)
    logger.info(f"LID consensus: {consensus_count}/{len(results)} ({100*consensus_count/len(results):.0f}%)")

    # Show first few results
    for r in results[:5]:
        logger.info(
            f"  {r.segment_file}: gemini={r.gemini_lang} mms={r.mms_lang_iso1}({r.mms_confidence:.2f}) "
            f"vox={r.vox_lang_iso1}({r.vox_confidence:.2f}) consensus={r.lid_consensus} "
            f"ctc_norm={r.conformer_multi_ctc_normalized}"
        )
    if len(results) > 5:
        logger.info(f"  ... and {len(results)-5} more")

    # Write test parquet
    packer = ParquetPacker(config, work_dir / "output")
    packer.add_video_results(video_id, results)
    shard_path = packer.flush()
    if shard_path:
        logger.info(f"Test parquet: {shard_path}")

    pipeline.unload_models()
    logger.info(f"Total time: {time.time()-t0:.1f}s (download={dl_time:.1f}s, process={proc_time:.1f}s)")

    # Keep work dir for inspection
    logger.info(f"Work dir (inspect results): {work_dir}")


async def test_recover_local(video_id: str, args):
    """Recover + validate a single video from raw tar and historical tx rows."""
    import shutil
    import tempfile

    import duckdb

    from src.config import EnvConfig
    from src.r2_client import R2Client

    from .config import ValidationConfig
    from .packer import ParquetPacker
    from .pipeline import ValidationPipeline
    from .recover_loader import load_recover_segments

    config = ValidationConfig()
    _apply_model_toggles(config, args.models)

    tx_rows = _load_tx_rows_for_video(Path(args.tx_parquet), video_id)
    if not tx_rows:
        logger.error(f"No transcription rows found for {video_id} in {args.tx_parquet}")
        return

    validated_ids = _load_validated_segment_ids(Path(args.validation_dir), video_id)
    tx_ids = {row["segment_file"] for row in tx_rows}
    target_ids = sorted(tx_ids - validated_ids)
    if args.recover_limit > 0:
        target_ids = target_ids[:args.recover_limit]

    logger.info(f"=== Test recover local: {video_id} ===")
    logger.info(
        f"Historical tx rows={len(tx_rows)}, currently validated={len(validated_ids)}, "
        f"missing_validation_targets={len(target_ids)}"
    )
    logger.info(
        f"Models: MMS={config.enable_mms_lid}, Vox={config.enable_voxlingua}, "
        f"Conformer={config.enable_conformer_multi}, Wav2Vec={config.enable_wav2vec_lang}"
    )

    if not target_ids:
        logger.info("Nothing missing for this video under the current local validation snapshot.")
        return

    raw_config = EnvConfig()
    r2 = R2Client(raw_config)
    work_dir = Path(tempfile.mkdtemp(prefix=f"recover_test_{video_id}_"))

    try:
        tar_path = r2.download_tar(video_id, work_dir)
        extracted = r2.extract_tar(tar_path, video_id)

        recover = load_recover_segments(
            extracted.work_dir,
            video_id,
            tx_rows,
            target_segment_ids=set(target_ids),
        )

        logger.info(
            f"Recovered target segments={len(recover.segments)}, matched_tx_ids={len(recover.matched_tx_ids)}, "
            f"missing_tx_ids={len(recover.missing_tx_ids)}, extra_regen_ids={len(recover.extra_regen_ids)}, "
            f"missing_parent_files={len(recover.missing_parent_files)}"
        )

        if recover.missing_tx_ids:
            logger.warning(f"Missing historical IDs after replay: {recover.missing_tx_ids[:10]}")

        if not recover.segments:
            logger.error("Replay produced no recoverable segments for validation")
            return

        pipeline = ValidationPipeline(config)
        pipeline.load_models()
        try:
            results = pipeline.process_video(video_id, recover.segments)
        finally:
            pipeline.unload_models()

        logger.info(f"Validation results: {len(results)} segment rows")

        pack_config = ValidationConfig(mock_mode=True, worker_id=config.worker_id)
        packer = ParquetPacker(pack_config, work_dir / "recover_output")
        packer.add_video_results(video_id, results)
        shard_path = packer.flush()
        if shard_path:
            logger.info(f"Recover test parquet: {shard_path}")

        logger.info(f"Work dir (inspect recover output): {work_dir}")
    except Exception:
        shutil.rmtree(work_dir, ignore_errors=True)
        raise


def _apply_model_toggles(config, models_arg: str):
    if models_arg == "all":
        return
    models = {m.strip() for m in models_arg.split(",") if m.strip()}
    config.enable_mms_lid = "mms" in models
    config.enable_voxlingua = "vox" in models
    config.enable_conformer_multi = "conformer" in models
    config.enable_wav2vec_lang = "wav2vec" in models


def _load_tx_rows_for_video(tx_parquet: Path, video_id: str) -> list[dict]:
    import duckdb

    if not tx_parquet.exists():
        raise FileNotFoundError(f"Missing tx parquet: {tx_parquet}")

    con = duckdb.connect()
    rel = con.execute(
        """
        SELECT
            segment_file,
            expected_language_hint,
            detected_language,
            transcription,
            tagged,
            quality_score,
            speaker_emotion,
            speaker_style,
            speaker_pace,
            speaker_accent
        FROM read_parquet(?)
        WHERE video_id = ?
        ORDER BY segment_file
        """,
        [str(tx_parquet), video_id],
    )
    cols = [d[0] for d in rel.description]
    rows = [dict(zip(cols, row)) for row in rel.fetchall()]
    con.close()
    return rows


def _load_validated_segment_ids(validation_dir: Path, video_id: str) -> set[str]:
    import duckdb

    csv_paths = [
        validation_dir / "golden_segments.csv",
        validation_dir / "redo_segments.csv",
        validation_dir / "dispose_segments.csv",
    ]
    existing = [str(p) for p in csv_paths if p.exists()]
    if not existing:
        return set()

    csv_list = ", ".join([f"'{p}'" for p in existing])
    con = duckdb.connect()
    rows = con.execute(
        f"""
        SELECT DISTINCT segment_file
        FROM read_csv_auto([{csv_list}], union_by_name=true, header=true)
        WHERE video_id = ?
        """,
        [video_id],
    ).fetchall()
    con.close()
    return {row[0] for row in rows}


def main():
    args = parse_args()

    if args.test_local:
        asyncio.run(test_local(args.test_local, args))
    elif args.test_recover_local:
        asyncio.run(test_recover_local(args.test_recover_local, args))
    else:
        asyncio.run(run_worker(args))


if __name__ == "__main__":
    main()
