Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5fc9857
support cuda graph capture offloading module
lhb8125 Dec 1, 2025
913fbe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2025
e04bc00
remove reset_hook and init_chunk_handler_hook
lhb8125 Dec 8, 2025
dda34c2
remove reset_hook and init_chunk_handler_hook
lhb8125 Dec 8, 2025
6ed4b91
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Dec 8, 2025
2f61c00
Merge branch 'main' into hongbinl/offload_activation_cuda_graph
lhb8125 Dec 8, 2025
88295b4
minor fix
Dec 18, 2025
ed2ee6a
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
Jan 13, 2026
09d0801
temp fix overlap-grad-reduce
lhb8125 Jan 19, 2026
8641228
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Jan 19, 2026
c3e341a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
6cd4af9
reuse mark_not_offload() and do not offload scale_inv
lhb8125 Jan 20, 2026
b54e77c
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Jan 20, 2026
ba065fc
temp fix for mxfp8
lhb8125 Jan 22, 2026
e00db5e
fix bug for record_stream and from_blob
lhb8125 Feb 2, 2026
f47b543
disable offloading core_attn_out and refine cpu overhead of at::empty
lhb8125 Feb 3, 2026
7ca3618
minor fix
lhb8125 Feb 5, 2026
12cf77b
Merge branch 'main' into hongbinl/offload_activation_cuda_graph
lhb8125 Feb 5, 2026
8c8fe59
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2026
8421cf9
return ptr of whole buffer and offload the whole buffer
lhb8125 Feb 6, 2026
2e47119
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 Feb 6, 2026
25dbad1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2026
24d22cf
Merge branch 'main' into hongbinl/offload_activation_cuda_graph_mxfp8…
lhb8125 Feb 27, 2026
d65b416
remove mark_not_offload for core_attn_out
lhb8125 Feb 27, 2026
484b0d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2026
281e95c
Merge branch 'main' into hongbinl/offload_activation_cuda_graph_mxfp8…
lhb8125 Mar 10, 2026
07e7cd8
Merge branch 'hongbinl/offload_activation_cuda_graph_mxfp8_offload_fi…
lhb8125 Mar 10, 2026
f1330a9
minor fix
lhb8125 Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def check_group_quantization_nvfp4_versus_reference(
reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose)
)

split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers)
split_quantize_outputs, _ = tex.split_quantize(x, split_sections, quantizers)

if return_identity:
x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs]
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def group_quantize_fp4(
]

if use_tex_split_quantize:
outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers)
outputs, _ = tex.split_quantize(x, split_sections, nvfp4_quantizers)
qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs]
sx_list = [output._rowwise_scale_inv for output in outputs]
qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
start_offload,
mark_activation_offload,
NVTE_CPU_OFFLOAD_V1,
mark_not_offload,
)
from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded

Expand Down Expand Up @@ -1311,6 +1312,9 @@ def forward(
# return appropriate tensors
out_ret = out_fp8 if is_output_fp8 else out

mark_not_offload(out_fp8)
mark_not_offload(out)

# save appropriate tensors
fp8_tensors = (None, None, None, None)
qkvo_tensors = (None, None, None, None)
Expand Down Expand Up @@ -1361,6 +1365,7 @@ def forward(
out = out_
out_ret = out_
fp8_tensors = (None, None, None, None)
mark_not_offload(out)
qkvo_tensors = (q, k, v, out)

nvtx_range_pop(f"{nvtx_label}")
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def mark_activation_offload(*tensors):

def mark_not_offload(*tensors: torch.Tensor):
"""Marks tensors to prevent them from being offloaded."""
if NVTE_CPU_OFFLOAD_V1:
return

tensors, tensor_obj = prepare_for_saving(*tensors)

Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation = false);
std::tuple<std::vector<py::object>, std::vector<at::Tensor>> split_quantize(
const at::Tensor &tensor, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list, bool disable_bulk_allocation = false);

/***************************************************************************************************
* Bias gradient fusions
Expand Down
55 changes: 34 additions & 21 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,15 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten

namespace {

std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp8_blockwise_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<Float8BlockQuantizer *> &quantizer_cpp_list) {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>>
bulk_allocate_fp8_blockwise_tensors(std::vector<std::vector<size_t>> &shape_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<Float8BlockQuantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
auto &buffer_list = std::get<2>(retval); // Buffers for offload

// Number of tensors
const size_t num_tensors = shape_list.size();
Expand Down Expand Up @@ -412,6 +414,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -455,6 +458,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -497,13 +501,15 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
return retval;
}

std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mxfp8_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<MXFP8Quantizer *> &quantizer_cpp_list) {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>>
bulk_allocate_mxfp8_tensors(std::vector<std::vector<size_t>> &shape_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<MXFP8Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, std::vector<at::Tensor>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
auto &buffer_list = std::get<2>(retval); // Buffers for offload

// Number of tensors
const size_t num_tensors = shape_list.size();
Expand Down Expand Up @@ -565,6 +571,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -605,6 +612,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -650,14 +658,17 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// allocate fp4 data, fp8 scalings, and amax values
// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN]
// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_allocate_nvfp4_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<NVFP4Quantizer *> &quantizer_cpp_list) {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool, std::vector<at::Tensor>>
bulk_allocate_nvfp4_tensors(std::vector<std::vector<size_t>> &shape_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<NVFP4Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> retval;
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool, std::vector<at::Tensor>>
retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
auto &contiguous_data_and_scale = std::get<2>(retval);
auto &buffer_list = std::get<3>(retval); // Buffers for offload
contiguous_data_and_scale = true;

// Number of tensors
Expand Down Expand Up @@ -742,6 +753,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -804,6 +816,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
buffer_list.push_back(*buffer); // Save buffer for offload

// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
Expand Down Expand Up @@ -1250,18 +1263,17 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,

} // namespace

std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation) {
std::tuple<std::vector<py::object>, std::vector<at::Tensor>> split_quantize(
const at::Tensor &tensor, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list, bool disable_bulk_allocation) {
init_extension();

// Check number of tensors
const size_t num_splits = split_sections.size();
NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ",
quantizer_list.size());
if (num_splits == 0) {
return {};
return {{}, {}};
}

// Input tensor properties
Expand Down Expand Up @@ -1328,14 +1340,15 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
// Allocate output tensors
std::vector<TensorWrapper> output_cpp_list;
std::vector<py::object> output_py_list;
std::vector<at::Tensor> buffer_list; // Buffers for offload (can be used for record_stream)
switch (allocation_method) {
case AllocationMethod::BULK_FP8_BLOCKWISE: {
// Bulk allocation for FP8 block-scaling tensors
std::vector<Float8BlockQuantizer *> blockwise_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
blockwise_quantizers.push_back(static_cast<Float8BlockQuantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
std::tie(output_py_list, output_cpp_list, buffer_list) =
bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers);
break;
}
Expand All @@ -1345,7 +1358,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
for (auto &quantizer : quantizer_cpp_list) {
mxfp8_quantizers.push_back(static_cast<MXFP8Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
std::tie(output_py_list, output_cpp_list, buffer_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
break;
}
Expand All @@ -1356,7 +1369,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
nvfp4_quantizers.push_back(static_cast<NVFP4Quantizer *>(quantizer.get()));
}
bool contiguous_data_and_scale;
std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) =
std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale, buffer_list) =
bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers);
if (!contiguous_data_and_scale) {
// Avoid fused quantize kernel if data is not contiguous
Expand Down Expand Up @@ -1393,7 +1406,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
}

return output_py_list;
return {output_py_list, buffer_list};
}

} // namespace pytorch
Expand Down
51 changes: 40 additions & 11 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
Quantizer,
prepare_for_saving,
restore_from_saved,
get_columnwise_subview_info,
restore_columnwise_subviews,
)
from ...debug.pytorch.debug_quantization import DebugQuantizer
from ...debug.pytorch.debug_state import TEDebugState
Expand Down Expand Up @@ -144,16 +146,24 @@ def forward(
)
inp_view = inp.reshape(-1, in_features)
inputmats: list
offload_buffer: torch.Tensor = None
subview_restore_info: list = []
if fp8 and not debug:
# Disable bulk allocation when CPU offloading is active: offloading skips small
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize(
inp_view,
m_splits,
input_quantizers,
disable_bulk_allocation=cpu_offloading,
)
inputmats, buffer_list = tex.split_quantize(inp_view, m_splits, input_quantizers)
if cpu_offloading:
# Mark inputmats as not offload - we offload the buffer instead
mark_not_offload(*inputmats)
# buffer_list layout: [rowwise_buffer?, columnwise_buffer?]
# columnwise buffer is always last if present; we only offload it
# since rowwise data is discarded when weight_requires_grad is True
if buffer_list and input_quantizers[0].columnwise_usage:
offload_buffer = buffer_list[-1]
# Get subview boundary info for restoration in backward
if offload_buffer is not None:
subview_restore_info = get_columnwise_subview_info(inputmats, offload_buffer)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype
Expand All @@ -162,7 +172,12 @@ def forward(
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)

if cpu_offloading:
start_offload(*inputmats)
if offload_buffer is not None:
# Offload the buffer instead of individual tensors
# (rowwise data is discarded when weight_requires_grad is True)
start_offload(offload_buffer)
else:
start_offload(*inputmats)

# Initialize weights
weights_fp8: list
Expand Down Expand Up @@ -237,10 +252,15 @@ def forward(
if save_original_input:
inputmats = [None] * num_gemms
inputmats[0] = inp
offload_buffer = None
subview_restore_info = []
else:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensorStorage):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if cpu_offloading and offload_buffer is not None:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=False)
else:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
inputmats = [None] * num_gemms

Expand All @@ -262,6 +282,7 @@ def forward(
*weights_fp8,
*weights,
*biases,
offload_buffer,
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
Expand Down Expand Up @@ -308,6 +329,7 @@ def forward(
ctx.debug = debug
ctx.save_original_input = save_original_input
ctx.input_quantizers = input_quantizers
ctx.subview_restore_info = subview_restore_info

# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
Expand All @@ -322,8 +344,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
weights = saved_tensors[N : 2 * N]
origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N]
offload_buffer = saved_tensors[4 * N] if len(saved_tensors) > 4 * N else None
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]

# Restore subviews from reloaded buffer
if ctx.cpu_offloading and ctx.subview_restore_info and offload_buffer is not None:
restore_columnwise_subviews(inputmats, offload_buffer, ctx.subview_restore_info)

if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
for i, weight in enumerate(ctx.weight_objects):
Expand Down Expand Up @@ -353,14 +380,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
# Unfused bias grad and multi-tensor quantize
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = tex.split_quantize(
grad_output, _ = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
else:
# Multi-tensor quantize
grad_output = tex.split_quantize(
grad_output, _ = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
Expand Down Expand Up @@ -452,7 +479,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list
if ctx.fp8 and not ctx.debug:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
inputmats, _ = tex.split_quantize(
inp_view, ctx.m_splits, ctx.input_quantizers
)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view,
Expand Down
Loading
Loading