You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The restructuring of getCommunicationInfo function changes the control flow significantly. The new implementation checks haveDifferentShardings first, then processes each parallel type. This appears correct but should be carefully validated, especially the handling of LoadStoreOp vs ReductionOp/SqueezeOp cases and the continue statements that were added.
for (ParallelType pt : kParallelTypeDIDs) {
if (!haveDifferentShardings(producer, consumer, {pt})) {
continue;
}
IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
IterDomain* p_logical_id =
p_loop_did ? getLogicalFromLoopId(producer, p_loop_did) : nullptr;
IterDomain* c_logical_id =
c_loop_did ? getLogicalFromLoopId(consumer, c_loop_did) : nullptr;
if (e->isA<LoadStoreOp>()) {
if (p_loop_did && !c_loop_did) {
CommunicationType type = same_mesh ? CommunicationType::Allgather
: CommunicationType::Gather;
fill_communication_info(type, p_logical_id, p2c.at(p_logical_id));
continue;
}
if (!p_loop_did && c_loop_did) {
fill_communication_info(
CommunicationType::Scatter, c2p.at(c_logical_id), c_logical_id);
continue;
}
if (p_loop_did && c_loop_did) {
if (c_logical_id == p2c.at(p_logical_id)) {
fill_communication_info(
CommunicationType::SendRecv, p_logical_id, c_logical_id);
} else {
fill_communication_info(
CommunicationType::AllToAll, p_logical_id, c_logical_id);
}
}
continue;
}
NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>());
if (!p_loop_did) {
// Not a reduction based communication.continue;
}
if (!c_loop_did) {
CommunicationType type =
same_mesh ? CommunicationType::Allreduce : CommunicationType::Reduce;
fill_communication_info(type, p_logical_id, p2c.at(p_logical_id));
continue;
}
fill_communication_info(
CommunicationType::ReduceScatter, c2p.at(c_logical_id), c_logical_id);
}
The added comment correctly identifies that the current code is problematic for multi-dimensional sharding, but it should be verified that the fix in lower_to_communication.cpp actually resolves this issue completely, or if there are still edge cases that need to be addressed.
// This code is problematic for multi-dimensional sharding.// ```// x: [iDIDy{2}, iDIDx{2}] on mesh [[0, 1], [2, 3]]// y = set(x): [iDIDy{2}, i{2}] on mesh [[0], [2]]// ```// should be treated as non-resharding on DIDy.returntrue;
Test failures (partial, pipeline still running)
(Medium, 1)Thunder vs. eager output mismatch in nanoGPT CUDA CUDAGraphs test (test_networks)
This PR fixes getCommunicationInfo to correctly handle multi-dimensional device meshes (issue #4604) by filtering parallel types using per-DID haveDifferentShardings checks instead of the previous approach that compared all DIDs simultaneously.
Core fix in getCommunicationInfo: Iterates over each ParallelType in kParallelTypeDIDs and uses haveDifferentShardings(producer, consumer, {pt}) to skip dimensions where sharding hasn't changed. This prevents false matches when only one dimension of a 2D mesh changes sharding.
Structural improvements: Moves producer_mesh/consumer_mesh/same_mesh outside the loop, adds early continue statements for clearer control flow, pre-computes p_logical_id/c_logical_id, and simplifies the ReduceScatter path by removing a now-redundant isReduction() check.
Known limitation documented in resharding.cpp: haveDifferentShardings still has an early return when meshes differ that incorrectly reports "different sharding" for all DIDs. This limitation is documented with a comment but not yet fixed, meaning cross-mesh multi-dimensional scenarios remain unsupported.
New test test_allgather_2d: Validates allgather on a 2D mesh where DIDy (data parallelism) is unchanged while DIDx (tensor parallelism) is allgathered.
Confidence Score: 4/5
This PR is safe to merge — it fixes a real bug for multi-dimensional meshes with a well-tested same-mesh scenario, while correctly documenting remaining limitations.
The refactored logic in getCommunicationInfo is sound and well-structured. The per-parallel-type haveDifferentShardings filter is the correct approach. A minor concern is the documented-but-unfixed limitation in haveDifferentShardings for cross-mesh scenarios, but this is pre-existing and the PR makes progress on the issue. Score of 4 rather than 5 due to this known limitation remaining.
csrc/multidevice/resharding.cpp — the documented early-return limitation at line 133 could cause incorrect behavior for future cross-mesh multi-dimensional sharding use cases.
Important Files Changed
Filename
Overview
csrc/host_ir/lower_to_communication.cpp
Refactors getCommunicationInfo to use per-parallel-type haveDifferentShardings filtering, enabling correct handling of multi-dimensional meshes where only one DID changes sharding. Also moves mesh variables outside the loop, adds continue statements for clearer control flow, and simplifies the ReduceScatter path by removing a now-redundant isReduction() check.
csrc/multidevice/resharding.cpp
Adds a comment documenting a known issue: haveDifferentShardings returns true too eagerly when meshes differ, even for parallel types that haven't actually changed sharding. This remains a TODO for cross-mesh multi-dimensional sharding scenarios.
tests/python/multidevice/test_communication.py
Adds test_allgather_2d that validates allgather on a 2D device mesh (DIDy for data parallelism, DIDx for tensor parallelism). Only the DIDx axis is allgathered while DIDy remains unchanged.
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["getCommunicationInfo(Expr* e)"] --> B["For each ParallelType pt in kParallelTypeDIDs"]
B --> C{"haveDifferentShardings\n(producer, consumer, {pt})"}
C -- "false (same sharding)" --> B
C -- "true (different sharding)" --> D["Compute p_loop_did, c_loop_did,\np_logical_id, c_logical_id"]
D --> E{"e is LoadStoreOp?"}
E -- "Yes" --> F{"p_loop_did && !c_loop_did"}
F -- "Yes" --> G["Allgather (same mesh)\nor Gather (diff mesh)"]
F -- "No" --> H{"!p_loop_did && c_loop_did"}
H -- "Yes" --> I["Scatter"]
H -- "No" --> J{"p_loop_did && c_loop_did"}
J -- "Same logical ID" --> K["SendRecv"]
J -- "Diff logical ID" --> L["AllToAll"]
E -- "No (ReductionOp/SqueezeOp)" --> M{"!p_loop_did"}
M -- "Yes" --> N["continue\n(not reduction-based)"]
M -- "No" --> O{"!c_loop_did"}
O -- "Yes" --> P["Allreduce (same mesh)\nor Reduce (diff mesh)"]
O -- "No" --> Q["ReduceScatter"]
csrc/multidevice/resharding.cpp Known limitation affects callers of this function
The new code in getCommunicationInfo relies on calling haveDifferentShardings(producer, consumer, {pt}) per parallel type to skip unchanged dimensions. However, this early return at line 140 means that when meshes differ (e.g., [[0,1],[2,3]] vs [[0],[2]]), calling haveDifferentShardings(..., {DIDy}) will return true even though DIDy sharding is unchanged — as the comment correctly notes.
This means getCommunicationInfo could produce incorrect results for multi-dimensional resharding across different meshes (e.g., Gather on one DID while the other DID remains unchanged but meshes differ). The same-mesh allgather case in the new test works because the meshes are identical. Is there a plan to address this for the cross-mesh case, or is that intentionally deferred? Is the cross-mesh multi-dimensional sharding case (e.g., Gather where meshes differ but one DID is unchanged) intentionally deferred as a future fix, or should it be addressed as part of this PR?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #4604