Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/matrix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ workflows:
# args: '--preset libcudacxx --lit-tests "cuda/utility/basic_any.pass.cpp"' }
#
override:

- {jobs: ['test'], project: 'cudax', std: 'max', gpu: 'h100' }
pull_request:
# Old CTK: Oldest/newest supported host compilers:
- {jobs: ['build'], std: 'minmax', ctk: '12.0', cxx: ['gcc7', 'gcc12', 'clang14', 'msvc2019', 'msvc14.39']}
Expand Down
48 changes: 48 additions & 0 deletions cudax/include/cuda/experimental/__library/library_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cuda/__runtime/ensure_current_context.h>
#include <cuda/std/__cstddef/types.h>
#include <cuda/std/__exception/cuda_error.h>
#include <cuda/std/__type_traits/is_function.h>

#include <cuda/std/__cccl/prologue.h>

Expand Down Expand Up @@ -197,6 +198,53 @@ public:
return library_symbol_info{reinterpret_cast<void*>(__dptr), __size};
}

//! @brief Checks if the library contains a unified function with the given name
//!
//! @param __name The name of the unified function to check for
//!
//! @return true if the library contains a unified function with the given name, false otherwise
//!
//! @throws cuda_error if the library could not be queried for the unified function
//!
//! @note Unified functions require PTX target at least 9.0
[[nodiscard]] bool has_unified_function(const char* __name) const
{
void* __fn_ptr{};
switch (const auto __res = ::cuda::__driver::__libraryGetUnifiedFunctionNoThrow(__fn_ptr, __library_, __name))
{
case ::cudaSuccess:
return true;
case ::cudaErrorSymbolNotFound:
return false;
default:
::cuda::__throw_cuda_error(__res, "Failed to get the unified function from library");
}
}

//! @brief Gets a pointer and size of a unified function from the library
//!
//! @param __name The name of the unified function to retrieve
//!
//! @return A pair containing a pointer to the unified function and its size
//!
//! @throws cuda_error if the unified function could not be found in the library
//!
//! @note Unified functions require PTX target at least 9.0
template <class _Signature>
[[nodiscard]] _Signature* unified_function(const char* __name) const
{
static_assert(::cuda::std::is_function_v<_Signature>,
"_Signature must be a function signature in form of R(Args...).");

void* __fn_ptr{};
if (const auto __res = ::cuda::__driver::__libraryGetUnifiedFunctionNoThrow(__fn_ptr, __library_, __name);
__res != ::cudaSuccess)
{
::cuda::__throw_cuda_error(__res, "Failed to get the unified function from the library");
}
return reinterpret_cast<_Signature*>(__fn_ptr);
}

//! @brief Gets the CUlibrary handle
//!
//! @return The CUlibrary handle wrapped by this `library_ref`
Expand Down
8 changes: 8 additions & 0 deletions cudax/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,18 @@ cudax_add_catch2_test(test_target kernel
cudax_add_catch2_test(test_target library_ref
library/library_ref.cu
)
target_link_libraries(
${test_target}
PRIVATE CUDA::nvptxcompiler_static CUDA::nvJitLink
)

cudax_add_catch2_test(test_target library
library/library.cu
)
target_link_libraries(
${test_target}
PRIVATE CUDA::nvptxcompiler_static CUDA::nvJitLink
)

cudax_add_catch2_test(test_target algorithm
algorithm/fill.cu
Expand Down
129 changes: 45 additions & 84 deletions cudax/test/library/library.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

Expand All @@ -21,94 +21,21 @@

#include <testing.cuh>

// extern "C" __constant__ int const_data;
//
// extern "C" __device__ int global_data;
//
// extern "C" __managed__ int managed_data;
//
// extern "C" __global__ void kernel(int* array, int n)
// {
// __shared__ int shared[32];
// int tid = blockDim.x * blockIdx.x + threadIdx.x;
// if (tid < n)
// {
// shared[threadIdx.x] = array[tid];
// __syncthreads();
// array[tid] = shared[threadIdx.x + 1 % 32] + const_data;
// }
// }

constexpr char library_src[] = R"(
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-32267302
// Cuda compilation tools, release 12.0, V12.0.140
// Based on NVVM 7.0.1
//

.version 8.0
.target sm_75
.address_size 64

// .globl kernel
.const .align 4 .u32 const_data;
.global .align 4 .u32 global_data;
.global .attribute(.managed) .align 4 .u32 managed_data;
// _ZZ6kernelE6shared has been demoted

.visible .entry kernel(
.param .u64 kernel_param_0,
.param .u32 kernel_param_1
)
{
.reg .pred %p<2>;
.reg .b32 %r<13>;
.reg .b64 %rd<5>;
// demoted variable
.shared .align 4 .b8 _ZZ6kernelE6shared[128];

ld.param.u64 %rd1, [kernel_param_0];
ld.param.u32 %r3, [kernel_param_1];
mov.u32 %r4, %ntid.x;
mov.u32 %r5, %ctaid.x;
mov.u32 %r1, %tid.x;
mad.lo.s32 %r2, %r4, %r5, %r1;
setp.ge.s32 %p1, %r2, %r3;
@%p1 bra $L__BB0_2;

cvta.to.global.u64 %rd2, %rd1;
mul.wide.s32 %rd3, %r2, 4;
add.s64 %rd4, %rd2, %rd3;
ld.global.u32 %r6, [%rd4];
shl.b32 %r7, %r1, 2;
mov.u32 %r8, _ZZ6kernelE6shared;
add.s32 %r9, %r8, %r7;
st.shared.u32 [%r9], %r6;
bar.sync 0;
ld.const.u32 %r10, [const_data];
ld.shared.u32 %r11, [%r9+4];
add.s32 %r12, %r10, %r11;
st.global.u32 [%rd4], %r12;

$L__BB0_2:
ret;

}
)";
#include "library_cubin.h"

C2H_CCCLRT_TEST("Library", "[library]")
{
constexpr char kernel_name[] = "kernel";
constexpr char global_symbol_name[] = "global_data";
constexpr char const_symbol_name[] = "const_data";
constexpr char managed_symbol_name[] = "managed_data";
const cuda::device_ref device{0};
const auto cc = device.attribute(cuda::device_attributes::compute_capability);

CUlibrary lib1_native = ::cuda::__driver::__libraryLoadData(library_src, nullptr, nullptr, 0, nullptr, nullptr, 0);
CUlibrary lib2_native = ::cuda::__driver::__libraryLoadData(library_src, nullptr, nullptr, 0, nullptr, nullptr, 0);
// unified function requires at least sm90
const auto with_unified_function = (cc >= cuda::compute_capability{90});
const auto lib_src = make_library_cubin(cc);

const cuda::device_ref device{0};
CUlibrary lib1_native =
::cuda::__driver::__libraryLoadData(lib_src.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0);
CUlibrary lib2_native =
::cuda::__driver::__libraryLoadData(lib_src.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0);

// Types
{
Expand Down Expand Up @@ -300,6 +227,40 @@ C2H_CCCLRT_TEST("Library", "[library]")
(void) lib.release(); // prevent library unload in destructor
}

// Has unified function
if (with_unified_function)
{
STATIC_REQUIRE(
cuda::std::is_same_v<decltype(cuda::std::declval<cudax::library>().has_unified_function(unified_function_name)),
bool>);

cudax::library lib = cudax::library::from_native_handle(lib1_native);
CUDAX_REQUIRE(lib.has_unified_function(unified_function_name));
CUDAX_REQUIRE(!lib.has_unified_function("non_existent_unified"));

(void) lib.release(); // prevent library unload in destructor
}

// Get unified function
if (with_unified_function)
{
STATIC_REQUIRE(cuda::std::is_same_v<
decltype(cuda::std::declval<cudax::library_ref>().unified_function<int()>(unified_function_name)),
int (*)()>);

cudax::library lib = cudax::library::from_native_handle(lib1_native);
auto unified_fn = lib.unified_function<int()>(unified_function_name);

void* unified_fn_addr;
CUDAX_REQUIRE(
::cuda::__driver::__libraryGetUnifiedFunctionNoThrow(unified_fn_addr, lib1_native, managed_symbol_name)
== cudaSuccess);

CUDAX_REQUIRE(reinterpret_cast<void*>(unified_fn) == unified_fn_addr);

(void) lib.release(); // prevent library unload in destructor
}

// Get handle
{
STATIC_REQUIRE(cuda::std::is_same_v<decltype(cuda::std::declval<cudax::library>().get()), CUlibrary>);
Expand Down
Loading
Loading