# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile

import nemo_run as run
import pytest
import torch

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.api import _validate_config
from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel


class TestValidateConfig:

    def reset_configs(self):
        model = LlamaModel(config=run.Config(Llama3Config8B))
        data = llm.MockDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2)
        trainer = nl.Trainer(strategy=nl.MegatronStrategy())
        return model, data, trainer

    def test_model_validation(self):
        model, data, trainer = self.reset_configs()
        _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.seq_length = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.num_layers = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.hidden_size = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.num_attention_heads = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.ffn_hidden_size = 0
            _validate_config(model, data, trainer)

    def test_data_validation(self):
        model, data, trainer = self.reset_configs()
        _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            data.micro_batch_size = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            data.global_batch_size = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            data.seq_length = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            data.micro_batch_size = 3
            data.global_batch_size = 128
            _validate_config(model, data, trainer)

    def test_trainer_validatiopn(self):
        model, data, trainer = self.reset_configs()
        _validate_config(model, data, trainer)

        # Basic validation
        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            trainer.strategy.tensor_model_parallel_size = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            trainer.strategy.pipeline_model_parallel_size = 0
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            trainer.strategy.context_parallel_size = 0
            _validate_config(model, data, trainer)

        # DP validation
        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            trainer.strategy.tensor_model_parallel_size = 8
            trainer.strategy.pipeline_model_parallel_size = 2
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            trainer.strategy.tensor_model_parallel_size = 3
            trainer.strategy.pipeline_model_parallel_size = 2
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            data.global_batch_size = 3
            data.micro_batch_size = 1
            trainer.strategy.tensor_model_parallel_size = 2
            trainer.strategy.pipeline_model_parallel_size = 2
            _validate_config(model, data, trainer)

        # TP/SP validation
        model, data, trainer = self.reset_configs()
        trainer.strategy.tensor_model_parallel_size = 1
        trainer.strategy.sequence_parallel = True
        _validate_config(model, data, trainer)
        assert trainer.strategy.sequence_parallel == False

        # PP/VP validation
        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            trainer.strategy.pipeline_model_parallel_size = 2
            trainer.strategy.pipeline_dtype = None
            _validate_config(model, data, trainer)

        model, data, trainer = self.reset_configs()
        trainer.strategy.pipeline_model_parallel_size = 1
        trainer.strategy.virtual_pipeline_model_parallel_size = 2
        trainer.strategy.pipeline_dtype = torch.bfloat16
        _validate_config(model, data, trainer)
        assert trainer.strategy.virtual_pipeline_model_parallel_size is None
        assert trainer.strategy.pipeline_dtype is None

        # CP validation
        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.seq_length = 5
            trainer.strategy.context_parallel_size = 2
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.seq_length = 2
            trainer.strategy.context_parallel_size = 2
            _validate_config(model, data, trainer)

        # EP validation
        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.num_moe_experts = None
            trainer.strategy.expert_model_parallel_size = 2
            _validate_config(model, data, trainer)

        with pytest.raises(AssertionError):
            model, data, trainer = self.reset_configs()
            model.config.num_moe_experts = 3
            trainer.strategy.expert_model_parallel_size = 2
            _validate_config(model, data, trainer)


class TestImportCkpt:

    def test_output_path_exists_no_overwrite(self):
        """Test that an error is raised when the output path exists and overwrite is set to False."""

        with pytest.raises(FileExistsError), tempfile.TemporaryDirectory() as output_path:
            llm.import_ckpt(
                model=llm.LlamaModel(config=llm.Llama32Config1B()),
                source="hf://meta-llama/Llama-3.2-1B",
                output_path=output_path,
                overwrite=False,
            )


class TestExportCkpt:

    def test_output_path_exists_no_overwrite(self):
        """Test that an error is raised when the output path exists and overwrite is set to False."""

        with (
            pytest.raises(FileExistsError),
            tempfile.TemporaryDirectory() as output_path,
            tempfile.TemporaryDirectory() as path,
        ):
            llm.export_ckpt(
                path=path,
                target="hf",
                output_path=output_path,
                overwrite=False,
            )
