# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h
# from Cutlass C++ to Cute-DSL.
import math
import operator
from typing import Type, Optional

import cuda.bindings.driver as cuda

import cutlass
import cutlass.cute as cute

from flash_attn.cute import utils


class FlashAttentionBackwardPreprocess:
    def __init__(
        self,
        dtype: Type[cutlass.Numeric],
        head_dim: int,
        m_block_size: int = 128,
        num_threads: int = 128,
    ):
        """
        All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension
        should be a multiple of 8.

        :param head_dim: head dimension
        :type head_dim: int
        :param m_block_size: m block size
        :type m_block_size: int
        :param num_threads: number of threads
        :type num_threads: int
        """
        self.dtype = dtype
        self.m_block_size = m_block_size
        # padding head_dim to a multiple of 32 as k_block_size
        hdim_multiple_of = 32
        self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
        self.check_hdim_oob = head_dim != self.head_dim_padded
        self.num_threads = num_threads

    @staticmethod
    def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool:
        """Check if the kernel can be implemented with the given parameters.

        :param dtype: data type
        :type dtype: cutlass.Numeric
        :param head_dim: head dimension
        :type head_dim: int
        :param m_block_size: m block size
        :type m_block_size: int
        :param num_threads: number of threads
        :type num_threads: int

        :return: True if the kernel can be implemented, False otherwise
        :rtype: bool
        """
        if dtype not in [cutlass.Float16, cutlass.BFloat16]:
            return False
        if head_dim % 8 != 0:
            return False
        if num_threads % 32 != 0:
            return False
        if num_threads < m_block_size:  # For multiplying lse with log2
            return False
        return True

    def _setup_attributes(self):
        # ///////////////////////////////////////////////////////////////////////////////
        # GMEM Tiled copy:
        # ///////////////////////////////////////////////////////////////////////////////
        # Thread layouts for copies
        # We want kBlockKGmem to be a power of 2 so that when we do the summing,
        # it's just between threads in the same warp
        gmem_k_block_size = (
            128
            if self.head_dim_padded % 128 == 0
            else (
                64
                if self.head_dim_padded % 64 == 0
                else (32 if self.head_dim_padded % 32 == 0 else 16)
            )
        )
        universal_copy_bits = 128
        async_copy_elems = universal_copy_bits // self.dtype.width
        # atom_universal_copy: universal copy atom for O & dO load
        atom_universal_copy = cute.make_copy_atom(
            cute.nvgpu.CopyUniversalOp(),
            self.dtype,
            num_bits_per_copy=universal_copy_bits,
        )
        # tOdO_layout: thread layout for O & dO load
        self.gmem_threads_per_row = gmem_k_block_size // async_copy_elems
        assert self.num_threads % self.gmem_threads_per_row == 0
        tOdO_layout = cute.make_ordered_layout(
            (self.num_threads // self.gmem_threads_per_row, self.gmem_threads_per_row),
            order=(1, 0),
        )
        # Value layouts for copies
        vOdO_layout = cute.make_layout((1, async_copy_elems))
        self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
            atom_universal_copy, tOdO_layout, vOdO_layout
        )
        self.gmem_tiled_copy_dO = cute.make_tiled_copy_tv(
            atom_universal_copy, tOdO_layout, vOdO_layout
        )

        async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width
        atom_universal_copy_accum = cute.make_copy_atom(
            cute.nvgpu.CopyUniversalOp(),
            cutlass.Float32,
            num_bits_per_copy=universal_copy_bits,
        )
        assert (
            self.m_block_size * self.head_dim_padded // async_copy_elems_accum
        ) % self.num_threads == 0
        self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv(
            atom_universal_copy_accum,
            cute.make_layout(self.num_threads),
            cute.make_layout(async_copy_elems_accum),
        )

    @cute.jit
    def __call__(
        self,
        mO: cute.Tensor,
        mdO: cute.Tensor,
        mdPsum: cute.Tensor,
        mLSE: Optional[cute.Tensor],
        mLSElog2: Optional[cute.Tensor],
        mdQaccum: Optional[cute.Tensor],
        stream: cuda.CUstream,
    ):
        # Get the data type and check if it is fp16 or bf16
        if cutlass.const_expr(not (mO.element_type == mdO.element_type)):
            raise TypeError("All tensors must have the same data type")
        if cutlass.const_expr(not mO.element_type in [cutlass.Float16, cutlass.BFloat16]):
            raise TypeError("Only Float16 or BFloat16 is supported")
        if cutlass.const_expr(not mdPsum.element_type in [cutlass.Float32]):
            raise TypeError("dPsum tensor must be Float32")
        if cutlass.const_expr(mdQaccum is not None):
            if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]):
                raise TypeError("dQaccum tensor must be Float32")
        if cutlass.const_expr(mLSE is not None):
            assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided"
            if cutlass.const_expr(not mLSE.element_type in [cutlass.Float32]):
                raise TypeError("LSE tensor must be Float32")
            if cutlass.const_expr(not mLSElog2.element_type in [cutlass.Float32]):
                raise TypeError("LSElog2 tensor must be Float32")

        self._setup_attributes()

        # grid_dim: (m_block, num_head, batch_size)
        grid_dim = (
            cute.ceil_div(mO.shape[1], self.m_block_size),
            cute.size(mO.shape[2]),
            cute.size(mO.shape[0]),
        )
        self.kernel(
            mO,
            mdO,
            mdPsum,
            mLSE,
            mLSElog2,
            mdQaccum,
            self.gmem_tiled_copy_O,
            self.gmem_tiled_copy_dO,
            self.gmem_tiled_copy_dQaccum,
        ).launch(
            grid=grid_dim,
            block=[self.num_threads, 1, 1],
            stream=stream,
        )

    @cute.kernel
    def kernel(
        self,
        mO: cute.Tensor,
        mdO: cute.Tensor,
        mdPsum: cute.Tensor,
        mLSE: Optional[cute.Tensor],
        mLSElog2: Optional[cute.Tensor],
        mdQaccum: Optional[cute.Tensor],
        gmem_tiled_copy_O: cute.TiledCopy,
        gmem_tiled_copy_dO: cute.TiledCopy,
        gmem_tiled_copy_dQaccum: cute.TiledCopy,
    ):
        # Thread index, block index
        tidx, _, _ = cute.arch.thread_idx()
        m_block, num_head, batch_size = cute.arch.block_idx()

        # ///////////////////////////////////////////////////////////////////////////////
        # Get the appropriate tiles for this thread block.
        # ///////////////////////////////////////////////////////////////////////////////
        blkOdO_shape = (self.m_block_size, self.head_dim_padded)
        # (m_block_size, head_dim)
        gO = cute.local_tile(mO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0))
        gdO = cute.local_tile(mdO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0))

        gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
        gmem_thr_copy_dO = gmem_tiled_copy_dO.get_slice(tidx)
        # (CPY_Atom, CPY_M, CPY_K)
        tOgO = gmem_thr_copy_O.partition_S(gO)
        tOgdO = gmem_thr_copy_dO.partition_S(gdO)

        # ///////////////////////////////////////////////////////////////////////////////
        # Predicate: Mark indices that need to copy when problem_shape isn't a multiple
        # of tile_shape
        # ///////////////////////////////////////////////////////////////////////////////
        # Construct identity layout for KV
        cOdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
        tOcO = gmem_thr_copy_O.partition_S(cOdO)
        t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cOdO)
        tOpO = utils.predicate_k(tOcO, limit=mO.shape[3])
        tOcdO = gmem_thr_copy_dO.partition_S(cOdO)
        t0OcdO = gmem_thr_copy_dO.get_slice(0).partition_S(cOdO)
        tOpdO = utils.predicate_k(tOcdO, limit=mdO.shape[3])

        seqlen_q = mO.shape[1]
        seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size)

        if cutlass.const_expr(mLSE is not None):
            gLSE = cute.local_tile(
                mLSE[batch_size, num_head, None], (self.m_block_size,), (m_block,)
            )
            lse = cutlass.Float32.inf
            if tidx < seqlen_q - m_block * self.m_block_size:
                lse = gLSE[tidx]

        tOrO = cute.make_fragment_like(tOgO)
        tOrdO = cute.make_fragment_like(tOgdO)
        assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0])
        assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1])
        assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2])
        for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):
            # Instead of using tOcO, we using t0OcO and subtract the offset from the limit
            # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time.
            if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]:
                cute.copy(
                    gmem_thr_copy_O,
                    tOgO[None, m, None],
                    tOrO[None, m, None],
                    pred=tOpO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None,
                )
                cute.copy(
                    gmem_thr_copy_dO,
                    tOgdO[None, m, None],
                    tOrdO[None, m, None],
                    pred=tOpdO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None,
                )
        # Sum across the "k" dimension
        dpsum = (tOrO.load().to(cutlass.Float32) * tOrdO.load().to(cutlass.Float32)).reduce(
            cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)
        )
        dpsum = utils.warp_reduce(dpsum, operator.add, width=self.gmem_threads_per_row)
        dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), cutlass.Float32)
        dP_sum.store(dpsum)

        # Write dPsum from rmem -> gmem
        gdPsum = cute.local_tile(
            mdPsum[batch_size, num_head, None], (self.m_block_size,), (m_block,)
        )
        # Only the thread corresponding to column 0 writes out the lse to gmem
        if tOcO[0, 0, 0][1] == 0:
            for m in cutlass.range(cute.size(dP_sum), unroll_full=True):
                row = tOcO[0, m, 0][0]
                gdPsum[row] = dP_sum[m] if row < mO.shape[1] - m_block * self.m_block_size else 0.0

        # Clear dQaccum
        if cutlass.const_expr(mdQaccum is not None):
            blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,)
            gdQaccum = cute.local_tile(
                mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,)
            )
            gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)
            tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)
            zero = cute.make_fragment_like(tQgQaccum)
            zero.fill(0.0)
            cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum)

        if cutlass.const_expr(mLSE is not None):
            gLSElog2 = cute.local_tile(
                mLSElog2[batch_size, num_head, None], (self.m_block_size,), (m_block,)
            )
            LOG2_E = math.log2(math.e)
            if tidx < seqlen_q_rounded - m_block * self.m_block_size:
                gLSElog2[tidx] = lse * LOG2_E if lse != -cutlass.Float32.inf else 0.0
