# --------------------------------------------------------------------------
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
# ruff: noqa: N801
# --------------------------------------------------------------------------

from __future__ import annotations

from typing import Optional, Sequence, Tuple, TypeVar, Union

from onnx import TensorProto
from onnx.defs import get_schema
from typing_extensions import TypeAlias

from onnxscript.onnx_opset._impl.opset_ai_onnx_ml2 import Opset_ai_onnx_ml2
from onnxscript.onnx_types import DOUBLE, FLOAT, INT32, INT64, STRING
from onnxscript.values import Op, Opset


class Opset_ai_onnx_ml3(Opset_ai_onnx_ml2):
    def __new__(cls):
        return Opset.__new__(cls, "ai.onnx.ml", 3)

    T1_TreeEnsembleClassifier = TypeVar(
        "T1_TreeEnsembleClassifier", DOUBLE, FLOAT, INT32, INT64
    )

    T2_TreeEnsembleClassifier: TypeAlias = Union[INT64, STRING]

    def TreeEnsembleClassifier(
        self,
        X: T1_TreeEnsembleClassifier,
        *,
        base_values: Optional[Sequence[float]] = None,
        base_values_as_tensor: Optional[TensorProto] = None,
        class_ids: Optional[Sequence[int]] = None,
        class_nodeids: Optional[Sequence[int]] = None,
        class_treeids: Optional[Sequence[int]] = None,
        class_weights: Optional[Sequence[float]] = None,
        class_weights_as_tensor: Optional[TensorProto] = None,
        classlabels_int64s: Optional[Sequence[int]] = None,
        classlabels_strings: Optional[Sequence[str]] = None,
        nodes_falsenodeids: Optional[Sequence[int]] = None,
        nodes_featureids: Optional[Sequence[int]] = None,
        nodes_hitrates: Optional[Sequence[float]] = None,
        nodes_hitrates_as_tensor: Optional[TensorProto] = None,
        nodes_missing_value_tracks_true: Optional[Sequence[int]] = None,
        nodes_modes: Optional[Sequence[str]] = None,
        nodes_nodeids: Optional[Sequence[int]] = None,
        nodes_treeids: Optional[Sequence[int]] = None,
        nodes_truenodeids: Optional[Sequence[int]] = None,
        nodes_values: Optional[Sequence[float]] = None,
        nodes_values_as_tensor: Optional[TensorProto] = None,
        post_transform: str = "NONE",
    ) -> Tuple[T2_TreeEnsembleClassifier, FLOAT]:
        r"""[🌐 ai.onnx.ml::TreeEnsembleClassifier(3)](https://onnx.ai/onnx/operators/onnx_aionnxml_TreeEnsembleClassifier.html#treeensembleclassifier-3 "Online Documentation")


            Tree Ensemble classifier. Returns the top class for each of N inputs.

            The attributes named 'nodes_X' form a sequence of tuples, associated by
            index into the sequences, which must all be of equal length. These tuples
            define the nodes.

            Similarly, all fields prefixed with 'class_' are tuples of votes at the leaves.
            A leaf may have multiple votes, where each vote is weighted by
            the associated class_weights index.

            One and only one of classlabels_strings or classlabels_int64s
            will be defined. The class_ids are indices into this list.
            All fields ending with <i>_as_tensor</i> can be used instead of the
            same parameter without the suffix if the element type is double and not float.


        Args:
            X: Input of shape [N,F]

            base_values: Base values for classification, added to final class score; the
                size must be the same as the classes or can be left unassigned (assumed
                0)

            base_values_as_tensor: Base values for classification, added to final class
                score; the size must be the same as the classes or can be left
                unassigned (assumed 0)

            class_ids: The index of the class list that each weight is for.

            class_nodeids: node id that this weight is for.

            class_treeids: The id of the tree that this node is in.

            class_weights: The weight for the class in class_id.

            class_weights_as_tensor: The weight for the class in class_id.

            classlabels_int64s: Class labels if using integer labels.<br>One and only
                one of the 'classlabels_*' attributes must be defined.

            classlabels_strings: Class labels if using string labels.<br>One and only
                one of the 'classlabels_*' attributes must be defined.

            nodes_falsenodeids: Child node if expression is false.

            nodes_featureids: Feature id for each node.

            nodes_hitrates: Popularity of each node, used for performance and may be
                omitted.

            nodes_hitrates_as_tensor: Popularity of each node, used for performance and
                may be omitted.

            nodes_missing_value_tracks_true: For each node, define what to do in the
                presence of a missing value: if a value is missing (NaN), use the 'true'
                or 'false' branch based on the value in this array.<br>This attribute
                may be left undefined, and the default value is false (0) for all nodes.

            nodes_modes: The node kind, that is, the comparison to make at the node.
                There is no comparison to make at a leaf node.<br>One of 'BRANCH_LEQ',
                'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ',
                'LEAF'

            nodes_nodeids: Node id for each node. Ids may restart at zero for each tree,
                but it not required to.

            nodes_treeids: Tree id for each node.

            nodes_truenodeids: Child node if expression is true.

            nodes_values: Thresholds to do the splitting on for each node.

            nodes_values_as_tensor: Thresholds to do the splitting on for each node.

            post_transform: Indicates the transform to apply to the score. <br> One of
                'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT.'
        """

        schema = get_schema("TreeEnsembleClassifier", 3, "ai.onnx.ml")
        op = Op(self, "TreeEnsembleClassifier", schema)
        return op(
            *self._prepare_inputs(schema, X),
            base_values=base_values,
            base_values_as_tensor=base_values_as_tensor,
            class_ids=class_ids,
            class_nodeids=class_nodeids,
            class_treeids=class_treeids,
            class_weights=class_weights,
            class_weights_as_tensor=class_weights_as_tensor,
            classlabels_int64s=classlabels_int64s,
            classlabels_strings=classlabels_strings,
            nodes_falsenodeids=nodes_falsenodeids,
            nodes_featureids=nodes_featureids,
            nodes_hitrates=nodes_hitrates,
            nodes_hitrates_as_tensor=nodes_hitrates_as_tensor,
            nodes_missing_value_tracks_true=nodes_missing_value_tracks_true,
            nodes_modes=nodes_modes,
            nodes_nodeids=nodes_nodeids,
            nodes_treeids=nodes_treeids,
            nodes_truenodeids=nodes_truenodeids,
            nodes_values=nodes_values,
            nodes_values_as_tensor=nodes_values_as_tensor,
            post_transform=post_transform,
        )

    T_TreeEnsembleRegressor = TypeVar("T_TreeEnsembleRegressor", DOUBLE, FLOAT, INT32, INT64)

    def TreeEnsembleRegressor(
        self,
        X: T_TreeEnsembleRegressor,
        *,
        aggregate_function: str = "SUM",
        base_values: Optional[Sequence[float]] = None,
        base_values_as_tensor: Optional[TensorProto] = None,
        n_targets: Optional[int] = None,
        nodes_falsenodeids: Optional[Sequence[int]] = None,
        nodes_featureids: Optional[Sequence[int]] = None,
        nodes_hitrates: Optional[Sequence[float]] = None,
        nodes_hitrates_as_tensor: Optional[TensorProto] = None,
        nodes_missing_value_tracks_true: Optional[Sequence[int]] = None,
        nodes_modes: Optional[Sequence[str]] = None,
        nodes_nodeids: Optional[Sequence[int]] = None,
        nodes_treeids: Optional[Sequence[int]] = None,
        nodes_truenodeids: Optional[Sequence[int]] = None,
        nodes_values: Optional[Sequence[float]] = None,
        nodes_values_as_tensor: Optional[TensorProto] = None,
        post_transform: str = "NONE",
        target_ids: Optional[Sequence[int]] = None,
        target_nodeids: Optional[Sequence[int]] = None,
        target_treeids: Optional[Sequence[int]] = None,
        target_weights: Optional[Sequence[float]] = None,
        target_weights_as_tensor: Optional[TensorProto] = None,
    ) -> FLOAT:
        r"""[🌐 ai.onnx.ml::TreeEnsembleRegressor(3)](https://onnx.ai/onnx/operators/onnx_aionnxml_TreeEnsembleRegressor.html#treeensembleregressor-3 "Online Documentation")


            Tree Ensemble regressor.  Returns the regressed values for each input in N.

            All args with nodes_ are fields of a tuple of tree nodes, and
            it is assumed they are the same length, and an index i will decode the
            tuple across these inputs.  Each node id can appear only once
            for each tree id.

            All fields prefixed with target_ are tuples of votes at the leaves.

            A leaf may have multiple votes, where each vote is weighted by
            the associated target_weights index.

            All fields ending with <i>_as_tensor</i> can be used instead of the
            same parameter without the suffix if the element type is double and not float.
            All trees must have their node ids start at 0 and increment by 1.

            Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF


        Args:
            X: Input of shape [N,F]

            aggregate_function: Defines how to aggregate leaf values within a target.
                <br>One of 'AVERAGE,' 'SUM,' 'MIN,' 'MAX.'

            base_values: Base values for regression, added to final prediction after
                applying aggregate_function; the size must be the same as the classes or
                can be left unassigned (assumed 0)

            base_values_as_tensor: Base values for regression, added to final prediction
                after applying aggregate_function; the size must be the same as the
                classes or can be left unassigned (assumed 0)

            n_targets: The total number of targets.

            nodes_falsenodeids: Child node if expression is false

            nodes_featureids: Feature id for each node.

            nodes_hitrates: Popularity of each node, used for performance and may be
                omitted.

            nodes_hitrates_as_tensor: Popularity of each node, used for performance and
                may be omitted.

            nodes_missing_value_tracks_true: For each node, define what to do in the
                presence of a NaN: use the 'true' (if the attribute value is 1) or
                'false' (if the attribute value is 0) branch based on the value in this
                array.<br>This attribute may be left undefined and the default value is
                false (0) for all nodes.

            nodes_modes: The node kind, that is, the comparison to make at the node.
                There is no comparison to make at a leaf node.<br>One of 'BRANCH_LEQ',
                'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ',
                'LEAF'

            nodes_nodeids: Node id for each node. Node ids must restart at zero for each
                tree and increase sequentially.

            nodes_treeids: Tree id for each node.

            nodes_truenodeids: Child node if expression is true

            nodes_values: Thresholds to do the splitting on for each node.

            nodes_values_as_tensor: Thresholds to do the splitting on for each node.

            post_transform: Indicates the transform to apply to the score. <br>One of
                'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT'

            target_ids: The index of the target that each weight is for

            target_nodeids: The node id of each weight

            target_treeids: The id of the tree that each node is in.

            target_weights: The weight for each target

            target_weights_as_tensor: The weight for each target
        """

        schema = get_schema("TreeEnsembleRegressor", 3, "ai.onnx.ml")
        op = Op(self, "TreeEnsembleRegressor", schema)
        return op(
            *self._prepare_inputs(schema, X),
            aggregate_function=aggregate_function,
            base_values=base_values,
            base_values_as_tensor=base_values_as_tensor,
            n_targets=n_targets,
            nodes_falsenodeids=nodes_falsenodeids,
            nodes_featureids=nodes_featureids,
            nodes_hitrates=nodes_hitrates,
            nodes_hitrates_as_tensor=nodes_hitrates_as_tensor,
            nodes_missing_value_tracks_true=nodes_missing_value_tracks_true,
            nodes_modes=nodes_modes,
            nodes_nodeids=nodes_nodeids,
            nodes_treeids=nodes_treeids,
            nodes_truenodeids=nodes_truenodeids,
            nodes_values=nodes_values,
            nodes_values_as_tensor=nodes_values_as_tensor,
            post_transform=post_transform,
            target_ids=target_ids,
            target_nodeids=target_nodeids,
            target_treeids=target_treeids,
            target_weights=target_weights,
            target_weights_as_tensor=target_weights_as_tensor,
        )
