# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The Pattern IR: used to describe (source) patterns of rewrite rules."""

from __future__ import annotations

import abc
import contextlib
import inspect
import itertools
from collections.abc import Mapping
from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    Protocol,
    Sequence,
    TypeVar,
    Union,
)

import onnxscript.rewriter._basics as _basics
from onnxscript import ir

T = TypeVar("T")


class Pattern(Protocol[T]):  # type: ignore[misc]
    """This is essentially a Predicate[T], that is, a Callable[[T], bool] bound to the name "matches"."""

    def matches(self, item: T) -> bool: ...


class StringPattern(abc.ABC, Pattern[str]):
    """Abstract base class for string patterns."""

    @abc.abstractmethod
    def matches(self, item: str) -> bool:
        pass

    @abc.abstractmethod
    def __str__(self) -> str:
        pass


class StringConstantPattern(StringPattern):
    """Matches strings with given value."""

    def __init__(self, value: str):
        self._value = value

    def matches(self, item: str) -> bool:
        return item == self._value

    def __str__(self) -> str:
        return self._value

    def value(self) -> str:
        return self._value


class PrefixPattern(StringPattern):
    """Matches strings with a given prefix."""

    def __init__(self, value: str) -> None:
        self._value = value

    def matches(self, value: str) -> bool:
        return value.startswith(self._value)

    def __str__(self) -> str:
        return f"{self._value}*"


class AttrPattern(Pattern[ir.Attr]):
    """Base class for an attribute pattern. Matches any attribute value by default."""

    def __init__(self, name: str | None, *, can_match_none: bool = False):
        self._name = name
        self._can_match_none = can_match_none

    @property
    def name(self) -> str | None:
        return self._name

    @property
    def can_match_none(self) -> bool:
        """Indicates whether this pattern can match a None attribute."""
        return self._can_match_none

    def matches(self, attr: ir.Attr) -> bool:
        return True

    def __str__(self) -> str:
        return self._name if self._name is not None else "anonymous:" + str(id(self))


class AttrVar(AttrPattern):
    """Represents a pattern variable used to match against attribute values."""

    def __init__(self, name: str | None, *, can_match_none: bool = False):
        super().__init__(name, can_match_none=can_match_none)


# TODO: Support tensors. Align with usage elsewhere.
SupportedAttrTypes = Union[
    int,
    float,
    str,
    Sequence[int],
    Sequence[float],
    Sequence[str],
]


class AttrConstantPattern(AttrPattern):
    """Matches attributes with given value.

    Uses standard equality for matching. For list-valued attributes, the order of elements matters.
    If order is immaterial, we need to define a separate pattern for that.
    """

    def __init__(self, value: SupportedAttrTypes):
        super().__init__(None)
        self._value = value

    def matches(self, attr: ir.Attr) -> bool:
        if attr.type in {
            ir.AttributeType.INTS,
            ir.AttributeType.FLOATS,
            ir.AttributeType.STRINGS,
        }:
            # Since the type of attr.value is Sequence, we need to convert to the same type for comparison.
            return tuple(attr.value) == tuple(self._value)
        return attr.value == self._value

    def __str__(self) -> str:
        return str(self._value)


def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> AttrPattern:
    """Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern."""
    if isinstance(value, AttrPattern):
        return value
    if isinstance(value, Var):
        # This is a hack. Currently, when we create pattern-variables, we create them as Var,
        # and change them to AttrPattern if/when used in an attribute context. We could use type
        # annotations to distinguish between ValuePattern and AttrPattern, but forces users to
        # use these type annotations.
        # TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.)
        if value.check_method is not None:
            raise ValueError(
                "Pattern variables used in attributes must not have check_method set."
            )
        return AttrVar(value.name, can_match_none=value.can_match_none)
    if isinstance(value, (int, float, str)):
        return AttrConstantPattern(value)
    if isinstance(value, Sequence):
        if all(isinstance(i, (int, float)) for i in value):
            return AttrConstantPattern(value)
        if all(isinstance(i, str) for i in value):
            return AttrConstantPattern(value)
        raise ValueError("Only lists of int/float/str can be used as an AttrPattern")
    raise TypeError(f"Cannot convert {type(value)} to AttrPattern")


class OpsetPatternBuilder:
    """Represents an opset pattern and a pattern builder.

    (i) It is used to create a NodePattern (via OpPatternBuilder).
    Example usage:
    ::

        z = op.Matmul(x, y)

    Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance
    of OpPatternBuilder, and  `op.Matmul(x, y)` is an instance of NodePattern.

    (ii) It contains a domain pattern matched against the actual opset domain used in the
    input model.
    """

    def __init__(self, domain: StringPattern | str, record: bool = False) -> None:
        if isinstance(domain, str):
            domain = StringConstantPattern(domain)
        self._domain_pattern = domain
        if record:
            self._nodes: list[NodePattern] | None = []
        else:
            self._nodes = None

    def domain_pattern(self) -> StringPattern:
        return self._domain_pattern

    def __getattr__(self, op_name: str) -> OpPatternBuilder:
        return OpPatternBuilder(self, op_name)

    def submodule(self, name: str) -> OpPatternBuilder:
        """This method is used to match against submodule ops with prefix."""
        return OpPatternBuilder(self, PrefixPattern(name))

    def __str__(self) -> str:
        return str(self._domain_pattern)

    def add_node(self, node: NodePattern) -> None:
        if self._nodes is not None:
            self._nodes.append(node)

    def nodes(self) -> Sequence[NodePattern]:
        if self._nodes is None:
            raise ValueError("Nodes were not recorded.")
        return self._nodes


onnxop = OpsetPatternBuilder("")

torch_module_op = OpsetPatternBuilder(PrefixPattern("pkg.torch"))


class OpPatternBuilder:
    """A utility class to build a NodePattern.

    It is used primarily to create a NodePattern.
    Example usage:
    ::

        z = op.Matmul(x, y)

    Here, `op` is an instance of OpsetPatternBuilder and `op.Matmul` is an instance
    of OpPatternBuilder, and  `op.Matmul(x, y)` is an instance of NodePattern.

    """

    def __init__(
        self,
        pattern_builder: OpsetPatternBuilder,
        op_name: str | Pattern[str],
    ) -> None:
        self.pattern_builder = pattern_builder
        self.op_name = op_name

    def __call__(
        self,
        *args,
        _domain: str | None = None,
        _version: int | None = None,
        _outputs: int | list[str | None] = 1,
        _allow_other_attributes: bool | None = None,
        _allow_other_inputs: bool | None = None,
        _check: Callable | None = None,
        **kwargs,
    ):
        if _version is not None:
            raise ValueError(
                "The pattern builder does not support '_version' keyword argument. "
                "Version restrictions should be handled by rewrite rules."
            )
        if _domain is None:
            opset_pattern = self.pattern_builder.domain_pattern()
        elif isinstance(_domain, str):
            opset_pattern = StringConstantPattern(_domain)
        else:
            # TODO(rama): allow OpsetPatternBuilder as _domain.
            raise TypeError("_domain must be a string.")

        if isinstance(_outputs, int):
            _outputs = [None for _ in range(_outputs)]
        elif not isinstance(_outputs, Sequence) or not all(
            isinstance(x, (str, type(None))) for x in _outputs
        ):
            raise ValueError("_outputs must be an int or a list[str|None].")
        inputs = [_to_value_pattern(x) for x in args]
        attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()}
        node_pattern = NodePattern(
            opset_pattern,
            self.op_name,
            inputs,
            attributes,
            _outputs,
            allow_other_attributes=_allow_other_attributes,
            allow_other_inputs=_allow_other_inputs,
            check=_check,
        )
        self.pattern_builder.add_node(node_pattern)
        output_values = node_pattern.outputs
        # Unpack outputs if there is only one output, the common case.
        if len(output_values) == 1:
            return output_values[0]
        else:
            return output_values


def _to_value_pattern(
    x: ValuePattern | int | float | Callable | None,
) -> ValuePattern | None:
    """Promotes an input-value used to construct a NodePattern to a ValuePattern.

    Example usage:
    ::
        x = op.MatMul(a, b)
        z = op.Add(x, 0)

    In this example, `a, `b`, and `x` are ValuePatterns used to construct a NodePattern.
    `0` is a constant (int) value, and is automatically promoted to a ValuePattern.

    Note that this is a shorthand for creating a Constant pattern. The user can more
    explicitly write this as:
    ::
        z = op.Add(x, op.Constant(0))

    If a callable is provided, it will be converted to a ValuePattern with the callable as the check attribute.
    """
    if x is None or isinstance(x, ValuePattern):
        return x
    if isinstance(x, (int, float)):
        return Constant(x)
    if isinstance(x, Sequence):
        if all(isinstance(i, (int, float)) for i in x):
            return Constant(x)
        raise ValueError("Only lists of int/float can be used as a ValuePattern")
    if callable(x):
        return ValuePattern(None, check=x)

    raise TypeError(f"Cannot convert {type(x)} to ValuePattern")


_pattern_builder: OpsetPatternBuilder = onnxop


@contextlib.contextmanager
def pattern_builder(builder: OpsetPatternBuilder):
    global _pattern_builder
    prev_builder = _pattern_builder
    _pattern_builder = builder
    yield
    _pattern_builder = prev_builder


class ValuePattern:
    """Base class for all patterns that match against IR values.

    This is used primarily to provide operator overloadings for arithmetic
    operations, so that we can write patterns like `x + 1` and `1 + x`.
    """

    def __init__(
        self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False
    ) -> None:
        self._name = name
        self._check = check
        self._can_match_none = can_match_none
        # Note: uses will be computed only when the full graph-pattern is constructed.
        self._uses: list[tuple[NodePattern, int]] = []

    def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern:
        del node_map
        return ValuePattern(self._name, check=self._check)

    @property
    def name(self) -> str | None:
        return self._name

    @property
    def check_method(self) -> Callable | None:
        return self._check

    @property
    def can_match_none(self) -> bool:
        """Indicates whether this variable can match a None input."""
        return self._can_match_none

    def producer(self) -> NodePattern | None:
        return None

    def uses(self) -> Sequence[tuple[NodePattern, int]]:
        return self._uses

    def append_use(self, node: NodePattern, index: int):
        self._uses.append((node, index))

    def __repr__(self) -> str:
        return f"ValuePattern({self._name!r})"

    def __add__(self, other):
        return _pattern_builder.Add(self, other)

    def __radd__(self, other):
        return _pattern_builder.Add(other, self)

    def __sub__(self, other):
        return _pattern_builder.Sub(self, other)

    def __rsub__(self, other):
        return _pattern_builder.Sub(other, self)

    def __mul__(self, other):
        return _pattern_builder.Mul(self, other)

    def __rmul__(self, other):
        return _pattern_builder.Mul(other, self)

    def __truediv__(self, other):
        return _pattern_builder.Div(self, other)

    def __rtruediv__(self, other):
        return _pattern_builder.Div(other, self)

    def __pow__(self, other):
        return _pattern_builder.Pow(self, other)

    def __str__(self) -> str:
        return self._name if self._name is not None else "anonymous:" + str(id(self))


class NodePattern:
    """Represents a pattern that matches against a Node.

    This differs from a NodeOutputPattern in that it matches against a node (which
    may produce 1 or more outputs), whereas a NodeOutputPattern matches against
    a specific output of a node.

    Args:
        domain: pattern to match against the domain of the node.
        op: pattern or string constant to match against the op_type of the node.
        inputs: sequence of ValuePatterns (or constants) to match against the inputs of the node.
        attributes: dictionary of attribute patterns to match against the attributes of the node.
        outputs: specifies pattern-variable-name for outputs (or None)
        allow_other_attributes: specifies whether other attributes (not mentioned in `attributes`)
          are allowed in the node.
    """

    def __init__(
        self,
        domain: StringPattern,
        op: str | Pattern[str],
        inputs: Sequence[int | float | ValuePattern | None],
        attributes: dict[str, AttrPattern],
        outputs: Sequence[str | None],
        *,
        allow_other_attributes: bool | None,
        allow_other_inputs: bool | None,
        check: Callable | None = None,
    ):
        if allow_other_attributes is None:
            # Default behavior: allow other unmatched attributes in the node.
            allow_other_attributes = True
        if allow_other_inputs is None:
            # TODO(rama): Should we default to True? For now, we preserve the current behavior.
            allow_other_inputs = False
        self.domain = domain
        self.op = StringConstantPattern(op) if isinstance(op, str) else op
        self.inputs = [_to_value_pattern(x) for x in inputs]
        self.attributes = attributes
        self.allow_other_attributes = allow_other_attributes
        self.allow_other_inputs = allow_other_inputs
        self._check = check
        # In the common case, domain and op are constants, which can be used to optimize matching.
        if isinstance(op, str) and isinstance(domain, StringConstantPattern):
            # TODO(rama): support overloaded operators.
            overload = ""
            self._op_identifier: ir.OperatorIdentifier | None = (
                domain.value(),
                op,
                overload,
            )
        else:
            self._op_identifier = None
        self.outputs = [NodeOutputPattern(self, i, name) for i, name in enumerate(outputs)]

        # Update uses for inputs.
        for index, value in enumerate(self.inputs):
            if value is not None:
                value.append_use(self, index)

    def __str__(self) -> str:
        inputs = ", ".join(str(v) for v in self.inputs)
        outputs = ", ".join(str(v) for v in self.outputs)
        attributes = ", ".join(f"{k}={v}" for k, v in self.attributes.items())
        op = str(self.op)
        domain = str(self.domain)
        qualified_op = f"{domain}.{op}" if domain else op
        inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs
        return f"{outputs} = {qualified_op} ({inputs_and_attributes})"

    def op_identifier(self) -> ir.OperatorIdentifier | None:
        return self._op_identifier

    @property
    def op_type(self) -> str:
        return str(self.op)

    @property
    def check_method(self) -> Callable | None:
        return self._check

    def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchResult:
        """Matches the pattern represented by self against a node.

        This is purely a local node-level match, and does not consider the subgraph rooted at the node.
        We check the domain, op_type, and attributes of the node, but not the inputs.
        """
        # TODO(rama): Ensure we handle "" and "onnx.ai" correctly.
        if not self.op.matches(node.op_type):
            return match.fail(
                f"OpType mismatch: expected {self.op}, got {node.op_type}.", node
            )
        if not self.domain.matches(node.domain):
            return match.fail(
                f"Domain mismatch: expected {self.domain}, got {node.domain}.", node
            )

        for name, attr_pattern in self.attributes.items():
            attr_value = node.attributes.get(name)
            if attr_value is None:
                if not attr_pattern.can_match_none:
                    return match.fail(f"Attribute {name} not found in node.", node)
            elif not attr_pattern.matches(attr_value):
                return match.fail(
                    f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.",
                    node,
                )
            if attr_pattern.name is not None:
                if not match.bind(attr_pattern.name, attr_value):
                    return match

        if not self.allow_other_attributes:
            for name in node.attributes:
                # TODO: Support matching default nodes for attributes.
                if name not in self.attributes:
                    return match.fail(f"Attribute {name} not expected in node.", node)

        return match

    def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePattern:
        inputs = [(v.clone(node_map) if v is not None else None) for v in self.inputs]
        if swap:
            assert len(inputs) == 2, (
                "Internal error: commutative swap applies only to binary ops."
            )
            inputs = [inputs[1], inputs[0]]
        outputs = [value.name for value in self.outputs]
        copied = NodePattern(
            self.domain,
            self.op,
            inputs,
            self.attributes,
            outputs,
            allow_other_attributes=self.allow_other_attributes,
            allow_other_inputs=self.allow_other_inputs,
            check=self._check,
        )
        node_map[self] = copied
        return copied


class NodeOutputPattern(ValuePattern):
    """Represents a pattern that matches against a specific output of a Node.

    This is the primary pattern used to match against computed values, that
    is values computed using a specific op.
    """

    def __init__(
        self, producer: NodePattern, output_index: int, name: str | None = None
    ) -> None:
        super().__init__(name)
        self._producer = producer
        self._output_index = output_index

    def clone(self, node_map: dict[NodePattern, NodePattern]) -> NodeOutputPattern:
        return node_map[self._producer].outputs[self._output_index]
        # return NodeOutputPattern(node_map[self._producer], self._output_index, self._name)

    @property
    def output_index(self) -> int:
        return self._output_index

    def producer(self) -> NodePattern:
        return self._producer


class Var(ValuePattern):
    """Represents a pattern-variable."""

    def __init__(
        self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False
    ) -> None:
        super().__init__(name, check=check, can_match_none=can_match_none)

    def clone(self, node_map: dict[NodePattern, NodePattern]) -> Var:
        """Clones the pattern-variable, preserving its name and check method."""
        return Var(self.name, check=self.check_method, can_match_none=self.can_match_none)


class AnyValue(ValuePattern):
    """Represents a pattern that matches against any value."""

    def __init__(self) -> None:
        super().__init__(None)

    def clone(self, node_map: dict[NodePattern, NodePattern]) -> AnyValue:
        # A single instance of AnyValue suffices.
        return self


ANY_VALUE = AnyValue()


class Constant(ValuePattern):
    """Represents a pattern that matches against a scalar constant value."""

    def __init__(
        self,
        value: int | float | Sequence[int] | Sequence[float],
        rel_tol: float = 1e-5,
        abs_tol: float = 1e-8,
    ) -> None:
        super().__init__(None)
        self._value = list(value) if isinstance(value, Sequence) else value
        self._rel_tol = rel_tol
        self._abs_tol = abs_tol

    def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant:
        del node_map
        return Constant(self._value, self._rel_tol, self._abs_tol)

    @property
    def value(self) -> int | float | list[int] | list[float]:
        return self._value

    def __str__(self) -> str:
        return str(self._value)


class OpIdDispatchOr(ValuePattern):
    """Represents a (restricted) form of value pattern disjunction that enables deterministic matching."""

    def __init__(
        self,
        op_to_pattern: Mapping[ir.OperatorIdentifier, tuple[Any, ValuePattern]],
        name: str | None = None,
        tag_var: str | None = None,
    ) -> None:
        """
        Initialize an OpIdDispatchOr pattern.

        Args:
            op_to_pattern: A dictionary mapping operator identifiers to tuples of tag values and patterns.
                The keys are operator identifiers, and the values are tuples containing a tag value
                and a pattern to match against.
            name: An optional variable name for the pattern. Defaults to None. If present,
                this name will be bound to the value matched by the pattern.
            tag_var: An optional variable name for the tag. Defaults to None. If present,
                it will be bound to a value indicating which alternative was matched.
        """
        super().__init__(name)
        self._op_to_pattern = op_to_pattern
        self._tag_var = tag_var

    @property
    def tag_var(self) -> str | None:
        """Returns the tag variable associated with the OrValue pattern."""
        return self._tag_var

    def clone(self, node_map: dict[NodePattern, NodePattern]) -> OpIdDispatchOr:
        return OpIdDispatchOr(
            {k: (v[0], v[1].clone(node_map)) for k, v in self._op_to_pattern.items()},
            self.name,
            self._tag_var,
        )

    def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern] | None:
        """Returns the pattern that should be tried for the given value."""
        producer = value.producer()
        if producer is not None:
            id = producer.op_identifier()
            if id is not None and id in self._op_to_pattern:
                return self._op_to_pattern[id]
        return None


class BacktrackingOr(ValuePattern):
    """Represents an unrestricted form of OR pattern implemented using backtracking."""

    def __init__(
        self,
        values: Sequence[ValuePattern],
        name: str | None = None,
        tag_var: str | None = None,
        tag_values: Sequence[Any] | None = None,
    ) -> None:
        """
        Initialize a BacktrackingOr pattern.

        Args:
            values: A sequence of value patterns to match against.
            name: An optional variable name for the pattern. Defaults to None. If present,
                this name will be bound to the value matched by the pattern.
            tag_var: An optional variable name for the tag. Defaults to None. If present,
                it will be bound to a value (from tag_values) indicating which alternative was matched.
            tag_values: An optional sequence of values to bind to the tag_var. Defaults to None.
                If present, the length of tag_values must match the number of alternatives in values.
                In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th
                alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used.
        """
        super().__init__(name)
        if tag_values is not None:
            if tag_var is None:
                raise ValueError("tag_var must be specified if tag_values is provided.")
            if len(tag_values) != len(values):
                raise ValueError(
                    "tag_values must have the same length as the number of alternatives."
                )
        else:
            tag_values = tuple(range(len(values)))
        self._tag_var = tag_var
        self._tag_values = tag_values
        self._values = values

    @property
    def tag_var(self) -> str | None:
        """Returns the tag variable associated with the OrValue pattern."""
        return self._tag_var

    def clone(self, node_map: dict[NodePattern, NodePattern]) -> BacktrackingOr:
        return BacktrackingOr(
            [v.clone(node_map) for v in self._values],
            self.name,
            self._tag_var,
            self._tag_values,
        )


def OrValue(
    values: Sequence[ValuePattern],
    name: str | None = None,
    tag_var: str | None = None,
    tag_values: Sequence[Any] | None = None,
) -> ValuePattern:
    """
    Creates an OR pattern.

    Args:
        values: A sequence of value patterns to match against.
        name: An optional variable name for the pattern. Defaults to None. If present,
            this name will be bound to the value matched by the pattern.
        tag_var: An optional variable name for the tag. Defaults to None. If present,
            it will be bound to a value (from tag_values) indicating which alternative was matched.
        tag_values: An optional sequence of values to bind to the tag_var. Defaults to None.
            If present, the length of tag_values must match the number of alternatives in values.
            In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th
            alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used.
    """
    if tag_values is not None:
        if tag_var is None:
            raise ValueError("tag_var must be specified if tag_values is provided.")
        if len(tag_values) != len(values):
            raise ValueError(
                "tag_values must have the same length as the number of alternatives."
            )
    else:
        tag_values = tuple(range(len(values)))

    def make_op_id_or_pattern() -> OpIdDispatchOr | None:
        mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {}
        for i, alternative in enumerate(values):
            if not isinstance(alternative, NodeOutputPattern):
                return None
            producer = alternative.producer()
            id = producer.op_identifier()
            if id is None or id in mapping:
                return None
            mapping[id] = (tag_values[i], alternative)
        return OpIdDispatchOr(mapping, name, tag_var)

    optimized_pattern = make_op_id_or_pattern()
    return optimized_pattern or BacktrackingOr(
        values, name, tag_var, tag_values if tag_var else None
    )


def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]:
    """Returns all nodes used in a pattern, given the outputs of the pattern."""
    node_patterns: list[NodePattern] = []

    def visit(value_patterns: Sequence[ValuePattern | None]) -> None:
        for value_pattern in value_patterns:
            if isinstance(value_pattern, NodeOutputPattern):
                node_pattern = value_pattern.producer()
                if node_pattern not in node_patterns:
                    node_patterns.append(node_pattern)
                    visit(node_pattern.inputs)

    visit(outputs)
    node_patterns.reverse()
    return node_patterns


def _add_backward_slice(
    node: NodePattern,
    backward_slice: set[NodePattern],
    backward_slice_values: set[ValuePattern],
) -> None:
    """Adds all nodes in the backward slice of given node to the set `backward_slice`.

    The backward slice of a node is the set of all nodes that are reachable from the node
    in a backward traversal from the given node.
    """
    if node in backward_slice:
        return
    backward_slice.add(node)
    for value_pattern in node.inputs:
        if isinstance(value_pattern, NodeOutputPattern):
            _add_backward_slice(
                value_pattern.producer(), backward_slice, backward_slice_values
            )
        elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)):
            backward_slice_values.add(value_pattern)


class GraphPattern:
    """Represents a pattern that can be matched against a subgraph."""

    def __init__(
        self,
        inputs: Sequence[ValuePattern],
        outputs: Sequence[ValuePattern],
        nodes: Sequence[NodePattern],
    ) -> None:
        self._inputs = inputs
        self._outputs = outputs
        if len(outputs) == 0:
            raise ValueError("GraphPattern must have at least one output")
        self._nodes = nodes  # _nodes_in_pattern(outputs)

        # Determine the output nodes of the pattern. These are a minimal set of nodes
        # whose backward-slices cover the entire pattern.
        output_nodes: set[NodePattern] = set()
        covered: set[NodePattern] = set()
        choice_values_returned: set[ValuePattern] = set()
        covered_choice_values: set[ValuePattern] = set()
        for value_pattern in outputs:
            if not isinstance(value_pattern, ValuePattern):
                raise TypeError(
                    f"Invalid type {type(value_pattern)} for graph pattern output."
                )
            if isinstance(value_pattern, NodeOutputPattern):
                candidate = value_pattern.producer()
                if candidate not in covered:
                    output_nodes.add(candidate)
                    _add_backward_slice(candidate, covered, covered_choice_values)
            elif isinstance(value_pattern, (OpIdDispatchOr, BacktrackingOr)):
                choice_values_returned.add(value_pattern)

        # check if all choice_values_returned are contained in covered_choice_values:
        # We don't yet support the use of a choice-value as a "root" of the search.
        # This is a limitation of the current implementation, and will be fixed in the future.
        if not (choice_values_returned <= covered_choice_values):
            raise NotImplementedError("Returning uncovered choice-values is not supported.")

        self.output_nodes: list[NodePattern] = list(output_nodes)

    @property
    def output_node(self) -> NodePattern:
        if len(self.output_nodes) != 1:
            raise ValueError("GraphPattern does not have unique output node.")
        return self.output_nodes[0]

    def node(self, index: int) -> NodePattern:
        return self._nodes[index]

    def num_nodes(self) -> int:
        return len(self._nodes)

    def __len__(self) -> int:
        return self.num_nodes()

    @property
    def inputs(self) -> Sequence[ValuePattern]:
        return self._inputs

    @property
    def outputs(self) -> Sequence[ValuePattern]:
        return self._outputs

    def __iter__(self) -> Iterator[NodePattern]:
        return iter(self._nodes)

    def __reversed__(self) -> Iterator[NodePattern]:
        return reversed(self._nodes)

    @property
    def has_single_output_node(self) -> bool:
        return len(self.output_nodes) == 1

    @property
    def num_outputs(self) -> int:
        return len(self._outputs)

    def commute(self) -> Sequence[GraphPattern]:
        def commute_node(node: NodePattern) -> Iterable[bool]:
            if node.op_identifier() == ("", "Add", "") or node.op_identifier() == (
                "",
                "Mul",
                "",
            ):
                # Try with and without swapping inputs.
                return [False, True]
            # No swapping of inputs
            return [False]

        iteration_space = [commute_node(node) for node in self._nodes]

        def copy_graph(swap_list: Iterable[bool]) -> GraphPattern:
            if not any(swap_list):
                # No need to swap inputs of any node
                return self
            # Create a copy of the graph, with swapped inputs for the nodes that need it.
            node_map: dict[NodePattern, NodePattern] = {}
            new_inputs = [v.clone(node_map) for v in self._inputs]
            new_nodes = [
                node.clone(node_map, swap) for node, swap in zip(self._nodes, swap_list)
            ]
            new_outputs = [v.clone(node_map) for v in self._outputs]
            return GraphPattern(new_inputs, new_outputs, new_nodes)

        return [copy_graph(swap_list) for swap_list in itertools.product(*iteration_space)]

    def __str__(self) -> str:
        inputs = ", ".join(str(v) for v in self._inputs)
        outputs = ", ".join(str(v) for v in self._outputs)
        nodes = "\n   ".join(str(n) for n in self._nodes)
        return f"pattern ({inputs}) {{\n   {nodes}\n   return {outputs}\n}}"


def _to_graph_pattern(pattern_constructor: Callable) -> GraphPattern:
    """Convert a pattern-construction function to a GraphPattern.

    A pattern-construction function will return values as below:
    ::
        def pattern(op, x: Var, shape1: Var, shape2: Var):
            ...
            return outputs

    We create a pattern graph by creating pattern-variables for each parameter of the function,
    and calling the function. The returned values are normalized to a list of ValuePatterns,
    which represent the outputs of the pattern graph.

    Args:
        pattern_constructor: Callable

    Returns:
        GraphPattern: A representation of the pattern that can be matched against a subgraph.
    """
    _pattern_vars = inspect.signature(pattern_constructor).parameters
    pattern_inputs = [Var(v) for v in _pattern_vars][1:]  # Skip the first parameter
    builder = OpsetPatternBuilder("", record=True)
    with pattern_builder(builder):
        pattern_outputs = pattern_constructor(builder, *pattern_inputs)
    # TODO(rama): classify inputs as value/attribute vars
    # Returned value could be a single ValuePattern or a list of ValuePatterns.
    # Normalize representation to a list of ValuePatterns.
    if isinstance(pattern_outputs, ValuePattern):
        pattern_outputs = [pattern_outputs]
    return GraphPattern(pattern_inputs, pattern_outputs, builder.nodes())
