# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import os
import copy
import collections
import json
from abc import ABC, abstractmethod

from deepspeed.utils import logger
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine

from .weight_quantizer import WeightQuantization

AUTO_MODULE_KEY = 'auto'


class SDLoaderFactory:

    @staticmethod
    def get_sd_loader_json(json_file, checkpoint_engine):
        if isinstance(json_file, str):
            with open(json_file) as f:
                data = json.load(f)
        else:
            assert isinstance(json_file, dict)
            data = json_file
        sd_type = data['type']
        ckpt_list = data['checkpoints']
        version = data['version']
        ckpt_type = data.get('parallelization', 'pp')
        mp_size = data.get('mp_size', 0)
        if sd_type.lower() in ['bloom', 'ds_model']:
            return data
        return SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine, sd_type, version)

    @staticmethod
    def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None):
        if sd_type == 'Megatron':
            return MegatronSDLoader(ckpt_list, version, checkpoint_engine)
        else:
            assert False, '{} checkpoint type is not supported'.format(sd_type)


class SDLoaderBase(ABC):

    def __init__(self, ckpt_list, version, checkpoint_engine):
        self.module_key = None
        self.ckpt_list = ckpt_list
        self.version = version
        self.checkpoint_engine = TorchCheckpointEngine() if checkpoint_engine is None else checkpoint_engine
        self.check_ckpt_list()

    def load(self,
             mp_world_size,
             mp_rank,
             module_key=AUTO_MODULE_KEY,
             is_pipe_parallel=False,
             quantize=False,
             quantize_bits=8,
             quantize_groups=64,
             mlp_extra_grouping=True):
        self.module_key = module_key
        num_ckpt = len(self.ckpt_list)
        idx = mp_rank * num_ckpt // mp_world_size
        """ We have multiple cases to handle here for both training and inference:
            1. PipeModule loading mp_rank_*.pt files, is_pipe_parallel=True, module_key is not None
                a. if no mp_size/pp_size resizing occurs, for both training & inference, loading
                   the mp_rank related checkpoint directly.
                b. if has mp_size/pp_size resizing, only Megatron model inference is supported,
                   in this case each mp_rank_*.pt have same content, we will load the first checkpoint
                   file (idx=0), to avoid idx exceeding file list boundary.

            2. PipeModule loading layer_*.pt files, is_pipe_parallel=True, module_key is None
                a. if no mp_size resizing occurs, for both training & inference, loading
                   the mp_rank related checkpoint directly.
                b. if has mp_size resizing, only Megatron model inference is supported,
                   checkpoint file(s) will be merged/split according to mp_rank, mp_world_size and
                   checkpoint file list.

            3. Non-PipeModule loading mp_rank_*.pt files, is_pipe_parallel=False
                Same with case (2).
        """
        if is_pipe_parallel and module_key is not None and mp_world_size != num_ckpt:
            mp_world_size = num_ckpt
            idx = 0

        load_path = self.ckpt_list[idx]

        merge_count = 1
        if num_ckpt == mp_world_size:
            assert os.path.exists(load_path)
            #logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}')
            sd = self.checkpoint_engine.load(load_path, map_location=lambda storage, \
                loc: storage)

            if quantize:
                quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)
                sd_module, all_scales = quantizer.sd_quantize_megatron(self.get_module(sd), quantize_bits,
                                                                       quantize_groups)
                self.set_module(sd, sd_module)
            else:
                all_scales = None
        elif num_ckpt > mp_world_size:
            sd, all_scales, merge_count = self.merge_state_dict(mp_world_size, mp_rank, quantize, \
                quantize_bits, quantize_groups, mlp_extra_grouping)
        else:
            sd, all_scales = self.split_state_dict(mp_world_size, mp_rank, quantize, quantize_bits, \
                quantize_groups, mlp_extra_grouping)
        return load_path, sd, (all_scales, merge_count)

    def get_merge_state_dicts(self, mp_world_size, mp_rank):
        num_ckpt = len(self.ckpt_list)
        assert num_ckpt % mp_world_size == 0, 'Invalid checkpoints and world size for sd merge'

        num_to_merge = num_ckpt // mp_world_size
        ckpt_list = [self.ckpt_list[i] for i in range(num_to_merge * mp_rank, num_to_merge * (mp_rank + 1))]

        logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}")
        sd_list = [self.checkpoint_engine.load(ckpt, map_location=lambda storage, loc: storage) for ckpt in ckpt_list]
        return sd_list

    def get_split_state_dict(self, mp_world_size, mp_rank):
        num_ckpt = len(self.ckpt_list)
        assert mp_world_size % num_ckpt == 0, 'Invalid checkpoints and world size for sd split'

        num_to_split = mp_world_size // num_ckpt
        ckpt_index = mp_rank // num_to_split
        ckpt_offset = mp_rank % num_to_split

        logger.info(f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}")

        sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index], map_location=lambda storage, loc: storage)

        return sd, num_to_split, ckpt_offset

    def _choose_module_key(self, sd):
        assert not ('module' in sd
                    and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
        assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
        if 'module' in sd:
            return 'module'
        elif 'model' in sd:
            return 'model'

    def get_module(self, sd):
        if self.module_key is None:
            return sd
        elif self.module_key == AUTO_MODULE_KEY:
            return sd[self._choose_module_key(sd)]
        else:
            return sd[self.module_key]

    def set_module(self, sd, module):
        if self.module_key is None:
            sd = module
        elif self.module_key == AUTO_MODULE_KEY:
            sd[self._choose_module_key(sd)] = module
        else:
            sd[self.module_key] = module
        return sd

    def check_ckpt_list(self):
        #logger.info(f'checkpoint file list: {self.ckpt_list}')
        assert len(self.ckpt_list) > 0

        sd = self.checkpoint_engine.load(self.ckpt_list[0], map_location=lambda storage, loc: storage)

        # check checkpoint count is same with saved mp_world_size
        if 'mp_world_size' in sd.keys():
            assert len(self.ckpt_list) == sd[
                'mp_world_size'], f"checkpoint count {len(self.ckpt_list)} is different from saved mp_world_size {sd['mp_world_size']}"

    @abstractmethod
    def merge_state_dict(self, mp_world_size, mp_rank, quantize, quantize_bits, groups, mlp_extra_grouping):
        pass

    @abstractmethod
    def split_state_dict(self, mp_world_size, mp_rank, quantize, quantize_bits, groups, mlp_extra_grouping):
        pass

    @abstractmethod
    def sanity_check(self, ckpt_file_name):
        pass


class MegatronSDLoader(SDLoaderBase):

    def __init__(self, ckpt_list, version, checkpoint_engine):
        super().__init__(ckpt_list, version, checkpoint_engine)
        """
        ## Q/K/V data need special processing
        key: transformer.layers.0.attention.query_key_value.weight, shape: torch.Size([3192, 4256])
        key: transformer.layers.0.attention.query_key_value.bias, shape: torch.Size([3192])

        ## merge or split on axis=0
        key: word_embeddings.weight, shape: torch.Size([12672, 4256])
        key: transformer.layers.0.mlp.dense_h_to_4h.bias, shape: torch.Size([4256])
        key: transformer.layers.0.mlp.dense_h_to_4h.weight, shape: torch.Size([4256, 4256])

        ## merge or split on axis=1
        key: transformer.layers.0.attention.dense.weight, shape: torch.Size([4256, 1064])
        key: transformer.layers.0.mlp.dense_4h_to_h.weight, shape: torch.Size([4256, 4256])

        ## no change required
        key: transformer.layers.0.mlp.dense_4h_to_h.bias, shape: torch.Size([4256])
        key: transformer.final_layernorm.weight, shape: torch.Size([4256])
        key: transformer.final_layernorm.bias, shape: torch.Size([4256])
        key: transformer.layers.0.attention.dense.bias, shape: torch.Size([4256])
        key: transformer.layers.0.post_attention_layernorm.weight, shape: torch.Size([4256])
        key: transformer.layers.0.post_attention_layernorm.bias, shape: torch.Size([4256])
        key: transformer.layers.0.input_layernorm.weight, shape: torch.Size([4256])
        key: transformer.layers.0.input_layernorm.bias, shape: torch.Size([4256])
        key: position_embeddings.weight, shape: torch.Size([1024, 4256])
        """

    def merge_query_key_value(self, param_list, ckpt_ver):
        """
        Up to now we found 3 Q/K/V parameter formats in different Megatron checkpoint versions:

        1. version 0, there is no version information saved in checkpoint.
            format: [(3 * np * hn), h]
        2. version 1.0
            format: [(np * hn * 3), h]
        3. version 2.0
            format: [(np * 3 * hn), h]

        h: hidden size
        n: number of attention heads
        p: number of model parallel partitions
        np: n/p
        hn: h/n
        """

        new_qkv = None
        if ckpt_ver == 0:
            # [(3 * np * hn), h]
            assert param_list[0].shape[0] % 3 == 0
            size_qkv = param_list[0].shape[0] // 3
            split_tensors = [torch.split(param, size_qkv, dim=0) for param in param_list]

            tensors = []
            for i in range(3):
                tensor_tuple = [t[i] for t in split_tensors]
                tensors.append(torch.cat(tensor_tuple, axis=0))
            new_qkv = torch.cat(tensors, axis=0)
        elif ckpt_ver == 1.0 or ckpt_ver == 2.0:
            # [(np * hn * 3), h] or [(np * 3 * hn), h]
            new_qkv = torch.cat(param_list, axis=0)
        else:
            assert False, f'checkpoint version: {ckpt_ver} is not supported'

        return new_qkv

    def split_query_key_value(self, param, num_to_split, offset, ckpt_ver):
        """
        Up to now we found 3 Q/K/V parameter formats in different Megatron checkpoint versions:

        1. version 0, there is no version information saved in checkpoint.
            format: [(3 * np * hn), h]
        2. version 1.0
            format: [(np * hn * 3), h]
        3. version 2.0
            format: [(np * 3 * hn), h]

        h: hidden size
        n: number of attention heads
        p: number of model parallel partitions
        np: n/p
        hn: h/n
        """

        new_qkv = None
        if ckpt_ver == 0:
            # [(3 * np * hn), h]
            assert param.shape[0] % 3 == 0
            size_qkv = param.shape[0] // 3
            split_tensors = torch.split(param, size_qkv, dim=0)

            assert split_tensors[0].shape[0] % num_to_split == 0
            split_size = split_tensors[0].shape[0] // num_to_split

            tensors = []
            for i in range(3):
                tensors.append(torch.split(split_tensors[i], split_size, dim=0)[offset])
            new_qkv = torch.cat(tensors, axis=0)
        elif ckpt_ver == 1.0 or ckpt_ver == 2.0:
            # [(np * hn * 3), h] or [(np * 3 * hn), h]
            assert param.shape[0] % num_to_split == 0
            size_qkv = param.shape[0] // num_to_split
            split_tensors = torch.split(param, size_qkv, dim=0)
            new_qkv = split_tensors[offset]
        else:
            assert False, f'checkpoint version: {ckpt_ver} is not supported'

        return new_qkv

    def merge_state_dict(self,
                         mp_world_size,
                         mp_rank,
                         quantize=False,
                         quantize_bits=8,
                         groups=64,
                         mlp_extra_grouping=True):
        self.sanity_check(self.ckpt_list[0])

        sd_list = self.get_merge_state_dicts(mp_world_size, mp_rank)
        ds_sd = copy.deepcopy(sd_list[0])
        new_client_sd = collections.OrderedDict()

        client_sd_list = [self.get_module(sd) for sd in sd_list]
        keys = client_sd_list[0].keys()

        ckpt_ver = self.get_checkpoint_version(ds_sd)
        logger.info(f"checkpoint version: {ckpt_ver}")
        if quantize:
            quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)

        for key in keys:
            value_list = [sd[key] for sd in client_sd_list]

            if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
                if quantize:
                    value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key, merge_dim=1)
                new_client_sd[key] = torch.cat(value_list, axis=1)
            elif "attention.query_key_value" in key:
                if quantize and "attention.query_key_value.weight" in key:
                    value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key)
                    new_client_sd[key] = torch.cat(value_list, axis=0)
                else:
                    if quantize:
                        new_client_sd[key] = torch.cat(value_list, axis=0)
                    else:
                        new_client_sd[key] = self.merge_query_key_value(value_list, ckpt_ver)
            elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key:
                if quantize and "mlp.dense_h_to_4h.weight" in key:
                    value_list = quantizer.Quantize(value_list, quantize_bits, groups, key=key)
                new_client_sd[key] = torch.cat(value_list, axis=0)
            else:
                new_client_sd[key] = value_list[0]
        if quantize:
            all_scales = quantizer.merge_scales()
        ds_sd = self.set_module(ds_sd, new_client_sd)

        return ds_sd, (all_scales if quantize else None), len(client_sd_list)

    def split_state_dict(self,
                         mp_world_size,
                         mp_rank,
                         quantize=False,
                         quantize_bits=8,
                         groups=64,
                         mlp_extra_grouping=True):
        #self.sanity_check(self.ckpt_list[0])

        sd, num_to_split, ckpt_offset = self.get_split_state_dict(mp_world_size, mp_rank)
        ds_sd = copy.deepcopy(sd)
        new_client_sd = collections.OrderedDict()

        client_sd = self.get_module(sd)

        ckpt_ver = self.get_checkpoint_version(ds_sd)
        logger.info(f"checkpoint version: {ckpt_ver}")

        if quantize:
            quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, mp_size=mp_world_size)

        for key in client_sd.keys():
            value = client_sd[key]

            if "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
                assert value.shape[1] % num_to_split == 0
                split_size = value.shape[1] // num_to_split
                if quantize:
                    q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
                    value = q_vals[0]
                new_client_sd[key] = torch.split(value, split_size, dim=1)[ckpt_offset]
            elif "attention.query_key_value" in key:
                if quantize and "attention.query_key_value.weight" in key:
                    q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
                    value = q_vals[0]
                new_client_sd[key] = self.split_query_key_value(value, num_to_split, ckpt_offset, ckpt_ver)
            elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key or "final_linear.weight" in key:
                assert value.shape[0] % num_to_split == 0
                split_size = value.shape[0] // num_to_split
                if quantize and "mlp.dense_h_to_4h.weight" in key:
                    q_vals = quantizer.Quantize([value], quantize_bits, groups, key)
                    value = q_vals[0]
                new_client_sd[key] = torch.split(value, split_size, dim=0)[ckpt_offset]
            else:
                new_client_sd[key] = value

        if quantize:
            all_scales = quantizer.merge_scales_split(num_to_split)

        ds_sd = self.set_module(ds_sd, new_client_sd)

        return ds_sd, (all_scales if quantize else None)

    def sanity_check(self, ckpt_file_name):
        keys_to_check = [
            "attention.dense.weight", "mlp.dense_4h_to_h.weight", "attention.query_key_value",
            "mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.bias"
        ]

        sd = self.checkpoint_engine.load(ckpt_file_name, map_location=lambda storage, loc: storage)

        # partial_key is a sub-string of one key in the sd
        def check_key_exist(partial_key, sd):
            keys = sd.keys()
            found = False
            for k in keys:
                if partial_key in k:
                    found = True
                    break
            return found

        for key in keys_to_check:
            assert check_key_exist(key,
                                   self.get_module(sd)), f'key: {key} is not found in the checkpoint {ckpt_file_name}'

    def get_checkpoint_version(self, state_dict):
        # Use 0 if version info doesn't exist
        return self.version if self.version is not None else state_dict.get('checkpoint_version', 0)
