from __future__ import annotations

import argparse
import asyncio
import logging
import os
import sys

from .final_export_compactor import FinalExportCompactor
from .final_export_config import FinalExportConfig


def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        stream=sys.stdout,
    )


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Final export Stage B language compactor")
    parser.add_argument("--worker-id", default=None)
    parser.add_argument("--gpu-type", default=None)
    parser.add_argument("--run-id", default=None)
    parser.add_argument("--max-shards", type=int, default=None)
    parser.add_argument("--allow-partial-shards", action="store_true")
    parser.add_argument("--languages", default=None, help="Comma-separated language filters")
    return parser.parse_args()


def main():
    setup_logging()
    args = parse_args()
    logger = logging.getLogger("final_export_compact_main")
    if args.worker_id:
        os.environ["WORKER_ID"] = args.worker_id
    if args.gpu_type:
        os.environ["GPU_TYPE"] = args.gpu_type
    if args.run_id:
        os.environ["FINAL_EXPORT_RUN_ID"] = args.run_id
    if args.max_shards is not None:
        os.environ["FINAL_EXPORT_MAX_SHARDS"] = str(args.max_shards)
    if args.allow_partial_shards:
        os.environ["FINAL_EXPORT_ALLOW_PARTIAL_SHARDS"] = "true"
    if args.languages:
        os.environ["FINAL_EXPORT_LANG_FILTERS"] = args.languages

    config = FinalExportConfig.from_env()
    errors = config.validate_for_compactor()
    if errors:
        for error in errors:
            logger.error("Config error: %s", error)
        sys.exit(1)

    logger.info("Starting final export compactor %s", config.worker_id)
    logger.info("  Run ID: %s", config.run_id)
    logger.info("  Final shard rows: %s", config.final_shard_target_rows)
    compactor = FinalExportCompactor(config)
    asyncio.run(compactor.start())


if __name__ == "__main__":
    main()
