-
Notifications
You must be signed in to change notification settings - Fork 655
Pytorch binding for cublas gemm #2669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3e7859c
ffeace8
aa86859
98cb4fa
b4c91c5
1d041f8
3b2840e
d5f9569
733a061
d433e5f
38cf811
158d232
63efd1b
70abb18
ea33349
d0604eb
296dbb2
9fd6b8b
b48ef61
6d04d62
7a6a590
1c94649
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import math | ||
| import os | ||
| from torch._tensor import Tensor | ||
| from typing import Dict, List, Tuple, Optional | ||
| import pytest | ||
| import random | ||
|
|
@@ -46,7 +47,12 @@ | |
| is_nvfp4_available, | ||
| ) | ||
| from transformer_engine.pytorch import checkpoint as te_checkpoint | ||
| from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm | ||
| from transformer_engine.pytorch.cpp_extensions import ( | ||
| general_gemm, | ||
| general_grouped_gemm, | ||
| general_grouped_gemm_for_grouped_tensor, | ||
| ) | ||
| from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor | ||
| from transformer_engine.common import recipe | ||
| import transformer_engine_torch as tex | ||
| from utils import ModelConfig, reset_rng_states | ||
|
|
@@ -2792,6 +2798,189 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): | |
| os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) | ||
|
|
||
|
|
||
| def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: | ||
| data = grouped_tensor.rowwise_data | ||
| if data is None: | ||
| data = grouped_tensor.columnwise_data | ||
| if data is None: | ||
| raise ValueError("GroupedTensor has no data buffers to pack.") | ||
| offset = 0 | ||
| for tensor in tensors: | ||
| numel = tensor.numel() | ||
| data[offset : offset + numel].copy_(tensor.reshape(-1)) | ||
| offset += numel | ||
|
|
||
|
|
||
| def _make_grouped_tensor_from_splits( | ||
| m_sizes: List[int], | ||
| last_dim: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype, | ||
| ) -> GroupedTensor: | ||
| first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) | ||
| return GroupedTensor.make_grouped_tensor( | ||
| num_tensors=len(m_sizes), | ||
| first_dims=first_dims, | ||
| last_dims=None, | ||
| logical_first_dim=sum(m_sizes), | ||
| logical_last_dim=last_dim, | ||
| quantizer=None, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
|
|
||
| def _make_grouped_tensor_uniform( | ||
| num_tensors: int, | ||
| first_dim: int, | ||
| last_dim: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype, | ||
| ) -> GroupedTensor: | ||
| return GroupedTensor.make_grouped_tensor( | ||
| num_tensors=num_tensors, | ||
| first_dims=None, | ||
| last_dims=None, | ||
| logical_first_dim=num_tensors * first_dim, | ||
| logical_last_dim=last_dim, | ||
| quantizer=None, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "z, m, n, k", | ||
| [ | ||
| (4, 256, 256, 256), | ||
| (4, 512, 256, 512), | ||
| (4, 512, 512, 256), | ||
| (8, 512, 256, 512), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) | ||
| @pytest.mark.parametrize("accumulate", [False, True]) | ||
| def test_grouped_gemm_grouped_tensor(z, m, n, k, layout, accumulate) -> None: | ||
| if tex.get_cublasLt_version() < 130200: | ||
| pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") | ||
| if torch.cuda.get_device_capability() < (10, 0): | ||
| pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") | ||
| if not is_bf16_available(): | ||
| pytest.skip("bfloat16 is required for grouped GEMM test.") | ||
|
|
||
| torch.manual_seed(0) | ||
|
|
||
| dtype = torch.bfloat16 | ||
|
|
||
| split_points = torch.randperm(m - 1)[: z - 1] + 1 | ||
| split_points = torch.sort(split_points).values.tolist() | ||
| m_sizes = [split_points[0]] | ||
| m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] | ||
| m_sizes.append(m - split_points[-1]) | ||
| assert sum(m_sizes) == m and len(m_sizes) == z | ||
|
|
||
| if layout == "NT": | ||
| A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input | ||
| B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output | ||
| out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad | ||
| out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] | ||
| else: | ||
| A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight | ||
| B = [ | ||
| torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") | ||
| for ms in m_sizes | ||
| ] # TN --> input, NN --> grad_output | ||
| out = [ | ||
| torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") | ||
| for ms in m_sizes | ||
| ] # TN --> output, NN --> dgrad | ||
| if layout == "NN": | ||
| out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] | ||
| else: # layout == "TN" | ||
| out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] | ||
|
|
||
| if accumulate: | ||
| out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] | ||
|
|
||
| # Bias is applied after GEMM (broadcasted along rows) | ||
| # Match kernel behavior: GEMM output is already in output dtype when bias is added. | ||
| out_ref = [o.to(dtype) for o in out_ref] | ||
| if layout == "TN": | ||
| bias_last_dim = n | ||
| else: # layout == "NT" or "NN" | ||
| bias_last_dim = k | ||
| bias = [torch.zeros(1, bias_last_dim, dtype=dtype, device="cuda") * 0.01 for _ in range(z)] | ||
| # Bias add in grouped kernel accumulates in FP32 for BF16/FP16. | ||
| out_ref = [(o.float() + b.float()).to(dtype) for o, b in zip(out_ref, bias)] | ||
| # Create grouped tensors based on case | ||
| device = A[0].device | ||
| grouped_A = A | ||
| grouped_out = out | ||
| grouped_out_bias = None | ||
| grouped_out_no_bias = None | ||
| if layout == "TN": | ||
| grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) # | ||
| grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input | ||
| grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output | ||
| grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) | ||
| grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) | ||
| elif layout == "NN": | ||
| grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) # weight | ||
| grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output | ||
| grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) | ||
| grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) | ||
| grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) | ||
| else: # layout == "NT" | ||
| grouped_A = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input | ||
| grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output | ||
| grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad | ||
| grouped_out_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) | ||
| grouped_out_no_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) | ||
| _pack_grouped_tensor(grouped_B, B) | ||
| _pack_grouped_tensor(grouped_out, out) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider removing this line (and the corresponding |
||
| _pack_grouped_tensor(grouped_out_bias, out) | ||
| _pack_grouped_tensor(grouped_out_no_bias, out) | ||
| _pack_grouped_tensor(grouped_A, A) | ||
|
|
||
| grouped_bias = _make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype) | ||
| _pack_grouped_tensor(grouped_bias, bias) | ||
|
|
||
| general_grouped_gemm_for_grouped_tensor( | ||
| grouped_A, | ||
| grouped_B, | ||
| grouped_out_no_bias, | ||
| layout=layout, | ||
| accumulate=accumulate, | ||
| bias=None, | ||
| ) | ||
| general_grouped_gemm_for_grouped_tensor( | ||
| grouped_A, | ||
| grouped_B, | ||
| grouped_out_bias, | ||
| layout=layout, | ||
| accumulate=accumulate, | ||
| bias=grouped_bias, | ||
| ) | ||
| out_grouped_no_bias = ( | ||
| grouped_out_no_bias | ||
| if isinstance(grouped_out_no_bias, list) | ||
| else grouped_out_no_bias.split_into_quantized_tensors() | ||
| ) | ||
| out_grouped_bias = ( | ||
| grouped_out_bias | ||
| if isinstance(grouped_out_bias, list) | ||
| else grouped_out_bias.split_into_quantized_tensors() | ||
| ) | ||
| out_grouped_manual_bias = [ | ||
| (o.float() + b.float()).to(dtype) for o, b in zip(out_grouped_no_bias, bias) | ||
| ] | ||
| tols = dtype_tols(dtype) | ||
| for o, o_ref in zip(out_grouped_bias, out_ref): | ||
| torch.testing.assert_close(o, o_ref, **tols) | ||
| for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): | ||
| torch.testing.assert_close(o, o_ref, **tols) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("N", [32]) | ||
| @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused private-module import
Tensoris imported fromtorch._tensor, which is an internal/private PyTorch module. This import is never referenced in the new test code (test_grouped_gemm_grouped_tensorand its helpers), so it is dead code.Importing from private modules (prefixed with
_) is fragile — it can break without notice across PyTorch releases.