Conversation
|
!test |
|
Review updated until commit 6574bc8 Description
|
| Relevant files | |||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Logic Error in postSendRecv
|
Greptile SummaryThis PR refactors the communication lowering code to consistently use relative device indices instead of absolute device IDs. The Key changes:
Confidence Score: 5/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[lowerToScatter/Gather/Broadcast/SendRecv/Reduce] --> B[Convert absolute device ID to relative index]
B --> C[getRelativeIndex team, root_device_id]
C --> D[Create Communication object with relative root_index]
D --> E[postSingleCommunication called]
E --> F[Convert my_device_index to relative]
F --> G[getRelativeIndex team, my_device_index]
G --> H[Call specific post method with relative indices]
H --> I{Communication Type}
I -->|Broadcast| J[postBroadcast: use root_index directly]
I -->|Gather| K[postGather: use root_index directly]
I -->|Scatter| L[postScatter: use root_index directly]
I -->|SendRecv| M[postSendRecv: use root_index directly]
I -->|Reduce| N[postReduce: use root_index directly]
J --> O[backend->broadcast with rootRank: root_index]
K --> P[backend->gather with rootRank: root_index]
L --> Q[backend->scatter with rootRank: root_index]
M --> R[backend->send/recv with relative receiver_index]
N --> S[backend->reduce with rootRank: root_index]
Last reviewed commit: 6574bc8 |
Additional Comments (2)
After the changes in You'll need to add |
Additional Comments (1)
After this PR, This works only by coincidence because To be consistent with the new convention, this should pass the relative index: |
Additional Comments (1)
After this PR, To be consistent with the rest of the PR's intent, consider passing the relative index explicitly: |
Additional Comments (1)
For correctness and to actually test the relative-index contract, pass the relative index: |
| if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { | ||
| return nullptr; | ||
| } | ||
| my_device_index = getRelativeIndex(team, my_device_index); |
There was a problem hiding this comment.
I can also do this in caller methods such that postSingleCommunication receives all relative indices.
There was a problem hiding this comment.
Using the same name for both absolute and relative index is a bit confusing. Maybe my_index_in_team of type int64_t will be clearer?
There was a problem hiding this comment.
Also, you may want to consider using _index for and only for relative indices. E.g. my_device vs my_device_index/my_index_in_team and root vs root_index. You sort of followed that convention in other methods already, e.g., sender vs sender_index.
|
!test |
|
!test |
| team.size(), 2, "SendRecv's team size is expected to be atmost 2."); | ||
|
|
||
| if (sender == receiver) { | ||
| if (team.size() == 1 || (team[0] == team[1])) { |
There was a problem hiding this comment.
| if (team.size() == 1 || (team[0] == team[1])) { | |
| if (team.size() == 1 || team[0] == team[1]) { |
| if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { | ||
| return nullptr; | ||
| } | ||
| my_device_index = getRelativeIndex(team, my_device_index); |
There was a problem hiding this comment.
Using the same name for both absolute and relative index is a bit confusing. Maybe my_index_in_team of type int64_t will be clearer?
| if (std::find(team.begin(), team.end(), my_device_index) == team.end()) { | ||
| return nullptr; | ||
| } | ||
| my_device_index = getRelativeIndex(team, my_device_index); |
There was a problem hiding this comment.
Also, you may want to consider using _index for and only for relative indices. E.g. my_device vs my_device_index/my_index_in_team and root vs root_index. You sort of followed that convention in other methods already, e.g., sender vs sender_index.
| c10::intrusive_ptr<c10d::Work> postGather( | ||
| Communication* communication, | ||
| DeviceIdxType my_device_index, | ||
| DeviceIdxType root_index, |
There was a problem hiding this comment.
| int64_t root_index, |
For readability, I'd use DeviceIdxType for device IDs and int64_t for relative indices.
| for (auto i : arange(communication->team().size())) { | ||
| if (root_relative_index == static_cast<DeviceIdxType>(i) && | ||
| !communication->in()->getDeviceMesh().has(root_index)) { | ||
| if (root_index == static_cast<DeviceIdxType>(i) && |
There was a problem hiding this comment.
for (auto i : arange(std::ssize(communication->team())) {
if (root_index == i) {
| receiver = (team[0] == sender ? team[1] : team[0]); | ||
| } | ||
| NVF_ERROR_LE( | ||
| team.size(), 2, "SendRecv's team size is expected to be atmost 2."); |
There was a problem hiding this comment.
| team.size(), 2, "SendRecv's team size is expected to be atmost 2."); | |
| team.size(), 2); |
For cases such as broadcast-based allgather in a host loop, the root index is the for-loop index, which may not be the absolute device ID. I am changing all lowering methods to use relative root index which is what the backends use as well.