from typing import Tuple, Optional, Callable
from functools import lru_cache, partial

from torch import Tensor

import cutlass
import cutlass.cute as cute
from cutlass import Int32, Float32, Boolean, const_expr
from cutlass.cute.runtime import make_ptr
import cutlass.utils.hopper_helpers as sm90_utils_og
import cutlass.utils.blackwell_helpers as sm100_utils

from quack.compile_utils import make_fake_tensor as fake_tensor
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map
from quack.activation import act_fn_map
from quack.gemm_act import GemmActMixin
from quack.gemm_sm90 import GemmSm90
from quack.gemm_sm100 import GemmSm100
from quack.gemm_tvm_ffi_utils import (
    div_for_dtype,
    perm3d,
    get_majors,
    get_dtypes,
    make_scheduler_args,
    make_fake_scheduler_args,
    cached_compile,
    compile_gemm_kernel,
)
from quack.tile_scheduler import TriangularTileScheduler
from quack.varlen_utils import VarlenManager
import quack.copy_utils as copy_utils


class GemmSymmetricMixin(GemmActMixin):
    def get_scheduler_class(self, varlen_m: bool = False):
        return TriangularTileScheduler

    @cute.jit
    def epilogue(
        self,
        params: GemmActMixin.EpilogueParams,
        epi_smem_tensors: Tuple[cute.Tensor, ...],
        epi_pipeline: cutlass.pipeline.PipelineAsync,
        epi_store_pipeline: cutlass.pipeline.PipelineAsync,
        epi_read_state: cutlass.pipeline.PipelineState,
        epi_producer_state: cutlass.pipeline.PipelineState,
        epi_tile: cute.Tile,
        load_acc_subtile: Callable,
        tRS_rD: cute.Tensor,
        tRS_rC: Optional[cute.Tensor],
        tiled_copy_t2r: Optional[cute.TiledCopy],  # Only for Sm100
        tiled_copy_r2s: cute.TiledCopy,
        tRS_sD: cute.Tensor,
        tiled_copy_s2r: Optional[cute.TiledCopy],
        tSR_rC: Optional[cute.Tensor],
        tSR_sC: Optional[cute.Tensor],
        copy_D: Optional[Callable],
        copy_C: Optional[Callable],
        tile_coord_mnkl: cute.Coord,
        varlen_manager: VarlenManager,
        epilogue_barrier: cutlass.pipeline.NamedBarrier,
        tile_scheduler,
        tidx: Int32,
        is_tma_warp: Boolean,
    ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
        has_C = const_expr(tRS_rC is not None)
        has_D = const_expr(copy_D is not None)

        tma_atom_postact = params.tma_atom_postact
        mPostAct_mnl = params.mPostAct_mnl
        sRowVec, sColVec, sPostAct = epi_smem_tensors
        get_smem_store_op = (
            partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
            if self.arch == 100
            else sm90_utils_og.sm90_get_smem_store_op
        )
        copy_atom_postact_r2s = get_smem_store_op(
            self.postact_layout, self.postact_dtype, self.acc_dtype
        )
        tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
        tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
        batch_idx = tile_coord_mnkl[3]
        copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
            tma_atom_postact,
            varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
            self.cta_tile_shape_postact_mn,
            params.epi_tile_postact,
            sPostAct,
            tile_coord_mnkl,
        )

        # We iterate over epi tiles in the N dimension first before the M dimension
        epi_tile_shape = cute.zipped_divide(
            cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
        ).shape[1]
        epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
        epi_tile_num = cute.size(epi_tile_shape)
        num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num

        epi_tensors = self.epi_begin(
            params,
            epi_smem_tensors,
            epi_tile,
            tiled_copy_t2r,
            tiled_copy_r2s,
            tile_coord_mnkl,
            varlen_manager,
            epilogue_barrier,
            tidx,
        )

        if const_expr(copy_C is not None):
            for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
                gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
                if is_tma_warp:
                    epi_pipeline.producer_acquire(epi_producer_state)
                    copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
                    epi_pipeline.producer_commit(epi_producer_state)
                epi_producer_state.advance()

        for epi_idx in cutlass.range_constexpr(epi_tile_num):
            # The global memory coordinate for the current epi tile
            gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
            # Copy from acc to D registers
            load_acc_subtile(tRS_rD, epi_idx)
            epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
            if const_expr(has_C):
                epi_pipeline.consumer_wait(epi_read_state)
                cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
                # Fence to make sure shared memory read is visible to TMA load
                cute.arch.fence_view_async_shared()
                cute.arch.sync_warp()
                with cute.arch.elect_one():
                    epi_pipeline.consumer_release(epi_read_state)
                epi_read_state.advance()
            if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
                gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
                if is_tma_warp:
                    epi_pipeline.producer_acquire(epi_producer_state)
                    copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
                    epi_pipeline.producer_commit(epi_producer_state)
                epi_producer_state.advance()
            tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
            if is_tma_warp:
                epi_store_pipeline.producer_acquire()
            epilogue_barrier.arrive_and_wait()
            # Copy from D registers to shared memory
            epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
            if const_expr(has_D):
                copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
            cute.copy(
                tiled_copy_postact_r2s,
                tiled_copy_postact_r2s.retile(tRS_rPostAct),
                tRS_sPostAct[None, None, None, epi_buffer],
            )
            pid_m = tile_coord_mnkl[0]
            pid_n = tile_coord_mnkl[1]
            # Fence and barrier to make sure shared memory store is visible to TMA store
            cute.arch.fence_view_async_shared()
            epilogue_barrier.arrive_and_wait()
            # Copy from shared memory to global memory
            if is_tma_warp:
                square_tile_m = pid_m // self.cluster_shape_mnk[0]
                square_tile_n = pid_n // self.cluster_shape_mnk[1]
                if const_expr(has_D):
                    copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
                if square_tile_m != square_tile_n:  # don't write twice to the same tile
                    copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
                epi_store_pipeline.producer_commit()

        self.epi_end(
            params,
            epi_tensors,
            epi_tile,
            tiled_copy_t2r,
            tiled_copy_r2s,
            tile_coord_mnkl,
            varlen_manager,
            tidx,
        )

        return epi_read_state, epi_producer_state


class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
    pass


class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
    pass


@lru_cache(maxsize=None)
def _compile_gemm_symmetric(
    a_dtype,
    b_dtype,
    d_dtype,
    c_dtype,
    c_major,
    postact_dtype,
    a_major,
    b_major,
    d_major,
    postact_major,
    tile_shape_mn,
    cluster_shape_mnk,
    pingpong,
    persistent,
    has_semaphore,
    alpha_mode,
    beta_mode,
    device_capacity,
):
    GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
    # Symmetric GEMM: m == n, so reuse the same sym_int for shape checking
    m, k, l = cute.sym_int(), cute.sym_int(), cute.sym_int()
    a_leading = 1 if a_major == "k" else 0
    b_leading = 1 if b_major == "k" else 0
    d_leading = 1 if d_major == "n" else 0
    c_leading = 1 if c_major == "n" else 0
    div_a, div_b = div_for_dtype(a_dtype), div_for_dtype(b_dtype)
    div_d, div_c = div_for_dtype(d_dtype), div_for_dtype(c_dtype) if c_dtype else 1
    mA = fake_tensor(a_dtype, (m, k, l), leading_dim=a_leading, divisibility=div_a)
    mB = fake_tensor(b_dtype, (m, k, l), leading_dim=b_leading, divisibility=div_b)
    mD = fake_tensor(d_dtype, (m, m, l), leading_dim=d_leading, divisibility=div_d)
    mC = fake_tensor(c_dtype, (m, m, l), leading_dim=c_leading, divisibility=div_c)
    # PostAct = D.mT, so it has the opposite major from D (m↔n swapped)
    div_pa = div_for_dtype(postact_dtype)
    postact_leading = 1 if postact_major == "n" else 0
    mPostAct = fake_tensor(
        postact_dtype, (m, m, l), leading_dim=postact_leading, divisibility=div_pa
    )

    def fake_scalar(mode):
        if mode == 0:
            return None
        elif mode == 1:
            return Float32(1.0)
        else:
            return make_ptr(Float32, 0, cute.AddressSpace.gmem, assumed_align=4)

    activation = None  # identity
    act_fn = act_fn_map[activation]
    epi_args = GemmCls.EpilogueArguments(
        mPostAct,
        act_fn,
        alpha=fake_scalar(alpha_mode),
        beta=fake_scalar(beta_mode),
    )
    scheduler_args = make_fake_scheduler_args(has_semaphore, False, l)
    varlen_args = None
    key = (
        "gemm_symmetric",
        a_dtype,
        b_dtype,
        d_dtype,
        c_dtype,
        c_major,
        postact_dtype,
        a_major,
        b_major,
        d_major,
        postact_major,
        tile_shape_mn,
        cluster_shape_mnk,
        pingpong,
        persistent,
        has_semaphore,
        alpha_mode,
        beta_mode,
        device_capacity,
    )
    return cached_compile(
        key,
        lambda: compile_gemm_kernel(
            GemmCls,
            a_dtype,
            tile_shape_mn,
            cluster_shape_mnk,
            pingpong,
            persistent,
            False,
            device_capacity,
            mA,
            mB,
            mD,
            mC,
            epi_args,
            scheduler_args,
            varlen_args,
        ),
    )


def gemm_symmetric(
    A: Tensor,  # (l, m, k)
    B: Tensor,  # (l, m, k)
    D: Optional[Tensor],  # (l, m, m)
    C: Optional[Tensor],  # (l, m, m)
    tile_count_semaphore: Optional[Tensor],  # (1,)
    tile_M: int,
    tile_N: int,
    cluster_M: int,
    cluster_N: int,
    pingpong: bool = False,
    persistent: bool = True,
    max_swizzle_size: int = 8,
    alpha: float | Tensor = 1.0,
    beta: float | Tensor = 1.0,
) -> None:
    # Transpose D so the "activation" is a write to the mirrored tile
    PostAct = D.mT

    A_p, B_p, D_p, C_p = perm3d(A, B, D, C)
    PostAct_p = PostAct.permute(1, 2, 0) if PostAct.ndim == 3 else PostAct
    a_major, b_major, d_major, c_major = get_majors(A_p, B_p, D_p, C_p)
    a_dtype, b_dtype, d_dtype, c_dtype = get_dtypes(A, B, D, C)
    postact_dtype = torch2cute_dtype_map[PostAct.dtype]
    # PostAct = D.mT has swapped major: if D is n-major, PostAct is m-major
    postact_major = "n" if PostAct_p.stride(1) == 1 else "m"

    device_capacity = get_device_capacity(A.device)
    assert device_capacity[0] in [9, 10, 11], "Only SM90, SM100, and SM110 are supported"

    tile_shape_mn = (tile_M, tile_N)
    cluster_shape_mnk = (cluster_M, cluster_N, 1)
    alpha_mode = 2 if isinstance(alpha, Tensor) else (1 if alpha != 1.0 else 0)
    beta_mode = 2 if isinstance(beta, Tensor) else (1 if beta != 1.0 else 0)

    compiled_fn = _compile_gemm_symmetric(
        a_dtype,
        b_dtype,
        d_dtype,
        c_dtype,
        c_major,
        postact_dtype,
        a_major,
        b_major,
        d_major,
        postact_major,
        tile_shape_mn,
        cluster_shape_mnk,
        pingpong,
        persistent,
        tile_count_semaphore is not None,
        alpha_mode,
        beta_mode,
        device_capacity,
    )

    from quack.cache_utils import COMPILE_ONLY

    if COMPILE_ONLY:
        return

    max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0

    def scalar_arg(scalar, mode):
        if mode == 0:
            return None
        elif mode == 1:
            return Float32(scalar)
        else:
            return scalar.data_ptr()

    epi_args = GemmActMixin.EpilogueArguments(
        PostAct_p,
        None,  # act_fn is Constexpr, baked in at compile time
        alpha=scalar_arg(alpha, alpha_mode),
        beta=scalar_arg(beta, beta_mode),
    )
    scheduler_args = make_scheduler_args(
        max_active_clusters,
        max_swizzle_size,
        tile_count_semaphore,
    )
    varlen_args = None

    if device_capacity[0] > 9:
        compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None)
    else:
        compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args)
