# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Dict, Optional, Type

import onnx
import torch
import torch.nn as nn
import torch.nn.functional as F

from nemo.utils import CastToFloat, CastToFloatAll, logging

try:
    import onnxruntime

    ort_available = True
except (ImportError, ModuleNotFoundError):
    ort_available = False


class ExportFormat(Enum):
    """Which format to use when exporting a Neural Module for deployment"""

    ONNX = 1
    TORCHSCRIPT = 2


_EXT_DICT = {
    ".pt": ExportFormat.TORCHSCRIPT,
    ".ts": ExportFormat.TORCHSCRIPT,
    ".onnx": ExportFormat.ONNX,
}


class TorchRMSNorm(nn.Module):
    def __init__(self, weight, eps=1e-6):
        """
        LayerNorm without bias
        """
        super().__init__()
        self.weight = weight
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # can be only calculated with precision=32
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        return self.weight * hidden_states


class LinearWithBiasSkip(nn.Module):
    def __init__(self, weight, bias, skip_bias_add):
        super(LinearWithBiasSkip, self).__init__()
        self.bias = bias
        self.weight = weight
        self.skip_bias_add = skip_bias_add

    def forward(self, x, weight=None):
        if weight is None:
            weight = self.weight
        if self.skip_bias_add:
            return F.linear(x, weight), self.bias
        return F.linear(x, weight, self.bias), None


def get_export_format(filename: str):
    _, ext = os.path.splitext(filename)
    try:
        return _EXT_DICT[ext.lower()]
    except KeyError:
        raise ValueError(f"Export file {filename} extension does not correspond to any export format!")


def augment_filename(output: str, prepend: str):
    if prepend == 'self':
        return output

    path, filename = os.path.split(output)
    filename = f"{prepend}-{filename}"
    return os.path.join(path, filename)


def forward_method(self):
    if hasattr(self, "forward_for_export"):
        return self.forward_for_export
    else:
        return self.forward


def wrap_forward_method(self):
    tp = type(self)
    old_forward_method = None
    if hasattr(tp, "forward_for_export"):
        forward_method = tp.forward_for_export
        old_forward_method = tp.forward
        tp.forward = forward_method
    else:
        forward_method = None
    return forward_method, old_forward_method


def parse_input_example(input_example):
    input_list = list(input_example)
    input_dict = {}
    # process possible kwargs
    if isinstance(input_list[-1], dict):
        input_dict = input_list[-1]
        input_list = input_list[:-1]
    return input_list, input_dict


def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list):
    odict = {}
    if not input_names:
        input_list.extend(input_dict.values())
        for k, v in zip(ort_input_names, input_list):
            odict[k] = v.cpu().numpy()
        return odict
    for k in reversed(input_names):
        val = None
        if k in input_dict:
            val = input_dict[k].cpu().numpy()
        elif len(input_list) > 0:
            val = input_list.pop().cpu().numpy()
        if k in ort_input_names and val is not None:
            odict[k] = val
    return odict


def verify_torchscript(model, output, input_examples, check_tolerance=0.01):
    all_good = True
    for input_example in input_examples:
        input_list, input_dict = parse_input_example(input_example)
        # We disable autocast here to make sure exported TS will run under Triton or other C++ env
        with torch.amp.autocast('cuda', enabled=False):
            output_example = model.forward(*input_list, **input_dict)
            ts_model = torch.jit.load(output)
            all_good = all_good and run_ts_and_compare(
                ts_model, input_list, input_dict, output_example, check_tolerance
            )
    status = "SUCCESS" if all_good else "FAIL"
    logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status)
    return all_good


def verify_runtime(model, output, input_examples, input_names, check_tolerance=0.01):
    onnx_model = onnx.load(output)
    ort_input_names = [node.name for node in onnx_model.graph.input]

    global ort_available
    if not ort_available:
        logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n")
        onnx.checker.check_model(onnx_model, full_check=True)
        return
    onnx_session_opt = onnxruntime.SessionOptions()
    onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
    sess = onnxruntime.InferenceSession(
        onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider']
    )
    del onnx_model
    all_good = True
    for input_example in input_examples:
        input_list, input_dict = parse_input_example(input_example)
        output_example = model.forward(*input_list, **input_dict)
        if not isinstance(output_example, tuple):
            output_example = (output_example,)
        ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list)
        all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance)
    status = "SUCCESS" if all_good else "FAIL"
    logging.info(f"ONNX generated at {output} verified with onnxruntime : " + status)
    return all_good


def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01):
    # Verify the model can be read, and is valid
    ts_out = ts_model(*ts_input_list, **ts_input_dict)

    all_good = True
    for i, out in enumerate(ts_out):
        expected = output_example[i]

        if torch.is_tensor(expected):
            tout = out.to('cpu')
            logging.debug(f"Checking output {i}, shape: {expected.shape}:\n")
            this_good = True
            try:
                if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
                    this_good = False
            except Exception:  # there may ne size mismatch and it may be OK
                this_good = False
            if not this_good:
                logging.info(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
                all_good = False
    return all_good


def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
    # Verify the model can be read, and is valid
    ort_out = sess.run(None, ort_input)
    all_good = True
    for i, out in enumerate(ort_out):
        expected = output_example[i]

        if torch.is_tensor(expected):
            tout = torch.from_numpy(out)
            logging.debug(f"Checking output {i}, shape: {expected.shape}:\n")
            this_good = True
            try:
                if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
                    this_good = False
            except Exception:  # there may be size mismatch and it may be OK
                this_good = False
            if not this_good:
                logging.info(
                    f"onnxruntime results mismatch! PyTorch(expected, {expected.shape}):\n{expected}\nONNXruntime, {tout.shape}:\n{tout}"
                )
                all_good = False
    return all_good


apex_available = True

try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNorm
    from apex.normalization import MixedFusedRMSNorm
    from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
    from megatron.core.fusions.fused_layer_norm import FusedLayerNorm as MCoreFusedLayerNorm
    from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
    from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear

    def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
        """
        Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export.
        Args:
           n: the FusedLayerNorm pytorch module to replace
        Returns:
           Equivalent LayerNorm module
        """

        p = next(n.parameters())

        if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
            shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
        elif isinstance(n, MCoreFusedLayerNorm):
            shape, eps, affine = n.weight.shape, n.eps, True
        elif isinstance(n, FastLayerNorm):
            shape, eps, affine = n.weight.shape, n.epsilon, True
        else:
            return None

        n_state = n.state_dict()
        mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)

        mod.load_state_dict(n_state, strict=True)

        return mod

    def replace_MixedFusedRMSNorm(n: nn.Module):
        """
        Replaces Apex's MixedFusedRMSNorm with equivalent Pytorch layer. This is required for ONNX export.
        Args:
           n: the MixedFusedRMSNorm pytorch module to replace
        Returns:
           Equivalent module
        """

        p = next(n.parameters())

        if isinstance(n, MixedFusedRMSNorm):
            mod = TorchRMSNorm(n.state_dict()['weight'], n.eps).to(p.device)
        else:
            return None

        return mod

    def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
        """
        Replaces Apex's ColumnParallelLinear or RowParallelLinear with nn.Linear
        Args:
           n: the nn.Module pytorch module to replace
        Returns:
           Equivalent Linear module
        """
        if not (isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)):
            raise ValueError("This function can only change the ColumnParallelLinear or RowParallelLinear module.")

        dev = next(n.parameters()).device
        mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)

        n_state = n.state_dict()
        mod.load_state_dict(n_state, strict=False)
        return mod

    def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
        """
        Replaces Apex's FusedScaleMaskSoftmax with nn.LayerNorm. This is required for ONNX export.
        Args:
           n: the FusedScaleMaskSoftmax module to replace
        Returns:
           Equivalent LayerNorm module
        """
        if not isinstance(n, FusedScaleMaskSoftmax):
            logging.warning(f"This function can only change the FusedScaleMaskSoftmax module, got: {n.__class__}")
            return n

        # disable the fusion only
        mod = FusedScaleMaskSoftmax(
            n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
        )

        return mod

    default_Apex_replacements = {
        "FusedLayerNorm": replace_FusedLayerNorm,
        "MixedFusedLayerNorm": replace_FusedLayerNorm,
        "MCoreFusedLayerNorm": replace_FusedLayerNorm,
        "FastLayerNorm": replace_FusedLayerNorm,
        "RowParallelLinear": replace_ParallelLinear,
        "ColumnParallelLinear": replace_ParallelLinear,
        "FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax,
        "MixedFusedRMSNorm": replace_MixedFusedRMSNorm,
    }

except Exception as e:
    default_Apex_replacements = {}
    apex_available = False


def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
    """
    Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied.
    Args:
        BaseT : module type to replace
        DestT : destination module type
    Returns:
        swap function to replace BaseT module with DestT
    """

    def expansion_fn(mod: nn.Module) -> Optional[nn.Module]:
        if not isinstance(mod, BaseT):
            return None
        args = [getattr(mod, name, None) for name in mod.__constants__]
        out = DestT(*args)
        return out

    return expansion_fn


def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
    """
    Replaces MatchedScaleMaskSoftmax with exportable softmax layer
    Args:
        n: module to replace
    Returns:
        exportable module
    """
    # including the import here to avoid circular imports
    from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax

    # disabling fusion for the MatchedScaleMaskSoftmax
    mod = MatchedScaleMaskSoftmax(
        n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
    )
    return mod


def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
    """
    Generic function generator to replace BaseT module with DestT wrapper.
    Args:
        BaseT : module type to replace
        DestT : destination module type
    Returns:
        swap function to replace BaseT module with DestT
    """

    def expansion_fn(mod: nn.Module) -> Optional[nn.Module]:
        out = DestT(mod)
        return out

    return expansion_fn


def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]):
    """
    This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows
    for swapping nested modules through arbitrary levels if children

    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

    """
    for path, new_mod in mapping.items():
        expanded_path = path.split(".")
        parent_mod = model
        for sub_path in expanded_path[:-1]:
            parent_mod = parent_mod._modules[sub_path]  # noqa
        parent_mod._modules[expanded_path[-1]] = new_mod  # noqa

    return model


def replace_modules(
    model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None
) -> nn.Module:
    """
    Top-level function to replace modules in model, specified by class name with a desired replacement.
    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
    Args:
        model : top level module
        expansions : replacement dictionary: module class name -> replacement function generator
    Returns:
        model, possibly modified in-place
    """
    mapping: Dict[str, nn.Module] = {}
    for name, m in model.named_modules():
        m_type = type(m).__name__
        if m_type in expansions:
            swapped = expansions[m_type](m)
            if swapped:
                mapping[name] = swapped
    if len(mapping) > 0:
        logging.info(f"Swapped {len(mapping)} modules")
    swap_modules(model, mapping)
    return model


def script_module(m: nn.Module):
    return torch.jit.script(m)


script_replacements = {}


def replace_for_export(model: nn.Module) -> nn.Module:
    """
    Top-level function to replace 'default set' of modules in model, called from _prepare_for_export.
    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
    Args:
        model : top level module
    Returns:
        model, possibly modified in-place
    """
    default_replacements = {
        "MatchedScaleMaskSoftmax": wrap_module(None, replace_MatchedScaleMaskSoftmax),
    }

    replace_modules(model, default_Apex_replacements)
    replace_modules(model, default_replacements)
    # This one has to be the last
    replace_modules(model, script_replacements)


def add_casts_around_norms(model: nn.Module):
    """
    Function to put additional to/from float32 casts around operations known to require full precision.
    It was used with an extra post-parse script to have TRT preserve extra precision when --fp16 needed.
    Should not be needed with TRT 8.6.1 or later.
    """
    from nemo.collections.tts.modules.submodules import MaskedInstanceNorm1d

    default_cast_replacements = {
        "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
        "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
        "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
        "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
        "MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll),
    }
    replace_modules(model, default_cast_replacements)


def rename_onnx_io(output, input_names, output_names):
    onnx_model = onnx.load(output)
    rename_map = {}
    for inp, name in zip(onnx_model.graph.input, input_names):
        rename_map[inp.name] = name
    for out, name in zip(onnx_model.graph.output, output_names):
        rename_map[out.name] = name
    for n in onnx_model.graph.node:
        for inp in range(len(n.input)):
            if n.input[inp] in rename_map:
                n.input[inp] = rename_map[n.input[inp]]
        for out in range(len(n.output)):
            if n.output[out] in rename_map:
                n.output[out] = rename_map[n.output[out]]

    for i in range(len(input_names)):
        onnx_model.graph.input[i].name = input_names[i]
    for i in range(len(output_names)):
        onnx_model.graph.output[i].name = output_names[i]
    onnx.save(onnx_model, output)
