# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import numbers
from typing import Optional, Sequence

import numpy as np
import onnx
import onnx_ir as ir

from onnxscript import tensor


def external_tensor(
    name: str,
    data_type: int,
    dims: Sequence[int],
    location: str,
    offset: Optional[int] = None,
    length: Optional[int] = None,
    checksum: Optional[str] = None,
    basepath: Optional[str] = None,
) -> onnx.TensorProto:
    """Create a TensorProto referencing externally stored tensor-data.

    Args:
        name: name of the tensor
        data_type: data type of tensor element
        dims: shape of the tensor
        location: location of the external file (relative path)
        offset: offset in the file where the tensor-data starts
        length: number of bytes containing the data
        checksum: SHA1 digest of the file
        basepath: basepath combined with location to form the full path

    Returns:
        TensorProto

    See https://github.com/onnx/onnx/blob/main/docs/ExternalData.md for more details.
    """
    tensor_proto = onnx.TensorProto()
    tensor_proto.name = name
    tensor_proto.data_type = data_type
    tensor_proto.dims.extend(dims)
    tensor_proto.data_location = onnx.TensorProto.EXTERNAL

    def add(k, v):
        entry = tensor_proto.external_data.add()
        entry.key = k
        entry.value = str(v)

    add("location", location)
    if offset is not None:
        add("offset", int(offset))
    if length is not None:
        add("length", int(length))
    if checksum is not None:
        add("checksum", checksum)
    if basepath is not None:
        add("basepath", basepath)
    return tensor_proto


def value_to_type_proto(val):
    """Return the ONNX type of a python-value."""
    if isinstance(val, (np.ndarray, tensor.Tensor)):
        elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype)  # noqa: TID251
        shape = val.shape
        return onnx.helper.make_tensor_type_proto(elem_type, shape)  # noqa: TID251
    if isinstance(val, int):
        return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, [])  # noqa: TID251
    if isinstance(val, (float, np.float32)):
        return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [])  # noqa: TID251
    if isinstance(val, list):
        if len(val) > 0:
            return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))  # noqa: TID251
        # Edge-case. Cannot determine a suitable ONNX type for an empty list.
        # Should be using a typed-value instead.
        # Treated as a sequence of tensors of float-type.
        return onnx.helper.make_sequence_type_proto(  # noqa: TID251
            onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None)  # noqa: TID251
        )
    if isinstance(val, numbers.Number):
        nparray = np.array(val)
        elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype)  # noqa: TID251
        return onnx.helper.make_tensor_type_proto(elem_type, [])  # noqa: TID251
    raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")


def value_to_type(val):
    """Return an ir.Value representation of a python-value."""
    if isinstance(val, (np.ndarray, tensor.Tensor)):
        elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype)  # noqa: TID251
        shape = val.shape
        return (ir.TensorType(elem_type), shape)
    elif isinstance(val, int):
        elem_type = onnx.TensorProto.INT32
        shape = []
        return (ir.TensorType(elem_type), shape)
    elif isinstance(val, (float, np.float32)):
        elem_type = onnx.TensorProto.FLOAT
        shape = []
        return (ir.TensorType(elem_type), shape)
    elif isinstance(val, list):
        if len(val) > 0:
            type, shape = value_to_type(val[0])
            return ir.SequenceType(type), shape
        # Edge-case. Cannot determine a suitable ONNX type for an empty list.
        # Should be using a typed-value instead.
        # Treated as a sequence of tensors of float-type.
        return ir.SequenceType(ir.TensorType(onnx.TensorProto.FLOAT)), None
    if isinstance(val, numbers.Number):
        nparray = np.array(val)
        elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype)  # noqa: TID251
        return ir.TensorType(elem_type), []
    raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")


def value_to_ir_value(name: str, val) -> ir.Value:
    """Return an ir.Value representation of a python-value."""
    type, shape = value_to_type(val)
    return ir.Value(name=name, type=type, shape=shape)


def values_to_value_infos(name_values):
    """Create a list of ValueInfoProto from a list of (name, value) pairs,
    skipping any None values.
    """
    return [
        onnx.helper.make_value_info(name, value_to_type_proto(val))  # noqa: TID251
        for (name, val) in name_values
        if val is not None
    ]


def values_to_ir_values(name_values):
    """Create a list of ir.Value from a list of (name, value) pairs,
    skipping any None values.
    """
    return [value_to_ir_value(name, val) for (name, val) in name_values if val is not None]
