Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions tests/pytorch/distributed/run_fsdp2_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,17 +506,13 @@ def test_dcp_output_parity(recipe=None, async_save=False):
else:
model_state = model.state_dict()

save_state = {"model": model_state, "optimizer": optimizer.state_dict()}

if not async_save:
dcp.save(
{"model": model_state, "optimizer": optimizer.state_dict()},
checkpoint_id=checkpoint_dir,
)
future = None
dcp.save(save_state, checkpoint_id=checkpoint_dir)
else:
future = dcp.async_save(
{"model": model_state, "optimizer": optimizer.state_dict()},
checkpoint_id=checkpoint_dir,
)
future = dcp.async_save(save_state, checkpoint_id=checkpoint_dir)
future.result() # Block on async save completion

# ── Build a fresh model and load the checkpoint ──────────────────
model2 = _build_model(fp8_init=True, recipe=recipe)
Expand Down Expand Up @@ -545,9 +541,6 @@ def test_dcp_output_parity(recipe=None, async_save=False):

state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()}

if async_save:
future.result() # Block on async save completion

dcp.load(state_to_load, checkpoint_id=checkpoint_dir)
model2.load_state_dict(
state_to_load["model"],
Expand All @@ -572,15 +565,15 @@ def test_dcp_output_parity(recipe=None, async_save=False):
ref_output,
rtol=0.05,
atol=0.1,
msg="Fresh model loaded from DCP checkpoint produces different output",
msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
else:
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0,
atol=0,
msg="Fresh model loaded from DCP checkpoint produces different output",
msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)

# ── Verify one more training step produces identical results ─────
Expand Down
10 changes: 0 additions & 10 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,6 @@ def test_fsdp2_dcp_output_parity(fp_recipe):
@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
def test_fsdp2_dcp_output_parity_async(fp_recipe):
"""DCP save/load round-trip into a fresh model produces identical outputs."""
if fp_recipe in ("DelayedScaling", "Float8CurrentScaling"):
pytest.xfail(
f"async DCP save/load with {fp_recipe} uses StateDictStager._offload_tensor() which "
"tries to deep-copy the tensor's underlying storage. Float8Tensor is a wrapper subclass"
"(_make_wrapper_subclass) with data_ptr() == 0 (empty storage). The staging code at "
"line 215 skips the storage copy for wrapper subclasses, creating a plain tensor with "
"uninitialized garbage data. The actual FP8 data (in _data, _scale_inv attributes) is "
"deep-copied but ignored by DCP when writing."
)

if fp_recipe == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")

# New empty op (used by DCP async staging to create CPU copies)
if func == torch.ops.aten.new_empty.default:
tensor = args[0]
size = args[1]
dtype = kwargs.get("dtype", tensor.dtype)
device = kwargs.get("device", tensor.device)
pin_memory = kwargs.get("pin_memory", False)
out = tensor._quantizer.make_empty(
shape=torch.Size(size),
dtype=dtype,
device=device,
requires_grad=tensor.requires_grad,
pin_memory=pin_memory,
)
Comment on lines +570 to +576
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeError when _quantizer is None

tensor._quantizer can be None for Float8Tensor objects deserialized via the GPU path (_make_in_reduce_ex), which does not pass a quantizer argument. If a second async DCP save is attempted after a load/save round-trip, new_empty will be dispatched on the deserialized tensor, causing AttributeError: 'NoneType' object has no attribute 'make_empty'.

A guard is needed before calling make_empty:

if func == torch.ops.aten.new_empty.default:
    tensor = args[0]
    size = args[1]
    dtype = kwargs.get("dtype", tensor.dtype)
    device = kwargs.get("device", tensor.device)
    pin_memory = kwargs.get("pin_memory", False)
    if tensor._quantizer is None:
        raise RuntimeError(
            f"{type(tensor).__name__} does not have a quantizer; "
            "cannot create new_empty QuantizedTensor"
        )
    out = tensor._quantizer.make_empty(
        shape=torch.Size(size),
        dtype=dtype,
        device=device,
        requires_grad=tensor.requires_grad,
        pin_memory=pin_memory,
    )
    return out

return out

# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
Expand Down
20 changes: 19 additions & 1 deletion transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def make_empty(
requires_grad=requires_grad,
data_transpose=data_transpose,
quantizer=self,
device=device,
)

def calibrate(self, tensor: torch.Tensor) -> None:
Expand Down Expand Up @@ -379,6 +380,7 @@ def make_empty(
requires_grad=requires_grad,
data_transpose=data_transpose,
quantizer=self,
device=device,
)

def calibrate(self, tensor: torch.Tensor) -> None:
Expand Down Expand Up @@ -953,6 +955,15 @@ def is_cuda(self):
return self._transpose.is_cuda
raise RuntimeError("Both data and transpose are None")

@property
def is_cpu(self):
"""Return whether the tensor is on CPU."""
if self._data is not None:
return self._data.is_cpu
if self._transpose is not None:
return self._transpose.is_cpu
raise RuntimeError("Both data and transpose are None")

@classmethod
def _make_in_reduce_ex(
cls,
Expand All @@ -977,7 +988,14 @@ def _make_in_reduce_ex(
)

def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
"""Custom pickling to remove references to FP8 metadata objects

CPU Float8Tensors are serialized as dequantized plain tensors
for compatibility with torch.load(weights_only=True), which is
used by DCP async save staging.
"""
if self.is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
Comment on lines +997 to +998
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_cpu raises RuntimeError when both _data and _transpose are None

The is_cpu property raises RuntimeError("Both data and transpose are None") when both _data and _transpose are None — for example, after prepare_for_saving() or clear() sets them to None. The pre-existing GPU path in __reduce_ex__ handled this gracefully by passing None to _make_in_reduce_ex. The new check unconditionally calls self.is_cpu, introducing a regression for any Float8Tensor pickled in a cleared state.

A safe guard would be:

def __reduce_ex__(self, protocol: int) -> tuple:
    if self._data is not None and self._data.is_cpu:
        return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
    return (
        Float8Tensor._make_in_reduce_ex,
        (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
    )

This covers only the case where _data is a CPU tensor (the staged DCP scenario) and falls through to the original path for all other cases, including _data is None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, didn't i just change this from the other comment starting with "CPU-only transpose path not handled in reduce_ex"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — I suggested changing to if self.is_cpu: to handle the CPU transpose path, but that creates a new issue when both _data and _transpose are None.

The is_cpu property raises RuntimeError("Both data and transpose are None") in that case, which the original GPU path handled gracefully by passing None to _make_in_reduce_ex.

The fix needs to handle all three cases:

  1. _data on CPU → dequantize
  2. _transpose on CPU (no _data) → dequantize
  3. Both None → fall through to original path

Here's a clean solution using try-except:

Suggested change
if self.is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects
CPU Float8Tensors are serialized as dequantized plain tensors
for compatibility with torch.load(weights_only=True), which is
used by DCP async save staging.
"""
try:
is_cpu_tensor = self.is_cpu
except RuntimeError:
# Both _data and _transpose are None
is_cpu_tensor = False
if is_cpu_tensor:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
)

return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ...quantized_tensor import QuantizedTensorStorage, Quantizer

from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch

from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor

Expand All @@ -35,6 +35,13 @@ def forward(
if tensor._data is not None:
if tensor._data.numel() == 0:
return torch.empty_like(tensor._data, dtype=dtype)
if tensor._data.is_cpu:
# CPU fallback: reinterpret uint8 as FP8, cast to target dtype, scale
fp8_torch_dtype = TE_DType_To_Torch[tensor._fp8_dtype]
return (
tensor._data.view(fp8_torch_dtype).float()
* tensor._scale_inv.to(tensor._data.device)
).to(dtype)
# Cast from FP8
return tex.dequantize(tensor, te_dtype)

Expand Down Expand Up @@ -132,6 +139,7 @@ def get_metadata(self) -> Dict[str, Any]:
"fp8_dtype": self._fp8_dtype,
"data_transpose": self._transpose,
"quantizer": self._quantizer,
"device": self.device,
"fake_dtype": self._dtype,
Comment on lines 141 to 143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_metadata() raises when tensor is in cleared state

Adding "device": self.device is correct for the normal lifecycle, but Float8TensorStorage.device raises RuntimeError("Float8TensorStorage has no data!") when both _data and _transpose are None — exactly the state left by prepare_for_saving() or clear().

Before this PR, get_metadata() returned None for data and data_transpose without raising. Now any call to get_metadata() (e.g., via make_like()) on a cleared tensor would raise instead of propagating gracefully.

A safe guard:

"device": self._data.device if self._data is not None
          else (self._transpose.device if self._transpose is not None else None),

}

Expand Down