/*
 * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights
 * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

#pragma once

namespace fmha {

struct AlibiParams {
  constexpr static int round_down_to_power_two(int x) {
    x = x | (x >> 1);
    x = x | (x >> 2);
    x = x | (x >> 4);
    x = x | (x >> 8);
    x = x | (x >> 16);
    return x - (x >> 1);
  }

  AlibiParams() = default;

  AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
    h_pow_2 = round_down_to_power_two(h);
    alibi_neg4_div_h = -4.0f / h_pow_2;
  }

  AlibiParams(int h, int s, int tp_size, int rank, float scale_after_alibi = 1.f)
      : AlibiParams(h * tp_size, scale_after_alibi) {
    head_idx_offset = h * rank;
    sequence_pos_offset = s * rank;
  }

  int h_pow_2{};
  float alibi_neg4_div_h{};
  float scale_after_alibi{};
  // Could be simplified to `int rank` derive the others as `num_heads * rank, s * rank` at
  // runtime, but this makes assumptions about the layout downstream
  // (e.g. downstream may only split across the head dimension, so s would be the full sequence)
  int head_idx_offset = 0;
  int sequence_pos_offset = 0;
};

}  // namespace fmha
