Skip to content

[PyTorch debug] Fix issue with tp_group=None#2733

Open
pggPL wants to merge 3 commits intoNVIDIA:mainfrom
pggPL:fix_debug_hang
Open

[PyTorch debug] Fix issue with tp_group=None#2733
pggPL wants to merge 3 commits intoNVIDIA:mainfrom
pggPL:fix_debug_hang

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Mar 4, 2026

Description

If no TP is used, then tp_size=1 and tp_group=None inside TE modules.
In the debug feature utility get_reduction_params(), when TEDebugState.weight_tensor_tp_group_reduce=True, we previously set reduction_group = tp_group unconditionally. With tp_group=None, downstream distributed collectives interpret this as the default/world group, which can cause unintended cross-rank reduction for weight stats.

This PR fixes that behavior by making TP-group reduction explicit and safe:

  • tp_size is now passed to DebugQuantizer at creation time (from all module call sites: Linear, LayerNormLinear, LayerNormMLP, GroupedLinear) and forwarded through the inspect_tensor API args;
  • get_reduction_params() now takes tp_size and uses tp_size > 1 (instead of tp_group is not None) to determine whether tensor parallelism is active — this is the correct check since tp_group=None is ambiguous (it means the world process group in torch.distributed);
  • if weight_tensor_tp_group_reduce=True and tp_size > 1 but tp_group is None, skip reduction for weight stats;
  • if weight_tensor_tp_group_reduce=True and tp_size > 1 and tp_group is available, reduce in that TP group;
  • tp_size is added to the backward-compat kwargs filtering in call_feature (api.py), so custom user features whose inspect_tensor / inspect_tensor_postquantize signatures do not include tp_size continue to work without modification.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Mar 4, 2026

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR fixes a bug in the get_reduction_params() debug utility where tp_group=None (the state when no tensor parallelism is used, i.e. tp_size=1) was being passed directly as reduction_group to downstream distributed collectives. Because group=None in torch.distributed means the default world process group, this caused unintended cross-rank weight-stats reductions when the user never requested them.

The fix threads a new tp_size: int argument from each TE module through DebugQuantizer and into the three inspect_tensor feature implementations (LogTensorStats, LogFp8TensorStats, LogNvfp4TensorStats), which all forward it to get_reduction_params. The guarding condition TEDebugState.weight_tensor_tp_group_reduce and tp_size > 1 prevents the erroneous reduction_group = tp_group assignment when TP is effectively disabled. A secondary guard (if tp_group is not None) additionally protects against the unlikely case where tp_size > 1 yet tp_group is somehow None. The backward-compatibility shim in TransformerEngineAPI.call_feature strips tp_size from kwargs for third-party inspect_tensor / inspect_tensor_postquantize implementations that don't yet accept the new parameter.

Key changes:

  • get_reduction_params() now accepts tp_size and only sets reduction_group = tp_group when both tp_size > 1 and tp_group is not None; otherwise skip_reduction = True for weight tensors
  • DebugQuantizer.__init__ stores tp_size and propagates it through inspect_tensor kwargs
  • All four TE module classes (Linear, LayerNormLinear, LayerNormMLP, GroupedLinear) pass self.tp_size to DebugQuantizer
  • TransformerEngineAPI.call_feature now also filters tp_size for inspect_tensor_postquantize calls for backward compatibility
  • When tp_size > 1 but tp_group is None, reduction is silently skipped — a warning would help users identify this misconfiguration during debugging (see inline comment)

Confidence Score: 4/5

  • This PR is safe to merge; it is a well-scoped bug fix with no breaking interface changes and correct logic throughout.
  • The core fix in get_reduction_params is logically correct and directly addresses the described bug. tp_size is properly propagated through the entire call chain with a safe default of 1. Backward compatibility is maintained via the call_feature kwarg-stripping mechanism. The only notable gap is the silent skip when tp_size > 1 and tp_group is None, which could obscure misconfigurations during debugging. There is also a minor dead-code path in api.py. Neither of these is a correctness issue.
  • transformer_engine/debug/features/utils/init.py — review the silent-skip branch; transformer_engine/debug/features/api.py — dead kwargs_copy assignment in the else/inspect_tensor_postquantize path

Important Files Changed

Filename Overview
transformer_engine/debug/features/utils/init.py Core bug fix: adds tp_size parameter and guards reduction on tp_size > 1 + tp_group is not None to prevent unintended world-group reductions when TP is disabled.
transformer_engine/debug/pytorch/debug_quantization.py Correctly adds tp_size to DebugQuantizer.__init__ and propagates it into the inspect_tensor kwargs dict; no issues.
transformer_engine/debug/features/api.py Backward-compatibility shim to strip tp_size from inspect_tensor and inspect_tensor_postquantize kwargs when the feature implementation's signature doesn't accept it; minor dead-code path for the postquantize case.
transformer_engine/debug/features/log_fp8_tensor_stats.py Adds tp_size: int = 1 to inspect_tensor and passes it to get_reduction_params; clean and consistent.
transformer_engine/debug/features/log_tensor_stats.py Adds tp_size: int = 1 to inspect_tensor and forwards it to get_reduction_params.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[DebugQuantizer._call_inspect_tensor_api] --> B[debug_api.transformer_engine.inspect_tensor]
    B --> C[TransformerEngineAPI.call_feature]
    C --> D{call.__name__ == 'inspect_tensor'?}
    D -- Yes --> E[Strip tp_size if not in co_varnames]
    D -- No --> F{call.__name__ == 'inspect_tensor_postquantize'?}
    F -- Yes --> G[Strip tp_size if not in co_varnames]
    F -- No --> H[Pass kwargs as-is]
    E --> I[feature.inspect_tensor feat_config, layer_name, **kwargs_copy]
    G --> I
    H --> I
    I --> J[get_reduction_params tensor_name, tp_group, tp_size]
    J --> K{tensor_name == 'weight'?}
    K -- No --> L[skip_reduction=False, reduction_group=global, reduce_within_microbatch=True]
    K -- Yes --> M{weight_tensor_tp_group_reduce AND tp_size > 1?}
    M -- No --> N[skip_reduction=True]
    M -- Yes --> O{tp_group is not None?}
    O -- Yes --> P[reduction_group = tp_group]
    O -- No --> Q[skip_reduction=True NOTE: TP group missing]
Loading

Comments Outside Diff (2)

  1. transformer_engine/debug/features/utils/__init__.py, line 28-30 (link)

    Silent skip may obscure misconfigured TP state

    When tp_size > 1 but tp_group is None, reduction is silently skipped. In practice this combination indicates a misconfiguration (TP is logically enabled, but the group handle was never supplied), so weight stats will be silently dropped rather than reduced. Emitting a warning here would help users discover this edge case during debugging, rather than wondering why weight statistics look unexpected.

    Note: warnings would need to be imported at the top of the file.

  2. transformer_engine/debug/features/api.py, line 490-501 (link)

    Dead kwargs_copy assignment in else branch for inspect_tensor_postquantize

    When call.__name__ == "inspect_tensor_postquantize", execution first hits the else branch at line 491 which sets kwargs_copy = kwargs (a bare reference), but this value is immediately overwritten by kwargs_copy = kwargs.copy() at line 498. The first assignment is unreachable/dead for this path. Consider restructuring with elif to make the two branches mutually exclusive and avoid the unused assignment:

    (Remove the preceding else: kwargs_copy = kwargs block and change the second if to elif.)

Last reviewed commit: add5be3

@pggPL
Copy link
Collaborator Author

pggPL commented Mar 5, 2026

/te-ci pytorch L1

pggPL and others added 2 commits March 10, 2026 10:52
…rams

Use tp_size to determine whether tensor parallelism is active instead of
checking tp_group is None (which is ambiguous since None means world group
in torch.distributed). Also add tp_size to the backward-compat kwargs
filtering in call_feature so custom features without tp_size in their
inspect_tensor signature continue to work.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Mar 10, 2026

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant