# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual
from deepspeed.utils.types import ActivationFuncType


def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode):
    assert not transposed_mode and not async_op and not q_int8
    return matmul_ext.matmul(input, weight, bias=None, activation="", use_triton=True)


def fused_gemm_gelu(input,
                    weight,
                    weight_scale,
                    bias,
                    weight_out,
                    weight_out_scale,
                    epsilon,
                    pre_layer_norm,
                    q_int8,
                    async_op,
                    transposed_mode,
                    use_triton_ln=True):
    assert not transposed_mode

    # activation
    activation = "gelu"

    # intermediate fc in FF
    intm_out = matmul_ext.matmul(input, weight, bias=bias, activation=activation, use_triton=True)

    # output fc in FF
    ff_out = matmul_ext.matmul(
        intm_out,
        weight_out,
        bias=None,
        activation="",  # bias added layer with residual_add + bias + layerNorm layer
        use_triton=True)
    return ff_out


def linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, transposed_mode=False):
    assert not transposed_mode and not do_flash_attn
    qkv_out = matmul_ext.matmul(input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)

    return qkv_out


def mlp_gemm_func(input,
                  residual,
                  input_bias,
                  weight_interm,
                  weight_out,
                  bias,
                  gamma,
                  beta,
                  epsilon,
                  pre_layer_norm,
                  mlp_after_attn,
                  weight_interm_scale,
                  weight_out_scale,
                  q_int8,
                  mlp_act_func_type,
                  transposed_mode,
                  use_triton_ln=True):
    assert not transposed_mode

    # residual add and layerNorm after attention
    if use_triton_ln:
        mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
    else:
        mlp_input = LayerNormOp.layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)

    # activation
    if ActivationFuncType(mlp_act_func_type) == ActivationFuncType.GELU:
        activation = "gelu"
    elif ActivationFuncType(mlp_act_func_type) == ActivationFuncType.ReLU:
        activation = "relu"
    else:
        activation = ""

    # intermediate fc in FF
    intm_out = matmul_ext.matmul(mlp_input, weight_interm, bias=bias, activation=activation, use_triton=True)
    # output fc in FF
    ff_out = matmul_ext.matmul(
        intm_out,
        weight_out,
        bias=None,
        activation="",  # bias added layer with residual_add + bias + layerNorm layer
        use_triton=True)

    return ff_out, mlp_input


def qkv_gemm_func(
    input,
    weight,
    q_scale,
    bias,
    gamma,
    beta,
    epsilon,
    add_bias,
    q_int8,
    transposed_mode=False,
    use_triton_ln=True,
):

    assert not transposed_mode
    # residual add and layerNorm after attention
    if use_triton_ln:
        qkv_input = layer_norm(input, gamma, beta, epsilon)
    else:
        qkv_input = LayerNormOp()(input, gamma, beta, epsilon)

    qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)

    return qkv_out, qkv_input
