// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
 https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
*/

__global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
                            float scale,
                            int *LUT __readonly __noalias __aligned(16),
                            TYPE *RPE __readonly __noalias __aligned(16),
                            TYPE *KP_M __readonly __noalias __aligned(16),
                            TYPE *ATTN_M __readonly __noalias __aligned(16),
                            int num_blocks,
                            int sizemax,
                            long stride_zx __multipleof(BLOCK),
                            long stride_zrpe __multipleof(BLOCK),
                            int stride_hrpe __multipleof(BLOCK),
                            int stride_srpe __multipleof(BLOCK),
                            int stride_zkpm __multipleof(BLOCK),
                            int stride_zattnm __multipleof(BLOCK)){
  int pidhm = get_program_id(0);
  int pidz = get_program_id(1);

  // create index ranges
  int rxm     = pidhm % BLOCK;
  int rbm     = pidhm / BLOCK;
  int rxn[TN] = (0 ... TN) % BLOCK;
  int rbn[TN] = (0 ... TN) / BLOCK;

  // extract information from look-up table
  int* header = LUT + rbm * 2;
  int size    = *(header + 0);
  int offset  = *(header + 1);

  bool check[TN] = rbn < size;
  int   rbmn[TN] = check ? rbn : size - 1;

  // block id and column id
  long blockid [TN]  = *(LUT + offset + rbmn*4 + 0);
  long columnid[TN]  = *(LUT + offset + rbmn*4 + 1);
  long rowid   [TN]  = *(LUT + offset + rbmn*4 + 2);
  long headid  [TN]  = *(LUT + offset + rbmn*4 + 3);

  // pointers to X
  TYPE* px[TN]  = X + pidz * stride_zx
                    + blockid * BLOCK * BLOCK
                    + rxm * BLOCK
                    + rxn;
#ifdef APPLY_RPE
  // pointers to relative position embedding
  TYPE* prpe[TN] = RPE + pidz * stride_zrpe
                            + headid * stride_hrpe
                            + columnid * BLOCK
                            + rowid * BLOCK * stride_srpe
                            + rxm * stride_srpe
                            + rxn;
#endif

#ifdef APPLY_KP_MASK
  // pointers to key padding mask
  TYPE* pkp_m[TN]  = KP_M + pidz * stride_zkpm
                          + columnid * BLOCK
                          + rxn;
#endif

#ifdef APPLY_ATTN_MASK
  // pointers to attention mask
  TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
                             + rowid * BLOCK * stride_zattnm
                             + rxm * stride_zattnm
                             + rxn;
#endif

  // load  input
  TYPE x[TN] =  check ? *px : -INFINITY;

#ifdef APPLY_RPE
  // load relative position embedding
  TYPE rpe[TN] = check ? *prpe : 0;
#endif

#ifdef APPLY_KP_MASK
  // load key-padding mask
  TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
#endif

#ifdef APPLY_ATTN_MASK
  // load attention mask
  TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
#endif

  // compute softmax in float
#ifdef APPLY_RPE
  float Frpe[TN] = rpe;
#endif

#ifdef APPLY_KP_MASK
  float Fkp_m[TN] = kp_m;
#endif

#ifdef APPLY_ATTN_MASK
  float Fattn_m[TN] = attn_m;
#endif

#ifdef KP_MASK_MUL
  Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
#endif

#ifdef ATTN_MASK_MUL
  Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
#endif

  float Fx[TN] = x;

#ifdef APPLY_SCALE
  Fx = Fx * scale; // apply scale
#endif

#ifdef APPLY_RPE
  Fx = Fx + Frpe; // apply relative position embedding
#endif

#ifdef APPLY_KP_MASK
  Fx = Fx + Fkp_m; // apply key padding mask
#endif

#ifdef APPLY_ATTN_MASK
  Fx = Fx + Fattn_m; // apply attention mask
#endif

  float Fxmax  = Fx[max];
  float Fy[TN] = exp(Fx - Fxmax);
  float Fysum = (check ? Fy : 0)[+];

  // write-back in half/float
  TYPE y[TN] = Fy;
  TYPE ysum = Fysum;
  *?(check)px = y / ysum;
}
