[DeepCompile] fix gather params in dynamo skipped frames for ZeRO3#8059
[DeepCompile] fix gather params in dynamo skipped frames for ZeRO3#8059XAheli wants to merge 3 commits into
Conversation
Signed-off-by: ahpoddar <ahpoddar@redhat.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 04bf85696a
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for p in self.module.parameters(): | ||
| if hasattr(p, "ds_status") and p.ds_status == ZeroParamStatus.AVAILABLE and not p.ds_persist: | ||
| p.partition() |
There was a problem hiding this comment.
Don't free eager-gathered weights before backward
When a Dynamo-skipped frame executes eagerly and touches a ZeRO-3 parameter whose backward needs the weight (for example a Linear after an embedding, where grad-input must be computed), this loop immediately calls p.partition() after forward. In DeepCompile the ZeRO module backward hooks have been removed, and deepcompile_backward_prologue() only starts the compiled runtime, so there is no eager fallback to re-gather that saved weight before the eager autograd node runs; backward will see the released [0] parameter/storage or compute from invalid state. The fallback-gathered params need to stay available until their eager backward use has completed, or get a matching backward-time gather/release path.
Useful? React with 👍 / 👎.
| if _dc_z3_eager_fallback: | ||
| for p in self.module.parameters(): | ||
| if hasattr(p, "ds_status") and p.ds_status == ZeroParamStatus.AVAILABLE and not p.ds_persist: | ||
| p.partition() |
There was a problem hiding this comment.
Thanks for fixing this!
One question: I was wondering why we need to walk through all parameters and free those that are still gathered at this point? Does it mean the parameters gathered outside the compiled graphs are all alive till this point? If so, it can increase the peak GPU memory usage, which can hurt training efficiency in some cases.
|
Thanks @XAheli for digging into this! The root-cause analysis is correct, and keeping One issue is the post-forward release loop just moves the failure into backward (this confirms the earlier P1 review comment). With DeepCompile actually enabled and the parameters not persistent ( Autograd saves leaf parameters by reference and reads Another issue is that this PR releases gathered parameters based on global Your tests didn't catch these issues because To make this concrete, I've opened a PR against your branch implementing the rework: XAheli#1. In summary it:
On @eternalNight's question: This still increases the peak memory, though the cost is now bounded to the fallback-gathered set. But I think it would avoid errors and keep correctness. Can you take a look at the rework, adjust it as needed, and merge it into your branch if it looks reasonable? (if you'd prefer to address these issues in your own way, that works just as well) Either way, once the backward-safe release Thanks again for working on this! |
|
@tohtana @eternalNight thanks a lot for the detailed review :) I'll take a deeper look and push the changes soon! |
Fixes #7942
Root cause: When
init_z3()initializes DeepCompile it removes all three parameter-gathering mechanisms (ZeROOrderedDict, module hooks, engine forward hooks) and relies entirely on compiled FX graph ops for allgather/release. buttorch._dynamomay skip entire frames when it detects graph breaks in for/while loops. Skipped frames execute eagerly with no gathering mechanism, so parameters stay partitioned at shape[0].Testing
Validated on 2× H200 with ZeRO3 + DeepCompile:
Test plan
pre-commit runpasses on all changed filestests/torch_compile/test_compile.pypasses (2 GPU, ZeRO-3)test_deepcompile_skipped_frame.pypassescc @tohtana @eternalNight