diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..2f2d5383b2 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1787,13 +1787,19 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm + empty_gs = jnp.empty((0,), jnp.int32) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( lhs, rhs, - group_sizes, - contracting_dims, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, + contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1825,8 +1831,18 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + empty_gs = jnp.empty((0,), jnp.int32) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + lhs, + rhs, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, ) allclose_dtype = jnp.float8_e4m3fn diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4506adf33b..6a41cfc94e 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1331,17 +1331,57 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) +def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: + """Non-contracting output size M from the 2-D LHS buffer.""" + return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] + + +def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: + """Non-contracting output size N from the 2-D RHS buffer.""" + return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] + + +def _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups: int, +) -> None: + """Assert that all non-empty *_dims arrays have exactly num_groups elements. + + rhs_first_dims / rhs_last_dims describe the ragged contracting K dimension. + K totals need not fill the entire buffer (padding is allowed), so only the + array length is checked, not the per-group sum. + """ + for name, aval in [ + ("lhs_first_dims", lhs_first_dims_aval), + ("lhs_last_dims", lhs_last_dims_aval), + ("out_first_dims", out_first_dims_aval), + ("out_last_dims", out_last_dims_aval), + ("rhs_first_dims", rhs_first_dims_aval), + ("rhs_last_dims", rhs_last_dims_aval), + ]: + if aval.size > 0: + assert ( + aval.size == num_groups + ), f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, alpha, beta + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, + # lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, + # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + impl_static_args = (13, 14, 15, 16, 17, 18, 19) inner_primitive = None outer_primitive = None @@ -1352,17 +1392,18 @@ def abstract( rhs_data_aval, rhs_scale_inv_aval, bias_aval, - group_sizes_aval, + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1370,35 +1411,66 @@ def abstract( Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_data: Left-hand side input matrix data, 2D array [rows, cols] lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_data: Right-hand side input matrix data, 2D array [rows, cols] rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - group_sizes: 1D array containing the sizes of each group + lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel + rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel + out_group_sizes: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR * alpha: 1D array of shape (G,) containing alpha values for each group * beta: 1D array of shape (G,) containing beta values for each group - M: Number of rows in the output matrix - N: Number of columns in the output matrix - K: Number of columns in the left-hand side matrix lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided - is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation - where both lhs and rhs are 2D matrices and output is (G, M, N) Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval - del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + del bias_aval + del has_bias, use_async_d2h_group_sizes + + num_groups = ( + lhs_first_dims_aval.size + or lhs_last_dims_aval.size + or rhs_first_dims_aval.size + or rhs_last_dims_aval.size + or out_first_dims_aval.size + or out_last_dims_aval.size + or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 + ) + + _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups, + ) - num_groups = group_sizes_aval.size + # lhs_data_aval and rhs_data_aval are 2D; derive output shape from buffer dims. + # lhs shape: [M, K] (lhs_is_trans=False) or [K, M] (lhs_is_trans=True) + # rhs shape: [G*K, N] or [K, N] (rhs_is_trans=False) or [G*N, K] (rhs_is_trans=True) + M = _grouped_gemm_lhs_M(lhs_data_aval.shape, lhs_is_trans) + # K validation is intentionally skipped: per-group K values may not fill the + # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. + if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: + # Wgrad case: rhs has ragged contracting K dimension with no G-prefix. + # T-layout rhs shape is (N, K_total); N-layout rhs shape is (K_total, N). + N = rhs_data_aval.shape[0] if rhs_is_trans else rhs_data_aval.shape[1] + out_shape = (num_groups, M, N) + else: + # When rhs has a leading group axis, _grouped_gemm_rhs_N divides by num_groups. + N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) + out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( shape=( @@ -1409,9 +1481,6 @@ def abstract( dtype=jnp.uint8, ) - out_shape = (M, N) - if is_grouped_dense_wgrad: - out_shape = (num_groups, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) if use_v2_ffi: @@ -1419,7 +1488,24 @@ def abstract( shape=(get_grouped_gemm_setup_workspace_size(num_groups),), dtype=jnp.uint8 ) # Temporary buffer for int32 -> int64 conversion of group_sizes on device. - int64_workspace_size = num_groups * jnp.dtype(jnp.int64).itemsize + # Each non-empty *_dims buffer needs its own slot of num_groups int64 elements so that + # make_grouped_tensor can write to a distinct region per ragged dimension. Allocate + # exactly as many slots as there are non-empty buffers (minimum 1 to avoid zero-size). + num_ragged_dim_buffers = sum( + 1 + for aval in [ + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + ] + if aval.size > 0 + ) + int64_workspace_size = ( + max(num_ragged_dim_buffers, 1) * num_groups * jnp.dtype(jnp.int64).itemsize + ) int64_workspace_aval = jax.core.ShapedArray( shape=(int64_workspace_size,), dtype=jnp.uint8 ) @@ -1484,15 +1570,11 @@ def outer_abstract(*args, **kwargs): def lowering( ctx, *args, - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1502,26 +1584,18 @@ def lowering( return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) @@ -1532,18 +1606,19 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1559,17 +1634,18 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, *additional_args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, ) @@ -1875,13 +1951,35 @@ def _can_use_v2_grouped_gemm( if not _v2_grouped_gemm_available: return False + # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). + # Fall back to the v1 path on SM90 (Hopper) and older architectures. + if get_device_compute_capability(0) < 100: + return False + return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias +def _flatten_to_2d(data, flatten_axis): + """Reshape *data* to 2D by splitting at *flatten_axis*. + + Positive flatten_axis: split before that axis index. + Negative flatten_axis: split before (ndim + flatten_axis). + """ + if data.ndim == 2: + return data # Already 2D, no reshape needed + fa = flatten_axis if flatten_axis >= 0 else data.ndim + flatten_axis + return data.reshape(math.prod(data.shape[:fa]), math.prod(data.shape[fa:])) + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - group_sizes: jnp.ndarray, + lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) + lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) + rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) + rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) + out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) + out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -1896,7 +1994,12 @@ def grouped_gemm( Args: lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - group_sizes: 1D array containing the sizes of each group + lhs_first_dims: (G,) int32 if LHS squashed first dim varies per group, else None/(0,) + lhs_last_dims: (G,) int32 if LHS squashed last dim varies per group, else None/(0,) + rhs_first_dims: (G,) int32 if RHS squashed first dim varies per group (wgrad), else None/(0,) + rhs_last_dims: (G,) int32 if RHS squashed last dim varies per group, else None/(0,) + out_first_dims: (G,) int32 if output first dim varies per group, else None/(0,) + out_last_dims: (G,) int32 if output last dim varies per group, else None/(0,) contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -1916,6 +2019,15 @@ def grouped_gemm( # TODO(Phuong): implement the precision del precision + # Replace None sentinels with empty (0,) int32 arrays. + empty_gs = jnp.empty((0,), jnp.int32) + lhs_first_dims = empty_gs if lhs_first_dims is None else lhs_first_dims + lhs_last_dims = empty_gs if lhs_last_dims is None else lhs_last_dims + rhs_first_dims = empty_gs if rhs_first_dims is None else rhs_first_dims + rhs_last_dims = empty_gs if rhs_last_dims is None else rhs_last_dims + out_first_dims = empty_gs if out_first_dims is None else out_first_dims + out_last_dims = empty_gs if out_last_dims is None else out_last_dims + if isinstance(lhs, jnp.ndarray): if not isinstance(rhs, jnp.ndarray): raise TypeError( @@ -1937,8 +2049,10 @@ def grouped_gemm( out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape rhs_shape = rhs.original_shape - lhs_data = lhs.data - rhs_data = rhs.data + lhs_fa = lhs.flatten_axis + rhs_fa = rhs.flatten_axis + lhs_data = lhs.data.reshape(math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:])) + rhs_data = rhs.data.reshape(math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:])) lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv if lhs.scaling_mode != rhs.scaling_mode: @@ -1957,26 +2071,12 @@ def grouped_gemm( lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) - # rhs_shape [G, K, N] - rhs_is_trans = rhs_contract_dim[0] != 1 + # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). + # This formula handles both standard rhs [G, K, N] (G-prefixed) and wgrad + # rhs [K_total, N] (no G prefix) without needing a separate wgrad override. + rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - is_grouped_dense_wgrad = False - if len(rhs_shape) == 2: - rhs_is_trans = rhs_contract_dim[0] != 0 - is_grouped_dense_wgrad = True - - # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? - if ( - is_grouped_dense_wgrad - and not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - ): - lhs_is_trans = True - rhs_is_trans = False - lhs_flatten_axis = 1 - rhs_flatten_axis = 1 - if ( not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) @@ -2007,16 +2107,37 @@ def grouped_gemm( quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + active_group_sizes = next( + ( + gs + for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] + if gs.size > 0 + ), + empty_gs, + ) + lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) - lhs_data = lhs_q.data - rhs_data = rhs_q.data + # grouped_quantize returns a 1D flat buffer; reshape to 2D using the + # original_shape and flatten_axis stored in each quantized tensor. + lhs_fa = lhs_q.flatten_axis # positive index (adjusted in create_1x) + rhs_fa = rhs_q.flatten_axis + lhs_data = lhs_q.data.reshape( + math.prod(lhs_q.original_shape[:lhs_fa]), + math.prod(lhs_q.original_shape[lhs_fa:]), + ) + rhs_data = rhs_q.data.reshape( + math.prod(rhs_q.original_shape[:rhs_fa]), + math.prod(rhs_q.original_shape[rhs_fa:]), + ) lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv lhs_shape = lhs_q.original_shape rhs_shape = rhs_q.original_shape + # Data is already 2D; reset flatten axes so _flatten_to_2d calls below are no-ops. + lhs_flatten_axis = -1 + rhs_flatten_axis = -1 if lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2: raise ValueError("FP8 GEMM does not support E5M2 * E5M2") @@ -2044,38 +2165,41 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if group_sizes.size == rhs_shape[0]: + if ( + lhs_first_dims.size > 0 or lhs_last_dims.size > 0 + ): # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - # Calling GroupedGEMM Custom Call - K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) - K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - if K_lhs != K_rhs: + # Reshape inputs to 2D using the already-computed flatten_axes. + lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) + rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) + + num_gemms = ( + lhs_first_dims.size + or lhs_last_dims.size + or rhs_first_dims.size + or rhs_last_dims.size + or out_first_dims.size + or out_last_dims.size + ) + if num_gemms == 0: raise ValueError( - f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from" - f" lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" + "grouped_gemm requires at least one non-empty dimension array " + "(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, " + "out_first_dims, or out_last_dims)." ) - M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G - - if is_grouped_dense_wgrad: - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) - else: - if group_sizes.size != rhs_shape[0]: - raise ValueError( - "Expected group_sizes.size == rhs_shape[0], but got" - f" group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" - ) has_bias = bias is not None - if has_bias and bias.shape != (group_sizes.size, N): - raise ValueError( - f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" - ) + if has_bias: + N_dim = rhs_data_2d.shape[0] // num_gemms if rhs_is_trans else rhs_data_2d.shape[1] + assert bias.shape == ( + num_gemms, + N_dim, + ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" bias = jnp.empty((), jnp.float32) if bias is None else bias if group_offset is not None: @@ -2087,7 +2211,6 @@ def grouped_gemm( use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) if use_v2_ffi: - num_gemms = group_sizes.shape[0] additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta else: @@ -2095,23 +2218,24 @@ def grouped_gemm( additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, + lhs_data_2d, lhs_scale_inv, - rhs_data, + rhs_data_2d, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, additional_arg_1, - M=M, - N=N, - K=K_lhs, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0fe4e99239..bd429a7db6 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -55,6 +55,20 @@ struct GemmConfig { bool use_split_accumulator; }; +struct GroupedGemmV2Config { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; +}; + +struct GroupedGemmConfig { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; + bool has_bias; + bool use_async_d2h_group_sizes; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -192,6 +206,18 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ::xla::ffi::StructMember("rhs_transposed"), ::xla::ffi::StructMember("use_split_accumulator")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmV2Config, ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("has_bias"), + ::xla::ffi::StructMember("use_async_d2h_group_sizes")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Score_Function); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 737dd65622..2d73390d33 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -617,137 +617,100 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. -Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, - Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type cublas_workspace, - Result_Type setup_workspace, Result_Type int64_workspace, size_t m, - size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode, bool is_grouped_dense_wgrad) { - // Notes on matrix layouts and transpose: - // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major [m, k] for N - [k, m] for T - // B: row-major [k, n] for N - [n, k] for T - // on exiting this function, JAX expect: - // C: row-major with size [m, n]. - // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m] for T - [m, k] for N - // B: column-major with size [n, k] for T - [k, n] for N - // - // If we call cuBLAS GEMM for A * B, the output will be: - // C: column-major with size [m, n] --> row-major with size [n, m]. - // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. +// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// int32β†’int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. +// int64_offset (in int64 elements) is updated on return to the next available slot so callers can +// thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked +// before each slot is used. Only NO_SCALING is supported. +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, + Buffer_Type const &first_dims, + Buffer_Type const &last_dims, + int64_t *int64_workspace_base, + size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + // Flatten all leading dimensions into the first axis to produce a 2D NVTE shape. + // Input buffers (lhs, rhs) are already 2D from the Python side. Output buffers may be ND + // (e.g. [G, K, N] for wgrad), so we collapse dims[0..N-2] β†’ rows and keep dims[N-1] β†’ cols. + NVTEShape dataShape{.data = {product(dims, 0, dims.size() - 1), dims[dims.size() - 1]}, + .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); + wrapper.set_rowwise(data, std::nullopt); + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); - bool has_bias = product(bias.dimensions()) > 0; - auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); +// Returns num_gemms from the first non-empty per-tensor group_sizes buffer, +// falling back to the element count of alpha for the uniform-batch case. +size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, + Buffer_Type const &rhs_first_dims, Buffer_Type const &rhs_last_dims, + Buffer_Type const &out_first_dims, Buffer_Type const &out_last_dims, + Buffer_Type const &alpha) { + if (lhs_first_dims.element_count() > 0) { + return lhs_first_dims.dimensions()[0]; + } else if (lhs_last_dims.element_count() > 0) { + return lhs_last_dims.dimensions()[0]; + } else if (rhs_first_dims.element_count() > 0) { + return rhs_first_dims.dimensions()[0]; + } else if (rhs_last_dims.element_count() > 0) { + return rhs_last_dims.dimensions()[0]; + } else if (out_first_dims.element_count() > 0) { + return out_first_dims.dimensions()[0]; + } else if (out_last_dims.element_count() > 0) { + return out_last_dims.dimensions()[0]; + } else { + return alpha.element_count(); // uniform batch: no ragged tensor + } +} + +} // namespace jax +} // namespace transformer_engine - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; +namespace transformer_engine { +namespace jax { - // Convert int32 group_sizes to int64 into the dedicated output buffer. - NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); - auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); - nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), - int64_sizes_ptr, num_gemms, stream); +// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. +Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type alpha, Buffer_Type beta, Result_Type output, + Result_Type cublas_workspace, Result_Type setup_workspace, + Result_Type int64_workspace, GroupedGemmV2Config config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); - // It is weird that TE/Common GEMM only use colwise for MXFP8 - const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); - const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; - const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; - const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; - const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, + rhs_last_dims, out_first_dims, out_last_dims, alpha); - // Outputs - auto out_ptr = reinterpret_cast(output->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + // Workspaces. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); - // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); - auto workspace_total_size = product(cublas_workspace->dimensions()); - - auto lhs_sinv_size = product(lhs_sinv.dimensions()); - auto rhs_sinv_size = product(rhs_sinv.dimensions()); - const size_t workspace_alignment_padding = 256; - const size_t tensor_scaling_sinv_aligment = 16; - const size_t mxfp8_scaling_sinv_alignment_padding = 256; - auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { - // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned - // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. - workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); - } - auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned - auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; - - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); - size_t actual_lhs_size = product(lhs_data.dimensions()); - size_t actual_rhs_size = product(rhs_data.dimensions()); - size_t actual_out_size = product(output->dimensions()); - NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", - expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, - "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, - " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, - " * ", n, " = ", expected_out_size, ", got ", actual_out_size); - } else { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, - " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, - "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, - " = ", expected_out_size, ", got ", actual_out_size); - } - - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - bool grad = false; - bool accumulate = false; - bool use_split_accumulator = false; - auto bias_shape = std::vector{has_bias ? n : 0}; - const int arch = cuda::sm_arch(); - - if (arch < 100 && is_fp8_gemm) { - NVTE_CHECK(!lhs_is_trans && rhs_is_trans, - "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", - "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); - } - + auto workspace_size = product(cublas_workspace->dimensions()) - 256; TensorWrapper workspace_setup(setup_workspace_ptr, std::vector{product(setup_workspace->dimensions())}, DType::kByte); @@ -761,59 +724,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - if (is_grouped_dense_wgrad) { - NVTE_CHECK(lhs_is_trans && !rhs_is_trans, - "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); - - //// RHS - NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - - //// LHS - NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; - lhs_is_trans = true; - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - - //// OUTPUT - NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, - alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), - workspace_cublas.data(), - nullptr, // config (use defaults) - stream); - - return ffi_with_cuda_error_check(); - } - - // Nominal case for FWD or DGRAD - - //// RHS - NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; - if (rhs_is_trans) { - rhsShape.data[0] = num_gemms * n; - rhsShape.data[1] = k; - } - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - - //// LHS - NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; - if (lhs_is_trans) { - std::swap(lhsShape.data[0], lhsShape.data[1]); - } - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); - - //// OUTPUT - NVTEShape outShape{.data = {m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + // Build grouped tensors from XLA buffer shapes and group_sizes β€” no m/n/k derivation needed. + // int64_workspace is partitioned into per-ragged-buffer slots of num_gemms int64 elements each. + // int64_offset is threaded through the three make_grouped_tensor calls so each non-empty *_dims + // buffer gets its own non-aliasing slot; bounds are checked inside make_grouped_tensor. + auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); + size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); + size_t int64_offset = 0; + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); + auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -827,33 +750,34 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes (int32) + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // alpha .Arg() // beta .Ret() // output .Ret() // cublas_workspace .Ret() // setup_workspace .Ret() // int64_workspace - .Attr("M") - .Attr("N") - .Attr("K") - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("is_grouped_dense_wgrad"), + .Attrs(), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type group_offset, Result_Type output, Result_Type workspace, + GroupedGemmConfig config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -870,6 +794,55 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type int num_streams = nvte_get_num_compute_streams(); + // Determine which group_sizes buffers are active (non-empty = ragged dimension). + bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; + bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; + + size_t num_gemms; + if (is_lhs_first_ragged) + num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) + num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) + num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) + num_gemms = rhs_last_dims.dimensions()[0]; + else + NVTE_CHECK(false, + "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to " + "determine num_gemms."); + + const Buffer_Type *active_gs_ptr = nullptr; + if (is_lhs_first_ragged) + active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) + active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) + active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) + active_gs_ptr = &rhs_last_dims; + + // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. + NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); + NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); + size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; + size_t m, n; + if (is_rhs_ragged) { + // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M + m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; + // T-layout rhs: (N, K_total) -> n = dim[0]; N-layout rhs: (K_total, N) -> n = dim[1] + n = rhs_is_trans ? rhs_data.dimensions()[0] : rhs_data.dimensions()[1]; + } else { + m = lhs_is_trans ? lhs_data.dimensions()[1] + : lhs_data.dimensions()[0]; // total M (sum of group sizes) + n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; + } + // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); @@ -882,9 +855,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || @@ -951,14 +921,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { + if (!is_rhs_ragged) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); @@ -974,25 +944,28 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); + if (any_ragged) { + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); + auto gs_data_ptr = reinterpret_cast(active_gs_ptr->untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_rhs_ragged) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -1040,7 +1013,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto lhs_shape_i = std::vector{m_i, k}; auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { + if (is_rhs_ragged) { size_t k_i = dim_list_host[i]; lhs_shape_i[0] = lhs_is_trans ? k_i : m; lhs_shape_i[1] = lhs_is_trans ? m : k_i; @@ -1230,24 +1203,21 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // group_offset .Ret() // output .Ret() // workspace - .Attr("M") - .Attr("N") - .Attr("K") - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("has_bias") - .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")); + .Attrs()); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index fe02e61fc0..8b397520f2 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -518,15 +518,21 @@ def _grouped_dense_fwd_rule( # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout + empty_gs = jnp.empty((0,), jnp.int32) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, + contracting_dims=contracting_dims, + bias=bias, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) ctx = ( @@ -610,11 +616,17 @@ def _grouped_dense_bwd_rule( wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) + empty_gs = jnp.empty((0,), jnp.int32) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - group_sizes, - dgrad_contracting_dims, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, + contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, @@ -623,8 +635,13 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - group_sizes, - wgrad_contracting_dims, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=group_sizes, + rhs_last_dims=empty_gs, + out_first_dims=empty_gs, + out_last_dims=empty_gs, + contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset,