diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index a77078fb..c06393f5 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -139,28 +139,29 @@ Buffer::Buffer(int rank, num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), + device_id([&]() { + int id = -1; + CUDA_CHECK(cudaGetDevice(&id)); + return id; + }()), enable_shrink(enable_shrink), low_latency_mode(low_latency_mode), explicitly_destroy(explicitly_destroy), + comm_stream([&]() { + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + paddle::distributed::ProcessGroup* pg = map->get(context_ring_id); + const auto& place = phi::GPUPlace(device_id); + comm_ctx = + reinterpret_cast(pg) + ->GetOrCreateCommContext(place, + phi::distributed::CommType::ALLTOALL); + calc_ctx = reinterpret_cast( + reinterpret_cast(pg) + ->GetDeviceContext(place, true)); + return make_cuda_stream(comm_ctx->GetStream(), device_id); + }()), shared_memory_allocator(use_fabric) { - CUDA_CHECK(cudaGetDevice(&device_id)); - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - paddle::distributed::ProcessGroup* pg = map->get(context_ring_id); - const auto& place = phi::GPUPlace(device_id); - comm_ctx = - reinterpret_cast(pg) - ->GetOrCreateCommContext(place, phi::distributed::CommType::ALLTOALL); - // Construct at::cuda::CUDAStream from raw cudaStream_t - cudaStream_t raw_stream = comm_ctx->GetStream(); - c10::StreamId sid = static_cast(reinterpret_cast(raw_stream)); - comm_stream.emplace(c10::Stream(c10::Stream::UNSAFE, - c10::Device(c10::DeviceType::CUDA, device_id), - sid)); - calc_ctx = reinterpret_cast( - reinterpret_cast(pg) - ->GetDeviceContext(place, true)); - // Metadata memory int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); @@ -209,12 +210,12 @@ Buffer::Buffer(int rank, reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` - CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream.value().stream())); + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); } // Create 32 MiB workspace CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); - CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream.value().stream())); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); @@ -301,7 +302,7 @@ void Buffer::destroy() { if (num_nvl_bytes > 0) { // Barrier - intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream.value().stream()); + intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); CUDA_CHECK(cudaDeviceSynchronize()); // Close remote IPC @@ -419,14 +420,14 @@ Buffer::get_dispatch_layout( auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() and async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream.value(), previous_event.value()); + stream_wait(comm_stream, previous_event.value()); } else { - stream_wait(comm_stream.value(), compute_stream); + stream_wait(comm_stream, compute_stream); } auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); @@ -446,24 +447,24 @@ Buffer::get_dispatch_layout( num_topk, num_ranks, num_experts, - comm_stream.value().stream()); + comm_stream); // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream.value()); + event = EventHandle(comm_stream); for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { - t.record_stream(comm_stream.value()); + t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {num_tokens_per_rdma_rank}) { - to.has_value() ? to->record_stream(comm_stream.value()) : void(); + to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream.value()); + stream_wait(compute_stream, comm_stream); } // Switch back compute stream @@ -581,14 +582,14 @@ Buffer::intranode_dispatch(const torch::Tensor& x, auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream.value(), previous_event.value()); + stream_wait(comm_stream, previous_event.value()); } else { - stream_wait(comm_stream.value(), compute_stream); + stream_wait(comm_stream, compute_stream); } // Create handles (only return for non-cached mode) @@ -607,7 +608,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, // Copy rank prefix matrix and clean flags intranode::cached_notify_dispatch( - rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream.value().stream()); + rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); @@ -636,7 +637,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, - comm_stream.value().stream(), + comm_stream, num_channels); if (num_worst_tokens > 0) { @@ -730,7 +731,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, buffer_ptrs_gpu, rank, num_ranks, - comm_stream.value().stream(), + comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); @@ -738,10 +739,10 @@ Buffer::intranode_dispatch(const torch::Tensor& x, // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream.value()); + event = EventHandle(comm_stream); if (!skip_x_record_stream) { for (auto& t : {x, recv_x}) { - t.record_stream(comm_stream.value()); + t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -752,7 +753,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { - t.record_stream(comm_stream.value()); + t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -766,12 +767,12 @@ Buffer::intranode_dispatch(const torch::Tensor& x, recv_topk_idx, recv_topk_weights, recv_x_scales}) { - to.has_value() ? to->record_stream(comm_stream.value()) : void(); + to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream.value()); + stream_wait(compute_stream, comm_stream); } // Switch back compute stream @@ -832,14 +833,14 @@ std::tuple, std::optionalstream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream.value(), previous_event.value()); + stream_wait(comm_stream, previous_event.value()); } else { - stream_wait(comm_stream.value(), compute_stream); + stream_wait(comm_stream, compute_stream); } int num_topk = 0; @@ -866,7 +867,7 @@ std::tuple, std::optional>({bias_0, bias_1}); @@ -905,7 +906,7 @@ std::tuple, std::optional, std::optional event; if (async) { - event = EventHandle(comm_stream.value()); + event = EventHandle(comm_stream); if (!skip_x_record_stream) { - x.record_stream(comm_stream.value()); + x.record_stream(comm_stream); if (allocate_on_comm_stream) x.record_stream(compute_stream); } for (auto& t : {src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { - t.record_stream(comm_stream.value()); + t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { - to.has_value() ? to->record_stream(comm_stream.value()) : void(); + to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream.value()); + stream_wait(compute_stream, comm_stream); } // Switch back compute stream @@ -1074,14 +1075,14 @@ Buffer::internode_dispatch(const torch::Tensor& x, auto compute_stream = make_cuda_stream(calc_ctx->stream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream.value(), previous_event.value()); + stream_wait(comm_stream, previous_event.value()); } else { - stream_wait(comm_stream.value(), compute_stream); + stream_wait(comm_stream, compute_stream); } // Create handles (only return for non-cached mode) @@ -1119,7 +1120,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, - comm_stream.value().stream(), + comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, true, @@ -1160,7 +1161,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, config.num_max_nvl_chunked_recv_tokens, barrier_signal_ptrs_gpu, rank, - comm_stream.value().stream(), + comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode); @@ -1267,14 +1268,14 @@ Buffer::internode_dispatch(const torch::Tensor& x, rank, num_ranks, cached_mode, - comm_stream.value().stream(), + comm_stream, num_channels, low_latency_mode); // Wait streams std::optional event; if (async) { - event = EventHandle(comm_stream.value()); + event = EventHandle(comm_stream); for (auto& t : {x, is_token_in_rank, recv_x, @@ -1282,7 +1283,7 @@ Buffer::internode_dispatch(const torch::Tensor& x, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { - t.record_stream(comm_stream.value()); + t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } @@ -1304,12 +1305,12 @@ Buffer::internode_dispatch(const torch::Tensor& x, send_rdma_head, send_nvl_head, recv_src_meta}) { - to.has_value() ? to->record_stream(comm_stream.value()) : void(); + to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream.value()); + stream_wait(compute_stream, comm_stream); } // Switch back compute stream @@ -1392,14 +1393,14 @@ std::tuple, std::optionalstream(), device_id); if (allocate_on_comm_stream) { EP_HOST_ASSERT(previous_event.has_value() && async); - deep_ep::SetAllocatorStreamForGPUContext(comm_stream.value().stream(), calc_ctx); + deep_ep::SetAllocatorStreamForGPUContext(comm_stream, calc_ctx); } // Wait previous tasks to be finished if (previous_event.has_value()) { - stream_wait(comm_stream.value(), previous_event.value()); + stream_wait(comm_stream, previous_event.value()); } else { - stream_wait(comm_stream.value(), compute_stream); + stream_wait(comm_stream, compute_stream); } // Top-k checks @@ -1439,7 +1440,7 @@ std::tuple, std::optional, std::optional event; if (async) { - event = EventHandle(comm_stream.value()); + event = EventHandle(comm_stream); for (auto& t : {x, src_meta, is_combined_token_in_rank, @@ -1502,17 +1503,17 @@ std::tuple, std::optionalrecord_stream(comm_stream.value()) : void(); + to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); } } else { - stream_wait(compute_stream, comm_stream.value()); + stream_wait(compute_stream, comm_stream); } // Switch back compute stream @@ -1612,7 +1613,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); - auto launch_stream = return_recv_hook ? compute_stream : comm_stream.value(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; EP_HOST_ASSERT(not(async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); @@ -1757,7 +1758,7 @@ std::tuple, std::optional(comm_stream)); #if defined(PADDLE_WITH_CUDA) return phi::CUDAStream(phi::GPUPlace(device_id), s); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 93759f9d..726a2a6a 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -86,10 +86,10 @@ struct Buffer { shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication - std::optional comm_stream; - - phi::distributed::NCCLCommContext* comm_ctx; - phi::GPUContext* calc_ctx; + at::cuda::CUDAStream comm_stream; + + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + phi::GPUContext* calc_ctx = nullptr; // After IPC/NVSHMEM synchronization, this flag will be true bool available = false; @@ -151,15 +151,10 @@ struct Buffer { torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; - at::cuda::CUDAStream get_comm_stream() const { - return comm_stream.value(); + torch::Stream get_comm_stream() const { + return comm_stream; } - // Helper to get raw stream for CUDA APIs - cudaStream_t get_comm_stream_raw() const { - return comm_stream.value().stream(); - } - void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt);