import numpy as np

from numba.cuda.testing import skip_on_cudasim, CUDATestCase
from numba import cuda, float64
import unittest


def kernel_func(x):
    x[0] = 1


def device_func(x, y, z):
    return x * y + z


# Fragments of code that are removed from kernel_func's PTX when optimization
# is on. Previously this list was longer when kernel wrappers were used - if
# the test function were more complex it may be possible to isolate additional
# fragments of PTX we could check for the absence / presence of, but removal of
# the use of local memory is a good indicator that optimization was applied.
removed_by_opt = ("__local_depot0",)


@skip_on_cudasim("Simulator does not optimize code")
class TestOptimization(CUDATestCase):
    def test_eager_opt(self):
        # Optimization should occur by default
        sig = (float64[::1],)
        kernel = cuda.jit(sig)(kernel_func)
        ptx = kernel.inspect_asm()

        for fragment in removed_by_opt:
            with self.subTest(fragment=fragment):
                self.assertNotIn(fragment, ptx[sig])

    def test_eager_noopt(self):
        # Optimization disabled
        sig = (float64[::1],)
        kernel = cuda.jit(sig, opt=False)(kernel_func)
        ptx = kernel.inspect_asm()

        for fragment in removed_by_opt:
            with self.subTest(fragment=fragment):
                self.assertIn(fragment, ptx[sig])

    def test_lazy_opt(self):
        # Optimization should occur by default
        kernel = cuda.jit(kernel_func)
        x = np.zeros(1, dtype=np.float64)
        kernel[1, 1](x)

        # Grab the PTX for the one definition that has just been jitted
        ptx = next(iter(kernel.inspect_asm().items()))[1]

        for fragment in removed_by_opt:
            with self.subTest(fragment=fragment):
                self.assertNotIn(fragment, ptx)

    def test_lazy_noopt(self):
        # Optimization disabled
        kernel = cuda.jit(opt=False)(kernel_func)
        x = np.zeros(1, dtype=np.float64)
        kernel[1, 1](x)

        # Grab the PTX for the one definition that has just been jitted
        ptx = next(iter(kernel.inspect_asm().items()))[1]

        for fragment in removed_by_opt:
            with self.subTest(fragment=fragment):
                self.assertIn(fragment, ptx)

    def test_device_opt(self):
        # Optimization should occur by default
        sig = (float64, float64, float64)
        device = cuda.jit(sig, device=True)(device_func)
        ptx = device.inspect_asm(sig)
        self.assertIn("fma.rn.f64", ptx)

    def test_device_noopt(self):
        # Optimization disabled
        sig = (float64, float64, float64)
        device = cuda.jit(sig, device=True, opt=False)(device_func)
        ptx = device.inspect_asm(sig)
        # Fused-multiply adds should be disabled when not optimizing
        self.assertNotIn("fma.rn.f64", ptx)


if __name__ == "__main__":
    unittest.main()
