diff --git a/cuda_bindings/cuda/bindings/_internal/utils.pyx b/cuda_bindings/cuda/bindings/_internal/utils.pyx index a56ef35357..879a10e621 100644 --- a/cuda_bindings/cuda/bindings/_internal/utils.pyx +++ b/cuda_bindings/cuda/bindings/_internal/utils.pyx @@ -120,7 +120,14 @@ cdef int get_nested_resource_ptr(nested_resource[ResT] &in_out_ptr, object obj, nested_ptr.reset(nested_vec, True) for i, obj_i in enumerate(obj): if ResT is char: - obj_i_bytes = ((obj_i)).encode() + obj_i_type = type(obj_i) + if obj_i_type is str: + obj_i_bytes = obj_i.encode("utf-8") + elif obj_i_type is bytes: + obj_i_bytes = obj_i + else: + raise TypeError( + f"Expected str or bytes, got {obj_i_type.__name__}") str_len = (len(obj_i_bytes)) + 1 # including null termination deref(nested_res_vec)[i].resize(str_len) obj_i_ptr = (obj_i_bytes) diff --git a/cuda_bindings/tests/test_nvjitlink.py b/cuda_bindings/tests/test_nvjitlink.py index 3bfeb8d35a..9aa39de1ac 100644 --- a/cuda_bindings/tests/test_nvjitlink.py +++ b/cuda_bindings/tests/test_nvjitlink.py @@ -99,6 +99,13 @@ def test_create_and_destroy(option): nvjitlink.destroy(handle) +@pytest.mark.parametrize("option", ARCHITECTURES) +def test_create_and_destroy_bytes_options(option): + handle = nvjitlink.create(1, [f"-arch={option}".encode()]) + assert handle != 0 + nvjitlink.destroy(handle) + + @pytest.mark.parametrize("option", ARCHITECTURES) def test_complete_empty(option): handle = nvjitlink.create(1, [f"-arch={option}"]) diff --git a/cuda_bindings/tests/test_nvvm.py b/cuda_bindings/tests/test_nvvm.py index 05fec9767d..51151333ab 100644 --- a/cuda_bindings/tests/test_nvvm.py +++ b/cuda_bindings/tests/test_nvvm.py @@ -115,7 +115,9 @@ def test_get_buffer_empty(get_size, get_buffer): assert buffer == b"\x00" -@pytest.mark.parametrize("options", [[], ["-opt=0"], ["-opt=3", "-g"]]) +@pytest.mark.parametrize( + "options", [[], ["-opt=0"], ["-opt=3", "-g"], [b"-opt=0"], [b"-opt=3", b"-g"], ["-opt=3", b"-g"]] +) def test_compile_program_with_minimal_nvvm_ir(minimal_nvvmir, options): # noqa: F401, F811 with nvvm_program() as prog: nvvm.add_module_to_program(prog, minimal_nvvmir, len(minimal_nvvmir), "FileNameHere.ll") @@ -135,7 +137,9 @@ def test_compile_program_with_minimal_nvvm_ir(minimal_nvvmir, options): # noqa: assert ".visible .entry kernel()" in buffer.decode() -@pytest.mark.parametrize("options", [[], ["-opt=0"], ["-opt=3", "-g"]]) +@pytest.mark.parametrize( + "options", [[], ["-opt=0"], ["-opt=3", "-g"], [b"-opt=0"], [b"-opt=3", b"-g"], ["-opt=3", b"-g"]] +) def test_verify_program_with_minimal_nvvm_ir(minimal_nvvmir, options): # noqa: F401, F811 with nvvm_program() as prog: nvvm.add_module_to_program(prog, minimal_nvvmir, len(minimal_nvvmir), "FileNameHere.ll")