import os
import sys
import subprocess
import threading
from numba import cuda
from numba.cuda.testing import (
    unittest,
    CUDATestCase,
    skip_on_cudasim,
    skip_under_cuda_memcheck,
)
from numba.tests.support import captured_stdout


class TestCudaDetect(CUDATestCase):
    def test_cuda_detect(self):
        # exercise the code path
        with captured_stdout() as out:
            cuda.detect()
        output = out.getvalue()
        self.assertIn("Found", output)
        self.assertIn("CUDA devices", output)


@skip_under_cuda_memcheck("Hangs cuda-memcheck")
class TestCUDAFindLibs(CUDATestCase):
    def run_cmd(self, cmdline, env):
        popen = subprocess.Popen(
            cmdline, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
        )

        # finish in 5 minutes or kill it
        timeout = threading.Timer(5 * 60.0, popen.kill)
        try:
            timeout.start()
            out, err = popen.communicate()
            # the process should exit with an error
            return out.decode(), err.decode()
        finally:
            timeout.cancel()
        return None, None

    def run_test_in_separate_process(self, envvar, envvar_value):
        env_copy = os.environ.copy()
        env_copy[envvar] = str(envvar_value)
        code = """if 1:
            from numba import cuda
            @cuda.jit('(int64,)')
            def kernel(x):
                pass
            kernel(1,)
            """
        cmdline = [sys.executable, "-c", code]
        return self.run_cmd(cmdline, env_copy)

    @skip_on_cudasim("Simulator does not hit device library search code path")
    @unittest.skipIf(not sys.platform.startswith("linux"), "linux only")
    def test_cuda_find_lib_errors(self):
        """
        This tests that the find_libs works as expected in the case of an
        environment variable being used to set the path.
        """
        # one of these is likely to exist on linux, it's also unlikely that
        # someone has extracted the contents of libdevice into here!
        locs = ["lib", "lib64"]

        looking_for = None
        for l in locs:
            looking_for = os.path.join(os.path.sep, l)
            if os.path.exists(looking_for):
                break

        # This is the testing part, the test will only run if there's a valid
        # path in which to look
        if looking_for is not None:
            out, err = self.run_test_in_separate_process(
                "NUMBA_CUDA_DRIVER", looking_for
            )
            self.assertTrue(out is not None)
            self.assertTrue(err is not None)


if __name__ == "__main__":
    unittest.main()
