# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
# pylint: disable=C0415,R0912,R0913,R0914,R0915,W0611,W0603
"""Every class imported in this module defines an implementation of
an operator of the main domain. Any class name uses `_` to specify a
version defined in a specific opset. The class name without `_`
defines the current implementation. If an operator has no class
with `_`, it means the implementation is valid for every opset.
The operator may have been updated to support more types but that
did not change the implementation.
"""

from __future__ import annotations

import textwrap
from typing import Any

from onnx import FunctionProto, NodeProto, TypeProto
from onnx.defs import get_schema, onnx_opset_version
from onnx.onnx_cpp2py_export.defs import SchemaError
from onnx.reference.op_run import (
    OpFunction,
    OpRun,
    RuntimeContextError,
    RuntimeImplementationError,
)
from onnx.reference.ops._helpers import build_registered_operators_any_domain
from onnx.reference.ops.op_abs import Abs
from onnx.reference.ops.op_acos import Acos
from onnx.reference.ops.op_acosh import Acosh
from onnx.reference.ops.op_add import Add
from onnx.reference.ops.op_affine_grid import AffineGrid
from onnx.reference.ops.op_and import And
from onnx.reference.ops.op_argmax import ArgMax_1, ArgMax_12
from onnx.reference.ops.op_argmin import ArgMin_1, ArgMin_12
from onnx.reference.ops.op_asin import Asin
from onnx.reference.ops.op_asinh import Asinh
from onnx.reference.ops.op_atan import Atan
from onnx.reference.ops.op_atanh import Atanh
from onnx.reference.ops.op_attention import Attention
from onnx.reference.ops.op_attribute_has_value import AttributeHasValue
from onnx.reference.ops.op_average_pool import (
    AveragePool_1,
    AveragePool_7,
    AveragePool_11,
    AveragePool_19,
)
from onnx.reference.ops.op_batch_normalization import (
    BatchNormalization_6,
    BatchNormalization_9,
    BatchNormalization_14,
)
from onnx.reference.ops.op_bernoulli import Bernoulli
from onnx.reference.ops.op_bitshift import BitShift
from onnx.reference.ops.op_bitwise_and import BitwiseAnd
from onnx.reference.ops.op_bitwise_not import BitwiseNot
from onnx.reference.ops.op_bitwise_or import BitwiseOr
from onnx.reference.ops.op_bitwise_xor import BitwiseXor
from onnx.reference.ops.op_blackman_window import BlackmanWindow
from onnx.reference.ops.op_cast import Cast_1, Cast_19
from onnx.reference.ops.op_cast_like import CastLike_15, CastLike_19
from onnx.reference.ops.op_ceil import Ceil
from onnx.reference.ops.op_celu import Celu
from onnx.reference.ops.op_center_crop_pad import CenterCropPad
from onnx.reference.ops.op_clip import Clip_6, Clip_11
from onnx.reference.ops.op_col2im import Col2Im
from onnx.reference.ops.op_compress import Compress
from onnx.reference.ops.op_concat import Concat
from onnx.reference.ops.op_concat_from_sequence import ConcatFromSequence
from onnx.reference.ops.op_constant import (
    Constant_1,
    Constant_9,
    Constant_11,
    Constant_12,
)
from onnx.reference.ops.op_constant_of_shape import ConstantOfShape
from onnx.reference.ops.op_conv import Conv
from onnx.reference.ops.op_conv_integer import ConvInteger
from onnx.reference.ops.op_conv_transpose import ConvTranspose
from onnx.reference.ops.op_cos import Cos
from onnx.reference.ops.op_cosh import Cosh
from onnx.reference.ops.op_cum_sum import CumSum
from onnx.reference.ops.op_deform_conv import DeformConv
from onnx.reference.ops.op_depth_to_space import DepthToSpace
from onnx.reference.ops.op_dequantize_linear import (
    DequantizeLinear_19,
    DequantizeLinear_21,
)
from onnx.reference.ops.op_det import Det
from onnx.reference.ops.op_dft import DFT_17, DFT_20
from onnx.reference.ops.op_div import Div
from onnx.reference.ops.op_dropout import Dropout_7, Dropout_12
from onnx.reference.ops.op_dynamic_quantize_linear import DynamicQuantizeLinear
from onnx.reference.ops.op_einsum import Einsum
from onnx.reference.ops.op_elu import Elu
from onnx.reference.ops.op_equal import Equal
from onnx.reference.ops.op_erf import Erf
from onnx.reference.ops.op_exp import Exp
from onnx.reference.ops.op_expand import Expand
from onnx.reference.ops.op_eyelike import EyeLike
from onnx.reference.ops.op_flatten import Flatten
from onnx.reference.ops.op_floor import Floor
from onnx.reference.ops.op_gather import Gather
from onnx.reference.ops.op_gather_elements import GatherElements
from onnx.reference.ops.op_gathernd import GatherND
from onnx.reference.ops.op_gemm import Gemm_6, Gemm_7
from onnx.reference.ops.op_global_average_pool import GlobalAveragePool
from onnx.reference.ops.op_global_max_pool import GlobalMaxPool
from onnx.reference.ops.op_greater import Greater
from onnx.reference.ops.op_greater_or_equal import GreaterOrEqual
from onnx.reference.ops.op_grid_sample import GridSample
from onnx.reference.ops.op_gru import GRU
from onnx.reference.ops.op_hamming_window import HammingWindow
from onnx.reference.ops.op_hann_window import HannWindow
from onnx.reference.ops.op_hard_sigmoid import HardSigmoid
from onnx.reference.ops.op_hardmax import Hardmax
from onnx.reference.ops.op_identity import Identity
from onnx.reference.ops.op_if import If
from onnx.reference.ops.op_image_decoder import ImageDecoder
from onnx.reference.ops.op_instance_normalization import InstanceNormalization
from onnx.reference.ops.op_isinf import IsInf
from onnx.reference.ops.op_isnan import IsNaN
from onnx.reference.ops.op_layer_normalization import LayerNormalization
from onnx.reference.ops.op_leaky_relu import LeakyRelu
from onnx.reference.ops.op_less import Less
from onnx.reference.ops.op_less_or_equal import LessOrEqual
from onnx.reference.ops.op_log import Log
from onnx.reference.ops.op_log_softmax import LogSoftmax
from onnx.reference.ops.op_loop import Loop
from onnx.reference.ops.op_lp_normalization import LpNormalization
from onnx.reference.ops.op_lp_pool import LpPool
from onnx.reference.ops.op_lrn import LRN
from onnx.reference.ops.op_lstm import LSTM
from onnx.reference.ops.op_matmul import MatMul
from onnx.reference.ops.op_matmul_integer import MatMulInteger
from onnx.reference.ops.op_max import Max
from onnx.reference.ops.op_max_pool import MaxPool
from onnx.reference.ops.op_max_unpool import MaxUnpool
from onnx.reference.ops.op_mean import Mean
from onnx.reference.ops.op_mel_weight_matrix import MelWeightMatrix
from onnx.reference.ops.op_min import Min
from onnx.reference.ops.op_mod import Mod
from onnx.reference.ops.op_mul import Mul
from onnx.reference.ops.op_neg import Neg
from onnx.reference.ops.op_negative_log_likelihood_loss import NegativeLogLikelihoodLoss
from onnx.reference.ops.op_non_max_suppression import NonMaxSuppression
from onnx.reference.ops.op_non_zero import NonZero
from onnx.reference.ops.op_not import Not
from onnx.reference.ops.op_one_hot import OneHot
from onnx.reference.ops.op_optional import Optional
from onnx.reference.ops.op_optional_get_element import OptionalGetElement
from onnx.reference.ops.op_optional_has_element import OptionalHasElement
from onnx.reference.ops.op_or import Or
from onnx.reference.ops.op_pad import Pad_1, Pad_2, Pad_11, Pad_18
from onnx.reference.ops.op_pow import Pow
from onnx.reference.ops.op_prelu import PRelu
from onnx.reference.ops.op_qlinear_conv import QLinearConv
from onnx.reference.ops.op_qlinear_matmul import QLinearMatMul
from onnx.reference.ops.op_quantize_linear import (
    QuantizeLinear_10,
    QuantizeLinear_19,
    QuantizeLinear_21,
)
from onnx.reference.ops.op_random_normal import RandomNormal
from onnx.reference.ops.op_random_normal_like import RandomNormalLike
from onnx.reference.ops.op_random_uniform import RandomUniform
from onnx.reference.ops.op_random_uniform_like import RandomUniformLike
from onnx.reference.ops.op_range import Range
from onnx.reference.ops.op_reciprocal import Reciprocal
from onnx.reference.ops.op_reduce_l1 import ReduceL1_1, ReduceL1_18
from onnx.reference.ops.op_reduce_l2 import ReduceL2_1, ReduceL2_18
from onnx.reference.ops.op_reduce_log_sum import ReduceLogSum_1, ReduceLogSum_18
from onnx.reference.ops.op_reduce_log_sum_exp import (
    ReduceLogSumExp_1,
    ReduceLogSumExp_18,
)
from onnx.reference.ops.op_reduce_max import ReduceMax_1, ReduceMax_18
from onnx.reference.ops.op_reduce_mean import ReduceMean_1, ReduceMean_18
from onnx.reference.ops.op_reduce_min import ReduceMin_1, ReduceMin_18
from onnx.reference.ops.op_reduce_prod import ReduceProd_1, ReduceProd_18
from onnx.reference.ops.op_reduce_sum import ReduceSum_1, ReduceSum_13
from onnx.reference.ops.op_reduce_sum_square import (
    ReduceSumSquare_1,
    ReduceSumSquare_18,
)
from onnx.reference.ops.op_regex_full_match import RegexFullMatch
from onnx.reference.ops.op_relu import Relu
from onnx.reference.ops.op_reshape import Reshape_5, Reshape_14
from onnx.reference.ops.op_resize import Resize
from onnx.reference.ops.op_reverse_sequence import ReverseSequence
from onnx.reference.ops.op_rms_normalization import RMSNormalization
from onnx.reference.ops.op_rnn import RNN_7, RNN_14
from onnx.reference.ops.op_roi_align import RoiAlign
from onnx.reference.ops.op_rotary_embedding import RotaryEmbedding
from onnx.reference.ops.op_round import Round
from onnx.reference.ops.op_scan import Scan
from onnx.reference.ops.op_scatter_elements import ScatterElements
from onnx.reference.ops.op_scatternd import ScatterND
from onnx.reference.ops.op_selu import Selu
from onnx.reference.ops.op_sequence_at import SequenceAt
from onnx.reference.ops.op_sequence_construct import SequenceConstruct
from onnx.reference.ops.op_sequence_empty import SequenceEmpty
from onnx.reference.ops.op_sequence_erase import SequenceErase
from onnx.reference.ops.op_sequence_insert import SequenceInsert
from onnx.reference.ops.op_sequence_length import SequenceLength
from onnx.reference.ops.op_sequence_map import SequenceMap
from onnx.reference.ops.op_shape import Shape_1, Shape_15
from onnx.reference.ops.op_shrink import Shrink
from onnx.reference.ops.op_sigmoid import Sigmoid
from onnx.reference.ops.op_sign import Sign
from onnx.reference.ops.op_sin import Sin
from onnx.reference.ops.op_sinh import Sinh
from onnx.reference.ops.op_size import Size
from onnx.reference.ops.op_slice import Slice_1, Slice_10
from onnx.reference.ops.op_softmax import Softmax
from onnx.reference.ops.op_softmax_cross_entropy_loss import SoftmaxCrossEntropyLoss
from onnx.reference.ops.op_softplus import Softplus
from onnx.reference.ops.op_softsign import Softsign
from onnx.reference.ops.op_space_to_depth import SpaceToDepth
from onnx.reference.ops.op_split import Split_2, Split_11, Split_13, Split_18
from onnx.reference.ops.op_split_to_sequence import SplitToSequence
from onnx.reference.ops.op_sqrt import Sqrt
from onnx.reference.ops.op_squeeze import Squeeze_1, Squeeze_11, Squeeze_13
from onnx.reference.ops.op_stft import STFT
from onnx.reference.ops.op_string_concat import StringConcat
from onnx.reference.ops.op_string_normalizer import StringNormalizer
from onnx.reference.ops.op_string_split import StringSplit
from onnx.reference.ops.op_sub import Sub
from onnx.reference.ops.op_sum import Sum
from onnx.reference.ops.op_tan import Tan
from onnx.reference.ops.op_tanh import Tanh
from onnx.reference.ops.op_tfidf_vectorizer import TfIdfVectorizer
from onnx.reference.ops.op_thresholded_relu import ThresholdedRelu
from onnx.reference.ops.op_tile import Tile
from onnx.reference.ops.op_topk import TopK_1, TopK_10, TopK_11
from onnx.reference.ops.op_transpose import Transpose
from onnx.reference.ops.op_trilu import Trilu
from onnx.reference.ops.op_unique import Unique
from onnx.reference.ops.op_unsqueeze import Unsqueeze_1, Unsqueeze_11, Unsqueeze_13
from onnx.reference.ops.op_upsample import Upsample
from onnx.reference.ops.op_where import Where
from onnx.reference.ops.op_xor import Xor


def _build_registered_operators() -> dict[str, dict[int | None, type[OpRun]]]:
    return build_registered_operators_any_domain(globals().copy())


def load_op(
    domain: str,
    op_type: str,
    version: None | int = None,
    custom: Any = None,
    node: None | NodeProto = None,
    input_types: None | list[TypeProto] = None,
    expand: bool = False,
    evaluator_cls: type | None = None,
) -> Any:
    """Loads the implemented for a specified operator.

    Args:
        domain: domain
        op_type: oprator type
        version: requested version
        custom: custom implementation (like a function)
        node: used if no implementation was found and the operator
            defines a function which is context dependant
        input_types: used if no implementation was found and the
            operator defines a function which is context dependant
        expand: use the function implemented in the schema instead of
            its reference implementation
        evaluator_cls: evaluator to use

    Returns:
        class
    """
    global _registered_operators  # noqa: PLW0603
    schema = None
    if _registered_operators is None:
        _registered_operators = _build_registered_operators()  # type: ignore[assignment]
    assert _registered_operators is not None
    if custom is not None:
        return lambda *args: OpFunction(*args, impl=custom)  # type: ignore
    if version is None:
        version = onnx_opset_version()
    if domain != "":
        raise ValueError(f"Domain must be '' not {domain!r}.")
    if op_type in _registered_operators and not expand:  # type: ignore
        found = True
    else:
        # maybe the operator can be replaced by a function
        try:
            schema = get_schema(op_type, version, domain)  # type: ignore
        except SchemaError:
            raise NotImplementedError(
                f"No registered schema for operator {op_type!r} "
                f"and domain {domain!r}. Did you recompile the sources after updating the repository?"
            ) from None
        if schema.has_function:  # type: ignore
            body = schema.function_body  # type: ignore
            assert evaluator_cls is not None, (
                f"evaluator_cls must be specified to implement operator {op_type!r} from domain {domain!r}"
            )
            sess = evaluator_cls(body)
            return lambda *args, sess=sess: OpFunction(*args, impl=sess)  # type: ignore
        if schema.has_context_dependent_function:  # type: ignore
            if node is None or input_types is None:
                raise RuntimeContextError(
                    f"No registered implementation for operator {op_type!r} "
                    f"and domain {domain!r}, the operator has a context dependent function. "
                    f"but argument node or input_types is not defined (input_types={input_types})."
                )
            body = schema.get_context_dependent_function(  # type: ignore
                node.SerializeToString(), [it.SerializeToString() for it in input_types]
            )
            proto = FunctionProto()
            proto.ParseFromString(body)
            assert evaluator_cls is not None, (
                f"evaluator_cls must be specified to evaluate function {proto.name!r}"
            )
            sess = evaluator_cls(proto)
            return lambda *args, sess=sess: OpFunction(*args, impl=sess)  # type: ignore
        found = False
    if not found:
        available = "\n".join(textwrap.wrap(", ".join(sorted(_registered_operators))))  # type: ignore
        has_function = schema.has_function if schema else None  # type: ignore
        has_context_dependent_function = (
            schema.has_context_dependent_function if schema else None  # type: ignore
        )
        raise RuntimeImplementationError(
            f"No registered implementation for operator {op_type!r} "
            f"and domain {domain!r}, schema.has_function is {has_function}, "
            f"schema.has_context_dependent_function is {has_context_dependent_function}. "
            f"You may either add one or skip the test in "
            f"'reference_evaluator_bakcend_test.py'. Available implementations:\n{available}"
        )
    impl = _registered_operators[op_type]
    if None not in impl:
        raise RuntimeError(
            f"No default implementation for operator {op_type!r} "
            f"and domain {domain!r}, found "
            f"{', '.join(map(str, impl))}."
        )
    if version is None or len(impl) == 1:
        cl = impl[None]
    else:
        best = -1
        for v in impl:
            if v is None:
                continue
            if best < v <= version:
                best = v
        if best == -1:
            raise RuntimeError(
                f"No implementation for operator {op_type!r} "
                f"domain {domain!r} and version {version!r}, found "
                f"{', '.join(map(str, impl))}."
            )
        cl = impl[best]
    if cl is None:
        available = "\n".join(textwrap.wrap(", ".join(sorted(_registered_operators))))  # type: ignore
        raise ValueError(
            f"Not registered implementation for operator {op_type!r}, "
            f"domain {domain!r}, and {version!r} in\n{available}"
        )
    return cl


_registered_operators: dict[str, dict[int | None, OpRun]] | None = None
