From cf77bdbd23a43303e0a4d3234c2d062d14a9a2aa Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Jan 2026 02:08:34 -0800 Subject: [PATCH 01/18] first working dispatch and combine primitive for k=1 --- CMakeLists.txt | 2 + csrc/dispatch.h | 2 + csrc/host_ir/evaluator.cpp | 63 +++++ csrc/host_ir/evaluator.h | 2 + csrc/multidevice/communication.cpp | 161 +++++++++++ csrc/multidevice/communication.h | 162 +++++++++++ csrc/multidevice/dispatch_combine.cpp | 267 ++++++++++++++++++ csrc/multidevice/dispatch_combine.h | 51 ++++ .../cpp/test_multidevice_dispatch_combine.cpp | 121 ++++++++ 9 files changed, 831 insertions(+) create mode 100644 csrc/multidevice/dispatch_combine.cpp create mode 100644 csrc/multidevice/dispatch_combine.h create mode 100644 tests/cpp/test_multidevice_dispatch_combine.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 13dd918282b..b325b325d9c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,6 +235,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/communication.cpp ${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp ${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp + ${NVFUSER_SRCS_DIR}/multidevice/dispatch_combine.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp @@ -1143,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 3bf3b8350ff..01aa278af71 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -118,6 +118,8 @@ class Val; f(Merge); \ f(Partition); \ f(Combine); \ + f(MoEDispatch); \ + f(MoECombine); \ f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2ceedfddc40..a847a9d5f99 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -25,6 +25,7 @@ #include "multidevice/allocation_utils.h" #include "multidevice/communication.h" #include "multidevice/cuda_p2p.h" +#include "multidevice/dispatch_combine.h" #include "multidevice/execution_utils.h" #include "multidevice/symmetric_tensor.h" #include "multidevice/utils.h" @@ -386,6 +387,68 @@ void HostIrEvaluator::handle(P2PCommunication* communication) { } } +void HostIrEvaluator::handle(MoEDispatch* dispatch) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(dispatch->inX()).as(); + auto topk_idx = + getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_weights = + getKnownConcreteValue(dispatch->inTopkWeights()).as(); + auto is_token_in_rank = + getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); + + auto result = dispatchWithCudaBackend( + x, + topk_idx, + topk_weights, + is_token_in_rank, + dispatch->numExperts(), + communicator_, + dispatch->backend()); + + expr_evaluator_.bind(dispatch->outX(), result.recv_x); + expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx); + expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights); + expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx); + expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank); + expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank); + expr_evaluator_.bind( + dispatch->outTokensFromRank(), result.n_tokens_from_rank); +} + +void HostIrEvaluator::handle(MoECombine* combine) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(combine->inX()).as(); + auto topk_weights = + getKnownConcreteValue(combine->inTopkWeights()).as(); + auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as(); + auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as(); + auto n_tokens_to_rank = + getKnownConcreteValue(combine->inTokensToRank()).as(); + auto n_tokens_from_rank = + getKnownConcreteValue(combine->inTokensFromRank()).as(); + + auto result = combineWithCudaBackend( + x, + topk_weights, + src_idx, + src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + communicator_, + combine->backend()); + + expr_evaluator_.bind(combine->outX(), result.combined_x); + expr_evaluator_.bind( + combine->outTopkWeights(), result.combined_topk_weights); +} + void HostIrEvaluator::handle(Wait* wait) { Expr* expr = wait->communication(); auto* p2p_comm = dynamic_cast(expr); diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 22833156cab..c1b0a70ef78 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -98,6 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(LaunchKernel*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; + void handle(MoEDispatch*) override; + void handle(MoECombine*) override; void handle(Wait*) override; void handle(kir::ForLoop*) override; void handle(hir::ForLoop*) override; diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 06b4ffa426c..febbd519d10 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -321,6 +321,167 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +MoEDispatch::MoEDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_idx); + addInput(in_topk_weights); + addInput(in_is_token_in_rank); + addOutput(out_x); + addOutput(out_topk_idx); + addOutput(out_topk_weights); + addOutput(out_src_idx); + addOutput(out_src_rank); + addOutput(out_n_tokens_to_rank); + addOutput(out_n_tokens_from_rank); + addDataAttribute(num_experts); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoEDispatch) + +std::string MoEDispatch::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Dispatch " << name() << " (" + << "num_experts=" << numExperts() << ", " + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "topk_idx=" << inTopkIdx() << ", " + << "topk_weights=" << inTopkWeights() << ", " + << "is_token_in_rank=" << inIsTokenInRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoEDispatch::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoEDispatch::validate() { + NVF_CHECK(numExperts() > 0, "num_experts must be positive."); + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK(inTopkIdx()->isA(), "topk_idx must be a TensorView."); + NVF_CHECK( + inTopkIdx()->getDataType().has_value() && + isIntegralType(*inTopkIdx()->getDataType()), + "topk_idx must be integral."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "topk_weights must be floating point."); + NVF_CHECK( + inIsTokenInRank()->getDataType() == DataType::Bool, + "is_token_in_rank must be Bool."); + NVF_CHECK( + outTopkIdx()->getDataType().has_value() && + isIntegralType(*outTopkIdx()->getDataType()), + "out_topk_idx must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); + NVF_CHECK( + outSrcIdx()->getDataType().has_value() && + isIntegralType(*outSrcIdx()->getDataType()), + "out_src_idx must be integral."); + NVF_CHECK( + outSrcRank()->getDataType().has_value() && + isIntegralType(*outSrcRank()->getDataType()), + "out_src_rank must be integral."); + NVF_CHECK( + outTokensToRank()->getDataType().has_value() && + isIntegralType(*outTokensToRank()->getDataType()), + "out_n_tokens_to_rank must be integral."); + NVF_CHECK( + outTokensFromRank()->getDataType().has_value() && + isIntegralType(*outTokensFromRank()->getDataType()), + "out_n_tokens_from_rank must be integral."); +} + +MoECombine::MoECombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_weights); + addInput(in_src_idx); + addInput(in_src_rank); + addInput(in_n_tokens_to_rank); + addInput(in_n_tokens_from_rank); + addOutput(out_x); + addOutput(out_topk_weights); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoECombine) + +std::string MoECombine::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Combine " << name() << " (" + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "src_idx=" << inSrcIdx() << ", " + << "src_rank=" << inSrcRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoECombine::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoECombine::validate() { + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "in_topk_weights must be floating point."); + NVF_CHECK( + inSrcIdx()->getDataType().has_value() && + isIntegralType(*inSrcIdx()->getDataType()), + "in_src_idx must be integral."); + NVF_CHECK( + inSrcRank()->getDataType().has_value() && + isIntegralType(*inSrcRank()->getDataType()), + "in_src_rank must be integral."); + NVF_CHECK( + inTokensToRank()->getDataType().has_value() && + isIntegralType(*inTokensToRank()->getDataType()), + "in_n_tokens_to_rank must be integral."); + NVF_CHECK( + inTokensFromRank()->getDataType().has_value() && + isIntegralType(*inTokensFromRank()->getDataType()), + "in_n_tokens_from_rank must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); +} + namespace { c10::intrusive_ptr postBroadcast( Communication* communication, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 1a7f1a1cc4c..9c880110b5e 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -174,6 +174,168 @@ class P2PCommunication : public Expr { } }; +// Dispatch represents intra-node MoE token dispatch. It shuffles tokens from +// the local rank to destination ranks based on `is_token_in_rank`. +class MoEDispatch : public Expr { + public: + using Expr::Expr; + + MoEDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + MoEDispatch(const MoEDispatch& other) = delete; + MoEDispatch& operator=(const MoEDispatch& other) = delete; + MoEDispatch(MoEDispatch&& other) = delete; + MoEDispatch& operator=(MoEDispatch&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoEDispatch"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkIdx() const { + return output(1)->as(); + } + + TensorView* outTopkWeights() const { + return output(2)->as(); + } + + TensorView* outSrcIdx() const { + return output(3)->as(); + } + + TensorView* outSrcRank() const { + return output(4)->as(); + } + + TensorView* outTokensToRank() const { + return output(5)->as(); + } + + TensorView* outTokensFromRank() const { + return output(6)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkIdx() const { + return input(1)->as(); + } + + TensorView* inTopkWeights() const { + return input(2)->as(); + } + + TensorView* inIsTokenInRank() const { + return input(3)->as(); + } + + int64_t numExperts() const { + return attribute(0); + } + + CommunicatorBackend backend() const { + return attribute(1); + } + + private: + void validate(); +}; + +// Combine represents intra-node MoE token combine. It shuffles tokens back to +// their source ranks using `src_rank` and `src_idx`. +class MoECombine : public Expr { + public: + using Expr::Expr; + + MoECombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + MoECombine(const MoECombine& other) = delete; + MoECombine& operator=(const MoECombine& other) = delete; + MoECombine(MoECombine&& other) = delete; + MoECombine& operator=(MoECombine&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoECombine"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkWeights() const { + return output(1)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkWeights() const { + return input(1)->as(); + } + + TensorView* inSrcIdx() const { + return input(2)->as(); + } + + TensorView* inSrcRank() const { + return input(3)->as(); + } + + TensorView* inTokensToRank() const { + return input(4)->as(); + } + + TensorView* inTokensFromRank() const { + return input(5)->as(); + } + + CommunicatorBackend backend() const { + return attribute(0); + } + + private: + void validate(); +}; + // The method "post" triggers the execution of the communication. This call is // non-blocking. The communication can be posted multiple times. // It is assumed that the current device_index (given by diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp new file mode 100644 index 00000000000..7ac888c539a --- /dev/null +++ b/csrc/multidevice/dispatch_combine.cpp @@ -0,0 +1,267 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include "multidevice/dispatch_combine.h" + +#include +#include + +#include + +#include "multidevice/communicator.h" +#include "utils.h" + +namespace nvfuser { +namespace { + +CommunicatorBackend getBackendForDispatch(CommunicatorBackend backend) { + if (backend == CommunicatorBackend::kCuda) { + return CommunicatorBackend::kNccl; + } + return backend; +} + +std::vector toSplitSizes(const at::Tensor& sizes_tensor) { + auto cpu_sizes = sizes_tensor.to(at::kCPU); + auto* ptr = cpu_sizes.data_ptr(); + return std::vector(ptr, ptr + cpu_sizes.numel()); +} + +int64_t sumSplitSizes(const std::vector& splits) { + int64_t total = 0; + for (auto value : splits) { + total += value; + } + return total; +} + +at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { + if (topk.numel() == num_tokens) { + return topk.reshape({num_tokens}); + } + if (topk.dim() == 2 && topk.size(0) == num_tokens && + topk.size(1) == 1) { + return topk.reshape({num_tokens}); + } + NVF_CHECK( + false, + "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], got: ", + topk.sizes()); +} + +void ensureTopk1Assignment(const at::Tensor& is_token_in_rank) { + auto token_counts = is_token_in_rank.to(at::kLong).sum(1); + auto min_val = token_counts.min().item(); + auto max_val = token_counts.max().item(); + NVF_CHECK( + min_val == 1 && max_val == 1, + "Only topk=1 is supported. Each token must be assigned to exactly one rank."); +} + +} // namespace + +DispatchResult dispatchWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_idx, + const at::Tensor& topk_weights, + const at::Tensor& is_token_in_rank, + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Dispatch requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Dispatch input x must be on CUDA."); + NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); + NVF_CHECK( + is_token_in_rank.is_cuda(), + "Dispatch is_token_in_rank must be on CUDA."); + NVF_CHECK( + is_token_in_rank.dim() == 2, + "is_token_in_rank must be 2D [tokens, ranks], got: ", + is_token_in_rank.sizes()); + NVF_CHECK( + x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); + + const int64_t num_tokens = x.size(0); + const int64_t hidden = x.size(1); + const int64_t world_size = communicator->size(); + const int64_t my_rank = communicator->deviceId(); + NVF_CHECK( + is_token_in_rank.size(1) == world_size, + "is_token_in_rank second dim must match world size."); + NVF_CHECK(num_experts % world_size == 0, "num_experts must be divisible."); + + c10::cuda::CUDAGuard device_guard(x.device()); + ensureTopk1Assignment(is_token_in_rank); + + auto topk_idx_flat = flattenTopk(topk_idx, num_tokens); + auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); + + auto rank_for_token = + is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + auto sorted = rank_for_token.sort(); + auto sorted_indices = std::get<1>(sorted); + + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + auto send_src_idx = sorted_indices.to(at::kLong); + auto send_src_rank = at::full( + {num_tokens}, + my_rank, + at::TensorOptions().dtype(at::kLong).device(x.device())); + send_src_rank = send_src_rank.index_select(0, sorted_indices); + + auto rank_for_token_cpu = rank_for_token.to(at::kCPU); + auto n_tokens_to_rank_cpu = + at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); + auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + + CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for dispatch: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + std::vector one_split(world_size, 1); + if (auto work = pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)) { + work->wait(); + } + + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + if (auto work = + pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)) { + work->wait(); + } + + const int64_t experts_per_rank = num_experts / world_size; + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_sorted = local_expert.sort(); + auto expert_order = std::get<1>(expert_sorted); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; +} + +CombineResult combineWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); + NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); + NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); + NVF_CHECK(n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); + NVF_CHECK( + n_tokens_from_rank.is_cuda(), + "Combine n_tokens_from_rank must be CUDA."); + NVF_CHECK(x.dim() == 2, "Combine expects x to be 2D [tokens, hidden]."); + NVF_CHECK( + src_idx.dim() == 1 && src_rank.dim() == 1, + "src_idx and src_rank must be 1D."); + NVF_CHECK( + n_tokens_to_rank.numel() == communicator->size(), + "n_tokens_to_rank must match world size."); + NVF_CHECK( + n_tokens_from_rank.numel() == communicator->size(), + "n_tokens_from_rank must match world size."); + + c10::cuda::CUDAGuard device_guard(x.device()); + + auto sorted = src_rank.sort(); + auto sorted_indices = std::get<1>(sorted); + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights.index_select(0, sorted_indices); + auto send_src_idx = src_idx.index_select(0, sorted_indices); + + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for combine: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + if (auto work = + pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)) { + work->wait(); + } + + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; +} + +} // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h new file mode 100644 index 00000000000..0d8f75c9f6d --- /dev/null +++ b/csrc/multidevice/dispatch_combine.h @@ -0,0 +1,51 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include "multidevice/communicator.h" +#include "visibility.h" + +namespace nvfuser { + +struct DispatchResult { + at::Tensor recv_x; + at::Tensor recv_topk_idx; + at::Tensor recv_topk_weights; + at::Tensor recv_src_idx; + at::Tensor recv_src_rank; + at::Tensor n_tokens_to_rank; + at::Tensor n_tokens_from_rank; +}; + +struct CombineResult { + at::Tensor combined_x; + at::Tensor combined_topk_weights; +}; + +NVF_API DispatchResult dispatchWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_idx, + const at::Tensor& topk_weights, + const at::Tensor& is_token_in_rank, + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend); + +NVF_API CombineResult combineWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend); + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp new file mode 100644 index 00000000000..be13743c8b8 --- /dev/null +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include + +#include "fusion.h" +#include "host_ir/container.h" +#include "host_ir/evaluator.h" +#include "ir/all_nodes.h" +#include "multidevice/communication.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +class DispatchCombineTest : public MultiDeviceTest {}; + +TEST_F(DispatchCombineTest, DispatchCombineTop1) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + constexpr int64_t kNumExpertsPerRank = 2; + const int64_t num_experts = world_size * kNumExpertsPerRank; + constexpr int64_t kNumTokens = 8; + constexpr int64_t kHidden = 4; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in_x = makeSymbolicTensor(2); + auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); + + auto* recv_x = makeSymbolicTensor(2); + auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(1); + auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); + + auto* dispatch = IrBuilder::create( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + in_x, + in_topk_idx, + in_topk_weights, + in_is_token_in_rank, + num_experts, + CommunicatorBackend::kCuda); + + auto* combined_x = makeSymbolicTensor(2); + auto* combined_topk_weights = makeSymbolicTensor(1); + auto* combine = IrBuilder::create( + combined_x, + combined_topk_weights, + recv_x, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + CommunicatorBackend::kCuda); + + hic->pushBackTopLevelExprs(dispatch); + hic->pushBackTopLevelExprs(combine); + + hic->addInput(in_x); + hic->addInput(in_topk_idx); + hic->addInput(in_topk_weights); + hic->addInput(in_is_token_in_rank); + hic->addOutput(combined_x); + + HostIrEvaluator hie(std::move(hic), communicator_); + + auto float_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kFloat); + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto x = at::arange(kNumTokens * kHidden, float_options) + .reshape({kNumTokens, kHidden}) + + static_cast(my_rank) * 1000.0; + auto topk_idx = + (at::arange(kNumTokens, int_options) + my_rank) % num_experts; + auto topk_weights = at::ones({kNumTokens}, float_options); + + auto token_rank = topk_idx.div(kNumExpertsPerRank, "trunc"); + auto rank_ids = at::arange(world_size, int_options); + auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + + auto outputs = hie.runWithInput( + {{in_x, x}, + {in_topk_idx, topk_idx}, + {in_topk_weights, topk_weights}, + {in_is_token_in_rank, is_token_in_rank}}); + auto combined = outputs.back().as(); + + EXPECT_TRUE(at::allclose(combined, x)) + << "Dispatch/Combine mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser From 66e7811afa48f0ce819a66fd3191a699842d4254 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Jan 2026 05:27:05 -0800 Subject: [PATCH 02/18] add comments and cleanup --- csrc/host_ir/evaluator.cpp | 10 +- csrc/multidevice/communication.h | 16 +- csrc/multidevice/dispatch_combine.cpp | 152 +++++++++--------- csrc/multidevice/dispatch_combine.h | 97 +++++++++-- .../cpp/test_multidevice_dispatch_combine.cpp | 21 ++- 5 files changed, 186 insertions(+), 110 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index a847a9d5f99..5f6bb83227d 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -393,14 +393,13 @@ void HostIrEvaluator::handle(MoEDispatch* dispatch) { "A valid communicator must be provided"); auto x = getKnownConcreteValue(dispatch->inX()).as(); - auto topk_idx = - getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as(); auto topk_weights = getKnownConcreteValue(dispatch->inTopkWeights()).as(); auto is_token_in_rank = getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); - auto result = dispatchWithCudaBackend( + auto result = doMoEDispatch( x, topk_idx, topk_weights, @@ -434,7 +433,7 @@ void HostIrEvaluator::handle(MoECombine* combine) { auto n_tokens_from_rank = getKnownConcreteValue(combine->inTokensFromRank()).as(); - auto result = combineWithCudaBackend( + auto result = doMoECombine( x, topk_weights, src_idx, @@ -445,8 +444,7 @@ void HostIrEvaluator::handle(MoECombine* combine) { combine->backend()); expr_evaluator_.bind(combine->outX(), result.combined_x); - expr_evaluator_.bind( - combine->outTopkWeights(), result.combined_topk_weights); + expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights); } void HostIrEvaluator::handle(Wait* wait) { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 9c880110b5e..a3f806b6c64 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -175,7 +175,13 @@ class P2PCommunication : public Expr { }; // Dispatch represents intra-node MoE token dispatch. It shuffles tokens from -// the local rank to destination ranks based on `is_token_in_rank`. +// the local rank to destination ranks based on `in_is_token_in_rank`. +// +// Example shapes (topk=1): +// in_x: [T, H], in_topk_idx: [T] or [T, 1], in_topk_weights: [T] or [T, 1], +// in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. +// Outputs are recv-aligned tensors: out_x/out_topk_*/out_src_* with [T_recv, +// ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape [R]. class MoEDispatch : public Expr { public: using Expr::Expr; @@ -266,7 +272,13 @@ class MoEDispatch : public Expr { }; // Combine represents intra-node MoE token combine. It shuffles tokens back to -// their source ranks using `src_rank` and `src_idx`. +// their source ranks using `in_src_rank` and `in_src_idx`. +// +// Example shapes (topk=1): +// in_x: [T_recv, H], in_topk_weights: [T_recv], in_src_idx: [T_recv], +// in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: +// [R]. Outputs are source-aligned: out_x/out_topk_weights with shape [T_src, +// ...]. class MoECombine : public Expr { public: using Expr::Expr; diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 7ac888c539a..738e27765d9 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -19,13 +19,6 @@ namespace nvfuser { namespace { -CommunicatorBackend getBackendForDispatch(CommunicatorBackend backend) { - if (backend == CommunicatorBackend::kCuda) { - return CommunicatorBackend::kNccl; - } - return backend; -} - std::vector toSplitSizes(const at::Tensor& sizes_tensor) { auto cpu_sizes = sizes_tensor.to(at::kCPU); auto* ptr = cpu_sizes.data_ptr(); @@ -40,32 +33,27 @@ int64_t sumSplitSizes(const std::vector& splits) { return total; } -at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { - if (topk.numel() == num_tokens) { - return topk.reshape({num_tokens}); - } - if (topk.dim() == 2 && topk.size(0) == num_tokens && - topk.size(1) == 1) { - return topk.reshape({num_tokens}); +void waitWork(const c10::intrusive_ptr& work) { + if (work) { + work->wait(); } - NVF_CHECK( - false, - "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], got: ", - topk.sizes()); } -void ensureTopk1Assignment(const at::Tensor& is_token_in_rank) { - auto token_counts = is_token_in_rank.to(at::kLong).sum(1); - auto min_val = token_counts.min().item(); - auto max_val = token_counts.max().item(); +at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { + const bool is_1d = topk.dim() == 1 && topk.size(0) == num_tokens; + const bool is_2d = + topk.dim() == 2 && topk.size(0) == num_tokens && topk.size(1) == 1; NVF_CHECK( - min_val == 1 && max_val == 1, - "Only topk=1 is supported. Each token must be assigned to exactly one rank."); + is_1d || is_2d, + "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], " + "got: ", + topk.sizes()); + return topk.reshape({num_tokens}); } } // namespace -DispatchResult dispatchWithCudaBackend( +DispatchResult doMoEDispatch( const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, @@ -78,14 +66,12 @@ DispatchResult dispatchWithCudaBackend( NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); NVF_CHECK( - is_token_in_rank.is_cuda(), - "Dispatch is_token_in_rank must be on CUDA."); + is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); NVF_CHECK( is_token_in_rank.dim() == 2, "is_token_in_rank must be 2D [tokens, ranks], got: ", is_token_in_rank.sizes()); - NVF_CHECK( - x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); + NVF_CHECK(x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); const int64_t num_tokens = x.size(0); const int64_t hidden = x.size(1); @@ -97,33 +83,49 @@ DispatchResult dispatchWithCudaBackend( NVF_CHECK(num_experts % world_size == 0, "num_experts must be divisible."); c10::cuda::CUDAGuard device_guard(x.device()); - ensureTopk1Assignment(is_token_in_rank); + NVF_CHECK( + [&]() { + auto token_counts = is_token_in_rank.to(at::kLong).sum(1); + auto min_val = token_counts.min().item(); + auto max_val = token_counts.max().item(); + return min_val == 1 && max_val == 1; + }(), + "Only topk=1 is supported. Each token must be assigned to exactly one " + "rank."); auto topk_idx_flat = flattenTopk(topk_idx, num_tokens); auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); - auto rank_for_token = - is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + // Determine destination rank per token (topk=1). + auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + // Sort tokens by destination rank for contiguous alltoall slices. auto sorted = rank_for_token.sort(); auto sorted_indices = std::get<1>(sorted); + // Reorder payloads so alltoall can send contiguous chunks per rank. auto send_x = x.index_select(0, sorted_indices); auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + // Track original token indices and source rank for the combine step. auto send_src_idx = sorted_indices.to(at::kLong); + // All entries are identical, so no relayout is needed. auto send_src_rank = at::full( {num_tokens}, my_rank, at::TensorOptions().dtype(at::kLong).device(x.device())); - send_src_rank = send_src_rank.index_select(0, sorted_indices); + // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we + // sync/copy here. GPU-initiated comms can avoid this extra sync. auto rank_for_token_cpu = rank_for_token.to(at::kCPU); auto n_tokens_to_rank_cpu = at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); - CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + backend == CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoEDispatch."); + CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), "Backend not available for dispatch: ", @@ -131,43 +133,36 @@ DispatchResult dispatchWithCudaBackend( auto* pg = communicator->getWorld(actual_backend); NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + // Exchange per-rank token counts to build split sizes for alltoall. std::vector one_split(world_size, 1); - if (auto work = pg->alltoall_base( - n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)) { - work->wait(); - } + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + // Convert count tensors to CPU split vectors and size the receive buffers. auto input_splits = toSplitSizes(n_tokens_to_rank); auto output_splits = toSplitSizes(n_tokens_from_rank); auto total_recv = sumSplitSizes(output_splits); + // Allocate receive buffers for payloads and metadata. + // TODO: support preallocated buffers. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); - if (auto work = - pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_idx, send_topk_idx, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)) { - work->wait(); - } - + // Alltoall exchange payloads with per-rank splits. + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + // Locally reorder by expert id so each rank processes contiguous experts. const int64_t experts_per_rank = num_experts / world_size; auto local_expert = recv_topk_idx - my_rank * experts_per_rank; auto expert_sorted = local_expert.sort(); @@ -188,7 +183,7 @@ DispatchResult dispatchWithCudaBackend( n_tokens_from_rank}; } -CombineResult combineWithCudaBackend( +CombineResult doMoECombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, @@ -202,10 +197,10 @@ CombineResult combineWithCudaBackend( NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); - NVF_CHECK(n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); NVF_CHECK( - n_tokens_from_rank.is_cuda(), - "Combine n_tokens_from_rank must be CUDA."); + n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); + NVF_CHECK( + n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA."); NVF_CHECK(x.dim() == 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK( src_idx.dim() == 1 && src_rank.dim() == 1, @@ -219,18 +214,23 @@ CombineResult combineWithCudaBackend( c10::cuda::CUDAGuard device_guard(x.device()); + // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted = src_rank.sort(); auto sorted_indices = std::get<1>(sorted); auto send_x = x.index_select(0, sorted_indices); auto send_topk_weights = topk_weights.index_select(0, sorted_indices); auto send_src_idx = src_idx.index_select(0, sorted_indices); + // Split sizes come from dispatch counts. auto input_splits = toSplitSizes(n_tokens_from_rank); auto output_splits = toSplitSizes(n_tokens_to_rank); auto total_recv = sumSplitSizes(output_splits); auto hidden = x.size(1); - CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + backend == CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoECombine."); + CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), "Backend not available for combine: ", @@ -238,27 +238,21 @@ CombineResult combineWithCudaBackend( auto* pg = communicator->getWorld(actual_backend); NVF_CHECK(pg != nullptr, "Combine backend is null."); + // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); - if (auto work = - pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)) { - work->wait(); - } + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); - auto combined_topk_weights = - at::empty({total_recv}, topk_weights.options()); + auto combined_topk_weights = at::empty({total_recv}, topk_weights.options()); combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); return CombineResult{combined_x, combined_topk_weights}; diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 0d8f75c9f6d..5714a45a818 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -15,30 +15,95 @@ namespace nvfuser { struct DispatchResult { - at::Tensor recv_x; - at::Tensor recv_topk_idx; - at::Tensor recv_topk_weights; - at::Tensor recv_src_idx; - at::Tensor recv_src_rank; - at::Tensor n_tokens_to_rank; - at::Tensor n_tokens_from_rank; + at::Tensor recv_x; // Dispatched tokens received on this rank. + at::Tensor recv_topk_idx; // Expert ids aligned with recv_x. + at::Tensor recv_topk_weights; // Gating weights aligned with recv_x. + at::Tensor recv_src_idx; // Source token indices for combine. + at::Tensor recv_src_rank; // Source ranks for combine. + at::Tensor n_tokens_to_rank; // Tokens sent to each rank (this rank's view). + at::Tensor n_tokens_from_rank; // Tokens received from each rank. }; struct CombineResult { - at::Tensor combined_x; - at::Tensor combined_topk_weights; + at::Tensor combined_x; // Combined tokens back in original order. + at::Tensor combined_topk_weights; // Combined gating weights per token. }; -NVF_API DispatchResult dispatchWithCudaBackend( - const at::Tensor& x, - const at::Tensor& topk_idx, - const at::Tensor& topk_weights, - const at::Tensor& is_token_in_rank, +// Dispatch MoE tokens to the owning ranks. Only k=1 is supported for now. +// +// Args: +// x: Token embeddings on this rank, shape [T, H]. +// topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. +// topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. +// is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. +// num_experts: Total experts across all ranks (must be divisible by R). +// communicator: Communicator for alltoall exchange. +// backend: Communication backend (only NCCL is supported for now). +// +// Returns: +// DispatchResult with recv_* tensors on this rank. +// +// Example: +// // world_size=2, num_experts=4, T=4, H=2, topk=1 +// // Experts are partitioned by rank: +// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// // Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: +// // rank0 x = [x0, x1], rank1 x = [x2, x3] +// // token->rank: [0, 1, 1, 1] (rank0 keeps x0, sends x1; rank1 keeps x2,x3) +// // is_token_in_rank = +// // [[1, 0], +// // [0, 1], +// // [0, 1], +// // [0, 1]] +// // topk_idx = [0, 2, 3, 2] (global expert ids) +// // After dispatch on rank0: +// // recv_x has token {0} +// // recv_topk_idx aligned with recv_x (e.g., [0]) +// // recv_src_idx tells original token positions (e.g., [0]) +// // After dispatch on rank1: +// // recv_x has tokens {1, 2, 3} +// // recv_topk_idx aligned with recv_x (e.g., [2, 3, 2]) +// // recv_src_idx tells original token positions (e.g., [1, 2, 3]) +// auto out = doMoEDispatch( +// x, topk_idx, topk_weights, is_token_in_rank, 4, comm, +// CommunicatorBackend::kNccl); +NVF_API DispatchResult doMoEDispatch( + const at::Tensor& x, // [T, H] + const at::Tensor& topk_idx, // [T] or [T, 1] + const at::Tensor& topk_weights, // [T] or [T, 1] + const at::Tensor& is_token_in_rank, // [T, R] int64_t num_experts, Communicator* communicator, CommunicatorBackend backend); -NVF_API CombineResult combineWithCudaBackend( +// Combine dispatched MoE results back to original token order. +// +// Args: +// x: Token embeddings after expert compute, shape [T_recv, H]. +// topk_weights: Gating weights aligned with x, shape [T_recv]. +// src_idx: Original token indices for each row of x, shape [T_recv]. +// src_rank: Original source rank per token, shape [T_recv]. +// n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. +// n_tokens_from_rank: Tokens received from each rank (from dispatch), shape +// [R]. communicator: Communicator for alltoall exchange. backend: +// Communication backend (only NCCL is supported for now). +// +// Returns: +// CombineResult with tokens restored to original order on this rank. +// +// Example: +// // Continuing the dispatch example (experts partitioned by rank): +// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// // After expert compute: +// // rank0 recv_x has token {0} with src_idx = [0], src_rank = [0] +// // rank1 recv_x has tokens {1, 2, 3} with src_idx = [1, 2, 3], +// // src_rank = [0, 1, 1] +// // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. +// // Combine scatters results back to original token order per rank. +// auto combined = doMoECombine( +// x, topk_weights, src_idx, src_rank, n_tokens_to_rank, +// n_tokens_from_rank, comm, CommunicatorBackend::kNccl); +NVF_API CombineResult doMoECombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index be13743c8b8..0d84dbc03e0 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -32,7 +32,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { const int64_t my_rank = communicator_->deviceId(); constexpr int64_t kNumExpertsPerRank = 2; const int64_t num_experts = world_size * kNumExpertsPerRank; - constexpr int64_t kNumTokens = 8; + constexpr int64_t kNumTokens = 4; constexpr int64_t kHidden = 4; auto hic = std::make_unique(); @@ -64,7 +64,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_topk_weights, in_is_token_in_rank, num_experts, - CommunicatorBackend::kCuda); + CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); auto* combined_topk_weights = makeSymbolicTensor(1); @@ -77,7 +77,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, - CommunicatorBackend::kCuda); + CommunicatorBackend::kNccl); hic->pushBackTopLevelExprs(dispatch); hic->pushBackTopLevelExprs(combine); @@ -98,14 +98,21 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto x = at::arange(kNumTokens * kHidden, float_options) .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; - auto topk_idx = - (at::arange(kNumTokens, int_options) + my_rank) % num_experts; + auto topk_idx = at::zeros({kNumTokens}, int_options); auto topk_weights = at::ones({kNumTokens}, float_options); - auto token_rank = topk_idx.div(kNumExpertsPerRank, "trunc"); + // Asymmetric example: + // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. auto rank_ids = at::arange(world_size, int_options); + auto token_rank = at::tensor({0, 1, 1, 1}, int_options); auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. + topk_idx.index_put_({0}, 0); + topk_idx.index_put_({1}, kNumExpertsPerRank); + topk_idx.index_put_({2}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3}, kNumExpertsPerRank); + auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, From afd948d76a2377e9cc4595a47b1a05667a5546b3 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 22 Jan 2026 04:00:33 -0800 Subject: [PATCH 03/18] review --- csrc/host_ir/evaluator.cpp | 8 -- csrc/multidevice/communication.cpp | 28 +----- csrc/multidevice/communication.h | 61 +++++-------- csrc/multidevice/dispatch_combine.cpp | 85 +++++++------------ csrc/multidevice/dispatch_combine.h | 13 +-- .../cpp/test_multidevice_dispatch_combine.cpp | 16 +--- 6 files changed, 62 insertions(+), 149 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 5f6bb83227d..20e96b8c971 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -394,15 +394,12 @@ void HostIrEvaluator::handle(MoEDispatch* dispatch) { auto x = getKnownConcreteValue(dispatch->inX()).as(); auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as(); - auto topk_weights = - getKnownConcreteValue(dispatch->inTopkWeights()).as(); auto is_token_in_rank = getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); auto result = doMoEDispatch( x, topk_idx, - topk_weights, is_token_in_rank, dispatch->numExperts(), communicator_, @@ -410,7 +407,6 @@ void HostIrEvaluator::handle(MoEDispatch* dispatch) { expr_evaluator_.bind(dispatch->outX(), result.recv_x); expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx); - expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights); expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx); expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank); expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank); @@ -424,8 +420,6 @@ void HostIrEvaluator::handle(MoECombine* combine) { "A valid communicator must be provided"); auto x = getKnownConcreteValue(combine->inX()).as(); - auto topk_weights = - getKnownConcreteValue(combine->inTopkWeights()).as(); auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as(); auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as(); auto n_tokens_to_rank = @@ -435,7 +429,6 @@ void HostIrEvaluator::handle(MoECombine* combine) { auto result = doMoECombine( x, - topk_weights, src_idx, src_rank, n_tokens_to_rank, @@ -444,7 +437,6 @@ void HostIrEvaluator::handle(MoECombine* combine) { combine->backend()); expr_evaluator_.bind(combine->outX(), result.combined_x); - expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights); } void HostIrEvaluator::handle(Wait* wait) { diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index febbd519d10..1e59226a93b 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -325,25 +325,21 @@ MoEDispatch::MoEDispatch( IrBuilderPasskey passkey, TensorView* out_x, TensorView* out_topk_idx, - TensorView* out_topk_weights, TensorView* out_src_idx, TensorView* out_src_rank, TensorView* out_n_tokens_to_rank, TensorView* out_n_tokens_from_rank, TensorView* in_x, TensorView* in_topk_idx, - TensorView* in_topk_weights, TensorView* in_is_token_in_rank, int64_t num_experts, CommunicatorBackend backend) : Expr(passkey) { addInput(in_x); addInput(in_topk_idx); - addInput(in_topk_weights); addInput(in_is_token_in_rank); addOutput(out_x); addOutput(out_topk_idx); - addOutput(out_topk_weights); addOutput(out_src_idx); addOutput(out_src_rank); addOutput(out_n_tokens_to_rank); @@ -362,7 +358,6 @@ std::string MoEDispatch::toInlineString(int indent_size) const { << "backend=" << backend() << ", " << "in=" << inX() << ", " << "topk_idx=" << inTopkIdx() << ", " - << "topk_weights=" << inTopkWeights() << ", " << "is_token_in_rank=" << inIsTokenInRank() << ", " << "out=" << outX() << ")"; return ss.str(); @@ -381,20 +376,13 @@ void MoEDispatch::validate() { isIntegralType(*inTopkIdx()->getDataType()), "topk_idx must be integral."); NVF_CHECK( - inTopkWeights()->getDataType().has_value() && - isFloatingPointType(*inTopkWeights()->getDataType()), - "topk_weights must be floating point."); - NVF_CHECK( - inIsTokenInRank()->getDataType() == DataType::Bool, + inIsTokenInRank()->getDataType().has_value() && + inIsTokenInRank()->getDataType() == DataType::Bool, "is_token_in_rank must be Bool."); NVF_CHECK( outTopkIdx()->getDataType().has_value() && isIntegralType(*outTopkIdx()->getDataType()), "out_topk_idx must be integral."); - NVF_CHECK( - outTopkWeights()->getDataType().has_value() && - isFloatingPointType(*outTopkWeights()->getDataType()), - "out_topk_weights must be floating point."); NVF_CHECK( outSrcIdx()->getDataType().has_value() && isIntegralType(*outSrcIdx()->getDataType()), @@ -416,9 +404,7 @@ void MoEDispatch::validate() { MoECombine::MoECombine( IrBuilderPasskey passkey, TensorView* out_x, - TensorView* out_topk_weights, TensorView* in_x, - TensorView* in_topk_weights, TensorView* in_src_idx, TensorView* in_src_rank, TensorView* in_n_tokens_to_rank, @@ -426,13 +412,11 @@ MoECombine::MoECombine( CommunicatorBackend backend) : Expr(passkey) { addInput(in_x); - addInput(in_topk_weights); addInput(in_src_idx); addInput(in_src_rank); addInput(in_n_tokens_to_rank); addInput(in_n_tokens_from_rank); addOutput(out_x); - addOutput(out_topk_weights); addDataAttribute(backend); validate(); } @@ -456,10 +440,6 @@ std::string MoECombine::toString(int indent_size) const { void MoECombine::validate() { NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); - NVF_CHECK( - inTopkWeights()->getDataType().has_value() && - isFloatingPointType(*inTopkWeights()->getDataType()), - "in_topk_weights must be floating point."); NVF_CHECK( inSrcIdx()->getDataType().has_value() && isIntegralType(*inSrcIdx()->getDataType()), @@ -476,10 +456,6 @@ void MoECombine::validate() { inTokensFromRank()->getDataType().has_value() && isIntegralType(*inTokensFromRank()->getDataType()), "in_n_tokens_from_rank must be integral."); - NVF_CHECK( - outTopkWeights()->getDataType().has_value() && - isFloatingPointType(*outTopkWeights()->getDataType()), - "out_topk_weights must be floating point."); } namespace { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index a3f806b6c64..e9544e48e9e 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -175,13 +175,16 @@ class P2PCommunication : public Expr { }; // Dispatch represents intra-node MoE token dispatch. It shuffles tokens from -// the local rank to destination ranks based on `in_is_token_in_rank`. +// the local rank to destination ranks based on explicit routing. // // Example shapes (topk=1): -// in_x: [T, H], in_topk_idx: [T] or [T, 1], in_topk_weights: [T] or [T, 1], +// in_x: [T, H], in_topk_idx: [T] or [T, 1], // in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. -// Outputs are recv-aligned tensors: out_x/out_topk_*/out_src_* with [T_recv, -// ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape [R]. +// topk_weights are intentionally not forwarded; apply them before dispatch or +// after combine. +// Outputs are recv-aligned tensors: out_x/out_topk_idx/out_src_* with +// [T_recv, ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape +// [R]. class MoEDispatch : public Expr { public: using Expr::Expr; @@ -190,17 +193,18 @@ class MoEDispatch : public Expr { IrBuilderPasskey passkey, TensorView* out_x, TensorView* out_topk_idx, - TensorView* out_topk_weights, TensorView* out_src_idx, TensorView* out_src_rank, TensorView* out_n_tokens_to_rank, TensorView* out_n_tokens_from_rank, TensorView* in_x, TensorView* in_topk_idx, - TensorView* in_topk_weights, TensorView* in_is_token_in_rank, int64_t num_experts, CommunicatorBackend backend = CommunicatorBackend::kNccl); + TensorView* inIsTokenInRank() const { + return input(2)->as(); + } MoEDispatch(const MoEDispatch& other) = delete; MoEDispatch& operator=(const MoEDispatch& other) = delete; @@ -223,24 +227,20 @@ class MoEDispatch : public Expr { return output(1)->as(); } - TensorView* outTopkWeights() const { - return output(2)->as(); - } - TensorView* outSrcIdx() const { - return output(3)->as(); + return output(2)->as(); } TensorView* outSrcRank() const { - return output(4)->as(); + return output(3)->as(); } TensorView* outTokensToRank() const { - return output(5)->as(); + return output(4)->as(); } TensorView* outTokensFromRank() const { - return output(6)->as(); + return output(5)->as(); } TensorView* inX() const { @@ -251,14 +251,6 @@ class MoEDispatch : public Expr { return input(1)->as(); } - TensorView* inTopkWeights() const { - return input(2)->as(); - } - - TensorView* inIsTokenInRank() const { - return input(3)->as(); - } - int64_t numExperts() const { return attribute(0); } @@ -275,10 +267,9 @@ class MoEDispatch : public Expr { // their source ranks using `in_src_rank` and `in_src_idx`. // // Example shapes (topk=1): -// in_x: [T_recv, H], in_topk_weights: [T_recv], in_src_idx: [T_recv], -// in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: -// [R]. Outputs are source-aligned: out_x/out_topk_weights with shape [T_src, -// ...]. +// in_x: [T_recv, H], in_src_idx: [T_recv], in_src_rank: [T_recv], +// in_n_tokens_to_rank: [R], in_n_tokens_from_rank: [R]. +// Outputs are source-aligned: out_x with shape [T_src, ...]. class MoECombine : public Expr { public: using Expr::Expr; @@ -286,9 +277,7 @@ class MoECombine : public Expr { MoECombine( IrBuilderPasskey passkey, TensorView* out_x, - TensorView* out_topk_weights, TensorView* in_x, - TensorView* in_topk_weights, TensorView* in_src_idx, TensorView* in_src_rank, TensorView* in_n_tokens_to_rank, @@ -312,32 +301,24 @@ class MoECombine : public Expr { return output(0)->as(); } - TensorView* outTopkWeights() const { - return output(1)->as(); - } - TensorView* inX() const { return input(0)->as(); } - TensorView* inTopkWeights() const { - return input(1)->as(); - } - TensorView* inSrcIdx() const { - return input(2)->as(); + return input(1)->as(); } TensorView* inSrcRank() const { - return input(3)->as(); + return input(2)->as(); } TensorView* inTokensToRank() const { - return input(4)->as(); + return input(3)->as(); } TensorView* inTokensFromRank() const { - return input(5)->as(); + return input(4)->as(); } CommunicatorBackend backend() const { diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 738e27765d9..a712b11b35d 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -39,24 +39,11 @@ void waitWork(const c10::intrusive_ptr& work) { } } -at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { - const bool is_1d = topk.dim() == 1 && topk.size(0) == num_tokens; - const bool is_2d = - topk.dim() == 2 && topk.size(0) == num_tokens && topk.size(1) == 1; - NVF_CHECK( - is_1d || is_2d, - "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], " - "got: ", - topk.sizes()); - return topk.reshape({num_tokens}); -} - } // namespace DispatchResult doMoEDispatch( const at::Tensor& x, const at::Tensor& topk_idx, - const at::Tensor& topk_weights, const at::Tensor& is_token_in_rank, int64_t num_experts, Communicator* communicator, @@ -64,24 +51,27 @@ DispatchResult doMoEDispatch( NVF_CHECK(communicator != nullptr, "Dispatch requires a valid communicator."); NVF_CHECK(x.is_cuda(), "Dispatch input x must be on CUDA."); NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); - NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); NVF_CHECK( is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); - NVF_CHECK( - is_token_in_rank.dim() == 2, - "is_token_in_rank must be 2D [tokens, ranks], got: ", + NVF_CHECK_EQ( + is_token_in_rank.dim(), + 2, + "is_token_in_rank must be [tokens, ranks], got: ", is_token_in_rank.sizes()); - NVF_CHECK(x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); + NVF_CHECK_EQ(x.dim(), 2, "Dispatch expects x to be 2D [tokens, hidden]."); const int64_t num_tokens = x.size(0); const int64_t hidden = x.size(1); const int64_t world_size = communicator->size(); const int64_t my_rank = communicator->deviceId(); - NVF_CHECK( - is_token_in_rank.size(1) == world_size, + NVF_CHECK_EQ( + is_token_in_rank.size(1), + world_size, "is_token_in_rank second dim must match world size."); - NVF_CHECK(num_experts % world_size == 0, "num_experts must be divisible."); + NVF_CHECK_EQ(num_experts % world_size, 0, "num_experts must be divisible."); + const int64_t experts_per_rank = num_experts / world_size; + // Ensure subsequent allocations/ops are on x's device. c10::cuda::CUDAGuard device_guard(x.device()); NVF_CHECK( [&]() { @@ -93,19 +83,23 @@ DispatchResult doMoEDispatch( "Only topk=1 is supported. Each token must be assigned to exactly one " "rank."); - auto topk_idx_flat = flattenTopk(topk_idx, num_tokens); - auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); + const bool topk_is_1d = topk_idx.dim() == 1 && topk_idx.size(0) == num_tokens; + const bool topk_is_2d = topk_idx.dim() == 2 && + topk_idx.size(0) == num_tokens && topk_idx.size(1) == 1; + NVF_CHECK( + topk_is_1d || topk_is_2d, + "Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ", + topk_idx.sizes()); + auto topk_idx_flat = topk_idx.reshape({num_tokens}); // Determine destination rank per token (topk=1). auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); // Sort tokens by destination rank for contiguous alltoall slices. - auto sorted = rank_for_token.sort(); - auto sorted_indices = std::get<1>(sorted); + auto sorted_indices = at::argsort(rank_for_token); // Reorder payloads so alltoall can send contiguous chunks per rank. auto send_x = x.index_select(0, sorted_indices); auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); - auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); // Track original token indices and source rank for the combine step. auto send_src_idx = sorted_indices.to(at::kLong); // All entries are identical, so no relayout is needed. @@ -147,7 +141,6 @@ DispatchResult doMoEDispatch( // TODO: support preallocated buffers. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); @@ -155,28 +148,22 @@ DispatchResult doMoEDispatch( waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_topk_idx, send_topk_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_src_idx, send_src_idx, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_src_rank, send_src_rank, output_splits, input_splits)); // Locally reorder by expert id so each rank processes contiguous experts. - const int64_t experts_per_rank = num_experts / world_size; auto local_expert = recv_topk_idx - my_rank * experts_per_rank; - auto expert_sorted = local_expert.sort(); - auto expert_order = std::get<1>(expert_sorted); + auto expert_order = at::argsort(local_expert); recv_x = recv_x.index_select(0, expert_order); recv_topk_idx = recv_topk_idx.index_select(0, expert_order); - recv_topk_weights = recv_topk_weights.index_select(0, expert_order); recv_src_idx = recv_src_idx.index_select(0, expert_order); recv_src_rank = recv_src_rank.index_select(0, expert_order); return DispatchResult{ recv_x, recv_topk_idx, - recv_topk_weights, recv_src_idx, recv_src_rank, n_tokens_to_rank, @@ -185,7 +172,6 @@ DispatchResult doMoEDispatch( CombineResult doMoECombine( const at::Tensor& x, - const at::Tensor& topk_weights, const at::Tensor& src_idx, const at::Tensor& src_rank, const at::Tensor& n_tokens_to_rank, @@ -194,31 +180,29 @@ CombineResult doMoECombine( CommunicatorBackend backend) { NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); - NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); NVF_CHECK( n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); NVF_CHECK( n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA."); - NVF_CHECK(x.dim() == 2, "Combine expects x to be 2D [tokens, hidden]."); - NVF_CHECK( - src_idx.dim() == 1 && src_rank.dim() == 1, - "src_idx and src_rank must be 1D."); - NVF_CHECK( - n_tokens_to_rank.numel() == communicator->size(), + NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); + NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); + NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); + NVF_CHECK_EQ( + n_tokens_to_rank.numel(), + communicator->size(), "n_tokens_to_rank must match world size."); - NVF_CHECK( - n_tokens_from_rank.numel() == communicator->size(), + NVF_CHECK_EQ( + n_tokens_from_rank.numel(), + communicator->size(), "n_tokens_from_rank must match world size."); c10::cuda::CUDAGuard device_guard(x.device()); // Sort by source rank so alltoall can send contiguous chunks per rank. - auto sorted = src_rank.sort(); - auto sorted_indices = std::get<1>(sorted); + auto sorted_indices = at::argsort(src_rank); auto send_x = x.index_select(0, sorted_indices); - auto send_topk_weights = topk_weights.index_select(0, sorted_indices); auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. @@ -240,22 +224,17 @@ CombineResult doMoECombine( // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_src_idx, send_src_idx, output_splits, input_splits)); // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); - auto combined_topk_weights = at::empty({total_recv}, topk_weights.options()); - combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); - return CombineResult{combined_x, combined_topk_weights}; + return CombineResult{combined_x}; } } // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 5714a45a818..f924dd7347c 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -17,7 +17,6 @@ namespace nvfuser { struct DispatchResult { at::Tensor recv_x; // Dispatched tokens received on this rank. at::Tensor recv_topk_idx; // Expert ids aligned with recv_x. - at::Tensor recv_topk_weights; // Gating weights aligned with recv_x. at::Tensor recv_src_idx; // Source token indices for combine. at::Tensor recv_src_rank; // Source ranks for combine. at::Tensor n_tokens_to_rank; // Tokens sent to each rank (this rank's view). @@ -26,7 +25,6 @@ struct DispatchResult { struct CombineResult { at::Tensor combined_x; // Combined tokens back in original order. - at::Tensor combined_topk_weights; // Combined gating weights per token. }; // Dispatch MoE tokens to the owning ranks. Only k=1 is supported for now. @@ -34,7 +32,8 @@ struct CombineResult { // Args: // x: Token embeddings on this rank, shape [T, H]. // topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. -// topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. +// topk_weights: Apply gating weights either before dispatch or after combine. +// They are intentionally not forwarded through dispatch/combination. // is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. // num_experts: Total experts across all ranks (must be divisible by R). // communicator: Communicator for alltoall exchange. @@ -65,12 +64,10 @@ struct CombineResult { // // recv_topk_idx aligned with recv_x (e.g., [2, 3, 2]) // // recv_src_idx tells original token positions (e.g., [1, 2, 3]) // auto out = doMoEDispatch( -// x, topk_idx, topk_weights, is_token_in_rank, 4, comm, -// CommunicatorBackend::kNccl); +// x, topk_idx, is_token_in_rank, 4, comm, CommunicatorBackend::kNccl); NVF_API DispatchResult doMoEDispatch( const at::Tensor& x, // [T, H] const at::Tensor& topk_idx, // [T] or [T, 1] - const at::Tensor& topk_weights, // [T] or [T, 1] const at::Tensor& is_token_in_rank, // [T, R] int64_t num_experts, Communicator* communicator, @@ -80,7 +77,6 @@ NVF_API DispatchResult doMoEDispatch( // // Args: // x: Token embeddings after expert compute, shape [T_recv, H]. -// topk_weights: Gating weights aligned with x, shape [T_recv]. // src_idx: Original token indices for each row of x, shape [T_recv]. // src_rank: Original source rank per token, shape [T_recv]. // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. @@ -101,11 +97,10 @@ NVF_API DispatchResult doMoEDispatch( // // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. // // Combine scatters results back to original token order per rank. // auto combined = doMoECombine( -// x, topk_weights, src_idx, src_rank, n_tokens_to_rank, +// x, src_idx, src_rank, n_tokens_to_rank, // n_tokens_from_rank, comm, CommunicatorBackend::kNccl); NVF_API CombineResult doMoECombine( const at::Tensor& x, - const at::Tensor& topk_weights, const at::Tensor& src_idx, const at::Tensor& src_rank, const at::Tensor& n_tokens_to_rank, diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 0d84dbc03e0..dc22ac9f8bf 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -5,12 +5,12 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include -#include - #include #include +#include +#include + #include "fusion.h" #include "host_ir/container.h" #include "host_ir/evaluator.h" @@ -40,12 +40,10 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto* in_x = makeSymbolicTensor(2); auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); - auto* in_topk_weights = makeSymbolicTensor(1); auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); auto* recv_x = makeSymbolicTensor(2); auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); - auto* recv_topk_weights = makeSymbolicTensor(1); auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); @@ -54,25 +52,20 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto* dispatch = IrBuilder::create( recv_x, recv_topk_idx, - recv_topk_weights, recv_src_idx, recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, in_x, in_topk_idx, - in_topk_weights, in_is_token_in_rank, num_experts, CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); - auto* combined_topk_weights = makeSymbolicTensor(1); auto* combine = IrBuilder::create( combined_x, - combined_topk_weights, recv_x, - recv_topk_weights, recv_src_idx, recv_src_rank, n_tokens_to_rank, @@ -84,7 +77,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { hic->addInput(in_x); hic->addInput(in_topk_idx); - hic->addInput(in_topk_weights); hic->addInput(in_is_token_in_rank); hic->addOutput(combined_x); @@ -99,7 +91,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; auto topk_idx = at::zeros({kNumTokens}, int_options); - auto topk_weights = at::ones({kNumTokens}, float_options); // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. @@ -116,7 +107,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, - {in_topk_weights, topk_weights}, {in_is_token_in_rank, is_token_in_rank}}); auto combined = outputs.back().as(); From dda9aa7c2be35ef1e604fb12b63d8a5278834657 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 22 Jan 2026 09:33:18 -0800 Subject: [PATCH 04/18] add kernel based a2av and cuda backend for d/c --- CMakeLists.txt | 2 + csrc/multidevice/alltoallv.cu | 37 ++ csrc/multidevice/cuda_p2p.cpp | 315 ++++++++++++++++++ csrc/multidevice/cuda_p2p.h | 29 ++ csrc/multidevice/dispatch_combine.cpp | 309 +++++++++++++---- csrc/multidevice/dispatch_combine.h | 4 +- tests/cpp/test_multidevice_alltoallv.cpp | 82 +++++ .../cpp/test_multidevice_dispatch_combine.cpp | 20 +- 8 files changed, 726 insertions(+), 72 deletions(-) create mode 100644 csrc/multidevice/alltoallv.cu create mode 100644 tests/cpp/test_multidevice_alltoallv.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b325b325d9c..ff76e741b4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1144,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_alltoallv.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp @@ -1393,6 +1394,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/mbarrier.cu ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/multicast.cu + ${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor_memory.cu ${NVFUSER_ROOT}/runtime/tensor.cu diff --git a/csrc/multidevice/alltoallv.cu b/csrc/multidevice/alltoallv.cu new file mode 100644 index 00000000000..9725794f838 --- /dev/null +++ b/csrc/multidevice/alltoallv.cu @@ -0,0 +1,37 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +extern "C" __global__ void alltoallv_kernel( + const unsigned char* send, + const unsigned long long* recv_ptrs, + const long long* send_offsets, + const long long* send_sizes, + const long long* recv_offsets, + long long world_size, + long long elem_size, + long long max_send_bytes) { + const long long peer = static_cast(blockIdx.y); + if (peer >= world_size) { + return; + } + const long long bytes = send_sizes[peer] * elem_size; + if (bytes == 0) { + return; + } + const long long idx = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= bytes) { + return; + } + const long long send_byte_offset = send_offsets[peer] * elem_size + idx; + const long long recv_byte_offset = recv_offsets[peer] * elem_size + idx; + auto* dst = reinterpret_cast( + static_cast(recv_ptrs[peer])); + dst[recv_byte_offset] = send[send_byte_offset]; +} + diff --git a/csrc/multidevice/cuda_p2p.cpp b/csrc/multidevice/cuda_p2p.cpp index 6ad709fa062..8804c1a7a79 100644 --- a/csrc/multidevice/cuda_p2p.cpp +++ b/csrc/multidevice/cuda_p2p.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include "multidevice/cuda_p2p.h" +#include "nvfuser_resources/alltoallv.h" #include "nvfuser_resources/multicast.h" #include "cuda_utils.h" @@ -34,6 +35,143 @@ P2pProtocol getP2pProtocol() { } namespace { +void launchAlltoallvKernel( + const void* send, + const uint64_t* recv_ptrs, + const int64_t* send_offsets, + const int64_t* send_sizes, + const int64_t* recv_offsets, + int64_t world_size, + int64_t elem_size, + int64_t max_send_bytes, + CUstream stream) { + static CUmodule module = nullptr; + static CUfunction kernel = nullptr; + + if (module == nullptr) { + nvrtcProgram prog; + NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( + &prog, + nvfuser_resources::alltoallv_cu, + "alltoallv.cu", + 0, + nullptr, + nullptr)); + + int major = 0; + int minor = 0; + int device = 0; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); + cudaDeviceProp prop; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); + major = prop.major; + minor = prop.minor; + + std::string arch_arg = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + std::vector opts = {arch_arg.c_str(), "--std=c++17"}; + // NVRTC needs CUDA headers to compile alltoallv.cu. + opts.push_back("-I/usr/local/cuda/include"); + opts.push_back("-I/usr/local/cuda/include/cccl"); + + nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); + if (res != NVRTC_SUCCESS) { + size_t logSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); + std::vector log(logSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); + NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data()); + } + + size_t ptxSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); + std::vector ptx(ptxSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); + NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); + + CUresult load_result = cuModuleLoadData(&module, ptx.data()); + if (load_result != CUDA_SUCCESS) { + constexpr size_t kLogSize = 8192; + char error_log[kLogSize]; + char info_log[kLogSize]; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_INFO_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_LOG_VERBOSE}; + void* option_values[] = { + (void*)error_log, + (void*)kLogSize, + (void*)info_log, + (void*)kLogSize, + (void*)1}; + cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values); + NVF_ERROR( + false, + "Alltoallv kernel module load failed with error: ", + load_result, + "\nInfo Log:\n", + info_log, + "\nError Log:\n", + error_log); + } + + NVFUSER_CUDA_SAFE_CALL( + cuModuleGetFunction(&kernel, module, "alltoallv_kernel")); + } + + if (max_send_bytes == 0) { + return; + } + + constexpr int kThreads = 256; + const int64_t blocks_x = (max_send_bytes + kThreads - 1) / kThreads; + void* args_kernel[] = { + const_cast(static_cast(&send)), + const_cast(static_cast(&recv_ptrs)), + const_cast(static_cast(&send_offsets)), + const_cast(static_cast(&send_sizes)), + const_cast(static_cast(&recv_offsets)), + &world_size, + &elem_size, + &max_send_bytes}; + NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( + kernel, + blocks_x, + static_cast(world_size), + 1, + kThreads, + 1, + 1, + 0, + stream, + args_kernel, + nullptr)); +} + +std::vector serializeInt64Vector(const std::vector& values) { + std::vector bytes(values.size() * sizeof(int64_t)); + std::memcpy(bytes.data(), values.data(), bytes.size()); + return bytes; +} + +std::vector deserializeInt64Vector(const std::vector& bytes) { + NVF_CHECK( + bytes.size() % sizeof(int64_t) == 0, "Invalid int64 byte buffer size."); + const size_t count = bytes.size() / sizeof(int64_t); + std::vector values(count); + std::memcpy(values.data(), bytes.data(), bytes.size()); + return values; +} + +std::string alltoallvCountsKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank); +} + +std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); +} void launchMulticastKernel( void* dst, @@ -710,4 +848,181 @@ void waitWithCudaBackend( } } +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + const int64_t world_size = comm.size(); + const int64_t my_rank = comm.deviceId(); + NVF_CHECK( + send_counts.is_cuda(), "alltoallv send_counts must be CUDA tensor."); + NVF_CHECK( + send_counts.dim() == 1 && send_counts.numel() == world_size, + "alltoallv send_counts must be 1D [R]."); + + auto store = comm.getTcpStore(); + auto send_counts_cpu = send_counts.to(at::kCPU); + auto* send_ptr = send_counts_cpu.data_ptr(); + std::vector send_counts_vec(send_ptr, send_ptr + world_size); + + store->set( + alltoallvCountsKey(tag, my_rank), serializeInt64Vector(send_counts_vec)); + + std::vector> counts_matrix(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + auto bytes = store->get(alltoallvCountsKey(tag, rank)); + counts_matrix[rank] = deserializeInt64Vector(bytes); + NVF_CHECK( + (int64_t)counts_matrix[rank].size() == world_size, + "Invalid alltoallv counts size."); + } + comm.barrier(); + for (int64_t rank = 0; rank < world_size; ++rank) { + store->deleteKey(alltoallvCountsKey(tag, rank)); + } + + std::vector recv_counts_vec(world_size, 0); + for (int64_t sender = 0; sender < world_size; ++sender) { + recv_counts_vec[sender] = counts_matrix[sender][my_rank]; + } + + std::vector send_offsets_vec(world_size, 0); + int64_t prefix = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + send_offsets_vec[rank] = prefix; + prefix += send_counts_vec[rank]; + } + + std::vector recv_offsets_vec(world_size, 0); + for (int64_t peer = 0; peer < world_size; ++peer) { + int64_t offset = 0; + for (int64_t sender = 0; sender < my_rank; ++sender) { + offset += counts_matrix[sender][peer]; + } + recv_offsets_vec[peer] = offset; + } + + int64_t total_recv = 0; + for (auto value : recv_counts_vec) { + total_recv += value; + } + + int64_t max_recv = 0; + int64_t max_send_total = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t sender = 0; sender < world_size; ++sender) { + total += counts_matrix[sender][rank]; + } + if (total > max_recv) { + max_recv = total; + } + } + + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t dest = 0; dest < world_size; ++dest) { + total += counts_matrix[rank][dest]; + } + if (total > max_send_total) { + max_send_total = total; + } + } + + int64_t max_send = 0; + for (auto value : send_counts_vec) { + if (value > max_send) { + max_send = value; + } + } + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto send_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + send_offsets_cpu.data_ptr(), + send_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_offsets_cpu.data_ptr(), + recv_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_counts_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_counts_cpu.data_ptr(), + recv_counts_vec.data(), + world_size * sizeof(int64_t)); + + AlltoallvMetadata metadata; + metadata.send_counts = send_counts; + metadata.recv_counts = recv_counts_cpu.to(send_counts.device()); + metadata.send_offsets = send_offsets_cpu.to(send_counts.device()); + metadata.recv_offsets = recv_offsets_cpu.to(send_counts.device()); + metadata.total_recv = total_recv; + metadata.max_recv = max_recv; + metadata.max_send_total = max_send_total; + metadata.max_send_bytes = max_send; + metadata.world_size = world_size; + return metadata; +} + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream) { + NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA."); + NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); + NVF_CHECK( + (int64_t)recv_ptrs.size() == metadata.world_size, + "recv_ptrs size must match world size."); + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); + auto* ptrs = recv_ptrs_cpu.data_ptr(); + for (int64_t rank = 0; rank < metadata.world_size; ++rank) { + ptrs[rank] = + static_cast(reinterpret_cast(recv_ptrs[rank])); + } + auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); + + const int64_t elem_stride = + metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; + NVF_CHECK( + metadata.max_send_total == 0 || + send.numel() % metadata.max_send_total == 0, + "alltoallv send numel must be divisible by max_send_total."); + NVF_CHECK( + metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, + "alltoallv recv numel must be divisible by max_recv."); + + auto send_offsets = metadata.send_offsets; + auto send_counts = metadata.send_counts; + auto recv_offsets = metadata.recv_offsets; + int64_t max_send_bytes = metadata.max_send_bytes; + if (elem_stride > 1) { + send_offsets = metadata.send_offsets * elem_stride; + send_counts = metadata.send_counts * elem_stride; + recv_offsets = metadata.recv_offsets * elem_stride; + max_send_bytes = metadata.max_send_bytes * elem_stride; + } + + launchAlltoallvKernel( + send.data_ptr(), + reinterpret_cast(recv_ptrs_cuda.data_ptr()), + send_offsets.data_ptr(), + send_counts.data_ptr(), + recv_offsets.data_ptr(), + metadata.world_size, + send.element_size(), + max_send_bytes * send.element_size(), + stream); +} + +void alltoallvBarrier(const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + comm.barrier(); +} + } // namespace nvfuser diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index 4947e4e6ee1..e9fd5828597 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -9,6 +9,10 @@ #include +#include +#include +#include + #include "multidevice/ipc_handle.h" namespace nvfuser { @@ -43,4 +47,29 @@ void waitWithCudaBackend( CUstream stream, int64_t root); +struct AlltoallvMetadata { + at::Tensor send_counts; // CUDA [R] + at::Tensor recv_counts; // CUDA [R] + at::Tensor send_offsets; // CUDA [R] + at::Tensor recv_offsets; // CUDA [R] + int64_t total_recv = 0; + int64_t max_recv = 0; + int64_t max_send_total = 0; + int64_t max_send_bytes = 0; + int64_t world_size = 0; +}; + +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag); + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream); + +void alltoallvBarrier(const std::string& tag); + } // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 738e27765d9..cbad812aa06 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -11,9 +11,12 @@ #include #include +#include #include #include "multidevice/communicator.h" +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" #include "utils.h" namespace nvfuser { @@ -114,53 +117,160 @@ DispatchResult doMoEDispatch( my_rank, at::TensorOptions().dtype(at::kLong).device(x.device())); - // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we - // sync/copy here. GPU-initiated comms can avoid this extra sync. + // Split metadata is exchanged via CPU (TCPStore), so we sync/copy here. auto rank_for_token_cpu = rank_for_token.to(at::kCPU); auto n_tokens_to_rank_cpu = at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); - auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for dispatch: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + std::vector one_split(world_size, 1); + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = + at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + const int64_t experts_per_rank = num_experts / world_size; + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_sorted = local_expert.sort(); + auto expert_order = std::get<1>(expert_sorted); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoEDispatch."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for dispatch: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Dispatch backend is null."); - - // Exchange per-rank token counts to build split sizes for alltoall. - std::vector one_split(world_size, 1); - waitWork(pg->alltoall_base( - n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); - - // Convert count tensors to CPU split vectors and size the receive buffers. - auto input_splits = toSplitSizes(n_tokens_to_rank); - auto output_splits = toSplitSizes(n_tokens_from_rank); - auto total_recv = sumSplitSizes(output_splits); - - // Allocate receive buffers for payloads and metadata. - // TODO: support preallocated buffers. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); - auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); - auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); - - // Alltoall exchange payloads with per-rank splits. - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_idx, send_topk_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoEDispatch."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_to_rank, "moe_dispatch_counts"); + auto n_tokens_from_rank = metadata.recv_counts; + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, num_tokens).copy_(send_x); + auto send_topk_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_idx_flat.scalar_type(), x.device()); + send_topk_idx_sym.narrow(0, 0, num_tokens).copy_(send_topk_idx); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights_flat.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, num_tokens).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, num_tokens).copy_(send_src_idx); + auto send_src_rank_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_rank.scalar_type(), x.device()); + send_src_rank_sym.narrow(0, 0, num_tokens).copy_(send_src_rank); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_idx_sym = SymmetricTensor::allocate( + {max_recv}, topk_idx_flat.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights_flat.scalar_type(), x.device()); + auto recv_src_idx_sym = SymmetricTensor::allocate( + {max_recv}, send_src_idx.scalar_type(), x.device()); + auto recv_src_rank_sym = SymmetricTensor::allocate( + {max_recv}, send_src_rank.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_idx_handle(recv_topk_idx_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + SymmetricTensor recv_src_rank_handle(recv_src_rank_sym); + recv_x_handle.setupRemoteHandles("moe_dispatch_recv_x"); + recv_topk_idx_handle.setupRemoteHandles("moe_dispatch_recv_topk_idx"); + recv_topk_weights_handle.setupRemoteHandles("moe_dispatch_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_dispatch_recv_src_idx"); + recv_src_rank_handle.setupRemoteHandles("moe_dispatch_recv_src_rank"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_idx_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + std::vector recv_src_rank_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_idx_ptrs[rank] = + recv_topk_idx_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + recv_src_rank_ptrs[rank] = + recv_src_rank_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_idx_sym, + recv_topk_idx_sym, + metadata, + recv_topk_idx_ptrs, + stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvWithCudaBackend( + send_src_rank_sym, + recv_src_rank_sym, + metadata, + recv_src_rank_ptrs, + stream); + alltoallvBarrier("moe_dispatch_counts"); + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_idx = recv_topk_idx_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); + auto recv_src_rank = recv_src_rank_sym.narrow(0, 0, total_recv); // Locally reorder by expert id so each rank processes contiguous experts. const int64_t experts_per_rank = num_experts / world_size; @@ -212,6 +322,7 @@ CombineResult doMoECombine( n_tokens_from_rank.numel() == communicator->size(), "n_tokens_from_rank must match world size."); + const int64_t world_size = communicator->size(); c10::cuda::CUDAGuard device_guard(x.device()); // Sort by source rank so alltoall can send contiguous chunks per rank. @@ -222,32 +333,100 @@ CombineResult doMoECombine( auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. - auto input_splits = toSplitSizes(n_tokens_from_rank); - auto output_splits = toSplitSizes(n_tokens_to_rank); - auto total_recv = sumSplitSizes(output_splits); - auto hidden = x.size(1); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for combine: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoECombine."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for combine: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Combine backend is null."); - - // Allocate receive buffers and exchange payloads back to source ranks. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); - auto recv_src_idx = at::empty({total_recv}, src_idx.options()); - - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoECombine."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_from_rank, "moe_combine_counts"); + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + auto hidden = x.size(1); + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, x.size(0)).copy_(send_x); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, x.size(0)).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, x.size(0)).copy_(send_src_idx); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights.scalar_type(), x.device()); + auto recv_src_idx_sym = + SymmetricTensor::allocate({max_recv}, src_idx.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + recv_x_handle.setupRemoteHandles("moe_combine_recv_x"); + recv_topk_weights_handle.setupRemoteHandles("moe_combine_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_combine_recv_src_idx"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvBarrier("moe_combine_counts"); + + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 5714a45a818..ceb0a2652b4 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -38,7 +38,7 @@ struct CombineResult { // is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. // num_experts: Total experts across all ranks (must be divisible by R). // communicator: Communicator for alltoall exchange. -// backend: Communication backend (only NCCL is supported for now). +// backend: Communication backend (CUDA or NCCL). // // Returns: // DispatchResult with recv_* tensors on this rank. @@ -86,7 +86,7 @@ NVF_API DispatchResult doMoEDispatch( // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. // n_tokens_from_rank: Tokens received from each rank (from dispatch), shape // [R]. communicator: Communicator for alltoall exchange. backend: -// Communication backend (only NCCL is supported for now). +// Communication backend (CUDA or NCCL). // // Returns: // CombineResult with tokens restored to original order on this rank. diff --git a/tests/cpp/test_multidevice_alltoallv.cpp b/tests/cpp/test_multidevice_alltoallv.cpp new file mode 100644 index 00000000000..02cb21b7892 --- /dev/null +++ b/tests/cpp/test_multidevice_alltoallv.cpp @@ -0,0 +1,82 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include + +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +class AlltoallvCudaTest : public MultiDeviceTest {}; + +TEST_F(AlltoallvCudaTest, AlltoallvAsymmetric) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto count_for = [](int64_t sender, int64_t dest) { + return (sender + dest) % 3 + 1; + }; + auto send_counts = at::empty({world_size}, int_options); + for (int64_t dest = 0; dest < world_size; ++dest) { + send_counts.index_put_({dest}, count_for(my_rank, dest)); + } + + auto metadata = prepareAlltoallvMetadata(send_counts, "test_alltoallv_counts"); + const int64_t max_recv = metadata.max_recv; + const int64_t total_send = send_counts.sum().item(); + auto send_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, at::kLong, communicator_->device()); + send_sym.narrow(0, 0, total_send) + .copy_(at::arange(total_send, int_options) + my_rank * 1000); + + auto recv_sym = SymmetricTensor::allocate( + {max_recv}, at::kLong, communicator_->device()); + SymmetricTensor recv_handle(recv_sym); + recv_handle.setupRemoteHandles("test_alltoallv_recv"); + + std::vector recv_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_ptrs[rank] = recv_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = at::cuda::getDefaultCUDAStream().stream(); + alltoallvWithCudaBackend(send_sym, recv_sym, metadata, recv_ptrs, stream); + alltoallvBarrier("test_alltoallv_counts"); + + auto recv_view = recv_sym.narrow(0, 0, metadata.total_recv); + std::vector expected_vec; + expected_vec.reserve(static_cast(metadata.total_recv)); + for (int64_t sender = 0; sender < world_size; ++sender) { + int64_t offset = 0; + for (int64_t dest = 0; dest < my_rank; ++dest) { + offset += count_for(sender, dest); + } + const int64_t count = count_for(sender, my_rank); + for (int64_t i = 0; i < count; ++i) { + expected_vec.push_back(offset + i + sender * 1000); + } + } + auto expected = at::tensor(expected_vec, int_options); + EXPECT_TRUE(at::equal(recv_view, expected)) + << "Alltoallv mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 0d84dbc03e0..1a28c6e18d5 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -21,15 +21,21 @@ namespace nvfuser { namespace hir { -class DispatchCombineTest : public MultiDeviceTest {}; +class DispatchCombineTest + : public MultiDeviceTest, + public ::testing::WithParamInterface {}; -TEST_F(DispatchCombineTest, DispatchCombineTop1) { +TEST_P(DispatchCombineTest, DispatchCombineTop1) { if (!communicator_->is_available() || communicator_->size() < 2) { GTEST_SKIP() << "This test needs at least 2 ranks."; } const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); + const auto backend = GetParam(); + if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { + GTEST_SKIP() << "Backend " << backend << " not available."; + } constexpr int64_t kNumExpertsPerRank = 2; const int64_t num_experts = world_size * kNumExpertsPerRank; constexpr int64_t kNumTokens = 4; @@ -64,7 +70,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_topk_weights, in_is_token_in_rank, num_experts, - CommunicatorBackend::kNccl); + backend); auto* combined_x = makeSymbolicTensor(2); auto* combined_topk_weights = makeSymbolicTensor(1); @@ -77,7 +83,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, - CommunicatorBackend::kNccl); + backend); hic->pushBackTopLevelExprs(dispatch); hic->pushBackTopLevelExprs(combine); @@ -119,10 +125,14 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { {in_topk_weights, topk_weights}, {in_is_token_in_rank, is_token_in_rank}}); auto combined = outputs.back().as(); - EXPECT_TRUE(at::allclose(combined, x)) << "Dispatch/Combine mismatch on rank " << my_rank; } +INSTANTIATE_TEST_SUITE_P( + DispatchCombineBackends, + DispatchCombineTest, + ::testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kCuda)); + } // namespace hir } // namespace nvfuser From ba6612d6724aa1d6d3f2da8704f74d6f0742999f Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 26 Jan 2026 06:45:31 -0800 Subject: [PATCH 05/18] minor comments --- csrc/multidevice/communication.h | 2 ++ csrc/multidevice/dispatch_combine.cpp | 25 ++++++------------- csrc/multidevice/dispatch_combine.h | 3 ++- .../cpp/test_multidevice_dispatch_combine.cpp | 3 +-- 4 files changed, 12 insertions(+), 21 deletions(-) diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index e9544e48e9e..e801707e5f9 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -182,6 +182,8 @@ class P2PCommunication : public Expr { // in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. // topk_weights are intentionally not forwarded; apply them before dispatch or // after combine. +// out_src_idx/out_src_rank are returned for the combine step to restore the +// original token order. // Outputs are recv-aligned tensors: out_x/out_topk_idx/out_src_* with // [T_recv, ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape // [R]. diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index a712b11b35d..9b2edb92a4e 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -8,13 +8,10 @@ #include "multidevice/dispatch_combine.h" -#include #include -#include - +#include "exceptions.h" #include "multidevice/communicator.h" -#include "utils.h" namespace nvfuser { namespace { @@ -53,6 +50,12 @@ DispatchResult doMoEDispatch( NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); NVF_CHECK( is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); + NVF_CHECK( + x.device() == topk_idx.device(), + "Dispatch expects x and topk_idx on the same device."); + NVF_CHECK( + x.device() == is_token_in_rank.device(), + "Dispatch expects x and is_token_in_rank on the same device."); NVF_CHECK_EQ( is_token_in_rank.dim(), 2, @@ -71,18 +74,6 @@ DispatchResult doMoEDispatch( NVF_CHECK_EQ(num_experts % world_size, 0, "num_experts must be divisible."); const int64_t experts_per_rank = num_experts / world_size; - // Ensure subsequent allocations/ops are on x's device. - c10::cuda::CUDAGuard device_guard(x.device()); - NVF_CHECK( - [&]() { - auto token_counts = is_token_in_rank.to(at::kLong).sum(1); - auto min_val = token_counts.min().item(); - auto max_val = token_counts.max().item(); - return min_val == 1 && max_val == 1; - }(), - "Only topk=1 is supported. Each token must be assigned to exactly one " - "rank."); - const bool topk_is_1d = topk_idx.dim() == 1 && topk_idx.size(0) == num_tokens; const bool topk_is_2d = topk_idx.dim() == 2 && topk_idx.size(0) == num_tokens && topk_idx.size(1) == 1; @@ -198,8 +189,6 @@ CombineResult doMoECombine( communicator->size(), "n_tokens_from_rank must match world size."); - c10::cuda::CUDAGuard device_guard(x.device()); - // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted_indices = at::argsort(src_rank); auto send_x = x.index_select(0, sorted_indices); diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index f924dd7347c..8caf7074cb8 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -34,7 +34,8 @@ struct CombineResult { // topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. // topk_weights: Apply gating weights either before dispatch or after combine. // They are intentionally not forwarded through dispatch/combination. -// is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. +// is_token_in_rank: One-hot token-to-rank assignment, shape [T, R], enabling +// non-trivial device meshes or uneven expert-to-rank mappings. // num_experts: Total experts across all ranks (must be divisible by R). // communicator: Communicator for alltoall exchange. // backend: Communication backend (only NCCL is supported for now). diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index dc22ac9f8bf..5320aa7ec1e 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -14,14 +14,13 @@ #include "fusion.h" #include "host_ir/container.h" #include "host_ir/evaluator.h" -#include "ir/all_nodes.h" #include "multidevice/communication.h" #include "tests/cpp/multidevice.h" namespace nvfuser { namespace hir { -class DispatchCombineTest : public MultiDeviceTest {}; +using DispatchCombineTest = MultiDeviceTest; TEST_F(DispatchCombineTest, DispatchCombineTop1) { if (!communicator_->is_available() || communicator_->size() < 2) { From 4693c53bf2692d27720249b1c2baaae42605bb14 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 29 Jan 2026 06:00:22 -0800 Subject: [PATCH 06/18] minor review --- csrc/multidevice/dispatch_combine.cpp | 15 ++++++++-- csrc/multidevice/dispatch_combine.h | 41 ++++++++++++++------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 9b2edb92a4e..f069ee61163 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -67,6 +67,10 @@ DispatchResult doMoEDispatch( const int64_t hidden = x.size(1); const int64_t world_size = communicator->size(); const int64_t my_rank = communicator->deviceId(); + NVF_CHECK_EQ( + is_token_in_rank.size(0), + num_tokens, + "is_token_in_rank first dim must match number of tokens."); NVF_CHECK_EQ( is_token_in_rank.size(1), world_size, @@ -107,8 +111,9 @@ DispatchResult doMoEDispatch( auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); - NVF_CHECK( - backend == CommunicatorBackend::kNccl, + NVF_CHECK_EQ( + backend, + CommunicatorBackend::kNccl, "Only NCCL backend is supported for MoEDispatch."); CommunicatorBackend actual_backend = backend; NVF_CHECK( @@ -180,6 +185,12 @@ CombineResult doMoECombine( NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); + NVF_CHECK_EQ( + src_idx.size(0), x.size(0), "src_idx size must match x first dimension."); + NVF_CHECK_EQ( + src_rank.size(0), + x.size(0), + "src_rank size must match x first dimension."); NVF_CHECK_EQ( n_tokens_to_rank.numel(), communicator->size(), diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 8caf7074cb8..32241e5da7f 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -44,26 +44,27 @@ struct CombineResult { // DispatchResult with recv_* tensors on this rank. // // Example: -// // world_size=2, num_experts=4, T=4, H=2, topk=1 -// // Experts are partitioned by rank: -// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} -// // Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: -// // rank0 x = [x0, x1], rank1 x = [x2, x3] -// // token->rank: [0, 1, 1, 1] (rank0 keeps x0, sends x1; rank1 keeps x2,x3) -// // is_token_in_rank = -// // [[1, 0], -// // [0, 1], -// // [0, 1], -// // [0, 1]] -// // topk_idx = [0, 2, 3, 2] (global expert ids) -// // After dispatch on rank0: -// // recv_x has token {0} -// // recv_topk_idx aligned with recv_x (e.g., [0]) -// // recv_src_idx tells original token positions (e.g., [0]) -// // After dispatch on rank1: -// // recv_x has tokens {1, 2, 3} -// // recv_topk_idx aligned with recv_x (e.g., [2, 3, 2]) -// // recv_src_idx tells original token positions (e.g., [1, 2, 3]) +// world_size=2, num_experts=4, T=4, H=2, topk=1 +// Experts are partitioned by rank: +// rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: +// rank0 x = [x0, x1], rank1 x = [x2, x3] +// token->rank: [0, 1, 1, 1] (rank0 keeps x0, sends x1; rank1 keeps x2,x3) +// is_token_in_rank = +// [[1, 0], +// [0, 1], +// [0, 1], +// [0, 1]] +// topk_idx = [0, 2, 3, 2] (global expert ids) +// After dispatch on rank0: +// recv_x has token {0} +// recv_topk_idx aligned with recv_x (e.g., [0]) +// recv_src_idx tells original token positions (e.g., [0]) +// After dispatch on rank1: +// recv_x has tokens {1, 2, 3} +// recv_topk_idx aligned with recv_x (e.g., [2, 2, 3]). Tokens are grouped +// by expert id for local expert processing. +// recv_src_idx tells original token positions (e.g., [1, 2, 3]) // auto out = doMoEDispatch( // x, topk_idx, is_token_in_rank, 4, comm, CommunicatorBackend::kNccl); NVF_API DispatchResult doMoEDispatch( From a81a5146e1131d5e3079a55446b6e0d7e344f6f2 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 4 Feb 2026 05:56:58 -0800 Subject: [PATCH 07/18] renaming --- csrc/dispatch.h | 4 +-- csrc/host_ir/evaluator.cpp | 8 +++--- csrc/host_ir/evaluator.h | 4 +-- csrc/multidevice/communication.cpp | 20 ++++++------- csrc/multidevice/communication.h | 28 +++++++++---------- csrc/multidevice/dispatch_combine.cpp | 8 +++--- csrc/multidevice/dispatch_combine.h | 8 +++--- .../cpp/test_multidevice_dispatch_combine.cpp | 4 +-- 8 files changed, 42 insertions(+), 42 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 68b847ffb2e..bcf35ff5e55 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -118,8 +118,8 @@ class Val; f(Merge); \ f(Partition); \ f(Combine); \ - f(MoEDispatch); \ - f(MoECombine); \ + f(MoeDispatch); \ + f(MoeCombine); \ f(Swizzle); \ f(Resize); \ f(MatmulOp); \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 0bfc8a451a7..b2a64b32b2a 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -387,7 +387,7 @@ void HostIrEvaluator::handle(P2PCommunication* communication) { } } -void HostIrEvaluator::handle(MoEDispatch* dispatch) { +void HostIrEvaluator::handle(MoeDispatch* dispatch) { NVF_ERROR( communicator_ != nullptr && communicator_->is_available(), "A valid communicator must be provided"); @@ -397,7 +397,7 @@ void HostIrEvaluator::handle(MoEDispatch* dispatch) { auto is_token_in_rank = getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); - auto result = doMoEDispatch( + auto result = doMoeDispatch( x, topk_idx, is_token_in_rank, @@ -414,7 +414,7 @@ void HostIrEvaluator::handle(MoEDispatch* dispatch) { dispatch->outTokensFromRank(), result.n_tokens_from_rank); } -void HostIrEvaluator::handle(MoECombine* combine) { +void HostIrEvaluator::handle(MoeCombine* combine) { NVF_ERROR( communicator_ != nullptr && communicator_->is_available(), "A valid communicator must be provided"); @@ -427,7 +427,7 @@ void HostIrEvaluator::handle(MoECombine* combine) { auto n_tokens_from_rank = getKnownConcreteValue(combine->inTokensFromRank()).as(); - auto result = doMoECombine( + auto result = doMoeCombine( x, src_idx, src_rank, diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index c1b0a70ef78..4a1929ba1bd 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -98,8 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(LaunchKernel*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; - void handle(MoEDispatch*) override; - void handle(MoECombine*) override; + void handle(MoeDispatch*) override; + void handle(MoeCombine*) override; void handle(Wait*) override; void handle(kir::ForLoop*) override; void handle(hir::ForLoop*) override; diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index ac0e815632e..d901ab57dd0 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -321,7 +321,7 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } -MoEDispatch::MoEDispatch( +MoeDispatch::MoeDispatch( IrBuilderPasskey passkey, TensorView* out_x, TensorView* out_topk_idx, @@ -349,9 +349,9 @@ MoEDispatch::MoEDispatch( validate(); } -NVFUSER_DEFINE_CLONE_AND_CREATE(MoEDispatch) +NVFUSER_DEFINE_CLONE_AND_CREATE(MoeDispatch) -std::string MoEDispatch::toInlineString(int indent_size) const { +std::string MoeDispatch::toInlineString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "Dispatch " << name() << " (" << "num_experts=" << numExperts() << ", " @@ -363,11 +363,11 @@ std::string MoEDispatch::toInlineString(int indent_size) const { return ss.str(); } -std::string MoEDispatch::toString(int indent_size) const { +std::string MoeDispatch::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } -void MoEDispatch::validate() { +void MoeDispatch::validate() { NVF_CHECK(numExperts() > 0, "num_experts must be positive."); NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); NVF_CHECK(inTopkIdx()->isA(), "topk_idx must be a TensorView."); @@ -401,7 +401,7 @@ void MoEDispatch::validate() { "out_n_tokens_from_rank must be integral."); } -MoECombine::MoECombine( +MoeCombine::MoeCombine( IrBuilderPasskey passkey, TensorView* out_x, TensorView* in_x, @@ -421,9 +421,9 @@ MoECombine::MoECombine( validate(); } -NVFUSER_DEFINE_CLONE_AND_CREATE(MoECombine) +NVFUSER_DEFINE_CLONE_AND_CREATE(MoeCombine) -std::string MoECombine::toInlineString(int indent_size) const { +std::string MoeCombine::toInlineString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "Combine " << name() << " (" << "backend=" << backend() << ", " @@ -434,11 +434,11 @@ std::string MoECombine::toInlineString(int indent_size) const { return ss.str(); } -std::string MoECombine::toString(int indent_size) const { +std::string MoeCombine::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } -void MoECombine::validate() { +void MoeCombine::validate() { NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); NVF_CHECK( inSrcIdx()->getDataType().has_value() && diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index e801707e5f9..45c3e880e91 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -187,11 +187,11 @@ class P2PCommunication : public Expr { // Outputs are recv-aligned tensors: out_x/out_topk_idx/out_src_* with // [T_recv, ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape // [R]. -class MoEDispatch : public Expr { +class MoeDispatch : public Expr { public: using Expr::Expr; - MoEDispatch( + MoeDispatch( IrBuilderPasskey passkey, TensorView* out_x, TensorView* out_topk_idx, @@ -208,17 +208,17 @@ class MoEDispatch : public Expr { return input(2)->as(); } - MoEDispatch(const MoEDispatch& other) = delete; - MoEDispatch& operator=(const MoEDispatch& other) = delete; - MoEDispatch(MoEDispatch&& other) = delete; - MoEDispatch& operator=(MoEDispatch&& other) = delete; + MoeDispatch(const MoeDispatch& other) = delete; + MoeDispatch& operator=(const MoeDispatch& other) = delete; + MoeDispatch(MoeDispatch&& other) = delete; + MoeDispatch& operator=(MoeDispatch&& other) = delete; NVFUSER_DECLARE_CLONE_AND_CREATE std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; const char* getOpString() const override { - return "MoEDispatch"; + return "MoeDispatch"; } TensorView* outX() const { @@ -272,11 +272,11 @@ class MoEDispatch : public Expr { // in_x: [T_recv, H], in_src_idx: [T_recv], in_src_rank: [T_recv], // in_n_tokens_to_rank: [R], in_n_tokens_from_rank: [R]. // Outputs are source-aligned: out_x with shape [T_src, ...]. -class MoECombine : public Expr { +class MoeCombine : public Expr { public: using Expr::Expr; - MoECombine( + MoeCombine( IrBuilderPasskey passkey, TensorView* out_x, TensorView* in_x, @@ -286,17 +286,17 @@ class MoECombine : public Expr { TensorView* in_n_tokens_from_rank, CommunicatorBackend backend = CommunicatorBackend::kNccl); - MoECombine(const MoECombine& other) = delete; - MoECombine& operator=(const MoECombine& other) = delete; - MoECombine(MoECombine&& other) = delete; - MoECombine& operator=(MoECombine&& other) = delete; + MoeCombine(const MoeCombine& other) = delete; + MoeCombine& operator=(const MoeCombine& other) = delete; + MoeCombine(MoeCombine&& other) = delete; + MoeCombine& operator=(MoeCombine&& other) = delete; NVFUSER_DECLARE_CLONE_AND_CREATE std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; const char* getOpString() const override { - return "MoECombine"; + return "MoeCombine"; } TensorView* outX() const { diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index f069ee61163..b8d4be8be1d 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -38,7 +38,7 @@ void waitWork(const c10::intrusive_ptr& work) { } // namespace -DispatchResult doMoEDispatch( +DispatchResult doMoeDispatch( const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& is_token_in_rank, @@ -114,7 +114,7 @@ DispatchResult doMoEDispatch( NVF_CHECK_EQ( backend, CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoEDispatch."); + "Only NCCL backend is supported for MoeDispatch."); CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), @@ -166,7 +166,7 @@ DispatchResult doMoEDispatch( n_tokens_from_rank}; } -CombineResult doMoECombine( +CombineResult doMoeCombine( const at::Tensor& x, const at::Tensor& src_idx, const at::Tensor& src_rank, @@ -213,7 +213,7 @@ CombineResult doMoECombine( NVF_CHECK( backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoECombine."); + "Only NCCL backend is supported for MoeCombine."); CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 32241e5da7f..db2b7ff1dd0 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -65,9 +65,9 @@ struct CombineResult { // recv_topk_idx aligned with recv_x (e.g., [2, 2, 3]). Tokens are grouped // by expert id for local expert processing. // recv_src_idx tells original token positions (e.g., [1, 2, 3]) -// auto out = doMoEDispatch( +// auto out = doMoeDispatch( // x, topk_idx, is_token_in_rank, 4, comm, CommunicatorBackend::kNccl); -NVF_API DispatchResult doMoEDispatch( +NVF_API DispatchResult doMoeDispatch( const at::Tensor& x, // [T, H] const at::Tensor& topk_idx, // [T] or [T, 1] const at::Tensor& is_token_in_rank, // [T, R] @@ -98,10 +98,10 @@ NVF_API DispatchResult doMoEDispatch( // // src_rank = [0, 1, 1] // // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. // // Combine scatters results back to original token order per rank. -// auto combined = doMoECombine( +// auto combined = doMoeCombine( // x, src_idx, src_rank, n_tokens_to_rank, // n_tokens_from_rank, comm, CommunicatorBackend::kNccl); -NVF_API CombineResult doMoECombine( +NVF_API CombineResult doMoeCombine( const at::Tensor& x, const at::Tensor& src_idx, const at::Tensor& src_rank, diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 5320aa7ec1e..4084afd70ab 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -48,7 +48,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); - auto* dispatch = IrBuilder::create( + auto* dispatch = IrBuilder::create( recv_x, recv_topk_idx, recv_src_idx, @@ -62,7 +62,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); - auto* combine = IrBuilder::create( + auto* combine = IrBuilder::create( combined_x, recv_x, recv_src_idx, From a0de605750009bc7ab5023b671e73efd01b87bf5 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 5 Feb 2026 04:57:23 -0800 Subject: [PATCH 08/18] add back topk_weights --- csrc/host_ir/evaluator.cpp | 8 +++ csrc/multidevice/communication.cpp | 26 +++++++++ csrc/multidevice/communication.h | 54 +++++++++++++------ csrc/multidevice/dispatch_combine.cpp | 46 +++++++++++++++- csrc/multidevice/dispatch_combine.h | 24 ++++++--- .../cpp/test_multidevice_dispatch_combine.cpp | 17 +++++- 6 files changed, 150 insertions(+), 25 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index b2a64b32b2a..7d28c6e0755 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -394,12 +394,15 @@ void HostIrEvaluator::handle(MoeDispatch* dispatch) { auto x = getKnownConcreteValue(dispatch->inX()).as(); auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_weights = + getKnownConcreteValue(dispatch->inTopkWeights()).as(); auto is_token_in_rank = getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); auto result = doMoeDispatch( x, topk_idx, + topk_weights, is_token_in_rank, dispatch->numExperts(), communicator_, @@ -407,6 +410,7 @@ void HostIrEvaluator::handle(MoeDispatch* dispatch) { expr_evaluator_.bind(dispatch->outX(), result.recv_x); expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx); + expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights); expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx); expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank); expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank); @@ -420,6 +424,8 @@ void HostIrEvaluator::handle(MoeCombine* combine) { "A valid communicator must be provided"); auto x = getKnownConcreteValue(combine->inX()).as(); + auto topk_weights = + getKnownConcreteValue(combine->inTopkWeights()).as(); auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as(); auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as(); auto n_tokens_to_rank = @@ -429,6 +435,7 @@ void HostIrEvaluator::handle(MoeCombine* combine) { auto result = doMoeCombine( x, + topk_weights, src_idx, src_rank, n_tokens_to_rank, @@ -437,6 +444,7 @@ void HostIrEvaluator::handle(MoeCombine* combine) { combine->backend()); expr_evaluator_.bind(combine->outX(), result.combined_x); + expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights); } void HostIrEvaluator::handle(Wait* wait) { diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index d901ab57dd0..b790748f957 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -325,21 +325,25 @@ MoeDispatch::MoeDispatch( IrBuilderPasskey passkey, TensorView* out_x, TensorView* out_topk_idx, + TensorView* out_topk_weights, TensorView* out_src_idx, TensorView* out_src_rank, TensorView* out_n_tokens_to_rank, TensorView* out_n_tokens_from_rank, TensorView* in_x, TensorView* in_topk_idx, + TensorView* in_topk_weights, TensorView* in_is_token_in_rank, int64_t num_experts, CommunicatorBackend backend) : Expr(passkey) { addInput(in_x); addInput(in_topk_idx); + addInput(in_topk_weights); addInput(in_is_token_in_rank); addOutput(out_x); addOutput(out_topk_idx); + addOutput(out_topk_weights); addOutput(out_src_idx); addOutput(out_src_rank); addOutput(out_n_tokens_to_rank); @@ -358,6 +362,7 @@ std::string MoeDispatch::toInlineString(int indent_size) const { << "backend=" << backend() << ", " << "in=" << inX() << ", " << "topk_idx=" << inTopkIdx() << ", " + << "topk_weights=" << inTopkWeights() << ", " << "is_token_in_rank=" << inIsTokenInRank() << ", " << "out=" << outX() << ")"; return ss.str(); @@ -375,6 +380,10 @@ void MoeDispatch::validate() { inTopkIdx()->getDataType().has_value() && isIntegralType(*inTopkIdx()->getDataType()), "topk_idx must be integral."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "topk_weights must be floating point."); NVF_CHECK( inIsTokenInRank()->getDataType().has_value() && inIsTokenInRank()->getDataType() == DataType::Bool, @@ -383,6 +392,10 @@ void MoeDispatch::validate() { outTopkIdx()->getDataType().has_value() && isIntegralType(*outTopkIdx()->getDataType()), "out_topk_idx must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); NVF_CHECK( outSrcIdx()->getDataType().has_value() && isIntegralType(*outSrcIdx()->getDataType()), @@ -404,7 +417,9 @@ void MoeDispatch::validate() { MoeCombine::MoeCombine( IrBuilderPasskey passkey, TensorView* out_x, + TensorView* out_topk_weights, TensorView* in_x, + TensorView* in_topk_weights, TensorView* in_src_idx, TensorView* in_src_rank, TensorView* in_n_tokens_to_rank, @@ -412,11 +427,13 @@ MoeCombine::MoeCombine( CommunicatorBackend backend) : Expr(passkey) { addInput(in_x); + addInput(in_topk_weights); addInput(in_src_idx); addInput(in_src_rank); addInput(in_n_tokens_to_rank); addInput(in_n_tokens_from_rank); addOutput(out_x); + addOutput(out_topk_weights); addDataAttribute(backend); validate(); } @@ -428,6 +445,7 @@ std::string MoeCombine::toInlineString(int indent_size) const { indent(ss, indent_size) << "Combine " << name() << " (" << "backend=" << backend() << ", " << "in=" << inX() << ", " + << "topk_weights=" << inTopkWeights() << ", " << "src_idx=" << inSrcIdx() << ", " << "src_rank=" << inSrcRank() << ", " << "out=" << outX() << ")"; @@ -440,6 +458,10 @@ std::string MoeCombine::toString(int indent_size) const { void MoeCombine::validate() { NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "in_topk_weights must be floating point."); NVF_CHECK( inSrcIdx()->getDataType().has_value() && isIntegralType(*inSrcIdx()->getDataType()), @@ -456,6 +478,10 @@ void MoeCombine::validate() { inTokensFromRank()->getDataType().has_value() && isIntegralType(*inTokensFromRank()->getDataType()), "in_n_tokens_from_rank must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); } namespace { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 45c3e880e91..f4a1abaf667 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -179,14 +179,13 @@ class P2PCommunication : public Expr { // // Example shapes (topk=1): // in_x: [T, H], in_topk_idx: [T] or [T, 1], +// in_topk_weights: [T] or [T, 1], // in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. -// topk_weights are intentionally not forwarded; apply them before dispatch or -// after combine. // out_src_idx/out_src_rank are returned for the combine step to restore the // original token order. -// Outputs are recv-aligned tensors: out_x/out_topk_idx/out_src_* with -// [T_recv, ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape -// [R]. +// Outputs are recv-aligned tensors: out_x/out_topk_idx/out_topk_weights/ +// out_src_* with [T_recv, ...] and +// out_n_tokens_to_rank/out_n_tokens_from_rank with shape [R]. class MoeDispatch : public Expr { public: using Expr::Expr; @@ -195,17 +194,19 @@ class MoeDispatch : public Expr { IrBuilderPasskey passkey, TensorView* out_x, TensorView* out_topk_idx, + TensorView* out_topk_weights, TensorView* out_src_idx, TensorView* out_src_rank, TensorView* out_n_tokens_to_rank, TensorView* out_n_tokens_from_rank, TensorView* in_x, TensorView* in_topk_idx, + TensorView* in_topk_weights, TensorView* in_is_token_in_rank, int64_t num_experts, CommunicatorBackend backend = CommunicatorBackend::kNccl); TensorView* inIsTokenInRank() const { - return input(2)->as(); + return input(3)->as(); } MoeDispatch(const MoeDispatch& other) = delete; @@ -229,22 +230,26 @@ class MoeDispatch : public Expr { return output(1)->as(); } - TensorView* outSrcIdx() const { + TensorView* outTopkWeights() const { return output(2)->as(); } - TensorView* outSrcRank() const { + TensorView* outSrcIdx() const { return output(3)->as(); } - TensorView* outTokensToRank() const { + TensorView* outSrcRank() const { return output(4)->as(); } - TensorView* outTokensFromRank() const { + TensorView* outTokensToRank() const { return output(5)->as(); } + TensorView* outTokensFromRank() const { + return output(6)->as(); + } + TensorView* inX() const { return input(0)->as(); } @@ -253,6 +258,10 @@ class MoeDispatch : public Expr { return input(1)->as(); } + TensorView* inTopkWeights() const { + return input(2)->as(); + } + int64_t numExperts() const { return attribute(0); } @@ -269,9 +278,10 @@ class MoeDispatch : public Expr { // their source ranks using `in_src_rank` and `in_src_idx`. // // Example shapes (topk=1): -// in_x: [T_recv, H], in_src_idx: [T_recv], in_src_rank: [T_recv], -// in_n_tokens_to_rank: [R], in_n_tokens_from_rank: [R]. -// Outputs are source-aligned: out_x with shape [T_src, ...]. +// in_x: [T_recv, H], in_topk_weights: [T_recv], in_src_idx: [T_recv], +// in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: +// [R]. Outputs are source-aligned: out_x/out_topk_weights with shape [T_src, +// ...]. class MoeCombine : public Expr { public: using Expr::Expr; @@ -279,7 +289,9 @@ class MoeCombine : public Expr { MoeCombine( IrBuilderPasskey passkey, TensorView* out_x, + TensorView* out_topk_weights, TensorView* in_x, + TensorView* in_topk_weights, TensorView* in_src_idx, TensorView* in_src_rank, TensorView* in_n_tokens_to_rank, @@ -303,26 +315,34 @@ class MoeCombine : public Expr { return output(0)->as(); } + TensorView* outTopkWeights() const { + return output(1)->as(); + } + TensorView* inX() const { return input(0)->as(); } - TensorView* inSrcIdx() const { + TensorView* inTopkWeights() const { return input(1)->as(); } - TensorView* inSrcRank() const { + TensorView* inSrcIdx() const { return input(2)->as(); } - TensorView* inTokensToRank() const { + TensorView* inSrcRank() const { return input(3)->as(); } - TensorView* inTokensFromRank() const { + TensorView* inTokensToRank() const { return input(4)->as(); } + TensorView* inTokensFromRank() const { + return input(5)->as(); + } + CommunicatorBackend backend() const { return attribute(0); } diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index b8d4be8be1d..043b37fd421 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -41,6 +41,7 @@ void waitWork(const c10::intrusive_ptr& work) { DispatchResult doMoeDispatch( const at::Tensor& x, const at::Tensor& topk_idx, + const at::Tensor& topk_weights, const at::Tensor& is_token_in_rank, int64_t num_experts, Communicator* communicator, @@ -48,11 +49,18 @@ DispatchResult doMoeDispatch( NVF_CHECK(communicator != nullptr, "Dispatch requires a valid communicator."); NVF_CHECK(x.is_cuda(), "Dispatch input x must be on CUDA."); NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); + NVF_CHECK( + topk_weights.is_floating_point(), + "Dispatch topk_weights must be floating point."); NVF_CHECK( is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); NVF_CHECK( x.device() == topk_idx.device(), "Dispatch expects x and topk_idx on the same device."); + NVF_CHECK( + x.device() == topk_weights.device(), + "Dispatch expects x and topk_weights on the same device."); NVF_CHECK( x.device() == is_token_in_rank.device(), "Dispatch expects x and is_token_in_rank on the same device."); @@ -86,6 +94,15 @@ DispatchResult doMoeDispatch( "Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ", topk_idx.sizes()); auto topk_idx_flat = topk_idx.reshape({num_tokens}); + const bool weights_is_1d = + topk_weights.dim() == 1 && topk_weights.size(0) == num_tokens; + const bool weights_is_2d = topk_weights.dim() == 2 && + topk_weights.size(0) == num_tokens && topk_weights.size(1) == 1; + NVF_CHECK( + weights_is_1d || weights_is_2d, + "Only topk=1 supported. topk_weights must be shape [T] or [T, 1], got: ", + topk_weights.sizes()); + auto topk_weights_flat = topk_weights.reshape({num_tokens}); // Determine destination rank per token (topk=1). auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); @@ -95,6 +112,7 @@ DispatchResult doMoeDispatch( // Reorder payloads so alltoall can send contiguous chunks per rank. auto send_x = x.index_select(0, sorted_indices); auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); // Track original token indices and source rank for the combine step. auto send_src_idx = sorted_indices.to(at::kLong); // All entries are identical, so no relayout is needed. @@ -137,6 +155,7 @@ DispatchResult doMoeDispatch( // TODO: support preallocated buffers. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); @@ -144,6 +163,8 @@ DispatchResult doMoeDispatch( waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_src_idx, send_src_idx, output_splits, input_splits)); waitWork(pg->alltoall_base( @@ -154,12 +175,14 @@ DispatchResult doMoeDispatch( auto expert_order = at::argsort(local_expert); recv_x = recv_x.index_select(0, expert_order); recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); recv_src_idx = recv_src_idx.index_select(0, expert_order); recv_src_rank = recv_src_rank.index_select(0, expert_order); return DispatchResult{ recv_x, recv_topk_idx, + recv_topk_weights, recv_src_idx, recv_src_rank, n_tokens_to_rank, @@ -168,6 +191,7 @@ DispatchResult doMoeDispatch( CombineResult doMoeCombine( const at::Tensor& x, + const at::Tensor& topk_weights, const at::Tensor& src_idx, const at::Tensor& src_rank, const at::Tensor& n_tokens_to_rank, @@ -176,6 +200,10 @@ CombineResult doMoeCombine( CommunicatorBackend backend) { NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); + NVF_CHECK( + topk_weights.is_floating_point(), + "Combine topk_weights must be floating point."); NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); NVF_CHECK( @@ -185,6 +213,15 @@ CombineResult doMoeCombine( NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); + const bool weights_is_1d = + topk_weights.dim() == 1 && topk_weights.size(0) == x.size(0); + const bool weights_is_2d = topk_weights.dim() == 2 && + topk_weights.size(0) == x.size(0) && topk_weights.size(1) == 1; + NVF_CHECK( + weights_is_1d || weights_is_2d, + "topk_weights must be shape [T] or [T, 1], got: ", + topk_weights.sizes()); + auto topk_weights_flat = topk_weights.reshape({x.size(0)}); NVF_CHECK_EQ( src_idx.size(0), x.size(0), "src_idx size must match x first dimension."); NVF_CHECK_EQ( @@ -203,6 +240,7 @@ CombineResult doMoeCombine( // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted_indices = at::argsort(src_rank); auto send_x = x.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. @@ -224,17 +262,23 @@ CombineResult doMoeCombine( // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_src_idx, send_src_idx, output_splits, input_splits)); // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights_flat.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); - return CombineResult{combined_x}; + return CombineResult{combined_x, combined_topk_weights}; } } // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index db2b7ff1dd0..9c4f4e6a62a 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -17,6 +17,7 @@ namespace nvfuser { struct DispatchResult { at::Tensor recv_x; // Dispatched tokens received on this rank. at::Tensor recv_topk_idx; // Expert ids aligned with recv_x. + at::Tensor recv_topk_weights; // Gating weights aligned with recv_x. at::Tensor recv_src_idx; // Source token indices for combine. at::Tensor recv_src_rank; // Source ranks for combine. at::Tensor n_tokens_to_rank; // Tokens sent to each rank (this rank's view). @@ -25,6 +26,7 @@ struct DispatchResult { struct CombineResult { at::Tensor combined_x; // Combined tokens back in original order. + at::Tensor combined_topk_weights; // Combined gating weights per token. }; // Dispatch MoE tokens to the owning ranks. Only k=1 is supported for now. @@ -32,8 +34,7 @@ struct CombineResult { // Args: // x: Token embeddings on this rank, shape [T, H]. // topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. -// topk_weights: Apply gating weights either before dispatch or after combine. -// They are intentionally not forwarded through dispatch/combination. +// topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. // is_token_in_rank: One-hot token-to-rank assignment, shape [T, R], enabling // non-trivial device meshes or uneven expert-to-rank mappings. // num_experts: Total experts across all ranks (must be divisible by R). @@ -59,6 +60,7 @@ struct CombineResult { // After dispatch on rank0: // recv_x has token {0} // recv_topk_idx aligned with recv_x (e.g., [0]) +// recv_topk_weights aligned with recv_x (e.g., [1.0]) // recv_src_idx tells original token positions (e.g., [0]) // After dispatch on rank1: // recv_x has tokens {1, 2, 3} @@ -66,10 +68,17 @@ struct CombineResult { // by expert id for local expert processing. // recv_src_idx tells original token positions (e.g., [1, 2, 3]) // auto out = doMoeDispatch( -// x, topk_idx, is_token_in_rank, 4, comm, CommunicatorBackend::kNccl); +// x, +// topk_idx, +// topk_weights, +// is_token_in_rank, +// 4, +// comm, +// CommunicatorBackend::kNccl); NVF_API DispatchResult doMoeDispatch( const at::Tensor& x, // [T, H] const at::Tensor& topk_idx, // [T] or [T, 1] + const at::Tensor& topk_weights, // [T] or [T, 1] const at::Tensor& is_token_in_rank, // [T, R] int64_t num_experts, Communicator* communicator, @@ -79,12 +88,14 @@ NVF_API DispatchResult doMoeDispatch( // // Args: // x: Token embeddings after expert compute, shape [T_recv, H]. +// topk_weights: Gating weights aligned with x, shape [T_recv] or [T_recv, 1]. // src_idx: Original token indices for each row of x, shape [T_recv]. // src_rank: Original source rank per token, shape [T_recv]. // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. // n_tokens_from_rank: Tokens received from each rank (from dispatch), shape -// [R]. communicator: Communicator for alltoall exchange. backend: -// Communication backend (only NCCL is supported for now). +// [R]. +// communicator: Communicator for alltoall exchange. +// backend: Communication backend (only NCCL is supported for now). // // Returns: // CombineResult with tokens restored to original order on this rank. @@ -99,10 +110,11 @@ NVF_API DispatchResult doMoeDispatch( // // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. // // Combine scatters results back to original token order per rank. // auto combined = doMoeCombine( -// x, src_idx, src_rank, n_tokens_to_rank, +// x, topk_weights, src_idx, src_rank, n_tokens_to_rank, // n_tokens_from_rank, comm, CommunicatorBackend::kNccl); NVF_API CombineResult doMoeCombine( const at::Tensor& x, + const at::Tensor& topk_weights, const at::Tensor& src_idx, const at::Tensor& src_rank, const at::Tensor& n_tokens_to_rank, diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 4084afd70ab..22db18066f3 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -39,10 +39,12 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto* in_x = makeSymbolicTensor(2); auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(1); auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); auto* recv_x = makeSymbolicTensor(2); auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(1); auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); @@ -51,20 +53,25 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto* dispatch = IrBuilder::create( recv_x, recv_topk_idx, + recv_topk_weights, recv_src_idx, recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, in_x, in_topk_idx, + in_topk_weights, in_is_token_in_rank, num_experts, CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); + auto* combined_topk_weights = makeSymbolicTensor(1); auto* combine = IrBuilder::create( combined_x, + combined_topk_weights, recv_x, + recv_topk_weights, recv_src_idx, recv_src_rank, n_tokens_to_rank, @@ -76,8 +83,10 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { hic->addInput(in_x); hic->addInput(in_topk_idx); + hic->addInput(in_topk_weights); hic->addInput(in_is_token_in_rank); hic->addOutput(combined_x); + hic->addOutput(combined_topk_weights); HostIrEvaluator hie(std::move(hic), communicator_); @@ -90,6 +99,8 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; auto topk_idx = at::zeros({kNumTokens}, int_options); + auto topk_weights = + at::arange(kNumTokens, float_options) + static_cast(my_rank); // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. @@ -106,11 +117,15 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, + {in_topk_weights, topk_weights}, {in_is_token_in_rank, is_token_in_rank}}); - auto combined = outputs.back().as(); + auto combined = outputs[0].as(); + auto combined_weights = outputs[1].as(); EXPECT_TRUE(at::allclose(combined, x)) << "Dispatch/Combine mismatch on rank " << my_rank; + EXPECT_TRUE(at::allclose(combined_weights, topk_weights)) + << "Dispatch/Combine topk_weights mismatch on rank " << my_rank; } } // namespace hir From 74d18d1ca8bdf4c4acdb03bb223f94aca5c85021 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 9 Feb 2026 02:32:39 -0800 Subject: [PATCH 09/18] harden tests --- .../cpp/test_multidevice_dispatch_combine.cpp | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 22db18066f3..b3884b9499f 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -15,6 +15,7 @@ #include "host_ir/container.h" #include "host_ir/evaluator.h" #include "multidevice/communication.h" +#include "multidevice/dispatch_combine.h" #include "tests/cpp/multidevice.h" namespace nvfuser { @@ -128,5 +129,223 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { << "Dispatch/Combine topk_weights mismatch on rank " << my_rank; } +TEST_F(DispatchCombineTest, DispatchOnlyTop1) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + constexpr int64_t kNumExpertsPerRank = 2; + const int64_t num_experts = world_size * kNumExpertsPerRank; + constexpr int64_t kNumTokens = 4; + constexpr int64_t kHidden = 4; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in_x = makeSymbolicTensor(2); + auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); + + auto* recv_x = makeSymbolicTensor(2); + auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(1); + auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); + + auto* dispatch = IrBuilder::create( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + in_x, + in_topk_idx, + in_topk_weights, + in_is_token_in_rank, + num_experts, + CommunicatorBackend::kNccl); + + hic->pushBackTopLevelExprs(dispatch); + + hic->addInput(in_x); + hic->addInput(in_topk_idx); + hic->addInput(in_topk_weights); + hic->addInput(in_is_token_in_rank); + hic->addOutput(recv_x); + hic->addOutput(recv_topk_idx); + hic->addOutput(recv_topk_weights); + hic->addOutput(recv_src_idx); + hic->addOutput(recv_src_rank); + hic->addOutput(n_tokens_to_rank); + hic->addOutput(n_tokens_from_rank); + + HostIrEvaluator hie(std::move(hic), communicator_); + + auto float_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kFloat); + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto x = at::arange(kNumTokens * kHidden, float_options) + .reshape({kNumTokens, kHidden}) + + static_cast(my_rank) * 1000.0; + auto topk_idx = at::zeros({kNumTokens}, int_options); + auto topk_weights = + at::arange(kNumTokens, float_options) + static_cast(my_rank); + + // Asymmetric example: + // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. + auto rank_ids = at::arange(world_size, int_options); + auto token_rank = at::tensor({0, 1, 1, 1}, int_options); + auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + + // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. + topk_idx.index_put_({0}, 0); + topk_idx.index_put_({1}, kNumExpertsPerRank); + topk_idx.index_put_({2}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3}, kNumExpertsPerRank); + + auto outputs = hie.runWithInput( + {{in_x, x}, + {in_topk_idx, topk_idx}, + {in_topk_weights, topk_weights}, + {in_is_token_in_rank, is_token_in_rank}}); + + auto expected = doMoeDispatch( + x, + topk_idx, + topk_weights, + is_token_in_rank, + num_experts, + communicator_, + CommunicatorBackend::kNccl); + + EXPECT_TRUE(at::allclose(outputs[0].as(), expected.recv_x)) + << "Dispatch recv_x mismatch on rank " << my_rank; + EXPECT_TRUE( + at::allclose(outputs[1].as(), expected.recv_topk_idx)) + << "Dispatch recv_topk_idx mismatch on rank " << my_rank; + EXPECT_TRUE( + at::allclose(outputs[2].as(), expected.recv_topk_weights)) + << "Dispatch recv_topk_weights mismatch on rank " << my_rank; + EXPECT_TRUE( + at::allclose(outputs[3].as(), expected.recv_src_idx)) + << "Dispatch recv_src_idx mismatch on rank " << my_rank; + EXPECT_TRUE( + at::allclose(outputs[4].as(), expected.recv_src_rank)) + << "Dispatch recv_src_rank mismatch on rank " << my_rank; + EXPECT_TRUE( + at::allclose(outputs[5].as(), expected.n_tokens_to_rank)) + << "Dispatch n_tokens_to_rank mismatch on rank " << my_rank; + EXPECT_TRUE( + at::allclose(outputs[6].as(), expected.n_tokens_from_rank)) + << "Dispatch n_tokens_from_rank mismatch on rank " << my_rank; +} + +TEST_F(DispatchCombineTest, CombineOnlyTop1) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + constexpr int64_t kNumExpertsPerRank = 2; + const int64_t num_experts = world_size * kNumExpertsPerRank; + constexpr int64_t kNumTokens = 4; + constexpr int64_t kHidden = 4; + + auto float_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kFloat); + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto x = at::arange(kNumTokens * kHidden, float_options) + .reshape({kNumTokens, kHidden}) + + static_cast(my_rank) * 1000.0; + auto topk_idx = at::zeros({kNumTokens}, int_options); + auto topk_weights = + at::arange(kNumTokens, float_options) + static_cast(my_rank); + + // Asymmetric example: + // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. + auto rank_ids = at::arange(world_size, int_options); + auto token_rank = at::tensor({0, 1, 1, 1}, int_options); + auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + + // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. + topk_idx.index_put_({0}, 0); + topk_idx.index_put_({1}, kNumExpertsPerRank); + topk_idx.index_put_({2}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3}, kNumExpertsPerRank); + + auto dispatch_result = doMoeDispatch( + x, + topk_idx, + topk_weights, + is_token_in_rank, + num_experts, + communicator_, + CommunicatorBackend::kNccl); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in_x = makeSymbolicTensor(2); + auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_src_idx = makeSymbolicTensor(1, DataType::Int); + auto* in_src_rank = makeSymbolicTensor(1, DataType::Int); + auto* in_n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); + auto* in_n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); + + auto* combined_x = makeSymbolicTensor(2); + auto* combined_topk_weights = makeSymbolicTensor(1); + auto* combine = IrBuilder::create( + combined_x, + combined_topk_weights, + in_x, + in_topk_weights, + in_src_idx, + in_src_rank, + in_n_tokens_to_rank, + in_n_tokens_from_rank, + CommunicatorBackend::kNccl); + + hic->pushBackTopLevelExprs(combine); + + hic->addInput(in_x); + hic->addInput(in_topk_weights); + hic->addInput(in_src_idx); + hic->addInput(in_src_rank); + hic->addInput(in_n_tokens_to_rank); + hic->addInput(in_n_tokens_from_rank); + hic->addOutput(combined_x); + hic->addOutput(combined_topk_weights); + + HostIrEvaluator hie(std::move(hic), communicator_); + + auto outputs = hie.runWithInput( + {{in_x, dispatch_result.recv_x}, + {in_topk_weights, dispatch_result.recv_topk_weights}, + {in_src_idx, dispatch_result.recv_src_idx}, + {in_src_rank, dispatch_result.recv_src_rank}, + {in_n_tokens_to_rank, dispatch_result.n_tokens_to_rank}, + {in_n_tokens_from_rank, dispatch_result.n_tokens_from_rank}}); + + auto combined = outputs[0].as(); + auto combined_weights = outputs[1].as(); + + EXPECT_TRUE(at::allclose(combined, x)) + << "Combine mismatch on rank " << my_rank; + EXPECT_TRUE(at::allclose(combined_weights, topk_weights)) + << "Combine topk_weights mismatch on rank " << my_rank; +} + } // namespace hir } // namespace nvfuser From 6b994ba6254e42435449d0fcb3af591cb910c8ad Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 9 Feb 2026 02:51:07 -0800 Subject: [PATCH 10/18] assume continuous expert-to-rank mapping and simplify API and implementation --- csrc/host_ir/evaluator.cpp | 3 -- csrc/multidevice/communication.cpp | 7 ---- csrc/multidevice/communication.h | 8 +--- csrc/multidevice/dispatch_combine.cpp | 37 +++---------------- csrc/multidevice/dispatch_combine.h | 15 ++------ .../cpp/test_multidevice_dispatch_combine.cpp | 26 +------------ 6 files changed, 13 insertions(+), 83 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 7d28c6e0755..08a49e11ede 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -396,14 +396,11 @@ void HostIrEvaluator::handle(MoeDispatch* dispatch) { auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as(); auto topk_weights = getKnownConcreteValue(dispatch->inTopkWeights()).as(); - auto is_token_in_rank = - getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); auto result = doMoeDispatch( x, topk_idx, topk_weights, - is_token_in_rank, dispatch->numExperts(), communicator_, dispatch->backend()); diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index b790748f957..02d2fdad289 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -333,14 +333,12 @@ MoeDispatch::MoeDispatch( TensorView* in_x, TensorView* in_topk_idx, TensorView* in_topk_weights, - TensorView* in_is_token_in_rank, int64_t num_experts, CommunicatorBackend backend) : Expr(passkey) { addInput(in_x); addInput(in_topk_idx); addInput(in_topk_weights); - addInput(in_is_token_in_rank); addOutput(out_x); addOutput(out_topk_idx); addOutput(out_topk_weights); @@ -363,7 +361,6 @@ std::string MoeDispatch::toInlineString(int indent_size) const { << "in=" << inX() << ", " << "topk_idx=" << inTopkIdx() << ", " << "topk_weights=" << inTopkWeights() << ", " - << "is_token_in_rank=" << inIsTokenInRank() << ", " << "out=" << outX() << ")"; return ss.str(); } @@ -384,10 +381,6 @@ void MoeDispatch::validate() { inTopkWeights()->getDataType().has_value() && isFloatingPointType(*inTopkWeights()->getDataType()), "topk_weights must be floating point."); - NVF_CHECK( - inIsTokenInRank()->getDataType().has_value() && - inIsTokenInRank()->getDataType() == DataType::Bool, - "is_token_in_rank must be Bool."); NVF_CHECK( outTopkIdx()->getDataType().has_value() && isIntegralType(*outTopkIdx()->getDataType()), diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index f4a1abaf667..9303e30e7d1 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -179,8 +179,8 @@ class P2PCommunication : public Expr { // // Example shapes (topk=1): // in_x: [T, H], in_topk_idx: [T] or [T, 1], -// in_topk_weights: [T] or [T, 1], -// in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. +// in_topk_weights: [T] or [T, 1], num_experts = R * experts_per_rank. +// Experts are assumed to be placed contiguously by rank. // out_src_idx/out_src_rank are returned for the combine step to restore the // original token order. // Outputs are recv-aligned tensors: out_x/out_topk_idx/out_topk_weights/ @@ -202,12 +202,8 @@ class MoeDispatch : public Expr { TensorView* in_x, TensorView* in_topk_idx, TensorView* in_topk_weights, - TensorView* in_is_token_in_rank, int64_t num_experts, CommunicatorBackend backend = CommunicatorBackend::kNccl); - TensorView* inIsTokenInRank() const { - return input(3)->as(); - } MoeDispatch(const MoeDispatch& other) = delete; MoeDispatch& operator=(const MoeDispatch& other) = delete; diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 043b37fd421..3776f003f76 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -42,7 +42,6 @@ DispatchResult doMoeDispatch( const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, - const at::Tensor& is_token_in_rank, int64_t num_experts, Communicator* communicator, CommunicatorBackend backend) { @@ -53,36 +52,18 @@ DispatchResult doMoeDispatch( NVF_CHECK( topk_weights.is_floating_point(), "Dispatch topk_weights must be floating point."); - NVF_CHECK( - is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); NVF_CHECK( x.device() == topk_idx.device(), "Dispatch expects x and topk_idx on the same device."); NVF_CHECK( x.device() == topk_weights.device(), "Dispatch expects x and topk_weights on the same device."); - NVF_CHECK( - x.device() == is_token_in_rank.device(), - "Dispatch expects x and is_token_in_rank on the same device."); - NVF_CHECK_EQ( - is_token_in_rank.dim(), - 2, - "is_token_in_rank must be [tokens, ranks], got: ", - is_token_in_rank.sizes()); NVF_CHECK_EQ(x.dim(), 2, "Dispatch expects x to be 2D [tokens, hidden]."); const int64_t num_tokens = x.size(0); const int64_t hidden = x.size(1); const int64_t world_size = communicator->size(); const int64_t my_rank = communicator->deviceId(); - NVF_CHECK_EQ( - is_token_in_rank.size(0), - num_tokens, - "is_token_in_rank first dim must match number of tokens."); - NVF_CHECK_EQ( - is_token_in_rank.size(1), - world_size, - "is_token_in_rank second dim must match world size."); NVF_CHECK_EQ(num_experts % world_size, 0, "num_experts must be divisible."); const int64_t experts_per_rank = num_experts / world_size; @@ -104,10 +85,11 @@ DispatchResult doMoeDispatch( topk_weights.sizes()); auto topk_weights_flat = topk_weights.reshape({num_tokens}); - // Determine destination rank per token (topk=1). - auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); - // Sort tokens by destination rank for contiguous alltoall slices. - auto sorted_indices = at::argsort(rank_for_token); + // Assume contiguous expert placement: rank = expert_id / experts_per_rank. + auto topk_idx_long = topk_idx_flat.to(at::kLong); + auto rank_for_token = at::floor_divide(topk_idx_long, experts_per_rank); + // Sorting by expert id groups tokens by rank and by expert within rank. + auto sorted_indices = at::argsort(topk_idx_long); // Reorder payloads so alltoall can send contiguous chunks per rank. auto send_x = x.index_select(0, sorted_indices); @@ -170,15 +152,6 @@ DispatchResult doMoeDispatch( waitWork(pg->alltoall_base( recv_src_rank, send_src_rank, output_splits, input_splits)); - // Locally reorder by expert id so each rank processes contiguous experts. - auto local_expert = recv_topk_idx - my_rank * experts_per_rank; - auto expert_order = at::argsort(local_expert); - recv_x = recv_x.index_select(0, expert_order); - recv_topk_idx = recv_topk_idx.index_select(0, expert_order); - recv_topk_weights = recv_topk_weights.index_select(0, expert_order); - recv_src_idx = recv_src_idx.index_select(0, expert_order); - recv_src_rank = recv_src_rank.index_select(0, expert_order); - return DispatchResult{ recv_x, recv_topk_idx, diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 9c4f4e6a62a..2467f06995e 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -35,8 +35,8 @@ struct CombineResult { // x: Token embeddings on this rank, shape [T, H]. // topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. // topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. -// is_token_in_rank: One-hot token-to-rank assignment, shape [T, R], enabling -// non-trivial device meshes or uneven expert-to-rank mappings. +// Experts are assumed to be placed contiguously by rank so +// rank = topk_idx / experts_per_rank. // num_experts: Total experts across all ranks (must be divisible by R). // communicator: Communicator for alltoall exchange. // backend: Communication backend (only NCCL is supported for now). @@ -50,13 +50,8 @@ struct CombineResult { // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} // Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: // rank0 x = [x0, x1], rank1 x = [x2, x3] -// token->rank: [0, 1, 1, 1] (rank0 keeps x0, sends x1; rank1 keeps x2,x3) -// is_token_in_rank = -// [[1, 0], -// [0, 1], -// [0, 1], -// [0, 1]] -// topk_idx = [0, 2, 3, 2] (global expert ids) +// token->rank: [0, 1, 1, 1] (via expert ids below) +// topk_idx = [0, 2, 3, 2] (global expert ids, contiguous per rank) // After dispatch on rank0: // recv_x has token {0} // recv_topk_idx aligned with recv_x (e.g., [0]) @@ -71,7 +66,6 @@ struct CombineResult { // x, // topk_idx, // topk_weights, -// is_token_in_rank, // 4, // comm, // CommunicatorBackend::kNccl); @@ -79,7 +73,6 @@ NVF_API DispatchResult doMoeDispatch( const at::Tensor& x, // [T, H] const at::Tensor& topk_idx, // [T] or [T, 1] const at::Tensor& topk_weights, // [T] or [T, 1] - const at::Tensor& is_token_in_rank, // [T, R] int64_t num_experts, Communicator* communicator, CommunicatorBackend backend); diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index b3884b9499f..198790118eb 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -41,7 +41,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto* in_x = makeSymbolicTensor(2); auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); auto* in_topk_weights = makeSymbolicTensor(1); - auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); auto* recv_x = makeSymbolicTensor(2); auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); @@ -62,7 +61,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_x, in_topk_idx, in_topk_weights, - in_is_token_in_rank, num_experts, CommunicatorBackend::kNccl); @@ -85,7 +83,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { hic->addInput(in_x); hic->addInput(in_topk_idx); hic->addInput(in_topk_weights); - hic->addInput(in_is_token_in_rank); hic->addOutput(combined_x); hic->addOutput(combined_topk_weights); @@ -105,10 +102,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. - auto rank_ids = at::arange(world_size, int_options); - auto token_rank = at::tensor({0, 1, 1, 1}, int_options); - auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); - // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. topk_idx.index_put_({0}, 0); topk_idx.index_put_({1}, kNumExpertsPerRank); @@ -118,8 +111,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, - {in_topk_weights, topk_weights}, - {in_is_token_in_rank, is_token_in_rank}}); + {in_topk_weights, topk_weights}}); auto combined = outputs[0].as(); auto combined_weights = outputs[1].as(); @@ -147,7 +139,6 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { auto* in_x = makeSymbolicTensor(2); auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); auto* in_topk_weights = makeSymbolicTensor(1); - auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); auto* recv_x = makeSymbolicTensor(2); auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); @@ -168,7 +159,6 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { in_x, in_topk_idx, in_topk_weights, - in_is_token_in_rank, num_experts, CommunicatorBackend::kNccl); @@ -177,7 +167,6 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { hic->addInput(in_x); hic->addInput(in_topk_idx); hic->addInput(in_topk_weights); - hic->addInput(in_is_token_in_rank); hic->addOutput(recv_x); hic->addOutput(recv_topk_idx); hic->addOutput(recv_topk_weights); @@ -202,10 +191,6 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. - auto rank_ids = at::arange(world_size, int_options); - auto token_rank = at::tensor({0, 1, 1, 1}, int_options); - auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); - // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. topk_idx.index_put_({0}, 0); topk_idx.index_put_({1}, kNumExpertsPerRank); @@ -215,14 +200,12 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, - {in_topk_weights, topk_weights}, - {in_is_token_in_rank, is_token_in_rank}}); + {in_topk_weights, topk_weights}}); auto expected = doMoeDispatch( x, topk_idx, topk_weights, - is_token_in_rank, num_experts, communicator_, CommunicatorBackend::kNccl); @@ -275,10 +258,6 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. - auto rank_ids = at::arange(world_size, int_options); - auto token_rank = at::tensor({0, 1, 1, 1}, int_options); - auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); - // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. topk_idx.index_put_({0}, 0); topk_idx.index_put_({1}, kNumExpertsPerRank); @@ -289,7 +268,6 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { x, topk_idx, topk_weights, - is_token_in_rank, num_experts, communicator_, CommunicatorBackend::kNccl); From 47d710fae4ee6cc2e782dabf7fa16d816afa8274 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 9 Feb 2026 03:04:28 -0800 Subject: [PATCH 11/18] simplify by enforcing 2D shapes --- csrc/multidevice/communication.h | 5 +- csrc/multidevice/dispatch_combine.cpp | 44 +++++-------- csrc/multidevice/dispatch_combine.h | 8 +-- .../cpp/test_multidevice_dispatch_combine.cpp | 64 ++++++++++--------- 4 files changed, 59 insertions(+), 62 deletions(-) diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 9303e30e7d1..1a3ddd6f1ec 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -178,8 +178,9 @@ class P2PCommunication : public Expr { // the local rank to destination ranks based on explicit routing. // // Example shapes (topk=1): -// in_x: [T, H], in_topk_idx: [T] or [T, 1], -// in_topk_weights: [T] or [T, 1], num_experts = R * experts_per_rank. +// in_x: [T, H], in_topk_idx: [T, 1], +// in_topk_weights: [T, 1], num_experts = R * experts_per_rank. +// For topk>1, use [T, K] for both topk inputs. // Experts are assumed to be placed contiguously by rank. // out_src_idx/out_src_rank are returned for the combine step to restore the // original token order. diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 3776f003f76..6fc5e8f533e 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -67,23 +67,15 @@ DispatchResult doMoeDispatch( NVF_CHECK_EQ(num_experts % world_size, 0, "num_experts must be divisible."); const int64_t experts_per_rank = num_experts / world_size; - const bool topk_is_1d = topk_idx.dim() == 1 && topk_idx.size(0) == num_tokens; - const bool topk_is_2d = topk_idx.dim() == 2 && - topk_idx.size(0) == num_tokens && topk_idx.size(1) == 1; NVF_CHECK( - topk_is_1d || topk_is_2d, - "Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ", + topk_idx.dim() == 2 && topk_idx.size(0) == num_tokens && topk_idx.size(1) == 1, + "Only topk=1 supported. topk_idx must be shape [T, 1], got: ", topk_idx.sizes()); auto topk_idx_flat = topk_idx.reshape({num_tokens}); - const bool weights_is_1d = - topk_weights.dim() == 1 && topk_weights.size(0) == num_tokens; - const bool weights_is_2d = topk_weights.dim() == 2 && - topk_weights.size(0) == num_tokens && topk_weights.size(1) == 1; NVF_CHECK( - weights_is_1d || weights_is_2d, - "Only topk=1 supported. topk_weights must be shape [T] or [T, 1], got: ", + topk_weights.dim() == 2 && topk_weights.size(0) == num_tokens && topk_weights.size(1) == 1, + "Only topk=1 supported. topk_weights must be shape [T, 1], got: ", topk_weights.sizes()); - auto topk_weights_flat = topk_weights.reshape({num_tokens}); // Assume contiguous expert placement: rank = expert_id / experts_per_rank. auto topk_idx_long = topk_idx_flat.to(at::kLong); @@ -93,8 +85,8 @@ DispatchResult doMoeDispatch( // Reorder payloads so alltoall can send contiguous chunks per rank. auto send_x = x.index_select(0, sorted_indices); - auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); - auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + auto send_topk_idx = topk_idx.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights.index_select(0, sorted_indices); // Track original token indices and source rank for the combine step. auto send_src_idx = sorted_indices.to(at::kLong); // All entries are identical, so no relayout is needed. @@ -136,8 +128,10 @@ DispatchResult doMoeDispatch( // Allocate receive buffers for payloads and metadata. // TODO: support preallocated buffers. auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); + auto recv_topk_idx = at::empty( + {total_recv, topk_idx.size(1)}, topk_idx.options()); + auto recv_topk_weights = at::empty( + {total_recv, topk_weights.size(1)}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); @@ -186,15 +180,10 @@ CombineResult doMoeCombine( NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); - const bool weights_is_1d = - topk_weights.dim() == 1 && topk_weights.size(0) == x.size(0); - const bool weights_is_2d = topk_weights.dim() == 2 && - topk_weights.size(0) == x.size(0) && topk_weights.size(1) == 1; NVF_CHECK( - weights_is_1d || weights_is_2d, - "topk_weights must be shape [T] or [T, 1], got: ", + topk_weights.dim() == 2 && topk_weights.size(0) == x.size(0) && topk_weights.size(1) == 1, + "topk_weights must be shape [T, 1], got: ", topk_weights.sizes()); - auto topk_weights_flat = topk_weights.reshape({x.size(0)}); NVF_CHECK_EQ( src_idx.size(0), x.size(0), "src_idx size must match x first dimension."); NVF_CHECK_EQ( @@ -213,7 +202,7 @@ CombineResult doMoeCombine( // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted_indices = at::argsort(src_rank); auto send_x = x.index_select(0, sorted_indices); - auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights.index_select(0, sorted_indices); auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. @@ -235,7 +224,8 @@ CombineResult doMoeCombine( // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); + auto recv_topk_weights = at::empty( + {total_recv, topk_weights.size(1)}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); @@ -247,8 +237,8 @@ CombineResult doMoeCombine( // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); - auto combined_topk_weights = - at::empty({total_recv}, topk_weights_flat.options()); + auto combined_topk_weights = at::empty( + {total_recv, topk_weights.size(1)}, topk_weights.options()); combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); return CombineResult{combined_x, combined_topk_weights}; diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 2467f06995e..6bd24a76094 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -33,8 +33,8 @@ struct CombineResult { // // Args: // x: Token embeddings on this rank, shape [T, H]. -// topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. -// topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. +// topk_idx: Global expert ids per token, shape [T, K] (K=1 supported). +// topk_weights: Gating weights per token, shape [T, K] (K=1 supported). // Experts are assumed to be placed contiguously by rank so // rank = topk_idx / experts_per_rank. // num_experts: Total experts across all ranks (must be divisible by R). @@ -51,7 +51,7 @@ struct CombineResult { // Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: // rank0 x = [x0, x1], rank1 x = [x2, x3] // token->rank: [0, 1, 1, 1] (via expert ids below) -// topk_idx = [0, 2, 3, 2] (global expert ids, contiguous per rank) +// topk_idx = [[0], [2], [3], [2]] (global expert ids, contiguous per rank) // After dispatch on rank0: // recv_x has token {0} // recv_topk_idx aligned with recv_x (e.g., [0]) @@ -81,7 +81,7 @@ NVF_API DispatchResult doMoeDispatch( // // Args: // x: Token embeddings after expert compute, shape [T_recv, H]. -// topk_weights: Gating weights aligned with x, shape [T_recv] or [T_recv, 1]. +// topk_weights: Gating weights aligned with x, shape [T_recv, K] (K=1). // src_idx: Original token indices for each row of x, shape [T_recv]. // src_rank: Original source rank per token, shape [T_recv]. // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 198790118eb..59cdb692868 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -39,12 +39,12 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { FusionGuard fg(hic.get()); auto* in_x = makeSymbolicTensor(2); - auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); - auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_topk_idx = makeSymbolicTensor(2, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(2); auto* recv_x = makeSymbolicTensor(2); - auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); - auto* recv_topk_weights = makeSymbolicTensor(1); + auto* recv_topk_idx = makeSymbolicTensor(2, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(2); auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); @@ -65,7 +65,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); - auto* combined_topk_weights = makeSymbolicTensor(1); + auto* combined_topk_weights = makeSymbolicTensor(2); auto* combine = IrBuilder::create( combined_x, combined_topk_weights, @@ -96,17 +96,19 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto x = at::arange(kNumTokens * kHidden, float_options) .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; - auto topk_idx = at::zeros({kNumTokens}, int_options); + auto topk_idx = at::zeros({kNumTokens, 1}, int_options); auto topk_weights = - at::arange(kNumTokens, float_options) + static_cast(my_rank); + at::arange(kNumTokens, float_options) + .reshape({kNumTokens, 1}) + + static_cast(my_rank); // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. - topk_idx.index_put_({0}, 0); - topk_idx.index_put_({1}, kNumExpertsPerRank); - topk_idx.index_put_({2}, kNumExpertsPerRank + 1); - topk_idx.index_put_({3}, kNumExpertsPerRank); + topk_idx.index_put_({0, 0}, 0); + topk_idx.index_put_({1, 0}, kNumExpertsPerRank); + topk_idx.index_put_({2, 0}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3, 0}, kNumExpertsPerRank); auto outputs = hie.runWithInput( {{in_x, x}, @@ -137,12 +139,12 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { FusionGuard fg(hic.get()); auto* in_x = makeSymbolicTensor(2); - auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); - auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_topk_idx = makeSymbolicTensor(2, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(2); auto* recv_x = makeSymbolicTensor(2); - auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); - auto* recv_topk_weights = makeSymbolicTensor(1); + auto* recv_topk_idx = makeSymbolicTensor(2, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(2); auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); @@ -185,17 +187,19 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { auto x = at::arange(kNumTokens * kHidden, float_options) .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; - auto topk_idx = at::zeros({kNumTokens}, int_options); + auto topk_idx = at::zeros({kNumTokens, 1}, int_options); auto topk_weights = - at::arange(kNumTokens, float_options) + static_cast(my_rank); + at::arange(kNumTokens, float_options) + .reshape({kNumTokens, 1}) + + static_cast(my_rank); // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. - topk_idx.index_put_({0}, 0); - topk_idx.index_put_({1}, kNumExpertsPerRank); - topk_idx.index_put_({2}, kNumExpertsPerRank + 1); - topk_idx.index_put_({3}, kNumExpertsPerRank); + topk_idx.index_put_({0, 0}, 0); + topk_idx.index_put_({1, 0}, kNumExpertsPerRank); + topk_idx.index_put_({2, 0}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3, 0}, kNumExpertsPerRank); auto outputs = hie.runWithInput( {{in_x, x}, @@ -252,17 +256,19 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { auto x = at::arange(kNumTokens * kHidden, float_options) .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; - auto topk_idx = at::zeros({kNumTokens}, int_options); + auto topk_idx = at::zeros({kNumTokens, 1}, int_options); auto topk_weights = - at::arange(kNumTokens, float_options) + static_cast(my_rank); + at::arange(kNumTokens, float_options) + .reshape({kNumTokens, 1}) + + static_cast(my_rank); // Asymmetric example: // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. - topk_idx.index_put_({0}, 0); - topk_idx.index_put_({1}, kNumExpertsPerRank); - topk_idx.index_put_({2}, kNumExpertsPerRank + 1); - topk_idx.index_put_({3}, kNumExpertsPerRank); + topk_idx.index_put_({0, 0}, 0); + topk_idx.index_put_({1, 0}, kNumExpertsPerRank); + topk_idx.index_put_({2, 0}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3, 0}, kNumExpertsPerRank); auto dispatch_result = doMoeDispatch( x, @@ -276,14 +282,14 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { FusionGuard fg(hic.get()); auto* in_x = makeSymbolicTensor(2); - auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_topk_weights = makeSymbolicTensor(2); auto* in_src_idx = makeSymbolicTensor(1, DataType::Int); auto* in_src_rank = makeSymbolicTensor(1, DataType::Int); auto* in_n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); auto* in_n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); auto* combined_x = makeSymbolicTensor(2); - auto* combined_topk_weights = makeSymbolicTensor(1); + auto* combined_topk_weights = makeSymbolicTensor(2); auto* combine = IrBuilder::create( combined_x, combined_topk_weights, From f39daf23dff85b56b1f19022d6a4e949039af9ba Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 9 Feb 2026 03:17:56 -0800 Subject: [PATCH 12/18] lint --- csrc/multidevice/dispatch_combine.cpp | 25 ++++++++++-------- .../cpp/test_multidevice_dispatch_combine.cpp | 26 ++++++------------- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 6fc5e8f533e..69b7075a387 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -68,12 +68,14 @@ DispatchResult doMoeDispatch( const int64_t experts_per_rank = num_experts / world_size; NVF_CHECK( - topk_idx.dim() == 2 && topk_idx.size(0) == num_tokens && topk_idx.size(1) == 1, + topk_idx.dim() == 2 && topk_idx.size(0) == num_tokens && + topk_idx.size(1) == 1, "Only topk=1 supported. topk_idx must be shape [T, 1], got: ", topk_idx.sizes()); auto topk_idx_flat = topk_idx.reshape({num_tokens}); NVF_CHECK( - topk_weights.dim() == 2 && topk_weights.size(0) == num_tokens && topk_weights.size(1) == 1, + topk_weights.dim() == 2 && topk_weights.size(0) == num_tokens && + topk_weights.size(1) == 1, "Only topk=1 supported. topk_weights must be shape [T, 1], got: ", topk_weights.sizes()); @@ -128,10 +130,10 @@ DispatchResult doMoeDispatch( // Allocate receive buffers for payloads and metadata. // TODO: support preallocated buffers. auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_idx = at::empty( - {total_recv, topk_idx.size(1)}, topk_idx.options()); - auto recv_topk_weights = at::empty( - {total_recv, topk_weights.size(1)}, topk_weights.options()); + auto recv_topk_idx = + at::empty({total_recv, topk_idx.size(1)}, topk_idx.options()); + auto recv_topk_weights = + at::empty({total_recv, topk_weights.size(1)}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); @@ -181,7 +183,8 @@ CombineResult doMoeCombine( NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); NVF_CHECK( - topk_weights.dim() == 2 && topk_weights.size(0) == x.size(0) && topk_weights.size(1) == 1, + topk_weights.dim() == 2 && topk_weights.size(0) == x.size(0) && + topk_weights.size(1) == 1, "topk_weights must be shape [T, 1], got: ", topk_weights.sizes()); NVF_CHECK_EQ( @@ -224,8 +227,8 @@ CombineResult doMoeCombine( // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = at::empty( - {total_recv, topk_weights.size(1)}, topk_weights.options()); + auto recv_topk_weights = + at::empty({total_recv, topk_weights.size(1)}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); @@ -237,8 +240,8 @@ CombineResult doMoeCombine( // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); - auto combined_topk_weights = at::empty( - {total_recv, topk_weights.size(1)}, topk_weights.options()); + auto combined_topk_weights = + at::empty({total_recv, topk_weights.size(1)}, topk_weights.options()); combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); return CombineResult{combined_x, combined_topk_weights}; diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 59cdb692868..d8f7c5a58eb 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -98,8 +98,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { static_cast(my_rank) * 1000.0; auto topk_idx = at::zeros({kNumTokens, 1}, int_options); auto topk_weights = - at::arange(kNumTokens, float_options) - .reshape({kNumTokens, 1}) + + at::arange(kNumTokens, float_options).reshape({kNumTokens, 1}) + static_cast(my_rank); // Asymmetric example: @@ -111,9 +110,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { topk_idx.index_put_({3, 0}, kNumExpertsPerRank); auto outputs = hie.runWithInput( - {{in_x, x}, - {in_topk_idx, topk_idx}, - {in_topk_weights, topk_weights}}); + {{in_x, x}, {in_topk_idx, topk_idx}, {in_topk_weights, topk_weights}}); auto combined = outputs[0].as(); auto combined_weights = outputs[1].as(); @@ -189,8 +186,7 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { static_cast(my_rank) * 1000.0; auto topk_idx = at::zeros({kNumTokens, 1}, int_options); auto topk_weights = - at::arange(kNumTokens, float_options) - .reshape({kNumTokens, 1}) + + at::arange(kNumTokens, float_options).reshape({kNumTokens, 1}) + static_cast(my_rank); // Asymmetric example: @@ -202,9 +198,7 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { topk_idx.index_put_({3, 0}, kNumExpertsPerRank); auto outputs = hie.runWithInput( - {{in_x, x}, - {in_topk_idx, topk_idx}, - {in_topk_weights, topk_weights}}); + {{in_x, x}, {in_topk_idx, topk_idx}, {in_topk_weights, topk_weights}}); auto expected = doMoeDispatch( x, @@ -216,17 +210,14 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { EXPECT_TRUE(at::allclose(outputs[0].as(), expected.recv_x)) << "Dispatch recv_x mismatch on rank " << my_rank; - EXPECT_TRUE( - at::allclose(outputs[1].as(), expected.recv_topk_idx)) + EXPECT_TRUE(at::allclose(outputs[1].as(), expected.recv_topk_idx)) << "Dispatch recv_topk_idx mismatch on rank " << my_rank; EXPECT_TRUE( at::allclose(outputs[2].as(), expected.recv_topk_weights)) << "Dispatch recv_topk_weights mismatch on rank " << my_rank; - EXPECT_TRUE( - at::allclose(outputs[3].as(), expected.recv_src_idx)) + EXPECT_TRUE(at::allclose(outputs[3].as(), expected.recv_src_idx)) << "Dispatch recv_src_idx mismatch on rank " << my_rank; - EXPECT_TRUE( - at::allclose(outputs[4].as(), expected.recv_src_rank)) + EXPECT_TRUE(at::allclose(outputs[4].as(), expected.recv_src_rank)) << "Dispatch recv_src_rank mismatch on rank " << my_rank; EXPECT_TRUE( at::allclose(outputs[5].as(), expected.n_tokens_to_rank)) @@ -258,8 +249,7 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { static_cast(my_rank) * 1000.0; auto topk_idx = at::zeros({kNumTokens, 1}, int_options); auto topk_weights = - at::arange(kNumTokens, float_options) - .reshape({kNumTokens, 1}) + + at::arange(kNumTokens, float_options).reshape({kNumTokens, 1}) + static_cast(my_rank); // Asymmetric example: From da5222041125a1615302c42eeccae2a38ddc945c Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 9 Feb 2026 03:26:43 -0800 Subject: [PATCH 13/18] remove combined_topk_weights --- csrc/host_ir/evaluator.cpp | 1 - csrc/multidevice/communication.cpp | 6 ---- csrc/multidevice/communication.h | 10 ++---- csrc/multidevice/dispatch_combine.cpp | 34 +++++++++---------- csrc/multidevice/dispatch_combine.h | 4 +-- .../cpp/test_multidevice_dispatch_combine.cpp | 13 ------- 6 files changed, 20 insertions(+), 48 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 08a49e11ede..4cf418d3c1d 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -441,7 +441,6 @@ void HostIrEvaluator::handle(MoeCombine* combine) { combine->backend()); expr_evaluator_.bind(combine->outX(), result.combined_x); - expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights); } void HostIrEvaluator::handle(Wait* wait) { diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 02d2fdad289..16412a2ce7d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -410,7 +410,6 @@ void MoeDispatch::validate() { MoeCombine::MoeCombine( IrBuilderPasskey passkey, TensorView* out_x, - TensorView* out_topk_weights, TensorView* in_x, TensorView* in_topk_weights, TensorView* in_src_idx, @@ -426,7 +425,6 @@ MoeCombine::MoeCombine( addInput(in_n_tokens_to_rank); addInput(in_n_tokens_from_rank); addOutput(out_x); - addOutput(out_topk_weights); addDataAttribute(backend); validate(); } @@ -471,10 +469,6 @@ void MoeCombine::validate() { inTokensFromRank()->getDataType().has_value() && isIntegralType(*inTokensFromRank()->getDataType()), "in_n_tokens_from_rank must be integral."); - NVF_CHECK( - outTopkWeights()->getDataType().has_value() && - isFloatingPointType(*outTopkWeights()->getDataType()), - "out_topk_weights must be floating point."); } namespace { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 1a3ddd6f1ec..6fb46b14f41 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -275,10 +275,9 @@ class MoeDispatch : public Expr { // their source ranks using `in_src_rank` and `in_src_idx`. // // Example shapes (topk=1): -// in_x: [T_recv, H], in_topk_weights: [T_recv], in_src_idx: [T_recv], +// in_x: [T_recv, H], in_topk_weights: [T_recv, 1], in_src_idx: [T_recv], // in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: -// [R]. Outputs are source-aligned: out_x/out_topk_weights with shape [T_src, -// ...]. +// [R]. Output out_x is source-aligned with shape [T_src, ...]. class MoeCombine : public Expr { public: using Expr::Expr; @@ -286,7 +285,6 @@ class MoeCombine : public Expr { MoeCombine( IrBuilderPasskey passkey, TensorView* out_x, - TensorView* out_topk_weights, TensorView* in_x, TensorView* in_topk_weights, TensorView* in_src_idx, @@ -312,10 +310,6 @@ class MoeCombine : public Expr { return output(0)->as(); } - TensorView* outTopkWeights() const { - return output(1)->as(); - } - TensorView* inX() const { return input(0)->as(); } diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 69b7075a387..7ab967f228b 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -169,10 +169,18 @@ CombineResult doMoeCombine( CommunicatorBackend backend) { NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); - NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); - NVF_CHECK( - topk_weights.is_floating_point(), - "Combine topk_weights must be floating point."); + const bool has_topk_weights = topk_weights.numel() > 0; + if (has_topk_weights) { + NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); + NVF_CHECK( + topk_weights.is_floating_point(), + "Combine topk_weights must be floating point."); + NVF_CHECK( + topk_weights.dim() == 2 && topk_weights.size(0) == x.size(0) && + topk_weights.size(1) == 1, + "topk_weights must be shape [T, 1], got: ", + topk_weights.sizes()); + } NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); NVF_CHECK( @@ -182,11 +190,6 @@ CombineResult doMoeCombine( NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); - NVF_CHECK( - topk_weights.dim() == 2 && topk_weights.size(0) == x.size(0) && - topk_weights.size(1) == 1, - "topk_weights must be shape [T, 1], got: ", - topk_weights.sizes()); NVF_CHECK_EQ( src_idx.size(0), x.size(0), "src_idx size must match x first dimension."); NVF_CHECK_EQ( @@ -205,7 +208,6 @@ CombineResult doMoeCombine( // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted_indices = at::argsort(src_rank); auto send_x = x.index_select(0, sorted_indices); - auto send_topk_weights = topk_weights.index_select(0, sorted_indices); auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. @@ -227,24 +229,20 @@ CombineResult doMoeCombine( // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = - at::empty({total_recv, topk_weights.size(1)}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_src_idx, send_src_idx, output_splits, input_splits)); // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); - auto combined_topk_weights = - at::empty({total_recv, topk_weights.size(1)}, topk_weights.options()); - combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); - return CombineResult{combined_x, combined_topk_weights}; + // topk_weights is reserved for future weighted combine. + (void)topk_weights; + + return CombineResult{combined_x}; } } // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 6bd24a76094..e8614d50d4e 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -26,7 +26,6 @@ struct DispatchResult { struct CombineResult { at::Tensor combined_x; // Combined tokens back in original order. - at::Tensor combined_topk_weights; // Combined gating weights per token. }; // Dispatch MoE tokens to the owning ranks. Only k=1 is supported for now. @@ -81,7 +80,8 @@ NVF_API DispatchResult doMoeDispatch( // // Args: // x: Token embeddings after expert compute, shape [T_recv, H]. -// topk_weights: Gating weights aligned with x, shape [T_recv, K] (K=1). +// topk_weights: Optional gating weights aligned with x, shape [T_recv, K] +// (K=1). Pass empty to skip weighting in combine. // src_idx: Original token indices for each row of x, shape [T_recv]. // src_rank: Original source rank per token, shape [T_recv]. // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index d8f7c5a58eb..02b36e314f3 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -65,10 +65,8 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); - auto* combined_topk_weights = makeSymbolicTensor(2); auto* combine = IrBuilder::create( combined_x, - combined_topk_weights, recv_x, recv_topk_weights, recv_src_idx, @@ -84,7 +82,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { hic->addInput(in_topk_idx); hic->addInput(in_topk_weights); hic->addOutput(combined_x); - hic->addOutput(combined_topk_weights); HostIrEvaluator hie(std::move(hic), communicator_); @@ -112,12 +109,8 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, {in_topk_weights, topk_weights}}); auto combined = outputs[0].as(); - auto combined_weights = outputs[1].as(); - EXPECT_TRUE(at::allclose(combined, x)) << "Dispatch/Combine mismatch on rank " << my_rank; - EXPECT_TRUE(at::allclose(combined_weights, topk_weights)) - << "Dispatch/Combine topk_weights mismatch on rank " << my_rank; } TEST_F(DispatchCombineTest, DispatchOnlyTop1) { @@ -279,10 +272,8 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { auto* in_n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); auto* combined_x = makeSymbolicTensor(2); - auto* combined_topk_weights = makeSymbolicTensor(2); auto* combine = IrBuilder::create( combined_x, - combined_topk_weights, in_x, in_topk_weights, in_src_idx, @@ -300,7 +291,6 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { hic->addInput(in_n_tokens_to_rank); hic->addInput(in_n_tokens_from_rank); hic->addOutput(combined_x); - hic->addOutput(combined_topk_weights); HostIrEvaluator hie(std::move(hic), communicator_); @@ -313,12 +303,9 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { {in_n_tokens_from_rank, dispatch_result.n_tokens_from_rank}}); auto combined = outputs[0].as(); - auto combined_weights = outputs[1].as(); EXPECT_TRUE(at::allclose(combined, x)) << "Combine mismatch on rank " << my_rank; - EXPECT_TRUE(at::allclose(combined_weights, topk_weights)) - << "Combine topk_weights mismatch on rank " << my_rank; } } // namespace hir From c089049e193e499da5e9d1c72452b8f968697952 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 9 Feb 2026 03:31:50 -0800 Subject: [PATCH 14/18] minor simplification --- csrc/multidevice/dispatch_combine.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 7ab967f228b..b55caae635e 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -109,12 +109,11 @@ DispatchResult doMoeDispatch( backend, CommunicatorBackend::kNccl, "Only NCCL backend is supported for MoeDispatch."); - CommunicatorBackend actual_backend = backend; NVF_CHECK( - communicator->isBackendAvailable(actual_backend), + communicator->isBackendAvailable(backend), "Backend not available for dispatch: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); + backend); + auto* pg = communicator->getWorld(backend); NVF_CHECK(pg != nullptr, "Dispatch backend is null."); // Exchange per-rank token counts to build split sizes for alltoall. @@ -219,12 +218,11 @@ CombineResult doMoeCombine( NVF_CHECK( backend == CommunicatorBackend::kNccl, "Only NCCL backend is supported for MoeCombine."); - CommunicatorBackend actual_backend = backend; NVF_CHECK( - communicator->isBackendAvailable(actual_backend), + communicator->isBackendAvailable(backend), "Backend not available for combine: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); + backend); + auto* pg = communicator->getWorld(backend); NVF_CHECK(pg != nullptr, "Combine backend is null."); // Allocate receive buffers and exchange payloads back to source ranks. From 490200fe3712749f3c964c67a42fb980a2c84685 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 9 Feb 2026 05:36:54 -0800 Subject: [PATCH 15/18] remove (in|out|send)_src_rank --- csrc/host_ir/evaluator.cpp | 3 -- csrc/multidevice/communication.cpp | 13 --------- csrc/multidevice/communication.h | 26 +++++------------ csrc/multidevice/dispatch_combine.cpp | 29 +++++++------------ csrc/multidevice/dispatch_combine.h | 8 ++--- .../cpp/test_multidevice_dispatch_combine.cpp | 16 ++-------- 6 files changed, 23 insertions(+), 72 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 4cf418d3c1d..e77b8908a82 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -409,7 +409,6 @@ void HostIrEvaluator::handle(MoeDispatch* dispatch) { expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx); expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights); expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx); - expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank); expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank); expr_evaluator_.bind( dispatch->outTokensFromRank(), result.n_tokens_from_rank); @@ -424,7 +423,6 @@ void HostIrEvaluator::handle(MoeCombine* combine) { auto topk_weights = getKnownConcreteValue(combine->inTopkWeights()).as(); auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as(); - auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as(); auto n_tokens_to_rank = getKnownConcreteValue(combine->inTokensToRank()).as(); auto n_tokens_from_rank = @@ -434,7 +432,6 @@ void HostIrEvaluator::handle(MoeCombine* combine) { x, topk_weights, src_idx, - src_rank, n_tokens_to_rank, n_tokens_from_rank, communicator_, diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 16412a2ce7d..9557374bf6c 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -327,7 +327,6 @@ MoeDispatch::MoeDispatch( TensorView* out_topk_idx, TensorView* out_topk_weights, TensorView* out_src_idx, - TensorView* out_src_rank, TensorView* out_n_tokens_to_rank, TensorView* out_n_tokens_from_rank, TensorView* in_x, @@ -343,7 +342,6 @@ MoeDispatch::MoeDispatch( addOutput(out_topk_idx); addOutput(out_topk_weights); addOutput(out_src_idx); - addOutput(out_src_rank); addOutput(out_n_tokens_to_rank); addOutput(out_n_tokens_from_rank); addDataAttribute(num_experts); @@ -393,10 +391,6 @@ void MoeDispatch::validate() { outSrcIdx()->getDataType().has_value() && isIntegralType(*outSrcIdx()->getDataType()), "out_src_idx must be integral."); - NVF_CHECK( - outSrcRank()->getDataType().has_value() && - isIntegralType(*outSrcRank()->getDataType()), - "out_src_rank must be integral."); NVF_CHECK( outTokensToRank()->getDataType().has_value() && isIntegralType(*outTokensToRank()->getDataType()), @@ -413,7 +407,6 @@ MoeCombine::MoeCombine( TensorView* in_x, TensorView* in_topk_weights, TensorView* in_src_idx, - TensorView* in_src_rank, TensorView* in_n_tokens_to_rank, TensorView* in_n_tokens_from_rank, CommunicatorBackend backend) @@ -421,7 +414,6 @@ MoeCombine::MoeCombine( addInput(in_x); addInput(in_topk_weights); addInput(in_src_idx); - addInput(in_src_rank); addInput(in_n_tokens_to_rank); addInput(in_n_tokens_from_rank); addOutput(out_x); @@ -438,7 +430,6 @@ std::string MoeCombine::toInlineString(int indent_size) const { << "in=" << inX() << ", " << "topk_weights=" << inTopkWeights() << ", " << "src_idx=" << inSrcIdx() << ", " - << "src_rank=" << inSrcRank() << ", " << "out=" << outX() << ")"; return ss.str(); } @@ -457,10 +448,6 @@ void MoeCombine::validate() { inSrcIdx()->getDataType().has_value() && isIntegralType(*inSrcIdx()->getDataType()), "in_src_idx must be integral."); - NVF_CHECK( - inSrcRank()->getDataType().has_value() && - isIntegralType(*inSrcRank()->getDataType()), - "in_src_rank must be integral."); NVF_CHECK( inTokensToRank()->getDataType().has_value() && isIntegralType(*inTokensToRank()->getDataType()), diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 6fb46b14f41..ce217a01adc 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -182,8 +182,8 @@ class P2PCommunication : public Expr { // in_topk_weights: [T, 1], num_experts = R * experts_per_rank. // For topk>1, use [T, K] for both topk inputs. // Experts are assumed to be placed contiguously by rank. -// out_src_idx/out_src_rank are returned for the combine step to restore the -// original token order. +// out_src_idx is returned for the combine step to restore the original token +// order. // Outputs are recv-aligned tensors: out_x/out_topk_idx/out_topk_weights/ // out_src_* with [T_recv, ...] and // out_n_tokens_to_rank/out_n_tokens_from_rank with shape [R]. @@ -197,7 +197,6 @@ class MoeDispatch : public Expr { TensorView* out_topk_idx, TensorView* out_topk_weights, TensorView* out_src_idx, - TensorView* out_src_rank, TensorView* out_n_tokens_to_rank, TensorView* out_n_tokens_from_rank, TensorView* in_x, @@ -235,16 +234,12 @@ class MoeDispatch : public Expr { return output(3)->as(); } - TensorView* outSrcRank() const { - return output(4)->as(); - } - TensorView* outTokensToRank() const { - return output(5)->as(); + return output(4)->as(); } TensorView* outTokensFromRank() const { - return output(6)->as(); + return output(5)->as(); } TensorView* inX() const { @@ -272,11 +267,11 @@ class MoeDispatch : public Expr { }; // Combine represents intra-node MoE token combine. It shuffles tokens back to -// their source ranks using `in_src_rank` and `in_src_idx`. +// their source ranks using `in_src_idx`. // // Example shapes (topk=1): // in_x: [T_recv, H], in_topk_weights: [T_recv, 1], in_src_idx: [T_recv], -// in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: +// in_n_tokens_to_rank: [R], in_n_tokens_from_rank: // [R]. Output out_x is source-aligned with shape [T_src, ...]. class MoeCombine : public Expr { public: @@ -288,7 +283,6 @@ class MoeCombine : public Expr { TensorView* in_x, TensorView* in_topk_weights, TensorView* in_src_idx, - TensorView* in_src_rank, TensorView* in_n_tokens_to_rank, TensorView* in_n_tokens_from_rank, CommunicatorBackend backend = CommunicatorBackend::kNccl); @@ -322,16 +316,12 @@ class MoeCombine : public Expr { return input(2)->as(); } - TensorView* inSrcRank() const { - return input(3)->as(); - } - TensorView* inTokensToRank() const { - return input(4)->as(); + return input(3)->as(); } TensorView* inTokensFromRank() const { - return input(5)->as(); + return input(4)->as(); } CommunicatorBackend backend() const { diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index b55caae635e..ded8d3fd1c7 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -63,7 +63,6 @@ DispatchResult doMoeDispatch( const int64_t num_tokens = x.size(0); const int64_t hidden = x.size(1); const int64_t world_size = communicator->size(); - const int64_t my_rank = communicator->deviceId(); NVF_CHECK_EQ(num_experts % world_size, 0, "num_experts must be divisible."); const int64_t experts_per_rank = num_experts / world_size; @@ -89,13 +88,8 @@ DispatchResult doMoeDispatch( auto send_x = x.index_select(0, sorted_indices); auto send_topk_idx = topk_idx.index_select(0, sorted_indices); auto send_topk_weights = topk_weights.index_select(0, sorted_indices); - // Track original token indices and source rank for the combine step. + // Track original token indices for the combine step. auto send_src_idx = sorted_indices.to(at::kLong); - // All entries are identical, so no relayout is needed. - auto send_src_rank = at::full( - {num_tokens}, - my_rank, - at::TensorOptions().dtype(at::kLong).device(x.device())); // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we // sync/copy here. GPU-initiated comms can avoid this extra sync. @@ -134,7 +128,6 @@ DispatchResult doMoeDispatch( auto recv_topk_weights = at::empty({total_recv, topk_weights.size(1)}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); - auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); // Alltoall exchange payloads with per-rank splits. waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); @@ -144,15 +137,12 @@ DispatchResult doMoeDispatch( recv_topk_weights, send_topk_weights, output_splits, input_splits)); waitWork(pg->alltoall_base( recv_src_idx, send_src_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)); return DispatchResult{ recv_x, recv_topk_idx, recv_topk_weights, recv_src_idx, - recv_src_rank, n_tokens_to_rank, n_tokens_from_rank}; } @@ -161,7 +151,6 @@ CombineResult doMoeCombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, - const at::Tensor& src_rank, const at::Tensor& n_tokens_to_rank, const at::Tensor& n_tokens_from_rank, Communicator* communicator, @@ -181,20 +170,14 @@ CombineResult doMoeCombine( topk_weights.sizes()); } NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); - NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); NVF_CHECK( n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); NVF_CHECK( n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA."); NVF_CHECK_EQ(x.dim(), 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK_EQ(src_idx.dim(), 1, "src_idx must be 1D."); - NVF_CHECK_EQ(src_rank.dim(), 1, "src_rank must be 1D."); NVF_CHECK_EQ( src_idx.size(0), x.size(0), "src_idx size must match x first dimension."); - NVF_CHECK_EQ( - src_rank.size(0), - x.size(0), - "src_rank size must match x first dimension."); NVF_CHECK_EQ( n_tokens_to_rank.numel(), communicator->size(), @@ -204,6 +187,16 @@ CombineResult doMoeCombine( communicator->size(), "n_tokens_from_rank must match world size."); + // Reconstruct source ranks from per-rank counts. alltoall_base concatenates + // received chunks in rank order, so this matches the receive layout. + auto src_rank = at::arange( + n_tokens_from_rank.numel(), + at::TensorOptions().dtype(at::kLong).device(x.device())) + .repeat_interleave(n_tokens_from_rank.to(at::kLong)); + NVF_CHECK_EQ( + src_rank.size(0), + x.size(0), + "Reconstructed src_rank must match x first dimension."); // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted_indices = at::argsort(src_rank); auto send_x = x.index_select(0, sorted_indices); diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index e8614d50d4e..5e08f9cdca1 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -19,7 +19,6 @@ struct DispatchResult { at::Tensor recv_topk_idx; // Expert ids aligned with recv_x. at::Tensor recv_topk_weights; // Gating weights aligned with recv_x. at::Tensor recv_src_idx; // Source token indices for combine. - at::Tensor recv_src_rank; // Source ranks for combine. at::Tensor n_tokens_to_rank; // Tokens sent to each rank (this rank's view). at::Tensor n_tokens_from_rank; // Tokens received from each rank. }; @@ -83,7 +82,6 @@ NVF_API DispatchResult doMoeDispatch( // topk_weights: Optional gating weights aligned with x, shape [T_recv, K] // (K=1). Pass empty to skip weighting in combine. // src_idx: Original token indices for each row of x, shape [T_recv]. -// src_rank: Original source rank per token, shape [T_recv]. // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. // n_tokens_from_rank: Tokens received from each rank (from dispatch), shape // [R]. @@ -97,19 +95,17 @@ NVF_API DispatchResult doMoeDispatch( // // Continuing the dispatch example (experts partitioned by rank): // // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} // // After expert compute: -// // rank0 recv_x has token {0} with src_idx = [0], src_rank = [0] +// // rank0 recv_x has token {0} with src_idx = [0] // // rank1 recv_x has tokens {1, 2, 3} with src_idx = [1, 2, 3], -// // src_rank = [0, 1, 1] // // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. // // Combine scatters results back to original token order per rank. // auto combined = doMoeCombine( -// x, topk_weights, src_idx, src_rank, n_tokens_to_rank, +// x, topk_weights, src_idx, n_tokens_to_rank, // n_tokens_from_rank, comm, CommunicatorBackend::kNccl); NVF_API CombineResult doMoeCombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, - const at::Tensor& src_rank, const at::Tensor& n_tokens_to_rank, const at::Tensor& n_tokens_from_rank, Communicator* communicator, diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 02b36e314f3..a41eae9afac 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -46,7 +46,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto* recv_topk_idx = makeSymbolicTensor(2, DataType::Int); auto* recv_topk_weights = makeSymbolicTensor(2); auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); - auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); @@ -55,7 +54,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_topk_idx, recv_topk_weights, recv_src_idx, - recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, in_x, @@ -70,7 +68,6 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_x, recv_topk_weights, recv_src_idx, - recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, CommunicatorBackend::kNccl); @@ -136,7 +133,6 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { auto* recv_topk_idx = makeSymbolicTensor(2, DataType::Int); auto* recv_topk_weights = makeSymbolicTensor(2); auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); - auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); @@ -145,7 +141,6 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { recv_topk_idx, recv_topk_weights, recv_src_idx, - recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, in_x, @@ -163,7 +158,6 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { hic->addOutput(recv_topk_idx); hic->addOutput(recv_topk_weights); hic->addOutput(recv_src_idx); - hic->addOutput(recv_src_rank); hic->addOutput(n_tokens_to_rank); hic->addOutput(n_tokens_from_rank); @@ -210,13 +204,11 @@ TEST_F(DispatchCombineTest, DispatchOnlyTop1) { << "Dispatch recv_topk_weights mismatch on rank " << my_rank; EXPECT_TRUE(at::allclose(outputs[3].as(), expected.recv_src_idx)) << "Dispatch recv_src_idx mismatch on rank " << my_rank; - EXPECT_TRUE(at::allclose(outputs[4].as(), expected.recv_src_rank)) - << "Dispatch recv_src_rank mismatch on rank " << my_rank; EXPECT_TRUE( - at::allclose(outputs[5].as(), expected.n_tokens_to_rank)) + at::allclose(outputs[4].as(), expected.n_tokens_to_rank)) << "Dispatch n_tokens_to_rank mismatch on rank " << my_rank; EXPECT_TRUE( - at::allclose(outputs[6].as(), expected.n_tokens_from_rank)) + at::allclose(outputs[5].as(), expected.n_tokens_from_rank)) << "Dispatch n_tokens_from_rank mismatch on rank " << my_rank; } @@ -267,7 +259,6 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { auto* in_x = makeSymbolicTensor(2); auto* in_topk_weights = makeSymbolicTensor(2); auto* in_src_idx = makeSymbolicTensor(1, DataType::Int); - auto* in_src_rank = makeSymbolicTensor(1, DataType::Int); auto* in_n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); auto* in_n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); @@ -277,7 +268,6 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { in_x, in_topk_weights, in_src_idx, - in_src_rank, in_n_tokens_to_rank, in_n_tokens_from_rank, CommunicatorBackend::kNccl); @@ -287,7 +277,6 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { hic->addInput(in_x); hic->addInput(in_topk_weights); hic->addInput(in_src_idx); - hic->addInput(in_src_rank); hic->addInput(in_n_tokens_to_rank); hic->addInput(in_n_tokens_from_rank); hic->addOutput(combined_x); @@ -298,7 +287,6 @@ TEST_F(DispatchCombineTest, CombineOnlyTop1) { {{in_x, dispatch_result.recv_x}, {in_topk_weights, dispatch_result.recv_topk_weights}, {in_src_idx, dispatch_result.recv_src_idx}, - {in_src_rank, dispatch_result.recv_src_rank}, {in_n_tokens_to_rank, dispatch_result.n_tokens_to_rank}, {in_n_tokens_from_rank, dispatch_result.n_tokens_from_rank}}); From 3828247777acfa75a80f6f5967d5212ac595c114 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 10 Feb 2026 11:07:48 -0800 Subject: [PATCH 16/18] lint --- csrc/multidevice/alltoallv.cu | 1 - csrc/multidevice/cuda_p2p.h | 5 +--- tests/cpp/test_multidevice_alltoallv.cpp | 10 ++++---- .../cpp/test_multidevice_dispatch_combine.cpp | 23 +++++++------------ 4 files changed, 13 insertions(+), 26 deletions(-) diff --git a/csrc/multidevice/alltoallv.cu b/csrc/multidevice/alltoallv.cu index 9725794f838..be2b5d80373 100644 --- a/csrc/multidevice/alltoallv.cu +++ b/csrc/multidevice/alltoallv.cu @@ -34,4 +34,3 @@ extern "C" __global__ void alltoallv_kernel( static_cast(recv_ptrs[peer])); dst[recv_byte_offset] = send[send_byte_offset]; } - diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index e9fd5828597..514195c0746 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -7,11 +7,8 @@ // clang-format on #pragma once -#include - #include -#include -#include +#include #include "multidevice/ipc_handle.h" diff --git a/tests/cpp/test_multidevice_alltoallv.cpp b/tests/cpp/test_multidevice_alltoallv.cpp index 02cb21b7892..2d93b0e0092 100644 --- a/tests/cpp/test_multidevice_alltoallv.cpp +++ b/tests/cpp/test_multidevice_alltoallv.cpp @@ -6,9 +6,6 @@ */ // clang-format on #include -#include - -#include #include "multidevice/cuda_p2p.h" #include "multidevice/symmetric_tensor.h" @@ -38,7 +35,8 @@ TEST_F(AlltoallvCudaTest, AlltoallvAsymmetric) { send_counts.index_put_({dest}, count_for(my_rank, dest)); } - auto metadata = prepareAlltoallvMetadata(send_counts, "test_alltoallv_counts"); + auto metadata = + prepareAlltoallvMetadata(send_counts, "test_alltoallv_counts"); const int64_t max_recv = metadata.max_recv; const int64_t total_send = send_counts.sum().item(); auto send_sym = SymmetricTensor::allocate( @@ -46,8 +44,8 @@ TEST_F(AlltoallvCudaTest, AlltoallvAsymmetric) { send_sym.narrow(0, 0, total_send) .copy_(at::arange(total_send, int_options) + my_rank * 1000); - auto recv_sym = SymmetricTensor::allocate( - {max_recv}, at::kLong, communicator_->device()); + auto recv_sym = + SymmetricTensor::allocate({max_recv}, at::kLong, communicator_->device()); SymmetricTensor recv_handle(recv_sym); recv_handle.setupRemoteHandles("test_alltoallv_recv"); diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 4a42803b1e4..16bc818fb61 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -33,7 +33,8 @@ TEST_P(DispatchCombineTest, DispatchCombineTop1) { const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); const auto backend = GetParam(); - if (backend != CommunicatorBackend::kCuda && !communicator_->isBackendAvailable(backend)) { + if (backend != CommunicatorBackend::kCuda && + !communicator_->isBackendAvailable(backend)) { GTEST_SKIP() << "Backend " << backend << " not available."; } constexpr int64_t kNumExpertsPerRank = 2; @@ -124,7 +125,8 @@ TEST_P(DispatchCombineTest, DispatchOnlyTop1) { const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); const auto backend = GetParam(); - if (backend != CommunicatorBackend::kCuda && !communicator_->isBackendAvailable(backend)) { + if (backend != CommunicatorBackend::kCuda && + !communicator_->isBackendAvailable(backend)) { GTEST_SKIP() << "Backend " << backend << " not available."; } constexpr int64_t kNumExpertsPerRank = 2; @@ -198,12 +200,7 @@ TEST_P(DispatchCombineTest, DispatchOnlyTop1) { {{in_x, x}, {in_topk_idx, topk_idx}, {in_topk_weights, topk_weights}}); auto expected = doMoeDispatch( - x, - topk_idx, - topk_weights, - num_experts, - communicator_, - backend); + x, topk_idx, topk_weights, num_experts, communicator_, backend); EXPECT_TRUE(at::allclose(outputs[0].as(), expected.recv_x)) << "Dispatch recv_x mismatch on rank " << my_rank; @@ -230,7 +227,8 @@ TEST_P(DispatchCombineTest, CombineOnlyTop1) { const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); const auto backend = GetParam(); - if (backend != CommunicatorBackend::kCuda && !communicator_->isBackendAvailable(backend)) { + if (backend != CommunicatorBackend::kCuda && + !communicator_->isBackendAvailable(backend)) { GTEST_SKIP() << "Backend " << backend << " not available."; } constexpr int64_t kNumExpertsPerRank = 2; @@ -260,12 +258,7 @@ TEST_P(DispatchCombineTest, CombineOnlyTop1) { topk_idx.index_put_({3, 0}, kNumExpertsPerRank); auto dispatch_result = doMoeDispatch( - x, - topk_idx, - topk_weights, - num_experts, - communicator_, - backend); + x, topk_idx, topk_weights, num_experts, communicator_, backend); auto hic = std::make_unique(); FusionGuard fg(hic.get()); From c5add1b01a24148c1fa6e6217dda5199215fe7be Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 16 Feb 2026 08:22:07 -0800 Subject: [PATCH 17/18] minor comment --- csrc/multidevice/dispatch_combine.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 5aeedc023f8..e320ef341cb 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -152,8 +152,8 @@ DispatchResult doMoeDispatch( n_tokens_from_rank}; } - NVF_CHECK( - backend == CommunicatorBackend::kCuda, + NVF_CHECK_EQ( + backend, CommunicatorBackend::kCuda, "Only CUDA and NCCL backends are supported for MoeDispatch."); auto metadata = From 374c8b30bd528002dbdd79c81a1c58c34bf1c376 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 16 Feb 2026 08:24:15 -0800 Subject: [PATCH 18/18] lint --- csrc/multidevice/dispatch_combine.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index e320ef341cb..55f3a19c134 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -153,7 +153,8 @@ DispatchResult doMoeDispatch( } NVF_CHECK_EQ( - backend, CommunicatorBackend::kCuda, + backend, + CommunicatorBackend::kCuda, "Only CUDA and NCCL backends are supported for MoeDispatch."); auto metadata =