/*
 * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights
 * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

#pragma once

namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////
//
// IGMMA 64xNx32 TN with int32 Accumulator with A and B from SMEM
//
////////////////////////////////////////////////////////////////////////////////////////////////////

template <int N>
struct Igmma_int8_int32 {};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x64x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_int8_int32<64> {
  static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[32]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8\n"
        "{\n"
        "   %0,  %1,  %2,  %3,  %4,  %5,  %6,  %7,\n"
        "   %8,  %9, %10, %11, %12, %13, %14, %15,\n"
        "  %16, %17, %18, %19, %20, %21, %22, %23,\n"
        "  %24, %25, %26, %27, %28, %29, %30, %31 \n"
        "},\n"
        "  %32, %33, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31])
        : "l"(desc_a), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x128x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_int8_int32<128> {
  static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[64]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8\n"
        "{\n"
        "   %0,  %1,  %2,  %3,  %4,  %5,  %6,  %7,\n"
        "   %8,  %9, %10, %11, %12, %13, %14, %15,\n"
        "  %16, %17, %18, %19, %20, %21, %22, %23,\n"
        "  %24, %25, %26, %27, %28, %29, %30, %31,\n"
        "  %32, %33, %34, %35, %36, %37, %38, %39,\n"
        "  %40, %41, %42, %43, %44, %45, %46, %47,\n"
        "  %48, %49, %50, %51, %52, %53, %54, %55,\n"
        "  %56, %57, %58, %59, %60, %61, %62, %63 \n"
        "},\n"
        "  %64, %65, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]),
          "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]),
          "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]),
          "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]),
          "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]),
          "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63])
        : "l"(desc_a), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x192x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_int8_int32<192> {
  static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[96]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8\n"
        "{\n"
        "    %0,   %1,   %2,   %3,   %4,   %5,   %6,   %7,\n"
        "    %8,   %9,  %10,  %11,  %12,  %13,  %14,  %15,\n"
        "   %16,  %17,  %18,  %19,  %20,  %21,  %22,  %23,\n"
        "   %24,  %25,  %26,  %27,  %28,  %29,  %30,  %31,\n"
        "   %32,  %33,  %34,  %35,  %36,  %37,  %38,  %39,\n"
        "   %40,  %41,  %42,  %43,  %44,  %45,  %46,  %47,\n"
        "   %48,  %49,  %50,  %51,  %52,  %53,  %54,  %55,\n"
        "   %56,  %57,  %58,  %59,  %60,  %61,  %62,  %63,\n"
        "   %64,  %65,  %66,  %67,  %68,  %69,  %70,  %71,\n"
        "   %72,  %73,  %74,  %75,  %76,  %77,  %78,  %79,\n"
        "   %80,  %81,  %82,  %83,  %84,  %85,  %86,  %87,\n"
        "   %88,  %89,  %90,  %91,  %92,  %93,  %94,  %95 \n"
        "},\n"
        "   %96, %97, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]),
          "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]),
          "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]),
          "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]),
          "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]),
          "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]),
          "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]),
          "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]),
          "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]),
          "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]),
          "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95])
        : "l"(desc_a), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x256x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_int8_int32<256> {
  static inline __device__ void mma(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[128]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8\n"
        "{\n"
        "    %0,   %1,   %2,   %3,   %4,   %5,   %6,   %7,\n"
        "    %8,   %9,  %10,  %11,  %12,  %13,  %14,  %15,\n"
        "   %16,  %17,  %18,  %19,  %20,  %21,  %22,  %23,\n"
        "   %24,  %25,  %26,  %27,  %28,  %29,  %30,  %31,\n"
        "   %32,  %33,  %34,  %35,  %36,  %37,  %38,  %39,\n"
        "   %40,  %41,  %42,  %43,  %44,  %45,  %46,  %47,\n"
        "   %48,  %49,  %50,  %51,  %52,  %53,  %54,  %55,\n"
        "   %56,  %57,  %58,  %59,  %60,  %61,  %62,  %63,\n"
        "   %64,  %65,  %66,  %67,  %68,  %69,  %70,  %71,\n"
        "   %72,  %73,  %74,  %75,  %76,  %77,  %78,  %79,\n"
        "   %80,  %81,  %82,  %83,  %84,  %85,  %86,  %87,\n"
        "   %88,  %89,  %90,  %91,  %92,  %93,  %94,  %95,\n"
        "   %96,  %97,  %98,  %99, %100, %101, %102, %103,\n"
        "  %104, %105, %106, %107, %108, %109, %110, %111,\n"
        "  %112, %113, %114, %115, %116, %117, %118, %119,\n"
        "  %120, %121, %122, %123, %124, %125, %126, %127 \n"
        "},\n"
        "  %128, %129, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]),
          "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]),
          "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]),
          "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]),
          "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]),
          "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]),
          "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]),
          "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]),
          "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]),
          "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]),
          "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]),
          "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]),
          "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]),
          "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]),
          "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]),
          "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]),
          "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]),
          "+r"(acc[126]), "+r"(acc[127])
        : "l"(desc_a), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int N, bool /*ignored*/>
inline __device__ void igmma_int8_int32(uint64_t desc_a, uint64_t desc_b, uint32_t (&acc)[N / 2]) {
  Igmma_int8_int32<N>::mma(desc_a, desc_b, acc);
}

////////////////////////////////////////////////////////////////////////////////////////////////////
//
// IGMMA 64xNx32 TN with int32 Accumulator with A from RF and B from SMEM.
//
////////////////////////////////////////////////////////////////////////////////////////////////////

template <int N>
struct Igmma_rfa_int8_int32 {};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x64x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_rfa_int8_int32<64> {
  static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[32]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8\n"
        "{\n"
        "   %0,  %1,  %2,  %3,  %4,  %5,  %6,  %7,\n"
        "   %8,  %9, %10, %11, %12, %13, %14, %15,\n"
        "  %16, %17, %18, %19, %20, %21, %22, %23,\n"
        "  %24, %25, %26, %27, %28, %29, %30, %31 \n"
        "},\n"
        "{ %32, %33, %34, %35 }, %36, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x128x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_rfa_int8_int32<128> {
  static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[64]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8\n"
        "{\n"
        "   %0,  %1,  %2,  %3,  %4,  %5,  %6,  %7,\n"
        "   %8,  %9, %10, %11, %12, %13, %14, %15,\n"
        "  %16, %17, %18, %19, %20, %21, %22, %23,\n"
        "  %24, %25, %26, %27, %28, %29, %30, %31,\n"
        "  %32, %33, %34, %35, %36, %37, %38, %39,\n"
        "  %40, %41, %42, %43, %44, %45, %46, %47,\n"
        "  %48, %49, %50, %51, %52, %53, %54, %55,\n"
        "  %56, %57, %58, %59, %60, %61, %62, %63 \n"
        "},\n"
        "{ %64, %65, %66, %67 }, %68, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]),
          "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]),
          "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]),
          "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]),
          "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]),
          "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x192x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_rfa_int8_int32<192> {
  static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[96]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8\n"
        "{\n"
        "    %0,   %1,   %2,   %3,   %4,   %5,   %6,   %7,\n"
        "    %8,   %9,  %10,  %11,  %12,  %13,  %14,  %15,\n"
        "   %16,  %17,  %18,  %19,  %20,  %21,  %22,  %23,\n"
        "   %24,  %25,  %26,  %27,  %28,  %29,  %30,  %31,\n"
        "   %32,  %33,  %34,  %35,  %36,  %37,  %38,  %39,\n"
        "   %40,  %41,  %42,  %43,  %44,  %45,  %46,  %47,\n"
        "   %48,  %49,  %50,  %51,  %52,  %53,  %54,  %55,\n"
        "   %56,  %57,  %58,  %59,  %60,  %61,  %62,  %63,\n"
        "   %64,  %65,  %66,  %67,  %68,  %69,  %70,  %71,\n"
        "   %72,  %73,  %74,  %75,  %76,  %77,  %78,  %79,\n"
        "   %80,  %81,  %82,  %83,  %84,  %85,  %86,  %87,\n"
        "   %88,  %89,  %90,  %91,  %92,  %93,  %94,  %95 \n"
        "},\n"
        "{ %96, %97, %98, %99 }, %100, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]),
          "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]),
          "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]),
          "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]),
          "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]),
          "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]),
          "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]),
          "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]),
          "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]),
          "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]),
          "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////
// 64x256x32
////////////////////////////////////////////////////////////////////////////////////////////////////

template <>
struct Igmma_rfa_int8_int32<256> {
  static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[128]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8\n"
        "{\n"
        "    %0,   %1,   %2,   %3,   %4,   %5,   %6,   %7,\n"
        "    %8,   %9,  %10,  %11,  %12,  %13,  %14,  %15,\n"
        "   %16,  %17,  %18,  %19,  %20,  %21,  %22,  %23,\n"
        "   %24,  %25,  %26,  %27,  %28,  %29,  %30,  %31,\n"
        "   %32,  %33,  %34,  %35,  %36,  %37,  %38,  %39,\n"
        "   %40,  %41,  %42,  %43,  %44,  %45,  %46,  %47,\n"
        "   %48,  %49,  %50,  %51,  %52,  %53,  %54,  %55,\n"
        "   %56,  %57,  %58,  %59,  %60,  %61,  %62,  %63,\n"
        "   %64,  %65,  %66,  %67,  %68,  %69,  %70,  %71,\n"
        "   %72,  %73,  %74,  %75,  %76,  %77,  %78,  %79,\n"
        "   %80,  %81,  %82,  %83,  %84,  %85,  %86,  %87,\n"
        "   %88,  %89,  %90,  %91,  %92,  %93,  %94,  %95,\n"
        "   %96,  %97,  %98,  %99, %100, %101, %102, %103,\n"
        "  %104, %105, %106, %107, %108, %109, %110, %111,\n"
        "  %112, %113, %114, %115, %116, %117, %118, %119,\n"
        "  %120, %121, %122, %123, %124, %125, %126, %127 \n"
        "},\n"
        "{ %128, %129, %130, %131 }, %132, 1;\n"

        : "+r"(acc[0]), "+r"(acc[1]), "+r"(acc[2]), "+r"(acc[3]), "+r"(acc[4]), "+r"(acc[5]),
          "+r"(acc[6]), "+r"(acc[7]), "+r"(acc[8]), "+r"(acc[9]), "+r"(acc[10]), "+r"(acc[11]),
          "+r"(acc[12]), "+r"(acc[13]), "+r"(acc[14]), "+r"(acc[15]), "+r"(acc[16]), "+r"(acc[17]),
          "+r"(acc[18]), "+r"(acc[19]), "+r"(acc[20]), "+r"(acc[21]), "+r"(acc[22]), "+r"(acc[23]),
          "+r"(acc[24]), "+r"(acc[25]), "+r"(acc[26]), "+r"(acc[27]), "+r"(acc[28]), "+r"(acc[29]),
          "+r"(acc[30]), "+r"(acc[31]), "+r"(acc[32]), "+r"(acc[33]), "+r"(acc[34]), "+r"(acc[35]),
          "+r"(acc[36]), "+r"(acc[37]), "+r"(acc[38]), "+r"(acc[39]), "+r"(acc[40]), "+r"(acc[41]),
          "+r"(acc[42]), "+r"(acc[43]), "+r"(acc[44]), "+r"(acc[45]), "+r"(acc[46]), "+r"(acc[47]),
          "+r"(acc[48]), "+r"(acc[49]), "+r"(acc[50]), "+r"(acc[51]), "+r"(acc[52]), "+r"(acc[53]),
          "+r"(acc[54]), "+r"(acc[55]), "+r"(acc[56]), "+r"(acc[57]), "+r"(acc[58]), "+r"(acc[59]),
          "+r"(acc[60]), "+r"(acc[61]), "+r"(acc[62]), "+r"(acc[63]), "+r"(acc[64]), "+r"(acc[65]),
          "+r"(acc[66]), "+r"(acc[67]), "+r"(acc[68]), "+r"(acc[69]), "+r"(acc[70]), "+r"(acc[71]),
          "+r"(acc[72]), "+r"(acc[73]), "+r"(acc[74]), "+r"(acc[75]), "+r"(acc[76]), "+r"(acc[77]),
          "+r"(acc[78]), "+r"(acc[79]), "+r"(acc[80]), "+r"(acc[81]), "+r"(acc[82]), "+r"(acc[83]),
          "+r"(acc[84]), "+r"(acc[85]), "+r"(acc[86]), "+r"(acc[87]), "+r"(acc[88]), "+r"(acc[89]),
          "+r"(acc[90]), "+r"(acc[91]), "+r"(acc[92]), "+r"(acc[93]), "+r"(acc[94]), "+r"(acc[95]),
          "+r"(acc[96]), "+r"(acc[97]), "+r"(acc[98]), "+r"(acc[99]), "+r"(acc[100]),
          "+r"(acc[101]), "+r"(acc[102]), "+r"(acc[103]), "+r"(acc[104]), "+r"(acc[105]),
          "+r"(acc[106]), "+r"(acc[107]), "+r"(acc[108]), "+r"(acc[109]), "+r"(acc[110]),
          "+r"(acc[111]), "+r"(acc[112]), "+r"(acc[113]), "+r"(acc[114]), "+r"(acc[115]),
          "+r"(acc[116]), "+r"(acc[117]), "+r"(acc[118]), "+r"(acc[119]), "+r"(acc[120]),
          "+r"(acc[121]), "+r"(acc[122]), "+r"(acc[123]), "+r"(acc[124]), "+r"(acc[125]),
          "+r"(acc[126]), "+r"(acc[127])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b));
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int N, bool /*ignored*/>
inline __device__ void igmma_rfa_int8_int32(uint32_t const (&a)[4], uint64_t desc_b,
                                            uint32_t (&acc)[N / 2]) {
  Igmma_rfa_int8_int32<N>::mma(a, desc_b, acc);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace fmha
