Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 30, 2025

Description

MCore's fused wgrad accumulation feature requires setting the grad_added_to_main_grad attribute on the weight's Python object. This means the original Python object must be accessible and modifiable during the backward pass.

Currently, weights are saved via save_for_backward, with the assumption that no hooks substitute them with different tensors (e.g., during CPU offload/reload). For CPU offloading, we work around this by saving weights directly on ctx. However, this approach is incompatible with non-TE CPU offloading scenarios and potentially conflicts with FSDP, which also manages weight tensors.

This PR addresses these issues by saving weak references to weights for the backward pass instead. When modifications to the original Python object are needed (e.g., setting grad_added_to_main_grad), the weakref is dereferenced and the modification is applied. This is done conditionally, only when MCore FSDP or MCore fused wgrad accumulation is enabled.

Changes:

  • Replace direct weight references with weakref in forward pass
  • Dereference weakrefs in backward pass only when fuse_wgrad_accumulation is enabled
  • Remove CPU offloading workarounds that saved weights directly on ctx
  • Apply consistent pattern across linear.py, layernorm_linear.py, grouped_linear.py, and layernorm_mlp.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 4 commits December 30, 2025 16:14
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant