Skip to content

Broadcast-based allgather in host for-loop#5925

Open
Priya2698 wants to merge 15 commits intomainfrom
pm/stream_broadcast
Open

Broadcast-based allgather in host for-loop#5925
Priya2698 wants to merge 15 commits intomainfrom
pm/stream_broadcast

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 6, 2026

Screenshot 2026-02-09 at 1 24 11 PM

The broadcast version is very slow so I am not comparing timings until we integrate this with multicast

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Review updated until commit 1c23e50

Description

  • Add StreamBroadcast communication type for broadcast-based allgather in host for-loops

  • Implement lowerToStreamBroadcast function to handle stream-based broadcast operations

  • Add loop index invalidation logic in HostIrEvaluator for proper cache management

  • Update communication lowering to pass loop index as root parameter for stream broadcasts

  • Add comprehensive tests for column parallel linear forward with stream broadcast

Changes walkthrough

Relevant files
Enhancement
8 files
evaluator.cpp
Add loop index invalidation logic for cache management     
+17/-0   
lower_to_communication.cpp
Implement StreamBroadcast communication type and lowering
+61/-5   
lowering.cpp
Update convertSingleOpToCommunication to pass root parameter
+8/-3     
ops.cpp
Enhance error message in shardByStream function                   
+4/-2     
convert_op_to_communication.cpp
Update convertSingleOpToCommunication call with root parameter
+4/-1     
communication.cpp
Add StreamBroadcast support to communication type handling
+6/-0     
lower_to_communication.h
Update function signature to accept root parameter             
+6/-0     
communication.h
Add StreamBroadcast to CommunicationType enum                       
+6/-1     
Tests
1 files
test_overlap.py
Add tests for column parallel linear with stream broadcast
+114/-0 

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Loop Index Invalidation Logic

The new loop index invalidation logic (lines 535-551) adds complex tracking of allocations and consumer values. This needs careful validation to ensure it correctly handles all edge cases in loop-dependent expressions and doesn't introduce memory safety issues or incorrect evaluation order.

// Expressions dependent on loop index and all allocations
// inside the loop body should be invalidated. We cannot
// simply use allConsumerValsOf because the loop index can be an input to
// fusion outputs or buffers allocated outside the loop.
std::unordered_set<Val*> allocations;
for (Expr* e : for_loop->body().exprs()) {
  if (auto* alloc = dynamic_cast<kir::Allocate*>(e)) {
    allocations.insert(alloc->buffer());
  }
}
expr_evaluator_.invalidate(for_loop->index());
for (auto consumer : allConsumerValsOf(for_loop->index())) {
  if (consumer->isA<TensorView>() && !allocations.contains(consumer)) {
    continue;
  }
  expr_evaluator_.invalidate(consumer);
}
StreamBroadcast Root Validation

The StreamBroadcast implementation requires careful validation of the root parameter. The error checking ensures root is not null, but the semantic correctness of using loop indices as roots in different parallel contexts needs thorough testing.

case CommunicationType::StreamBroadcast:
  NVF_ERROR(
      root != nullptr,
      "StreamBroadcast requires a root value passed in through lowering");
  lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root);
  break;
Performance vs Correctness Trade-off

The PR description mentions the broadcast version is very slow and timing comparisons are deferred until multicast integration. While correctness tests are added, the performance implications and when this optimization becomes beneficial need clear documentation and validation criteria.

def column_parallel_linear_forward(h: int, d: int):
    with FusionDefinition() as fd:
        inp_tv = fd.define_tensor((-1, h), contiguity=True, dtype=DataType.BFloat16)
        weight_tv = fd.define_tensor(
            (4 * h, h), contiguity=True, dtype=DataType.BFloat16
        )
        ag_out = fd.ops.set(inp_tv)
        out_tv = fd.ops.linear(ag_out, weight_tv)
        fd.add_output(out_tv)

        mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d))

        for tv in [inp_tv, weight_tv]:
            tv.set_device_mesh(mesh)
            tv.outer_split(0, d)
            tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x)

        ag_out.set_device_mesh(mesh)
        ag_out.outer_split(0, d)
        ag_out.axis(0).parallelize(nvfuser.ParallelType.stream)

        # Fusion IR before segmentation will look like this:
        #   [t, h]
        #   /\.
        #  d
        # (deviceIdx.x)
        #    |
        #    | set (lowered to StreamBroadcast. This decomposition is done manually in the definition above. It will later be done by preseg)
        #    |
        #   [t, h]                                  [4h,  h]
        #   /\                                      /\.
        #  s                                       d
        # (streamIdx)
        #                      |
        #                      | linear
        #                      |
        #                   [t, 4h, r{h}]
        #                   /\  /\.
        #                  s*   d

    return fd


@pytest.mark.mpi
def test_column_parallel_linear_forward(multidevice_test):
    # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward.
    # The difference is we are using broadcast based overlapping instead of send/recv.
    h, t = 2, 24
    d = multidevice_test.size
    if (h * 4) % d != 0:
        pytest.skip(
            f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
        )
    if t % d != 0:
        pytest.skip(
            f"Column-parallel linear requires {t} to be divisible by world size {d}."
        )

    fd = column_parallel_linear_forward(h, d)

    inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to(
        torch.bfloat16
    )
    weight_ref = torch.testing.make_tensor(
        4 * h, h, dtype=torch.int32, device="cpu"
    ).to(torch.bfloat16)

    inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0])
    weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1])

    out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight)

    with torch.profiler.profile(record_shapes=True) as prof:
        (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"])
    torch.testing.assert_close(out, out_ref)
    broadcast_events = [
        event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name
    ]
    assert len(broadcast_events) == (d if d > 1 else 0)


@pytest.mark.mpi
@pytest.mark.benchmark
def test_column_parallel_linear_forward_benchmark(multidevice_test, benchmark):
    # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward.
    h, t = 8192, 8192
    d = multidevice_test.size
    if (4 * h) % d != 0:
        pytest.skip(
            f"Column-parallel linear requires {4 * h} to be divisible by world size {d}."
        )
    if t % d != 0:
        pytest.skip(
            f"Column-parallel linear requires {t} to be divisible by world size {d}."
        )

    fd = column_parallel_linear_forward(h, d)

    inp_ref = torch.randn(t, h, dtype=torch.bfloat16, device="cpu")
    weight_ref = torch.randn(4 * h, h, dtype=torch.bfloat16, device="cpu")

    inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0])
    weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1])

    warmup_fn, benchmark_fn = get_benchmark_fns(
        lambda: fd.execute(
            [inp, weight],
            _enable_options=["host_ir_lowering"],
        )
    )
    warmup_fn()
    benchmark.pedantic(benchmark_fn, rounds=5)

Test failures

  • (Medium, 6) NVFuser HostIrEvaluatorTest internal asserts (AddInLoop / InplaceUpdateInLoop) across GPU runners

    Test Name A100 GB200 H100 Source
    HostIrEvaluatorTest.AddInLoop Link
    HostIrEvaluatorTest.InplaceUpdateInLoop Link

@Priya2698 Priya2698 marked this pull request as ready for review February 9, 2026 21:10
@Priya2698 Priya2698 requested a review from wujingyue February 9, 2026 21:11
@Priya2698
Copy link
Collaborator Author

!test

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Summary

Implements broadcast-based allgather in host for-loop by introducing a new StreamBroadcast communication type. When lowering detects a DIDx→Stream resharding on the same mesh, it emits a StreamBroadcast communication inside a host loop, passing the loop index as the broadcast root. On each iteration, the evaluator invalidates loop-dependent allocations and evaluates the root to determine which device broadcasts. This decomposition enables overlapping computation with communication for stream-parallel workloads.

Key Changes

  • Added StreamBroadcast communication type that reuses postBroadcast implementation
  • Modified getCommunicationInfo to detect DIDx→Stream pattern and emit StreamBroadcast
  • Updated lowerSegment to pass loop index as root parameter for StreamBroadcast
  • Enhanced HostIrEvaluator::handle(ForLoop*) to invalidate allocations inside loop body on each iteration
  • Added column-parallel linear test exercising broadcast-based allgather with profiling validation
  • Improved error messages and added explicit nullptr handling for root parameter

Notes

The PR description indicates broadcast performance is currently slow and will be improved with multicast integration. Previous review threads have covered potential issues around null pointer handling, root evaluation, and test correctness, with developer responses clarifying intended behavior.

Confidence Score: 4/5

  • Safe to merge with minor caveats around performance and potential runtime edge cases
  • The implementation is architecturally sound and follows existing patterns. All major concerns from previous reviews have been addressed or explained by the development team. The core logic correctly handles the loop-based broadcast decomposition, and tests validate the functionality. Score is 4 (not 5) because: (1) performance is acknowledged as slow pending multicast integration, (2) the invalidation logic in evaluator is complex and may have edge cases, and (3) CUDA backend compatibility mentioned in previous threads may need verification
  • csrc/host_ir/evaluator.cpp and csrc/host_ir/lowering.cpp warrant careful testing to ensure invalidation logic handles all loop-dependent allocation scenarios correctly

Important Files Changed

Filename Overview
csrc/host_ir/lower_to_communication.cpp Adds StreamBroadcast communication type to decompose allgather as broadcast in host for-loop; all previously flagged issues appear addressed
csrc/host_ir/lowering.cpp Passes loop index as root to convertSingleOpToCommunication and adds StreamBroadcast check; removes old expr from container in cloneWithNewOperands
csrc/host_ir/evaluator.cpp Adds invalidation logic for loop-dependent expressions and allocations to handle stream-parallel broadcast correctly
tests/python/multidevice/test_overlap.py Adds column-parallel linear tests that exercise broadcast-based allgather in stream parallel; includes profiling and benchmarking

Flowchart

flowchart TD
    A[Input Tensor<br/>Sharded on DIDx] --> B{getCommunicationInfo}
    B -->|DIDx → Stream<br/>Same Mesh| C[StreamBroadcast]
    B -->|DIDx → Stream<br/>Different Mesh| D[Gather]
    C --> E[lowerSegment<br/>in host for-loop]
    E --> F[Pass loop index as root]
    F --> G[lowerToStreamBroadcast]
    G --> H[Communication expr<br/>with Val* root]
    H --> I[HostIrEvaluator]
    I --> J{For each loop iteration}
    J --> K[Invalidate allocations]
    K --> L[Bind loop index]
    L --> M[Evaluate root Val*<br/>to int64_t]
    M --> N[postBroadcast<br/>with device at root index]
    N --> O[NCCL broadcast]
    J -->|next iteration| J
    O --> P[Output Tensor<br/>Sharded on Stream]
Loading

Last reviewed commit: 1c23e50

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
@Priya2698
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

csrc/multidevice/communication.cpp
Root validation rejects non-const

Communication::validate only enforces the root/type contract when root() is a const integral scalar. For StreamBroadcast, root is the host loop index (non-const), so hasRoot(type()) is never validated and invalid roots (e.g., non-integral or negative-at-runtime) can slip through. This can lead to runtime failures when postBroadcast interprets the root.

Consider extending validation to require root() be Index dtype for StreamBroadcast/rooted collectives even when not constant, and/or add runtime checks where the root is consumed.

@Priya2698
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 54 to 58
std::vector<Expr*> convertSingleOpToCommunication(
Expr* c,
DeviceIdxType my_device_idx,
Val* host_loop_index = nullptr,
const CommunicatorBackend backend = CommunicatorBackend::kNccl);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamBroadcast not handled

convertSingleOpToCommunication now defaults host_loop_index to nullptr, but it can return CommunicationType::StreamBroadcast from getCommunicationInfo and then hard-errors if host_loop_index == nullptr (lower_to_communication.cpp:657-663). This makes the existing ConvertOpToCommunication pass (which calls it at top-level with nullptr) crash for any fusion that produces this DIDx→Stream resharding outside the host-loop lowering path.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

csrc/multidevice/communication.cpp
Non-constant root accepted

Communication::validate only checks root/type consistency when root() is a const integral scalar (communication.cpp:238-246). For StreamBroadcast, the root is intentionally a non-const Val* (host loop index), so this validation becomes a no-op: invalid roots (e.g., negative at runtime, wrong dtype) won’t be rejected here but later code assumes a valid rank/root. If StreamBroadcast relies on runtime root, it still needs a type/dtype/range validation path for non-const roots (at least DataType::Index and non-negative).

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see this work functionally!

type == CommunicationType::SendRecv ||
type == CommunicationType::AllToAll) {
type == CommunicationType::AllToAll ||
type == CommunicationType::StreamBroadcast) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understood the motivation but can this be consolidated into the same Broadcast?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it separate so I don't need to check for the StreamParallel Type in lowerToBroadcast when deciding the root. Posting the communication uses a common function.
I also wanted to first integrate SendRecv based decomposition and then reconsider the design based on what is needed for both these comms.

"Destination allocation should be sharded on stream after "
"shardAllocationAsLoop: ",
destination);
destination->domain()->toString(0, /*loop_only=*/false));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess destination is still worth printing in addition to the domain?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

destination prints the loop domain. I added name above to be printed in addition to the complete domain

TensorView* in = communication->in();
TensorView* out = communication->out();
if (haveDifferentShardings(
if (communication->type() != CommunicationType::StreamBroadcast &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner.

Is it possible to frame this as an optimization? For example, if in can be sharded on Stream in the same way as communication, insert a shardByStream.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do think this should be merged in the shardByStream or some other logic.
For now, I kept it simple since I am not sure how it will look like with Collective Permute representation (a composite Communication, P2P comms corresponding to SendRecv etc.), so I took the verbose approach as an interim step.

Let me see what I can do in this PR itself.


// This ignores device dimensions on reduction axis.
auto producer_pt_to_did =
auto producer_pt_to_id =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto producer_pt_to_id =
const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id =

auto producer_pt_to_id =
mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain());
auto consumer_pt_to_did =
auto consumer_pt_to_id =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

std::vector<Expr*> convertSingleOpToCommunication(
Expr* e,
DeviceIdxType my_device_idx,
Val* host_loop_index,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Val* host_loop_index,
Val* root,

Some communications (e.g. broadcast, reduce, gather, and scatter) are rooted. So far, we've been deciding the root according to device meshes. However, this use makes a case for passing in the root from the lowering process.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Priya2698
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

input_tensor,
output_tensor);
case CommunicationType::Broadcast:
case CommunicationType::StreamBroadcast:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamBroadcast reuses broadcast logic but csrc/host_ir/evaluator.cpp:334-336 and :405-407 still reject it for the CUDA backend multicast path. If CUDA backend is intended to work with StreamBroadcast, add it to those checks; otherwise this will fail at runtime when using CommunicatorBackend::kCuda.

input_tensor,
output_tensor);
case CommunicationType::Broadcast:
case CommunicationType::StreamBroadcast:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamBroadcast shares broadcast logic but csrc/multidevice/cuda_p2p.cpp:645-667 and :689-704 don't handle it in their switch statements. This will cause runtime errors with CUDA backend. Either add StreamBroadcast cases (treating them like Broadcast) or ensure CUDA backend is never used with this communication type.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 407 to 408
IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c_stream_id recomputed redundantly on every DID iteration

c_stream_id is fetched from consumer_pt_to_id on every iteration of the kParallelTypeDIDs loop, but its value never changes — it doesn't depend on the loop variable pt. Move it above the loop to avoid repeating the same map lookup up to three times:

Suggested change
IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);
IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);
for (ParallelType pt : kParallelTypeDIDs) {
IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt);

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 184 to 185
for (Expr* c : convertSingleOpToCommunication(
e, device_id, innermost.loop->index())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional null dereference on innermost.loop for all non-loop communications

innermost is a default-constructed LoopInfo (i.e. innermost.loop == nullptr) whenever loop_nest.empty(). This PR changed the call to convertSingleOpToCommunication to unconditionally pass innermost.loop->index(), which dereferences a null pointer for every Communication segment that sits outside a stream for-loop—not just StreamBroadcast.

Existing tests such as LowerGatherTest and LowerScatterTest (with enable_host_ir_lowering = true) exercise plain Allgather/Gather outside any loop and will crash on this path.

The fix is to pass nullptr when there is no enclosing loop:

Suggested change
for (Expr* c : convertSingleOpToCommunication(
e, device_id, innermost.loop->index())) {
for (Expr* c : convertSingleOpToCommunication(
e, device_id, loop_nest.empty() ? nullptr : innermost.loop->index())) {

Comment on lines +184 to +192
Team team = receiver_mesh.vector();
comms.push_back(IrBuilder::create<Communication>(
CommunicationType::StreamBroadcast,
output_tv,
input_tv,
team,
root,
c10d::ReduceOp::RedOpType::UNUSED,
backend));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamBroadcast root is a raw loop index, not a device ID

root is the for-loop index—an integer in [0, d). However team is receiver_mesh.vector(), which contains absolute device IDs. Inside postBroadcast the root is used in two ways that both assume it equals an absolute device ID:

  1. if (my_device_index == root_index) — decides which device does the local copy.
  2. getRootRelativeIndex(root_index) — calls std::find(team.begin(), team.end(), root_index) and asserts the value is present in the team.

For a mesh such as DeviceMesh({4, 5, 6, 7}), loop index 0 is neither equal to any my_device_index in {4,5,6,7} nor present in the team vector, so the assert fires at runtime.

The root should be the mesh device at position loop_index, i.e. receiver_mesh.at(loop_index). Since loop_index is a Val* evaluated at runtime, one approach is to look up the mesh device ID at evaluation time (e.g. via a GetItem/helper expression on the mesh tensor), or document this as a hard requirement (mesh must be arange(d)) and add a validation check at communication-creation time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Communication can accept a Val as root or a DeviceIdx

Comment on lines 404 to 408
for (ParallelType pt : kParallelTypeDIDs) {
IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt);
IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c_stream_id re-fetched on every DID iteration

c_stream_id is looked up from consumer_pt_to_id on every iteration of the kParallelTypeDIDs loop (DIDx, DIDy, DIDz), but its value never changes since it doesn't depend on pt. Moving it above the loop avoids up to two extra map lookups:

Suggested change
for (ParallelType pt : kParallelTypeDIDs) {
IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt);
IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);
IterDomain* c_stream_id =
getOrDefault(consumer_pt_to_id, ParallelType::Stream);
for (ParallelType pt : kParallelTypeDIDs) {
IterDomain* p_loop_did = getOrDefault(producer_pt_to_id, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_id, pt);

@Priya2698
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

sender_mesh,
" and ",
receiver_mesh);
Team team = receiver_mesh.vector();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

team contains absolute device IDs (e.g. {4, 5, 6, 7}), but root is a loop index Val* that evaluates to 0, 1, 2, ... at runtime. In postBroadcast (communication.cpp:474), the code checks if (my_device_index == root_index) and calls getRootRelativeIndex(root_index) which does std::find(team.begin(), team.end(), root_index). Both expect root_index to be an actual device ID from the team, not a loop iteration index. For a mesh like {4,5,6,7}, loop index 0 won't match any device and the assertion at communication.cpp:255 will fire.

Root should be the device ID at the loop index position: receiver_mesh.at(loop_index). Since loop_index is evaluated at runtime, you'll need to compute this device ID lookup at evaluation time (e.g., via an IR expression that indexes into the mesh).

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants