# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import triton
import os
from filelock import FileLock
import deepspeed.ops.transformer.inference.triton.triton_matmul_kernel as triton_matmul_kernel
import pickle
from io import open
import deepspeed
from pathlib import Path
import atexit
import subprocess


# -----------------------------------------------------------------------------
# util class/functions for triton
def is_nfs_path(path):
    if os.name == 'nt':
        return False

    # Normalize the path to get the absolute path
    path = os.path.abspath(path)

    # Use the 'df' command to find the file system type for the given path
    try:
        output = subprocess.check_output(['df', '-T', path], encoding='utf-8')
    except subprocess.CalledProcessError:
        return False  # Command failed

    # Process the output of 'df -T' to check for 'nfs' in the filesystem type column
    lines = output.strip().split('\n')
    if len(lines) > 1:  # The first line is headers
        fs_type = lines[1].split()[1].lower()  # File system type is the second column
        return 'nfs' in fs_type
    return False


class TritonCacheDir:
    _warning_printed = False

    @staticmethod
    def warn_if_nfs(cache_dir):
        if is_nfs_path(cache_dir) and not TritonCacheDir._warning_printed:
            print(
                f"Warning: The cache directory for DeepSpeed Triton autotune, {cache_dir}, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path."
            )
            TritonCacheDir._warning_printed = True
        return

    @staticmethod
    def default_cache_dir():
        tmp_path = os.path.join(Path.home(), ".triton", "autotune")
        return tmp_path


def bias_add_activation(C, bias=None, activation=""):
    if bias is not None:
        C += bias
    # activation
    if activation == "relu":
        relu = torch.nn.Relu()
        C = relu(C)
    elif activation == "leaky_relu":
        leaky_relu = torch.nn.LeakyReLU(0.01)
        C = leaky_relu(C)
    elif activation == "gelu":
        sigmoid = torch.nn.Sigmoid()
        C = sigmoid(1.702 * C) * C
    elif activation == "sigmoid":
        sigmoid = torch.nn.Sigmoid()
        C = sigmoid(C)
    return C


class AutotuneCacheManager:
    """
        Cache manager for autotune
    """

    def __init__(self, key):
        self.key = key
        self.file_path = None
        self.lock_path = None
        # if caching is enabled, get the lock and bin path
        self.cache_dir = os.environ.get('TRITON_CACHE_DIR', TritonCacheDir.default_cache_dir())
        TritonCacheDir.warn_if_nfs(self.cache_dir)
        if self.cache_dir:
            os.makedirs(self.cache_dir, exist_ok=True)
            self.file_path = os.path.join(self.cache_dir, self.key + ".pickle")
            self.lock_path = self.file_path + ".lock"

    def has_file(self):
        return self.file_path and os.path.exists(self.file_path)

    def put(self, table):
        if self.file_path:
            assert self.lock_path is not None
            with FileLock(self.lock_path):
                with open(self.file_path + ".tmp", 'wb') as handle:
                    pickle.dump(table, handle)
                os.replace(self.file_path + ".tmp", self.file_path)

    def load(self):
        if os.path.exists(self.file_path):
            with open(self.file_path, 'rb') as handle:
                loaded_dict = pickle.load(handle)
            return loaded_dict
        else:
            return None


# -----------------------------------------------------------------------------
# triton matmul class


class MatmulExt(torch.autograd.Function):
    """
        a wrapper class that can call different triton matmul kernels depending on the input parameters
    """

    @staticmethod
    def forward(A, B, bias=None, activation="", use_triton=True, update_autotune_table=False):
        """
            A: input, activation matrix A
            B: input, weight matrix B
        """
        matmul = None
        quantize_activation = False
        Batch = 0

        if len(A.shape) == 3:  # if A is 3d-tensor where batch index is given as 0-axis
            assert A.is_contiguous(), "matrix A must be contiguous"
            Batch, M, K = A.shape
            A = A.view(-1, K)

        # fp16 activation and fp16 weight matmul into fp16 output
        matmul = fp16_matmul
        C = matmul.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)

        if matmul and update_autotune_table:
            matmul._update_autotune_table()

        if Batch > 0:
            C = C.view(Batch, M, -1)

        return C


class TritonMatmul(torch.autograd.Function):
    """
        triton matmul kernel superclass
    """

    def __init__(self):
        pass

    @staticmethod
    def _ref_forward(A, B, ref_dtype=torch.float32):
        C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
        return C

    @staticmethod
    def _read_autotune_table(cache_key, triton_kernel):
        cache_manager = AutotuneCacheManager(cache_key)
        table = cache_manager.load()
        if table:
            triton_kernel.cache = table

    @staticmethod
    def _write_autotune_table(cache_key, triton_kernel):
        cache_manager = AutotuneCacheManager(cache_key)
        cache_manager.put(triton_kernel.cache)

    @staticmethod
    def _update_autotune_table(cache_key, triton_kernel):
        cache_manager = AutotuneCacheManager(cache_key)
        autotune_table = cache_manager.load()
        if autotune_table is None:
            autotune_table = dict()
        autotune_table.update(triton_kernel.cache)  # always overwrite with the new autotune results
        cache_manager = AutotuneCacheManager(cache_key)
        cache_manager.put(autotune_table)

    @staticmethod
    def forward(
            A,
            B,
            ref_dtype=torch.float32,  # fp32 only
            bias=None,
            activation=""):
        C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
        C = bias_add_activation(C, bias, activation)
        return C


class Fp16Matmul(TritonMatmul):
    """
        fp16 matrix multiplication kernel
        dtypes: fp16 x fp16 = fp16
    """

    _2d_kernel = triton_matmul_kernel._fp_matmul
    _4d_kernel = triton_matmul_kernel.matmul_4d_kernel
    _cache_stride = 32

    def __init__(self, read_cache=True):
        super().__init__()
        if read_cache:
            __class__._read_autotune_table()

    def skip_autotune(self):
        __class__._2d_kernel.configs = [__class__._2d_kernel.configs[0]]
        __class__._4d_kernel.configs = [__class__._4d_kernel.configs[0]]

    @staticmethod
    def forward(A, B, use_triton=True, bias=None, activation=""):
        if use_triton:
            device = A.device
            # handle non-contiguous inputs if necessary
            if A.stride(0) > 1 and A.stride(1) > 1:
                A = A.contiguous()
            if B.stride(0) > 1 and B.stride(1) > 1:
                B = B.contiguous()
            # checks constraints
            assert A.shape[1] == B.shape[0], "incompatible dimensions"
            M, K = A.shape
            _, N = B.shape
            # allocates output
            C = torch.empty((M, N), device=device, dtype=A.dtype)
            # accumulator types
            ACC_TYPE = triton.language.float32 if A.dtype in [torch.float16, torch.bfloat16, torch.float32
                                                              ] else triton.language.int32
            # launch kernel
            grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
            __class__._2d_kernel[grid](A,
                                       B,
                                       C,
                                       M,
                                       N,
                                       K,
                                       bias,
                                       A.stride(0),
                                       A.stride(1),
                                       B.stride(0),
                                       B.stride(1),
                                       C.stride(0),
                                       C.stride(1),
                                       M // __class__._cache_stride,
                                       N // __class__._cache_stride,
                                       K // __class__._cache_stride,
                                       GROUP_M=8,
                                       ACC_TYPE=ACC_TYPE,
                                       BIAS_ADD=(0 if bias is None else 1),
                                       ACTIVATION=activation)
        else:
            C = torch.matmul(A, B)
        return C

    @staticmethod
    def _matmul_4d(a, b):
        assert a.shape[-1] == b.shape[-2], "incompatible dimensions"
        assert a.is_contiguous(), "matrix A must be contiguous"
        assert b.is_contiguous(), "matrix B must be contiguous"

        B, H, M, K = a.shape
        B, H, K, N = b.shape

        assert K > 1, "inner-product dimension K should be larger than 1"

        c = torch.empty((B, H, M, N), device=a.device, dtype=a.dtype)

        grid = lambda META: (
            triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
            H,
            B,
        )

        __class__._4d_kernel[grid](
            a,
            b,
            c,
            M,
            N,
            K,
            M // __class__._cache_stride,
            N // __class__._cache_stride,
            K // __class__._cache_stride,
            a.stride(0),
            a.stride(1),
            a.stride(2),
            a.stride(3),
            b.stride(0),
            b.stride(1),
            b.stride(2),
            b.stride(3),
            c.stride(0),
            c.stride(1),
            c.stride(2),
            c.stride(3),
            scale=-1.0,
            MASK=False,
        )
        return c

    @staticmethod
    def _score_4d_matmul(input, head_size, input_mask, scale=-1.0):
        assert input.is_contiguous(), "matrix input must be contiguous"

        batches = input.shape[0]
        d_model = input.shape[-1] // 3
        num_of_heads = d_model // head_size

        q = input[:, :, :d_model]
        k = input[:, :, d_model:d_model * 2]

        q = q.view(batches, -1, num_of_heads, head_size)
        k = k.view(batches, -1, num_of_heads, head_size)

        # checks constraints
        assert q.shape == k.shape, "incompatible dimensions"
        B, M, H, K = q.shape
        B, N, H, K = k.shape

        assert K > 1, "inner-product dimension K should be larger than 1"

        # allocates output
        output = torch.empty((B, H, M, N), device=q.device, dtype=q.dtype)
        grid = lambda META: (
            triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
            H,
            B,
        )
        __class__._4d_kernel[grid](
            q,
            k,
            output,
            M,
            N,
            K,
            M // __class__._cache_stride,
            N // __class__._cache_stride,
            K // __class__._cache_stride,
            q.stride(0),
            q.stride(2),
            q.stride(1),
            q.stride(3),
            k.stride(0),
            k.stride(2),
            k.stride(3),
            k.stride(1),
            output.stride(0),
            output.stride(1),
            output.stride(2),
            output.stride(3),
            scale=scale,
            MASK=False,
        )
        return output

    @staticmethod
    def _context_4d_matmul(prob, input, head_size):
        assert prob.is_contiguous(), "matrix prob must be contiguous"
        assert input.is_contiguous(), "matrix input must be contiguous"

        batches = input.shape[0]
        d_model = input.shape[-1] // 3
        num_of_heads = d_model // head_size

        v = input[:, :, d_model * 2:]

        v = v.view(batches, -1, num_of_heads, head_size)

        # checks constraints
        assert (prob.shape[0] == v.shape[0] and prob.shape[1] == v.shape[2] and prob.shape[2] == v.shape[1]
                and prob.shape[3] == v.shape[1]), "incompatible dimensions"
        B, H, M, K = prob.shape
        B, K, H, N = v.shape

        assert K > 1, "inner-product dimension K should be larger than 1"

        # allocates output
        output = torch.empty((B, M, H, N), device=v.device, dtype=v.dtype)
        grid = lambda META: (
            triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
            H,
            B,
        )

        __class__._4d_kernel[grid](
            prob,
            v,
            output,
            M,
            N,
            K,
            M // __class__._cache_stride,
            N // __class__._cache_stride,
            K // __class__._cache_stride,
            prob.stride(0),
            prob.stride(1),
            prob.stride(2),
            prob.stride(3),
            v.stride(0),
            v.stride(2),
            v.stride(1),
            v.stride(3),
            # Here we also transpose the output when writing to memory.
            output.stride(0),
            output.stride(2),
            output.stride(1),
            output.stride(3),
            scale=-1,
            MASK=False,
        )
        return output.view(batches, -1, d_model)

    @staticmethod
    def _ref_forward(A, B, ref_dtype=torch.float32, bias=None, activation=""):
        C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
        C = bias_add_activation(C, bias, activation)
        return C

    @staticmethod
    def _check_parity(A,
                      B,
                      output_dtype,
                      SA=None,
                      SB=None,
                      qblock_size=None,
                      ref_dtype=torch.float32,
                      tol=0.01,
                      use_triton=True,
                      bias=None,
                      activation=""):
        torch_output = __class__._ref_forward(A, B, ref_dtype=ref_dtype, bias=bias, activation=activation)
        triton_output = __class__.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)
        assert torch.allclose(triton_output.cpu().type(torch_output.dtype), torch_output.cpu(), rtol=tol)
        print(f"{__class__.__name__}: PASSed the parity check")
        return triton_output, torch_output

    @staticmethod
    def _read_autotune_table():
        TritonMatmul._read_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
        TritonMatmul._read_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)

    @staticmethod
    def _write_autotune_table():
        TritonMatmul._write_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
        TritonMatmul._write_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)

    @staticmethod
    def _update_autotune_table():
        TritonMatmul._update_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
        TritonMatmul._update_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)


# -----------------------------------------------------------------------------
# mapping
if deepspeed.HAS_TRITON:
    fp16_matmul = Fp16Matmul()
    matmul = MatmulExt.forward
    matmul_4d = fp16_matmul._matmul_4d
    score_4d_matmul = fp16_matmul._score_4d_matmul
    context_4d_matmul = fp16_matmul._context_4d_matmul
else:
    fp16_matmul = None
    matmul = None
    matmul_4d = None
    score_4d_matmul = None
    context_4d_matmul = None


@atexit.register
def matmul_ext_update_autotune_table():
    if deepspeed.HAS_TRITON:
        fp16_matmul._update_autotune_table()
