# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from collections import OrderedDict
import torch
import sys
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed import comm as dist
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
                                     align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
                                     see_memory_usage, graph_process, get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.utils.torch import register_grad_hook
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
                                            SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
                                            PARAM_SLICE_MAPPINGS)

setattr(sys.modules[__name__], 'fragment_address', fragment_address)


def print_rank_0(message, debug=False, force=False):
    if dist.get_rank() == 0 and (debug or force):
        print(message)


class BF16_Optimizer(ZeROOptimizer):

    def __init__(self,
                 init_optimizer,
                 param_names,
                 mpu=None,
                 clip_grad=0.0,
                 norm_type=2,
                 allgather_bucket_size=5000000000,
                 dp_process_group=None,
                 timers=None,
                 grad_acc_dtype=None,
                 graph_harvesting=False,
                 immediate_grad_update=True,
                 has_moe_layers=False):
        super().__init__()
        see_memory_usage('begin bf16_optimizer', force=True)
        self.timers = timers
        self.optimizer = init_optimizer
        self.param_names = param_names
        self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

        assert grad_acc_dtype in [torch.float32, torch.bfloat16
                                  ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
        self.grad_acc_dtype = grad_acc_dtype
        self.immediate_grad_update = immediate_grad_update

        self.clip_grad = clip_grad
        self.norm_type = norm_type
        self.mpu = mpu
        self.allgather_bucket_size = int(allgather_bucket_size)
        self.dp_process_group = dp_process_group
        self.dp_rank = dist.get_rank(group=self.dp_process_group)
        self.has_moe_layers = has_moe_layers
        self.non_expert_gradients = []
        self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
        if self.has_moe_layers:
            self._configure_moe_settings()

        # Use torch (un)flatten ops
        self.flatten = _flatten_dense_tensors
        self.unflatten = _unflatten_dense_tensors

        #align nccl all-gather send buffers to 4-bye boundary
        self.nccl_start_alignment_factor = 2  # 4-byte alignment/sizeof(fp16) = 2

        # Build BF16/FP32 groups
        self.bf16_groups = []
        self.bf16_groups_flat = []
        self.bf16_partitioned_groups = []

        self.fp32_groups_flat_partition = []

        # Maintain different fp32 gradients views for convenience
        self.fp32_groups_gradients = []
        self.fp32_groups_gradient_dict = {}
        self.fp32_groups_gradients_flat = []
        self.fp32_groups_actual_gradients_flat = []
        self.fp32_groups_gradient_flat_partition = []
        self.fp32_groups_has_gradients = []

        self.group_paddings = []
        self.graph_harvesting = graph_harvesting
        if self.using_real_optimizer:
            self._setup_for_real_optimizer()

        see_memory_usage('end bf16_ optimizer', force=True)

    def destroy(self):
        for i, _ in enumerate(self.optimizer.param_groups):
            for p in self.bf16_groups[i]:
                if getattr(p, '_hp_mapping', None):
                    p._hp_mapping = None
        for hook in self._grad_acc_hooks:
            hook.remove()
        print_rank_0("Removed grad acc hooks")

    def _configure_moe_settings(self):
        assert any(
            [is_moe_param_group(group) for group in self.optimizer.param_groups]
        ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"

        for i, group in enumerate(self.optimizer.param_groups):
            if is_moe_param_group(group):
                assert all([is_moe_param(param)
                            for param in group['params']]), "All params in MoE group must be MoE params"
                self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name'])
        self.expert_gradients = {}
        if self.has_moe_layers:
            for key in groups._get_expert_data_parallel_group_dict().keys():
                self.expert_gradients[key] = []

    def _setup_for_real_optimizer(self):
        self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group]

        for i, param_group in enumerate(self.optimizer.param_groups):
            real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
            see_memory_usage(f'before initializing group {i}', force=True)

            partition_id = dist.get_rank(group=self.real_dp_process_group[i])

            # grab the original list
            trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
            self.bf16_groups.append(trainable_parameters)

            # create flat bf16 params
            self.bf16_groups_flat.append(
                self._flatten_dense_tensors_aligned(self.bf16_groups[i],
                                                    self.nccl_start_alignment_factor * real_dp_world_size))
            # Make bf16 params point to flat tensor storage
            self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
                                                     flat_tensor=self.bf16_groups_flat[i])

            # divide flat weights into equal sized partitions
            partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
            bf16_dp_partitions = [
                self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
                for dp_index in range(real_dp_world_size)
            ]
            self.bf16_partitioned_groups.append(bf16_dp_partitions)

            # create fp32 params partition
            self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
            self.fp32_groups_flat_partition[i].requires_grad = True

            num_elem_list = [t.numel() for t in self.bf16_groups[i]]

            # create fp32 gradients
            fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)
            self.fp32_groups_gradients_flat.append(fp32_flat_buffer)
            if self.has_moe_layers and is_moe_param_group(param_group):
                self.expert_gradients[param_group['name']].append(fp32_flat_buffer)
            else:
                self.non_expert_gradients.append(fp32_flat_buffer)

            # track individual fp32 gradients for entire model
            fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
                                                     num_elem_list=num_elem_list)
            self.fp32_groups_gradients.append(fp32_gradients)
            self.fp32_groups_gradient_dict[i] = fp32_gradients

            # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
            length_without_padding = sum(num_elem_list)
            self.fp32_groups_actual_gradients_flat.append(
                torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))

            # flat tensor corresponding to gradient partition
            self.fp32_groups_gradient_flat_partition.append(
                torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))

            # track fp32 gradient updates
            self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))

            # Record padding required for alignment
            if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
                padding = self.bf16_groups_flat[i].numel() - length_without_padding
            else:
                padding = 0

            self.group_paddings.append(padding)

            # update optimizer param groups to reference fp32 params partition
            param_group['params'] = [self.fp32_groups_flat_partition[i]]

            see_memory_usage(f'after initializing group {i}', force=True)

        self._grad_acc_hooks = []
        if self.immediate_grad_update:
            self.create_grad_acc_hooks()

        # Need optimizer states initialized before linking lp to optimizer state
        self._link_all_hp_params()
        self._hp_optimizer_states_linked = False
        self._enable_universal_checkpoint()
        self._param_slice_mappings = self._create_param_mapping()

    def _enable_universal_checkpoint(self):
        for lp_param_group in self.bf16_groups:
            enable_universal_checkpoint(param_list=lp_param_group)

    def _create_param_mapping(self):
        param_mapping = []
        for i, _ in enumerate(self.optimizer.param_groups):
            param_mapping_per_group = OrderedDict()
            for lp in self.bf16_groups[i]:
                if lp._hp_mapping is not None:
                    lp_name = self.param_names[lp]
                    param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
            param_mapping.append(param_mapping_per_group)

        return param_mapping

    def _link_all_hp_params(self):
        for i, _ in enumerate(self.optimizer.param_groups):
            real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])

            # Link bf16 and fp32 params in partition
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
            partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
            flat_hp_partition = self.fp32_groups_flat_partition[i]
            link_hp_params(lp_param_list=self.bf16_groups[i],
                           flat_hp_partition=flat_hp_partition,
                           gradient_dict=self.fp32_groups_gradient_dict,
                           offload_gradient_dict=None,
                           use_offload=False,
                           param_group_index=i,
                           partition_start=partition_id * partition_size,
                           partition_size=partition_size,
                           dp_group=self.real_dp_process_group[i])

    def _lazy_init_hp_params_optimizer_state(self):
        if not self._hp_optimizer_states_linked:
            for i, _ in enumerate(self.optimizer.param_groups):
                lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i],
                                                    self.optimizer.state)
            self._hp_optimizer_states_linked = True

    def _split_flat_tensor(self, flat_tensor, num_elem_list):
        assert sum(num_elem_list) <= flat_tensor.numel()
        tensor_list = []
        offset = 0
        for num_elem in num_elem_list:
            dense_tensor = torch.narrow(flat_tensor, 0, offset, num_elem)
            tensor_list.append(dense_tensor)
            offset += num_elem

        return tensor_list

    def _update_storage_to_flattened_tensor(self, tensor_list, flat_tensor):
        updated_params = self.unflatten(flat_tensor, tensor_list)
        for p, q in zip(tensor_list, updated_params):
            p.data = q.data

    def _flatten_dense_tensors_aligned(self, tensor_list, alignment):
        return self.flatten(align_dense_tensors(tensor_list, alignment))

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None:
            raise NotImplementedError(f'{self.__class__} does not support closure.')

        non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
        non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
                                                            mpu=self.mpu,
                                                            norm_type=self.norm_type,
                                                            use_graph=self.graph_harvesting)
        all_groups_norm = non_expert_groups_norm
        if self.has_moe_layers:
            all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm,
                                                       mpu=self.mpu,
                                                       expert_tensors=expert_grads_for_norm,
                                                       norm_type=self.norm_type)

        self._global_grad_norm = all_groups_norm

        assert all_groups_norm > 0.
        if self.clip_grad > 0.:
            clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
                                        max_norm=self.clip_grad,
                                        global_norm=all_groups_norm,
                                        mpu=self.mpu,
                                        use_graph=self.graph_harvesting)

        for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
                                                   self.fp32_groups_gradient_flat_partition):
            # In case of grad acc dtype different than FP32, need to cast to high precision.
            param_partition.grad = grad_partition.to(
                param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition

        self.optimizer.step()

        if self.grad_acc_dtype is not torch.float32:
            for param_partition in self.fp32_groups_flat_partition:
                param_partition.grad = None

        # We need to link optimizer state after the first step() call
        self._lazy_init_hp_params_optimizer_state()

        self.update_lp_params()

        self.clear_hp_grads()

    def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
        """Perform a backward pass and copy the low-precision gradients to the
        high-precision copy.

        We copy/accumulate to the high-precision grads now to prevent accumulating in the
        bf16 grads after successive backward() calls (i.e., grad accumulation steps > 1)

        The low-precision grads are deallocated during this procedure.
        """
        self.clear_lp_grads()
        loss.backward(retain_graph=retain_graph, **bwd_kwargs)

        if update_hp_grads:
            self.update_hp_grads(clear_lp_grads=clear_lp_grads)

    @torch.no_grad()
    def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):
        if lp.grad is None:
            return

        hp_grad = self.fp32_groups_gradients[group_idx][param_idx]
        assert hp_grad is not None, \
            f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]'

        hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
        lp._hp_grad = hp_grad
        self.fp32_groups_has_gradients[group_idx][param_idx] = True

        # clear gradients
        if clear_lp_grads:
            lp.grad.zero_()

    @torch.no_grad()
    def _update_hp_grads_func(self, clear_lp_grads=False):
        for i, group in enumerate(self.bf16_groups):
            for j, lp in enumerate(group):
                self._update_hp_grad(lp, i, j, clear_lp_grads)

    @torch.no_grad()
    def update_hp_grads(self, clear_lp_grads=False):
        if self.immediate_grad_update:
            return

        if self.graph_harvesting:
            graph_process(False, self._update_hp_grads_func, clear_lp_grads)
        else:
            self._update_hp_grads_func(clear_lp_grads)
        #cpu op
        for i, group in enumerate(self.bf16_groups):
            for j, lp in enumerate(group):
                if lp.grad is None:
                    continue
                self.fp32_groups_has_gradients[i][j] = True

    @torch.no_grad()
    def get_grads_for_reduction(self):
        if self.has_moe_layers:
            return self.non_expert_gradients, self.expert_gradients
        return self.non_expert_gradients, {}

    @torch.no_grad()
    def get_grads_for_norm(self, for_clipping=False):
        """
        Returns:
            tuple[list[Tensor], dict[ep_name, List[Tensor]] | list:
            If for_clipping, return all gradients.
            Otherwise, separate and return dict of expert_grad and list of non_expert_grad
        """
        # (grads, expert_group_name)
        expert_grads_for_norm = {}

        # grads
        non_expert_grads_for_norm = []
        all_grads_for_clip = []

        tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
        assert len(self.bf16_groups) == len(self.optimizer.param_groups)
        for i, group in enumerate(self.bf16_groups):
            for j, lp in enumerate(group):
                if not for_clipping:
                    if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
                        continue

                    # skip duplicated parameters. perform norm only on cards with tp_rank=0.
                    # non-duplicated parameters include:
                    # - Parameters with tp: Use allreducesum of mp_group.
                    # - Moe Parameters with ep: Use allreducesum of ep_group.
                    if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)):
                        continue

                if not self.fp32_groups_has_gradients[i][j]:
                    continue
                if not for_clipping:
                    param_group = self.optimizer.param_groups[i]
                    if self.has_moe_layers and is_moe_param_group(param_group):
                        if param_group['name'] not in expert_grads_for_norm:
                            expert_grads_for_norm[param_group['name']] = []
                        expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j])
                    else:
                        non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
                else:
                    all_grads_for_clip.append(self.fp32_groups_gradients[i][j])
        if not for_clipping:
            return non_expert_grads_for_norm, expert_grads_for_norm
        return all_grads_for_clip

    @torch.no_grad()
    def update_lp_params(self):
        for i, (bf16_partitions,
                fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
            bf16_partitions[partition_id].data.copy_(fp32_partition.data)

        all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
                             partitioned_param_groups=self.bf16_partitioned_groups,
                             dp_process_group=self.real_dp_process_group,
                             start_alignment_factor=self.nccl_start_alignment_factor,
                             allgather_bucket_size=self.allgather_bucket_size)

    def clear_hp_grads(self):
        for flat_gradients in self.fp32_groups_gradients_flat:
            flat_gradients.zero_()

        for i, group in enumerate(self.fp32_groups_gradients):
            self.fp32_groups_has_gradients[i] = [False] * len(group)

    def clear_lp_grads(self, set_to_none=False):

        # using zero_() fixed memory address for graph replay
        if self.graph_harvesting:
            assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None"

        zero_grads_list = []
        for group in self.bf16_groups:
            for param in group:
                if set_to_none:
                    param.grad = None
                elif param.grad is not None:
                    if param.grad.grad_fn is not None:
                        param.grad.detach_()
                    zero_grads_list.append(param.grad)
        if not set_to_none and len(zero_grads_list) > 0:
            torch._foreach_zero_(zero_grads_list)

    def zero_grad(self, set_to_none=True):
        self.clear_lp_grads(set_to_none)
        self.clear_hp_grads()

    def state_dict(self):
        state_dict = {}
        state_dict[CLIP_GRAD] = self.clip_grad
        state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
        state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = self.fp32_groups_flat_partition
        state_dict[GROUP_PADDINGS] = self.group_paddings
        state_dict[PARTITION_COUNT] = self.partition_count
        state_dict[DS_VERSION] = version
        state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings

        return state_dict

    # Restore base optimizer fp32 weights bfloat16 weights
    def _restore_from_bit16_weights(self):
        for i, (bf16_partitions,
                fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
            fp32_partition.data.copy_(bf16_partitions[partition_id].data)

    def refresh_fp32_params(self):
        self._restore_from_bit16_weights()

    def load_state_dict(self,
                        state_dict_list,
                        checkpoint_folder=None,
                        load_optimizer_states=True,
                        load_from_fp32_weights=False,
                        load_serial=None,
                        param_shapes=None):
        if checkpoint_folder:
            self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
        else:
            self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)

    def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):

        dp_rank = dist.get_rank(group=self.dp_process_group)
        current_rank_sd = state_dict_list[dp_rank]

        ckpt_version = current_rank_sd.get(DS_VERSION, False)
        assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
        ckpt_version = pkg_version.parse(ckpt_version)

        self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)

        if load_optimizer_states:
            print(f"_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]")
            self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])

        if load_from_fp32_weights:
            for current, saved in zip(self.fp32_groups_flat_partition,
                                      current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
                src_tensor = _get_padded_tensor(saved, current.numel())
                current.data.copy_(src_tensor.data)

        if load_optimizer_states:
            self._link_all_hp_params()

    def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
        self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder)

    def _load_global_state(self, sd):
        pass

    @property
    def param_groups(self):
        """Forward the wrapped optimizer's parameters."""
        return self.optimizer.param_groups

    @property
    def state(self):
        """Forward the wrapped optimizer's states."""
        return self.optimizer.state

    def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
        assert self.immediate_grad_update
        self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False)

    def create_grad_acc_hooks(self):
        for i, param_group in enumerate(self.bf16_groups):
            for j, param in enumerate(param_group):
                if param.requires_grad:

                    def wrapper(param, i, j):

                        def accumulate_hp_grads_and_remove_lp(*notneeded):
                            self.accumulate_hp_grads_and_remove_lp(param, i, j)

                        self._grad_acc_hooks.append(register_grad_hook(param, accumulate_hp_grads_and_remove_lp))

                    wrapper(param, i, j)


def _get_padded_tensor(src_tensor, size):
    if src_tensor.numel() >= size:
        return src_tensor
    padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
    slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
    slice_tensor.data.copy_(src_tensor.data)
    return padded_tensor
