# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple, Union, cast

import torch
from torch import nn

from .layer import MoE


def has_moe_layers(m: nn.Module) -> Tuple[bool, int]:
    has_moe = False
    num_experts = 0

    for module in m.modules():
        if isinstance(module, MoE):
            has_moe = True
            num_experts = module.num_experts
            break
    return has_moe, num_experts


def is_moe_param(param: torch.Tensor) -> bool:
    if hasattr(param, "allreduce") and not param.allreduce:
        return True
    return False


def split_params_into_shared_and_expert_params(
        params: List[torch.nn.Parameter]) -> Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]:
    shared_params: List[nn.Parameter] = []
    expert_params: List[nn.Parameter] = []

    for p in params:
        if is_moe_param(p):
            expert_params.append(p)
        else:
            shared_params.append(p)
    return shared_params, expert_params


def split_params_grads_into_shared_and_expert_params(
        group: List[torch.nn.Parameter]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """Split grad of parameters into grads of non-expert params
    and grads of expert params. This is useful while computing
    grad-norms for clipping and overflow detection

        group (List[torch.nn.Parameter]):
    Args:
            The group of parameters to split

    Returns:
        Tuple[List[torch.Tensor], List[torch.Tensor]]:
        list of gradients for non MoE params, list of gradients of MoE params
    """
    expert_grads: List[torch.Tensor] = []
    shared_grads: List[torch.Tensor] = []

    for p in group:
        if p.grad is not None:
            if is_moe_param(p):
                expert_grads.append(p.grad.to(p.dtype))
            else:
                shared_grads.append(p.grad.to(p.dtype))
    return shared_grads, expert_grads


def split_params_into_different_moe_groups_for_optimizer(
        param_groups: Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]],
        max_group_size: Union[int, float] = 178956971) -> List[Dict[str, Any]]:
    """Split parameters into different MoE groups for optimizer

    Args:
        param_groups (Union[Dict[str, Any], Tuple[Dict[str, Any], ...], List[Dict[str, Any]]])
            The list of parameter groups to split

    Returns:
        List[Dict[str, Any]]:
        list of MoE/non-MoE groups for optimizer
    """
    if isinstance(param_groups, tuple):
        param_groups = list(param_groups)  # Tuple cannot be modified
    elif isinstance(param_groups, dict):
        param_groups = [param_groups]
    elif not isinstance(param_groups, list):
        raise ValueError(f"Unknown param group type of {type(param_groups)}")

    # gather all data parallel group names
    data_parallel_group_names: Set[str] = set()
    for param_group in param_groups:
        for param in cast(List[nn.Parameter], param_group["params"]):
            if is_moe_param(param):
                data_parallel_group_names.add(param.group_name)

    # Create the param MoE groups, leave param assign to next step
    group_moe: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict))
    for param_group in param_groups:
        for key in data_parallel_group_names:
            group_moe[param_group['name']][key] = {
                **param_group,
                'name': key,
                'moe': True,
                'params': [],
            }

    # Assign param
    for param_group in param_groups:
        new_params: List[nn.Parameter] = []

        for param in cast(List[nn.Parameter], param_group['params']):
            if is_moe_param(param):
                group_moe[param_group['name']][param.group_name]['params'].append(param)
            else:
                new_params.append(param)
        param_group['params'] = new_params

    # Flatten the moe groups
    if max_group_size is not None:
        for moe_group in group_moe.values():
            for param_group in moe_group.values():
                cur_group: List[nn.Parameter] = []
                all_groups: List[List[nn.Parameter]] = []
                size_of_cur_group = 0

                for param in cast(List[nn.Parameter], param_group['params']):
                    if size_of_cur_group + param.numel() <= max_group_size:
                        cur_group.append(param)
                        size_of_cur_group += param.numel()
                    else:
                        all_groups.append(cur_group)
                        cur_group = [param]
                        size_of_cur_group = param.numel()

                if cur_group:
                    all_groups.append(cur_group)

                for group in all_groups:
                    param_groups.append({**param_group, 'params': group})
    else:
        for moe_group in group_moe.values():
            for param_group in moe_group.values():
                param_groups.append(param_group)

    return param_groups


def is_moe_param_group(param_group):
    return param_group.get('moe', False)


def configure_moe_param_groups(model_parameters: List):
    assert isinstance(model_parameters, list), "model_parameters must be a list"

    for p in model_parameters:
        # match torch.optim.Optimizer expectations,
        # see: https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/torch/optim/optimizer.py#L270-L272
        if not isinstance(p, (torch.Tensor, dict)):
            raise TypeError("param argument that would be given to the optimizer should be "
                            f"an iterable of Tensors or dicts, but got {type(p)}")

    # peak at the first element to determine how to proceed
    first = model_parameters[0]

    # Case 1: model_parameters is a list of torch.nn.Parameter
    #   -> need to create moe compatible param groups
    if isinstance(first, torch.nn.Parameter):
        param_group = {'params': model_parameters, 'name': 'dense-params'}
        return split_params_into_different_moe_groups_for_optimizer(param_group)

    # Case 2: model_parameters is a list of param groups List[dict]
    #   -> moe compatible param groups might already exist, if not create them
    elif isinstance(first, dict):
        #there are no moe groups created
        if not any(['moe' in param_group for param_group in model_parameters]):
            return split_params_into_different_moe_groups_for_optimizer(model_parameters)
        else:
            # moe groups exist, nothing to do
            return model_parameters
