# --------------------------------------------------------------------------
# ⚠️ WARNING - AUTO-GENERATED CODE - DO NOT EDIT ⚠️
# ⚙️ Generated by 'python -m opgen'
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# pylint: disable=W0221,W0222,R0901,W0237
# mypy: disable-error-code=override
# ruff: noqa: N801
# --------------------------------------------------------------------------

from __future__ import annotations

from typing import Optional, Sequence, TypeVar

from onnx import TensorProto
from onnx.defs import get_schema

from onnxscript.onnx_opset._impl.opset_ai_onnx_ml4 import Opset_ai_onnx_ml4
from onnxscript.onnx_types import DOUBLE, FLOAT, FLOAT16
from onnxscript.values import Op, Opset


class Opset_ai_onnx_ml5(Opset_ai_onnx_ml4):
    def __new__(cls):
        return Opset.__new__(cls, "ai.onnx.ml", 5)

    T_TreeEnsemble = TypeVar("T_TreeEnsemble", DOUBLE, FLOAT, FLOAT16)

    def TreeEnsemble(
        self,
        X: T_TreeEnsemble,
        *,
        aggregate_function: int = 1,
        leaf_targetids: Sequence[int],
        leaf_weights: TensorProto,
        membership_values: Optional[TensorProto] = None,
        n_targets: Optional[int] = None,
        nodes_falseleafs: Sequence[int],
        nodes_falsenodeids: Sequence[int],
        nodes_featureids: Sequence[int],
        nodes_hitrates: Optional[TensorProto] = None,
        nodes_missing_value_tracks_true: Optional[Sequence[int]] = None,
        nodes_modes: TensorProto,
        nodes_splits: TensorProto,
        nodes_trueleafs: Sequence[int],
        nodes_truenodeids: Sequence[int],
        post_transform: int = 0,
        tree_roots: Sequence[int],
    ) -> T_TreeEnsemble:
        r"""[🌐 ai.onnx.ml::TreeEnsemble(5)](https://onnx.ai/onnx/operators/onnx_aionnxml_TreeEnsemble.html#treeensemble-5 "Online Documentation")


            Tree Ensemble operator.  Returns the regressed values for each input in a batch.
            Inputs have dimensions `[N, F]` where `N` is the input batch size and `F` is the number of input features.
            Outputs have dimensions `[N, num_targets]` where `N` is the batch size and `num_targets` is the number of targets, which is a configurable attribute.

            The encoding of this attribute is split along interior nodes and the leaves of the trees. Notably, attributes with the prefix `nodes_*` are associated with interior nodes, and attributes with the prefix `leaf_*` are associated with leaves.
            The attributes `nodes_*` must all have the same length and encode a sequence of tuples, as defined by taking all the `nodes_*` fields at a given position.

            All fields prefixed with `leaf_*` represent tree leaves, and similarly define tuples of leaves and must have identical length.

            This operator can be used to implement both the previous `TreeEnsembleRegressor` and `TreeEnsembleClassifier` nodes.
            The `TreeEnsembleRegressor` node maps directly to this node and requires changing how the nodes are represented.
            The `TreeEnsembleClassifier` node can be implemented by adding a `ArgMax` node after this node to determine the top class.
            To encode class labels, a `LabelEncoder` or `GatherND` operator may be used.


        Args:
            X: Input of shape [Batch Size, Number of Features]

            aggregate_function: Defines how to aggregate leaf values within a target.
                <br>One of 'AVERAGE' (0) 'SUM' (1) 'MIN' (2) 'MAX (3) defaults to 'SUM'
                (1)

            leaf_targetids: The index of the target that this leaf contributes to (this
                must be in range `[0, n_targets)`).

            leaf_weights: The weight for each leaf.

            membership_values: Members to test membership of for each set membership
                node. List all of the members to test again in the order that the
                'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will
                have the same number of sets of values as nodes with mode
                'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any
                'BRANCH_MEMBER' nodes.

            n_targets: The total number of targets.

            nodes_falseleafs: 1 if false branch is leaf for each node and 0 if an
                interior node. To represent a tree that is a leaf (only has one node),
                one can do so by having a single `nodes_*` entry with true and false
                branches referencing the same `leaf_*` entry

            nodes_falsenodeids: If `nodes_falseleafs` is false at an entry, this
                represents the position of the false branch node. This position can be
                used to index into a `nodes_*` entry. If `nodes_falseleafs` is false, it
                is an index into the leaf_* attributes.

            nodes_featureids: Feature id for each node.

            nodes_hitrates: Popularity of each node, used for performance and may be
                omitted.

            nodes_missing_value_tracks_true: For each node, define whether to follow the
                true branch (if attribute value is 1) or false branch (if attribute
                value is 0) in the presence of a NaN input feature. This attribute may
                be left undefined and the default value is false (0) for all nodes.

            nodes_modes: The comparison operation performed by the node. This is encoded
                as an enumeration of 0 ('BRANCH_LEQ'), 1 ('BRANCH_LT'), 2
                ('BRANCH_GTE'), 3 ('BRANCH_GT'), 4 ('BRANCH_EQ'), 5 ('BRANCH_NEQ'), and
                6 ('BRANCH_MEMBER'). Note this is a tensor of type uint8.

            nodes_splits: Thresholds to do the splitting on for each node with mode that
                is not 'BRANCH_MEMBER'.

            nodes_trueleafs: 1 if true branch is leaf for each node and 0 an interior
                node. To represent a tree that is a leaf (only has one node), one can do
                so by having a single `nodes_*` entry with true and false branches
                referencing the same `leaf_*` entry

            nodes_truenodeids: If `nodes_trueleafs` is false at an entry, this
                represents the position of the true branch node. This position can be
                used to index into a `nodes_*` entry. If `nodes_trueleafs` is false, it
                is an index into the leaf_* attributes.

            post_transform: Indicates the transform to apply to the score. <br>One of
                'NONE' (0), 'SOFTMAX' (1), 'LOGISTIC' (2), 'SOFTMAX_ZERO' (3) or
                'PROBIT' (4), defaults to 'NONE' (0)

            tree_roots: Index into `nodes_*` for the root of each tree. The tree
                structure is derived from the branching of each node.
        """

        schema = get_schema("TreeEnsemble", 5, "ai.onnx.ml")
        op = Op(self, "TreeEnsemble", schema)
        return op(
            *self._prepare_inputs(schema, X),
            aggregate_function=aggregate_function,
            leaf_targetids=leaf_targetids,
            leaf_weights=leaf_weights,
            membership_values=membership_values,
            n_targets=n_targets,
            nodes_falseleafs=nodes_falseleafs,
            nodes_falsenodeids=nodes_falsenodeids,
            nodes_featureids=nodes_featureids,
            nodes_hitrates=nodes_hitrates,
            nodes_missing_value_tracks_true=nodes_missing_value_tracks_true,
            nodes_modes=nodes_modes,
            nodes_splits=nodes_splits,
            nodes_trueleafs=nodes_trueleafs,
            nodes_truenodeids=nodes_truenodeids,
            post_transform=post_transform,
            tree_roots=tree_roots,
        )
