Pytorch binding for cublas gemm + Grouped Linear integration#2669
Pytorch binding for cublas gemm + Grouped Linear integration#2669vthumbe1503 wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…b.com:vthumbe1503/TransformerEngine into users/vthumbe/pytorch_binding_for_cublas_gemm
for more information, see https://pre-commit.ci
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Greptile SummaryThis PR introduces a PyTorch binding for the new Key changes:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as GroupedLinear (Python)
participant GemmPY as gemm.py
participant CPP as gemm.cpp (C++)
participant TC as type_converters.cpp
participant CUDA as cublaslt_grouped_gemm.cu
PY->>GemmPY: general_grouped_gemm_for_grouped_tensor(A, B, out, m_splits_tensor)
GemmPY->>GemmPY: allocate workspace_setup (get_grouped_gemm_setup_workspace_size)
GemmPY->>GemmPY: allocate workspace_cublas (32 MiB)
GemmPY->>CPP: tex.te_general_grouped_gemm_for_grouped_tensor(A, B, C=out, D=out, alpha, beta, ...)
CPP->>TC: GroupedTensorFromPyTorchGroupedTensor(A)
TC-->>CPP: GroupedTensorWrapper (rowwise/columnwise data, shape metadata)
CPP->>TC: GroupedTensorFromPyTorchGroupedTensor(B)
TC-->>CPP: GroupedTensorWrapper
CPP->>TC: GroupedTensorFromPyTorchGroupedTensor(D/C)
TC-->>CPP: GroupedTensorWrapper
CPP->>CUDA: nvte_grouped_gemm(A, B, C, D, alpha, beta, ws_setup, ws_cublas, config, stream)
CUDA->>CUDA: validate_grouped_gemm_inputs()
CUDA->>CUDA: select_grouped_operand() — rowwise vs columnwise, FP8 TN-only logic
CUDA->>CUDA: launch setup_grouped_gemm_kernel() — fills A/B/C/D pointer arrays + dimensions
CUDA->>CUDA: cudaDeviceSynchronize (implicit via stream ordering)
CUDA->>CUDA: init_matrix_layouts() + init_matmul_desc() + set_fp8_scale_pointers()
CUDA->>CUDA: select_grouped_gemm_algo() (cuBLASLt heuristics)
CUDA->>CUDA: cublasLtMatmul() — batched grouped GEMM
CUDA-->>CPP: result in D (GroupedTensor data buffer)
CPP-->>GemmPY: D (same Python GroupedTensor object)
GemmPY-->>PY: output GroupedTensor
|
| weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] | ||
| bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] |
There was a problem hiding this comment.
AttributeError in backward_dw when single_weight=True
In backward_dw, weight_params is constructed using range(self.num_gemms), but when single_weight=True, only weight0 is registered as a parameter (self.num_weight_params = 1). Accessing weight1, weight2, etc. will raise AttributeError.
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] # BUG: should use num_weight_params
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] # BUG: same issueThis bug surfaces whenever single_weight=True AND delay_wgrad_compute=True. The fix should use self.num_weight_params:
| weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] | |
| bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] | |
| weight_params = [getattr(self, f"weight{i}") for i in range(self.num_weight_params)] | |
| bias_params = [getattr(self, f"bias{i}") for i in range(self.num_weight_params)] |
| if single_weight: | ||
| bias = False | ||
| return_bias = False |
There was a problem hiding this comment.
Silent bias override with no warning
When single_weight=True, bias and return_bias are silently overridden to False without any warning to the caller. A user who passes bias=True, single_weight=True will get no bias but also no error message or warning. Add a warnings.warn(...) call to notify callers of the implicit override:
| if single_weight: | |
| bias = False | |
| return_bias = False | |
| if single_weight: | |
| if bias: | |
| warnings.warn( | |
| "bias=True is not supported with single_weight=True; bias has been disabled.", | |
| UserWarning, | |
| ) | |
| bias = False | |
| return_bias = False |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
| def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: | ||
| """Return workspace size for grouped GEMM pointer setup. | ||
| Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. | ||
| """ | ||
| ptr_bytes = ctypes.sizeof(ctypes.c_void_p) | ||
| int_bytes = ctypes.sizeof(ctypes.c_int) | ||
| ptr_size = num_tensors * ptr_bytes | ||
| int_size = num_tensors * int_bytes | ||
| k_ptr_alignment = 16 | ||
| aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment | ||
| size = 8 * aligned_ptr_size + 6 * int_size | ||
| alignment = 256 | ||
| return ((size + alignment - 1) // alignment) * alignment |
There was a problem hiding this comment.
Workspace size formula doesn't match the C++ implementation
The comment explicitly states this must match GroupedGemmSetupWorkspace::required_setup_size, but the two formulas disagree in both the number of pointer arrays and the intermediate alignment.
C++ (cublaslt_grouped_gemm.cu):
// Layout: 6 ptr arrays, then 6 int arrays
size_t size = 6 * ptr_size + 6 * int_size;Python (here):
size = 8 * aligned_ptr_size + 6 * int_size # 8 × (16-byte-aligned ptr arrays)The workspace has exactly 6 pointer arrays (A_ptrs, B_ptrs, C_ptrs, D_ptrs, alpha_ptrs, beta_ptrs) as confirmed by from_buffers(), so the coefficient should be 6, not 8. The extra per-array 16-byte alignment (k_ptr_alignment) also has no counterpart in the C++ code.
Because Python allocates more than C++ requires, validate_and_get_workspace_ptr (which checks provided_size >= required_size) will always pass and there is no buffer overflow. However the over-allocation grows proportionally with num_tensors and the comment is factually wrong. The formula should be corrected to match C++:
def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int:
ptr_bytes = ctypes.sizeof(ctypes.c_void_p)
int_bytes = ctypes.sizeof(ctypes.c_int)
ptr_size = num_tensors * ptr_bytes
int_size = num_tensors * int_bytes
# Layout: 6 ptr arrays, then 6 int arrays — must match
# GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu
size = 6 * ptr_size + 6 * int_size
alignment = 256
return ((size + alignment - 1) // alignment) * alignment
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: