Skip to content

Single parameter for GroupedLinear module#2727

Closed
ksivaman wants to merge 2 commits intoNVIDIA:mainfrom
ksivaman:quantized_fused_group_tensor
Closed

Single parameter for GroupedLinear module#2727
ksivaman wants to merge 2 commits intoNVIDIA:mainfrom
ksivaman:quantized_fused_group_tensor

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Mar 3, 2026

Description

Follow ups and miscellaneous fixes from #2600, #2654, and #2678.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ksivaman ksivaman marked this pull request as draft March 3, 2026 05:12
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Greptile Summary

This PR follows up on prior grouped-linear work (#2600, #2654, #2678) with three main changes: (1) introduces a single_grouped_parameter constructor flag on GroupedLinear that replaces the old NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS environment variable, exposing a single weight GroupedTensor parameter instead of num_gemms individual weight{i} parameters; (2) refactors GroupedTensor (now GroupedTensorStorage) into a storage-only base class and adds a new GroupedTensor wrapper subclass of torch.Tensor to enable direct use as an nn.Parameter; and (3) adds a custom nvte_cumsum CUDA kernel (single-block Kogge-Stone prefix scan) to replace the two-kernel at::cumsum + at::cat pattern in build_grouped_tensor_offsets.

Key changes and issues found:

  • __torch_dispatch__ returns None for in-place ops (grouped_tensor.py:1231): In-place operations should return args[0] per PyTorch convention; returning None can silently break callers that chain on the result (e.g., torch.nn.init.kaiming_uniform_ which returns tensor.uniform_(...)).
  • Single fp8_meta_index for all GEMM sub-weights (grouped_linear.py:790): The grouped weight parameter is registered with fp8_meta_index=self._offsets["weight"], but each sub-weight previously had its own distinct index. This may cause incorrect FP8 metadata (amax/scale) for MXFP8 and Float8BlockScaling recipes.
  • API rename shapeshapes, quantizerquantizers is applied consistently across Python and C++ layers, including test updates.
  • The nvte_cumsum kernel is well-tested (sizes 1–1024, negative values) and correctly handles multi-chunk inputs via a chunk_carry accumulator within a single block.

Confidence Score: 3/5

  • Safe for experimental use but has correctness concerns for FP8 recipes and in-place op return values that should be resolved before production use.
  • The single_grouped_parameter path is marked EXPERIMENTAL in the docstring, which is appropriate. The single fp8_meta_index issue means MXFP8 and Float8BlockScaling recipes may silently use stale metadata for GEMMs beyond the first. The None return from in-place __torch_dispatch__ deviates from PyTorch convention and could break parameter initialization paths that use the return value.
  • transformer_engine/pytorch/tensor/grouped_tensor.py (in-place dispatch return value) and transformer_engine/pytorch/module/grouped_linear.py (FP8 meta index assignment in make_grouped_weights)

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/grouped_tensor.py New GroupedTensor class that wraps GroupedTensorStorage and torch.Tensor as a wrapper subclass. Implements __torch_dispatch__ for transparent dequantize/requantize on ops, but returns None for in-place ops instead of the expected args[0], deviating from PyTorch convention.
transformer_engine/pytorch/module/grouped_linear.py Introduces single_grouped_parameter flag replacing the old NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS env var. make_grouped_weights now registers a single weight parameter; however, it assigns a single fp8_meta_index for all GEMM sub-weights, which may be incorrect for MXFP8 and Float8BlockScaling recipes.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Renamed from grouped_tensor.py; class renamed from GroupedTensor to GroupedTensorStorage. API changed from single quantizer to per-tensor quantizers list, with shapeshapes and datarowwise_data. The make_grouped_tensor factory now dispatches to GroupedTensor vs GroupedTensorStorage based on quantizer.internal. Scale shape calculations use only reference_quantizer, which is safe given same-recipe enforcement.
transformer_engine/common/common.cu Adds cumsum_with_leading_zero_kernel, a single-block Kogge-Stone parallel prefix scan for int64. Launched as <<<1, 256>>> so it processes large inputs in 256-element chunks via a chunk_carry accumulator. Correctness verified by the accompanying test. Appropriate for the small num_gemms use case.
transformer_engine/pytorch/csrc/quantizer.cpp Replaces at::cumsum + at::cat with nvte_cumsum in build_grouped_tensor_offsets, reducing it to a single kernel launch. All six create_grouped_tensor implementations uniformly switch from GroupedTensorStoragePythonClass to the new grouped_tensor_python_class(this->internal) helper, correctly routing to wrapper vs. storage class.
tests/cpp/operator/test_cumsum.cu New C++ test for nvte_cumsum. Tests known values and parameterized sizes (1, 2, 17, 256, 257, 513, 1024) covering single-chunk, boundary, and multi-chunk cases. Includes negative values to exercise signed arithmetic.
tests/pytorch/test_sanity.py Replaces the verbose check_grouped_tensor_pointers helper (which checked raw memory layout) with the simpler check_grouped_weight, which verifies the module exposes exactly one weight parameter of shape (num_gemms, out_features, in_features). Env var setup/teardown removed cleanly.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adds GroupedTensorPythonClass global and updates init_grouped_tensor_extension to import both GroupedTensor (from the new grouped_tensor module) and GroupedTensorStorage (from storage.grouped_tensor_storage). Adds Python binding for the new cumsum extension function.

Class Diagram

%%{init: {'theme': 'neutral'}}%%
classDiagram
    class `torch.Tensor` {
        <<built-in>>
    }
    class GroupedTensorStorage {
        +num_tensors: int
        +quantizers: List~Quantizer~
        +tensor_shapes: List~Tuple~
        +rowwise_data: Tensor
        +columnwise_data: Tensor
        +logical_shape: Tuple
        +fake_dtype: dtype
        +quantized_tensors: List
        +make_grouped_tensor()$
        +make_grouped_tensor_with_shapes()$
        +split_into_quantized_tensors()
        +quantize()
        +has_data()
        +all_same_shape()
    }
    class GroupedTensor {
        +__new__()
        +__torch_dispatch__()$
        +__torch_function__()$
    }
    class GroupedLinear {
        +single_grouped_parameter: bool
        +make_grouped_weights()
        +reset_parameters()
        +_get_weight_tensors()
        +set_tensor_parallel_attributes()
    }
    class Quantizer {
        <<abstract>>
        +internal: bool
        +rowwise_usage: bool
        +columnwise_usage: bool
        +create_grouped_tensor()
    }

    GroupedTensorStorage <|-- GroupedTensor
    `torch.Tensor` <|-- GroupedTensor
    GroupedLinear --> GroupedTensor : weight (single_grouped_parameter=True)
    GroupedLinear --> GroupedTensorStorage : internal weights
    Quantizer --> GroupedTensorStorage : creates via create_grouped_tensor()
    GroupedTensorStorage o-- Quantizer : quantizers[]
Loading

Last reviewed commit: f413a93

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 786 to +795
else:
grouped_weights.quantized_tensors[i].copy_(weights[i])

# Re-register the grouped weights as parameters.
# Re-register as a single grouped weight parameter.
self.register_parameter(
"weight",
torch.nn.Parameter(grouped_weights),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single fp8_meta_index for all GEMM sub-weights

The old per-GEMM registration assigned each weight{i} its own offset:

fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]

The new registration uses a single offset for all sub-weights:

fp8_meta_index=self._offsets["weight"],

For FP8 recipes that use per-tensor metadata (e.g., MXFP8, Float8BlockScaling), this means all num_gemms sub-weights share the same fp8_meta_index. During FP8 metadata updates (e.g., amax tracking), only the first GEMM's metadata will be updated correctly; the remaining GEMMs will silently use stale or zero metadata. Note that make_grouped_weights does guard against delayed() and float8_current_scaling() recipes, but not against mxfp8() or float8_block_scaling(), so those paths could be affected.

quantizer=None,
shapes=shape,
quantizers=None,
device="cuda",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each per-tensor quantizer constructed with full-group num_tensors

Each list entry calls make_quantizer(quantization, num_tensors, shape) with num_tensors=3, meaning each quantizer's internal buffers (e.g., FP8 amax/scale tensors) are sized for the entire group of 3 tensors, not for a single tensor. While this doesn't break correctness today (only index 0 of the per-quantizer buffers is used), it inflates memory usage and diverges from production use, where each per-tensor quantizer should be sized for one tensor.

Consider constructing each quantizer for num_tensors=1:

quantizers = [make_quantizer(quantization, 1, shape) for _ in range(num_tensors)]

Comment on lines 438 to 450
@@ -450,7 +445,7 @@ def make_grouped_tensor(
total_scale_elements = 0
scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
scale_inv_shape = reference_quantizer.get_scale_shape(s, False)
scale_elements = math.prod(scale_inv_shape)
total_scale_elements += scale_elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reference_quantizer used for all per-tensor scale shape calculations

Scale buffer sizes for each tensor are computed using reference_quantizer.get_scale_shape(s, ...) for every index i. If the quantizers in the list share the same recipe type but differ in a parameter that affects block size (e.g., different block_dim in a block-quantizer), scale buffers could be mis-sized for non-reference tensors.

The existing check only validates recipe type equality:

if any(type(q._get_compatible_recipe()) is not type(reference_quantizer._get_compatible_recipe()) ...):

It would be safer to also validate that any shape-determining attributes (e.g., block scaling dim) match across all quantizers.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Additional Comments (1)

transformer_engine/pytorch/tensor/grouped_tensor.py
In-place dispatch returns None instead of self

For in-place (is_mutable) ops, PyTorch convention requires __torch_dispatch__ to return the modified tensor (i.e., args[0]), not None. Returning None means callers that use the return value of an in-place op — for example result = tensor.fill_(0), or frameworks like kaiming_uniform_ which does return tensor.uniform_(a, b) — will receive None instead of the GroupedTensor.

This can silently break parameter initialization when register_parameter internally calls init_fn on the new grouped weight, since torch.nn.init.kaiming_uniform_ returns the result of .uniform_(). If the caller chains on that result, it will see None.

            return args[0]

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman force-pushed the quantized_fused_group_tensor branch from a32a640 to 5213efb Compare March 3, 2026 05:52
@ksivaman
Copy link
Member Author

ksivaman commented Mar 3, 2026

/te-ci L0

constexpr size_t kCumsumThreadsPerBlock = 256;

__global__ void __launch_bounds__(kCumsumThreadsPerBlock)
cumsum_with_leading_zero_kernel(const int64_t *__restrict__ input, int64_t *__restrict__ output,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't need so many syncthreads in the kernel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might prefer a single thread kernel, but if we target a performant cumsum kernel, it's not done properly


const int64_t logical_last_dim_i64 = static_cast<int64_t>(logical_last_dim);
auto scaled_first_dims = first_dims_tensor * logical_last_dim_i64;
auto scaled_first_dims = (first_dims_tensor * logical_last_dim_i64).contiguous();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be fused with the cumsum

the cumsum kernel can be renamed to something like nvte_compute_tensor_offsets

return tensor ? py::cast(*tensor) : py::none();
}

py::object make_grouped_quantizers(const py::object& quantizer, const size_t num_tensors) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

"first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(),
"last_dims"_a = py::none(),
"tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(),
"logical_shape"_a = std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we delete it?

weight_tensors = grouped_weight.quantized_tensors
if weight_tensors is None:
# TODO(ksivaman): Remove this after GEMM integration.
weight_tensors = grouped_weight.split_into_quantized_tensors()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

rowwise_usage = quantizer.rowwise_usage if not no_quantization else True
columnwise_usage = quantizer.columnwise_usage if not no_quantization else False
no_quantization = quantizers is None or all(q is None for q in quantizers)
reference_quantizer = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is reference_quantizer


size_t get_cudnn_version() { return cudnnGetVersion(); }

at::Tensor cumsum(at::Tensor input, std::optional<at::Tensor> out) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to clean it up and replace it with a nvte_compute_tensor_offsets call

@ksivaman ksivaman closed this Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants