# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from nemo.core import NeuralModule
from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins
from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin


class DefaultModule(NeuralModule):
    def __init__(self):
        super().__init__()

        self.fc = torch.nn.Linear(50, 50)
        self.bn = torch.nn.BatchNorm1d(50)

    def forward(self, x):
        x = self.fc(x)
        x = self.bn(x)
        out = x
        return out

    def num_params(self):
        num: int = 0
        for p in self.parameters():
            if p.requires_grad:
                num += p.numel()
        return num


class DefaultModuleAdapter(DefaultModule, AdapterModuleMixin):
    def forward(self, x):
        x = super(DefaultModuleAdapter, self).forward(x)

        if self.is_adapter_available():
            # For testing purposes, cache the adapter names
            self._adapter_names = self.get_enabled_adapters()
            # call forward over model adapters, summing them up
            x = self.forward_enabled_adapters(x)

        return x


def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre'):
    cfg = {
        '_target_': 'nemo.collections.common.parts.adapter_modules.LinearAdapter',
        'in_features': in_features,
        'dim': dim,
        'norm_position': norm_pos,
    }
    return cfg


def get_classpath(cls):
    return f'{cls.__module__}.{cls.__name__}'


if adapter_mixins.get_registered_adapter(DefaultModule) is None:
    adapter_mixins.register_adapter(DefaultModule, DefaultModuleAdapter)


class TestAdapterMixin:
    @pytest.mark.unit
    def test_module_registered_adapter_by_class_path(self):
        classpath = get_classpath(DefaultModule)
        adapter_meta = adapter_mixins.get_registered_adapter(classpath)
        assert adapter_meta is not None
        assert adapter_meta.base_class == DefaultModule
        assert adapter_meta.adapter_class == DefaultModuleAdapter

    @pytest.mark.unit
    def test_module_registered_adapter_by_class(self):
        adapter_meta = adapter_mixins.get_registered_adapter(DefaultModule)
        assert adapter_meta is not None
        assert adapter_meta.base_class == DefaultModule
        assert adapter_meta.adapter_class == DefaultModuleAdapter

    @pytest.mark.unit
    def test_module_registered_adapter_by_adapter_class(self):
        adapter_meta = adapter_mixins.get_registered_adapter(DefaultModuleAdapter)
        assert adapter_meta is not None
        assert adapter_meta.base_class == DefaultModule
        assert adapter_meta.adapter_class == DefaultModuleAdapter

    @pytest.mark.unit
    def test_single_adapter(self):
        model = DefaultModuleAdapter()
        original_num_params = model.num_params()

        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg())
        new_num_params = model.num_params()
        assert new_num_params > original_num_params

    @pytest.mark.unit
    def test_multiple_adapter(self):
        model = DefaultModuleAdapter()
        original_num_params = model.num_params()

        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg())
        new_num_params = model.num_params()
        assert new_num_params > original_num_params

        original_num_params = new_num_params
        model.add_adapter(name='adapter_1', cfg=get_adapter_cfg())
        new_num_params = model.num_params()
        assert new_num_params > original_num_params

    @pytest.mark.unit
    def test_forward_linear_pre(self):
        torch.random.manual_seed(0)
        x = torch.randn(2, 50)

        model = DefaultModuleAdapter()
        origial_output = model(x)

        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg())
        new_output = model(x)

        assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5

    @pytest.mark.unit
    def test_forward_linear_post(self):
        torch.random.manual_seed(0)
        x = torch.randn(2, 50)

        model = DefaultModuleAdapter()
        origial_output = model(x)

        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg(norm_pos='post'))
        new_output = model(x)

        assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5

    @pytest.mark.unit
    def test_multi_adapter_forward(self):
        torch.random.manual_seed(0)
        x = torch.randn(2, 50)

        model = DefaultModuleAdapter()
        origial_output = model(x)

        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg())
        model.add_adapter(name='adapter_1', cfg=get_adapter_cfg())
        new_output = model(x)

        assert model._adapter_names == ['adapter_0', 'adapter_1']
        assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5

    @pytest.mark.unit
    def test_multi_adapter_partial_forward(self):
        torch.random.manual_seed(0)
        x = torch.randn(2, 50)

        model = DefaultModuleAdapter()
        origial_output = model(x)

        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg())
        model.add_adapter(name='adapter_1', cfg=get_adapter_cfg())

        model.set_enabled_adapters(name='adapter_0', enabled=False)
        new_output = model(x)

        assert model._adapter_names == ['adapter_1']
        assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5

    @pytest.mark.unit
    def test_forward_unfrozen_adapters(self):
        model = DefaultModuleAdapter()
        original_num_params = model.num_params()

        dim = 10
        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg(dim=dim))
        model.freeze()
        model.unfreeze_enabled_adapters()

        assert original_num_params == 2650

        original_params = 0
        adapter_params = 0
        for name, param in model.named_parameters():
            if 'adapter' not in name:
                assert param.requires_grad is False
                original_params += param.numel()
            else:
                assert param.requires_grad is True
                adapter_params += param.numel()

        for mname, module in model.named_modules():
            if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
                assert module.track_running_stats is False

        assert original_params > adapter_params

    @pytest.mark.unit
    def test_forward_linear_no_strategy(self):
        torch.random.manual_seed(0)
        x = torch.randn(2, 50)

        model = DefaultModuleAdapter()
        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg())

        # delete the strategy
        adapter_module = model.adapter_layer[model.get_enabled_adapters()[0]]
        del adapter_module.adapter_strategy

        with pytest.raises(AttributeError):
            _ = model(x)

    @pytest.mark.unit
    def test_forward_linear_replaced_strategy(self):
        class MultiplyAdapterStrategy(adapter_mixin_strategies.AbstractAdapterStrategy):
            def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: AdapterModuleMixin):
                out = adapter(input)
                return input * out

        torch.random.manual_seed(0)
        x = torch.randn(2, 50)

        model = DefaultModuleAdapter()
        model.add_adapter(name='adapter_0', cfg=get_adapter_cfg())

        # modify the strategy
        adapter_module = model.adapter_layer[model.get_enabled_adapters()[0]]
        adapter_module.adapter_strategy = MultiplyAdapterStrategy()

        out = model(x)
        # result of adapter is zero tensor, output multiplied by adapter result should be zero
        assert (out > 0.0).any() == torch.tensor(False)
