// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */
// clang-format on

#ifndef CUTE_DSL_RUNTIME_H
#define CUTE_DSL_RUNTIME_H

#include <stddef.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef struct BinaryModule CuteDSLRT_Module_t;
typedef struct BinaryFunction CuteDSLRT_Function_t;

typedef enum {
  // Indicates that the API call completed successfully without any errors.
  CuteDSLRT_Success = 0,
  CuteDSLRT_Error_Success = 0,
  // Indicates that a CUDA error occurred during the API call.
  CuteDSLRT_Error_CudaError = 1,
  // Indicates that the given binary bytes are illegal or invalid.
  CuteDSLRT_Error_InvalidBinary = 2,
  // Indicates that the binary module's metadata is invalid. This is likely due
  // to an internal error; please report this issue.
  CuteDSLRT_Error_InvalidMetadata = 3,
  // Indicates that the binary module's version is not compatible with the
  // current runtime.
  CuteDSLRT_Error_InvalidVersion = 4,
  // Indicates that the cuda library is already loaded. This is likely happened
  // when the executor is created twice.
  CuteDSLRT_Error_LibraryAlreadyLoaded = 5,
  // Indicates that the cuda kernel is not found. This is likely due to an
  // internal error; please report this issue.
  CuteDSLRT_Error_KernelNotFound = 6,
  // Indicates that the arguments of the API call are invalid. Likely NULL
  // pointer or wrong size.
  CuteDSLRT_Error_InvalidArguments = 7,
  // Indicates that no binary was loaded before execution.
  CuteDSLRT_Error_NoBinaryLoaded = 8,
  // Indicates that a binary was already loaded into the executor.
  CuteDSLRT_Error_BinaryAlreadyLoaded = 9,
} CuteDSLRT_Error_t;

/**
 * @brief Creates a new CUTE DSL runtime module from the given binary bytes.
 *
 * This function loads the provided binary module from the given bytes,
 * initializing all necessary runtime structures and CUDA libraries for
 * execution. The bytes could be derived from an object file dumped by python
 * side. The binary bytes are no longer needed after the module is created.
 * The returned pointer is an opaque handle to the module instance,
 * which could be destroyed with `CuteDSLRT_Module_Destroy` when no longer
 * needed.
 *
 * @param module_obj         Opaque pointer to the module instance.
 * @param binary_bytes       Pointer to the binary bytes.
 * @param binary_bytes_size  Size of the binary bytes.
 * @param shared_libs        Array of paths to shared libraries.
 * @param shared_libs_size   Size of the shared_libs array.
 * @return Error code due to runtime failure.
 */
__attribute__((visibility("default"))) CuteDSLRT_Error_t
CuteDSLRT_Module_Create_From_Bytes(CuteDSLRT_Module_t **module_obj,
                                   const unsigned char *binary_bytes,
                                   size_t binary_bytes_size,
                                   const char **shared_libs,
                                   size_t shared_libs_size);

/**
 * @brief Gets the function from the module using the provided function prefix.
 *
 * This function gets the function from the module using the provided function
 * prefix. The function prefix must be a valid function prefix.
 * If the function is not found, the function will return an error.
 *
 * @param func Opaque pointer to the function instance.
 * @param module_obj Opaque pointer to the module instance.
 * @param function_prefix Function prefix to get the function from the module.
 * @return Error code due to runtime failure.
 */
__attribute__((visibility("default"))) CuteDSLRT_Error_t
CuteDSLRT_Module_Get_Function(CuteDSLRT_Function_t **func,
                              CuteDSLRT_Module_t *module_obj,
                              const char *function_prefix);

/**
 * @brief Executes the compiled function using the provided arguments.
 *
 * This function launches the compiled function on the device using the
 * specified arguments and function handle. The arguments must match the
 * function signature expected by the compiled module. The function handle must
 * be valid and could be either the handle `CuteDSLRT_Function_t` or more
 * extendable handle in the future.
 * If execution fails, the function will raise an runtime error.
 *
 * @param func   Opaque pointer to the function.
 * @param args       Array of pointers to arguments to pass to the function.
 * @param num_args   Number of arguments in the args array.
 * @return Error code due to runtime failure.
 */
__attribute__((visibility("default"))) CuteDSLRT_Error_t
CuteDSLRT_Function_Run(void *func, void **args, size_t num_args);

/**
 * @brief Destroys a CUTE DSL runtime module instance.
 *
 * This function releases all resources associated with the module instance
 * previously created by `CuteDSLRT_Module_Create_xxx`. This includes
 * unloading any loaded CUDA modules and cleaning up all internal runtime
 * structures. After calling this function, the module pointer will be set to
 * nullptr. The `module_obj` handle must come from
 * `CuteDSLRT_Module_Create_xxx`.
 *
 * @param module_obj   Opaque pointer to the module instance to destroy.
 */
__attribute__((visibility("default"))) CuteDSLRT_Error_t
CuteDSLRT_Module_Destroy(CuteDSLRT_Module_t *module_obj);

/**
 * @brief Returns a string containing the name of an error code in the enum.
 *
 * If the error code is not recognized, "Unknown" is returned.
 *
 * @param error              The error code
 * @return char* pointer to a NULL-terminated string
 */
__attribute__((visibility("default"))) const char *
CuteDSLRT_GetErrorName(CuteDSLRT_Error_t error);

/**
 * @brief Returns a string containing a description of an error code in the
 * enum.
 *
 * If the error code is not recognized, "unrecognized error code" is returned.
 *
 * @param error              The error code
 * @return char* pointer to a NULL-terminated string
 */
__attribute__((visibility("default"))) const char *
CuteDSLRT_GetErrorString(CuteDSLRT_Error_t error);

#ifdef __cplusplus
}
#endif

#endif /*CUTE_DSL_RUNTIME_H*/
