# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pathlib import Path

import torch
from megatron.core.dist_checkpointing import load_content_metadata
from megatron.core.distributed import DistributedDataParallelConfig as McoreDDPConfig
from megatron.core.transformer.enums import AttnBackend

from nemo.collections.common.tokenizers.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.llm import MixtralConfig8x3B, MixtralModel, PreTrainingDataModule
from nemo.collections.llm.api import train
from nemo.lightning import MegatronStrategy, NeMoLogger, Trainer
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule as MegatronOptim
from nemo.lightning.pytorch.optim.megatron import OptimizerConfig


def tokenizer(vocab_path, merges_path):
    return get_nmt_tokenizer(
        "megatron",
        "GPT2BPETokenizer",
        vocab_file=vocab_path,
        merges_file=merges_path,
    )


def load_dcp(ckpt_dir, torch_tensor=True):
    from pathlib import Path

    import torch
    import torch.distributed.checkpoint as dcp
    from torch.distributed.checkpoint import FileSystemReader

    if not isinstance(ckpt_dir, Path):
        ckpt_dir = Path(ckpt_dir)
    fs_reader = FileSystemReader(ckpt_dir)
    metadata = fs_reader.read_metadata()

    state_dict = {
        k: torch.empty(tp.size, dtype=tp.properties.dtype)
        for k, tp in metadata.state_dict_metadata.items()
        if type(tp).__name__ == 'TensorStorageMetadata'
    }

    dcp.load(
        state_dict,
        storage_reader=fs_reader,
    )
    return state_dict


def main(args):
    strategy = MegatronStrategy(
        expert_model_parallel_size=args.devices,
        tensor_model_parallel_size=1,
        sequence_parallel=False,
        context_parallel_size=1,
        params_dtype=torch.bfloat16,
        pipeline_dtype=torch.bfloat16,
        autocast_dtype=torch.float32,
        precision=torch.bfloat16,
        ddp=McoreDDPConfig(
            grad_reduce_in_fp32=True,
            overlap_grad_reduce=False,
            use_distributed_optimizer=True,
            check_for_nan_in_grad=True,
            bucket_size=None,
        ),
    )

    trainer = Trainer(
        log_every_n_steps=1,
        devices=args.devices,
        max_steps=args.max_steps,
        accelerator="gpu",
        strategy=strategy,
        num_sanity_val_steps=0,
        logger=None,
        limit_val_batches=1,
    )

    data = PreTrainingDataModule(
        args.data_path,
        seq_length=512,
        global_batch_size=2,
        micro_batch_size=1,
        num_workers=1,
        split='99,1,0',
        tokenizer=tokenizer(args.vocab_path, args.merges_path),
    )

    mixtral_config = MixtralConfig8x3B(
        num_layers=2,
        hidden_size=128,
        num_attention_heads=8,
        num_query_groups=8,
        ffn_hidden_size=320,
        kv_channels=16,
        init_method_std=0.015,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        layernorm_epsilon=1e-5,
        make_vocab_size_divisible_by=128,
        max_position_embeddings=512,
        bf16=True,
        params_dtype=torch.bfloat16,
        pipeline_dtype=torch.bfloat16,
        attention_backend=AttnBackend.unfused,
    )
    mixtral_config.overlap_param_gather_with_optimizer_step = True

    optim_config = OptimizerConfig(
        fp16=False,
        bf16=True,
        params_dtype=torch.bfloat16,
        lr=0.01,
        weight_decay=0,
        adam_beta1=0.9,
        adam_beta2=0.9,
        clip_grad=0.0,
        use_distributed_optimizer=True,
        min_lr=0.0,
        log_num_zeros_in_grad=True,
        barrier_with_L1_time=True,
    )

    opt = MegatronOptim(config=optim_config)
    model = MixtralModel(mixtral_config, optim=opt, tokenizer=data.tokenizer)

    nemo_logger = NeMoLogger(
        name=args.experiment_name,
        use_datetime_version=False,
        explicit_log_dir=args.experiment_dir,
    )

    output_path = Path(args.experiment_dir)
    assert not output_path.exists(), f"Did not expect {output_path} to exist"

    train(
        model=model,
        resume=None,
        data=data,
        trainer=trainer,
        log=nemo_logger,
        tokenizer='data',
        optim=opt,
    )

    # Confirm checkpoint directory structure
    output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0-consumed_samples=8.0/weights"
    assert output_path.exists(), f"Expected {output_path} to exist"
    assert output_path.is_dir(), f"Expected {output_path} to be a directory"
    output_files = ['__0_0.distcp', '__0_1.distcp', 'common.pt', 'metadata.json', '.metadata']
    for file in output_files:
        path = output_path / file
        assert path.exists(), f"Expected {file} to exist"
        assert path.is_file(), f"Expected {file} to be a file"
        assert os.access(path, os.R_OK), f"Expected {file} to be readable"
        assert path.stat().st_size, f"Expected {file} to be non-empty"

    for file in os.listdir(output_path):
        assert file in output_files, f"Got unexpected {file} in checkpoint directory"

    # Finally confirm checkpoint contents
    expected_ckpt = {
        "module.embedding.word_embeddings.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"),
        "module.decoder.layers.self_attention.linear_proj.weight": (torch.Size([2, 128, 128]), torch.bfloat16, "cpu"),
        "module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
            torch.Size([2, 128]),
            torch.bfloat16,
            "cpu",
        ),
        "module.decoder.layers.self_attention.linear_qkv.weight": (torch.Size([2, 384, 128]), torch.bfloat16, "cpu"),
        "module.decoder.layers.pre_mlp_layernorm.weight": (torch.Size([2, 128]), torch.bfloat16, "cpu"),
        "module.decoder.layers.mlp.router.weight": (torch.Size([2, 8, 128]), torch.bfloat16, "cpu"),
        "module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
            torch.Size([2, 8, 640, 128]),
            torch.bfloat16,
            "cpu",
        ),
        "module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
            torch.Size([2, 8, 128, 320]),
            torch.bfloat16,
            "cpu",
        ),
        "module.decoder.final_layernorm.weight": (torch.Size([128]), torch.bfloat16, "cpu"),
        "module.output_layer.weight": (torch.Size([50304, 128]), torch.bfloat16, "cpu"),
        "optimizer.state.fp32_param.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
        "optimizer.state.exp_avg.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
        "optimizer.state.exp_avg_sq.module.output_layer.weight": (torch.Size([1, 1, 6438912]), torch.float32, "cpu"),
        "optimizer.state.fp32_param.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
        "optimizer.state.exp_avg.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
        "optimizer.state.exp_avg_sq.module.decoder.final_layernorm.weight": (torch.Size([128]), torch.float32, "cpu"),
        "optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
            torch.Size([2, 8, 1, 1, 40960]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
            torch.Size([2, 8, 1, 1, 40960]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc2.weight": (
            torch.Size([2, 8, 1, 1, 40960]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.fp32_param.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
            torch.Size([2, 8, 2, 1, 40960]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
            torch.Size([2, 8, 2, 1, 40960]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.decoder.layers.mlp.experts.experts.linear_fc1.weight": (
            torch.Size([2, 8, 2, 1, 40960]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.fp32_param.module.decoder.layers.mlp.router.weight": (
            torch.Size([2, 1, 1, 1024]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.decoder.layers.mlp.router.weight": (
            torch.Size([2, 1, 1, 1024]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.decoder.layers.mlp.router.weight": (
            torch.Size([2, 1, 1, 1024]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.fp32_param.module.decoder.layers.pre_mlp_layernorm.weight": (
            torch.Size([2, 1, 128]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.decoder.layers.pre_mlp_layernorm.weight": (
            torch.Size([2, 1, 128]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.decoder.layers.pre_mlp_layernorm.weight": (
            torch.Size([2, 1, 128]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.weight": (
            torch.Size([2, 1, 1, 49152]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.weight": (
            torch.Size([2, 1, 1, 49152]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.weight": (
            torch.Size([2, 1, 1, 49152]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
            torch.Size([2, 1, 128]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
            torch.Size([2, 1, 128]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_qkv.layer_norm_weight": (
            torch.Size([2, 1, 128]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.fp32_param.module.decoder.layers.self_attention.linear_proj.weight": (
            torch.Size([2, 1, 1, 16384]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.decoder.layers.self_attention.linear_proj.weight": (
            torch.Size([2, 1, 1, 16384]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.decoder.layers.self_attention.linear_proj.weight": (
            torch.Size([2, 1, 1, 16384]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.fp32_param.module.embedding.word_embeddings.weight": (
            torch.Size([1, 1, 6438912]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg.module.embedding.word_embeddings.weight": (
            torch.Size([1, 1, 6438912]),
            torch.float32,
            "cpu",
        ),
        "optimizer.state.exp_avg_sq.module.embedding.word_embeddings.weight": (
            torch.Size([1, 1, 6438912]),
            torch.float32,
            "cpu",
        ),
    }
    ckpt = load_dcp(output_path)

    # Handle new optimizer format
    content_metadata = load_content_metadata(output_path)
    if content_metadata and content_metadata.get('distrib_optim_sharding_type') == 'dp_reshardable':
        optim_keys = set(k for k in ckpt.keys() if k.startswith('optimizer') or k.startswith('chained_'))
        for optim_key in optim_keys:
            assert optim_key.split(".")[-1] in ["param", "exp_avg", "exp_avg_sq"]
            assert (
                "dp_group_idx" in optim_key and "gbuf_idx" in optim_key and "bucket_idx" in optim_key
            ), f"Unexpected dp_reshardable optimizer key structure: {optim_key}"
            # we can't check the exact size because it differs for different devices num
            assert len(ckpt[optim_key].shape) == 1, f"Expected {optim_key} to be 1-dimensional"

        # Trim state dicts for the rest of the checks to only compare model parts
        ckpt = {k: v for k, v in ckpt.items() if k not in optim_keys}
        expected_ckpt = {k: v for k, v in expected_ckpt.items() if not k.startswith('optimizer')}

    ckpt_keys = set(ckpt.keys())
    expected_keys = set(expected_ckpt.keys())
    assert len(ckpt) == len(expected_ckpt), (
        "Checkpoint length mismatch ",
        len(ckpt),
        len(expected_ckpt),
        ckpt_keys - expected_keys,
    )
    for key, (shape, dtype, device) in expected_ckpt.items():
        assert key in ckpt, f"Expected {key} to be in ckpt"
        assert isinstance(ckpt[key], torch.Tensor), f"Expected {key} to be a tensor"

        if len(shape) == 1 and key.startswith('optimizer.state'):
            assert ckpt[key].shape == (
                1,
                shape[0],
            ), f"Expected {key} shapes to match {ckpt[key].shape} & (1, {shape[0]})"
        else:
            assert ckpt[key].shape == shape, f"Expected {key} shapes to match {ckpt[key].shape} & {shape}"

        assert ckpt[key].dtype == dtype, f"Expected {key} dtype to match {ckpt[key].dtype} & {dtype}"
        assert str(ckpt[key].device) == device, f"Expected {key} device to match {ckpt[key].device} & {device}"


def parse_args():
    parser = argparse.ArgumentParser(description='Train a small Mixtral model using NeMo 2.0')
    parser.add_argument('--devices', type=int, default=1, help="Number of devices to use for training")
    parser.add_argument('--max-steps', type=int, default=4, help="Number of steps to train for")
    parser.add_argument(
        '--experiment-dir', type=str, default='/tmp/exp_dir', help="directory to write results and checkpoints to"
    )
    parser.add_argument('--experiment-name', type=str, default='mini_mixtral_test', help="name of experiment")
    parser.add_argument('--data-path', type=str, help="Path to data file")
    parser.add_argument('--vocab-path', type=str, default=None, help="Path to vocab file")
    parser.add_argument('--merges-path', type=str, default=None, help="Path to merges file")

    return parser.parse_args()


if __name__ == "__main__":
    main(parse_args())
