Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3e7859c
PyTorch-Python GroupedTensor
ksivaman Feb 6, 2026
ffeace8
grouped gemm support for bf16, bias support missing
vthumbe1503 Feb 4, 2026
aa86859
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
98cb4fa
remove changes not needed for bf16
vthumbe1503 Feb 11, 2026
b4c91c5
Merge branch 'users/vthumbe/pytorch_binding_for_cublas_gemm' of githu…
vthumbe1503 Feb 11, 2026
1d041f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
3b2840e
Merge branch 'main' into users/vthumbe/pytorch_binding_for_cublas_gemm
vthumbe1503 Feb 11, 2026
d5f9569
resolve merge conflicts wit main
vthumbe1503 Mar 6, 2026
733a061
resolve merge conflicts agains
vthumbe1503 Mar 6, 2026
d433e5f
merge conflicts
vthumbe1503 Mar 6, 2026
38cf811
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
158d232
keep only pytorch binding for now
vthumbe1503 Mar 9, 2026
63efd1b
Merge branch 'users/vthumbe/pytorch_binding_for_cublas_gemm' of githu…
vthumbe1503 Mar 9, 2026
70abb18
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2026
ea33349
Merge branch 'main' into users/vthumbe/pytorch_binding_for_cublas_gemm
vthumbe1503 Mar 9, 2026
d0604eb
linting error
vthumbe1503 Mar 9, 2026
296dbb2
Merge branch 'users/vthumbe/pytorch_binding_for_cublas_gemm' of githu…
vthumbe1503 Mar 9, 2026
9fd6b8b
add fast accumulator support
vthumbe1503 Mar 9, 2026
b48ef61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2026
6d04d62
fix the test
vthumbe1503 Mar 10, 2026
7a6a590
fix merge conflict
vthumbe1503 Mar 10, 2026
1c94649
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ struct TestParams {
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}};
}
}

Expand Down
191 changes: 190 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import os
from torch._tensor import Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused private-module import

Tensor is imported from torch._tensor, which is an internal/private PyTorch module. This import is never referenced in the new test code (test_grouped_gemm_grouped_tensor and its helpers), so it is dead code.

Importing from private modules (prefixed with _) is fragile — it can break without notice across PyTorch releases.

Suggested change
from torch._tensor import Tensor
from typing import Dict, List, Tuple, Optional

from typing import Dict, List, Tuple, Optional
import pytest
import random
Expand Down Expand Up @@ -46,7 +47,12 @@
is_nvfp4_available,
)
from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions import (
general_gemm,
general_grouped_gemm,
general_grouped_gemm_for_grouped_tensor,
)
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states
Expand Down Expand Up @@ -2792,6 +2798,189 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
if data is None:
data = grouped_tensor.columnwise_data
if data is None:
raise ValueError("GroupedTensor has no data buffers to pack.")
offset = 0
for tensor in tensors:
numel = tensor.numel()
data[offset : offset + numel].copy_(tensor.reshape(-1))
offset += numel


def _make_grouped_tensor_from_splits(
m_sizes: List[int],
last_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> GroupedTensor:
first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64)
return GroupedTensor.make_grouped_tensor(
num_tensors=len(m_sizes),
first_dims=first_dims,
last_dims=None,
logical_first_dim=sum(m_sizes),
logical_last_dim=last_dim,
quantizer=None,
device=device,
dtype=dtype,
)


def _make_grouped_tensor_uniform(
num_tensors: int,
first_dim: int,
last_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> GroupedTensor:
return GroupedTensor.make_grouped_tensor(
num_tensors=num_tensors,
first_dims=None,
last_dims=None,
logical_first_dim=num_tensors * first_dim,
logical_last_dim=last_dim,
quantizer=None,
device=device,
dtype=dtype,
)


@pytest.mark.parametrize(
"z, m, n, k",
[
(4, 256, 256, 256),
(4, 512, 256, 512),
(4, 512, 512, 256),
(8, 512, 256, 512),
],
)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm_grouped_tensor(z, m, n, k, layout, accumulate) -> None:
if tex.get_cublasLt_version() < 130200:
pytest.skip("Grouped GEMM requires cuBLAS 13.2+.")
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")

torch.manual_seed(0)

dtype = torch.bfloat16

split_points = torch.randperm(m - 1)[: z - 1] + 1
split_points = torch.sort(split_points).values.tolist()
m_sizes = [split_points[0]]
m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])]
m_sizes.append(m - split_points[-1])
assert sum(m_sizes) == m and len(m_sizes) == z

if layout == "NT":
A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input
B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)]
else:
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [
torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda")
for ms in m_sizes
] # TN --> input, NN --> grad_output
out = [
torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda")
for ms in m_sizes
] # TN --> output, NN --> dgrad
if layout == "NN":
out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)]
else: # layout == "TN"
out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)]

if accumulate:
out_ref = [out[i].float() + o for i, o in enumerate(out_ref)]

# Bias is applied after GEMM (broadcasted along rows)
# Match kernel behavior: GEMM output is already in output dtype when bias is added.
out_ref = [o.to(dtype) for o in out_ref]
if layout == "TN":
bias_last_dim = n
else: # layout == "NT" or "NN"
bias_last_dim = k
bias = [torch.zeros(1, bias_last_dim, dtype=dtype, device="cuda") * 0.01 for _ in range(z)]
# Bias add in grouped kernel accumulates in FP32 for BF16/FP16.
out_ref = [(o.float() + b.float()).to(dtype) for o, b in zip(out_ref, bias)]
# Create grouped tensors based on case
device = A[0].device
grouped_A = A
grouped_out = out
grouped_out_bias = None
grouped_out_no_bias = None
if layout == "TN":
grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) #
grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input
grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output
grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype)
grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype)
elif layout == "NN":
grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) # weight
grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output
grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype)
grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype)
grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype)
else: # layout == "NT"
grouped_A = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input
grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output
grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad
grouped_out_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype)
grouped_out_no_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype)
_pack_grouped_tensor(grouped_B, B)
_pack_grouped_tensor(grouped_out, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grouped_out is packed but never used

grouped_out is allocated and then initialized here with _pack_grouped_tensor(grouped_out, out), but it is never passed to any general_grouped_gemm_for_grouped_tensor call. The two actual GEMM calls use grouped_out_no_bias and grouped_out_bias. grouped_out is therefore dead code — the allocation and packing work is wasted.

Consider removing this line (and the corresponding grouped_out allocations in each if/elif/else branch) to avoid confusion and unnecessary GPU memory traffic.

_pack_grouped_tensor(grouped_out_bias, out)
_pack_grouped_tensor(grouped_out_no_bias, out)
_pack_grouped_tensor(grouped_A, A)

grouped_bias = _make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype)
_pack_grouped_tensor(grouped_bias, bias)

general_grouped_gemm_for_grouped_tensor(
grouped_A,
grouped_B,
grouped_out_no_bias,
layout=layout,
accumulate=accumulate,
bias=None,
)
general_grouped_gemm_for_grouped_tensor(
grouped_A,
grouped_B,
grouped_out_bias,
layout=layout,
accumulate=accumulate,
bias=grouped_bias,
)
out_grouped_no_bias = (
grouped_out_no_bias
if isinstance(grouped_out_no_bias, list)
else grouped_out_no_bias.split_into_quantized_tensors()
)
out_grouped_bias = (
grouped_out_bias
if isinstance(grouped_out_bias, list)
else grouped_out_bias.split_into_quantized_tensors()
)
out_grouped_manual_bias = [
(o.float() + b.float()).to(dtype) for o, b in zip(out_grouped_no_bias, bias)
]
tols = dtype_tols(dtype)
for o, o_ref in zip(out_grouped_bias, out_ref):
torch.testing.assert_close(o, o_ref, **tols)
for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias):
torch.testing.assert_close(o, o_ref, **tols)


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
18 changes: 18 additions & 0 deletions transformer_engine/common/gemm/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");

// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto bool_to_uint8 = [](bool in, void *out) {
*reinterpret_cast<uint8_t *>(out) = static_cast<uint8_t>(in);
};

// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::GroupedMatmulConfig *>(config);
Expand All @@ -172,6 +178,9 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigUseSplitAccumulator:
bool_to_uint8(config_.use_split_accumulator, buf);
break;
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(buf, &config_.sm_count, attr_size);
break;
Expand All @@ -194,6 +203,12 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");

// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto uint8_to_bool = [](const void *in, bool &out) {
out = static_cast<bool>(*reinterpret_cast<const uint8_t *>(in));
};

// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::GroupedMatmulConfig *>(config);
Expand All @@ -216,6 +231,9 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
config_.avg_k = val;
break;
}
case kNVTEGroupedMatmulConfigUseSplitAccumulator:
uint8_to_bool(buf, config_.use_split_accumulator);
break;
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(&config_.sm_count, buf, attr_size);
break;
Expand Down
9 changes: 6 additions & 3 deletions transformer_engine/common/gemm/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ struct GroupedMatmulConfig {
// Number of streaming multiprocessors to use in GEMM kernel
int sm_count = 0;

// Whether to use split accumulator for FP8 GEMM
bool use_split_accumulator = false;

// Note: API transfers the value type, not std::optional
static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type),
sizeof(decltype(avg_n)::value_type),
sizeof(decltype(avg_k)::value_type), sizeof(sm_count)};
static constexpr size_t attr_sizes[] = {
sizeof(decltype(avg_m)::value_type), sizeof(decltype(avg_n)::value_type),
sizeof(decltype(avg_k)::value_type), sizeof(uint8_t), sizeof(sm_count)};
};

} // namespace transformer_engine
Expand Down
Loading