Skip to content

use relative indices in post methods#5974

Open
Priya2698 wants to merge 9 commits intomainfrom
pm/relative_index
Open

use relative indices in post methods#5974
Priya2698 wants to merge 9 commits intomainfrom
pm/relative_index

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 18, 2026

For cases such as broadcast-based allgather in a host loop, the root index is the for-loop index, which may not be the absolute device ID. I am changing all lowering methods to use relative root index which is what the backends use as well.

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 18, 2026

Review updated until commit 6574bc8

Description

  • Replace absolute device IDs with relative indices in communication operations

  • Move getRelativeIndex function from Communication class to utils module

  • Update post communication methods to work with relative indices

  • Simplify SendRecv logic using relative team indices

Changes walkthrough

Relevant files
Enhancement
lower_to_communication.cpp
Update communication lowering to use relative indices       

csrc/host_ir/lower_to_communication.cpp

  • Replace root parameter with getRelativeIndex(team, root) in Scatter,
    Gather, Broadcast, SendRecv, and Reduce operations
  • Update Communication creation to use relative indices instead of
    absolute device IDs
  • +8/-7     
    communication.cpp
    Remove getRelativeIndex from Communication class                 

    csrc/multidevice/communication.cpp

  • Remove getRelativeIndex method from Communication class
  • Delete the method implementation that converted absolute rank to
    relative index
  • +0/-7     
    post_communication.cpp
    Update post communication methods for relative indices     

    csrc/multidevice/post_communication.cpp

  • Update all post methods to receive relative indices directly instead
    of converting them
  • Simplify postSendRecv logic using relative team indices
  • Add getRelativeIndex call in postSingleCommunication for proper index
    conversion
  • Remove redundant getRelativeIndex calls throughout post methods
  • +31/-39 
    utils.cpp
    Add getRelativeIndex utility function                                       

    csrc/multidevice/utils.cpp

  • Add getRelativeIndex function that converts absolute device rank to
    relative team index
  • Implement error checking to ensure rank exists in team
  • +6/-0     
    communication.h
    Remove getRelativeIndex from Communication header               

    csrc/multidevice/communication.h

  • Remove getRelativeIndex method declaration from Communication class
  • Clean up header file to reflect moved functionality
  • +0/-5     
    utils.h
    Add getRelativeIndex declaration to utils header                 

    csrc/multidevice/utils.h

  • Add getRelativeIndex function declaration with team and rank
    parameters
  • Document function purpose for converting absolute to relative indices
  • +3/-0     
    Tests
    test_multidevice_communications.cpp
    Update test to use relative indices                                           

    tests/cpp/test_multidevice_communications.cpp

  • Update SendRecv test to use getRelativeIndex when creating
    communication
  • Ensure test uses relative indices consistent with new implementation
  • +6/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Logic Error in postSendRecv

    The postSendRecv function has a significant logic change where it now assumes root_index is already a relative index (line 366-367). However, the function signature still takes root_index as a parameter, and it's unclear if callers are providing relative or absolute indices. This could cause incorrect behavior if callers still pass absolute device IDs.

    // All indices are relative and not absolute device IDs.
    DeviceIdxType sender_index = root_index;
    DeviceIdxType receiver_index = 1 - sender_index;
    Potential Index Out of Bounds

    In postGather (line 160-161) and postScatter (line 225), the code uses team.at(root_index) which assumes root_index is a valid array index. If root_index is still an absolute device ID rather than a relative index, this could cause out-of-bounds access.

    if (root_index == static_cast<DeviceIdxType>(i) &&
        !communication->in()->getDeviceMesh().has(team.at(root_index))) {
    Inconsistent Index Usage

    There's inconsistent handling of root_index across different post functions. Some functions like postBroadcast and postReduce still use root_index directly in backend calls (lines 132, 282), while others have been updated. This inconsistency could lead to bugs if the expectation isn't clearly documented.

      return backend->broadcast(tensors, {.rootRank = root_index});
    }
    
    c10::intrusive_ptr<c10d::Work> postGather(
        Communication* communication,
        DeviceIdxType my_device_index,
        DeviceIdxType root_index,
        c10d::Backend* backend,
        at::Tensor input_tensor,
        at::Tensor output_tensor) {
      const Team& team = communication->team();
      if (my_device_index == root_index &&
          !communication->in()->getDeviceMesh().has(team.at(root_index))) {
        // This is likely a suboptimal way to allocate tensors for nccl. To benefit
        // from zero copy
        // (https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html),
        // tensors for nccl should be `ncclMemAlloc`ed and be `ncclCommRegister`ed.
        // https://github.com/pytorch/pytorch/issues/124807 is one proposal trying
        // to partially address this problem.
        input_tensor = at::empty_like(output_tensor.slice(0, 0, 1));
      }
      std::vector<at::Tensor> input_tensors({input_tensor});
    
      std::vector<std::vector<at::Tensor>> output_tensors;
      if (my_device_index == root_index) {
        output_tensors.resize(1);
        int64_t j = 0;
        for (auto i : arange(communication->team().size())) {
          if (root_index == static_cast<DeviceIdxType>(i) &&
              !communication->in()->getDeviceMesh().has(team.at(root_index))) {
            output_tensors[0].push_back(input_tensor);
            continue;
          }
          output_tensors[0].push_back(output_tensor.slice(0, j, j + 1));
          j++;
        }
    
        assertBufferCount(output_tensors[0], communication->team().size());
        assertBuffersHaveSameSize(input_tensors, output_tensors[0]);
      }
    
      return backend->gather(
          output_tensors, input_tensors, {.rootRank = root_index});
    }
    
    c10::intrusive_ptr<c10d::Work> postAllgather(
        Communication* communication,
        DeviceIdxType my_device_index,
        DeviceIdxType root_index,
        c10d::Backend* backend,
        at::Tensor input_tensor,
        at::Tensor output_tensor) {
      // input and output tensors maybe strided (tensor with shape [m, n, k] and
      // strides [1, k*m, m]), so we flatten them to match the ProcessGroupNCCL
      // contiguity requirements. Presegmentation pass `makeReshardingContiguous`
      // ensures that the tvs are contiguous. CommunicationExecutor and
      // HostIrEvaluator validate the tensor against the tv allocation domain.
    
      NVF_ERROR(
          isTvContiguous(communication->in()), "Input tensor is not contiguous");
      NVF_ERROR(
          isTvContiguous(communication->out()), "Output tensor is not contiguous");
    
      auto flattened_output_tensor = viewAsCompact(output_tensor);
      auto flattened_input_tensor = viewAsCompact(input_tensor);
      auto splits = at::tensor_split(
          flattened_output_tensor, communication->team_size(), /*dim=*/0);
      assertBuffersHaveSameSize({flattened_input_tensor}, splits);
    
      // allgather primitive in c10d induces extra buffering time to copy out the
      // received tensors into user buffer. It is therefore always preferable to use
      // _allgather_base, which does not perform any extra copy at the cost of
      // assuming that the receive buffers are placed contiguously. See #2384 for an
      // illustration.
      return backend->_allgather_base(
          flattened_output_tensor, flattened_input_tensor, {});
    }
    
    c10::intrusive_ptr<c10d::Work> postScatter(
        Communication* communication,
        DeviceIdxType my_device_index,
        DeviceIdxType root_index,
        c10d::Backend* backend,
        at::Tensor input_tensor,
        at::Tensor output_tensor) {
      NVF_ERROR(
          isTvContiguous(communication->in()), "Input tensor is not contiguous");
      NVF_ERROR(
          isTvContiguous(communication->out()), "Output tensor is not contiguous");
    
      auto output_device_mesh = communication->out()->getDeviceMesh();
      const Team& team = communication->team();
      NVF_ERROR(
          output_device_mesh.has(team.at(root_index)),
          "root_index ",
          team.at(root_index),
          " is not in the output device mesh ",
          output_device_mesh,
          ".");
    
      std::vector<std::vector<at::Tensor>> input_tensors;
    
      output_tensor = viewAsCompact(output_tensor);
      std::vector<at::Tensor> output_tensors({output_tensor});
    
      if (my_device_index == root_index) {
        auto splits = at::tensor_split(
            viewAsCompact(input_tensor),
            output_device_mesh.size(),
            /*dim=*/0);
    
        input_tensors.resize(1);
        for (const auto& split : splits) {
          input_tensors.front().push_back(split);
        }
    
        assertBufferCount(input_tensors[0], output_device_mesh.size());
        assertBuffersHaveSameSize(input_tensors[0], output_tensors);
      }
    
      return backend->scatter(
          output_tensors, input_tensors, {.rootRank = root_index});
    }
    
    c10::intrusive_ptr<c10d::Work> postReduce(
        Communication* communication,
        DeviceIdxType my_device_index,
        DeviceIdxType root_index,
        c10d::Backend* backend,
        at::Tensor input_tensor,
        at::Tensor output_tensor) {
      at::Tensor tensor;
      const Team& team = communication->team();
      if (my_device_index == root_index) {
        if (communication->in()->getDeviceMesh().has(team.at(root_index))) {
          doLocalCopy(output_tensor, input_tensor);
          tensor = output_tensor;
        } else {
          NVF_ERROR(
              output_tensor.scalar_type() == at::kFloat,
              "only float tensors are supported");
          output_tensor.fill_(getInitialValue<float>(communication->reduceOp()));
          tensor = output_tensor;
        }
      } else {
        tensor = input_tensor;
      }
      std::vector<at::Tensor> tensors({tensor});
    
      c10d::ReduceOptions options = {
          .reduceOp = communication->reduceOp(), .rootRank = root_index};

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Greptile Summary

    This PR refactors the communication lowering code to consistently use relative device indices instead of absolute device IDs. The getRelativeIndex method was moved from the Communication class to a standalone utility function in utils.cpp/h, and all communication creation sites in lower_to_communication.cpp now convert absolute device IDs to relative indices before passing them to Communication objects. In post_communication.cpp, all post methods now directly use the relative root_index parameter instead of calling getRelativeIndex again, since the index is already relative.

    Key changes:

    • Moved getRelativeIndex from Communication class to utils as a standalone function
    • Updated all lowering methods (scatter, gather, broadcast, sendrecv, reduce) to convert root device IDs to relative indices
    • Simplified post methods by removing redundant getRelativeIndex calls
    • Properly handles self-send case with duplicate team check (team[0] == team[1])
    • Added clarifying comment in postSendRecv that indices are relative

    Confidence Score: 5/5

    • This PR is safe to merge - it's a well-structured refactoring with consistent changes across the codebase
    • The refactoring is systematic and addresses the stated goal of using relative indices consistently. All previously identified issues (missing header declaration, extra closing brace, self-send handling) have been properly addressed in the current code. The changes maintain backward compatibility in behavior while improving code clarity.
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/multidevice/post_communication.cpp Refactored to use relative indices throughout; properly handles self-send case with team duplicate check
    csrc/host_ir/lower_to_communication.cpp Updated all communication creation sites to convert absolute device IDs to relative indices using getRelativeIndex
    csrc/multidevice/utils.cpp Moved getRelativeIndex from Communication class to utils with proper const Team& parameter
    csrc/multidevice/utils.h Added declaration for getRelativeIndex helper function

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A[lowerToScatter/Gather/Broadcast/SendRecv/Reduce] --> B[Convert absolute device ID to relative index]
        B --> C[getRelativeIndex team, root_device_id]
        C --> D[Create Communication object with relative root_index]
        D --> E[postSingleCommunication called]
        E --> F[Convert my_device_index to relative]
        F --> G[getRelativeIndex team, my_device_index]
        G --> H[Call specific post method with relative indices]
        H --> I{Communication Type}
        I -->|Broadcast| J[postBroadcast: use root_index directly]
        I -->|Gather| K[postGather: use root_index directly]
        I -->|Scatter| L[postScatter: use root_index directly]
        I -->|SendRecv| M[postSendRecv: use root_index directly]
        I -->|Reduce| N[postReduce: use root_index directly]
        J --> O[backend->broadcast with rootRank: root_index]
        K --> P[backend->gather with rootRank: root_index]
        L --> Q[backend->scatter with rootRank: root_index]
        M --> R[backend->send/recv with relative receiver_index]
        N --> S[backend->reduce with rootRank: root_index]
    
    Loading

    Last reviewed commit: 6574bc8

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Additional Comments (2)

    csrc/multidevice/post_communication.cpp
    Calls removed member function

    postScatter still calls communication->getRelativeIndex(root_index), but the getRelativeIndex member function was removed from the Communication class in this PR. Since root_index is now already a relative index (converted in postSingleCommunication at line 465), this should just pass root_index directly, consistent with the pattern applied to postBroadcast, postGather, and postReduce.

          {.rootRank = root_index});
    

    csrc/multidevice/post_communication.cpp
    Uses relative index as device ID

    After the changes in postSingleCommunication (line 465), root_index is now a relative team index (0-based), not a global device ID. However, output_device_mesh.has(root_index) checks whether a device ID exists in the mesh, not a relative index. This should be output_device_mesh.has(team.at(root_index)) to first resolve the relative index back to a device ID, consistent with the pattern used in postBroadcast (line 116), postGather (line 144), and postReduce (line 267).

    You'll need to add const Team& team = communication->team(); at the beginning of this function as well.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    5 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    7 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Additional Comments (1)

    tests/cpp/test_multidevice_communications.cpp
    Test passes absolute index as root_index

    After this PR, postSingleCommunication expects root_index to be a relative index (consistent with communication->root() now storing relative indices). However, this test passes the absolute sender (value 1) as root_index.

    This works only by coincidence because team = {0, 1} happens to have device 1 at relative position 1. If the test used different device IDs (e.g., team = {3, 5}), the absolute sender ID would not equal its relative index.

    To be consistent with the new convention, this should pass the relative index:

        auto work = postSingleCommunication(
            communication, rank, backend, input_tensor, output_tensor, getRelativeIndex(team, sender));
    

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    7 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Additional Comments (1)

    tests/cpp/test_multidevice_communications.cpp
    Test passes absolute index as root_index

    After this PR, postSingleCommunication expects root_index to be a relative index (since my_device_index is now converted to relative at line 462 of post_communication.cpp). Here, sender = 1 (absolute device index) is passed as root_index. This works only by coincidence because team = {0, 1}, so the absolute index of device 1 happens to equal its relative index in the team.

    To be consistent with the rest of the PR's intent, consider passing the relative index explicitly:

        auto work = postSingleCommunication(
            communication, rank, backend, input_tensor, output_tensor, getRelativeIndex(team, sender));
    

    @Priya2698 Priya2698 marked this pull request as draft February 18, 2026 20:08
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    7 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Additional Comments (1)

    tests/cpp/test_multidevice_communications.cpp
    Test passes absolute ID as relative index

    postSingleCommunication now expects root_index to be a relative index (since my_device_index is converted to relative inside the function, and all comparisons are done in relative space). Here, sender is the absolute device ID (1), not the relative index. It works by coincidence because team = {0, 1} gives an identity mapping between absolute and relative indices.

    For correctness and to actually test the relative-index contract, pass the relative index:

        auto work = postSingleCommunication(
            communication, rank, backend, input_tensor, output_tensor, getRelativeIndex(team, sender));
    

    if (std::find(team.begin(), team.end(), my_device_index) == team.end()) {
    return nullptr;
    }
    my_device_index = getRelativeIndex(team, my_device_index);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I can also do this in caller methods such that postSingleCommunication receives all relative indices.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Using the same name for both absolute and relative index is a bit confusing. Maybe my_index_in_team of type int64_t will be clearer?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Also, you may want to consider using _index for and only for relative indices. E.g. my_device vs my_device_index/my_index_in_team and root vs root_index. You sort of followed that convention in other methods already, e.g., sender vs sender_index.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 marked this pull request as ready for review February 19, 2026 02:09
    @Priya2698 Priya2698 requested a review from wujingyue February 19, 2026 02:10
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    7 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    team.size(), 2, "SendRecv's team size is expected to be atmost 2.");

    if (sender == receiver) {
    if (team.size() == 1 || (team[0] == team[1])) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    if (team.size() == 1 || (team[0] == team[1])) {
    if (team.size() == 1 || team[0] == team[1]) {

    if (std::find(team.begin(), team.end(), my_device_index) == team.end()) {
    return nullptr;
    }
    my_device_index = getRelativeIndex(team, my_device_index);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Using the same name for both absolute and relative index is a bit confusing. Maybe my_index_in_team of type int64_t will be clearer?

    if (std::find(team.begin(), team.end(), my_device_index) == team.end()) {
    return nullptr;
    }
    my_device_index = getRelativeIndex(team, my_device_index);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Also, you may want to consider using _index for and only for relative indices. E.g. my_device vs my_device_index/my_index_in_team and root vs root_index. You sort of followed that convention in other methods already, e.g., sender vs sender_index.

    c10::intrusive_ptr<c10d::Work> postGather(
    Communication* communication,
    DeviceIdxType my_device_index,
    DeviceIdxType root_index,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    int64_t root_index,

    For readability, I'd use DeviceIdxType for device IDs and int64_t for relative indices.

    Comment on lines 159 to +160
    for (auto i : arange(communication->team().size())) {
    if (root_relative_index == static_cast<DeviceIdxType>(i) &&
    !communication->in()->getDeviceMesh().has(root_index)) {
    if (root_index == static_cast<DeviceIdxType>(i) &&
    Copy link
    Collaborator

    @wujingyue wujingyue Feb 19, 2026

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    for (auto i : arange(std::ssize(communication->team())) {
      if (root_index == i) {
    

    receiver = (team[0] == sender ? team[1] : team[0]);
    }
    NVF_ERROR_LE(
    team.size(), 2, "SendRecv's team size is expected to be atmost 2.");
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    team.size(), 2, "SendRecv's team size is expected to be atmost 2.");
    team.size(), 2);

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants

    Comments