# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import math
from torch import nn
from torch.nn import init
import deepspeed.comm as dist
from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer
from deepspeed.utils import logger

g_mpu = None


class QuantAct(nn.Module):
    """
    Class to quantize given activations. Note that when using this function, the input activation quantization range will be fixed for all
    tokens/images for inference. This generally will affect some accuracy but achieve better latency performance.
    Parameters:
    ----------
    act_range_momentum : float, default 0.95
        Momentum for updating the activation quantization range.
    quant_mode : str, default 'symmetric'
    """

    def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
        super(QuantAct, self).__init__()

        self.act_range_momentum = act_range_momentum
        self.quant_mode = quant_mode
        if quant_mode == 'symmetric':
            self.act_function = SymQuantizer.apply
        else:
            self.act_function = AsymQuantizer.apply

        self.register_buffer('x_min_max', torch.zeros(2))

    def forward(self, x, num_bits, *args):
        """
        x: the activation that we need to quantize
        num_bits: the number of bits we need to quantize the activation to
        *args: some extra arguments that are useless but needed for align with the interface of other quantization functions
        """

        if self.training:
            x_min = x.data.min()
            x_max = x.data.max()

            # Initialization
            if self.x_min_max[0] == self.x_min_max[1]:
                self.x_min_max[0] = x_min
                self.x_min_max[1] = x_max

            # if do not need momentum, please set self.act_range_momentum = 0
            self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
            self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum)

        x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])

        return x_q


class Embedding_Compress(nn.Embedding):

    def __init__(self, *kargs):
        super(Embedding_Compress, self).__init__(*kargs)
        self.weight.start_bits = None
        self.weight.target_bits = None
        self.weight.q_period = None
        self.weight_quantization_enabled_in_forward = False
        self.weight_quantization_enabled = False

    def extra_repr(self):
        return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
            self.num_embeddings, self.embedding_dim, self.weight.target_bits)

    def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
                                   weight_quantization_enabled_in_forward, quantization_type, num_groups):
        self.weight.start_bits = start_bits
        self.weight.target_bits = target_bits
        self.weight.q_period = quantization_period
        self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
        if self.weight_quantization_enabled_in_forward:
            logger.warning(
                "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
            )
            if self.weight.target_bits >= 3:
                if quantization_type == 'symmetric':
                    self.weight_quantizer = SymQuantizer.apply
                else:
                    self.weight_quantizer = AsymQuantizer.apply
            elif self.weight.target_bits == 2:
                assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
                self.weight_quantizer = TernaryQuantizer.apply
            elif self.weight.target_bits == 1:
                assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
                self.weight_quantizer = BinaryQuantizer.apply
            # for embedding, we always use token-wise quantization
            self.weight_quantize_num_groups = self.weight.size(0)

    def fix_weight_quantization(self):
        self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
                                                 self.weight_quantize_num_groups).data
        self.weight_quantization_enabled_in_forward = False
        return None

    def forward(self, input):
        if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
            weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
                                           self.weight_quantize_num_groups)
        else:
            weight = self.weight

        out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
                                      self.scale_grad_by_freq, self.sparse)
        return out


class LinearLayer_Compress(nn.Linear):
    """
    Linear layer with compression.
    """

    def __init__(self, *kargs, bias=True):
        super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
        self.sparse_pruning_method = None
        self.row_pruning_method = None
        self.head_pruning_method = None
        self.activation_quantization_method = None
        self.weight.start_bits = None
        self.weight.target_bits = None
        self.weight.q_period = None
        self.weight_quantization_enabled_in_forward = False
        self.weight_quantization_enabled = False
        self.sparse_pruning_enabled = False
        self.row_pruning_enabled = False
        self.head_pruning_enabled = False
        self.activation_quantization_enabled = False

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format(
            self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \
            self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)

    def enable_sparse_pruning(self, ratio, method):
        # Here, we support two cases: L1 norm based pruning and topk based pruning
        self.sparse_pruning_ratio = ratio
        self.sparse_pruning_method = method
        if method == 'l1':
            weight_norm = torch.abs(self.weight.data)
            mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
            mask = mask.view(self.weight.size())
            mask = mask.to(self.weight.device)
        elif method == 'topk':
            self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
            self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
            init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
            mask = None
        else:
            raise NotImplementedError

        self.register_buffer('sparse_pruning_mask', mask)

    def enable_row_pruning(self, ratio, method):
        # Here, we support two cases: L1 norm based pruning and topk based pruning
        self.row_pruning_ratio = ratio
        self.row_pruning_method = method

        if method == 'l1':
            # compute the l1 norm of each column
            weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=1)
            mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False)
            mask = mask.view(-1, 1)
            mask = mask.to(self.weight.device)
        elif method == 'topk':
            self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1))
            self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device)
            init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5))
            mask = None
        else:
            raise NotImplementedError

        self.register_buffer('row_pruning_mask', mask)

    def enable_head_pruning(self, ratio, method, num_heads):
        # Here, we support only topk based pruning
        self.num_heads = num_heads
        self.head_pruning_ratio = ratio
        self.head_pruning_method = method

        if method not in ['topk']:
            raise NotImplementedError
        else:
            self.head_pruning_ratio = ratio
            self.head_pruning_scores = nn.Parameter(torch.Tensor(1,
                                                                 self.num_heads))  # we apply the pruning to O matrix
            self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device)
            init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))

    def fix_sparse_pruning_helper(self):
        mask = self.get_mask(pruning_type='sparse')
        self.weight.data = self.weight.data * mask
        del self.sparse_pruning_mask
        if self.sparse_pruning_method == 'topk':
            del self.sparse_mask_scores
        self.sparse_pruning_method = None
        self.sparse_pruning_enabled = False
        return None

    def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False):
        # This function is used for row/col pruning
        # particularly, if we have two back-to-back layers, F1 and F2; when
        # we remove rows from F1, we also need to remove columns from F2
        # However, if we only have one layer, F1, then we only need to mask pruned
        # rows as 0 in F1
        if mask is None:
            mask = self.get_mask(pruning_type='row').bool()
            if dim_reduction:
                start_bits = self.weight.start_bits
                target_bits = self.weight.target_bits
                q_period = self.weight.q_period
                self.weight = nn.Parameter(self.weight.data[mask.view(-1), :])
                self.weight.start_bits = start_bits
                self.weight.target_bits = target_bits
                self.weight.q_period = q_period
                if self.bias is not None:
                    self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
                self.out_features = self.weight.size(0)
            else:
                self.weight.data = self.weight.data * mask.view(-1, 1)
                if self.bias is not None:
                    self.bias.data = self.bias.data * mask.view(-1)

            del self.row_pruning_mask
            if self.row_pruning_method == 'topk':
                del self.row_mask_scores
            self.row_pruning_method = None
        else:
            # this is generally for column pruning
            start_bits = self.weight.start_bits
            target_bits = self.weight.target_bits
            q_period = self.weight.q_period
            self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)])
            self.weight.start_bits = start_bits
            self.weight.target_bits = target_bits
            self.weight.q_period = q_period
            self.in_features = self.weight.size(1)
            mask = None
        self.row_pruning_enabled = False
        return mask

    def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False):
        # similar as row/col pruning, head pruning also needs to prune QKV which is associated with O matrix
        num_heads = num_heads if num_heads else self.num_heads
        if mask is None:
            if self.head_pruning_method == 'topk':
                mask = self.get_mask(pruning_type='head').bool()
                if dim_reduction:
                    shape = self.weight.size(0)
                    start_bits = self.weight.start_bits
                    target_bits = self.weight.target_bits
                    q_period = self.weight.q_period
                    self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads,
                                                                            -1)[mask.view(-1), :].reshape(-1,
                                                                                                          shape).t())
                    self.weight.start_bits = start_bits
                    self.weight.target_bits = target_bits
                    self.weight.q_period = q_period
                else:

                    shape = self.weight.size()
                    self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(
                        shape[1], shape[0]).t()

                if self.head_pruning_method == 'topk':
                    del self.head_pruning_scores
                self.head_pruning_method = None
            else:
                raise NotImplementedError
        else:
            start_bits = self.weight.start_bits
            target_bits = self.weight.target_bits
            q_period = self.weight.q_period
            shape = self.weight.size(1)
            self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape))
            self.weight.start_bits = start_bits
            self.weight.target_bits = target_bits
            self.weight.q_period = q_period
            if self.bias is not None:
                self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1))
        self.head_pruning_enabled = False
        return mask

    def get_mask(self, pruning_type='row'):
        if pruning_type == 'sparse':
            if self.sparse_pruning_method == 'l1':
                return self.sparse_pruning_mask.to(self.weight.device)
            elif self.sparse_pruning_method == 'topk':
                return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
            else:
                raise NotImplementedError
        if pruning_type == 'row':
            if self.row_pruning_method == 'l1':
                return self.row_pruning_mask.to(self.weight.device)
            elif self.row_pruning_method == 'topk':
                return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False)
            else:
                raise NotImplementedError
        elif pruning_type == 'head':
            if self.head_pruning_method == 'topk':
                return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

    def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
                                   weight_quantization_enabled_in_forward, quantization_type, num_groups):
        self.weight.start_bits = start_bits
        self.weight.target_bits = target_bits
        self.weight.q_period = quantization_period
        self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
        if self.weight_quantization_enabled_in_forward:
            logger.warning(
                "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
            )
            if self.weight.target_bits >= 3:
                if quantization_type == 'symmetric':
                    self.weight_quantizer = SymQuantizer.apply
                else:
                    self.weight_quantizer = AsymQuantizer.apply
            elif self.weight.target_bits == 2:
                assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
                self.weight_quantizer = TernaryQuantizer.apply
            elif self.weight.target_bits == 1:
                assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
                self.weight_quantizer = BinaryQuantizer.apply
            self.weight_quantize_num_groups = num_groups

    def fix_weight_quantization(self):
        self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
                                                 self.weight_quantize_num_groups).data
        self.weight_quantization_enabled_in_forward = False
        return None

    def enable_activation_quantization(self, bits, quantization_type, range_calibration):
        assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
        self.activation_quantization_bits = bits
        self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
        if range_calibration == 'static':
            self.activation_quantizer = QuantAct(quant_mode=quantization_type)
        else:
            if quantization_type == 'symmetric':
                self.activation_quantizer = SymQuantizer.apply
            else:
                self.activation_quantizer = AsymQuantizer.apply

    def head_pruning_reshape(self, w, mask):
        shape = w.shape
        return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t()

    def forward(self, input, skip_bias_add=False):

        if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
            weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
                                           self.weight_quantize_num_groups)
            bias = self.bias
        else:
            weight = self.weight
            bias = self.bias

        if self.sparse_pruning_enabled and self.sparse_pruning_method:
            mask = self.get_mask(pruning_type='sparse')
            weight = weight * mask.view(self.weight.size())

        if self.row_pruning_enabled and self.row_pruning_method:
            mask = self.get_mask(pruning_type='row')
            weight = weight * mask.view(-1, 1)
            if bias is not None:
                bias = bias * mask.view(-1)

        if self.head_pruning_enabled and self.head_pruning_method:
            mask = self.get_mask(pruning_type='head')
            weight = self.head_pruning_reshape(weight, mask)

        if self.activation_quantization_enabled:
            if 'dynamic' in self.activation_quantization_method:
                num_groups = input.numel() // input.size(-1)
            else:
                num_groups = 1
            input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)

        if skip_bias_add:
            # used for mpu linear layers
            output = nn.functional.linear(input, weight, None)
            return output, bias
        else:
            output = nn.functional.linear(input, weight, bias)
            return output


class Conv2dLayer_Compress(nn.Conv2d):
    """
    Conv2D layer with compression.
    """

    def __init__(self, *kargs):
        super(Conv2dLayer_Compress, self).__init__(*kargs)
        self.sparse_pruning_method = None
        self.channel_pruning_method = None
        self.activation_quantization_method = None
        self.weight.start_bits = None
        self.weight.target_bits = None
        self.weight.q_period = None
        self.weight_quantization_enabled_in_forward = False
        self.sparse_pruning_enabled = False
        self.channel_pruning_enabled = False
        self.activation_quantization_enabled = False

    def __repr__(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        output = s.format(**self.__dict__)

        return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
            self.sparse_pruning_method is not None, self.channel_pruning_method is not None,
            self.activation_quantization_method is not None, self.weight.target_bits)

    def enable_sparse_pruning(self, ratio, method):
        self.sparse_pruning_ratio = ratio
        self.sparse_pruning_method = method
        if method == 'l1':
            weight_norm = torch.abs(self.weight.data)
            mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
            mask = mask.view(self.weight.size())
            mask = mask.to(self.weight.device)
        elif method == 'topk':
            self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
            self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
            init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
            mask = None
        else:
            raise NotImplementedError

        self.register_buffer('sparse_pruning_mask', mask)

    def enable_channel_pruning(self, ratio, method):
        # Here, we support two cases: L1 norm based pruning and topk based pruning
        self.channel_pruning_ratio = ratio
        self.channel_pruning_method = method

        if method == 'l1':
            # compute the l1 norm of each conv2d kernel (the last three dimension)
            weight_norm = torch.linalg.norm(self.weight.data, ord=1, dim=[1, 2, 3])
            mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False)
            mask = mask.view(-1, 1, 1, 1)
            mask = mask.to(self.weight.device)
        elif method == 'topk':
            self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1))
            self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device)
            init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
            mask = None
        else:
            raise NotImplementedError

        self.register_buffer('channel_pruning_mask', mask)

    def fix_sparse_pruning_helper(self):
        mask = self.get_mask(pruning_type='sparse')
        self.weight.data = self.weight.data * mask
        del self.sparse_pruning_mask
        if self.sparse_pruning_method == 'topk':
            del self.sparse_mask_scores
        self.sparse_pruning_method = None
        self.sparse_pruning_enabled = False
        return None

    def fix_channel_pruning_helper(self, mask=None, dim_reduction=False):
        if mask is None:
            if self.channel_pruning_method in ['l1', 'topk']:
                mask = self.get_mask(pruning_type='channel').bool()
                if dim_reduction:
                    start_bits = self.weight.start_bits
                    target_bits = self.weight.target_bits
                    q_period = self.weight.q_period
                    self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...])
                    self.weight.start_bits = start_bits
                    self.weight.target_bits = target_bits
                    self.weight.q_period = q_period
                    if self.bias is not None:
                        self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
                else:
                    self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1)
                    if self.bias is not None:
                        self.bias.data = self.bias.data * mask.view(-1)
                del self.channel_pruning_mask
                if self.channel_pruning_method == 'topk':
                    del self.channel_mask_scores
                self.channel_pruning_method = None
            else:
                raise NotImplementedError
        else:
            start_bits = self.weight.start_bits
            target_bits = self.weight.target_bits
            q_period = self.weight.q_period
            self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...])
            self.weight.start_bits = start_bits
            self.weight.target_bits = target_bits
            self.weight.q_period = q_period
            mask = None
        self.channel_pruning_enabled = False
        return mask

    def get_mask(self, pruning_type='sparse'):
        if pruning_type == 'sparse':
            if self.sparse_pruning_method == 'l1':
                return self.sparse_pruning_mask.to(self.weight.device)
            elif self.sparse_pruning_method == 'topk':
                return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
            else:
                raise NotImplementedError
        elif pruning_type == 'channel':
            if self.channel_pruning_method == 'l1':
                return self.channel_pruning_mask.to(self.weight.device)
            elif self.channel_pruning_method == 'topk':
                return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

    def fix_weight_quantization(self):
        self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
                                                 self.weight_quantize_num_groups).data
        self.weight_quantization_enabled_in_forward = False
        return None

    def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
                                   weight_quantization_enabled_in_forward, quantization_type, num_groups):
        self.weight.start_bits = start_bits
        self.weight.target_bits = target_bits
        self.weight.q_period = quantization_period
        self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
        if self.weight_quantization_enabled_in_forward:
            assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now'
            logger.warning(
                "************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
            )
            if quantization_type == 'symmetric':
                self.weight_quantizer = SymQuantizer.apply
            else:
                self.weight_quantizer = AsymQuantizer.apply
            self.weight_quantize_num_groups = num_groups

    def enable_activation_quantization(self, bits, quantization_type, range_calibration):
        assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
        self.activation_quantization_bits = bits
        self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
        if range_calibration == 'static':
            self.activation_quantizer = QuantAct(quant_mode=quantization_type)
        else:
            if quantization_type == 'symmetric':
                self.activation_quantizer = SymQuantizer.apply
            else:
                self.activation_quantizer = AsymQuantizer.apply

    def forward(self, input):

        if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
            weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
                                           self.weight_quantize_num_groups)
            bias = self.bias
        else:
            weight = self.weight
            bias = self.bias

        if self.sparse_pruning_enabled and self.sparse_pruning_method:
            mask = self.get_mask(pruning_type='sparse')
            weight = weight * mask.view(self.weight.size())

        if self.channel_pruning_enabled:
            mask = self.get_mask(pruning_type='channel')
            weight = weight * mask.view(-1, 1, 1, 1)
            if bias is not None:
                bias = bias * mask.view(-1)

        if self.activation_quantization_enabled:
            if 'dynamic' in self.activation_quantization_method:
                num_groups = input.numel() // input[0].numel()
            else:
                num_groups = 1
            input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)

        return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)


class BNLayer_Compress(nn.BatchNorm2d):

    def fix_channel_pruning_helper(self, mask, dim_reduction=True):
        self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
        self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
        self.running_mean = self.running_mean[mask.view(-1)]
        self.running_var = self.running_var[mask.view(-1)]


def _reduce(input_):
    """All-reduce the input tensor across model parallel group."""
    group = g_mpu.get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # All-reduce.
    dist.all_reduce(input_, group=group)

    return input_


def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
    """Split a tensor along its last dimension.
    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    assert tensor.size()[last_dim] % num_partitions == 0
    last_dim_size = tensor.size()[last_dim] // num_partitions
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list


def _split(input_):
    """Split the tensor along its last dimension and keep the
    corresponding slice."""
    group = g_mpu.get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # Split along last dimension.
    world_size = dist.get_world_size(group=group)
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
    rank = dist.get_rank(group=group)
    output = input_list[rank].contiguous()

    return output


def _gather(input_):
    """Gather tensors and concatenate along the last dimension."""
    group = g_mpu.get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
    rank = dist.get_rank(group=group)
    world_size = dist.get_world_size(group=group)

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
    dist.all_gather(tensor_list, input_, group=group)

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output


class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output)


class _ReduceFromModelParallelRegion(torch.autograd.Function):
    """All-reduce the input from the model parallel region."""

    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class _ScatterToModelParallelRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def forward(ctx, input_):
        return _split(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _gather(grad_output)


class _GatherFromModelParallelRegion(torch.autograd.Function):
    """Gather the input from model parallel region and concatenate."""

    @staticmethod
    def forward(ctx, input_):
        return _gather(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _split(grad_output)


# -----------------
# Helper functions.
# -----------------


def copy_to_model_parallel_region(input_):
    return _CopyToModelParallelRegion.apply(input_)


def reduce_from_model_parallel_region(input_):
    return _ReduceFromModelParallelRegion.apply(input_)


def scatter_to_model_parallel_region(input_):
    return _ScatterToModelParallelRegion.apply(input_)


def gather_from_model_parallel_region(input_):
    return _GatherFromModelParallelRegion.apply(input_)


class ColumnParallelLinear_Compress(LinearLayer_Compress):

    def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False):
        # Keep input parameters
        global g_mpu
        g_mpu = mpu
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        self.skip_bias_add = skip_bias_add

        # Divide the weight matrix along the last dimension.
        world_size = mpu.get_model_parallel_world_size()
        assert output_size % world_size == 0
        self.output_size_per_partition = output_size // world_size

        super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias)

    def forward(self, input_):
        # Set up backprop all-reduce.
        input_parallel = copy_to_model_parallel_region(input_)
        # Matrix multiply.
        if self.skip_bias_add:
            output_parallel, bias = super().forward(input_parallel, True)
        else:
            output_parallel = super().forward(input_parallel)
            bias = None
        if self.gather_output:
            # All-gather across the partitions.
            output = gather_from_model_parallel_region(output_parallel)
        else:
            output = output_parallel
        return output, bias


class RowParallelLinear_Compress(LinearLayer_Compress):

    def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False):
        # Keep input parameters
        global g_mpu
        g_mpu = mpu
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        self.skip_bias_add = skip_bias_add

        # Divide the weight matrix along the last dimension.
        world_size = mpu.get_model_parallel_world_size()
        assert input_size % world_size == 0
        self.input_size_per_partition = input_size // world_size

        super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias)

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            input_parallel = scatter_to_model_parallel_region(input_)
        # Matrix multiply.
        output_parallel, bias = super().forward(input_parallel, True)

        # All-reduce across all the partitions.
        output_ = reduce_from_model_parallel_region(output_parallel)
        if not self.skip_bias_add:
            if bias is not None:
                output = output_ + bias
            else:
                output = output_
            output_bias = None
        else:
            output = output_
            output_bias = bias
        return output, output_bias
