# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

import numpy as np
import pytest
import torch
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.models import EncDecCTCModel

try:
    from eff.cookbooks import NeMoCookbook

    _EFF_PRESENT_ = True
except ImportError:
    _EFF_PRESENT_ = False

# A decorator marking the EFF requirement.
requires_eff = pytest.mark.skipif(not _EFF_PRESENT_, reason="Export File Format library required to run test")


@pytest.fixture()
def asr_model():
    preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in': 64,
            'activation': 'relu',
            'conv_mask': True,
            'jasper': [
                {
                    'filters': 1024,
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                }
            ],
        },
    }

    decoder = {
        'cls': 'nemo.collections.asr.modules.ConvASRDecoder',
        'params': {
            'feat_in': 1024,
            'num_classes': 28,
            'vocabulary': [
                ' ',
                'a',
                'b',
                'c',
                'd',
                'e',
                'f',
                'g',
                'h',
                'i',
                'j',
                'k',
                'l',
                'm',
                'n',
                'o',
                'p',
                'q',
                'r',
                's',
                't',
                'u',
                'v',
                'w',
                'x',
                'y',
                'z',
                "'",
            ],
        },
    }
    modelConfig = DictConfig(
        {'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder)}
    )

    model_instance = EncDecCTCModel(cfg=modelConfig)
    return model_instance


class TestFileIO:
    @pytest.mark.unit
    def test_to_from_config_file(self, asr_model):
        """" Test makes sure that the second instance created with the same configuration (BUT NOT checkpoint)
        has different weights. """

        with tempfile.NamedTemporaryFile() as fp:
            yaml_filename = fp.name
            asr_model.to_config_file(path2yaml_file=yaml_filename)
            next_instance = EncDecCTCModel.from_config_file(path2yaml_file=yaml_filename)

            assert isinstance(next_instance, EncDecCTCModel)

            assert len(next_instance.decoder.vocabulary) == 28
            assert asr_model.num_weights == next_instance.num_weights

            w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()
            w2 = next_instance.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()

            assert not np.array_equal(w1, w2)

    @pytest.mark.unit
    def test_save_restore_from_nemo_file(self, asr_model):
        """" Test makes sure that the second instance created from the same configuration AND checkpoint 
        has the same weights. """

        with tempfile.NamedTemporaryFile() as fp:
            filename = fp.name

            # Save model (with random artifact).
            with tempfile.NamedTemporaryFile() as artifact:
                asr_model.register_artifact(config_path="abc", src=artifact.name)
                asr_model.save_to(save_path=filename)

            # Restore the model.
            asr_model2 = EncDecCTCModel.restore_from(restore_path=filename)

            assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary)
            assert asr_model.num_weights == asr_model2.num_weights

            w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()
            w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()

            assert np.array_equal(w1, w2)

    @requires_eff
    @pytest.mark.unit
    def test_eff_save_restore_from_nemo_file_encrypted(self, asr_model):
        """" Test makes sure that after encrypted save-restore the model has the same weights. """

        with tempfile.NamedTemporaryFile() as fp:
            filename = fp.name

            # Set key - use checkpoint encryption.
            NeMoCookbook.set_encryption_key("test_key")

            # Save model (with random artifact).
            with tempfile.NamedTemporaryFile() as artifact:
                asr_model.register_artifact(config_path="abc", src=artifact.name)
                asr_model.save_to(save_path=filename)

            # Try to restore the encrypted archive (weights) without the encryption key.
            NeMoCookbook.set_encryption_key(None)
            with pytest.raises(PermissionError):
                # Restore the model.
                asr_model2 = EncDecCTCModel.restore_from(restore_path=filename)

            # Restore the model.
            NeMoCookbook.set_encryption_key("test_key")
            asr_model3 = EncDecCTCModel.restore_from(restore_path=filename)
            # Reset encryption so it won't mess up with other save/restore.
            NeMoCookbook.set_encryption_key(None)

            assert asr_model.num_weights == asr_model3.num_weights

    @pytest.mark.unit
    def test_save_restore_from_nemo_file_with_override(self, asr_model, tmpdir):
        """" Test makes sure that the second instance created from the same configuration AND checkpoint
        has the same weights.

        Args:
            tmpdir: fixture providing a temporary directory unique to the test invocation.
        """
        # Name of the archive in tmp folder.
        filename = os.path.join(tmpdir, "eff.nemo")

        # Get path where the command is executed - the artifacts will be "retrieved" there.
        # (original .nemo behavior)
        cwd = os.getcwd()

        with tempfile.NamedTemporaryFile(mode='a+') as conf_fp:

            # Create a "random artifact".
            with tempfile.NamedTemporaryFile(mode="w", delete=False) as artifact:
                artifact.write("magic content 42")
            # Remember the filename of the artifact.
            _, artifact_filename = os.path.split(artifact.name)
            # Add artifact to model.
            asr_model.register_artifact(config_path="abc", src=artifact.name)
            # Save model (with "random artifact").
            asr_model.save_to(save_path=filename)

            # Modify config slightly
            cfg = asr_model.cfg
            cfg.encoder.activation = 'swish'
            yaml_cfg = OmegaConf.to_yaml(cfg)
            conf_fp.write(yaml_cfg)
            conf_fp.seek(0)

            # Restore the model.
            asr_model2 = EncDecCTCModel.restore_from(restore_path=filename, override_config_path=conf_fp.name)

            assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary)
            assert asr_model.num_weights == asr_model2.num_weights

            w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()
            w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()

            assert np.array_equal(w1, w2)

            assert asr_model2.cfg.encoder.activation == 'swish'

    @pytest.mark.unit
    def test_save_model_level_pt_ckpt(self, asr_model):
        with tempfile.TemporaryDirectory() as ckpt_dir:
            nemo_file = os.path.join(ckpt_dir, 'asr.nemo')
            asr_model.save_to(nemo_file)

            # Save model level PT checkpoint
            asr_model.extract_state_dict_from(nemo_file, ckpt_dir)
            ckpt_path = os.path.join(ckpt_dir, 'model_weights.ckpt')

            assert os.path.exists(ckpt_path)

            # Restore the model.
            asr_model2 = EncDecCTCModel.restore_from(restore_path=nemo_file)

            assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary)
            assert asr_model.num_weights == asr_model2.num_weights

            # Change weights values
            asr_model2.encoder.encoder[0].mconv[0].conv.weight.data += 1.0

            w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()
            w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()

            assert not np.array_equal(w1, w2)

            # Restore from checkpoint
            asr_model2.load_state_dict(torch.load(ckpt_path))

            w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()
            w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()

            assert np.array_equal(w1, w2)

    @pytest.mark.unit
    def test_save_module_level_pt_ckpt(self, asr_model):
        with tempfile.TemporaryDirectory() as ckpt_dir:
            nemo_file = os.path.join(ckpt_dir, 'asr.nemo')
            asr_model.save_to(nemo_file)

            # Save model level PT checkpoint
            asr_model.extract_state_dict_from(nemo_file, ckpt_dir, split_by_module=True)
            encoder_path = os.path.join(ckpt_dir, 'encoder.ckpt')
            decoder_path = os.path.join(ckpt_dir, 'decoder.ckpt')
            preprocessor_path = os.path.join(ckpt_dir, 'preprocessor.ckpt')

            assert os.path.exists(encoder_path)
            assert os.path.exists(decoder_path)
            assert os.path.exists(preprocessor_path)

            # Restore the model.
            asr_model2 = EncDecCTCModel.restore_from(restore_path=nemo_file)

            assert len(asr_model.decoder.vocabulary) == len(asr_model2.decoder.vocabulary)
            assert asr_model.num_weights == asr_model2.num_weights

            # Change weights values
            asr_model2.encoder.encoder[0].mconv[0].conv.weight.data += 1.0

            w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()
            w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()

            assert not np.array_equal(w1, w2)

            # Restore from checkpoint
            asr_model2.encoder.load_state_dict(torch.load(encoder_path))

            w1 = asr_model.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()
            w2 = asr_model2.encoder.encoder[0].mconv[0].conv.weight.data.detach().cpu().numpy()

            assert np.array_equal(w1, w2)
