# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rewrite rules to eliminate redundant ScatterND operations.

This module contains two rewrite rules:

1. ScatterAllDynamic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates)
   when the indices are computed dynamically using Range operations but represent a complete update
   of an entire axis. This is generated by the translation of `x[:, ...] = y` in PyTorch.

2. ScatterAllStatic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates)
   when the indices are statically known constants in the form [[0], [1], ..., [n-1]] covering
   the entire first dimension of the data tensor.

Both rules detect when the scatter-update ends up being an assignment of a new value to the entire tensor.
"""

from __future__ import annotations

import onnx_ir as ir

import onnxscript.rewriter
from onnxscript.rewriter import _ir_utils
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet


class ScatterAllDynamic(RewriteRuleClassBase):
    def __init__(self):
        super().__init__(remove_nodes=False)

    def pattern(self, op, data, axis, transposed_data, updates):
        # Construct update-indices spanning an entire axis:
        shape = op.Shape(data, start=0)
        dim = op.Gather(shape, axis, axis=0)
        full_range = op.Range(0, dim, 1)
        full_range_2d = op.Unsqueeze(full_range, [-1])
        # The update is applied to the data transposed to bring the updated axis to the front:
        return op.ScatterND(transposed_data, full_range_2d, updates, reduction="none")

    def check(self, context, data, axis, transposed_data, **_):
        # Check that updated-indices represent the full range of the first dimension of the transposed data.
        # That is: check that the data.shape[axis] matches transposed_data.shape[0].
        result = onnxscript.rewriter.MatchResult()
        axis_value = _ir_utils.get_singleton_value(axis)
        if not isinstance(axis_value, int):
            return result.fail("Axis value must be a constant integer.", axis)
        shape: ir.Shape | None = data.shape
        if shape is None:
            return result.fail("Data shape is not statically known.", data)
        updated_dim_value = shape[axis_value]
        transposed_data_shape: ir.Shape | None = transposed_data.shape
        if transposed_data_shape is None:
            return result.fail(
                "Transposed data shape is not statically known.", transposed_data
            )
        actual_dim_value = transposed_data_shape[0]
        if not _ir_utils.same_dim(updated_dim_value, actual_dim_value):
            # The first dimension of the transposed data does not match the updated dimension,
            # so we cannot apply this rule.
            return result.fail(
                "The first dimension of the transposed data does not match the updated dimension.",
                [data, transposed_data],
            )
        return True

    def rewrite(self, op, updates, **_):
        return op.Identity(updates)


class ScatterAllStatic(RewriteRuleClassBase):
    """Rewrite rule for eliminating redundant ScatterND with statically known indices.

    This handles the case where indices are constant values in the form [[0], [1], ..., [n-1]]
    that update the entire first dimension of the data tensor.
    """

    def pattern(self, op, data, indices, updates):
        """Pattern to match ScatterND with static indices."""
        return op.ScatterND(data, indices, updates)

    def check(self, context, data, indices, updates, **_):
        """Check if the ScatterND is redundant due to static indices covering entire tensor."""
        # To validate data can be replaced directly by updates, we need to check the following:
        # 1. they have the same shape
        result = onnxscript.rewriter.MatchResult()
        if data.shape is None:
            return result.fail("The value 'data' shape is not statically known.", data)
        if updates.shape is None:
            return result.fail("The value 'updates' shape is not statically known.", updates)
        if not _ir_utils.same_shape(data.shape, updates.shape):
            return result.fail(
                "The shape of 'data' and 'updates' are different.", [data, updates]
            )

        # 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
        if indices.const_value is None:
            return result.fail("The value 'indices' is not statically known.", indices)
        expected_indices = [[i] for i in range(data.shape[0])]
        actual_indices = indices.const_value.numpy().tolist()
        if actual_indices != expected_indices:
            return result.fail("The 'indices' is not referring to the whole data.", indices)

        return True

    def rewrite(self, op, updates, **_):
        """Replace ScatterND with Identity since updates covers entire tensor."""
        return op.Identity(updates)


no_op_dynamic_scatter_nd_rule = ScatterAllDynamic.rule()
no_op_static_scatter_nd_rule = ScatterAllStatic.rule()

rules = RewriteRuleSet([no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule])
