#!/usr/bin/env python3
"""Run training with NeMo EncDecHybridRNNTCTCBPEModel.

Usage:
  python3 scripts/train_smoke.py --config configs/train/stage1_smoke.yaml --max-steps 20
  python3 scripts/train_smoke.py --config configs/train/stage1_prod_8xh200.yaml
  python3 scripts/train_smoke.py --config ... --resume-from-checkpoint path/to/last.ckpt
"""

import argparse
import time

import lightning.pytorch as pl
import nemo.collections.asr as nemo_asr
from nemo.utils.exp_manager import exp_manager
from omegaconf import OmegaConf, open_dict


def main():
    parser = argparse.ArgumentParser(description="NeMo ASR training")
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--max-steps", type=int, default=None)
    parser.add_argument("--devices", type=int, default=None)
    parser.add_argument("--train-manifest", type=str, default=None)
    parser.add_argument("--val-manifest", type=str, default=None)
    parser.add_argument(
        "--input-cfg",
        type=str,
        default=None,
        help="NeMo input_cfg YAML for multi-corpus training (overrides manifest_filepath)",
    )
    parser.add_argument("--resume-from-checkpoint", type=str, default=None)
    parser.add_argument(
        "--smoke",
        action="store_true",
        default=False,
        help="Smoke mode: force val_check_interval=1.0 (epoch-based)",
    )
    args = parser.parse_args()

    cfg = OmegaConf.load(args.config)

    with open_dict(cfg):
        if args.max_steps is not None:
            cfg.trainer.max_steps = args.max_steps
        if args.devices is not None:
            cfg.trainer.devices = args.devices
            cfg.trainer.strategy = "ddp" if args.devices > 1 else "auto"
        if args.smoke:
            # Smoke mode: epoch-based validation to avoid batch count issues
            cfg.trainer.val_check_interval = 1.0
        if args.input_cfg is not None:
            # Multi-corpus mode: use input_cfg, clear manifest_filepath
            cfg.model.train_ds.input_cfg = args.input_cfg
            if "manifest_filepath" in cfg.model.train_ds:
                cfg.model.train_ds.manifest_filepath = None
        elif args.train_manifest is not None:
            cfg.model.train_ds.manifest_filepath = args.train_manifest
        if args.val_manifest is not None:
            cfg.model.validation_ds.manifest_filepath = args.val_manifest
        if args.resume_from_checkpoint is not None:
            cfg.exp_manager.resume_from_checkpoint = args.resume_from_checkpoint
            cfg.exp_manager.resume_if_exists = True

    print(f"Config: {args.config}")
    print(f"Max steps: {cfg.trainer.max_steps}")
    train_src = cfg.model.train_ds.get("input_cfg") or cfg.model.train_ds.get("manifest_filepath")
    print(f"Train data: {train_src}")
    print(f"Val manifest: {cfg.model.validation_ds.manifest_filepath}")
    if args.resume_from_checkpoint:
        print(f"Resume from: {args.resume_from_checkpoint}")
    print()

    trainer = pl.Trainer(**cfg.trainer, logger=False)
    exp_manager(trainer, cfg.get("exp_manager", None))

    model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer)
    params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {params:,}")
    print()

    t0 = time.time()
    trainer.fit(model)
    elapsed = time.time() - t0

    steps_done = trainer.global_step
    avg_step = elapsed / max(steps_done, 1)

    print()
    print("Training complete!")
    print(f"  Steps: {steps_done}")
    print(f"  Elapsed: {elapsed:.1f}s")
    print(f"  Avg sec/step: {avg_step:.2f}")


if __name__ == "__main__":
    main()
