# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import os

import tempfile

import lightning.pytorch as pl
import numpy as np
import pytest
import soundfile as sf
import torch
from omegaconf import DictConfig, ListConfig

from nemo.collections.asr.data import audio_to_label
from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel, configs
from nemo.utils.config_utils import assert_dataclass_signature_match


@pytest.fixture()
def speech_classification_model():
    preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in': 64,
            'activation': 'relu',
            'conv_mask': True,
            'jasper': [
                {
                    'filters': 32,
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                }
            ],
        },
    }

    decoder = {
        'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification',
        'params': {
            'feat_in': 32,
            'num_classes': 30,
        },
    }

    modelConfig = DictConfig(
        {
            'preprocessor': DictConfig(preprocessor),
            'encoder': DictConfig(encoder),
            'decoder': DictConfig(decoder),
            'labels': ListConfig(["dummy_cls_{}".format(i + 1) for i in range(30)]),
        }
    )
    model = EncDecClassificationModel(cfg=modelConfig)
    return model


@pytest.fixture()
def frame_classification_model():
    preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in': 64,
            'activation': 'relu',
            'conv_mask': True,
            'jasper': [
                {
                    'filters': 32,
                    'repeat': 1,
                    'kernel': [1],
                    'stride': [1],
                    'dilation': [1],
                    'dropout': 0.0,
                    'residual': False,
                    'separable': True,
                    'se': True,
                    'se_context_size': -1,
                }
            ],
        },
    }

    decoder = {
        'cls': 'nemo.collections.common.parts.MultiLayerPerceptron',
        'params': {
            'hidden_size': 32,
            'num_classes': 5,
        },
    }

    optim = {
        'name': 'sgd',
        'lr': 0.01,
        'weight_decay': 0.001,
        'momentum': 0.9,
    }

    modelConfig = DictConfig(
        {
            'preprocessor': DictConfig(preprocessor),
            'encoder': DictConfig(encoder),
            'decoder': DictConfig(decoder),
            'optim': DictConfig(optim),
            'labels': ListConfig(["0", "1"]),
        }
    )
    model = EncDecFrameClassificationModel(cfg=modelConfig)
    return model


class TestEncDecClassificationModel:
    @pytest.mark.unit
    def test_constructor(self, speech_classification_model):
        asr_model = speech_classification_model.train()

        conv_cnt = (64 * 32 * 1 + 32) + (64 * 1 * 1 + 32)  # separable kernel + bias + pointwise kernel + bias
        bn_cnt = (4 * 32) * 2  # 2 * moving averages
        dec_cnt = 32 * 30 + 30  # fc + bias

        param_count = conv_cnt + bn_cnt + dec_cnt
        assert asr_model.num_weights == param_count

        # Check to/from config_dict:
        confdict = asr_model.to_config_dict()
        instance2 = EncDecClassificationModel.from_config_dict(confdict)

        assert isinstance(instance2, EncDecClassificationModel)

    @pytest.mark.unit
    def test_forward(self, speech_classification_model):
        asr_model = speech_classification_model.eval()

        asr_model.preprocessor.featurizer.dither = 0.0
        asr_model.preprocessor.featurizer.pad_to = 0

        input_signal = torch.randn(size=(4, 512))
        length = torch.randint(low=321, high=500, size=[4])

        with torch.no_grad():
            # batch size 1
            logprobs_instance = []
            for i in range(input_signal.size(0)):
                logprobs_ins = asr_model.forward(
                    input_signal=input_signal[i : i + 1], input_signal_length=length[i : i + 1]
                )
                logprobs_instance.append(logprobs_ins)
            logprobs_instance = torch.cat(logprobs_instance, 0)

            # batch size 4
            logprobs_batch = asr_model.forward(input_signal=input_signal, input_signal_length=length)

        assert logprobs_instance.shape == logprobs_batch.shape
        diff = torch.mean(torch.abs(logprobs_instance - logprobs_batch))
        assert diff <= 1e-6
        diff = torch.max(torch.abs(logprobs_instance - logprobs_batch))
        assert diff <= 1e-6

    @pytest.mark.unit
    def test_vocab_change(self, speech_classification_model):
        asr_model = speech_classification_model.train()

        old_labels = copy.deepcopy(asr_model._cfg.labels)
        nw1 = asr_model.num_weights
        asr_model.change_labels(new_labels=old_labels)
        # No change
        assert nw1 == asr_model.num_weights
        new_labels = copy.deepcopy(old_labels)
        new_labels.append('dummy_cls_31')
        new_labels.append('dummy_cls_32')
        new_labels.append('dummy_cls_33')
        asr_model.change_labels(new_labels=new_labels)
        # fully connected + bias
        assert asr_model.num_weights == nw1 + 3 * (asr_model.decoder._feat_in + 1)

    @pytest.mark.unit
    def test_transcription(self, speech_classification_model, test_data_dir):
        # Ground truth labels = ["yes", "no"]
        audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav']
        audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames]

        model = speech_classification_model.eval()

        # Test Top 1 classification transcription
        results = model.transcribe(audio_paths, batch_size=2)
        assert len(results) == 2
        assert results[0].shape == torch.Size([1])

        # Test Top 5 classification transcription
        model._accuracy.top_k = [5]  # set top k to 5 for accuracy calculation
        results = model.transcribe(audio_paths, batch_size=2)
        assert len(results) == 2
        assert results[0].shape == torch.Size([5])

        # Test Top 1 and Top 5 classification transcription
        model._accuracy.top_k = [1, 5]
        results = model.transcribe(audio_paths, batch_size=2)
        assert len(results) == 2
        assert results[0].shape == torch.Size([2, 1])
        assert results[1].shape == torch.Size([2, 5])
        assert model._accuracy.top_k == [1, 5]

        # Test log probs extraction
        model._accuracy.top_k = [1]
        results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
        assert len(results) == 2
        assert results[0].shape == torch.Size([len(model.cfg.labels)])

        # Test log probs extraction remains same for any top_k
        model._accuracy.top_k = [5]
        results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
        assert len(results) == 2
        assert results[0].shape == torch.Size([len(model.cfg.labels)])

    @pytest.mark.unit
    def test_EncDecClassificationDatasetConfig_for_AudioToSpeechLabelDataset(self):
        # ignore some additional arguments as dataclass is generic
        IGNORE_ARGS = [
            'is_tarred',
            'num_workers',
            'batch_size',
            'tarred_audio_filepaths',
            'shuffle',
            'pin_memory',
            'drop_last',
            'tarred_shard_strategy',
            'shuffle_n',
            # `featurizer` is supplied at runtime
            'featurizer',
            # additional ignored arguments
            'vad_stream',
            'int_values',
            'sample_rate',
            'normalize_audio',
            'augmentor',
            'bucketing_batch_size',
            'bucketing_strategy',
            'bucketing_weights',
        ]

        REMAP_ARGS = {'trim_silence': 'trim'}

        result = assert_dataclass_signature_match(
            audio_to_label.AudioToSpeechLabelDataset,
            configs.EncDecClassificationDatasetConfig,
            ignore_args=IGNORE_ARGS,
            remap_args=REMAP_ARGS,
        )
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None


class TestEncDecFrameClassificationModel(TestEncDecClassificationModel):
    @pytest.mark.parametrize(["logits_len", "labels_len"], [(20, 10), (21, 10), (19, 10), (20, 9), (20, 11)])
    @pytest.mark.unit
    def test_reshape_labels(self, frame_classification_model, logits_len, labels_len):
        model = frame_classification_model.eval()

        logits = torch.ones(4, logits_len, 2)
        labels = torch.ones(4, labels_len)
        logits_len = torch.tensor([6, 7, 8, 9])
        labels_len = torch.tensor([5, 6, 7, 8])
        labels_new, labels_len_new = model.reshape_labels(
            logits=logits, labels=labels, logits_len=logits_len, labels_len=labels_len
        )
        assert labels_new.size(1) == logits.size(1)
        assert torch.equal(labels_len_new, torch.tensor([6, 7, 8, 9]))

    @pytest.mark.unit
    def test_EncDecClassificationDatasetConfig_for_AudioToMultiSpeechLabelDataset(self):
        # ignore some additional arguments as dataclass is generic
        IGNORE_ARGS = [
            'is_tarred',
            'num_workers',
            'batch_size',
            'tarred_audio_filepaths',
            'shuffle',
            'pin_memory',
            'drop_last',
            'tarred_shard_strategy',
            'shuffle_n',
            # `featurizer` is supplied at runtime
            'featurizer',
            # additional ignored arguments
            'vad_stream',
            'int_values',
            'sample_rate',
            'normalize_audio',
            'augmentor',
            'bucketing_batch_size',
            'bucketing_strategy',
            'bucketing_weights',
            'delimiter',
            'normalize_audio_db',
            'normalize_audio_db_target',
            'window_length_in_sec',
            'shift_length_in_sec',
        ]

        REMAP_ARGS = {'trim_silence': 'trim'}

        result = assert_dataclass_signature_match(
            audio_to_label.AudioToMultiLabelDataset,
            configs.EncDecClassificationDatasetConfig,
            ignore_args=IGNORE_ARGS,
            remap_args=REMAP_ARGS,
        )
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None

    @pytest.mark.unit
    def test_frame_classification_model(self, frame_classification_model: EncDecFrameClassificationModel):
        with tempfile.TemporaryDirectory() as temp_dir:
            # generate random audio
            audio = np.random.randn(16000 * 1)
            # save the audio
            audio_path = os.path.join(temp_dir, "audio.wav")
            sf.write(audio_path, audio, 16000)

            dummy_labels = "0 0 0 0 1 1 1 1 0 0 0 0"

            dummy_sample = {
                "audio_filepath": audio_path,
                "offset": 0.0,
                "duration": 1.0,
                "label": dummy_labels,
            }

            # create a manifest file
            manifest_path = os.path.join(temp_dir, "dummy_manifest.json")
            with open(manifest_path, "w") as f:
                for i in range(4):
                    f.write(json.dumps(dummy_sample) + "\n")

            dataloader_cfg = {
                "batch_size": 2,
                "manifest_filepath": manifest_path,
                "sample_rate": 16000,
                "num_workers": 0,
                "shuffle": False,
                "labels": ["0", "1"],
            }

            trainer_cfg = {
                "max_epochs": 1,
                "devices": 1,
                "accelerator": "auto",
            }

            optim = {
                'name': 'sgd',
                'lr': 0.01,
                'weight_decay': 0.001,
                'momentum': 0.9,
            }

            trainer = pl.Trainer(**trainer_cfg)
            frame_classification_model.set_trainer(trainer)
            frame_classification_model.setup_optimization(DictConfig(optim))
            frame_classification_model.setup_training_data(dataloader_cfg)
            frame_classification_model.setup_validation_data(dataloader_cfg)

            trainer.fit(frame_classification_model)
