Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ class TensorQuantizer(nn.Module):
"_padding",
# Extra flags added by huggingface
"_is_hf_initialized",
# Extra flags added by accelerate
"_hf_hook",
"_old_forward",
"forward",
# Extra flags added by deepspeed
"ds_external_parameters",
"all_parameters",
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/torch/quantization/plugins/test_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle

import pytest
import torch
import torch.nn as nn

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.nn import QuantLinearConvBase
from modelopt.torch.quantization.nn import QuantLinearConvBase, TensorQuantizer

try:
from accelerate.hooks import ModelHook, add_hook_to_module
Expand Down Expand Up @@ -51,3 +53,30 @@ def test_linear_with_accelerate_monkey_patched_forward():

assert module_test.input_quantizer.amax is not None
assert module_test.weight_quantizer.amax is not None


def test_tensor_quantizer_modelopt_state_with_accelerate_hook():
"""Verify accelerate hook attributes are excluded from modelopt state.
When accelerate's add_hook_to_module patches a TensorQuantizer, it adds
_hf_hook, _old_forward, and an instance-level forward (a functools.partial
wrapping a local function). These must be excluded from the modelopt state
dict, otherwise torch.save / pickle will fail with:
AttributeError: Can't get local object 'add_hook_to_module.<locals>.new_forward'
"""
tq = TensorQuantizer()
add_hook_to_module(tq, ModelHook())

# The hook should have injected these instance attributes
assert hasattr(tq, "_hf_hook")
assert hasattr(tq, "_old_forward")
assert "forward" in tq.__dict__

# None of the accelerate attributes should appear in the modelopt state
state = tq.get_modelopt_state()
accelerate_attrs = {"_hf_hook", "_old_forward", "forward"}
leaked = accelerate_attrs & state.keys()
assert not leaked, f"Accelerate attributes leaked into modelopt state: {leaked}"

# The state dict must be picklable (torch.save uses pickle internally)
pickle.dumps(state)