// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
 https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
*/

__global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16),
                            float scale,
                            TYPE* DX __readonly __noalias __aligned(16),
                            int* LUT,
                            int sizemax,
                            long stride_zx __multipleof(BLOCK),
                            long stride_zdx __multipleof(BLOCK)) {
    int pidhm = get_program_id(0);
    int pidz = get_program_id(1);

    // create index ranges
    int rxm = pidhm % BLOCK;
    int rbm = pidhm / BLOCK;
    int rxn[TN] = (0 ... TN) % BLOCK;
    int rbn[TN] = (0 ... TN) / BLOCK;

    // extract information from look-up table
    int* header = LUT + rbm * 2;
    int size    = *(header + 0);
    int offset  = *(header + 1);

    // bounds checking on lut
    bool check[TN] = rbn < size;
    int rbmn[TN] = check ? rbn : size - 1;

    // initialize pointers to block-sparse input
    long blockid[TN] = *(LUT + offset + rbmn*4);

    TYPE* px[TN] = X + pidz * stride_zx
                         + blockid * BLOCK * BLOCK
                         + rxm * BLOCK
                         + rxn;

    TYPE* pdx[TN] = DX + pidz * stride_zdx
                           + blockid * BLOCK * BLOCK
                           + rxm * BLOCK
                           + rxn;

    // compute fused softmax backward
    TYPE x[TN] = check ? *px : 0;
    TYPE dx[TN] = check ? *pdx : 0;
    float Fdx[TN] = dx;
    float Fx[TN] = x;
    float Fxdx[TN] = Fdx*Fx;
    float Fxdxsum = Fxdx[+];
    float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
    TYPE y[TN] = Fy;

    // write-back
    *? (check)pdx = y;
}
