# ruff: noqa: ANN001, ANN201, ERA001, N803, N806
import triton
import triton.language as tl


@triton.jit
def fused_add_round_kernel(
    x_ptr,
    output_ptr,  # contents will be added to the output
    seed,
    n_elements,
    EXPONENT_BIAS,
    MANTISSA_BITS,
    BLOCK_SIZE: tl.constexpr,
):
    """
    A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
    and add them to bfloat16 output weights. Might be used to upcast original model weights
    and to further add them to precalculated deltas coming from LoRAs.
    """
    # Get program ID and compute offsets
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    rand_vals = tl.rand(seed, offsets) - 0.5

    x = tl.cast(x, tl.float16)
    delta = tl.load(output_ptr + offsets, mask=mask)
    delta = tl.cast(delta, tl.float16)
    x = x + delta

    x_bits = tl.cast(x, tl.int16, bitcast=True)

    # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
    # normal numbers and -14 for subnormals.
    fp16_exponent_bits = (x_bits & 0x7C00) >> 10
    fp16_normals = fp16_exponent_bits > 0
    fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)

    # Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
    exponent = fp16_exponent + EXPONENT_BIAS
    MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
    exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
    exponent = tl.where(exponent < 0, 0, exponent)

    # Normal ULP exponent, expressed as an fp16 exponent field:
    # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
    # Simplifies to: fp16_exponent - MANTISSA_BITS + 15
    # See https://en.wikipedia.org/wiki/Unit_in_the_last_place
    eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))

    # Calculate epsilon in the target dtype
    eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)

    # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
    # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
    # 16 - EXPONENT_BIAS - MANTISSA_BITS
    eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
    eps = tl.where(exponent > 0, eps_normal, eps_subnormal)

    # Apply zero mask to epsilon
    eps = tl.where(x == 0, 0.0, eps)

    # Apply stochastic rounding
    output = tl.cast(x + rand_vals * eps, tl.bfloat16)

    # Store the result
    tl.store(output_ptr + offsets, output, mask=mask)
