"""
Entry point for the validation pipeline.
Usage:
  python -m validations.main [--max-videos N] [--mock] [--worker-id ID]
  python -m validations.main --test-local VIDEO_ID  # test on a single video
"""
from __future__ import annotations

import argparse
import asyncio
import logging
import sys
import time
from pathlib import Path

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("validations")


def parse_args():
    p = argparse.ArgumentParser(description="Validation pipeline: LID + CTC scoring")
    p.add_argument("--max-videos", type=int, default=0, help="Max videos to process (0=unlimited)")
    p.add_argument("--mock", action="store_true", help="Mock mode (no R2/DB)")
    p.add_argument("--worker-id", type=str, default="", help="Worker ID")
    p.add_argument("--test-local", type=str, default="", help="Test single video from R2")
    p.add_argument(
        "--models", type=str, default="all",
        help="Comma-separated model list: mms,vox,conformer,wav2vec or 'all'",
    )
    p.add_argument("--shard-size", type=int, default=50, help="Videos per parquet shard")
    return p.parse_args()


async def run_worker(args):
    from .config import ValidationConfig
    from .worker import ValidationWorker

    config = ValidationConfig(
        mock_mode=args.mock,
        max_videos=args.max_videos,
        worker_id=args.worker_id,
    )

    # Model toggles
    if args.models != "all":
        models = set(args.models.split(","))
        config.enable_mms_lid = "mms" in models
        config.enable_voxlingua = "vox" in models
        config.enable_conformer_multi = "conformer" in models
        config.enable_wav2vec_lang = "wav2vec" in models

    errors = config.validate()
    if errors and not args.mock:
        for e in errors:
            logger.error(f"Config error: {e}")
        sys.exit(1)

    worker = ValidationWorker(config)
    await worker.start()


async def test_local(video_id: str, args):
    """Test pipeline on a single video — download, process, inspect results."""
    import json
    import shutil
    import tarfile
    import tempfile

    from .config import ValidationConfig
    from .audio_loader import load_video_segments
    from .pipeline import ValidationPipeline
    from .packer import ParquetPacker

    config = ValidationConfig()

    # Model toggles
    if args.models != "all":
        models = set(args.models.split(","))
        config.enable_mms_lid = "mms" in models
        config.enable_voxlingua = "vox" in models
        config.enable_conformer_multi = "conformer" in models
        config.enable_wav2vec_lang = "wav2vec" in models

    logger.info(f"=== Test local: {video_id} ===")
    logger.info(
        f"Models: MMS={config.enable_mms_lid}, Vox={config.enable_voxlingua}, "
        f"Conformer={config.enable_conformer_multi}, Wav2Vec={config.enable_wav2vec_lang}"
    )

    # Download tar
    import boto3
    s3 = boto3.client(
        "s3",
        endpoint_url=config.r2_endpoint_url,
        aws_access_key_id=config.r2_access_key_id,
        aws_secret_access_key=config.r2_secret_access_key,
        region_name="auto",
    )

    work_dir = Path(tempfile.mkdtemp(prefix=f"val_test_{video_id}_"))
    tar_path = work_dir / f"{video_id}_transcribed.tar"

    key = f"{video_id}_transcribed.tar"
    t0 = time.time()
    # Try both R2 locations (tars split across buckets)
    downloaded = False
    for bucket, obj_key in [(config.r2_bucket_source, key), ("1-cleaned-data", f"transcribed/{key}")]:
        try:
            s3.head_object(Bucket=bucket, Key=obj_key)
            logger.info(f"Downloading s3://{bucket}/{obj_key}")
            s3.download_file(bucket, obj_key, str(tar_path))
            downloaded = True
            break
        except Exception:
            continue
    if not downloaded:
        logger.error(f"Tar not found in any R2 location for {video_id}")
        return
    dl_time = time.time() - t0
    size_mb = tar_path.stat().st_size / 1e6
    logger.info(f"Downloaded: {size_mb:.1f}MB in {dl_time:.1f}s")

    # Extract
    with tarfile.open(tar_path, "r:*") as tf:
        tf.extractall(work_dir, filter="data")
    tar_path.unlink()

    # Load segments
    metadata, segments = load_video_segments(work_dir, video_id)
    logger.info(f"Loaded {len(segments)} segments, language={metadata.get('language', '?')}")

    if not segments:
        logger.error("No segments found!")
        shutil.rmtree(work_dir)
        return

    # Load models and process
    pipeline = ValidationPipeline(config)
    pipeline.load_models()

    t1 = time.time()
    results = pipeline.process_video(video_id, segments)
    proc_time = time.time() - t1

    # Summary
    logger.info(f"\n{'='*60}")
    logger.info(f"Results: {len(results)} segments in {proc_time:.1f}s")
    logger.info(f"{'='*60}")

    consensus_count = sum(1 for r in results if r.lid_consensus)
    logger.info(f"LID consensus: {consensus_count}/{len(results)} ({100*consensus_count/len(results):.0f}%)")

    # Show first few results
    for r in results[:5]:
        logger.info(
            f"  {r.segment_file}: gemini={r.gemini_lang} mms={r.mms_lang_iso1}({r.mms_confidence:.2f}) "
            f"vox={r.vox_lang_iso1}({r.vox_confidence:.2f}) consensus={r.lid_consensus} "
            f"ctc_norm={r.conformer_multi_ctc_normalized}"
        )
    if len(results) > 5:
        logger.info(f"  ... and {len(results)-5} more")

    # Write test parquet
    packer = ParquetPacker(config, work_dir / "output")
    packer.add_video_results(video_id, results)
    shard_path = packer.flush()
    if shard_path:
        logger.info(f"Test parquet: {shard_path}")

    pipeline.unload_models()
    logger.info(f"Total time: {time.time()-t0:.1f}s (download={dl_time:.1f}s, process={proc_time:.1f}s)")

    # Keep work dir for inspection
    logger.info(f"Work dir (inspect results): {work_dir}")


def main():
    args = parse_args()

    if args.test_local:
        asyncio.run(test_local(args.test_local, args))
    else:
        asyncio.run(run_worker(args))


if __name__ == "__main__":
    main()
