diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 5f35e9ad10..6df6dabd29 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -77,7 +77,7 @@ def check_group_quantization_nvfp4_versus_reference( reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) ) - split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers) + split_quantize_outputs, _ = tex.split_quantize(x, split_sections, quantizers) if return_identity: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index b14eeb815b..92efd9e42a 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -194,7 +194,7 @@ def group_quantize_fp4( ] if use_tex_split_quantize: - outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers) + outputs, _ = tex.split_quantize(x, split_sections, nvfp4_quantizers) qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs] sx_list = [output._rowwise_scale_inv for output in outputs] qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a6a8b0b26a..e214df1dfa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -60,6 +60,7 @@ start_offload, mark_activation_offload, NVTE_CPU_OFFLOAD_V1, + mark_not_offload, ) from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded @@ -1311,6 +1312,9 @@ def forward( # return appropriate tensors out_ret = out_fp8 if is_output_fp8 else out + mark_not_offload(out_fp8) + mark_not_offload(out) + # save appropriate tensors fp8_tensors = (None, None, None, None) qkvo_tensors = (None, None, None, None) @@ -1361,6 +1365,7 @@ def forward( out = out_ out_ret = out_ fp8_tensors = (None, None, None, None) + mark_not_offload(out) qkvo_tensors = (q, k, v, out) nvtx_range_pop(f"{nvtx_label}") diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 05219b7b18..e1eaa9fa46 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -45,8 +45,6 @@ def mark_activation_offload(*tensors): def mark_not_offload(*tensors: torch.Tensor): """Marks tensors to prevent them from being offloaded.""" - if NVTE_CPU_OFFLOAD_V1: - return tensors, tensor_obj = prepare_for_saving(*tensors) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4d4e5094c..57e37518c1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -293,10 +293,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation = false); +std::tuple, std::vector> split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation = false); /*************************************************************************************************** * Bias gradient fusions diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f8f793f036..c38075b9ed 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -345,13 +345,15 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { -std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +std::tuple, std::vector, std::vector> +bulk_allocate_fp8_blockwise_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector> retval; + std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); + auto &buffer_list = std::get<2>(retval); // Buffers for offload // Number of tensors const size_t num_tensors = shape_list.size(); @@ -412,6 +414,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -455,6 +458,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -497,13 +501,15 @@ std::tuple, std::vector> bulk_allocate_fp return retval; } -std::tuple, std::vector> bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +std::tuple, std::vector, std::vector> +bulk_allocate_mxfp8_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector> retval; + std::tuple, std::vector, std::vector> retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); + auto &buffer_list = std::get<2>(retval); // Buffers for offload // Number of tensors const size_t num_tensors = shape_list.size(); @@ -565,6 +571,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -605,6 +612,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -650,14 +658,17 @@ std::tuple, std::vector> bulk_allocate_mx // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate -std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, - std::vector &quantizer_cpp_list) { +std::tuple, std::vector, bool, std::vector> +bulk_allocate_nvfp4_tensors(std::vector> &shape_list, + std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { init_extension(); - std::tuple, std::vector, bool> retval; + std::tuple, std::vector, bool, std::vector> + retval; auto &tensor_py_list = std::get<0>(retval); auto &tensor_cpp_list = std::get<1>(retval); auto &contiguous_data_and_scale = std::get<2>(retval); + auto &buffer_list = std::get<3>(retval); // Buffers for offload contiguous_data_and_scale = true; // Number of tensors @@ -742,6 +753,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -804,6 +816,7 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + buffer_list.push_back(*buffer); // Save buffer for offload // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -1250,10 +1263,9 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, } // namespace -std::vector split_quantize(const at::Tensor &tensor, - const std::vector &split_sections, - std::vector quantizer_list, - bool disable_bulk_allocation) { +std::tuple, std::vector> split_quantize( + const at::Tensor &tensor, const std::vector &split_sections, + std::vector quantizer_list, bool disable_bulk_allocation) { init_extension(); // Check number of tensors @@ -1261,7 +1273,7 @@ std::vector split_quantize(const at::Tensor &tensor, NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ", quantizer_list.size()); if (num_splits == 0) { - return {}; + return {{}, {}}; } // Input tensor properties @@ -1328,6 +1340,7 @@ std::vector split_quantize(const at::Tensor &tensor, // Allocate output tensors std::vector output_cpp_list; std::vector output_py_list; + std::vector buffer_list; // Buffers for offload (can be used for record_stream) switch (allocation_method) { case AllocationMethod::BULK_FP8_BLOCKWISE: { // Bulk allocation for FP8 block-scaling tensors @@ -1335,7 +1348,7 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { blockwise_quantizers.push_back(static_cast(quantizer.get())); } - std::tie(output_py_list, output_cpp_list) = + std::tie(output_py_list, output_cpp_list, buffer_list) = bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers); break; } @@ -1345,7 +1358,7 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { mxfp8_quantizers.push_back(static_cast(quantizer.get())); } - std::tie(output_py_list, output_cpp_list) = + std::tie(output_py_list, output_cpp_list, buffer_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); break; } @@ -1356,7 +1369,7 @@ std::vector split_quantize(const at::Tensor &tensor, nvfp4_quantizers.push_back(static_cast(quantizer.get())); } bool contiguous_data_and_scale; - std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = + std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale, buffer_list) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); if (!contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous @@ -1393,7 +1406,7 @@ std::vector split_quantize(const at::Tensor &tensor, multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); } - return output_py_list; + return {output_py_list, buffer_list}; } } // namespace pytorch diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3e7b57cf1..9d5f55ae28 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -50,6 +50,8 @@ Quantizer, prepare_for_saving, restore_from_saved, + get_columnwise_subview_info, + restore_columnwise_subviews, ) from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_state import TEDebugState @@ -144,16 +146,24 @@ def forward( ) inp_view = inp.reshape(-1, in_features) inputmats: list + offload_buffer: torch.Tensor = None + subview_restore_info: list = [] if fp8 and not debug: # Disable bulk allocation when CPU offloading is active: offloading skips small # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. - inputmats = tex.split_quantize( - inp_view, - m_splits, - input_quantizers, - disable_bulk_allocation=cpu_offloading, - ) + inputmats, buffer_list = tex.split_quantize(inp_view, m_splits, input_quantizers) + if cpu_offloading: + # Mark inputmats as not offload - we offload the buffer instead + mark_not_offload(*inputmats) + # buffer_list layout: [rowwise_buffer?, columnwise_buffer?] + # columnwise buffer is always last if present; we only offload it + # since rowwise data is discarded when weight_requires_grad is True + if buffer_list and input_quantizers[0].columnwise_usage: + offload_buffer = buffer_list[-1] + # Get subview boundary info for restoration in backward + if offload_buffer is not None: + subview_restore_info = get_columnwise_subview_info(inputmats, offload_buffer) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype @@ -162,7 +172,12 @@ def forward( inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) if cpu_offloading: - start_offload(*inputmats) + if offload_buffer is not None: + # Offload the buffer instead of individual tensors + # (rowwise data is discarded when weight_requires_grad is True) + start_offload(offload_buffer) + else: + start_offload(*inputmats) # Initialize weights weights_fp8: list @@ -237,10 +252,15 @@ def forward( if save_original_input: inputmats = [None] * num_gemms inputmats[0] = inp + offload_buffer = None + subview_restore_info = [] else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if cpu_offloading and offload_buffer is not None: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=False) + else: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms @@ -262,6 +282,7 @@ def forward( *weights_fp8, *weights, *biases, + offload_buffer, ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects @@ -308,6 +329,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.subview_restore_info = subview_restore_info # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -322,8 +344,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weights = saved_tensors[N : 2 * N] origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] + offload_buffer = saved_tensors[4 * N] if len(saved_tensors) > 4 * N else None main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + # Restore subviews from reloaded buffer + if ctx.cpu_offloading and ctx.subview_restore_info and offload_buffer is not None: + restore_columnwise_subviews(inputmats, offload_buffer, ctx.subview_restore_info) + if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: for i, weight in enumerate(ctx.weight_objects): @@ -353,14 +380,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Unfused bias grad and multi-tensor quantize for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) - grad_output = tex.split_quantize( + grad_output, _ = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, ) else: # Multi-tensor quantize - grad_output = tex.split_quantize( + grad_output, _ = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, @@ -452,7 +479,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + inputmats, _ = tex.split_quantize( + inp_view, ctx.m_splits, ctx.input_quantizers + ) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index cb697bc197..095ff2834d 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -173,6 +173,108 @@ def restore_from_saved( return tensor_objects +def get_columnwise_subview_info(inputmats: list, columnwise_buffer: torch.Tensor) -> list: + """ + Get boundary information for columnwise internal tensors in inputmats. + + This function extracts the byte offsets, shapes, strides, and dtypes of + columnwise internal tensors (_columnwise_data, _columnwise_scale_inv, + _columnwise_amax) relative to a shared buffer. This information is used + to restore subviews after CPU offload/reload. + + Only extracts columnwise data info since rowwise data is typically + discarded when weight_requires_grad is True. + + Uses tuples instead of dicts to minimize CPU overhead. + + Args: + inputmats: List of QuantizedTensorStorage objects created by bulk allocation. + columnwise_buffer: The buffer that contains the columnwise data (must be uint8). + + Returns: + List of tuples: [(columnwise_data_info, columnwise_scale_info, columnwise_amax_info), ...] + Each info is (byte_offset, shape, stride, dtype) or None if not present. + byte_offset is the offset in bytes from the start of the buffer. + """ + if columnwise_buffer is None: + return [] + + info_list = [] + buffer_ptr = columnwise_buffer.data_ptr() + + for tensor in inputmats: + # Get columnwise_data info + # Use data_ptr() difference to get the actual byte offset in buffer + col_data = getattr(tensor, "_columnwise_data", None) + col_data_info = ( + (col_data.data_ptr() - buffer_ptr, col_data.shape, col_data.stride(), col_data.dtype) + if col_data is not None + else None + ) + + # Get columnwise_scale_inv info + col_scale = getattr(tensor, "_columnwise_scale_inv", None) + col_scale_info = ( + ( + col_scale.data_ptr() - buffer_ptr, + col_scale.shape, + col_scale.stride(), + col_scale.dtype, + ) + if col_scale is not None + else None + ) + + # Get columnwise_amax info (for NVFP4) + col_amax = getattr(tensor, "_columnwise_amax", None) + col_amax_info = ( + (col_amax.data_ptr() - buffer_ptr, col_amax.shape, col_amax.stride(), col_amax.dtype) + if col_amax is not None + else None + ) + + info_list.append((col_data_info, col_scale_info, col_amax_info)) + return info_list + + +def restore_columnwise_subviews( + inputmats: list, columnwise_buffer: torch.Tensor, info_list: list +) -> None: + """ + Restore columnwise internal tensors from reloaded buffer. + + After CPU offload and reload, the columnwise_buffer may be at a new memory + location. This function restores the columnwise internal tensors of inputmats + to point to the correct locations in the reloaded buffer using as_strided. + + Args: + inputmats: List of QuantizedTensorStorage objects to restore. + columnwise_buffer: The reloaded columnwise buffer (must be uint8). + info_list: Boundary info returned by get_columnwise_subview_info(). + """ + if columnwise_buffer is None or not info_list: + return + + for tensor, info in zip(inputmats, info_list): + col_data_info, col_scale_info, col_amax_info = info + + # Restore columnwise_data using as_strided (avoids empty + set_ overhead) + # NOTE: byte_offset == element_offset because buffer dtype is uint8 + if col_data_info is not None: + byte_offset, shape, stride, _ = col_data_info + tensor._columnwise_data = columnwise_buffer.as_strided(shape, stride, byte_offset) + + # Restore columnwise_scale_inv + if col_scale_info is not None: + byte_offset, shape, stride, _ = col_scale_info + tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset) + + # Restore columnwise_amax (NVFP4) + if col_amax_info is not None: + byte_offset, shape, stride, _ = col_amax_info + tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset) + + class Quantizer(abc.ABC): """Builder class for quantized tensors. diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 2a86717017..febf4cbe1f 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -367,9 +367,9 @@ def update_usage( rowwise_usage = self._rowwise_data is not None if columnwise_usage is None: columnwise_usage = self._columnwise_data is not None - assert ( - columnwise_usage or rowwise_usage - ), "Must retain some data either columnwise or rowwise" + # assert ( + # columnwise_usage or rowwise_usage + # ), "Must retain some data either columnwise or rowwise" if columnwise_usage and rowwise_usage: if not self._is_2D_scaled: