diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3852d1144..9b401a335 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -156,6 +156,10 @@ class TensorQuantizer(nn.Module): "_padding", # Extra flags added by huggingface "_is_hf_initialized", + # Extra flags added by accelerate + "_hf_hook", + "_old_forward", + "forward", # Extra flags added by deepspeed "ds_external_parameters", "all_parameters", diff --git a/tests/unit/torch/quantization/plugins/test_accelerate.py b/tests/unit/torch/quantization/plugins/test_accelerate.py index 0c81ba457..df5a4701d 100644 --- a/tests/unit/torch/quantization/plugins/test_accelerate.py +++ b/tests/unit/torch/quantization/plugins/test_accelerate.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle + import pytest import torch import torch.nn as nn import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.nn import QuantLinearConvBase +from modelopt.torch.quantization.nn import QuantLinearConvBase, TensorQuantizer try: from accelerate.hooks import ModelHook, add_hook_to_module @@ -51,3 +53,30 @@ def test_linear_with_accelerate_monkey_patched_forward(): assert module_test.input_quantizer.amax is not None assert module_test.weight_quantizer.amax is not None + + +def test_tensor_quantizer_modelopt_state_with_accelerate_hook(): + """Verify accelerate hook attributes are excluded from modelopt state. + + When accelerate's add_hook_to_module patches a TensorQuantizer, it adds + _hf_hook, _old_forward, and an instance-level forward (a functools.partial + wrapping a local function). These must be excluded from the modelopt state + dict, otherwise torch.save / pickle will fail with: + AttributeError: Can't get local object 'add_hook_to_module..new_forward' + """ + tq = TensorQuantizer() + add_hook_to_module(tq, ModelHook()) + + # The hook should have injected these instance attributes + assert hasattr(tq, "_hf_hook") + assert hasattr(tq, "_old_forward") + assert "forward" in tq.__dict__ + + # None of the accelerate attributes should appear in the modelopt state + state = tq.get_modelopt_state() + accelerate_attrs = {"_hf_hook", "_old_forward", "forward"} + leaked = accelerate_attrs & state.keys() + assert not leaked, f"Accelerate attributes leaked into modelopt state: {leaked}" + + # The state dict must be picklable (torch.save uses pickle internally) + pickle.dumps(state)