from llvmlite import ir
from numba.core import ir as numba_ir
from numba.core import (
    cgutils,
    types,
    typing,
    funcdesc,
    config,
    compiler,
    sigutils,
    utils,
)
from numba.core.compiler import (
    sanitize_compile_result_entries,
    CompilerBase,
    DefaultPassBuilder,
    Flags,
    Option,
    CompileResult,
)
from numba.core.compiler_lock import global_compiler_lock
from numba.core.compiler_machinery import (
    FunctionPass,
    LoweringPass,
    PassManager,
    register_pass,
)
from numba.core.interpreter import Interpreter
from numba.core.errors import NumbaInvalidConfigWarning
from numba.core.untyped_passes import TranslateByteCode
from numba.core.typed_passes import (
    IRLegalization,
    NativeLowering,
    AnnotateTypes,
)
from warnings import warn
from numba.cuda import nvvmutils
from numba.cuda.api import get_current_device
from numba.cuda.codegen import ExternalCodeLibrary
from numba.cuda.cudadrv import nvvm
from numba.cuda.descriptor import cuda_target
from numba.cuda.target import CUDACABICallConv
from numba.cuda import lowering


def _nvvm_options_type(x):
    if x is None:
        return None

    else:
        assert isinstance(x, dict)
        return x


def _optional_int_type(x):
    if x is None:
        return None

    else:
        assert isinstance(x, int)
        return x


class CUDAFlags(Flags):
    nvvm_options = Option(
        type=_nvvm_options_type,
        default=None,
        doc="NVVM options",
    )
    compute_capability = Option(
        type=tuple,
        default=None,
        doc="Compute Capability",
    )
    max_registers = Option(
        type=_optional_int_type, default=None, doc="Max registers"
    )
    lto = Option(type=bool, default=False, doc="Enable Link-time Optimization")


# The CUDACompileResult (CCR) has a specially-defined entry point equal to its
# id.  This is because the entry point is used as a key into a dict of
# overloads by the base dispatcher. The id of the CCR is the only small and
# unique property of a CompileResult in the CUDA target (cf. the CPU target,
# which uses its entry_point, which is a pointer value).
#
# This does feel a little hackish, and there are two ways in which this could
# be improved:
#
# 1. We could change the core of Numba so that each CompileResult has its own
#    unique ID that can be used as a key - e.g. a count, similar to the way in
#    which types have unique counts.
# 2. At some future time when kernel launch uses a compiled function, the entry
#    point will no longer need to be a synthetic value, but will instead be a
#    pointer to the compiled function as in the CPU target.


class CUDACompileResult(CompileResult):
    @property
    def entry_point(self):
        return id(self)


def cuda_compile_result(**entries):
    entries = sanitize_compile_result_entries(entries)
    return CUDACompileResult(**entries)


@register_pass(mutates_CFG=True, analysis_only=False)
class CUDABackend(LoweringPass):
    _name = "cuda_backend"

    def __init__(self):
        LoweringPass.__init__(self)

    def run_pass(self, state):
        """
        Back-end: Packages lowering output in a compile result
        """
        lowered = state["cr"]
        signature = typing.signature(state.return_type, *state.args)

        state.cr = cuda_compile_result(
            typing_context=state.typingctx,
            target_context=state.targetctx,
            typing_error=state.status.fail_reason,
            type_annotation=state.type_annotation,
            library=state.library,
            call_helper=lowered.call_helper,
            signature=signature,
            fndesc=lowered.fndesc,
        )
        return True


@register_pass(mutates_CFG=False, analysis_only=False)
class CreateLibrary(LoweringPass):
    """
    Create a CUDACodeLibrary for the NativeLowering pass to populate. The
    NativeLowering pass will create a code library if none exists, but we need
    to set it up with nvvm_options from the flags if they are present.
    """

    _name = "create_library"

    def __init__(self):
        LoweringPass.__init__(self)

    def run_pass(self, state):
        codegen = state.targetctx.codegen()
        name = state.func_id.func_qualname
        nvvm_options = state.flags.nvvm_options
        max_registers = state.flags.max_registers
        lto = state.flags.lto
        state.library = codegen.create_library(
            name,
            nvvm_options=nvvm_options,
            max_registers=max_registers,
            lto=lto,
        )
        # Enable object caching upfront so that the library can be serialized.
        state.library.enable_object_caching()

        return True


@register_pass(mutates_CFG=True, analysis_only=False)
class CUDANativeLowering(NativeLowering):
    """Lowering pass for a CUDA native function IR described solely in terms of
    Numba's standard `numba.core.ir` nodes."""

    _name = "cuda_native_lowering"

    @property
    def lowering_class(self):
        return lowering.CUDALower


class CUDABytecodeInterpreter(Interpreter):
    # Based on the superclass implementation, but names the resulting variable
    # "$bool<N>" instead of "bool<N>" - see Numba PR #9888:
    # https://github.com/numba/numba/pull/9888
    #
    # This can be removed once that PR is available in an upstream Numba
    # release.
    def _op_JUMP_IF(self, inst, pred, iftrue):
        brs = {
            True: inst.get_jump_target(),
            False: inst.next,
        }
        truebr = brs[iftrue]
        falsebr = brs[not iftrue]

        name = "$bool%s" % (inst.offset)
        gv_fn = numba_ir.Global("bool", bool, loc=self.loc)
        self.store(value=gv_fn, name=name)

        callres = numba_ir.Expr.call(
            self.get(name), (self.get(pred),), (), loc=self.loc
        )

        pname = "$%spred" % (inst.offset)
        predicate = self.store(value=callres, name=pname)
        bra = numba_ir.Branch(
            cond=predicate, truebr=truebr, falsebr=falsebr, loc=self.loc
        )
        self.current_block.append(bra)


@register_pass(mutates_CFG=True, analysis_only=False)
class CUDATranslateBytecode(FunctionPass):
    _name = "cuda_translate_bytecode"

    def __init__(self):
        FunctionPass.__init__(self)

    def run_pass(self, state):
        func_id = state["func_id"]
        bc = state["bc"]
        interp = CUDABytecodeInterpreter(func_id)
        func_ir = interp.interpret(bc)
        state["func_ir"] = func_ir
        return True


class CUDACompiler(CompilerBase):
    def define_pipelines(self):
        dpb = DefaultPassBuilder
        pm = PassManager("cuda")

        untyped_passes = dpb.define_untyped_pipeline(self.state)

        # Rather than replicating the whole untyped passes definition in
        # numba-cuda, it seems cleaner to take the pass list and replace the
        # TranslateBytecode pass with our own.

        def replace_translate_pass(implementation, description):
            if implementation is TranslateByteCode:
                return (CUDATranslateBytecode, description)
            else:
                return (implementation, description)

        cuda_untyped_passes = [
            replace_translate_pass(implementation, description)
            for implementation, description in untyped_passes.passes
        ]

        pm.passes.extend(cuda_untyped_passes)

        typed_passes = dpb.define_typed_pipeline(self.state)
        pm.passes.extend(typed_passes.passes)

        lowering_passes = self.define_cuda_lowering_pipeline(self.state)
        pm.passes.extend(lowering_passes.passes)

        pm.finalize()
        return [pm]

    def define_cuda_lowering_pipeline(self, state):
        pm = PassManager("cuda_lowering")
        # legalise
        pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")
        pm.add_pass(AnnotateTypes, "annotate types")

        # lower
        pm.add_pass(CreateLibrary, "create library")
        pm.add_pass(CUDANativeLowering, "cuda native lowering")
        pm.add_pass(CUDABackend, "cuda backend")

        pm.finalize()
        return pm


@global_compiler_lock
def compile_cuda(
    pyfunc,
    return_type,
    args,
    debug=False,
    lineinfo=False,
    forceinline=False,
    fastmath=False,
    nvvm_options=None,
    cc=None,
    max_registers=None,
    lto=False,
):
    if cc is None:
        raise ValueError("Compute Capability must be supplied")

    from .descriptor import cuda_target

    typingctx = cuda_target.typing_context
    targetctx = cuda_target.target_context

    flags = CUDAFlags()
    # Do not compile (generate native code), just lower (to LLVM)
    flags.no_compile = True
    flags.no_cpython_wrapper = True
    flags.no_cfunc_wrapper = True

    # Both debug and lineinfo turn on debug information in the compiled code,
    # but we keep them separate arguments in case we later want to overload
    # some other behavior on the debug flag. In particular, -opt=3 is not
    # supported with debug enabled, and enabling only lineinfo should not
    # affect the error model.
    if debug or lineinfo:
        flags.debuginfo = True

    if lineinfo:
        flags.dbg_directives_only = True

    if debug:
        flags.error_model = "python"
        flags.dbg_extend_lifetimes = True
    else:
        flags.error_model = "numpy"

    if forceinline:
        flags.forceinline = True
    if fastmath:
        flags.fastmath = True
    if nvvm_options:
        flags.nvvm_options = nvvm_options
    flags.compute_capability = cc
    flags.max_registers = max_registers
    flags.lto = lto

    # Run compilation pipeline
    from numba.core.target_extension import target_override

    with target_override("cuda"):
        cres = compiler.compile_extra(
            typingctx=typingctx,
            targetctx=targetctx,
            func=pyfunc,
            args=args,
            return_type=return_type,
            flags=flags,
            locals={},
            pipeline_class=CUDACompiler,
        )

    library = cres.library
    library.finalize()

    return cres


def cabi_wrap_function(
    context, lib, fndesc, wrapper_function_name, nvvm_options
):
    """
    Wrap a Numba ABI function in a C ABI wrapper at the NVVM IR level.

    The C ABI wrapper will have the same name as the source Python function.
    """
    # The wrapper will be contained in a new library that links to the wrapped
    # function's library
    library = lib.codegen.create_library(
        f"{lib.name}_function_",
        entry_name=wrapper_function_name,
        nvvm_options=nvvm_options,
    )
    library.add_linking_library(lib)

    # Determine the caller (C ABI) and wrapper (Numba ABI) function types
    argtypes = fndesc.argtypes
    restype = fndesc.restype
    c_call_conv = CUDACABICallConv(context)
    wrapfnty = c_call_conv.get_function_type(restype, argtypes)
    fnty = context.call_conv.get_function_type(fndesc.restype, argtypes)

    # Create a new module and declare the callee
    wrapper_module = context.create_module("cuda.cabi.wrapper")
    func = ir.Function(wrapper_module, fnty, fndesc.llvm_func_name)

    # Define the caller - populate it with a call to the callee and return
    # its return value

    wrapfn = ir.Function(wrapper_module, wrapfnty, wrapper_function_name)
    builder = ir.IRBuilder(wrapfn.append_basic_block(""))

    arginfo = context.get_arg_packer(argtypes)
    callargs = arginfo.from_arguments(builder, wrapfn.args)
    # We get (status, return_value), but we ignore the status since we
    # can't propagate it through the C ABI anyway
    _, return_value = context.call_conv.call_function(
        builder, func, restype, argtypes, callargs
    )
    builder.ret(return_value)

    if config.DUMP_LLVM:
        utils.dump_llvm(fndesc, wrapper_module)

    library.add_ir_module(wrapper_module)
    library.finalize()
    return library


def kernel_fixup(kernel, debug):
    if debug:
        exc_helper = add_exception_store_helper(kernel)

    # Pass 1 - replace:
    #
    #    ret <value>
    #
    # with:
    #
    #    exc_helper(<value>)
    #    ret void

    for block in kernel.blocks:
        for i, inst in enumerate(block.instructions):
            if isinstance(inst, ir.Ret):
                old_ret = block.instructions.pop()
                block.terminator = None

                # The original return's metadata will be set on the new
                # instructions in order to preserve debug info
                metadata = old_ret.metadata

                builder = ir.IRBuilder(block)
                if debug:
                    status_code = old_ret.operands[0]
                    exc_helper_call = builder.call(exc_helper, (status_code,))
                    exc_helper_call.metadata = metadata

                new_ret = builder.ret_void()
                new_ret.metadata = old_ret.metadata

                # Need to break out so we don't carry on modifying what we are
                # iterating over. There can only be one return in a block
                # anyway.
                break

    # Pass 2: remove stores of null pointer to return value argument pointer

    return_value = kernel.args[0]

    for block in kernel.blocks:
        remove_list = []

        # Find all stores first
        for inst in block.instructions:
            if (
                isinstance(inst, ir.StoreInstr)
                and inst.operands[1] == return_value
            ):
                remove_list.append(inst)

        # Remove all stores
        for to_remove in remove_list:
            block.instructions.remove(to_remove)

    # Replace non-void return type with void return type and remove return
    # value

    if isinstance(kernel.type, ir.PointerType):
        new_type = ir.PointerType(
            ir.FunctionType(ir.VoidType(), kernel.type.pointee.args[1:])
        )
    else:
        new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:])

    kernel.type = new_type
    kernel.return_value = ir.ReturnValue(kernel, ir.VoidType())
    kernel.args = kernel.args[1:]

    # If debug metadata is present, remove the return value from it

    if kernel_metadata := getattr(kernel, "metadata", None):
        if dbg_metadata := kernel_metadata.get("dbg", None):
            for name, value in dbg_metadata.operands:
                if name == "type":
                    type_metadata = value
                    for tm_name, tm_value in type_metadata.operands:
                        if tm_name == "types":
                            types = tm_value
                            types.operands = types.operands[1:]
                            if config.DUMP_LLVM:
                                types._clear_string_cache()

    # Mark as a kernel for NVVM

    nvvm.set_cuda_kernel(kernel)

    if config.DUMP_LLVM:
        print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, "-"))
        print(kernel.module)
        print("=" * 80)


def add_exception_store_helper(kernel):
    # Create global variables for exception state

    def define_error_gv(postfix):
        name = kernel.name + postfix
        gv = cgutils.add_global_variable(kernel.module, ir.IntType(32), name)
        gv.initializer = ir.Constant(gv.type.pointee, None)
        return gv

    gv_exc = define_error_gv("__errcode__")
    gv_tid = []
    gv_ctaid = []
    for i in "xyz":
        gv_tid.append(define_error_gv("__tid%s__" % i))
        gv_ctaid.append(define_error_gv("__ctaid%s__" % i))

    # Create exception store helper function

    helper_name = kernel.name + "__exc_helper__"
    helper_type = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
    helper_func = ir.Function(kernel.module, helper_type, helper_name)

    block = helper_func.append_basic_block(name="entry")
    builder = ir.IRBuilder(block)

    # Implement status check / exception store logic

    status_code = helper_func.args[0]
    call_conv = cuda_target.target_context.call_conv
    status = call_conv._get_return_status(builder, status_code)

    # Check error status
    with cgutils.if_likely(builder, status.is_ok):
        builder.ret_void()

    with builder.if_then(builder.not_(status.is_python_exc)):
        # User exception raised
        old = ir.Constant(gv_exc.type.pointee, None)

        # Use atomic cmpxchg to prevent rewriting the error status
        # Only the first error is recorded

        xchg = builder.cmpxchg(
            gv_exc, old, status.code, "monotonic", "monotonic"
        )
        changed = builder.extract_value(xchg, 1)

        # If the xchange is successful, save the thread ID.
        sreg = nvvmutils.SRegBuilder(builder)
        with builder.if_then(changed):
            for (
                dim,
                ptr,
            ) in zip("xyz", gv_tid):
                val = sreg.tid(dim)
                builder.store(val, ptr)

            for (
                dim,
                ptr,
            ) in zip("xyz", gv_ctaid):
                val = sreg.ctaid(dim)
                builder.store(val, ptr)

    builder.ret_void()

    return helper_func


@global_compiler_lock
def compile(
    pyfunc,
    sig,
    debug=None,
    lineinfo=False,
    device=True,
    fastmath=False,
    cc=None,
    opt=None,
    abi="c",
    abi_info=None,
    output="ptx",
    forceinline=False,
    launch_bounds=None,
):
    """Compile a Python function to PTX or LTO-IR for a given set of argument
    types.

    :param pyfunc: The Python function to compile.
    :param sig: The signature representing the function's input and output
                types. If this is a tuple of argument types without a return
                type, the inferred return type is returned by this function. If
                a signature including a return type is passed, the compiled code
                will include a cast from the inferred return type to the
                specified return type, and this function will return the
                specified return type.
    :param debug: Whether to include debug info in the compiled code.
    :type debug: bool
    :param lineinfo: Whether to include a line mapping from the compiled code
                     to the source code. Usually this is used with optimized
                     code (since debug mode would automatically include this),
                     so we want debug info in the LLVM IR but only the line
                     mapping in the final output.
    :type lineinfo: bool
    :param device: Whether to compile a device function.
    :type device: bool
    :param fastmath: Whether to enable fast math flags (ftz=1, prec_sqrt=0,
                     prec_div=, and fma=1)
    :type fastmath: bool
    :param cc: Compute capability to compile for, as a tuple
               ``(MAJOR, MINOR)``. Defaults to ``(5, 0)``.
    :type cc: tuple
    :param opt: Whether to enable optimizations in the compiled code.
    :type opt: bool
    :param abi: The ABI for a compiled function - either ``"numba"`` or
                ``"c"``. Note that the Numba ABI is not considered stable.
                The C ABI is only supported for device functions at present.
    :type abi: str
    :param abi_info: A dict of ABI-specific options. The ``"c"`` ABI supports
                     one option, ``"abi_name"``, for providing the wrapper
                     function's name. The ``"numba"`` ABI has no options.
    :type abi_info: dict
    :param output: Type of output to generate, either ``"ptx"`` or ``"ltoir"``.
    :type output: str
    :param forceinline: Enables inlining at the NVVM IR level when set to
                        ``True``. This is accomplished by adding the
                        ``alwaysinline`` function attribute to the function
                        definition. This is only valid when the output is
                        ``"ltoir"``.
    :param launch_bounds: Kernel launch bounds, specified as a scalar or a tuple
                          of between one and three items. Tuple items provide:

                          - The maximum number of threads per block,
                          - The minimum number of blocks per SM,
                          - The maximum number of blocks per cluster.

                          If a scalar is provided, it is used as the maximum
                          number of threads per block.
    :type launch_bounds: int | tuple[int]
    :return: (code, resty): The compiled code and inferred return type
    :rtype: tuple
    """
    if abi not in ("numba", "c"):
        raise NotImplementedError(f"Unsupported ABI: {abi}")

    if abi == "c" and not device:
        raise NotImplementedError("The C ABI is not supported for kernels")

    if output not in ("ptx", "ltoir"):
        raise NotImplementedError(f"Unsupported output type: {output}")

    if forceinline and not device:
        raise ValueError("Cannot force-inline kernels")

    if forceinline and output != "ltoir":
        raise ValueError("Can only designate forced inlining in LTO-IR")

    debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
    opt = (config.OPT != 0) if opt is None else opt

    if debug and opt:
        msg = (
            "debug=True with opt=True "
            "is not supported by CUDA. This may result in a crash"
            " - set debug=False or opt=False."
        )
        warn(NumbaInvalidConfigWarning(msg))

    lto = output == "ltoir"
    abi_info = abi_info or dict()

    nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}

    if debug:
        nvvm_options["g"] = None

    if lto:
        nvvm_options["gen-lto"] = None

    args, return_type = sigutils.normalize_signature(sig)

    # If the user has used the config variable to specify a non-default that is
    # greater than the lowest non-deprecated one, then we should default to
    # their specified CC instead of the lowest non-deprecated one.
    MIN_CC = max(config.CUDA_DEFAULT_PTX_CC, nvvm.LOWEST_CURRENT_CC)
    cc = cc or MIN_CC

    cres = compile_cuda(
        pyfunc,
        return_type,
        args,
        debug=debug,
        lineinfo=lineinfo,
        fastmath=fastmath,
        nvvm_options=nvvm_options,
        cc=cc,
        forceinline=forceinline,
    )
    resty = cres.signature.return_type

    if resty and not device and resty != types.void:
        raise TypeError("CUDA kernel must have void return type.")

    tgt = cres.target_context

    if device:
        lib = cres.library
        if abi == "c":
            wrapper_name = abi_info.get("abi_name", pyfunc.__name__)
            lib = cabi_wrap_function(
                tgt, lib, cres.fndesc, wrapper_name, nvvm_options
            )
    else:
        lib = cres.library
        kernel = lib.get_function(cres.fndesc.llvm_func_name)
        lib._entry_name = cres.fndesc.llvm_func_name
        kernel_fixup(kernel, debug)
        nvvm.set_launch_bounds(kernel, launch_bounds)

    if lto:
        code = lib.get_ltoir(cc=cc)
    else:
        code = lib.get_asm_str(cc=cc)
    return code, resty


def compile_for_current_device(
    pyfunc,
    sig,
    debug=None,
    lineinfo=False,
    device=True,
    fastmath=False,
    opt=None,
    abi="c",
    abi_info=None,
    output="ptx",
    forceinline=False,
    launch_bounds=None,
):
    """Compile a Python function to PTX or LTO-IR for a given signature for the
    current device's compute capabilility. This calls :func:`compile` with an
    appropriate ``cc`` value for the current device."""
    cc = get_current_device().compute_capability
    return compile(
        pyfunc,
        sig,
        debug=debug,
        lineinfo=lineinfo,
        device=device,
        fastmath=fastmath,
        cc=cc,
        opt=opt,
        abi=abi,
        abi_info=abi_info,
        output=output,
        forceinline=forceinline,
        launch_bounds=launch_bounds,
    )


def compile_ptx(
    pyfunc,
    sig,
    debug=None,
    lineinfo=False,
    device=False,
    fastmath=False,
    cc=None,
    opt=None,
    abi="numba",
    abi_info=None,
    forceinline=False,
    launch_bounds=None,
):
    """Compile a Python function to PTX for a given signature. See
    :func:`compile`. The defaults for this function are to compile a kernel
    with the Numba ABI, rather than :func:`compile`'s default of compiling a
    device function with the C ABI."""
    return compile(
        pyfunc,
        sig,
        debug=debug,
        lineinfo=lineinfo,
        device=device,
        fastmath=fastmath,
        cc=cc,
        opt=opt,
        abi=abi,
        abi_info=abi_info,
        output="ptx",
        forceinline=forceinline,
        launch_bounds=launch_bounds,
    )


def compile_ptx_for_current_device(
    pyfunc,
    sig,
    debug=None,
    lineinfo=False,
    device=False,
    fastmath=False,
    opt=None,
    abi="numba",
    abi_info=None,
    forceinline=False,
    launch_bounds=None,
):
    """Compile a Python function to PTX for a given signature for the current
    device's compute capabilility. See :func:`compile_ptx`."""
    cc = get_current_device().compute_capability
    return compile_ptx(
        pyfunc,
        sig,
        debug=debug,
        lineinfo=lineinfo,
        device=device,
        fastmath=fastmath,
        cc=cc,
        opt=opt,
        abi=abi,
        abi_info=abi_info,
        forceinline=forceinline,
        launch_bounds=launch_bounds,
    )


def declare_device_function(name, restype, argtypes, link, use_cooperative):
    from .descriptor import cuda_target

    typingctx = cuda_target.typing_context
    targetctx = cuda_target.target_context
    sig = typing.signature(restype, *argtypes)

    # extfn is the descriptor used to call the function from Python code, and
    # is used as the key for typing and lowering.
    extfn = ExternFunction(name, sig)

    # Typing
    device_function_template = typing.make_concrete_template(name, extfn, [sig])
    typingctx.insert_user_function(extfn, device_function_template)

    # Lowering
    lib = ExternalCodeLibrary(f"{name}_externals", targetctx.codegen())
    for file in link:
        lib.add_linking_file(file)
    lib.use_cooperative = use_cooperative

    # ExternalFunctionDescriptor provides a lowering implementation for calling
    # external functions
    fndesc = funcdesc.ExternalFunctionDescriptor(name, restype, argtypes)
    targetctx.insert_user_function(extfn, fndesc, libs=(lib,))

    return device_function_template


class ExternFunction:
    """A descriptor that can be used to call the external function from within
    a Python kernel."""

    def __init__(self, name, sig):
        self.name = name
        self.sig = sig
