[Bug fix] Fake quantized model save after HF accelerate hooks are added#906
[Bug fix] Fake quantized model save after HF accelerate hooks are added#906
Conversation
📝 WalkthroughWalkthroughUpdates TensorQuantizer to exclude HuggingFace/accelerate-related attributes (_hf_hook, _old_forward, forward) from save/restore metadata. Adds test verifying these attributes are excluded and model state remains picklable. Publicly exports TensorQuantizer from the quantization module. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: realAsma <akuriparambi@nvidia.com>
f084240 to
27a0fb6
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/torch/quantization/plugins/test_accelerate.py (1)
58-82: Consider adding a round-trip save/restore test andproperties_only=Truecoverage.The current test validates exclusion and picklability, but doesn't verify that
set_from_modelopt_statestill correctly restores a quantizer when accelerate hooks are active, nor that theproperties_only=Truevariant is equally clean. Both gaps are low-risk (the skip set is shared by all code paths), but explicit coverage would prevent regressions.♻️ Suggested additions
# Also verify properties_only=True path state_props_only = tq.get_modelopt_state(properties_only=True) leaked_props = accelerate_attrs & state_props_only.keys() assert not leaked_props, f"Accelerate attributes leaked (properties_only=True): {leaked_props}" pickle.dumps(state_props_only) # Round-trip: restore a fresh TQ from the saved state tq2 = TensorQuantizer() add_hook_to_module(tq2, ModelHook()) tq2.set_from_modelopt_state(state) assert tq2.num_bits == tq.num_bits🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/quantization/plugins/test_accelerate.py` around lines 58 - 82, Add tests to assert get_modelopt_state(properties_only=True) also omits accelerate-injected attributes and is picklable, and perform a round-trip save/restore to ensure set_from_modelopt_state works when hooks are present: call TensorQuantizer.get_modelopt_state(properties_only=True) and verify it does not contain "_hf_hook", "_old_forward", or "forward" and that pickle.dumps succeeds, then create a fresh TensorQuantizer, apply add_hook_to_module, call set_from_modelopt_state(state) using the previously saved state, and assert key properties (e.g., num_bits) match the original; reference TensorQuantizer, get_modelopt_state, set_from_modelopt_state, properties_only, and add_hook_to_module to locate the code under test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unit/torch/quantization/plugins/test_accelerate.py`:
- Around line 58-82: Add tests to assert
get_modelopt_state(properties_only=True) also omits accelerate-injected
attributes and is picklable, and perform a round-trip save/restore to ensure
set_from_modelopt_state works when hooks are present: call
TensorQuantizer.get_modelopt_state(properties_only=True) and verify it does not
contain "_hf_hook", "_old_forward", or "forward" and that pickle.dumps succeeds,
then create a fresh TensorQuantizer, apply add_hook_to_module, call
set_from_modelopt_state(state) using the previously saved state, and assert key
properties (e.g., num_bits) match the original; reference TensorQuantizer,
get_modelopt_state, set_from_modelopt_state, properties_only, and
add_hook_to_module to locate the code under test.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #906 +/- ##
=======================================
Coverage 73.54% 73.54%
=======================================
Files 205 205
Lines 22000 22000
=======================================
Hits 16179 16179
Misses 5821 5821 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: Bug fix
Overview: Fix
AttributeError: Can't get local object 'add_hook_to_module.<locals>.new_forward'when saving a quantized model a second time after restoring it withdevice_map="auto".When a model is loaded with
device_map="auto", accelerate'sadd_hook_to_modulepatches every submodule (includingTensorQuantizerinstances) and injects three instance attributes:_hf_hook,_old_forward, andforward(afunctools.partialwrapping a local function). These are not picklable and were leaking into the modelopt state dict collected byget_modelopt_state(), causingtorch.saveto fail.This PR adds the three accelerate-injected attributes to
TensorQuantizer._skip_properties_for_save_restoreso they are excluded from the serialized state, matching the existing pattern used for HuggingFace and DeepSpeed attributes.Usage
Testing
test_tensor_quantizer_modelopt_state_with_accelerate_hookintests/unit/torch/quantization/plugins/test_accelerate.pythat verifies accelerate hook attributes are excluded from modelopt state and the state dict remains picklable.Before your PR is "Ready for review"
Additional Information
The root cause is in accelerate's
add_hook_to_module, which definesnew_forwardas a local function and binds it viafunctools.partialontomodule.forward. Since local functions cannot be pickled, anyTensorQuantizerthat has been hooked by accelerate becomes unserializable unless these attributes are excluded.Summary by CodeRabbit
Bug Fixes
Tests
Public API