# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

try:
    import torch.utils._pytree as pytree
    from torch._functorch.aot_autograd import create_aot_dispatcher_function
    from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
    from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
    from torch._inductor.virtualized import V
    from torch._inductor.scheduler import Scheduler

    original_create_aot_dispatcher_function = create_aot_dispatcher_function
except ImportError:
    pass

from .util import get_input_nodes
from .graph_param import DSGraphParamManager


def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):

    def wrapped_compiler(gm, fake_inputs):
        mod_graph = dc_compiler(gm, fake_inputs)

        # For symint case
        if mod_graph is None:
            return None

        if z3_partition:
            # Inductor validates input size estimated by the first trace, where ds tensor is materialized.
            # We need to patch the input tensors to avoid the validation error.
            patched_inputs = []
            if bwd:
                param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph)
                param_names = [n.name for n in param_nodes_bw]
            else:
                param_names = graph_param_manager[graph_id].param_names
            input_nodes = get_input_nodes(gm.graph)

            for in_node, in_v in zip(input_nodes, fake_inputs):
                ds_param = in_node.name in param_names
                if ds_param:
                    from torch._subclasses.fake_tensor import is_fake
                    from torch._dynamo.utils import to_fake_tensor
                    assert is_fake(in_v), f"Input {in_v} should be fake tensor"
                    patched_inputs.append(
                        to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode))
                else:
                    patched_inputs.append(in_v)

            patched_inputs = tuple(patched_inputs)
        else:
            patched_inputs = fake_inputs

        return original_compiler(gm, patched_inputs)

    return wrapped_compiler


def wrap_partition_fn(partition_fn, real_inputs, param_indices):

    def wrapped_partition_fn(*args, **kwargs):

        fw_module, bw_module = partition_fn(*args, **kwargs)

        # get parameter names
        pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)

        def fix_placeholder_meta(graph):
            for n in graph.nodes:
                if n.op == "placeholder" and n.name in pm.param_names:
                    n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)

        fix_placeholder_meta(fw_module.graph)
        fix_placeholder_meta(bw_module.graph)

        return fw_module, bw_module

    return wrapped_partition_fn


def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
                                         param_indices, param_manager):

    from torch._dynamo.backends.common import AotAutograd
    import functools

    def patch_aotautograd():
        # Unpatch if it was already patched
        if hasattr(AotAutograd, "__original_init"):
            AotAutograd.__init__ = AotAutograd.__original_init

        original_init = AotAutograd.__init__

        @functools.wraps(original_init)
        def patched_init(self, **kwargs):
            kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"],
                                                   make_fw_graph,
                                                   z3_partition,
                                                   graph_id,
                                                   param_manager,
                                                   bwd=False)
            kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"],
                                                   make_bw_graph,
                                                   z3_partition,
                                                   graph_id,
                                                   param_manager,
                                                   bwd=True)
            kwargs["inference_compiler"] = kwargs["fw_compiler"]

            if z3_partition:
                kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)

            original_init(self, **kwargs)

        AotAutograd.__original_init = original_init
        AotAutograd.__init__ = patched_init

    patch_aotautograd()


def register_custom_ops():

    def fallback_handler_no_reuse(kernel,
                                  never_reuse_input,
                                  never_reuse_output,
                                  force_free_input,
                                  add_to_fallback_set=True):
        if add_to_fallback_set:
            fallbacks.add(kernel)

        def handler(*args, **kwargs):

            def wrap_tensors(x):
                out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x
                if out is not None and never_reuse_output:
                    V.graph.never_reuse_buffers.add(out.get_name())
                return out

            class CustomDCKernel(FallbackKernel):

                def __init__(self, op, *args, **kwargs):
                    super().__init__(op, *args, **kwargs)

                    def add_to_never_reuse(x):
                        if isinstance(x, IRNode):
                            assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}"
                            V.graph.never_reuse_buffers.add(x.get_name())

                    if never_reuse_input:
                        pytree.tree_map(add_to_never_reuse, args)

                def get_var_name_for_arg(self, arg: str):
                    if arg.isidentifier():
                        return arg

                    import re
                    match = re.match(r"reinterpret_tensor\((\w+),", arg)
                    if match:
                        return match.group(1)
                    return None

                def codegen(self, wrapper):
                    if not force_free_input:
                        return super().codegen(wrapper)

                    kernel = self.op_overload
                    self.codegen_comment(wrapper)
                    args = [*self.codegen_args(), *self.codegen_kwargs()]

                    V.graph.wrapper_code.generate_fallback_kernel(self, args)
                    if isinstance(self.layout, Layout):
                        self.codegen_size_asserts(wrapper)

                    var_name = self.get_var_name_for_arg(args[0])
                    if var_name:
                        wrapper.writeline(f"{var_name} = None")

                    self.codegen_unbacked_symbol_defs(wrapper)

            kernel_cls = CustomDCKernel if force_free_input else FallbackKernel
            return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs))

        return handler

    def register_fallback_no_reuse(op_overload,
                                   never_reuse_input=False,
                                   never_reuse_output=False,
                                   force_free_input=False):
        add_needs_realized_inputs(op_overload)
        return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse(
            op_overload,
            never_reuse_input=never_reuse_input,
            never_reuse_output=never_reuse_output,
            force_free_input=force_free_input))

    # Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops.
    # -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops.
    register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True)
    register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True)
    register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False)
    register_fallback_no_reuse(torch.ops.dc.reduce_grad.default,
                               never_reuse_input=True,
                               never_reuse_output=True,
                               force_free_input=True)
    register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)

    if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
        Scheduler.is_dc_patched = True
        Scheduler.dead_node_elimination = lambda _: None
