# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Callable, Any, List
from collections import defaultdict

import torch
from torch.fx import Node, Graph

from .util import get_last_uses


def get_output_node(graph: Graph):
    for v in graph.nodes:
        if v.target == "output":
            return v
    raise ValueError("No output node found")


def move_primals_to_head(graph: Graph):

    # Move primals to the head of the graph
    primals = [n for n in graph.nodes if n.op == "placeholder"]
    non_primals = [n for n in graph.nodes if n.op != "placeholder"]
    all_nodes = primals + non_primals

    new_graph = Graph()
    env = {}
    for node in all_nodes:
        new_node = new_graph.node_copy(node, lambda n: env[n.name])
        env[node.name] = new_node
    new_graph.lint()

    return new_graph


def add_args_process(graph: Graph,
                     node: Node,
                     fn: Callable[..., Any],
                     extra_args: List[int] = [],
                     name=None,
                     meta={}) -> List[Node]:
    # Apply fn to all args of node
    new_nodes = []
    with graph.inserting_before(node):
        target_args = [arg for arg in node.args if isinstance(arg, Node)]

        for arg in target_args:
            new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name)
            for k, v in meta.items():
                new_node.meta[k] = v
            node.replace_input_with(arg, new_node)
            new_nodes.append(new_node)

    return new_nodes


def add_postprocess(graph: Graph,
                    node: Node,
                    fn: Callable[..., Any],
                    extra_args: List[int] = [],
                    name=None,
                    meta={}) -> Node:
    # https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
    with graph.inserting_after(node):
        args = (node, )
        for a in extra_args:  # To add ds_id
            args += (a, )

        node_users = node.users.keys()
        new_node = graph.create_node('call_function', fn, args, {}, name=name)
        users = {}
        for u in node_users:
            if u != new_node:
                users[u] = (node, new_node)
        for u, (old_in, new_in) in users.items():
            u.replace_input_with(old_in, new_in)

    for k, v in meta.items():
        new_node.meta[k] = v

    return new_node


def _make_node_meta(node: Node, ds_id: int, comm: bool):
    meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm}
    if "tensor_meta" in node.meta:
        meta["tensor_meta"] = node.meta["tensor_meta"]
    return meta


def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
    node_to_last_use, _ = get_last_uses(graph)
    activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names])

    offload_id_to_node = {}
    node_to_wait_reload = {}
    for node in graph.nodes:
        if node.target == torch.ops.dc.reload_tensor.default:
            offload_act = node.args[0]
            # node_to_offload_id[offload_act] = node.args[2]
            offload_id_to_node[node.args[2]] = offload_act
        elif node.target == torch.ops.dc.wait_reload.default:
            offload_id = node.args[2]
            node_to_wait_reload[offload_id_to_node[offload_id]] = node

    activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set)

    last_user_to_uses = defaultdict(list)
    for node, last_user in node_to_last_use.items():
        last_user_to_uses[last_user].append(node)

    def _should_free(node: Node) -> bool:
        if not hasattr(node, "meta"):
            return False
        if not "tensor_meta" in node.meta:
            return False
        return True

    def free_tensors(tensors: List[torch.Tensor]):
        for a in tensors:
            if a.numel() > 10_000_000:
                a.data = torch.empty([0], device=a.device, dtype=a.dtype)

    for last_user, used_nodes in last_user_to_uses.items():
        activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]

        if len(activation_args) == 0:
            continue

        node_name = f"free_activations_{[n.name for n in used_nodes]}"
        with graph.inserting_after(last_user):
            args = (activation_args, )
            graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name)

            # Python version for debugging
            # graph.create_node('call_function', free_tensors, args, {}, name=node_name)
