# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""
Use to partition the activations stored for backward propagation
Therefore reduces the memory consumption
Also implements CPU checkpointing and contiguous memory checkpointing
Reduces memory consumption and memory fragmentation

Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
"""

# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import copy
import torch
import contextlib
from deepspeed import comm as dist
import weakref

import mmap
from torch import _C

from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime import compiler

# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False

# MP parameters
mpu = None

#set default values
mp_rank = 0
mp_size = 1
mp_group = None

# Model Parameters
num_layers = None

# Checkpointing buffers
contiguous_data_buffers = []
data_offsets = []

contiguous_size_buffers = []
size_offsets = []

timers = None

# optimization flags
PARTITION_ACTIVATIONS = False
CPU_CHECKPOINT = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False

# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'


def detach_variable(inputs, device=None):
    if isinstance(inputs, tuple):
        out = []
        for inp in inputs:
            if not isinstance(inp, torch.Tensor):
                out.append(inp)
                continue

            requires_grad = inp.requires_grad

            if device is not None:
                x = inp.to(device=device)
            else:
                x = inp

            x = x.detach()
            x.requires_grad = requires_grad
            out.append(x)
        return tuple(out)
    else:
        raise RuntimeError("Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)


def _set_cuda_rng_state(new_state, device=-1):
    """Sets the random number generator state of the current GPU.

    Arguments:
        new_state (torch.ByteTensor): The desired state
    This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
    with a single change: the input state is not cloned. Cloning caused
    major performance issues for +4 GPU cases.
    """
    if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
        # older PyTorch
        def cb():
            with get_accelerator().device(device):
                _C._cuda_setRNGState(new_state)
    else:
        # newer PyTorch
        if device == -1:
            device = torch.device(get_accelerator().device_name())
        elif isinstance(device, str):
            device = torch.device(device)
        elif isinstance(device, int):
            device = torch.device(get_accelerator().device_name(), device)

        def cb():
            idx = device.index
            if idx is None:
                idx = get_accelerator().current_device()
            default_generator = get_accelerator().default_generator(idx)
            default_generator.set_state(new_state)

    get_accelerator().lazy_call(cb)


class CudaRNGStatesTracker:
    """Tracker for the cuda RNG states.

    Using the `add` method, a cuda rng state is initialized based on
    the input `seed` and is assigned to `name`. Later, by forking the
    rng state, we can perform operations and return to our starting
    cuda state.
    """

    def __init__(self):
        # Map from a string name to the cuda rng state.
        self.states_ = {}
        # Seeds are just for book keeping and ensure no seed is set twice.
        self.seeds_ = set()

    def reset(self):
        """Set to the initial state (no tracker)."""
        self.states_ = {}
        self.seeds_ = set()

    def get_states(self):
        """Get rng states. Copy the dictionary so we have direct
        pointers to the states, not just a pointer to the dictionary."""
        return copy.copy(self.states_)

    def set_states(self, states):
        """Set the rng states. For efficiency purposes, we do not check
        the size of seed for compatibility."""
        self.states_ = states

    def add(self, name, seed):
        """Track the rng state."""
        # Check seed is not already used.
        if seed in self.seeds_:
            raise Exception('seed {} already exists'.format(seed))
        self.seeds_.add(seed)
        # Check that state is not already defined.
        if name in self.states_:
            raise Exception('cuda rng state {} already exists'.format(name))
        # Get the current rng state.
        orig_rng_state = get_accelerator().get_rng_state()
        # Set the new state and store it.
        get_accelerator().manual_seed(seed)
        self.states_[name] = get_accelerator().get_rng_state()
        # Reset rng state to what it was.
        _set_cuda_rng_state(orig_rng_state)

    @contextlib.contextmanager
    def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
        """Fork the cuda rng state, perform operations, and exit with
        the original state."""
        # Check if we have added the state
        if name not in self.states_:
            raise Exception('cuda rng state {} is not added'.format(name))
        # Store current rng state.
        orig_cuda_rng_state = get_accelerator().get_rng_state()
        # Set rng state to the desired one
        _set_cuda_rng_state(self.states_[name])
        # Do the stuff we wanted to do.
        try:
            yield
        finally:
            # Update the current rng state for later use.
            self.states_[name] = get_accelerator().get_rng_state()
            # And set the state to the original state we started with.
            _set_cuda_rng_state(orig_cuda_rng_state)


# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()


def get_cuda_rng_tracker():
    """Get cuda rng tracker."""
    return _CUDA_RNG_STATE_TRACKER


def model_parallel_cuda_manual_seed(seed):
    """Initialize model parallel cuda seed.

    This function should be called after the model parallel is
    initialized. Also, no get_accelerator().manual_seed should be called
    after this function. Basically, this is replacement for that
    function.
    Two set of RNG states are tracked:
        default state: This is for data parallelism and is the same among a
                       set of model parallel GPUs but different across
                       different model parallel groups. This is used for
                       example for dropout in the non-model-parallel regions.
        model-parallel state: This state is different among a set of model
                              parallel GPUs, but the same across data parallel
                              groups. This is used for example for dropout in
                              model parallel regions.
    """
    global mpu

    tp_rank = bwc_tensor_model_parallel_rank(mpu)

    # 2718 is just for fun and any POSITIVE value will work.
    offset = seed + 2718
    model_parallel_seed = offset + tp_rank
    # Data parallel gets the original seed.
    data_parallel_seed = seed

    if dist.get_rank() == 0:
        logger.info(
            '> initializing model parallel cuda seeds on global rank {}, '
            'model parallel rank {}, and data parallel rank {} with '
            'model parallel seed: {} and data parallel seed: {}'.format(dist.get_rank(), tp_rank,
                                                                        mpu.get_data_parallel_rank(),
                                                                        model_parallel_seed, data_parallel_seed), )
    _CUDA_RNG_STATE_TRACKER.reset()
    # Set the default state.
    get_accelerator().manual_seed(data_parallel_seed)
    # and model parallel state.
    _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)


def model_parallel_reconfigure_tp_seed(seed):
    global mpu
    tp_rank = bwc_tensor_model_parallel_rank(mpu)
    model_parallel_seed = seed + 2718 + tp_rank
    with _CUDA_RNG_STATE_TRACKER.fork():
        get_accelerator().manual_seed(model_parallel_seed)


def get_partition_start(item):
    global mp_rank, mp_size, mp_group
    size = item.numel()
    partition_size = size / mp_size
    start = partition_size * mp_rank
    return int(start)


def get_partition_size(item):
    global mp_rank, mp_size, mp_group
    size = item.numel()
    assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
    partition_size = size / mp_size
    return int(partition_size)


def gather_partitioned_activations(tensors, device=None):
    global mp_rank, mp_size, mp_group
    assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
    inputs = []
    num_args = int(len(tensors) / 2)
    for i in range(num_args):

        item = tensors[2 * i]
        size = tensors[2 * i + 1]

        if not is_activation_to_checkpoint(item):
            inputs.append(item)
            continue

        # don't need to do all_gather if model parallel is not enabled
        if mp_group is None or mp_size == 1:
            item = item.view(list(size.numpy()))
            if device is not None:
                item = item.to(device)
            inputs.append(item)
            continue

        partition_size = item.numel()
        tensor_size = partition_size * mp_size
        if device is not None:
            flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
        else:
            flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
        part = flat_tensor.narrow(0, partition_size * mp_rank, partition_size)
        part.copy_(item)
        dist.all_gather_into_tensor(flat_tensor, part, group=mp_group)
        input_tensor = flat_tensor.view(list(size.numpy()))
        item.data = input_tensor.data

        inputs.append(item)

    return tuple(inputs)


def extract_tensors(all_objects):
    """
    Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
    The order of tensors and non-tensors is preserved in their respective output groups.

    Parameters:
        all_objects (list/tuple): Objects containing tensors and non-tensors to be split.

    Returns:
        tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.

    """
    tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
    non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
    tensor_flags = [torch.is_tensor(v) for v in all_objects]
    if type(all_objects) is tuple:
        return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
    return tensor_objects, non_tensor_objects, tensor_flags


def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
    """
    Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).

    Parameters:
        tensor_objects (list/tuple): Tensors to merge.
        non_tensor_objects (list/tuple): Non-tensors to merge.
        tensor_flags (list/tuple): Indicates whether each position in output is a tensor.

    Returns:
        tuple: Merge of tensors and non-tensors
    """
    merged_objects = []
    tensor_idx = 0
    non_tensor_idx = 0

    real_tensor_flags = None

    # remove the flags that are assigned to the size of the flattened tensors
    if PARTITION_ACTIVATIONS:
        real_tensor_flags = []
        previous_flag = False
        for flag in tensor_flags:
            if previous_flag:
                previous_flag = False
                continue
            previous_flag = flag
            real_tensor_flags.append(flag)
    else:
        real_tensor_flags = tensor_flags

    for is_tensor in real_tensor_flags:
        if is_tensor:
            merged_objects.append(tensor_objects[tensor_idx])
            tensor_idx += 1
        else:
            merged_objects.append(non_tensor_objects[non_tensor_idx])
            non_tensor_idx += 1

    return tuple(merged_objects)


def is_activation_to_checkpoint(item):
    """
        Is an activation to be checkpointed
    """
    global mp_size
    extra_flag = (not hasattr(item, 'no_checkpointing')) or (hasattr(item, 'no_checkpointing')
                                                             and item.no_checkpointing == False)
    return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size and extra_flag


def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
    global contiguous_data_buffers, data_offsets

    inputs = []
    num_non_fp_tensors = 0

    for arg_index, item in enumerate(args):
        if not is_activation_to_checkpoint(item):
            inputs.append(item)
            num_non_fp_tensors += 1
            continue

        i = arg_index - num_non_fp_tensors
        partition_size = get_partition_size(item)
        partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()

        buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device

        if contiguous_checkpoint:
            if i >= len(contiguous_data_buffers):
                tensor_list = [
                    torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
                    for _ in range(num_layers)
                ]
                contiguous_data_buffers.append(tensor_list)
                data_offsets.append(0)
            elif contiguous_data_buffers[i] is None:
                tensor_list = [
                    torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
                    for _ in range(num_layers)
                ]
                contiguous_data_buffers[i] = tensor_list
                data_offsets[i] = 0

            # Because the 'new_empty' returns uninitialized pages,
            # the pages need to be populated during the cudaMemcpy time
            # which increases the data copy time. To avoid this, we
            # pre-populate these pages by simply writing 0 ahead of
            # the actual cudaMemcpy operation time. Due to the
            # previously launched GPU kernels, there is a small
            # window of time here for CPUs to populate pages asynchronously.
            contiguous_data_buffers[i][data_offsets[i]].data[range(
                0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
                int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0

            contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)
            data_offsets[i] = data_offsets[i] + 1
            inputs.append(contiguous_partition)
        else:
            partition = partition.cpu() if CPU_CHECKPOINT else partition
            inputs.append(partition)

    return inputs


def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint):
    global contiguous_size_buffers, size_offsets

    new_args = []
    num_non_fp_tensors = 0

    for arg_index, (arg, inp) in enumerate(zip(args, inputs)):
        size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
        if not is_activation_to_checkpoint(arg):
            new_args.append(arg)
            new_args.append(size)
            num_non_fp_tensors += 1
            continue

        arg.data = torch.empty([], device=arg.device).data
        arg.saved_data = inp.data

        new_args.append(arg)
        i = arg_index - num_non_fp_tensors

        if contiguous_checkpoint:
            numel = size.numel()
            if i >= len(contiguous_size_buffers):
                tmp = torch.tensor(())
                contiguous_size_buffers.append(
                    tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))
                size_offsets.append(0)
            elif contiguous_size_buffers[i] is None:
                tmp = torch.tensor(())
                contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)
                size_offsets[i] = 0

            contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)
            contiguous_size = contiguous_size.view_as(size)
            size_offsets[i] = size_offsets[i] + numel
            new_args.append(contiguous_size)
        else:
            new_args.append(size)

    return new_args


def get_cpu_activations_for_backward(args, inputs):
    new_args = []
    for i, (arg, inp) in enumerate(zip(args, inputs)):
        if not is_activation_to_checkpoint(arg):
            new_args.append(arg)
            continue

        arg.data = torch.empty([], device=arg.device).data
        arg.saved_data = inp.data
        new_args.append(arg)

    return new_args


class CheckpointFunction(torch.autograd.Function):
    """This function is adapted from torch.utils.checkpoint with
       two main changes:
           1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`  #ignore-cuda
           2) the states in the model parallel tracker are also properly
              tracked/set/reset.
           3) Performance activation partitioning, contiguous memory optimization
           4) CPU Checkpointing
           5) Profile forward and backward functions
    """

    @staticmethod
    def forward(ctx, run_function, all_outputs, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        def save_args_for_backward(*all_args):
            tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
            ctx.deepspeed_saved_tensors = tensor_args
            ctx.non_tensor_args = non_tensor_args
            ctx.tensor_flags = tensor_flags

        if SYNCHRONIZE:
            get_accelerator().synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers(FORWARD_GLOBAL_TIMER).start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        cuda_device = get_accelerator().current_device_name()
        transport_stream = get_accelerator().Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
        elif CPU_CHECKPOINT:
            inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)

        # just in case something funky is happening such as reuse of inputs
        inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        see_memory_usage("Before running forward on the layer", force=False)
        # ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        see_memory_usage("After running forward on the layer", force=False)
        del inputs_cuda

        if PARTITION_ACTIVATIONS:
            new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
            assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
            save_args_for_backward(*new_args)
        elif CPU_CHECKPOINT:
            new_args = get_cpu_activations_for_backward(args, inputs)
            save_args_for_backward(*new_args)
        else:
            save_args_for_backward(*args)

        if PROFILE_TIME:
            timers(FORWARD_GLOBAL_TIMER).stop()
            timers.log([FORWARD_GLOBAL_TIMER])
        if SYNCHRONIZE:
            get_accelerator().synchronize()

        # Tensors returned from forward() may not be differentiable.
        if torch.is_tensor(outputs):
            non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
        else:
            non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]
        ctx.mark_non_differentiable(*non_grad_outputs)

        if torch.is_tensor(outputs):
            all_outputs += [outputs]
            return outputs
        else:
            all_outputs += outputs
            outputs, _, _ = extract_tensors(all_objects=outputs)
            return tuple(outputs)

    @staticmethod
    def backward(ctx, *grads):
        global timers
        see_memory_usage("In backward", force=False)
        # removing pointers to the contiguous buffer memory
        # so that they can be garbage collected once the checkpoints
        # have been used
        if SYNCHRONIZE:
            get_accelerator().synchronize()
        if PROFILE_TIME:
            timers('backward').start()

        if CONTIGUOUS_CHECKPOINTING:
            global data_offsets, size_offsets
            global contiguous_data_buffers, contiguous_size_buffers

            for buffers in contiguous_data_buffers:
                buffers = []

            # frees up all the pointers to the checkpoints except for the ones
            # stored by save for backward
            contiguous_data_buffers = []
            contiguous_size_buffers = []
            data_offsets = []
            size_offsets = []

        see_memory_usage("In backward checkpointing code", force=False)
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), "
                               "please use .backward() if possible")

        global PARTITION_ACTIVATIONS
        cuda_device = get_accelerator().current_device_name()
        transport_stream = get_accelerator().Stream(device=cuda_device)
        # Rebuild deepspeed_saved_tensors
        for t in ctx.deepspeed_saved_tensors:
            if t is not None and hasattr(t, 'saved_data') and t.saved_data is not None:
                t.data = t.saved_data.to(t.device)
                t.saved_data = None

        if PARTITION_ACTIVATIONS:
            # with get_accelerator().stream(transport_stream):
            inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,
                                                    device=cuda_device if CPU_CHECKPOINT else None)
            detached_inputs = detach_variable(inputs)
        elif CPU_CHECKPOINT:
            inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
            detached_inputs = detach_variable(inputs)
        else:
            inputs = ctx.deepspeed_saved_tensors
            detached_inputs = detach_variable(inputs)

        # Add non tensor input args
        detached_inputs = merge_tensors(tensor_objects=detached_inputs,
                                        non_tensor_objects=ctx.non_tensor_args,
                                        tensor_flags=ctx.tensor_flags)

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = get_accelerator().get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # if PARTITION_ACTIVATIONS:
        #     current_stream=get_accelerator().current_stream()
        #     current_stream.wait_stream(transport_stream)

        see_memory_usage("In backward checkpointing code before forward", force=False)

        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        see_memory_usage("In backward checkpointing code after forward", force=False)
        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        # Filter out non tensor outputs
        outputs, _, _ = extract_tensors(all_objects=outputs)

        # Construct arguments to autograd.backward().
        # This is usually just outputs and grads, but forward() can return tensors that
        # are not differentiable.
        output_tensors = []
        grad_tensors = []
        for out, grad in zip(outputs, grads):
            if out.requires_grad:
                output_tensors.append(out)
                grad_tensors.append(grad)

        see_memory_usage("In backward checkpointing code before backward", force=False)

        torch.autograd.backward(output_tensors, grad_tensors)

        # Force clear our stashed tensors to prevent a memory leak in certain scenarios
        ctx.deepspeed_saved_tensors = None
        ctx.non_tensor_args = None
        ctx.tensor_flags = None

        see_memory_usage("After backward checkpointing code after backward", force=False)

        if PROFILE_TIME:
            timers('backward').stop()
            timers.log(['backward'])
        if SYNCHRONIZE:
            get_accelerator().synchronize()
        ret_list = [None, None]  # first None for ctx
        for inp in detached_inputs:
            if torch.is_tensor(inp):
                ret_list.append(inp.grad)
            else:
                ret_list.append(None)

        return tuple(ret_list)


def non_reentrant_checkpoint(function, *args):
    """This function is union of `torch.utils.checkpoint._checkpoint_without_reentrant` and `CheckpointFunction` in this module

    This function is aim to solve the back probagation error raised from all input requires no grad.
    * has already been implemented in pytorch for a while, the solution is stable at most time except for jit module mode.
    * can help to solve the issue which is hacked by `deepspeed.runtime.pipe.module.PipelineModule._is_checkpointable`

    Main modifications compared to the implementation of torch:
    1. adapt to the signature of `checkpoint` function in this module
    2. solve the non-deterministic by random state management consistent with deepspeed `CheckpointFunction`
    3. when there is partition or cpu checkpointing, gather them in the unpack_hook during back probagation
    4. make all after backward blocks in the hook which will executed after all leaf nodes backward execution.
    5. above 4. is inspired by `torch.autograd.graph.register_multi_grad_hook`, which is only implemented after 2.0.0
    """
    global mpu, timers, SYNCHRONIZE, PROFILE_TIME

    deepspeed_saved_tensors = None
    non_tensor_args = None
    tensor_flags = None

    def save_args_for_backward(*all_args):
        """keep this function to reduce the modification from original implementation"""
        nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags
        tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
        deepspeed_saved_tensors = tensor_args
        non_tensor_args = non_tensor_args
        tensor_flags = tensor_flags

    if SYNCHRONIZE:
        get_accelerator().synchronize()

    if timers is None and PROFILE_TIME:
        timers = Timers()

    if PROFILE_TIME:
        timers(FORWARD_GLOBAL_TIMER).start()

    global num_layers
    global mp_rank, mp_size, mp_group
    global contiguous_data_buffers, contiguous_size_buffers
    global data_offsets, size_offsets
    global PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

    cuda_device = get_accelerator().current_device_name()
    transport_stream = get_accelerator().Stream(device=cuda_device)

    if PARTITION_ACTIVATIONS:
        inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
    elif CPU_CHECKPOINT:
        inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)

    # just in case something funky is happening such as reuse of inputs
    inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)

    # Copy the rng states.
    fwd_cpu_rng_state = torch.get_rng_state()
    fwd_cuda_rng_state = get_accelerator().get_rng_state()
    fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

    if PARTITION_ACTIVATIONS:
        new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
        assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
        save_args_for_backward(*new_args)
    elif CPU_CHECKPOINT:
        new_args = get_cpu_activations_for_backward(args, inputs)
        save_args_for_backward(*new_args)
    else:
        save_args_for_backward(*args)

    class Holder():
        """the place holder object used as activations to save memory"""
        pass

    # weakref seems utilized to discover the tensor deletion before a whole
    # forward backward pair loop finished
    storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
    weak_holder_list = []
    leaf_tensors = []
    backward_visited_leaf_nodes = 0

    def checkpoint_pack(tensor_from_forward):
        """used to record the activation order in the `weak_holder_list`

        the activation order in holder list is consistent between the first forward and recomputing forward.
        * the jit compiled forward will break the order consistency *
        """
        res = Holder()
        weak_holder_list.append(weakref.ref(res))

        # if this is a leaf tensor, save it for backward progression trace
        # leaf tensor used to be input or parameters, which is not activations and
        # has no memory overhead
        if tensor_from_forward.requires_grad and tensor_from_forward.is_leaf:
            leaf_tensors.append(tensor_from_forward)
        return res

    def checkpoint_unpack(holder_from_backward):
        """retrieve the activations from recompute"""
        nonlocal deepspeed_saved_tensors, non_tensor_args, tensor_flags

        # if this is the first step of backward probagation, recompute the graph and save
        # all the activations with the same order as `checkpoint_pack` does
        if len(storage) == 0:
            unpack_counter = 0

            def replay_pack(tensor_from_replay):
                """save recompute activations"""
                nonlocal unpack_counter
                unpack_counter += 1

                if weak_holder_list[unpack_counter - 1]() is None:
                    return

                detached_activations = tensor_from_replay.detach()
                storage[weak_holder_list[unpack_counter - 1]()] = detached_activations

                return

            def replay_unpack(none_value):
                """recompute graph need not to backward"""
                raise RuntimeError("You are calling backwards on a tensor that is never exposed.")

            global timers
            see_memory_usage("In backward", force=False)
            # removing pointers to the contiguous buffer memory
            # so that they can be garbage collected once the checkpoints
            # have been used
            if SYNCHRONIZE:
                get_accelerator().synchronize()
            if PROFILE_TIME:
                timers('backward').start()

            if CONTIGUOUS_CHECKPOINTING:
                global data_offsets, size_offsets
                global contiguous_data_buffers, contiguous_size_buffers

                for buffers in contiguous_data_buffers:
                    buffers = []

                # frees up all the pointers to the checkpoints except for the ones
                # stored by save for backward
                contiguous_data_buffers = []
                contiguous_size_buffers = []
                data_offsets = []
                size_offsets = []

            see_memory_usage("In backward checkpointing code", force=False)
            if not torch.autograd._is_checkpoint_valid():
                raise RuntimeError("Checkpointing is not compatible with .grad(), "
                                   "please use .backward() if possible")

            global PARTITION_ACTIVATIONS
            cuda_device = get_accelerator().current_device_name()
            transport_stream = get_accelerator().Stream(device=cuda_device)

            # gather inputs which is partitioned or checkpointed before first forward
            if PARTITION_ACTIVATIONS:
                # with get_accelerator().stream(transport_stream):
                inputs = gather_partitioned_activations(deepspeed_saved_tensors,
                                                        device=cuda_device if CPU_CHECKPOINT else None)
                detached_inputs = detach_variable(inputs)
            elif CPU_CHECKPOINT:
                inputs = move_to_device(deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
                detached_inputs = detach_variable(inputs)
            else:
                inputs = deepspeed_saved_tensors
                detached_inputs = detach_variable(inputs)

            # Add non tensor input args
            detached_inputs = merge_tensors(tensor_objects=detached_inputs,
                                            non_tensor_objects=non_tensor_args,
                                            tensor_flags=tensor_flags)

            # Store the current states.
            bwd_cpu_rng_state = torch.get_rng_state()
            bwd_cuda_rng_state = get_accelerator().get_rng_state()
            bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

            # Set the states to what it used to be before the forward pass.
            torch.set_rng_state(fwd_cpu_rng_state)
            _set_cuda_rng_state(fwd_cuda_rng_state)
            get_cuda_rng_tracker().set_states(fwd_cuda_rng_state_tracker)

            see_memory_usage("In backward checkpointing code before forward", force=False)
            with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(replay_pack, replay_unpack):
                _unused = function(*detached_inputs)

            see_memory_usage("In backward checkpointing code after forward", force=False)
            # Set the states back to what it was at the start of this function.
            torch.set_rng_state(bwd_cpu_rng_state)
            _set_cuda_rng_state(bwd_cuda_rng_state)
            get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

            deepspeed_saved_tensors = None
            non_tensor_args = None
            tensor_flags = None

        if holder_from_backward not in storage:
            raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
                               " recomputation being triggered in between, this is not currently supported.")

        return storage[holder_from_backward]

    def after_backward_hook(_nonuse_grads):
        """the hook registered to all leaf tensors"""
        nonlocal leaf_tensors, backward_visited_leaf_nodes
        backward_visited_leaf_nodes += 1

        if backward_visited_leaf_nodes == len(leaf_tensors):
            see_memory_usage("After backward checkpointing code after backward", force=False)

            if PROFILE_TIME:
                timers('backward').stop()
                timers.log(['backward'])
            if SYNCHRONIZE:
                get_accelerator().synchronize()

    with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack):
        outputs = function(*inputs_cuda)
    if PROFILE_TIME or SYNCHRONIZE:
        for leaf_tensor in leaf_tensors:
            leaf_tensor.register_hook(after_backward_hook)

    see_memory_usage("After running forward on the layer", force=False)

    if PROFILE_TIME:
        timers(FORWARD_GLOBAL_TIMER).stop()
        timers.log([FORWARD_GLOBAL_TIMER])
    if SYNCHRONIZE:
        get_accelerator().synchronize()

    all_outputs = []
    if torch.is_tensor(outputs):
        all_outputs += [outputs]
    else:
        all_outputs += outputs

    if len(all_outputs) == 1:
        return all_outputs[0]
    else:
        return tuple(all_outputs)


@compiler.disable  # WA from Pytorch repo for compile + zero 3 accuracy issue
def checkpoint(function, *args):
    """Checkpoint a model or part of the model.
    This has been directly copied from torch.utils.checkpoint. """

    all_outputs = []
    CheckpointFunction.apply(function, all_outputs, *args)
    if len(all_outputs) == 1:
        return all_outputs[0]
    else:
        return tuple(all_outputs)


def partition_activations_in_checkpoint(partition_activation):
    global PARTITION_ACTIVATIONS
    PARTITION_ACTIVATIONS = partition_activation
    if dist.get_rank() == 0:
        logger.info(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")


def set_num_layers(nlayers):
    global num_layers
    num_layers = nlayers


def reset():
    """Resets memory buffers related to contiguous memory optimizations.
    Should be called during eval when multiple forward propagations are
    computed without any backward propagation that usually clears these
    buffers.
    Arguments:
        None

    Return:
        None
    """
    if CONTIGUOUS_CHECKPOINTING:
        global data_offsets, size_offsets
        global contiguous_data_buffers, contiguous_size_buffers

        for buffers in contiguous_data_buffers:
            buffers = []

        # frees up all the pointers to the checkpoints except for the ones
        # stored by save for backward
        contiguous_data_buffers = []
        contiguous_size_buffers = []
        data_offsets = []
        size_offsets = []


def _configure_using_config_file(config, mpu=None):
    global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
        CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME

    config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config
    if dist.get_rank() == 0:
        logger.info(config.repr())
    PARTITION_ACTIVATIONS = config.partition_activations
    CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
    num_layers = config.number_checkpoints
    CPU_CHECKPOINT = config.cpu_checkpointing
    SYNCHRONIZE = config.synchronize_checkpoint_boundary
    PROFILE_TIME = config.profile


def _configure_defaults():

    global mpu, num_layers, deepspeed_checkpointing_enabled

    global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
        CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME

    PARTITION_ACTIVATIONS = False
    CONTIGUOUS_CHECKPOINTING = False
    num_layers = False
    CPU_CHECKPOINT = False
    SYNCHRONIZE = False
    PROFILE_TIME = False
    deepspeed_checkpointing_enabled = True


def configure(
    mpu_,
    deepspeed_config=None,
    partition_activations=None,
    contiguous_checkpointing=None,
    num_checkpoints=None,
    checkpoint_in_cpu=None,
    synchronize=None,
    profile=None,
):
    """Configure DeepSpeed Activation Checkpointing.

    Arguments:
        mpu_: Optional: An object that implements the following methods
            get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size

        deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
            configure DeepSpeed Activation Checkpointing

        partition_activations: Optional: Partitions activation checkpoint across model parallel
            GPUs when enabled. By default False. Will overwrite deepspeed_config if provided

        contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
            buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
            Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
            provided

        num_checkpoints: Optional: Number of activation checkpoints stored during the forward
            propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
            Will overwrite deepspeed_config if provided

        checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
            partition_activation. Default is false. Will overwrite deepspeed_config if provided

        synchronize: Optional: Performs get_accelerator().synchronize() at the beginning and end of
            each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
            By default false. Will overwrite deepspeed_config if provided

        profile: Optional: Logs the forward and backward time for each
            deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
            if provided

    Returns:
        None
    """
    global mpu, num_layers, deepspeed_checkpointing_enabled

    global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
        CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME

    _configure_defaults()

    if mpu_ is not None:
        mpu = mpu_

    if deepspeed_config is not None:
        _configure_using_config_file(deepspeed_config, mpu=mpu)

    if partition_activations is not None:
        PARTITION_ACTIVATIONS = partition_activations

    if contiguous_checkpointing is not None:
        CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing

    if num_checkpoints is not None:
        num_layers = num_checkpoints

    if checkpoint_in_cpu is not None:
        CPU_CHECKPOINT = checkpoint_in_cpu

    if synchronize is not None:
        SYNCHRONIZE = synchronize

    if profile is not None:
        PROFILE_TIME = profile

    if CONTIGUOUS_CHECKPOINTING:
        assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config"
    if CONTIGUOUS_CHECKPOINTING:
        assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"

    global mp_rank, mp_size, mp_group

    if mpu is not None:
        if hasattr(mpu, 'get_tensor_model_parallel_rank'):
            mp_rank = mpu.get_tensor_model_parallel_rank()
            mp_size = mpu.get_tensor_model_parallel_world_size()
            mp_group = mpu.get_tensor_model_parallel_group()
        else:
            mp_rank = mpu.get_model_parallel_rank()
            mp_size = mpu.get_model_parallel_world_size()
            mp_group = mpu.get_model_parallel_group()

    #print configuration only once
    see_memory_usage("After configuration", force=False)
    if dist.get_rank() == 0:
        logger.info(f"Activation Checkpointing Information")
        logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
        logger.info(f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
        logger.info(f"----Synchronization {SYNCHRONIZE}")
        logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")


def is_configured():
    """True if deepspeed activation checkpointing has been configured
        by calling deepspeed.checkpointing.configure, else returns false

    Arguments:
        None

    Return:
        True of configured, else False
    """
    return deepspeed_checkpointing_enabled
