diff --git a/tests/pytorch/distributed/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/run_fsdp2_fused_adam.py index 0439bf1b5a..7302755e7f 100644 --- a/tests/pytorch/distributed/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/run_fsdp2_fused_adam.py @@ -506,17 +506,13 @@ def test_dcp_output_parity(recipe=None, async_save=False): else: model_state = model.state_dict() + save_state = {"model": model_state, "optimizer": optimizer.state_dict()} + if not async_save: - dcp.save( - {"model": model_state, "optimizer": optimizer.state_dict()}, - checkpoint_id=checkpoint_dir, - ) - future = None + dcp.save(save_state, checkpoint_id=checkpoint_dir) else: - future = dcp.async_save( - {"model": model_state, "optimizer": optimizer.state_dict()}, - checkpoint_id=checkpoint_dir, - ) + future = dcp.async_save(save_state, checkpoint_id=checkpoint_dir) + future.result() # Block on async save completion # ── Build a fresh model and load the checkpoint ────────────────── model2 = _build_model(fp8_init=True, recipe=recipe) @@ -545,9 +541,6 @@ def test_dcp_output_parity(recipe=None, async_save=False): state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()} - if async_save: - future.result() # Block on async save completion - dcp.load(state_to_load, checkpoint_id=checkpoint_dir) model2.load_state_dict( state_to_load["model"], @@ -572,7 +565,7 @@ def test_dcp_output_parity(recipe=None, async_save=False): ref_output, rtol=0.05, atol=0.1, - msg="Fresh model loaded from DCP checkpoint produces different output", + msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", ) else: torch.testing.assert_close( @@ -580,7 +573,7 @@ def test_dcp_output_parity(recipe=None, async_save=False): ref_output, rtol=0, atol=0, - msg="Fresh model loaded from DCP checkpoint produces different output", + msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", ) # ── Verify one more training step produces identical results ───── diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index b10f31ea07..f51887f799 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -171,16 +171,6 @@ def test_fsdp2_dcp_output_parity(fp_recipe): @pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs") def test_fsdp2_dcp_output_parity_async(fp_recipe): """DCP save/load round-trip into a fresh model produces identical outputs.""" - if fp_recipe in ("DelayedScaling", "Float8CurrentScaling"): - pytest.xfail( - f"async DCP save/load with {fp_recipe} uses StateDictStager._offload_tensor() which " - "tries to deep-copy the tensor's underlying storage. Float8Tensor is a wrapper subclass" - "(_make_wrapper_subclass) with data_ptr() == 0 (empty storage). The staging code at " - "line 215 skips the storage copy for wrapper subclasses, creating a plain tensor with " - "uninitialized garbage data. The actual FP8 data (in _data, _scale_inv attributes) is " - "deep-copied but ignored by DCP when writing." - ) - if fp_recipe == "MXFP8BlockScaling": pytest.xfail( "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 07171914f5..ba706bd2cb 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -560,6 +560,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") + # New empty op (used by DCP async staging to create CPU copies) + if func == torch.ops.aten.new_empty.default: + tensor = args[0] + size = args[1] + dtype = kwargs.get("dtype", tensor.dtype) + device = kwargs.get("device", tensor.device) + pin_memory = kwargs.get("pin_memory", False) + out = tensor._quantizer.make_empty( + shape=torch.Size(size), + dtype=dtype, + device=device, + requires_grad=tensor.requires_grad, + pin_memory=pin_memory, + ) + return out + # Empty like op if func == torch.ops.aten.empty_like.default: tensor = args[0] diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 9cc00855cd..c35cbe6b7e 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -152,6 +152,7 @@ def make_empty( requires_grad=requires_grad, data_transpose=data_transpose, quantizer=self, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -379,6 +380,7 @@ def make_empty( requires_grad=requires_grad, data_transpose=data_transpose, quantizer=self, + device=device, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -953,6 +955,15 @@ def is_cuda(self): return self._transpose.is_cuda raise RuntimeError("Both data and transpose are None") + @property + def is_cpu(self): + """Return whether the tensor is on CPU.""" + if self._data is not None: + return self._data.is_cpu + if self._transpose is not None: + return self._transpose.is_cpu + raise RuntimeError("Both data and transpose are None") + @classmethod def _make_in_reduce_ex( cls, @@ -977,7 +988,14 @@ def _make_in_reduce_ex( ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling to remove references to FP8 metadata objects + + CPU Float8Tensors are serialized as dequantized plain tensors + for compatibility with torch.load(weights_only=True), which is + used by DCP async save staging. + """ + if self.is_cpu: + return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) return ( Float8Tensor._make_in_reduce_ex, (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 0fb7966c2f..0d9afd56d6 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -14,7 +14,7 @@ from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor @@ -35,6 +35,13 @@ def forward( if tensor._data is not None: if tensor._data.numel() == 0: return torch.empty_like(tensor._data, dtype=dtype) + if tensor._data.is_cpu: + # CPU fallback: reinterpret uint8 as FP8, cast to target dtype, scale + fp8_torch_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + return ( + tensor._data.view(fp8_torch_dtype).float() + * tensor._scale_inv.to(tensor._data.device) + ).to(dtype) # Cast from FP8 return tex.dequantize(tensor, te_dtype) @@ -132,6 +139,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, + "device": self.device, "fake_dtype": self._dtype, }