diff --git a/cuda_bindings/tests/test_nvfatbin.py b/cuda_bindings/tests/test_nvfatbin.py index 32c6e70f59..ec2f17ef95 100644 --- a/cuda_bindings/tests/test_nvfatbin.py +++ b/cuda_bindings/tests/test_nvfatbin.py @@ -121,8 +121,7 @@ def nvcc_smoke(tmpdir) -> str: return nvcc -@pytest.fixture -def CUBIN(arch): +def _build_cubin(arch): def CHECK_NVRTC(err): if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise RuntimeError(repr(err)) @@ -141,6 +140,11 @@ def CHECK_NVRTC(err): return cubin +@pytest.fixture +def CUBIN(arch): + return _build_cubin(arch) + + # create a valid LTOIR input for testing @pytest.fixture def LTOIR(arch): @@ -259,11 +263,11 @@ def test_nvfatbin_add_ptx(PTX, arch): nvfatbin.destroy(handle) -@pytest.mark.parametrize("arch", ["sm_80"], indirect=True) -def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH(CUBIN, arch): +def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH(): + cubin = _build_cubin("sm_80") handle = nvfatbin.create([], 0) with pytest.raises(nvfatbin.nvFatbinError, match="ERROR_ELF_ARCH_MISMATCH"): - nvfatbin.add_cubin(handle, CUBIN, len(CUBIN), "75", "inc") + nvfatbin.add_cubin(handle, cubin, len(cubin), "75", "inc") nvfatbin.destroy(handle) @@ -280,11 +284,11 @@ def test_nvfatbin_add_cubin(CUBIN, arch): nvfatbin.destroy(handle) -@pytest.mark.parametrize("arch", ["sm_80"], indirect=True) -def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH(CUBIN, arch): +def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH(): + cubin = _build_cubin("sm_80") handle = nvfatbin.create([], 0) with pytest.raises(nvfatbin.nvFatbinError, match="ERROR_ELF_ARCH_MISMATCH"): - nvfatbin.add_cubin(handle, CUBIN, len(CUBIN), "75", "inc") + nvfatbin.add_cubin(handle, cubin, len(cubin), "75", "inc") nvfatbin.destroy(handle) diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 3d4059b696..54b3cc949b 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -148,7 +148,7 @@ def _cpu_array_samples(): return samples -@pytest.mark.parametrize("in_arr,", _cpu_array_samples()) +@pytest.mark.parametrize("in_arr", _cpu_array_samples()) class TestViewCPU: def test_args_viewable_as_strided_memory_cpu(self, in_arr): @args_viewable_as_strided_memory((0,))