"""
Entrypoint for transcript-variant shard workers.
"""
from __future__ import annotations

import argparse
import asyncio
import logging
import os
import sys

from .config import EnvConfig
from .variant_worker import TranscriptVariantWorker


def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        stream=sys.stdout,
    )
    logging.getLogger("httpx").setLevel(logging.WARNING)
    logging.getLogger("httpcore").setLevel(logging.WARNING)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Transcript variant worker")
    parser.add_argument("--worker-id", default=None)
    parser.add_argument("--gpu-type", default=None)
    parser.add_argument("--key-index", type=int, default=None)
    parser.add_argument("--max-jobs", type=int, default=None)
    parser.add_argument("--max-rows-per-job", type=int, default=None)
    parser.add_argument("--concurrent-requests", type=int, default=None)
    parser.add_argument("--pack-target-videos", type=int, default=None)
    parser.add_argument("--pack-target-rows", type=int, default=None)
    return parser.parse_args()


def main():
    setup_logging()
    args = parse_args()
    logger = logging.getLogger("variant_main")

    if args.worker_id:
        os.environ["WORKER_ID"] = args.worker_id
    if args.gpu_type:
        os.environ["GPU_TYPE"] = args.gpu_type
    if args.key_index is not None:
        os.environ["GEMINI_KEY_INDEX"] = str(args.key_index)
    if args.max_jobs is not None:
        os.environ["VARIANT_MAX_JOBS"] = str(args.max_jobs)
    if args.max_rows_per_job is not None:
        os.environ["VARIANT_MAX_ROWS_PER_JOB"] = str(args.max_rows_per_job)
    if args.concurrent_requests is not None:
        os.environ["VARIANT_CONCURRENT_REQUESTS"] = str(args.concurrent_requests)
    if args.pack_target_videos is not None:
        os.environ["VARIANT_PACK_TARGET_VIDEOS"] = str(args.pack_target_videos)
    if args.pack_target_rows is not None:
        os.environ["VARIANT_PACK_TARGET_ROWS"] = str(args.pack_target_rows)

    config = EnvConfig()
    errors = []
    if not config.mock_mode:
        if not config.gemini_keys:
            errors.append("At least one GEMINI_KEY required")
        if not config.database_url:
            errors.append("DATABASE_URL is required")
        if not config.r2_endpoint_url:
            errors.append("R2_ENDPOINT_URL is required")
    if errors:
        for error in errors:
            logger.error("Config error: %s", error)
        sys.exit(1)

    logger.info("Starting transcript variant worker %s", config.worker_id)
    logger.info("  Gemini key index: %s (of %s keys)", config.gemini_key_index, len(config.gemini_keys))
    logger.info("  GPU type: %s", config.gpu_type)

    worker = TranscriptVariantWorker(config)
    asyncio.run(worker.start())


if __name__ == "__main__":
    main()
