Skip to content

[JAX] Add bias support for v2 grouped GEMM path#2744

Open
jberchtold-nvidia wants to merge 1 commit intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-bias
Open

[JAX] Add bias support for v2 grouped GEMM path#2744
jberchtold-nvidia wants to merge 1 commit intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-bias

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

The cuda-graphable (v2) grouped GEMM FFI does not natively support bias. This change applies bias in pure JAX after the GEMM in GroupedGemmPrimitive.impl, using a per-token expert index built from group_sizes to gather the correct bias row for each token.

A dedicated unit test (test_grouped_gemm_fp16_with_bias) is added to directly exercise the v2 path with a non-None bfloat16 bias.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Update gemm.py to support pure-JAX bias computation
  • Add new dedicated test for grouped GEMM with bias. Grouped Dense already included bias testing

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

The cuda-graphable (v2) grouped GEMM FFI does not natively support bias.
This change applies bias in pure JAX after the GEMM in
GroupedGemmPrimitive.impl, using a per-token expert index built from
group_sizes to gather the correct bias row for each token.

A dedicated unit test (test_grouped_gemm_fp16_with_bias) is added to
directly exercise the v2 path with a non-None bfloat16 bias.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 6, 2026

Greptile Summary

This PR adds bias support to the cuda-graphable v2 grouped GEMM path in JAX by applying bias in pure JAX after the GEMM kernel (since the v2 C++ FFI does not natively support it). The _can_use_v2_grouped_gemm guard that previously blocked the v2 path when has_bias=True is removed, and GroupedGemmPrimitive.impl now builds a per-token expert index via jnp.repeat over group_sizes to gather and add the correct bias row for each token.

Key changes:

  • _can_use_v2_grouped_gemm no longer requires not has_bias; v2 is now used for all BF16 no-scaling inputs regardless of bias presence.
  • GroupedGemmPrimitive.impl applies the gathered bias (bias[segment_ids]) after the GEMM when use_v2_ffi and has_bias.
  • A new test test_grouped_gemm_fp16_with_bias (despite the name, uses bfloat16) is added to TestGroupedDense to directly exercise the v2+bias path.

Issues found:

  • The bias addition in impl does not guard against is_grouped_dense_wgrad=True, where the output shape is (num_groups, M, N). Broadcasting bias[segment_ids] of shape (M, N) onto that tensor would silently produce semantically incorrect results.
  • The new test method name test_grouped_gemm_fp16_with_bias is misleading — it exclusively tests bfloat16, not float16.

Confidence Score: 3/5

  • Safe to merge for the common forward-pass use case, but contains an unguarded code path that would silently produce incorrect results if bias is ever passed alongside a 2D rhs (wgrad mode) with BF16 inputs.
  • The core bias-gathering logic (jnp.repeat + gather + add) is correct for the standard (M, N) output shape. However, the impl does not assert or guard against the is_grouped_dense_wgrad=True case where output is (G, M, N); the broadcast would succeed silently but produce wrong values. Additionally, the test name mismatch (fp16 vs bfloat16) reduces confidence in test coverage clarity.
  • transformer_engine/jax/cpp_extensions/gemm.py — specifically the bias addition block in GroupedGemmPrimitive.impl (lines 1688–1698) needs a guard for is_grouped_dense_wgrad.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Removes has_bias guard from _can_use_v2_grouped_gemm and adds pure-JAX bias application in GroupedGemmPrimitive.impl using jnp.repeat/gather. The logic is correct for the common (M, N) output shape but doesn't guard against the is_grouped_dense_wgrad=True case where output is (G, M, N), leading to semantically incorrect broadcasting if both flags are ever set simultaneously.
tests/jax/test_custom_call_compute.py Adds test_grouped_gemm_fp16_with_bias to TestGroupedDense, exercising the v2 bfloat16 GEMM path with a non-None bias. The test logic is correct (uses existing helpers and reference function), but the method name says "fp16" while the dtype is hardcoded to bfloat16.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["grouped_gemm()"] --> B{"_can_use_v2_grouped_gemm?\nBF16 + NO_SCALING"}
    B -- "No (FP8, MXFP8, fp16...)" --> C["Legacy v1 FFI\nte_grouped_gemm_ffi\n(native bias support)"]
    B -- "Yes" --> D["v2 FFI\nte_grouped_gemm_v2_ffi\n(no native bias)"]
    D --> E["GroupedGemmPrimitive.impl\ninner_primitive.bind(...)"]
    E --> F{"use_v2_ffi\nand has_bias?"}
    F -- "No" --> G["Return GEMM output"]
    F -- "Yes" --> H["Build segment_ids\njnp.repeat(arange(G), group_sizes, M)"]
    H --> I["Gather bias rows\nbias[segment_ids] → shape (M, N)"]
    I --> J["out = out + bias[segment_ids].astype(out.dtype)"]
    J --> G
    C --> G
Loading

Last reviewed commit: 1d5faa2

self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)

@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp16_with_bias(self, input_shape, layout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading test name: "fp16" actually uses bfloat16

The method is named test_grouped_gemm_fp16_with_bias, but the body hardcodes dtype = jnp.bfloat16 and the docstring explicitly states "bfloat16 only, since v2 is only active for bfloat16 no-scaling inputs." This makes the name actively confusing for anyone trying to find tests for fp16 or bfloat16 coverage.

Suggested change
def test_grouped_gemm_fp16_with_bias(self, input_shape, layout):
def test_grouped_gemm_bf16_with_bias(self, input_shape, layout):

Comment on lines +1688 to +1698
if use_v2_ffi and has_bias:
# The C++ FFI for v2 grouped GEMM does not support bias, so we apply it here in
# pure JAX. Groups are contiguous, so we build a per-token expert index via
# jnp.repeat and gather the corresponding bias row for each token.
num_groups = group_sizes.shape[0]
segment_ids = jnp.repeat(
jnp.arange(num_groups, dtype=jnp.int32),
group_sizes,
total_repeat_length=M,
)
out = out + bias[segment_ids].astype(out.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bias application is incorrect when is_grouped_dense_wgrad=True

When is_grouped_dense_wgrad=True (triggered when rhs is 2D — len(rhs_shape) == 2), the GEMM output shape is (num_groups, M, N) (see abstract: out_shape = (num_groups, M, N)). However, the bias addition code computes bias[segment_ids] with segment_ids.shape == (M,), yielding a (M, N) tensor. Broadcasting (M, N) onto (num_groups, M, N) does not raise an error at runtime, but it adds the same token-mapped bias to every group's output matrix rather than leaving the output unchanged for each group's individual result — semantically incorrect.

While this combination (is_grouped_dense_wgrad=True and has_bias=True and use_v2_ffi=True) is unlikely in the existing call sites (weight-gradient paths don't normally carry a bias), the code as written has no guard against it, and a future caller could hit this silently wrong behavior. Consider adding an assertion:

if use_v2_ffi and has_bias:
    assert not is_grouped_dense_wgrad, (
        "Bias is not supported for the grouped dense wgrad path with v2 FFI."
    )
    # ... rest of bias application

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant