Conversation
|
Review updated until commit 1c23e50 Description
|
| Relevant files | |||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement | 8 files
| ||||||||||||||||
| Tests | 1 files
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Loop Index Invalidation Logic
|
|
!test |
Greptile SummaryImplements broadcast-based allgather in host for-loop by introducing a new Key Changes
NotesThe PR description indicates broadcast performance is currently slow and will be improved with multicast integration. Previous review threads have covered potential issues around null pointer handling, root evaluation, and test correctness, with developer responses clarifying intended behavior. Confidence Score: 4/5
Important Files Changed
Flowchartflowchart TD
A[Input Tensor<br/>Sharded on DIDx] --> B{getCommunicationInfo}
B -->|DIDx → Stream<br/>Same Mesh| C[StreamBroadcast]
B -->|DIDx → Stream<br/>Different Mesh| D[Gather]
C --> E[lowerSegment<br/>in host for-loop]
E --> F[Pass loop index as root]
F --> G[lowerToStreamBroadcast]
G --> H[Communication expr<br/>with Val* root]
H --> I[HostIrEvaluator]
I --> J{For each loop iteration}
J --> K[Invalidate allocations]
K --> L[Bind loop index]
L --> M[Evaluate root Val*<br/>to int64_t]
M --> N[postBroadcast<br/>with device at root index]
N --> O[NCCL broadcast]
J -->|next iteration| J
O --> P[Output Tensor<br/>Sharded on Stream]
Last reviewed commit: 1c23e50 |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
|
!test |
Additional Comments (1)
Consider extending validation to require |
|
!test |
| std::vector<Expr*> convertSingleOpToCommunication( | ||
| Expr* c, | ||
| DeviceIdxType my_device_idx, | ||
| Val* host_loop_index = nullptr, | ||
| const CommunicatorBackend backend = CommunicatorBackend::kNccl); |
There was a problem hiding this comment.
StreamBroadcast not handled
convertSingleOpToCommunication now defaults host_loop_index to nullptr, but it can return CommunicationType::StreamBroadcast from getCommunicationInfo and then hard-errors if host_loop_index == nullptr (lower_to_communication.cpp:657-663). This makes the existing ConvertOpToCommunication pass (which calls it at top-level with nullptr) crash for any fusion that produces this DIDx→Stream resharding outside the host-loop lowering path.
Additional Comments (1)
|
wujingyue
left a comment
There was a problem hiding this comment.
It's great to see this work functionally!
| type == CommunicationType::SendRecv || | ||
| type == CommunicationType::AllToAll) { | ||
| type == CommunicationType::AllToAll || | ||
| type == CommunicationType::StreamBroadcast) { |
There was a problem hiding this comment.
I understood the motivation but can this be consolidated into the same Broadcast?
There was a problem hiding this comment.
I kept it separate so I don't need to check for the StreamParallel Type in lowerToBroadcast when deciding the root. Posting the communication uses a common function.
I also wanted to first integrate SendRecv based decomposition and then reconsider the design based on what is needed for both these comms.
| "Destination allocation should be sharded on stream after " | ||
| "shardAllocationAsLoop: ", | ||
| destination); | ||
| destination->domain()->toString(0, /*loop_only=*/false)); |
There was a problem hiding this comment.
I guess destination is still worth printing in addition to the domain?
There was a problem hiding this comment.
destination prints the loop domain. I added name above to be printed in addition to the complete domain
| TensorView* in = communication->in(); | ||
| TensorView* out = communication->out(); | ||
| if (haveDifferentShardings( | ||
| if (communication->type() != CommunicationType::StreamBroadcast && |
There was a problem hiding this comment.
While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner.
Is it possible to frame this as an optimization? For example, if in can be sharded on Stream in the same way as communication, insert a shardByStream.
There was a problem hiding this comment.
Yeah, I do think this should be merged in the shardByStream or some other logic.
For now, I kept it simple since I am not sure how it will look like with Collective Permute representation (a composite Communication, P2P comms corresponding to SendRecv etc.), so I took the verbose approach as an interim step.
Let me see what I can do in this PR itself.
|
|
||
| // This ignores device dimensions on reduction axis. | ||
| auto producer_pt_to_did = | ||
| auto producer_pt_to_id = |
There was a problem hiding this comment.
| auto producer_pt_to_id = | |
| const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id = |
| auto producer_pt_to_id = | ||
| mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); | ||
| auto consumer_pt_to_did = | ||
| auto consumer_pt_to_id = |
| std::vector<Expr*> convertSingleOpToCommunication( | ||
| Expr* e, | ||
| DeviceIdxType my_device_idx, | ||
| Val* host_loop_index, |
There was a problem hiding this comment.
| Val* host_loop_index, | |
| Val* root, |
Some communications (e.g. broadcast, reduce, gather, and scatter) are rooted. So far, we've been deciding the root according to device meshes. However, this use makes a case for passing in the root from the lowering process.
|
!test |
| input_tensor, | ||
| output_tensor); | ||
| case CommunicationType::Broadcast: | ||
| case CommunicationType::StreamBroadcast: |
There was a problem hiding this comment.
StreamBroadcast reuses broadcast logic but csrc/host_ir/evaluator.cpp:334-336 and :405-407 still reject it for the CUDA backend multicast path. If CUDA backend is intended to work with StreamBroadcast, add it to those checks; otherwise this will fail at runtime when using CommunicatorBackend::kCuda.
| input_tensor, | ||
| output_tensor); | ||
| case CommunicationType::Broadcast: | ||
| case CommunicationType::StreamBroadcast: |
There was a problem hiding this comment.
StreamBroadcast shares broadcast logic but csrc/multidevice/cuda_p2p.cpp:645-667 and :689-704 don't handle it in their switch statements. This will cause runtime errors with CUDA backend. Either add StreamBroadcast cases (treating them like Broadcast) or ensure CUDA backend is never used with this communication type.
| IterDomain* c_stream_id = | ||
| getOrDefault(consumer_pt_to_id, ParallelType::Stream); |
There was a problem hiding this comment.
c_stream_id recomputed redundantly on every DID iteration
c_stream_id is fetched from consumer_pt_to_id on every iteration of the kParallelTypeDIDs loop, but its value never changes — it doesn't depend on the loop variable pt. Move it above the loop to avoid repeating the same map lookup up to three times:
| IterDomain* c_stream_id = | |
| getOrDefault(consumer_pt_to_id, ParallelType::Stream); | |
| IterDomain* c_stream_id = | |
| getOrDefault(consumer_pt_to_id, ParallelType::Stream); | |
| for (ParallelType pt : kParallelTypeDIDs) { | |
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); | |
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); |
csrc/host_ir/lowering.cpp
Outdated
| for (Expr* c : convertSingleOpToCommunication( | ||
| e, device_id, innermost.loop->index())) { |
There was a problem hiding this comment.
Unconditional null dereference on innermost.loop for all non-loop communications
innermost is a default-constructed LoopInfo (i.e. innermost.loop == nullptr) whenever loop_nest.empty(). This PR changed the call to convertSingleOpToCommunication to unconditionally pass innermost.loop->index(), which dereferences a null pointer for every Communication segment that sits outside a stream for-loop—not just StreamBroadcast.
Existing tests such as LowerGatherTest and LowerScatterTest (with enable_host_ir_lowering = true) exercise plain Allgather/Gather outside any loop and will crash on this path.
The fix is to pass nullptr when there is no enclosing loop:
| for (Expr* c : convertSingleOpToCommunication( | |
| e, device_id, innermost.loop->index())) { | |
| for (Expr* c : convertSingleOpToCommunication( | |
| e, device_id, loop_nest.empty() ? nullptr : innermost.loop->index())) { |
| Team team = receiver_mesh.vector(); | ||
| comms.push_back(IrBuilder::create<Communication>( | ||
| CommunicationType::StreamBroadcast, | ||
| output_tv, | ||
| input_tv, | ||
| team, | ||
| root, | ||
| c10d::ReduceOp::RedOpType::UNUSED, | ||
| backend)); |
There was a problem hiding this comment.
StreamBroadcast root is a raw loop index, not a device ID
root is the for-loop index—an integer in [0, d). However team is receiver_mesh.vector(), which contains absolute device IDs. Inside postBroadcast the root is used in two ways that both assume it equals an absolute device ID:
if (my_device_index == root_index)— decides which device does the local copy.getRootRelativeIndex(root_index)— callsstd::find(team.begin(), team.end(), root_index)and asserts the value is present in the team.
For a mesh such as DeviceMesh({4, 5, 6, 7}), loop index 0 is neither equal to any my_device_index in {4,5,6,7} nor present in the team vector, so the assert fires at runtime.
The root should be the mesh device at position loop_index, i.e. receiver_mesh.at(loop_index). Since loop_index is a Val* evaluated at runtime, one approach is to look up the mesh device ID at evaluation time (e.g. via a GetItem/helper expression on the mesh tensor), or document this as a hard requirement (mesh must be arange(d)) and add a validation check at communication-creation time.
There was a problem hiding this comment.
Communication can accept a Val as root or a DeviceIdx
| for (ParallelType pt : kParallelTypeDIDs) { | ||
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); | ||
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); | ||
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); | ||
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); | ||
| IterDomain* c_stream_id = | ||
| getOrDefault(consumer_pt_to_id, ParallelType::Stream); |
There was a problem hiding this comment.
c_stream_id re-fetched on every DID iteration
c_stream_id is looked up from consumer_pt_to_id on every iteration of the kParallelTypeDIDs loop (DIDx, DIDy, DIDz), but its value never changes since it doesn't depend on pt. Moving it above the loop avoids up to two extra map lookups:
| for (ParallelType pt : kParallelTypeDIDs) { | |
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); | |
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); | |
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); | |
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); | |
| IterDomain* c_stream_id = | |
| getOrDefault(consumer_pt_to_id, ParallelType::Stream); | |
| IterDomain* c_stream_id = | |
| getOrDefault(consumer_pt_to_id, ParallelType::Stream); | |
| for (ParallelType pt : kParallelTypeDIDs) { | |
| IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt); | |
| IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt); | |
|
!test |
| sender_mesh, | ||
| " and ", | ||
| receiver_mesh); | ||
| Team team = receiver_mesh.vector(); |
There was a problem hiding this comment.
team contains absolute device IDs (e.g. {4, 5, 6, 7}), but root is a loop index Val* that evaluates to 0, 1, 2, ... at runtime. In postBroadcast (communication.cpp:474), the code checks if (my_device_index == root_index) and calls getRootRelativeIndex(root_index) which does std::find(team.begin(), team.end(), root_index). Both expect root_index to be an actual device ID from the team, not a loop iteration index. For a mesh like {4,5,6,7}, loop index 0 won't match any device and the assertion at communication.cpp:255 will fire.
Root should be the device ID at the loop index position: receiver_mesh.at(loop_index). Since loop_index is evaluated at runtime, you'll need to compute this device ID lookup at evaluation time (e.g., via an IR expression that indexes into the mesh).
The broadcast version is very slow so I am not comparing timings until we integrate this with multicast