// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
 https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
*/

__global__ void NAME (TYPE* A __readonly  __noalias __aligned(16),
                        TYPE* B __readonly  __noalias __aligned(16),
                        TYPE* C __noalias __aligned(16),
                        int lda __multipleof(8),
                        int ldb __multipleof(8),
                        int ldc __multipleof(8),
                        long stride_za __multipleof(8),
                        long stride_zb __multipleof(8),
                        long stride_zc __multipleof(8),
                        long stride_ha __multipleof(8),
                        long stride_hb __multipleof(8),
                        long stride_hc __multipleof(8),
                        int DS0, int DS1,
                        int SDD_K __multipleof(16),
                        int SDD_off_width,
                        int* lut, int* locks, int nlocks) {
    /* ---------------- */
    /*    Prologue      */
    /* ---------------- */
    // program ids
    int pid0 = get_program_id(0);
    int pid1 = get_program_id(1);
    int pidz = get_program_id(2);
#ifdef SDD
    // load LUT header
    pid1 = pid1 + SDD_off_width;
    int blockidm[TM] = (0 ... TM) / BLOCK;
    int blockidn[TN] = (0 ... TN) / BLOCK;
    int offlutm[TM]  = blockidm*(TN/BLOCK)*4;
    int offlutn[TN]  = blockidn*4;
    int *header      = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
    int z            = *(header + 0);
    int i[TM]        = *(header + 1 + offlutm);
    int j[TN]        = *(header + 2 + offlutn);
    int AS1 = SDD_K / TZ;
    int lockid = select(TZ > 1, 1, 0);
    int offka  = pid0 * AS1;
    int offkb  = pid0 * AS1;
    int offmc  = 0;
    int offnc  = 0;
    int offpa  = 0;
    int offpb  = 0;
    int maxid = TZ;
    int offhc = 0;
    int offha = z;
    int offhb = z;
    int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
    int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
#else
    // load LUT header
    int *header = lut + pid0 * 6;
    int offset = *(header + 0);
    int AS1    = *(header + 1);
    int column = *(header + 2);
    int depth  = *(header + 3);
    int lockid = *(header + 4);
    int maxid  = *(header + 5);
    int *pinc  = lut + offset;
    int offhc = depth;
#ifdef DSD
    // output offset
    int offnc = pid1 * TN;
    int offmc = column * TM;
    int offpc = 0;
    // dense input offset
    int offnb = pid1 * TN;
    int offkb __multipleof(8) = *pinc;
    int offpb = 0;
    // sparse input offset
    int offma = 0;
    int offka = 0;
    long offpa __multipleof(8) = *(pinc + 1);
    offpa = offpa * BLOCK * BLOCK;
    int offha = 0;
    int offhb = depth;
#endif
#ifdef DDS
    // output offset
    int offmc = pid1 * TM;
    int offnc = column * TN;
    int offpc = 0;
    // dense input offset
    int offma = pid1 * TM;
    int offka __multipleof(8) = *pinc;
    int offpa = 0;
    // sparse input offset
    int offnb = 0;
    int offkb = 0;
    long offpb __multipleof(8) = *(pinc + 1);
    offpb = offpb * BLOCK * BLOCK;
    int offha = depth;
    int offhb = 0;
#endif
    int ram[TM] = offma + 0 ... TM;
    int rbn[TN] = offnb + 0 ... TN;
#endif
    // initialize a, b pointers
    int rka[TK] = offka + 0 ... TK;
    int rkb[TK] = offkb + 0 ... TK;
    TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
    TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
    // pre-fetch
#ifdef DDS
    bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
    bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
    bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
#else
    bool checkbn[TK, TN] = AS1 > 0;
#endif
    TYPE a[TM, TK] = checkam ? *pa : 0;
    TYPE b[TK, TN] = checkbn ? *pb : 0;

    /* ---------------- */
    /*    Inner Loop    */
    /* ---------------- */
    // create result tile
    float acc[TM, TN] = 0;
    int step = TK;
    for(int k = AS1; k > 0; k -= step) {
      acc += a @ b;
      // update pointers
#ifdef SDD
      int inc_a = TK * STRIDE_AK;
      int inc_b = TK * STRIDE_BK;
#else
      pinc += 2;
#ifdef DSD
      int inc_b __multipleof(8) = *pinc;
      int inc_a __multipleof(8) = *(pinc + 1);
      inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
      int inc_a __multipleof(8) = *pinc;
      int inc_b __multipleof(8) = *(pinc + 1);
      inc_a = inc_a * STRIDE_AK;
#endif
#endif
      pa += inc_a;
      pb += inc_b;
      // pre-fetch
      bool checkak[TM, TK] = k > TK;
      bool checkbk[TK, TN] = k > TK;
      bool checka[TM, TK] = checkam && checkak;
      bool checkb[TK, TN] = checkbk && checkbn;
      a = *?(checka)pa;
      b = *?(checkb)pb;
    }
    TYPE c[TM, TN] = acc;

    /* ---------------- */
    /*    Epilogue      */
    /* ---------------- */
    // initialize c pointers
#ifdef SDD
    bool checkc[TM, TN] = 1;
    // rematerialize
    int rr_blockidm[TM]  = (0 ... TM) / BLOCK;
    int rr_blockidn[TN]  = (0 ... TN) / BLOCK;
    int rr_offlutm[TM]   = rr_blockidm*(TN/BLOCK)*4;
    int rr_offlutn[TN]   = rr_blockidn*4;
    int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
    int bkid[TM, TN]     = *(header + off_bkid);
    long offpc[TM, TN]   = bkid * BLOCK * BLOCK;
    // range within blocks
    int   rcm[TM]    = (0 ... TM) % BLOCK;
    int   rcn[TN]    = (0 ... TN) % BLOCK;
#else
    int   rcm[TM]    = offmc + 0 ... TM;
    int   rcn[TN]    = offnc + 0 ... TN;
#ifdef DSD
    bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
#endif
#ifdef DDS
    bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
    TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
    // write-back directly
    if(lockid == 0) {
      *?(checkc) pc = c;
    }
    // accumulate partial result using spin-locks
    else {
      int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
      int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
      for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
      int count = *pcount;
      if(count == 0)
        *?(checkc) pc = c;
      else
        *?(checkc) pc = c + *?(checkc)pc;
      atomic_xchg(pcount, (count + 1) % maxid);
      atomic_xchg(plock, 0);
    }
  }
