# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence

import numpy as np
import onnx

from onnxscript import ir, tensor
from onnxscript.ir import _schemas

if TYPE_CHECKING:
    from onnxscript._internal import converter

# Conversions from python values to ONNX are used by both the script converter as well
# as the eager-mode runtime and both need to be consistent. The script converter converts
# python values into ONNX TensorProto, while the runtime converts python values into
# ONNXScript runtime's value-representation (based on Tensor).

_REPEATED_ATTRIBUTE_TYPES = frozenset(
    {
        ir.AttributeType.FLOATS,
        ir.AttributeType.INTS,
        ir.AttributeType.STRINGS,
        ir.AttributeType.TENSORS,
        ir.AttributeType.GRAPHS,
        ir.AttributeType.SPARSE_TENSORS,
        ir.AttributeType.TYPE_PROTOS,
    }
)


def pyvalue_to_onnx_attribute(
    key: str,
    value: Any,
    name_generator: Callable[[], str],
    attr_type: ir.AttributeType | None = None,
) -> ir.Attr:
    """Helper function to create an ONNX AttributeProto.

    * Empty lists can be attribute values, provided the attribute type is specified
    and is a list type.
    * Scalar-values like 1.0 as well as lists like [1, -1] to be specified
    when the attribute type is TensorProto by automatically converting the value
    into a 0-D or 1-D tensor respectively.
    """
    # TODO(justinchuby): Remove this function and use onnx-ir directly.
    if isinstance(value, list) and not value:
        # Empty list value:
        if attr_type is None:
            raise ValueError("Attribute type must be specified for empty list value.")
        if attr_type not in _REPEATED_ATTRIBUTE_TYPES:
            raise ValueError("Empty list value is only allowed for repeated attribute types.")
        return ir.Attr(name=key, type=attr_type, value=[])
    elif attr_type == ir.AttributeType.TENSOR and not isinstance(value, onnx.TensorProto):
        return ir.AttrTensor(name=key, value=ir.tensor(value, name=name_generator()))
    else:
        return ir.convenience.convert_attribute(key, value, attr_type=attr_type)


# Utilities to convert python values into onnxscript tensors.


def _promotable(x) -> bool:
    """Checks if a runtime parameter value needs to be promoted into an onnxscript value.
    This is the runtime-equivalent of the promotion of literal constants into ONNX values
    in the static converter.
    """
    if isinstance(x, (bool, int, float)):
        return True
    if isinstance(x, list) and x:
        # Note: This is meant to handle valid scenarios correctly. No attempt is
        # made yet to capture all invalid usages in runtime mode.
        return _promotable(x[0])
    return False


def _get_dtype(pyvalue):
    """Return np.dtype to use when converting a python value to an onnxscript tensor.
    Note that int constants are treated as int64, as that is the common type in ONNX
    for shape/index values.
    """
    if isinstance(pyvalue, bool):
        return np.bool_
    elif isinstance(pyvalue, int):
        return np.int64
    elif isinstance(pyvalue, float):
        return np.float32
    elif isinstance(pyvalue, list):
        if pyvalue:
            # TODO: What to do about lists with mixed value types, like [1, 2.0]?
            # Should at least produce an error/warning message.
            return _get_dtype(pyvalue[0])
        raise ValueError("Cannot determine target type for empty list")
    raise TypeError(f"Value of unexpected type {type(pyvalue)}")


def cast_pyvalue_to_os_tensor(pyvalue, dtype=None):
    """Promotes python values into onnxscript tensors.
    The optional argument dtype specifies the desired np.dtype of the tensor,
    used only when a non-standard onnxscript-value is promoted into one.
    """
    if _promotable(pyvalue):
        if dtype is None:
            dtype = _get_dtype(pyvalue)
        return tensor.Tensor(np.array(pyvalue, dtype=dtype))
    return pyvalue


def cast_inputs(
    get_type_info: Callable[[Any], Any],
    cast: Callable[[Any, Any], Any],
    op_signature: _schemas.OpSignature | None,
    args,
) -> tuple[Any, ...]:
    """Uses schema specification to support a limited form of auto-casting.

    * Scalars are promoted to tensors.
    * Further. they are cast to the required type when used in ops with other
    tensor inputs that are required to be of same type.
    Thus, in "A+1" or "Add(A, 1)", the value 1 will be converted to the same
    type as A.

    This is used by the converter in a static-mode, as well as by the eager-mode
    execution in a dynamic-mode.
    """
    if op_signature is None:
        # Either an error or a custom op.
        # No checks/casts in this case.
        return tuple(cast(x, None) for x in args)

    # Filter to get only input parameters (not AttributeParameters)
    expected_inputs = op_signature.inputs
    # We make two passes. In the first pass, we identify known type-bindings for
    # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}.
    # In the second pass, we use these bindings to cast scalar-values to
    # tensors of appropriate types. The two passes are needed to handle cases
    # like "Add(1, X)" where 1 must be cast to the same type as X.
    type_bindings: dict[Optional[str], np.dtype] = {}
    args_typevars: list[tuple[str, Optional[str]]] = []
    for i, x in enumerate(args):
        if i < len(expected_inputs):
            expected = expected_inputs[i]
        elif expected_inputs[-1].variadic:
            expected = expected_inputs[-1]
            if not expected.homogeneous:
                args_typevars.append((x, None))
                continue
        else:
            raise ValueError(
                f"Number of actual parameters {len(args)} "
                f"exceeds number of formal parameters {len(expected_inputs)}."
            )
        typevar = expected.type_constraint.name
        if "(" not in typevar:
            # typevar is an identifier, like "T"
            typeinfo = get_type_info(x)
            if typeinfo is not None:
                type_bindings[typevar] = typeinfo
        args_typevars.append((x, typevar))
    cast_args = [cast(x, type_bindings.get(typevar)) for x, typevar in args_typevars]
    return tuple(cast_args)


def dynamic_cast_inputs(op_signature: _schemas.OpSignature, args):
    """Used for autocast during eager-mode execution."""

    def get_type_info(x):
        return x.dtype if isinstance(x, tensor.Tensor) else None

    return cast_inputs(get_type_info, cast_pyvalue_to_os_tensor, op_signature, args)


def static_cast_inputs(
    converter_: converter.Converter,
    op_signature: Optional[_schemas.OpSignature],
    args: Sequence[Optional[ir.Value]],
) -> tuple[str, ...]:
    """Used for autocast during script-translation.
    This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))"
    Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed.
    """

    def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]:
        """Returns x back if x can serve as the target-type for a cast (as the second
        argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is
        castable, while X can serve as the target-type.
        """
        return None if x is None or converter_.is_castable(x.name) else x

    def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]:
        if x is None:
            return None
        if converter_.is_castable(x.name) and y is not None:
            # Polymorphic constant x is cast to the type of y:
            x_cast = converter_.generate_unique_name(f"{x.name}_cast")
            return converter_.emit1([x_cast], "CastLike", [x, y])
        return x

    return cast_inputs(get_type_info, cast_like, op_signature, args)
