Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <cuda/experimental/__stf/internal/hashtable_linearprobing.cuh>
#include <cuda/experimental/__stf/stream/stream_data_interface.cuh>
#include <cuda/experimental/__stf/utility/scope_guard.cuh>

namespace cuda::experimental::stf
{
Expand Down Expand Up @@ -77,7 +78,9 @@ public:
size_t sz = this->shape.get_capacity() * sizeof(reserved::KeyValue);

// NAIVE method !
cuda_safe_call(cudaMemcpyAsync((void*) dst, (void*) src, sz, kind, s));
// cudaMemcpyAsync is an overload set (cuda_runtime.h alternate-spelling wrapper),
// so it keeps the runtime-status cuda_try form.
cuda_try<cudaMemcpyAsync<void*, void*>>(dst, src, sz, kind, s);
}

void stream_data_allocate(
Expand All @@ -95,18 +98,26 @@ public:

if (memory_node.is_host())
{
// Fallback to a synchronous method
cuda_safe_call(cudaStreamSynchronize(stream));
cuda_safe_call(cudaHostAlloc(&base_ptr, s, cudaHostAllocMapped));
// Fallback to a synchronous method. cudaHostAlloc is an overload set
// (cuda_runtime.h templated wrapper), so it keeps the runtime-status form.
cuda_try<cudaStreamSynchronize>(stream);
base_ptr = cuda_try<cudaHostAlloc<reserved::KeyValue>>(s, cudaHostAllocMapped);
memset(base_ptr, 0xff, s);
}
else
{
cuda_safe_call(cudaMallocAsync(&base_ptr, s, stream));
// cudaMallocAsync is an overload set (templated wrapper), so it keeps the
// runtime-status form.
cuda_try(cudaMallocAsync(&base_ptr, s, stream));
// Free the buffer if the initialization below throws.
SCOPE(fail)
{
cuda_safe_call(cudaFreeAsync(base_ptr, stream));
};

// We also need to initialize the hashtable
static_assert(reserved::kEmpty == 0xffffffff, "memset expected kEmpty=0xffffffff");
cuda_safe_call(cudaMemsetAsync(base_ptr, 0xff, s, stream));
cuda_try<cudaMemsetAsync>(base_ptr, 0xff, s, stream);
}

local_desc.addr = base_ptr;
Expand All @@ -120,16 +131,19 @@ public:
cudaStream_t stream) override
{
hashtable& local_desc = this->instance(instance_id);

if (memory_node.is_host())
{
// Fallback to a synchronous method
cuda_safe_call(cudaStreamSynchronize(stream));
cuda_safe_call(cudaFreeHost(local_desc.addr));
auto cudaStreamSynchronizeResult = cudaStreamSynchronize(stream);
cuda_try<cudaFreeHost>(local_desc.addr);
cuda_try(cudaStreamSynchronizeResult);
}
else
{
cuda_safe_call(cudaFreeAsync(local_desc.addr, stream));
cuda_try<cudaFreeAsync>(local_desc.addr, stream);
}

local_desc.addr = nullptr; // not strictly necessary, but helps debugging
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,30 +232,32 @@ public:

if constexpr (dimensions == 0)
{
cuda_safe_call(cudaMemcpyAsync(dst_ptr, src_ptr, sizeof(T), kind, s));
// cudaMemcpyAsync is an overload set (cuda_runtime.h adds an alternate-spelling
// wrapper), so it keeps the runtime-status cuda_try form.
cuda_try(cudaMemcpyAsync(dst_ptr, src_ptr, sizeof(T), kind, s));
}
else if constexpr (dimensions == 1)
{
cuda_safe_call(cudaMemcpyAsync(dst_ptr, src_ptr, b.extent(0) * sizeof(T), kind, s));
cuda_try(cudaMemcpyAsync(dst_ptr, src_ptr, b.extent(0) * sizeof(T), kind, s));
}
else if constexpr (dimensions == 2)
{
cuda_safe_call(cudaMemcpy2DAsync(
cuda_try<cudaMemcpy2DAsync>(
dst_ptr,
dst_instance.stride(1) * sizeof(T),
src_ptr,
src_instance.stride(1) * sizeof(T),
b.extent(0) * sizeof(T),
b.extent(1),
kind,
s));
s);
}
else
{
// We only support higher dimensions if they are contiguous !
if ((contiguous_dims(src_instance) == dimensions) && (contiguous_dims(dst_instance) == dimensions))
{
cuda_safe_call(cudaMemcpyAsync(dst_ptr, src_ptr, b.size() * sizeof(T), kind, s));
cuda_try(cudaMemcpyAsync(dst_ptr, src_ptr, b.size() * sizeof(T), kind, s));
}
else
{
Expand All @@ -279,11 +281,7 @@ public:

::std::optional<cudaMemoryType> get_memory_type(instance_id_t instance_id) override
{
auto s = this->instance(instance_id);

cudaPointerAttributes attributes{};
cuda_safe_call(cudaPointerGetAttributes(&attributes, s.data_handle()));

const auto attributes = cuda_try<cudaPointerGetAttributes>(this->instance(instance_id).data_handle());
// Implicitly converted to an optional
return attributes.type;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public:
if (e.affine_data_place().is_host())
{
// TODO make a callback when the situation gets better
cuda_safe_call(cudaStreamSynchronize(s));
cuda_try<cudaStreamSynchronize>(s);
// slice_print(in, "in before op");
// slice_print(inout, "inout before op");

Expand Down Expand Up @@ -160,7 +160,7 @@ public:
if (e.affine_data_place().is_host())
{
// TODO make a callback when the situation gets better
cuda_safe_call(cudaStreamSynchronize(s));
cuda_try<cudaStreamSynchronize>(s);
if constexpr (dimensions == 1)
{
for (size_t i = 0; i < out.extent(0); i++)
Expand Down
Loading