Skip to content

[PyTorch] Fix CPU offloading for bulk-allocated quantized tensors#2716

Draft
lhb8125 wants to merge 28 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/offload_activation_cuda_graph_mxfp8_offload_fix
Draft

[PyTorch] Fix CPU offloading for bulk-allocated quantized tensors#2716
lhb8125 wants to merge 28 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/offload_activation_cuda_graph_mxfp8_offload_fix

Conversation

@lhb8125
Copy link
Contributor

@lhb8125 lhb8125 commented Feb 27, 2026

Description

When CPU offloading is used with bulk-allocated quantized tensors (FP8 blockwise / MXFP8 / NVFP4), sub-tensors share a common buffer via sub-views. This creates a dilemma:

  • If small tensors (e.g. scales) are not offloaded: their sub-view references keep the entire shared buffer's refcount non-zero, preventing the buffer from being freed -- completely defeating the purpose of offloading.
  • If small tensors are offloaded: the large number of small D2H/H2D transfers introduces significant CPU overhead.

This PR resolves both issues by offloading the underlying buffer directly as a single contiguous transfer, bypassing individual sub-tensor handling entirely.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • split_quantize now returns buffer list alongside quantized tensors (C++ & Python API change)
  • New helpers get_columnwise_subview_info() / restore_columnwise_subviews() to save/restore sub-tensor positions within the buffer after reload
  • grouped_linear.py: offload the whole buffer in forward; restore columnwise subviews in backward
  • backends.py: mark_not_offload on attention outputs to prevent incorrect offloading
  • cpu_offload.py: remove V1 early-return in mark_not_offload so it works for both V1 and V2
  • Tests: update split_quantize call sites for new return type

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

lhb8125 and others added 24 commits November 30, 2025 21:33
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: root <root@eos0046.eos.clusters.nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@lhb8125 lhb8125 marked this pull request as draft February 27, 2026 11:09
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR enhances CPU offload for CUDA graphs with MXFP8 quantization by switching from individual tensor offload to buffer-based offload with subview restoration.

Major changes:

  • Modified split_quantize to return (tensor_list, buffer_list) tuple instead of just tensor list
  • Added get_columnwise_subview_info() and restore_columnwise_subviews() for tracking/restoring tensor views after offload
  • Changed grouped_linear to offload buffers instead of individual tensors when CPU offloading is enabled
  • Added CUDA stream synchronization for graph replay with cuda_graph_stream and cuda_graph_event parameters
  • Commented out assertion in float8_blockwise_tensor_storage.py to allow both usage flags to be False during buffer-based offload

Critical issues found:

  • restore_columnwise_subviews() captures dtype info but doesn't use it - restored tensors will have wrong dtype (uint8 instead of float32 for scales/amax)
  • Missing event.record() calls after graph replay in forward and backward - synchronization won't work correctly
  • API breaking change: split_quantize now returns tuple, but transformer_engine/pytorch/ops/basic/grouped_linear.py (lines 527, 612) wasn't updated and will fail at runtime

Confidence Score: 1/5

  • This PR contains multiple critical bugs that will cause runtime failures
  • Three critical logic bugs: (1) dtype restoration bug causes wrong tensor types after offload, (2) missing CUDA event recording breaks graph synchronization, (3) API change breaks existing code in ops/basic/grouped_linear.py which wasn't updated
  • Critical: transformer_engine/pytorch/quantized_tensor.py (dtype bug), transformer_engine/pytorch/graph.py (sync bug), and transformer_engine/pytorch/ops/basic/grouped_linear.py (not in PR but will be broken)

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds subview tracking/restoration functions for CPU offload, but dtype information is captured but not used during restoration
transformer_engine/pytorch/module/grouped_linear.py Changes CPU offload strategy to offload buffers instead of individual tensors, adds subview restoration in backward pass
transformer_engine/pytorch/ops/basic/grouped_linear.py NOT CHANGED IN THIS PR but has critical bug: split_quantize calls not updated to unpack tuple return value (lines 527, 612)
transformer_engine/pytorch/graph.py Adds pre/post warmup hooks and CUDA stream synchronization for graph replay, but missing event.record() after graph replay
transformer_engine/pytorch/csrc/extensions/cast.cpp Updates bulk allocation functions to return buffer list alongside tensor list for CPU offload support
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Comments out assertion requiring at least one of columnwise/rowwise usage; needed for buffer-based offload where both flags are False

Last reviewed commit: 484b0d5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +268 to +270
if col_scale_info is not None:
byte_offset, shape, stride, _ = col_scale_info
tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype stored in col_scale_info but not used. as_strided on uint8 buffer returns uint8 tensor, but _columnwise_scale_inv should be float32 (for Float8Blockwise) or uint8 (for MXFP8). This will cause type mismatches when the restored tensor is used.

Suggested change
if col_scale_info is not None:
byte_offset, shape, stride, _ = col_scale_info
tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset)
if col_scale_info is not None:
byte_offset, shape, stride, dtype = col_scale_info
tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype)

Comment on lines +273 to +275
if col_amax_info is not None:
byte_offset, shape, stride, _ = col_amax_info
tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same dtype issue: _columnwise_amax should be float32 for NVFP4, but will be restored as uint8

Suggested change
if col_amax_info is not None:
byte_offset, shape, stride, _ = col_amax_info
tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset)
if col_amax_info is not None:
byte_offset, shape, stride, dtype = col_amax_info
tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype)

Comment on lines +808 to +811
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cuda_graph_event.record(cuda_graph_stream) after replay. Without recording the event, wait_event waits for the wrong completion point

Suggested change
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
cuda_graph_event.record(cuda_graph_stream)
torch.cuda.current_stream().wait_event(cuda_graph_event)

Comment on lines +828 to +831
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue: missing ctx.cuda_graph_event.record(ctx.cuda_graph_stream) after backward graph replay

Suggested change
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
ctx.cuda_graph_event.record(ctx.cuda_graph_stream)
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)

Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@lhb8125 lhb8125 changed the title Hongbinl/offload activation cuda graph mxfp8 offload fix [feat] Return the pointer of whole buffer instead of a list of pointers to experts Mar 10, 2026
@lhb8125 lhb8125 changed the title [feat] Return the pointer of whole buffer instead of a list of pointers to experts [PyTorch] Fix CPU offloading for bulk-allocated quantized tensors Mar 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant