# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import re
import torch
import types
from typing import List, Tuple, Union
from dataclasses import dataclass
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE)


@dataclass
class SubparamShape:
    patterns: List[str]
    shape: Tuple[Union[Tuple[int], int]]
    partition_dim: int


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
    hp_mapping = self._hp_mapping
    hp_mapping.optim_fragment = {}

    hp_keys = []
    for file in os.listdir(folder):
        # We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
        pattern = r'(.+).pt'
        match = re.search(pattern, file)
        if match:
            hp_keys.append(match.group(1))

    step = None
    for key in hp_keys:
        ckpt_file = os.path.join(folder, f"{key}.pt")
        ckpt_dict = torch.load(ckpt_file, weights_only=False)

        if key == "step":
            step = ckpt_dict
            continue

        full_hp_param = ckpt_dict[PARAM]

        # need to deal with slices that were averaged.
        # the opposite of averaging here becomes an exact copy of the first slice
        # I thought of 2 ways:
        # implementation a. find a way for a client to pass a dict with patterns
        # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
        #     tp_rank = 0
        #     tp_world_size = 1
        # the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
        # self.shape that means we automatically copy?
        # implementation b.
        # this version requires no additional data passed from the client
        # if the shapes already match it must be slices that were averaged - so we just hack around those
        if full_hp_param.shape == self.shape:
            tp_rank = 0
            tp_world_size = 1

        # special case for word_embeddings weights which get padded differently depending on TP degree.
        # the converter to universal currently strips the original padding completely so the saved
        # weight is padding-free and we just need to add new padding depending on the target TP
        # degree
        is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
        if is_vocab_tensor:
            # In the absence of data passed from the user wrt new padded vocab specific to tp degree
            # we can again derive that data by reverse engineering the target shapes like so:
            padded_target_vocab_size = self.shape[0] * tp_world_size
            assert padded_target_vocab_size >= full_hp_param.shape[0], \
                f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}'
            if padded_target_vocab_size > full_hp_param.shape[0]:
                padding_size = padded_target_vocab_size - full_hp_param.shape[0]
                full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)

        full_param_numel = full_hp_param.numel()
        tp_slice_numel = self.numel()
        #        if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
        #            print_rank_0(f'{full_hp_param[:10]=}', force=True)


        assert full_param_numel == tp_world_size * tp_slice_numel, \
            f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'

        #        print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
        #        print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")

        sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None)
        # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
        # special case is when a single parameter is effectively a container for multiple sub parameters
        # (more details at PARAM_N_SUB_PARAMS definition)
        chunk_dim = ckpt_dict.get(CAT_DIM, 0)
        n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
        if sub_param_shape:
            partition_dim = sub_param_shape.partition_dim
            sub_dim_sizes = sub_param_shape.shape[partition_dim]
            if not isinstance(sub_dim_sizes, tuple):
                sub_dim_sizes = (sub_dim_sizes, )

            partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape]
            full_hp_param = full_hp_param.view(partition_shape)

            offset = 0
            merged_chunks = []
            for sub_dim_size in sub_dim_sizes:
                sub_params_tp_slice = full_hp_param.narrow(partition_dim,
                                                           offset, sub_dim_size).chunk(tp_world_size,
                                                                                       dim=partition_dim)[tp_rank]
                merged_chunks.append(sub_params_tp_slice)
                offset += sub_dim_size
            tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim)

        elif n_sub_params > 1:
            sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
            sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
            tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
        else:
            # this performs the opposite of cat when merging TP slices
            tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]

        tp_hp_slice = tp_hp_slice.flatten()

        lp_frag_address = hp_mapping.lp_fragment_address
        tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)

        #        print(f"{key} SHAPE: {tp_hp_slice.shape=}")
        #        print(f"{key} SHAPE: {dst_tensor.shape=}")
        #        print(f"{key} SHAPE: {tp_hp_fragment.shape=}")

        if key == FP32_WEIGHT_KEY:
            dst_tensor = hp_mapping.get_hp_fragment()
            assert dst_tensor.numel() == lp_frag_address.numel, \
                f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
            dst_tensor.data.copy_(tp_hp_fragment.data)
        else:
            assert tp_hp_fragment.numel() == lp_frag_address.numel, \
                f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}'

            hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()

    return step


def enable_universal_checkpoint(param_list):
    for param in param_list:
        param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param)
