# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil
from collections import Counter
from pathlib import Path
from typing import Dict, Optional, Union

import torch


def is_nemo2_checkpoint(checkpoint_path: str) -> bool:
    """
    Checks if the checkpoint is in NeMo 2.0 format.
    Args:
        checkpoint_path (str): Path to a checkpoint.
    Returns:
        bool: True if the path points to a NeMo 2.0 checkpoint; otherwise false.
    """

    ckpt_path = Path(checkpoint_path)
    return (ckpt_path / 'context').is_dir()


def prepare_directory_for_export(
    model_dir: Union[str, Path], delete_existing_files: bool, subdir: Optional[str] = None
) -> None:
    """
    Prepares model_dir path for the TensorRTT-LLM / vLLM export.
    Makes sure that the model_dir directory exists and is empty.

    Args:
        model_dir (str): Path to the target directory for the export.
        delete_existing_files (bool): Attempt to delete existing files if they exist.
        subdir (Optional[str]): Subdirectory to create inside the model_dir.

    Returns:
        None
    """
    model_path = Path(model_dir)

    if model_path.exists():
        if delete_existing_files:
            shutil.rmtree(model_path)
        elif any(model_path.iterdir()):
            raise RuntimeError(f"There are files in {model_path} folder: try setting delete_existing_files=True.")

    if subdir is not None:
        model_path /= subdir
    model_path.mkdir(parents=True, exist_ok=True)


def is_nemo_tarfile(path: str) -> bool:
    """
    Checks if the path exists and points to packed NeMo 1 checkpoint.

    Args:
        path (str): Path to possible checkpoint.
    Returns:
        bool: NeMo 1 checkpoint exists and is in '.nemo' format.
    """
    checkpoint_path = Path(path)
    return checkpoint_path.exists() and checkpoint_path.suffix == '.nemo'


# Copied from nemo.collections.nlp.parts.utils_funcs to avoid introducing extra NeMo dependencies:
def torch_dtype_from_precision(precision: Union[int, str], megatron_amp_O2: bool = True) -> torch.dtype:
    """
    Mapping from PyTorch Lighthing (PTL) precision types to corresponding PyTorch parameter data type.

    Args:
        precision (Union[int, str]): The PTL precision type used.
        megatron_amp_O2 (bool): A flag indicating if Megatron AMP O2 is enabled.

    Returns:
        torch.dtype: The corresponding PyTorch data type based on the provided precision.
    """
    if not megatron_amp_O2:
        return torch.float32

    if precision in ['bf16', 'bf16-mixed']:
        return torch.bfloat16
    elif precision in [16, '16', '16-mixed']:
        return torch.float16
    elif precision in [32, '32', '32-true']:
        return torch.float32
    else:
        raise ValueError(f"Could not parse the precision of '{precision}' to a valid torch.dtype")


def get_model_device_type(module: torch.nn.Module) -> str:
    """Find the device type the model is assigned to and ensure consistency."""
    # Collect device types of all parameters and buffers
    param_device_types = {param.device.type for param in module.parameters()}
    buffer_device_types = {buffer.device.type for buffer in module.buffers()}
    all_device_types = param_device_types.union(buffer_device_types)

    if len(all_device_types) > 1:
        raise ValueError(
            f"Model parameters and buffers are on multiple device types: {all_device_types}. "
            "Ensure all parameters and buffers are on the same device type."
        )

    # Return the single device type, or default to 'cpu' if no parameters or buffers
    return all_device_types.pop() if all_device_types else "cpu"


def get_example_inputs(tokenizer) -> Dict[str, torch.Tensor]:
    """Gets example data to feed to the model during ONNX export.

    Returns:
        Dictionary of tokenizer outputs.
    """
    example_inputs = dict(
        tokenizer(
            ["example query one", "example query two"],
            ["example passage one", "example passage two"],
            return_tensors="pt",
        )
    )

    return example_inputs


def validate_fp8_network(network) -> None:
    """Checks the network to ensure it's compatible with fp8 precison.

    Raises:
        ValueError if netowrk doesn't container Q/DQ FP8 layers
    """

    import tensorrt as trt

    quantize_dequantize_layers = []
    for layer in network:
        if layer.type in {trt.LayerType.QUANTIZE, trt.LayerType.DEQUANTIZE}:
            quantize_dequantize_layers.append(layer)
    if not quantize_dequantize_layers:
        error_msg = "No Quantize/Dequantize layers found"
        raise ValueError(error_msg)
    quantize_dequantize_layer_dtypes = Counter(layer.precision for layer in quantize_dequantize_layers)
    if trt.DataType.FP8 not in quantize_dequantize_layer_dtypes:
        error_msg = "Found Quantize/Dequantize layers. But none with FP8 precision."
        raise ValueError(error_msg)
