# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from abc import ABC, abstractmethod
from packaging import version as pkg_version
import torch


class MetaTensorContainer(ABC):
    """
    NOTE: If you are using this feature with a container that
    also inherits from `HybridEngineContainer`, ensure that `MetaTensorContainer`
    is inherited before `HybridEngineContainer` in the class definition.
    """

    def __init__(self, **kwargs):
        if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
            raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
        super().__init__(**kwargs)
        self.is_meta = False
        self.ckpt_load_enabled = True

    def initialize_tensors(self, enable_training=False):
        super().initialize_tensors(enable_training=enable_training)
        self.is_meta = self.qkvw.is_meta

    def apply_tensor_parallelism(self, mp_replace, **kwargs):
        if self.is_meta:
            if self.qkvb is None:
                self.module.attention.attn_qkvb = None
            if self.dense_b is None:
                self.module.attention.attn_ob = None
        else:
            super().apply_tensor_parallelism(mp_replace, **kwargs)

    def copy_data_to_new_module(self):
        if self.is_meta:
            if self.attn_nw is None:
                self.module.mlp.attn_nw = self.attn_nw
                self.module.mlp.attn_nb = self.attn_nb
        else:
            super().copy_data_to_new_module()

    def transpose(self):
        if not self.is_meta:
            super().transpose()

    @abstractmethod
    def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
        """
        Load all the transformer parameter from the checkpoint file (sd).
        In addition to the parameter names, we require two
        more parameters to help read the data correctly
        from the checkpoint and split the qkv heads in the
        right order:
            1. `use_load_prefix` (Default: False): this specifies
                whether we need to use the name of first abstraction
                layer of the model for searching the parameter's name
                in a checkpoint file. For more information of how this
                is used please see
                https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py
            2. `split_qkv` (Default: True): we use this flag when splitting
                the qkv parameter into heads. If it is False, it means the heads
                of q, k, and v are stored together and needs to split in the
                DeepSpeed-Inference API.
        """
        raise NotImplementedError("A load_params() function must be defined in the model container \
                                  when inheriting the MetaTensorContainer feature")
