diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index a7aabbbcb6..abe92059b5 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -88,16 +88,16 @@ struct TestParams { std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: - return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}}; case ShapeCase::kSameFirst: // Same M (first dim), varying N and K - return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; + return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}}; case ShapeCase::kSameLast: // Same N (last dim), varying M and K - return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; + return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; + return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; } } diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 8e3b0517ee..43bf4d3ad6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -4,6 +4,7 @@ import math import os +from torch._tensor import Tensor from typing import Dict, List, Tuple, Optional import pytest import random @@ -46,7 +47,12 @@ is_nvfp4_available, ) from transformer_engine.pytorch import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions import ( + general_gemm, + general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, +) +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states @@ -2792,6 +2798,189 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + data = grouped_tensor.rowwise_data + if data is None: + data = grouped_tensor.columnwise_data + if data is None: + raise ValueError("GroupedTensor has no data buffers to pack.") + offset = 0 + for tensor in tensors: + numel = tensor.numel() + data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + +def _make_grouped_tensor_from_splits( + m_sizes: List[int], + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) + return GroupedTensor.make_grouped_tensor( + num_tensors=len(m_sizes), + first_dims=first_dims, + last_dims=None, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +def _make_grouped_tensor_uniform( + num_tensors: int, + first_dim: int, + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, + last_dims=None, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +@pytest.mark.parametrize( + "z, m, n, k", + [ + (4, 256, 256, 256), + (4, 512, 256, 512), + (4, 512, 512, 256), + (8, 512, 256, 512), + ], +) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +def test_grouped_gemm_grouped_tensor(z, m, n, k, layout, accumulate) -> None: + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + + dtype = torch.bfloat16 + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + if layout == "NT": + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] + else: + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [ + torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> input, NN --> grad_output + out = [ + torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> output, NN --> dgrad + if layout == "NN": + out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] + else: # layout == "TN" + out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] + + if accumulate: + out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] + + # Bias is applied after GEMM (broadcasted along rows) + # Match kernel behavior: GEMM output is already in output dtype when bias is added. + out_ref = [o.to(dtype) for o in out_ref] + if layout == "TN": + bias_last_dim = n + else: # layout == "NT" or "NN" + bias_last_dim = k + bias = [torch.zeros(1, bias_last_dim, dtype=dtype, device="cuda") * 0.01 for _ in range(z)] + # Bias add in grouped kernel accumulates in FP32 for BF16/FP16. + out_ref = [(o.float() + b.float()).to(dtype) for o, b in zip(out_ref, bias)] + # Create grouped tensors based on case + device = A[0].device + grouped_A = A + grouped_out = out + grouped_out_bias = None + grouped_out_no_bias = None + if layout == "TN": + grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) # + grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output + grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) # weight + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_A = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad + grouped_out_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_B, B) + _pack_grouped_tensor(grouped_out, out) + _pack_grouped_tensor(grouped_out_bias, out) + _pack_grouped_tensor(grouped_out_no_bias, out) + _pack_grouped_tensor(grouped_A, A) + + grouped_bias = _make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype) + _pack_grouped_tensor(grouped_bias, bias) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_no_bias, + layout=layout, + accumulate=accumulate, + bias=None, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_bias, + layout=layout, + accumulate=accumulate, + bias=grouped_bias, + ) + out_grouped_no_bias = ( + grouped_out_no_bias + if isinstance(grouped_out_no_bias, list) + else grouped_out_no_bias.split_into_quantized_tensors() + ) + out_grouped_bias = ( + grouped_out_bias + if isinstance(grouped_out_bias, list) + else grouped_out_bias.split_into_quantized_tensors() + ) + out_grouped_manual_bias = [ + (o.float() + b.float()).to(dtype) for o, b in zip(out_grouped_no_bias, bias) + ] + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped_bias, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): + torch.testing.assert_close(o, o_ref, **tols) + + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp index 286fc0cc96..de533909f6 100644 --- a/transformer_engine/common/gemm/config.cpp +++ b/transformer_engine/common/gemm/config.cpp @@ -153,6 +153,12 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, " bytes)"); + // bool size is implementation-dependent, so we explicitly specify + // uint8_t in the user-facing API. + auto bool_to_uint8 = [](bool in, void *out) { + *reinterpret_cast(out) = static_cast(in); + }; + // Write to buffer NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); const auto &config_ = *reinterpret_cast(config); @@ -172,6 +178,9 @@ void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, std::memcpy(buf, &val, attr_size); break; } + case kNVTEGroupedMatmulConfigUseSplitAccumulator: + bool_to_uint8(config_.use_split_accumulator, buf); + break; case kNVTEGroupedMatmulConfigSMCount: std::memcpy(buf, &config_.sm_count, attr_size); break; @@ -194,6 +203,12 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, " bytes)"); NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + // bool size is implementation-dependent, so we explicitly specify + // uint8_t in the user-facing API. + auto uint8_to_bool = [](const void *in, bool &out) { + out = static_cast(*reinterpret_cast(in)); + }; + // Read from buffer NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); auto &config_ = *reinterpret_cast(config); @@ -216,6 +231,9 @@ void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, config_.avg_k = val; break; } + case kNVTEGroupedMatmulConfigUseSplitAccumulator: + uint8_to_bool(buf, config_.use_split_accumulator); + break; case kNVTEGroupedMatmulConfigSMCount: std::memcpy(&config_.sm_count, buf, attr_size); break; diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h index ad38e88334..9ab8201fff 100644 --- a/transformer_engine/common/gemm/config.h +++ b/transformer_engine/common/gemm/config.h @@ -44,10 +44,13 @@ struct GroupedMatmulConfig { // Number of streaming multiprocessors to use in GEMM kernel int sm_count = 0; + // Whether to use split accumulator for FP8 GEMM + bool use_split_accumulator = false; + // Note: API transfers the value type, not std::optional - static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type), - sizeof(decltype(avg_n)::value_type), - sizeof(decltype(avg_k)::value_type), sizeof(sm_count)}; + static constexpr size_t attr_sizes[] = { + sizeof(decltype(avg_m)::value_type), sizeof(decltype(avg_n)::value_type), + sizeof(decltype(avg_k)::value_type), sizeof(uint8_t), sizeof(sm_count)}; }; } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index dc4757ab90..53b5c6a628 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include +#include #include "../common.h" #include "../util/cuda_runtime.h" @@ -364,7 +365,7 @@ inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA, } inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A, - cublasOperation_t op_B) { + cublasOperation_t op_B, bool use_split_accumulator, bool use_fp8) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A, @@ -383,6 +384,11 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); + + // Fast accumulation mode: 0 = split accumulator (more accurate), 1 = fast accumulator + int8_t fastAccuMode = use_split_accumulator ? 0 : use_fp8; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, sizeof(fastAccuMode))); } inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, @@ -463,6 +469,36 @@ __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorSha } } +template +__global__ void grouped_bias_add_kernel(char *d_base, const char *bias_base, TensorShapeInfo d_meta, + TensorShapeInfo bias_meta, size_t num_tensors) { + const size_t tensor_idx = blockIdx.x; + if (tensor_idx >= num_tensors) return; + + const int64_t m = d_meta.first_dims ? d_meta.first_dims[tensor_idx] : d_meta.uniform_first; + const int64_t n = d_meta.last_dims ? d_meta.last_dims[tensor_idx] : d_meta.uniform_last; + if (m == 0 || n == 0) return; + + const int64_t bias_n = + bias_meta.last_dims ? bias_meta.last_dims[tensor_idx] : bias_meta.uniform_last; + + const int64_t d_offset = compute_grouped_tensor_offset(d_meta, tensor_idx); + const int64_t bias_offset = compute_grouped_tensor_offset(bias_meta, tensor_idx); + + auto *d_ptr = reinterpret_cast(d_base + d_offset * sizeof(T)); + const auto *bias_ptr = reinterpret_cast(bias_base + bias_offset * sizeof(T)); + + const int64_t elements = m * n; + const int64_t stride = static_cast(blockDim.x) * gridDim.y; + for (int64_t linear = static_cast(blockIdx.y) * blockDim.x + threadIdx.x; + linear < elements; linear += stride) { + const int64_t col = linear % n; + if (col < bias_n) { + d_ptr[linear] = d_ptr[linear] + bias_ptr[col]; + } + } +} + // Single kernel that sets up all GEMM parameters. // Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions, // but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes. @@ -631,7 +667,8 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT // Create matmul descriptor cublasLtMatmulDescOpaque_t matmulDesc; - init_matmul_desc(matmulDesc, op_A, op_B); + const bool use_fp8 = is_fp8_dtype(A_sel.dtype) || is_fp8_dtype(B_sel.dtype); + init_matmul_desc(matmulDesc, op_A, op_B, config_.use_split_accumulator, use_fp8); set_fp8_scale_pointers(matmulDesc, A_sel, B_sel); // Compute average dimensions for heuristics @@ -654,6 +691,54 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT kGroupedGemmCublasWorkspaceSize, stream)); } +void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, + cudaStream_t stream) { + NVTE_API_CALL(nvte_grouped_bias_add); + using namespace transformer_engine; + + const GroupedTensor *outputD = convertNVTEGroupedTensorCheck(output); + const GroupedTensor *bias_tensor = convertNVTEGroupedTensorCheck(bias); + + NVTE_CHECK(outputD->num_tensors >= 1, "Grouped bias add: number of tensors must be at least 1"); + NVTE_CHECK(outputD->num_tensors == bias_tensor->num_tensors, + "Grouped bias add: output and bias must have the same number of tensors"); + NVTE_CHECK(outputD->has_data(), "Grouped bias add: output is missing row-wise data"); + NVTE_CHECK(bias_tensor->has_data(), "Grouped bias add: bias is missing row-wise data"); + NVTE_CHECK(outputD->dtype() == bias_tensor->dtype(), + "Grouped bias add: output and bias must have matching dtypes"); + NVTE_CHECK(bias_tensor->all_same_first_dim(), + "Grouped bias add: bias must have uniform first dim (expected 1)"); + NVTE_CHECK(bias_tensor->get_common_first_dim() == 1, + "Grouped bias add: bias first dim must be 1"); + if (outputD->all_same_last_dim() && bias_tensor->all_same_last_dim()) { + NVTE_CHECK(outputD->get_common_last_dim() == bias_tensor->get_common_last_dim(), + "Grouped bias add: output and bias last dims must match"); + } + + const TensorShapeInfo d_meta = TensorShapeInfo::from_tensor(outputD); + const TensorShapeInfo bias_meta = TensorShapeInfo::from_tensor(bias_tensor); + + const DType dtype = outputD->dtype(); + constexpr int kThreads = 256; + constexpr int kMaxBlocksPerTensor = 128; + const size_t total_elements = static_cast(outputD->logical_shape.data[0]) * + static_cast(outputD->logical_shape.data[1]); + const size_t avg_elements = total_elements / outputD->num_tensors; + int blocks_per_tensor = static_cast((avg_elements + kThreads - 1) / kThreads); + if (blocks_per_tensor < 1) blocks_per_tensor = 1; + if (blocks_per_tensor > kMaxBlocksPerTensor) blocks_per_tensor = kMaxBlocksPerTensor; + const dim3 grid(outputD->num_tensors, blocks_per_tensor); + const dim3 block(kThreads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, T, { + grouped_bias_add_kernel<<>>( + static_cast(outputD->data.dptr), static_cast(bias_tensor->data.dptr), + d_meta, bias_meta, outputD->num_tensors); + }); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + #else // CUBLAS_VERSION < 130200 void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, @@ -665,6 +750,12 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); } +void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, + cudaStream_t stream) { + NVTE_ERROR("nvte_grouped_bias_add requires cuBLAS 13.2+, but compile-time cuBLAS version is ", + CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer."); +} + size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) { NVTE_ERROR( "nvte_get_grouped_gemm_setup_workspace_size requires cuBLAS 13.2+, but compile-time cuBLAS " diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0f3b0ebd6b..ce0aff933e 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -82,8 +82,10 @@ enum NVTEGroupedMatmulConfigAttribute { * computed automatically from A's logical shape. */ kNVTEGroupedMatmulConfigAvgK = 2, + /*! Whether to use split accumulator for FP8 GEMM. */ + kNVTEGroupedMatmulConfigUseSplitAccumulator = 3, /*! Number of streaming multiprocessors to use in GEMM kernel. */ - kNVTEGroupedMatmulConfigSMCount = 3, + kNVTEGroupedMatmulConfigSMCount = 4, kNVTEGroupedMatmulConfigNumAttributes }; @@ -359,6 +361,8 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT const NVTETensor beta, NVTETensor workspace_setup, NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, cudaStream_t stream); +void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" @@ -522,6 +526,13 @@ class GroupedMatmulConfigWrapper { sizeof(int)); } + /*! \brief Set whether to use split accumulator for FP8 GEMM. */ + void set_use_split_accumulator(bool use_split_accumulator) { + const auto val = static_cast(use_split_accumulator); + nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigUseSplitAccumulator, + &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEGroupedMatmulConfig. */ NVTEGroupedMatmulConfig config_ = nullptr; diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index a37f1c2d4d..fa58166252 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -5,6 +5,7 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List +import ctypes import os import functools import torch @@ -22,6 +23,7 @@ __all__ = [ "general_gemm", "general_grouped_gemm", + "general_grouped_gemm_for_grouped_tensor", ] @@ -80,7 +82,7 @@ def general_gemm( layout: str = "TN", out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - use_split_accumulator: bool = False, + use_split_accumulator: bool = True, grad: bool = False, ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, ub_type: tex.CommOverlapType = None, @@ -284,3 +286,99 @@ def general_grouped_gemm( ) return out, bias, gelu_input + + +def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: + """Return workspace size for grouped GEMM pointer setup. + Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. + """ + ptr_bytes = ctypes.sizeof(ctypes.c_void_p) + int_bytes = ctypes.sizeof(ctypes.c_int) + ptr_size = num_tensors * ptr_bytes + int_size = num_tensors * int_bytes + k_ptr_alignment = 16 + aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment + size = 6 * aligned_ptr_size + 6 * int_size + alignment = 256 + return ((size + alignment - 1) // alignment) * alignment + + +def general_grouped_gemm_for_grouped_tensor( + A, + B, + out, + *, + layout: str = "TN", + accumulate: bool = False, + use_split_accumulator: bool = False, + bias=None, + alpha: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Grouped GEMM using GroupedTensor inputs. + + This uses nvte_grouped_gemm and supports different per-matrix shapes. + + The caller must ensure that GroupedTensor metadata is already compatible with the + underlying GEMM implementation (e.g., aligned offsets and output metadata layout). + """ + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." + transa = layout[0] == "T" + transb = layout[1] == "T" + + num_tensors = A.num_tensors + + if out.rowwise_data is not None: + device = out.data.device + elif out.columnwise_data is not None: + device = out.columnwise_data.device + else: + raise ValueError("Output GroupedTensor must have allocated data.") + if bias is not None: + if bias.rowwise_data is None: + raise ValueError("Bias GroupedTensor must have rowwise_data.") + if bias.num_tensors != num_tensors: + raise ValueError("Bias GroupedTensor must match num_tensors.") + if bias.rowwise_data.device != device: + raise ValueError("Bias GroupedTensor must be on the same device as output.") + + if alpha is None: + alpha = torch.ones(num_tensors, dtype=torch.float32, device=device) + if beta is None: + if accumulate: + beta = torch.ones(num_tensors, dtype=torch.float32, device=device) + else: + beta = torch.zeros(num_tensors, dtype=torch.float32, device=device) + + if not alpha.is_cuda or not beta.is_cuda: + raise ValueError("alpha and beta must be CUDA tensors.") + + workspace_setup = torch.empty( + get_grouped_gemm_setup_workspace_size(num_tensors), + dtype=torch.uint8, + device=device, + ) + workspace_cublas = torch.empty( + get_cublas_workspace_size_bytes(), + dtype=torch.uint8, + device=device, + ) + + sm_count = get_sm_count() + sm_count = sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))) + + return tex.te_general_grouped_gemm_for_grouped_tensor( + A, + transa, + B, + transb, + out, + bias, + alpha, + beta, + workspace_setup, + workspace_cublas, + use_split_accumulator, + sm_count, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4d4e5094c..fc647bf47a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -149,6 +149,11 @@ std::optional> te_general_grouped_gemm( std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +py::object te_general_grouped_gemm_for_grouped_tensor( + py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index d75b0f14c7..87a1abd8ff 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -78,6 +78,46 @@ bool checkGemmShape(const std::vector& expected, const NVTEShape& actual return true; } +struct GroupedGemmConfig { + TensorWrapper te_alpha; + TensorWrapper te_beta; + TensorWrapper te_workspace_setup; + TensorWrapper te_workspace_cublas; + std::optional matmul_config; +}; + +GroupedGemmConfig prepare_grouped_gemm_config(at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, size_t num_tensors, + int math_sm_count, bool use_split_accumulator) { + NVTE_CHECK(alpha.numel() == static_cast(num_tensors), + "Grouped GEMM expects alpha to have num_tensors elements."); + NVTE_CHECK(beta.numel() == static_cast(num_tensors), + "Grouped GEMM expects beta to have num_tensors elements."); + + GroupedGemmConfig grouped_gemm_config{ + makeTransformerEngineTensor(alpha), + makeTransformerEngineTensor(beta), + makeTransformerEngineTensor(workspace_setup.data_ptr(), + std::vector{static_cast(workspace_setup.numel())}, + DType::kByte), + makeTransformerEngineTensor( + workspace_cublas.data_ptr(), + std::vector{static_cast(workspace_cublas.numel())}, DType::kByte), + std::nullopt, + }; + + if (math_sm_count > 0 || use_split_accumulator) { + grouped_gemm_config.matmul_config.emplace(); + if (math_sm_count > 0) { + grouped_gemm_config.matmul_config->set_sm_count(math_sm_count); + } + grouped_gemm_config.matmul_config->set_use_split_accumulator(use_split_accumulator); + } + + return grouped_gemm_config; +} + } // namespace detail std::pair createOutputTensor(const std::vector& shape, @@ -570,4 +610,52 @@ std::optional> te_general_grouped_gemm( return bias; } +py::object te_general_grouped_gemm_for_grouped_tensor( + py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, + at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, + bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine::pytorch::detail; + + init_extension(); + + // Ensure that cublasLt handle is created on the correct device, + // overriding torch.cuda.set_device calls from user side. + // Assumes all tensors passed are on the same device. + at::cuda::CUDAGuard device_guard(workspace_cublas.device()); + + auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); + auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); + auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); + + const size_t num_tensors = grouped_A.num_tensors(); + NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); + NVTE_CHECK(grouped_B.num_tensors() == num_tensors, + "Grouped GEMM requires A and B to have the same num_tensors."); + NVTE_CHECK(grouped_D.num_tensors() == num_tensors, + "Grouped GEMM requires D to have the same num_tensors as inputs."); + + auto gemm_config = prepare_grouped_gemm_config(alpha, beta, workspace_setup, workspace_cublas, + num_tensors, math_sm_count, use_split_accumulator); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(grouped_A.data(), transa, grouped_B.data(), transb, grouped_D.data(), + grouped_D.data(), gemm_config.te_alpha.data(), gemm_config.te_beta.data(), + gemm_config.te_workspace_setup.data(), gemm_config.te_workspace_cublas.data(), + gemm_config.matmul_config.has_value() + ? static_cast(*gemm_config.matmul_config) + : nullptr, + at::cuda::getCurrentCUDAStream()); + }); + + if (!bias.is_none()) { + auto grouped_bias = GroupedTensorFromPyTorchGroupedTensor(bias); + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_bias_add(grouped_D.data(), grouped_bias.data(), + at::cuda::getCurrentCUDAStream()); + }); + } + + return py::reinterpret_borrow(D); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8302a13010..5585d0d93b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -276,6 +276,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); + m.def("te_general_grouped_gemm_for_grouped_tensor", + &transformer_engine::pytorch::te_general_grouped_gemm_for_grouped_tensor, + "Grouped GEMM for GroupedTensor"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index eda5e8fc54..36c4f39818 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -216,8 +216,8 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); // Rowwise data - if (!tensor.attr("data").is_none()) { - const auto &data = tensor.attr("data").cast(); + if (!tensor.attr("rowwise_data").is_none()) { + const auto &data = tensor.attr("rowwise_data").cast(); DType data_dtype = quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data));