# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch

import pytest
import torch

from nemo.collections.llm.bert.model.base import (
    BertConfig,
    BertModel,
    bert_forward_step,
    get_batch_on_this_cp_rank,
    get_packed_seq_params,
)


class TestBertBase:
    @pytest.fixture
    def sample_batch(self):
        return {
            "text": torch.randint(0, 1000, (32, 128)),  # [batch_size, seq_len]
            "padding_mask": torch.ones(32, 128),  # [batch_size, seq_len]
            "labels": torch.randint(0, 1000, (32, 128)),  # [batch_size, seq_len]
            "loss_mask": torch.ones(32, 128),  # [batch_size, seq_len]
            "types": torch.zeros(32, 128),  # [batch_size, seq_len]
        }

    @pytest.fixture
    def sample_packed_batch(self):
        return {
            "cu_seqlens": torch.tensor([0, 10, 25, 35, -1, -1]),
            "max_seqlen": torch.tensor(15),
        }

    @pytest.fixture
    def basic_config(self):
        return BertConfig(
            num_layers=6,
            hidden_size=768,
            num_attention_heads=12,
        )

    def test_get_batch_on_this_cp_rank_no_cp(self, sample_batch):
        with patch('megatron.core.parallel_state') as mock_parallel_state:
            mock_parallel_state.get_context_parallel_world_size.return_value = 1

            result = get_batch_on_this_cp_rank(sample_batch)

            # When context parallel size is 1, should return original batch unchanged
            assert result == sample_batch

    def test_get_packed_seq_params(self, sample_packed_batch):
        params = get_packed_seq_params(sample_packed_batch)

        # Check that cu_seqlens was properly trimmed using cu_seqlens_argmin
        assert params.cu_seqlens_q.shape[0] == 4  # Should trim at index 4
        assert params.max_seqlen_q == sample_packed_batch["max_seqlen"]
        assert params.qkv_format == "thd"

    def test_bert_config_initialization(self, basic_config):
        assert basic_config.num_layers == 6
        assert basic_config.hidden_size == 768
        assert basic_config.num_attention_heads == 12
        assert basic_config.bert_type == "megatron"  # default value
        assert basic_config.add_pooler is True  # default value

    def test_bert_model_initialization(self, basic_config):
        tokenizer = MagicMock()
        tokenizer.vocab_size = 30000

        model = BertModel(config=basic_config, tokenizer=tokenizer)

        assert model.config == basic_config
        assert model.tokenizer == tokenizer
        assert hasattr(model, "optim")

    def test_bert_forward_step(self, basic_config, sample_batch):
        model = BertModel(config=basic_config, tokenizer=None)
        model.module = MagicMock()

        bert_forward_step(model, sample_batch)

        # Verify model was called with correct arguments
        model.module.assert_called_once()
        call_args = model.module.call_args[1]
        assert "input_ids" in call_args
        assert "attention_mask" in call_args
        assert "lm_labels" in call_args
        assert "loss_mask" in call_args

    def test_bert_forward_step_with_tokentypes(self, basic_config, sample_batch):
        basic_config.num_tokentypes = 2
        model = BertModel(config=basic_config, tokenizer=None)
        model.module = MagicMock()

        result = bert_forward_step(model, sample_batch)

        # Verify tokentype_ids was included in forward call
        call_args = model.module.call_args[1]
        assert "tokentype_ids" in call_args

    def test_bert_forward_step_with_packed_seqs(self, basic_config, sample_batch):
        model = BertModel(config=basic_config, tokenizer=None)
        model.module = MagicMock()

        # Add packed sequence params to batch
        sample_batch["cu_seqlens"] = torch.tensor([0, 10, 20])

        result = bert_forward_step(model, sample_batch)

        # Verify packed_seq_params was included
        call_args = model.module.call_args[1]
        assert "packed_seq_params" in call_args

    def test_bert_model_training_step(self, basic_config):
        model = BertModel(config=basic_config, tokenizer=None)
        model.forward_step = MagicMock()

        batch = {"dummy": "batch"}
        output = model.training_step(batch)

        model.forward_step.assert_called_once_with(batch)

    def test_bert_model_validation_step(self, basic_config):
        model = BertModel(config=basic_config, tokenizer=None)
        model.forward_step = MagicMock()

        batch = {"dummy": "batch"}
        output = model.validation_step(batch)

        model.forward_step.assert_called_once_with(batch)

    def test_get_batch_with_context_parallel(self, sample_batch):
        with patch('megatron.core.parallel_state') as mock_parallel_state:
            mock_parallel_state.get_context_parallel_world_size.return_value = 2
            mock_parallel_state.get_context_parallel_rank.return_value = 0

            result = get_batch_on_this_cp_rank(sample_batch)

            # Verify batch was properly split for context parallel processing
            for key, val in result.items():
                if val is not None:
                    if key != "attention_mask":
                        assert val.shape[1] == sample_batch[key].shape[1]
                    else:
                        assert val.shape[2] == sample_batch[key].shape[2]
