Reduce-based MM+RS in MultiDeviceExecutor#5923
Conversation
Greptile OverviewGreptile SummaryImplements reduce-collective-based MM+RS (matmul + reduce-scatter) lowering for the MultiDeviceExecutor, complementing the existing P2P-based approach. The key change enables stream parallelization on reduction axes, where each stream performs a matmul slice and then reduces results using NCCL collectives with the stream index as the root. Key changes:
Critical issue:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Stream0
participant Stream1
participant StreamN
participant Device0
participant Device1
participant DeviceN
Note over Stream0,StreamN: FOR streamIdx in range(D)
Stream0->>Device0: SetCurrentStream(Stream0)
Stream0->>Device0: HirAliasSelect(A, axis=stream, idx=0)
Stream0->>Device0: HirAliasSelect(B, axis=stream, idx=0)
Stream0->>Device0: matmul(A_slice, B_slice)
Stream0->>DeviceN: Communication(Reduce, root=streamIdx=0)
Stream0->>Stream0: Wait(Reduce)
par Stream1 execution
Stream1->>Device1: SetCurrentStream(Stream1)
Stream1->>Device1: HirAliasSelect(A, axis=stream, idx=1)
Stream1->>Device1: HirAliasSelect(B, axis=stream, idx=1)
Stream1->>Device1: matmul(A_slice, B_slice)
Stream1->>DeviceN: Communication(Reduce, root=streamIdx=1)
Stream1->>Stream1: Wait(Reduce)
and StreamN execution
StreamN->>DeviceN: SetCurrentStream(StreamN)
StreamN->>DeviceN: HirAliasSelect(A, axis=stream, idx=N)
StreamN->>DeviceN: HirAliasSelect(B, axis=stream, idx=N)
StreamN->>DeviceN: matmul(A_slice, B_slice)
StreamN->>DeviceN: Communication(Reduce, root=streamIdx=N)
StreamN->>StreamN: Wait(Reduce)
end
Note over Stream0,StreamN: All streams synchronize with main stream
|
|
My plan after this PR is to implement an NVLS-based reduce for this algorithm, as well as an NVLS-based ReduceScatter for the default no-overlap algorithm. |
| // If no stream axis found, check the input tensor for a stream axis | ||
| // If found, use that. If not, keep the expression as is | ||
| if (stream_axis == nullptr) { | ||
| new_top_level_exprs.push_back(expr); | ||
| continue; | ||
| auto* input = expr->input(0)->as<TensorView>(); | ||
| IterDomain* input_stream_axis = getStreamAxis(input->getLoopDomain()); |
There was a problem hiding this comment.
Unsafe cast on input(0)
When stream_axis == nullptr, the new logic unconditionally does expr->input(0)->as<TensorView>() (line ~223). This will crash if the first input is not a TensorView (e.g., scalar params/indices/predicates in some host ops). Previously this path didn’t assume anything about inputs. This needs a type check (or use ir_utils::filterByType<TensorView>(expr->inputs())) before casting.
Additional Comments (1)
In Also appears in: same block controls all exprs that hit the fallback path in this function. |
Additional Comments (1)
This also interacts with Also appears in the detection logic at |
| body_expr); | ||
| NVF_ERROR( | ||
| body_expr->as<ReductionOp>()->getReductionOpType() == BinaryOpType::Add, | ||
| "expected a reduce operation but got ", |
There was a problem hiding this comment.
change error message to a more informative one.
Btw it shouldn't be a problem to support the other red ops.
Additional Comments (2)
|
| stream_axis->getIterType() == IterType::Broadcast, | ||
| "Stream axis ", | ||
| stream_axis, | ||
| " should be an iteration or broadcast axis."); |
There was a problem hiding this comment.
We stream parallelize the reduced axis in the sum op
e05bbc8 to
9d84835
Compare
|
!test |
… order to tell whether we're lowering for the p2p algorithm or the reduce-based algorithm
…s to be stream parallelized for MM+RS reduce-based
6fd37f8 to
3ada821
Compare
|
!test |
|
|
||
| EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); | ||
| } | ||
|
|
There was a problem hiding this comment.
Since the reduced axis can be stream parallelized, this test fails. So here I removed it.
| if_sending_to_self->elseBody().pushBack(send); | ||
| break; | ||
| if (params.offset_stream_indexing_by_rank) { | ||
| // Lower to MM + RS p2p based algorithm |
There was a problem hiding this comment.
This block is the same as before. The only difference is the indentation level
| NVF_THROW( | ||
| "Unsupported communicator backend for lowering stream parallel " | ||
| "type into p2p: ", | ||
| // Lower to the MM+RS reduce-collective-based algorithm |
There was a problem hiding this comment.
This block is the core change
| auto index = (indexed_id->isBroadcast() || input.size(axis) == 1) | ||
| ? 0 | ||
| : expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>(); | ||
|
|
There was a problem hiding this comment.
I removed the broadcast op from the fusion. It failed here because it was trying to select on the D axis which is 1 locally
There was a problem hiding this comment.
I removed the broadcast op from the fusion.
so why adding here the case of a broadcasted axis (which btw looks good to me) ?
It failed here because it was trying to select on the D axis which is 1 locally
so why not detect if the axis is sharded ? I think that checking that the axis is of size 1 is not correct. Firstly, if the dimension is DIDx then the symbolic size will be D and not 1. Secondly, if the axis is neither broadcast nor sharded but just happens to be of size 1, then we want to error out.
Does it make sense ?
| auto index = (indexed_id->isBroadcast() || input.size(axis) == 1) | ||
| ? 0 | ||
| : expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>(); | ||
|
|
There was a problem hiding this comment.
I removed the broadcast op from the fusion.
so why adding here the case of a broadcasted axis (which btw looks good to me) ?
It failed here because it was trying to select on the D axis which is 1 locally
so why not detect if the axis is sharded ? I think that checking that the axis is of size 1 is not correct. Firstly, if the dimension is DIDx then the symbolic size will be D and not 1. Secondly, if the axis is neither broadcast nor sharded but just happens to be of size 1, then we want to error out.
Does it make sense ?
This PR is a follow-up to Sam's Broadcast based pipeline PR. Instead of broadcasting for AG+MM, this PR handles the same flow but for reduce in the following MM+RS fusion:
The fusion gets lowered into this host_ir. The key idea is that there is no "swizzle", and the root of each reduce communication is the streamIdx.