# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import numpy as np
import torch
import wrapt
from transformers import AutoModel, AutoTokenizer

from nemo.deploy import ITritonDeployable
from nemo.export.utils import get_example_inputs, get_model_device_type, is_nemo2_checkpoint, validate_fp8_network
from nemo.utils import logging

if TYPE_CHECKING:
    import tensorrt as trt


@wrapt.decorator
def noop_decorator(func):
    """No op decorator"""

    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return wrapper


use_pytriton = True
batch = noop_decorator
try:
    from pytriton.decorators import batch
except Exception:
    logging.warning("PyTriton is not available.")
    use_pytriton = False


use_onnxruntime = True
try:
    import onnxruntime
except Exception:
    logging.warning("onnxruntime is not available.")
    use_onnxruntime = False


use_trt = True
try:
    import tensorrt as trt
except ImportError:
    logging.warning("tensorrt is not available")
    use_trt = False


# pylint: disable=line-too-long
class OnnxLLMExporter(ITritonDeployable):
    """
    Exports models to ONNX and run fast inference.

    Example:
        from nemo.export.onnx_llm_exporter import OnnxLLMExporter

        onnx_llm_exporter = OnnxLLMExporter(
            onnx_model_dir="/path/for/onnx_model/files",
            model_name_or_path="/path/for/model/files",
        )

        onnx_llm_exporter.export(
            input_names=["input_ids", "attention_mask", "dimensions"],
            output_names=["embeddings"],
        )

        output = onnx_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])
        print("output: ", output)
    """

    def __init__(
        self,
        onnx_model_dir: str,
        model: Optional[torch.nn.Module] = None,
        tokenizer=None,
        model_name_or_path: str = None,
        load_runtime: bool = True,
    ):
        """
        Initializes the ONNX Exporter.

        Args:
            onnx_model_dir (str): path for storing the ONNX model files.
            model (Optional[torch.nn.Module]): torch model.
            tokenizer (HF or NeMo tokenizer): tokenizer class.
            model_name_or_path (str): a path for ckpt or HF model ID
            load_runtime (bool): load ONNX runtime if there is any exported model available in
                                 the onnx_model_dir folder.
        """
        self.onnx_model_dir = onnx_model_dir
        self.model_name_or_path = model_name_or_path
        self.onnx_model_path = str(Path(onnx_model_dir) / "model.onnx")
        self.model = model
        self.tokenizer = tokenizer
        self.model_input_names = None
        self.model_output_names = None
        self.onnx_runtime_session = None
        self.calibration_data = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.quant_max_batch_size = None

        if self.model_name_or_path is not None:
            if model is not None:
                raise ValueError("A model was also passed but it will be overridden.")

            if Path(self.model_name_or_path).is_dir():
                if is_nemo2_checkpoint(self.model_name_or_path):
                    raise NotImplementedError("NeMo 2.0 checkpoint will be supported later.")
                else:
                    self._load_hf_model()

        if load_runtime:
            self._load_runtime()

    def _load_runtime(self):
        if use_onnxruntime:
            if Path(self.onnx_model_path).exists():
                self.onnx_runtime_session = onnxruntime.InferenceSession(self.onnx_model_path)
                self.model_input_names = [input.name for input in self.onnx_runtime_session.get_inputs()]
                self.model_output_names = [output.name for output in self.onnx_runtime_session.get_outputs()]
                self.tokenizer = AutoTokenizer.from_pretrained(
                    Path(self.onnx_model_dir) / "tokenizer", trust_remote_code=True
                )

    def _load_hf_model(self):
        self.model = AutoModel.from_pretrained(
            self.model_name_or_path,
            trust_remote_code=True,
        ).eval()

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)

    def export(
        self,
        input_names: list,
        output_names: list,
        example_inputs: dict = None,
        opset: int = 20,
        dynamic_axes_input: Optional[dict] = None,
        dynamic_axes_output: Optional[dict] = None,
        export_dtype: str = "fp32",
        verbose: bool = False,
    ):
        """
        Performs ONNX conversion from a PyTorch model.

        Args:
            input_names (list): input parameter names of the model that ONNX will export will use.
            output_names (list): output parameter names of the model that ONNX will export will use.
            example_inputs (dict): example input for the model to build the engine.
            opset (int): ONNX opset version. Default is 20.
            dynamic_axes_input (dict): Variable length axes for the input.
            dynamic_axes_output (dict): Variable length axes for the output.
            export_dtype (str): Export dtype, fp16 or fp32.
            verbose (bool): Enable verbose or not.
        """

        self._export_to_onnx(
            input_names=input_names,
            example_inputs=example_inputs,
            output_names=output_names,
            opset=opset,
            dynamic_axes_input=dynamic_axes_input,
            dynamic_axes_output=dynamic_axes_output,
            export_dtype=export_dtype,
            verbose=verbose,
        )
        self._load_runtime()

    def _export_to_onnx(
        self,
        input_names: list,
        output_names: list,
        example_inputs: dict = None,
        opset: int = 20,
        dynamic_axes_input: Optional[dict] = None,
        dynamic_axes_output: Optional[dict] = None,
        export_dtype: Union[torch.dtype, str] = "fp32",
        verbose: bool = False,
    ):

        if example_inputs is None:
            example_inputs = get_example_inputs(self.tokenizer)

        if "dimensions" in input_names:
            example_inputs["dimensions"] = torch.tensor([1] * example_inputs["input_ids"].shape[0])

        if isinstance(export_dtype, str):
            export_dtype = {"fp16": torch.float16, "fp32": torch.float32}[export_dtype]

        self.model.to(export_dtype)

        Path(self.onnx_model_dir).mkdir(parents=True, exist_ok=True)

        with torch.autocast(device_type=get_model_device_type(self.model), dtype=export_dtype):
            torch.onnx.export(
                model=self.model,
                args=(example_inputs,),
                f=self.onnx_model_path,
                input_names=input_names,
                output_names=output_names,
                dynamic_axes={**dynamic_axes_input, **dynamic_axes_output},
                verbose=verbose,
                opset_version=opset,
            )
        logging.info(f"Successfully exported PyTorch model to ONNX model {self.onnx_model_path}")

        existing_directory_path = Path(self.onnx_model_dir) / "tokenizer"
        existing_directory_path.mkdir(exist_ok=True)
        self.tokenizer.save_pretrained(existing_directory_path)

    def export_onnx_to_trt(
        self,
        trt_model_dir: str,
        profiles=None,
        override_layernorm_precision_to_fp32: bool = False,
        override_layers_to_fp32: List = None,
        trt_dtype: str = "fp16",
        profiling_verbosity: str = "layer_names_only",
        trt_builder_flags: List["trt.BuilderFlag"] = None,
    ) -> None:
        """Performs TensorRT conversion from an ONNX model.

        Args:
            trt_model_dir: path to store the TensorRT model.
            profiles: TensorRT profiles.
            override_layernorm_precision_to_fp32 (bool): whether to convert layers to fp32 or not.
            override_layers_to_fp32 (List): Layer names to be converted to fp32.
            trt_dtype (str): "fp16" or "fp32".
            profiling_verbosity (str): Profiling verbosity. Default is "layer_names_only".
            trt_builder_flags (List[trt.BuilderFlag]): TRT specific flags.
        """
        logging.info(f"Building TRT engine from ONNX model ({self.onnx_model_path})")
        trt_logger = trt.Logger(trt.Logger.WARNING)
        builder = trt.Builder(trt_logger)
        network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
        config = builder.create_builder_config()
        parser = trt.OnnxParser(network, trt_logger)

        # we use parse_from_file() instead of parse() because it can be used for both single
        # file models as well as externally stored models (required when model >2GiB)
        if not parser.parse_from_file(self.onnx_model_path):
            logging.warning("ONNX model could not be parsed")
            for error in range(parser.num_errors):
                logging.error(parser.get_error(error))
            return

        if profiles:
            for profile in profiles:
                optimization_profile = builder.create_optimization_profile()

                for i in range(network.num_inputs):
                    in_tensor = network.get_input(i)
                    optimization_profile.set_shape(
                        in_tensor.name,
                        min=profile[in_tensor.name][0],
                        opt=profile[in_tensor.name][1],
                        max=profile[in_tensor.name][2],
                    )

                config.add_optimization_profile(optimization_profile)

        if trt_dtype == "fp16":
            logging.info("Setting Build Flag FP16")
            config.set_flag(trt.BuilderFlag.FP16)
        elif trt_dtype == "fp8":
            # With FP8 export we want to also enable FP16 layers as a fallback instead of FP32
            logging.info("Setting Build Flag FP8 and FP16")
            config.set_flag(trt.BuilderFlag.FP8)
            config.set_flag(trt.BuilderFlag.FP16)
            validate_fp8_network(network)

        # patch network
        if override_layernorm_precision_to_fp32:
            logging.info("Overriding TensorRT network LayerNorm precision to float32.")
            self._override_layernorm_precision_to_fp32(network)

        if override_layers_to_fp32:
            logging.info("Overriding some layers to float32.")
            self._override_layers_to_fp32(network, override_layers_to_fp32)

        try:
            config.profiling_verbosity = {
                "detailed": trt.ProfilingVerbosity.DETAILED,
                "layer_names_only": trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
                "none": trt.ProfilingVerbosity.NONE,
            }[profiling_verbosity]
        except KeyError:
            error_msg = "Unknown profiling verbosity value."
            raise ValueError(error_msg)
        logging.info(f"Setting Profiling Verbosity to {config.profiling_verbosity}")

        if trt_builder_flags is not None:
            for flag in trt_builder_flags:
                config.set_flag(flag)

        engine_string = builder.build_serialized_network(network, config)
        if engine_string is None:
            raise Exception("Failed to serialize the TensorRT Engine. Please check the " "TensorRT logs for details")

        trt_model_path = Path(trt_model_dir)
        trt_model_path.mkdir(parents=True, exist_ok=True)
        trt_model_path = trt_model_path / "model.plan"
        trt_model_path.write_bytes(engine_string)
        logging.info(f"Successfully exported ONNX model ({self.onnx_model_path}) " f"to TRT engine ({trt_model_path})")

    def _override_layer_precision_to_fp32(self, layer: "trt.ILayer") -> None:
        layer.precision = trt.float32
        layer.set_output_type(0, trt.float32)

    def _override_layers_to_fp32(self, network: "trt.INetworkDefinition", fp32_layer_patterns: list[str]) -> None:
        for i in range(network.num_layers):
            layer = network.get_layer(i)
            layer_name = layer.name
            if any(layer_name.startswith(pattern) for pattern in fp32_layer_patterns) and layer.precision in {
                trt.float32,
                trt.float16,
            }:
                if layer.type in {trt.LayerType.CAST}:
                    logging.info(f"Skipping overriding {layer.type} layer {i} {layer_name} dtype")
                    continue
                if any(
                    layer.get_input(input_idx).dtype in {trt.float32, trt.float16}
                    for input_idx in range(layer.num_inputs)
                ):
                    # Note: Assigning to layer.precision (even the same value) sets precision_is_set=True,
                    # which prevents TensorRT from changing this layer's precision
                    layer.precision = trt.float32
                    logging.info(f"Setting layer {i} {layer_name} (type: {layer.type}) precision to FP32")
                for j in range(layer.num_outputs):
                    if layer.get_output_type(j) in {trt.float32, trt.float16}:
                        layer.set_output_type(j, trt.float32)
                        logging.info(f"Setting layer {i} {layer_name} (type: {layer.type}) output type {j} to FP32")

    def _override_layernorm_precision_to_fp32(self, network: "trt.INetworkDefinition") -> None:
        """Set the precision of LayerNorm subgraphs to FP32 to preserve accuracy.

        - https://nvbugs/4478448 (Mistral)
        - https://nvbugs/3802112 (T5)

        Args:
            network: tensorrt.INetworkDefinition
        """
        # Logic originally from OSS T5 HF export script:
        # https://gitlab-master.nvidia.com/TensorRT/Public/oss/-/blob/77495ec/demo/HuggingFace/T5/export.py
        pow_ops = {}
        for layer_index, layer in enumerate(network):
            if layer.type == trt.LayerType.IDENTITY:
                all_fp32 = all(
                    [
                        layer.output_type_is_set(o) and layer.get_output_type(o) == trt.float32
                        for o in range(layer.num_outputs)
                    ]
                )
                if all_fp32:
                    if layer.get_input(0).dtype == trt.float32:
                        layer.precision = trt.float32

            if layer.type == trt.LayerType.ELEMENTWISE:
                layer.__class__ = getattr(trt, "IElementWiseLayer")
                if layer.op == trt.ElementWiseOperation.POW:
                    pow_ops[layer] = layer_index
                    self._override_layer_precision_to_fp32(layer)

        for _, index in pow_ops.items():
            # Iterate from few layers before pow to include residual add and cast op.
            # Iterate till 10 layers after pow op to include all
            # operations included in layer norm.
            START_OFFSET = 4
            END_OFFSET = 12
            for i in range(index - START_OFFSET, index + END_OFFSET):
                layer = network.get_layer(i)
                if layer.type == trt.LayerType.REDUCE:
                    self._override_layer_precision_to_fp32(layer)

                if layer.type == trt.LayerType.ELEMENTWISE:
                    layer.__class__ = getattr(trt, "IElementWiseLayer")
                    if layer.op == trt.ElementWiseOperation.SUM:
                        self._override_layer_precision_to_fp32(layer)

                if layer.type == trt.LayerType.UNARY:
                    layer.__class__ = getattr(trt, "IUnaryLayer")
                    if layer.op == trt.UnaryOperation.SQRT:
                        self._override_layer_precision_to_fp32(layer)

                if layer.type == trt.LayerType.ELEMENTWISE:
                    layer.__class__ = getattr(trt, "IElementWiseLayer")
                    if layer.op == trt.ElementWiseOperation.DIV:
                        self._override_layer_precision_to_fp32(layer)

                if layer.type == trt.LayerType.ELEMENTWISE:
                    layer.__class__ = getattr(trt, "IElementWiseLayer")
                    if layer.op == trt.ElementWiseOperation.PROD:
                        self._override_layer_precision_to_fp32(layer)

    def forward(self, inputs: Union[List, Dict], dimensions: Optional[List] = None):
        """Run inference for a given input.

        Args:
            inputs (Union[List, Dict]): Input for the model. If list, it should be a list of strings.
                If dict, it should be a dictionary with keys as the model input names.
            dimensions (Optional[List]): The dimensions parameter of the model. Required if the model
                was exported to accept dimensions parameter and inputs is given as a list of strings.

        Returns:
            np.ndarray: Model output.
        """

        if self.onnx_runtime_session is None:
            warnings.warn("ONNX Runtime is not available. Please install the onnxruntime-gpu and try again.")
            return None

        if isinstance(inputs, List):
            if "dimensions" in self.model_input_names and dimensions is None:
                raise ValueError("Dimensions should be provided for list input.")
            inputs = dict(self.tokenizer(inputs))
            inputs["dimensions"] = dimensions

        output = self.onnx_runtime_session.run(self.model_output_names, inputs)
        return output[0]

    @property
    def get_model(self):
        """Returns the model"""

        return self.model

    @property
    def get_tokenizer(self):
        """Returns the tokenizer"""

        return self.tokenizer

    @property
    def get_model_input_names(self):
        """Returns the model input names"""

        return self.model_input_names

    @property
    def get_triton_input(self):
        """Get triton input"""

        raise NotImplementedError("This function will be implemented later.")

    @property
    def get_triton_output(self):
        """Get triton output"""

        raise NotImplementedError("This function will be implemented later.")

    @batch
    def triton_infer_fn(self, **inputs: np.ndarray):
        """PyTriton inference function"""

        raise NotImplementedError("This function will be implemented later.")
