[PyTorch debug] Fix issue with tp_group=None#2733
Conversation
|
/te-ci pytorch L1 |
Greptile SummaryThis PR fixes a bug in the The fix threads a new Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[DebugQuantizer._call_inspect_tensor_api] --> B[debug_api.transformer_engine.inspect_tensor]
B --> C[TransformerEngineAPI.call_feature]
C --> D{call.__name__ == 'inspect_tensor'?}
D -- Yes --> E[Strip tp_size if not in co_varnames]
D -- No --> F{call.__name__ == 'inspect_tensor_postquantize'?}
F -- Yes --> G[Strip tp_size if not in co_varnames]
F -- No --> H[Pass kwargs as-is]
E --> I[feature.inspect_tensor feat_config, layer_name, **kwargs_copy]
G --> I
H --> I
I --> J[get_reduction_params tensor_name, tp_group, tp_size]
J --> K{tensor_name == 'weight'?}
K -- No --> L[skip_reduction=False, reduction_group=global, reduce_within_microbatch=True]
K -- Yes --> M{weight_tensor_tp_group_reduce AND tp_size > 1?}
M -- No --> N[skip_reduction=True]
M -- Yes --> O{tp_group is not None?}
O -- Yes --> P[reduction_group = tp_group]
O -- No --> Q[skip_reduction=True NOTE: TP group missing]
|
|
/te-ci pytorch L1 |
…rams Use tp_size to determine whether tensor parallelism is active instead of checking tp_group is None (which is ambiguous since None means world group in torch.distributed). Also add tp_size to the backward-compat kwargs filtering in call_feature so custom features without tp_size in their inspect_tensor signature continue to work. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Description
If no TP is used, then
tp_size=1andtp_group=Noneinside TE modules.In the debug feature utility
get_reduction_params(), whenTEDebugState.weight_tensor_tp_group_reduce=True, we previously setreduction_group = tp_groupunconditionally. Withtp_group=None, downstream distributed collectives interpret this as the default/world group, which can cause unintended cross-rank reduction for weight stats.This PR fixes that behavior by making TP-group reduction explicit and safe:
tp_sizeis now passed toDebugQuantizerat creation time (from all module call sites:Linear,LayerNormLinear,LayerNormMLP,GroupedLinear) and forwarded through theinspect_tensorAPI args;get_reduction_params()now takestp_sizeand usestp_size > 1(instead oftp_group is not None) to determine whether tensor parallelism is active — this is the correct check sincetp_group=Noneis ambiguous (it means the world process group intorch.distributed);weight_tensor_tp_group_reduce=Trueandtp_size > 1buttp_group is None, skip reduction for weight stats;weight_tensor_tp_group_reduce=Trueandtp_size > 1andtp_groupis available, reduce in that TP group;tp_sizeis added to the backward-compat kwargs filtering incall_feature(api.py), so custom user features whoseinspect_tensor/inspect_tensor_postquantizesignatures do not includetp_sizecontinue to work without modification.Type of change
Checklist: