[PyTorch] Fix CPU offloading for bulk-allocated quantized tensors#2716
[PyTorch] Fix CPU offloading for bulk-allocated quantized tensors#2716lhb8125 wants to merge 28 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enhances CPU offload for CUDA graphs with MXFP8 quantization by switching from individual tensor offload to buffer-based offload with subview restoration. Major changes:
Critical issues found:
Confidence Score: 1/5
Important Files Changed
Last reviewed commit: 484b0d5 |
| if col_scale_info is not None: | ||
| byte_offset, shape, stride, _ = col_scale_info | ||
| tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset) |
There was a problem hiding this comment.
dtype stored in col_scale_info but not used. as_strided on uint8 buffer returns uint8 tensor, but _columnwise_scale_inv should be float32 (for Float8Blockwise) or uint8 (for MXFP8). This will cause type mismatches when the restored tensor is used.
| if col_scale_info is not None: | |
| byte_offset, shape, stride, _ = col_scale_info | |
| tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset) | |
| if col_scale_info is not None: | |
| byte_offset, shape, stride, dtype = col_scale_info | |
| tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype) |
| if col_amax_info is not None: | ||
| byte_offset, shape, stride, _ = col_amax_info | ||
| tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset) |
There was a problem hiding this comment.
Same dtype issue: _columnwise_amax should be float32 for NVFP4, but will be restored as uint8
| if col_amax_info is not None: | |
| byte_offset, shape, stride, _ = col_amax_info | |
| tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset) | |
| if col_amax_info is not None: | |
| byte_offset, shape, stride, dtype = col_amax_info | |
| tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype) |
transformer_engine/pytorch/graph.py
Outdated
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with cuda_graph_stream: | ||
| fwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(cuda_graph_event) |
There was a problem hiding this comment.
Missing cuda_graph_event.record(cuda_graph_stream) after replay. Without recording the event, wait_event waits for the wrong completion point
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) | |
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| cuda_graph_event.record(cuda_graph_stream) | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) |
transformer_engine/pytorch/graph.py
Outdated
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with ctx.cuda_graph_stream: | ||
| bwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) |
There was a problem hiding this comment.
Same issue: missing ctx.cuda_graph_event.record(ctx.cuda_graph_stream) after backward graph replay
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) | |
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| ctx.cuda_graph_event.record(ctx.cuda_graph_stream) | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) |
…x' of https://github.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph_mxfp8_offload_fix
Description
When CPU offloading is used with bulk-allocated quantized tensors (FP8 blockwise / MXFP8 / NVFP4), sub-tensors share a common buffer via sub-views. This creates a dilemma:
This PR resolves both issues by offloading the underlying buffer directly as a single contiguous transfer, bypassing individual sub-tensor handling entirely.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
split_quantizenow returns buffer list alongside quantized tensors (C++ & Python API change)get_columnwise_subview_info() / restore_columnwise_subviews()to save/restore sub-tensor positions within the buffer after reloadgrouped_linear.py: offload the whole buffer in forward; restore columnwise subviews in backwardbackends.py:mark_not_offloadon attention outputs to prevent incorrect offloadingcpu_offload.py: remove V1 early-return inmark_not_offloadso it works for both V1 and V2split_quantizecall sites for new return typeChecklist: