# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Basic types for the pattern matching and rewriter API."""

from __future__ import annotations

import dataclasses
import enum
from collections import defaultdict
from typing import TYPE_CHECKING, Any, MutableSequence, Sequence, Union

from onnxscript import ir

if TYPE_CHECKING:
    import onnxscript.rewriter._pattern_ir as _pattern_ir
    import onnxscript.rewriter._rewrite_rule as _rewrite_rule


class MatchFailureInfo:
    """Encapsulates information about a pattern match failure."""

    def __init__(
        self,
        reason: str = "",
        *failure_source: ir.Node | ir.Value,
    ):
        self.reason = reason
        self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source
        assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), (
            f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}"
        )

    def __str__(self):
        return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})"


class MatchFailureError(MatchFailureInfo, Exception):
    """Exception raised when a pattern match fails.

    This makes it easier to handle match failures in a compositional way,
    for example, during the condition-checking phase of a pattern match.
    It allows us to define utility functions without having to check for
    and propagate match failures explicitly.
    """

    def __init__(
        self,
        reason: str = "",
        *failure_source: ir.Node | ir.Value,
    ):
        MatchFailureInfo.__init__(self, reason, *failure_source)
        Exception.__init__(self, reason)


class MatchResult:
    """The state object used by the pattern-matching algorithm.

    A match can either succeed or fail.
    If it succeeds, it returns a list of nodes that matched the pattern
    and a set of bindings for the variables in the pattern.

    Example:
    ::
        def pattern(x, shape1, shape2):
            t1 = op.Reshape(x, shape1)
            t2 = op.Reshape(t1, shape2)
            return t2
    The above pattern matches a sequence of two Reshape ops.
    The matched_nodes will contain the two Reshape ops, and the bindings will
    contain the values that are bound to the variables `x`, `shape1`, and `shape2`.
    """

    def __init__(self) -> None:
        # We use a stack of partial matches to handle OR patterns that require backtracking.
        self._partial_matches: list[PartialMatchResult] = [PartialMatchResult()]

    def __repr__(self) -> str:
        """Returns a string representation of the match result."""
        if not self._partial_matches:
            return "MatchResult()"
        return (
            f"MatchResult(success={bool(self)}, reason={self.reason!r}, nodes={self.nodes!r})"
        )

    @property
    def _current_match(self) -> PartialMatchResult:
        """Returns the current match result."""
        return self._partial_matches[-1]

    def enter_new_match(self) -> None:
        """Starts a new sub-match to try out one of multiple alternatives."""
        match = PartialMatchResult()
        self._partial_matches.append(match)

    def abandon_current_match(self) -> PartialMatchResult:
        """Abandons the current alternative due to failure."""
        if len(self._partial_matches) < 2:
            raise ValueError("No match to abandon.")
        return self._partial_matches.pop()

    def merge_current_match(self) -> None:
        """Merges a successful sub-match for an alternative with the parent one."""
        if len(self._partial_matches) < 2:
            raise ValueError("No match to merge.")
        current_match = self._partial_matches.pop()
        previous_match = self._partial_matches[-1]
        if not current_match:
            raise ValueError("Current match is not successful.")
        # Merge the two matches.
        previous_match.merge(current_match)

    def __bool__(self) -> bool:
        """Returns True if the current match is successful."""
        return bool(self._current_match)

    def fail(
        self,
        reason: str = "",
        failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None,
    ) -> MatchResult:
        self._current_match.fail(reason, failure_source)
        return self

    @property
    def reason(self) -> str:
        """Returns the reason for the failure."""
        return self._current_match.reason

    @property
    def nodes(self) -> Sequence[ir.Node]:
        """Returns the list of nodes that matched the pattern."""
        return self._current_match.nodes

    def bind_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node):
        """Binds a pattern node to a matched node."""
        self.add_node(node)
        self._current_match.node_bindings[pattern_node] = node

    def add_node(self, node: ir.Node) -> None:
        """Adds a node to the list of matched nodes."""
        self._current_match.add_node(node)

    def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool:
        var_name = pattern_value.name
        # TODO(rama): Simplify the following. We currently bind values to
        # pattern variables in two different ways: via their name, or via the
        # pattern-value itself.
        if var_name is None:
            for match in self._partial_matches:
                if pattern_value in match.value_bindings:
                    # TODO(rama): Use appropriate equality-check here.
                    if match.value_bindings[pattern_value] == value:
                        return True
                    self._current_match.fail(
                        f"Binding failure: {pattern_value} bound to two different values.",
                        [match.value_bindings[pattern_value], value],
                    )
                    return False
            self._current_match.value_bindings[pattern_value] = value
            return True
        return self.bind(var_name, value)

    def bind(self, var: str, value: Any) -> bool:
        for match in self._partial_matches:
            if var in match.bindings:
                # TODO(rama): Use appropriate equality-check here.
                if match.bindings[var] == value:
                    return True
                self._current_match.fail(
                    f"Binding failure: {var} bound to two different values.",
                    [match.bindings[var], value],
                )
                return False
        self._current_match.bindings[var] = value
        return True

    @property
    def bindings(self) -> dict[str, Any]:
        """Returns the bindings for the pattern variables."""
        if len(self._partial_matches) > 1:
            raise ValueError("Bindings can be accessed only at the top-level match.")
        return self._current_match.bindings

    @property
    def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]:
        """Returns the bindings for the value variables."""
        if len(self._partial_matches) > 1:
            raise ValueError("Value bindings can be accessed only at the top-level match.")
        return self._current_match.value_bindings

    @property
    def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]:
        """Returns the bindings for the node variables."""
        if len(self._partial_matches) > 1:
            raise ValueError("Node bindings can be accessed only at the top-level match.")
        return self._current_match.node_bindings

    @property
    def outputs(self) -> MutableSequence[ir.Value]:
        """Returns the list of output values that matched the pattern."""
        if len(self._partial_matches) > 1:
            raise ValueError("Outputs can be accessed only at the top-level match.")
        return self._current_match.outputs

    @property
    def failure_nodes_and_values(self) -> list[Union[ir.Node, ir.Value]]:
        """Returns the nodes and values that caused the failure."""
        return self._current_match._failure_nodes_and_values

    def lookup_node(self, pattern_node: _pattern_ir.NodePattern) -> ir.Node | None:
        """Looks up the node that matched the given pattern node."""
        for match in self._partial_matches:
            if pattern_node in match.node_bindings:
                return match.node_bindings[pattern_node]
        return None

    def num_matched_nodes(self) -> int:
        """Returns the number of nodes matched so far."""
        return sum(len(match.node_bindings) for match in self._partial_matches)


class PartialMatchResult:
    """The state object used by the pattern-matching algorithm for a sub-match."""

    def __init__(self) -> None:
        self._success: bool = True
        # For a successful match, _matched_nodes is a list of values that matched the pattern.
        # These include the internal nodes of the pattern that were matched, but not
        # the leaves (sub-trees) that match against the variables in the pattern.
        # These represent the values that will be replaced by the replacement pattern.
        self._matched_nodes: MutableSequence[ir.Node] = []
        # For a successful match, bindings is a dictionary of mapping pattern-variable-names
        # to values.
        self._bindings: dict[str, Any] = {}
        self._value_bindings: dict[_pattern_ir.ValuePattern, ir.Value] = {}
        self._node_bindings: dict[_pattern_ir.NodePattern, ir.Node] = {}

        self._outputs: list[ir.Value] = []
        # For a failed match, _reason is a string that describes the reason for the failure.
        self._reason: str = ""
        # Track the node(s) or value(s) that caused the failure.
        self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = []

    def __bool__(self):
        return self._success

    def fail(
        self,
        reason: str = "",
        failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None,
    ) -> None:
        self._success = False
        self._reason = reason
        if failure_source is not None:
            if isinstance(failure_source, list):
                self._failure_nodes_and_values.extend(failure_source)
            else:
                self._failure_nodes_and_values.append(failure_source)

    @property
    def reason(self) -> str:
        return self._reason

    @property
    def nodes(self) -> Sequence[ir.Node]:
        return tuple(self._matched_nodes)

    def add_node(self, node: ir.Node) -> None:
        """Adds a node to the list of matched nodes."""
        self._matched_nodes.append(node)

    @property
    def bindings(self) -> dict[str, Any]:
        return self._bindings

    @property
    def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]:
        return self._value_bindings

    @property
    def outputs(self) -> MutableSequence[ir.Value]:
        return self._outputs

    @property
    def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]:
        return self._node_bindings

    def merge(self, other: PartialMatchResult) -> None:
        """Merges a successful sub-match for an alternative with the parent one."""
        if self._success and other._success:
            # Merge the two successful matches. Matching algorithm responsible for ensuring
            # that the two matches are compatible. No need to check for conflicts here.
            self._bindings.update(other._bindings)
            self._matched_nodes.extend(other.nodes)
            # Note: outputs should be set only at end of the (top-level) match. There
            # should be no outputs in the sub-match.
            assert not other._outputs
        else:
            # This should not happen currently.
            raise NotImplementedError("Merging failed matches is not yet supported.")


class MatchStatus(enum.IntEnum):
    """The status of a pattern-matching operation."""

    NO_MATCH = 0  # No successful match found for entire pattern graph
    CONDITION_FAILED = 1  # Subsequent validation check failed
    REPLACEMENT_FAILED = 2  # Replacement subgraph could not be created
    SUCCESS = 3  # A successful match was found


@dataclasses.dataclass
class MatchInfo:
    """The status of a pattern-matching operation. An extension of MatchResult."""

    match_result: MatchResult
    root_node: ir.Node
    container: ir.Graph | ir.Function
    status: MatchStatus

    def score(self) -> int:
        """Return a score for the match."""
        return len(self.match_result.nodes) + int(self.status.value) * 100

    def print(self):
        separator = "-" * 80
        print(separator)
        print(f"Status: {self.status.name}")
        if self.status != MatchStatus.SUCCESS:
            reason = self.match_result.reason
            if reason:
                if self.status == MatchStatus.CONDITION_FAILED:
                    print(f"Graph matching failed due to failing check condition : {reason}")
                else:
                    print(f"Graph matching failed: {reason}")
            else:
                print("Graph matching failed.")
            failure_nodes_and_values = self.match_result.failure_nodes_and_values
            print("Failure at or around nodes/values:")
            if failure_nodes_and_values:
                for failure_cause in failure_nodes_and_values:
                    failure_cause.display()
        print("Matched nodes:")
        import onnxscript.rewriter._ir_utils as ir_utils

        ir_utils.display_nodes(self.match_result.nodes)
        print(separator)


class MatchContext:
    """A read-only context containing information about a pattern match.

    This class captures information about the context describing a match to a given pattern,
    providing access to the model, graph/function, root node, output values, and all
    nodes of the matching subgraph.
    """

    def __init__(
        self,
        model: ir.Model,
        graph_or_function: ir.Graph | ir.Function,
        root: ir.Node,
        match_result: MatchResult,
    ) -> None:
        """Initialize the pattern match context.

        Args:
            model: The model being matched.
            graph_or_function: The graph or function being matched.
            root: The root node of the matching subgraph.
            match_result: The match result containing matched nodes and outputs.
        """
        self._model = model
        self._graph_or_function = graph_or_function
        self._root = root
        self._match_result = match_result

    @property
    def model(self) -> ir.Model:
        """The model being matched."""
        return self._model

    @property
    def graph_or_function(self) -> ir.Graph | ir.Function:
        """The graph or function being matched."""
        return self._graph_or_function

    @property
    def root(self) -> ir.Node:
        """The root node of the matching subgraph."""
        return self._root

    @property
    def output_values(self) -> Sequence[ir.Value]:
        """The output values of the matching subgraph."""
        return self._match_result.outputs

    @property
    def nodes(self) -> Sequence[ir.Node]:
        """All the nodes of the matching subgraph."""
        return self._match_result.nodes

    def display(self, *, in_graph_order: bool = True) -> None:
        """Display the nodes in the pattern match context.

        Args:
            in_graph_order: If True, display nodes in the order they appear in the
                graph/function. If False, display nodes in the order they appear
                in the match result.
        """
        nodes = self.nodes
        if not nodes:
            return

        if in_graph_order:
            # Display nodes in same order as in graph/function
            for node in self._graph_or_function:
                if node in nodes:
                    node.display()
        else:
            # Display nodes in match order
            for node in nodes:
                node.display()


class MatchingTracer:
    """A debugging helper class to trace the matching of a pattern against a graph.

    This is used to track the best matches found for each rule, and to report the
    results at the end of the matching.
    """

    def __init__(self) -> None:
        self._best_matches_map: dict[_rewrite_rule.RewriteRule, list[MatchInfo]] = defaultdict(
            list
        )

    @property
    def best_matches_map(self) -> dict[_rewrite_rule.RewriteRule, list[MatchInfo]]:
        return self._best_matches_map

    def log(
        self,
        rule: _rewrite_rule.RewriteRule,
        container: ir.Graph | ir.Function,
        node: ir.Node,
        match_result: MatchResult,
        status: MatchStatus,
    ) -> None:
        this_match = MatchInfo(match_result, node, container, status)
        this_score = this_match.score()
        if this_score == 0:
            return
        best_matches = self._best_matches_map[rule]
        if best_matches:
            if this_score < best_matches[0].score():
                return
            if this_score > best_matches[0].score():
                best_matches.clear()
        best_matches.append(this_match)

    def report(self) -> None:
        best_score = 0
        for rule, matches in self._best_matches_map.items():
            if not matches:
                continue
            if matches[0].score() > best_score:
                best_score = matches[0].score()
                best_match = matches[0]
                best_rule = rule

        if best_score > 0:
            print(f"Rule: {best_rule}")
            best_match.print()
        else:
            print("No matches found.")
