# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Automatic Tensor Parallelism
import re

from torch import nn
from .replace_policy import replace_policies
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import *
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode


def move(tensor, device, copy=True):
    if tensor.is_meta:
        return torch.empty_like(tensor, device=device)
    else:
        # Using new tensors help in freeing memory (after split for example) was done before by calling clone().
        # Using copy=True instead of clone() will help in case of cpu --> cpu.
        # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
        return tensor.to(device, copy=copy)


class ReplaceWithTensorSlicing:

    def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
        if mp_group is not None:
            self.gpu_index = dist.get_rank(group=mp_group)
        else:
            self.gpu_index = 0
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.mp_size = mp_size

    def merge_assert(self, dim1, dim2):
        assert dim1 > dim2, \
            'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\
            for merging your checkpoints before replacing the transformer layer with\
            inference-kernels'

    def strided_copy(self,
                     dst: Optional[torch.Tensor],
                     src: Optional[torch.Tensor],
                     num_splits: int,
                     int8: bool = False,
                     allocate_tensor: bool = False):
        if src is None:
            return src
        src_shape = src.shape
        dst_shape = dst.shape

        outer_dim = 0 if int8 else -1

        if allocate_tensor:
            dst = torch.empty_like(dst)

        src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
        if (len(src_shape) == 2 and len(dst_shape) == 2):
            if src_shape[outer_dim] == dst_shape[self.out_dim]:
                try:
                    dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
                except:
                    print(dst.shape, src.shape)
                    exit()
                dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
                if hasattr(src, 'scale'):
                    dst.scale = src.scale
                return dst
            self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
            qkv_size = dst_shape[self.out_dim] // num_splits
            qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
            weight_split = [
                torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
            ]
            dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
                weight_split[self.gpu_index].shape)
        else:
            if src_shape[0] == dst_shape[0]:
                return torch.nn.parameter.Parameter(src)
            qkv_size = dst_shape[0] // num_splits
            qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
            bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
            dst.data.copy_(bias_split[self.gpu_index].contiguous())

        dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
        if hasattr(src, 'scale'):
            dst.scale = src.scale
        return dst

    def copy(self, dst, src, int8=False, allocate_tensor=False):
        if src is None:
            return src
        assert not dst.data.is_meta  # the torch.Tensor.copy_ method used below will silently fail on meta tensors
        if allocate_tensor:
            dst = torch.empty_like(dst)
        outer_dim = 0 if int8 else 1
        inner_dim = 1 if int8 else 0
        src_shape = src.shape
        dst_shape = dst.shape
        if (len(src_shape) == 2 and len(dst_shape) == 2):

            if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
                dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
            else:
                if src_shape[inner_dim] != dst_shape[self.in_dim]:
                    self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
                    dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
                                   src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
                else:
                    self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
                    dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
                                   src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
        else:
            if src_shape[0] == dst_shape[0]:
                dst = src if src.dtype == dst.dtype else dst.data.copy_(src)
            else:
                dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
        dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
        if hasattr(src, 'scale'):
            dst.scale = src.scale
        return dst


class Loading():

    def is_load_module(module):
        load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
        load_layer_names = [
            "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
            "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
            "Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm",
            "Qwen3RMSNorm", "Qwen3MoeRMSNorm", "DeepseekV2RMSNorm", "DeepseekV3RMSNorm",
            "DeepseekV2YarnRotaryEmbedding", "DeepseekV3YarnRotaryEmbedding", "MoEGate"
        ]
        return module.__class__ in load_layers or module._get_name() in load_layer_names

    def load_buffer(module, state_dict, prefix):
        for name in module._buffers.keys():
            if module._buffers[name].data.is_meta:
                module._buffers[name] = torch.nn.parameter.Parameter(
                    data=torch.empty_like(module._buffers[name].data, device="cpu"),
                    requires_grad=module._buffers[name].data.requires_grad)
            if prefix + name in state_dict.keys():
                module._buffers[name].data.copy_(state_dict[prefix + name])

    def load(module, state_dict, prefix, mp_group=None):
        mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
        if hasattr(module, 'weight'):
            if module.weight.data.is_meta:
                # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
                module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"),
                                                             requires_grad=module.weight.data.requires_grad)
                if 'query_key_value' in prefix:
                    module.weight = mp_replace.strided_copy(module.weight.data,
                                                            state_dict[prefix + 'weight'],
                                                            num_splits=3)
                else:
                    module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
        else:
            if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
                if module.norm.weight.data.is_meta:
                    # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
                    module.norm.weight = torch.nn.parameter.Parameter(
                        data=torch.empty_like(module.norm.weight.data, device="cpu"),
                        requires_grad=module.norm.weight.data.requires_grad)
                module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])

        if prefix + 'bias' in state_dict.keys():
            if hasattr(module, 'bias'):
                if module.bias.data.is_meta:
                    # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
                    module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
                                                               requires_grad=module.bias.data.requires_grad)
                module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
            else:
                if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
                    if module.norm.bias.data.is_meta:
                        # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
                        module.norm.bias = torch.nn.parameter.Parameter(
                            data=torch.empty_like(module.norm.bias.data, device="cpu"),
                            requires_grad=module.norm.bias.data.requires_grad)
                    module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])


class AutoTP():

    def __init__(self,
                 module,
                 all_reduce_linears,
                 prefix,
                 state_dict,
                 linear_layer_setting,
                 orig_layer_impl,
                 keep_module_on_host=False):
        self.module = module
        self.all_reduce_linears = all_reduce_linears
        self.prefix = prefix
        self.state_dict = state_dict

        self.mp_size = None
        self.mp_group = None
        self.linear_layer_setting = linear_layer_setting
        self.orig_layer_impl = orig_layer_impl
        self.linear_policies = None
        self.conv_linear_layer = False
        TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)

    def in_module_list(module, module_list):
        for item in module_list:
            if type(item).__name__ == type(module).__name__:
                return True
        return False

    def get_module_list(model):
        mlist = []
        for child in model.children():
            if isinstance(child, nn.ModuleList):
                for module in child.children():
                    if not mlist:
                        mlist = [module]
                    elif not AutoTP.in_module_list(module, mlist):
                        mlist = mlist + [module]
            else:
                mlist = mlist + AutoTP.get_module_list(child)
        return mlist

    def supported(model):
        unsupported = ['deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
        model = str(model)
        key = re.search(r": (.*?)Model", model)
        if key is None:
            key = re.search(r": (.*?)Stack", model)
        if key is None:
            key = re.match(r"(.*?)Model", model)
        assert key is not None, "Not able to determine model policy automatically. Please provide policy."
        if key.group(1).lower() in unsupported:
            return False
        return True

    def get_layers(parent, module):
        layer_list = []
        for key, submodule in module._modules.items():
            if isinstance(submodule, nn.Linear):
                layer_list = layer_list + [parent + "." + key]
            elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
                layer_list = layer_list + ["ln"]
            else:
                layer_list = layer_list + AutoTP.get_layers(key, submodule)
        return layer_list

    def update_policy_list(policy_list, new_module, new_gems):
        if len(policy_list):
            for i, policy in enumerate(policy_list):
                # if module already exists in policy, combine gems and remove duplicates
                if policy[0] == type(new_module):
                    new_gems = set(new_gems + policy[1])
                    policy_list[i] = tuple([type(new_module), new_gems])
                    return policy_list
        policy_list.append(tuple([type(new_module), new_gems]))
        return policy_list

    def kernel_supported(module_list):
        policy = []
        for plcy in replace_policies:
            # instantiate a throw-away policy in order to populate the _orig_layer_class
            _ = plcy(None)
            if isinstance(plcy._orig_layer_class, list):
                for orig_layer_class in plcy._orig_layer_class:
                    policy.append(orig_layer_class)
            elif plcy._orig_layer_class is not None:
                policy.append(plcy._orig_layer_class)
        for child in module_list:
            if child.__class__ in policy:
                return True
        return False

    def tp_parser(model):
        policy_list = []
        module_list = []
        layer_list = []
        gem_list = []

        module_list = AutoTP.get_module_list(model)
        assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
        if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
        norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2']
        #ln_1 , ln_2 for Qwen
        for module in module_list:
            for key, submodule in module._modules.items():
                if isinstance(submodule, nn.Linear):
                    layer_list = layer_list + ["." + key]
                elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
                    layer_list = layer_list + ["ln"]
                else:
                    layer_list = layer_list + AutoTP.get_layers(key, submodule)
            for i, layer in enumerate(layer_list):
                if layer == 'ln':
                    if layer_list[i - 1] != 'ln':
                        gem_list = gem_list + [layer_list[i - 1]]
                elif 'out_proj' in layer:
                    gem_list = gem_list + [layer]
                elif 'o_proj' in layer:
                    gem_list = gem_list + [layer]
                elif 'down_proj' in layer:
                    gem_list = gem_list + [layer]
                elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
                    gem_list = gem_list + [layer]
                elif 'self_attention.dense' in layer and 'falcon' in str(
                        type(module)):  # this is a hack to get the right linear layer for this model!
                    gem_list = gem_list + [layer]
                # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
                elif 'w2' in layer and 'Mixtral' in str(type(module)):
                    gem_list = gem_list + [layer]
                elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
                    gem_list = gem_list + [layer]
                elif 'self_attention.dense' in layer and 'ChatGLM' in str(model):
                    gem_list = gem_list + [layer]
                elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model):
                    gem_list = gem_list + [layer]

            layer_list = []
            if gem_list != []:
                gem_list = list(set(gem_list))
                policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
                gem_list = []
        assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
        if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
        return policy_list

    def set_tensor_parallel_config(self, mp_size, mp_group):

        if is_autotp_training_mode():
            self.mp_group = groups.get_tensor_model_parallel_group()
            self.mp_size = groups.get_tensor_model_parallel_world_size()
            return

        self.mp_size = mp_size
        self.mp_group = mp_group

    def _replace(self, child, name, conv_linear_layer):
        # This function should clearly define the routing rules for specific layers
        # and avoid any complex shard-related logic.
        if getattr(child, "replaced", False) == True:
            return

        weight_shape = child.weight.shape
        mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
        # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
        if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
            ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
            return child
        # For Yuan model
        if 'Yuan' in str(self.module):
            if 'v_proj' in name:
                return Yuan_LinearLayer(child, self.mp_group)

            elif 'o_proj' in name:
                return Yuan_LinearAllreduce(child, self.mp_group)

        # For MLP including chunk layer.
        if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
            return GateUpPack_LinearLayer(child, self.mp_group)
            # For Arctic model, bypass to all_reduce replacement for w2 weights
        arctic_w2_all_reduce_linear = False
        if 'Arctic' in str(self.module) and 'w2' in name:
            arctic_w2_all_reduce_linear = True
        # For MoE MLP model, e.g., deepseek and jamba
        down_proj = False
        if 'down_proj' in name:
            down_proj = True
        if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:

            setattr(child, "replaced", True)
            if self.conv_linear_layer:
                return Conv_LinearALlreduce(child, self.mp_group, name=name)
            elif name == "lm_head" or name == 'embed_out':
                return LmHeadLinearAllreduce(child, self.mp_group)

            return LinearAllreduce(child, self.mp_group, name=name)
        else:

            setattr(child, "replaced", True)
            if self.conv_linear_layer:
                conv_LinearLayer(child, self.mp_group)
            elif require_tp_fused_qkvw(name, self.mp_size):
                #Check and handle fused qkv for TP
                return fused_LinearLayer(child, self.mp_group, fused_module=self.module)

            return LinearLayer(child, self.mp_group, name=name)

    def _slice_embedding(self, child, name, conv_linear_layer):
        if getattr(child, "replaced", False) == True:
            return
        mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

        if hasattr(child.weight, 'ds_tensor'):
            data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
        else:
            data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size, name), dim=1)
        data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
        data = torch.nn.parameter.Parameter(data, requires_grad=False)

        new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size, name))
        new_embedding.weight.data.copy_(data)
        setattr(child, "replaced", True)
        return new_embedding

    def update_mp_params(self, child):
        if getattr(child, "replaced", False) == True:
            return
        param_list = [
            "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size",
            "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model",
            "num_attention_heads_per_partition", "num_multi_query_groups_per_partition", "hidden_size_per_partition"
        ]
        for param in param_list:
            if "Yuan" in str(child) and 'embed_dim' in param_list:
                param_list.remove('embed_dim')
            if hasattr(child, param):
                param_val = getattr(child, param)
                setattr(child, param, get_shard_size(param_val, self.mp_size))
        setattr(child, "replaced", True)

    def update_linear_policies(self):
        self.conv_linear_layer = False
        if self.linear_layer_setting is not None:
            self.linear_policies = {self.linear_layer_setting[0]: self._replace}
            if len(self.linear_layer_setting) == 2:
                self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding})
        else:
            import transformers
            if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block:
                try:
                    self.conv_linear_layer = True
                    self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace}
                except ImportError:
                    self.linear_policies = {nn.Linear: self._replace}
            else:
                self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding}

    def _replace_module(self, r_module, prev_name='', prev_class_name=''):
        for name, child in r_module.named_children():
            if prev_class_name == "":
                class_name = prev_name
            elif prev_name == "":
                class_name = prev_class_name
            else:
                class_name = prev_class_name + '.' + prev_name
            checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
            if Loading.is_load_module(child) and self.state_dict is not None:
                if any(checking_key in item for item in self.state_dict):
                    Loading.load(child, self.state_dict, checking_key, self.mp_group)
                else:
                    continue
            if len(child._buffers) != 0 and self.state_dict is not None:
                Loading.load_buffer(child, self.state_dict, checking_key)
            if child.__class__ in self.linear_policies:
                setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
                                                                              self.conv_linear_layer))
            elif any(isinstance(child, lp) for lp in self.linear_policies):
                # Added for falcon model support
                # Note: isinstance will account for class inheritance, child.__class__ does not
                key = None
                for lp in self.linear_policies:
                    if isinstance(child, lp):
                        key = lp
                        break
                assert key is not None
                setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name,
                                                                  self.conv_linear_layer))
            else:
                self.update_mp_params(child)
                self._replace_module(child, name, class_name)
        return r_module

    def get_model_num_kv_heads(self, config):
        num_kv_heads = None
        # multi_query_group_num is for chatglm2 & chatglm3
        kv_head_names = [
            'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads',
            'attention_heads'
        ]
        for name in kv_head_names:
            if hasattr(config, name):
                num_kv_heads = getattr(config, name)
                if num_kv_heads is not None:
                    break
        return num_kv_heads

    def _replace_last_linear_module(self, r_module):
        if hasattr(r_module, "lm_head"):
            name = "lm_head"
            child = r_module.lm_head
        elif hasattr(r_module, "embed_out"):
            name = "embed_out"
            child = r_module.embed_out
        else:
            return r_module
        if child.__class__ in self.linear_policies:
            setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer))
        return r_module
