Skip to content

Fix getCommunicationInfo for multi-dimensional meshes#5969

Open
wujingyue wants to merge 4 commits intomainfrom
wjy/comm
Open

Fix getCommunicationInfo for multi-dimensional meshes#5969
wujingyue wants to merge 4 commits intomainfrom
wjy/comm

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Feb 18, 2026

Fixes #4604

@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 18, 2026

Review updated until commit e4301a9

Description

  • Fix getCommunicationInfo function to properly handle multi-dimensional device meshes

  • Refactor logic to correctly process cases where both producer and consumer have loop domain IDs

  • Add mesh equality checks and improve error handling for better debugging

  • Include new test case test_allgather_2d to validate 2D mesh communication functionality

Changes walkthrough

Relevant files
Bug fix
lower_to_communication.cpp
Refactor getCommunicationInfo for multi-dimensional mesh support

csrc/host_ir/lower_to_communication.cpp

  • Add missing include for interface_nodes.h and update communication.h
    include
  • Replace NVF_ERROR with NVF_ERROR_EQ for better error checking
  • Major refactoring of getCommunicationInfo function to handle
    multi-dimensional meshes
  • Add mesh equality checks and improve logical flow for
    producer/consumer ID processing
  • Enhance error messages and comments for better maintainability
  • +39/-49 
    Documentation
    resharding.cpp
    Document multi-dimensional sharding limitation                     

    csrc/multidevice/resharding.cpp

  • Add documentation comment explaining multi-dimensional sharding
    limitation
  • Include concrete example showing problematic 2D sharding scenario
  • Clarify that current implementation doesn't properly handle all
    multi-dimensional cases
  • +7/-1     
    Tests
    test_communication.py
    Add 2D mesh allgather test case                                                   

    tests/python/multidevice/test_communication.py

  • Add new test function test_allgather_2d for 2D device mesh validation
  • Test allgather operation with 2D mesh partitioning using mesh_y and
    mesh_x
  • Verify proper sharding and reconstruction across multi-dimensional
    device layout
  • +31/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Logic correctness

    The restructuring of getCommunicationInfo function changes the control flow significantly. The new implementation checks haveDifferentShardings first, then processes each parallel type. This appears correct but should be carefully validated, especially the handling of LoadStoreOp vs ReductionOp/SqueezeOp cases and the continue statements that were added.

    for (ParallelType pt : kParallelTypeDIDs) {
      if (!haveDifferentShardings(producer, consumer, {pt})) {
        continue;
      }
    
      IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
      IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
      IterDomain* p_logical_id =
          p_loop_did ? getLogicalFromLoopId(producer, p_loop_did) : nullptr;
      IterDomain* c_logical_id =
          c_loop_did ? getLogicalFromLoopId(consumer, c_loop_did) : nullptr;
    
      if (e->isA<LoadStoreOp>()) {
        if (p_loop_did && !c_loop_did) {
          CommunicationType type = same_mesh ? CommunicationType::Allgather
                                             : CommunicationType::Gather;
          fill_communication_info(type, p_logical_id, p2c.at(p_logical_id));
          continue;
        }
    
        if (!p_loop_did && c_loop_did) {
          fill_communication_info(
              CommunicationType::Scatter, c2p.at(c_logical_id), c_logical_id);
          continue;
        }
    
        if (p_loop_did && c_loop_did) {
          if (c_logical_id == p2c.at(p_logical_id)) {
            fill_communication_info(
                CommunicationType::SendRecv, p_logical_id, c_logical_id);
          } else {
            fill_communication_info(
                CommunicationType::AllToAll, p_logical_id, c_logical_id);
          }
        }
    
        continue;
      }
    
      NVF_ERROR(e->isA<ReductionOp>() || e->isA<SqueezeOp>());
      if (!p_loop_did) {
        // Not a reduction based communication.
        continue;
      }
    
      if (!c_loop_did) {
        CommunicationType type =
            same_mesh ? CommunicationType::Allreduce : CommunicationType::Reduce;
        fill_communication_info(type, p_logical_id, p2c.at(p_logical_id));
        continue;
      }
    
      fill_communication_info(
          CommunicationType::ReduceScatter, c2p.at(c_logical_id), c_logical_id);
    }
    Comment accuracy

    The added comment correctly identifies that the current code is problematic for multi-dimensional sharding, but it should be verified that the fix in lower_to_communication.cpp actually resolves this issue completely, or if there are still edge cases that need to be addressed.

    // This code is problematic for multi-dimensional sharding.
    // ```
    // x: [iDIDy{2}, iDIDx{2}] on mesh [[0, 1], [2, 3]]
    // y = set(x): [iDIDy{2}, i{2}] on mesh [[0], [2]]
    // ```
    // should be treated as non-resharding on DIDy.
    return true;

    Test failures (partial, pipeline still running)

    • (Medium, 1) Thunder vs. eager output mismatch in nanoGPT CUDA CUDAGraphs test (test_networks)

      Test Name GB200 Source
      thunder.tests.test_networks.test_nanogpt_complete_cudagraphs_autograd_nvfuser_cuda_thunder.dtypes.float32

    @wujingyue wujingyue requested a review from Priya2698 February 18, 2026 07:08
    @wujingyue wujingyue marked this pull request as ready for review February 18, 2026 07:09
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Greptile Summary

    This PR fixes getCommunicationInfo to correctly handle multi-dimensional device meshes (issue #4604) by filtering parallel types using per-DID haveDifferentShardings checks instead of the previous approach that compared all DIDs simultaneously.

    • Core fix in getCommunicationInfo: Iterates over each ParallelType in kParallelTypeDIDs and uses haveDifferentShardings(producer, consumer, {pt}) to skip dimensions where sharding hasn't changed. This prevents false matches when only one dimension of a 2D mesh changes sharding.
    • Structural improvements: Moves producer_mesh/consumer_mesh/same_mesh outside the loop, adds early continue statements for clearer control flow, pre-computes p_logical_id/c_logical_id, and simplifies the ReduceScatter path by removing a now-redundant isReduction() check.
    • Known limitation documented in resharding.cpp: haveDifferentShardings still has an early return when meshes differ that incorrectly reports "different sharding" for all DIDs. This limitation is documented with a comment but not yet fixed, meaning cross-mesh multi-dimensional scenarios remain unsupported.
    • New test test_allgather_2d: Validates allgather on a 2D mesh where DIDy (data parallelism) is unchanged while DIDx (tensor parallelism) is allgathered.

    Confidence Score: 4/5

    • This PR is safe to merge — it fixes a real bug for multi-dimensional meshes with a well-tested same-mesh scenario, while correctly documenting remaining limitations.
    • The refactored logic in getCommunicationInfo is sound and well-structured. The per-parallel-type haveDifferentShardings filter is the correct approach. A minor concern is the documented-but-unfixed limitation in haveDifferentShardings for cross-mesh scenarios, but this is pre-existing and the PR makes progress on the issue. Score of 4 rather than 5 due to this known limitation remaining.
    • csrc/multidevice/resharding.cpp — the documented early-return limitation at line 133 could cause incorrect behavior for future cross-mesh multi-dimensional sharding use cases.

    Important Files Changed

    Filename Overview
    csrc/host_ir/lower_to_communication.cpp Refactors getCommunicationInfo to use per-parallel-type haveDifferentShardings filtering, enabling correct handling of multi-dimensional meshes where only one DID changes sharding. Also moves mesh variables outside the loop, adds continue statements for clearer control flow, and simplifies the ReduceScatter path by removing a now-redundant isReduction() check.
    csrc/multidevice/resharding.cpp Adds a comment documenting a known issue: haveDifferentShardings returns true too eagerly when meshes differ, even for parallel types that haven't actually changed sharding. This remains a TODO for cross-mesh multi-dimensional sharding scenarios.
    tests/python/multidevice/test_communication.py Adds test_allgather_2d that validates allgather on a 2D device mesh (DIDy for data parallelism, DIDx for tensor parallelism). Only the DIDx axis is allgathered while DIDy remains unchanged.

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A["getCommunicationInfo(Expr* e)"] --> B["For each ParallelType pt in kParallelTypeDIDs"]
        B --> C{"haveDifferentShardings\n(producer, consumer, {pt})"}
        C -- "false (same sharding)" --> B
        C -- "true (different sharding)" --> D["Compute p_loop_did, c_loop_did,\np_logical_id, c_logical_id"]
        D --> E{"e is LoadStoreOp?"}
        E -- "Yes" --> F{"p_loop_did && !c_loop_did"}
        F -- "Yes" --> G["Allgather (same mesh)\nor Gather (diff mesh)"]
        F -- "No" --> H{"!p_loop_did && c_loop_did"}
        H -- "Yes" --> I["Scatter"]
        H -- "No" --> J{"p_loop_did && c_loop_did"}
        J -- "Same logical ID" --> K["SendRecv"]
        J -- "Diff logical ID" --> L["AllToAll"]
        E -- "No (ReductionOp/SqueezeOp)" --> M{"!p_loop_did"}
        M -- "Yes" --> N["continue\n(not reduction-based)"]
        M -- "No" --> O{"!c_loop_did"}
        O -- "Yes" --> P["Allreduce (same mesh)\nor Reduce (diff mesh)"]
        O -- "No" --> Q["ReduceScatter"]
    
    Loading

    Last reviewed commit: e4301a9

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Additional Comments (1)

    csrc/multidevice/resharding.cpp
    Known limitation affects callers of this function

    The new code in getCommunicationInfo relies on calling haveDifferentShardings(producer, consumer, {pt}) per parallel type to skip unchanged dimensions. However, this early return at line 140 means that when meshes differ (e.g., [[0,1],[2,3]] vs [[0],[2]]), calling haveDifferentShardings(..., {DIDy}) will return true even though DIDy sharding is unchanged — as the comment correctly notes.

    This means getCommunicationInfo could produce incorrect results for multi-dimensional resharding across different meshes (e.g., Gather on one DID while the other DID remains unchanged but meshes differ). The same-mesh allgather case in the new test works because the meshes are identical. Is there a plan to address this for the cross-mesh case, or is that intentionally deferred? Is the cross-mesh multi-dimensional sharding case (e.g., Gather where meshes differ but one DID is unchanged) intentionally deferred as a future fix, or should it be addressed as part of this PR?

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    Fix convertSingleOpToCommunication for 2D sharding

    1 participant