[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749
Draft
jberchtold-nvidia wants to merge 12 commits intoNVIDIA:mainfrom
Draft
[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749jberchtold-nvidia wants to merge 12 commits intoNVIDIA:mainfrom
jberchtold-nvidia wants to merge 12 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Contributor
Greptile SummaryThis PR refactors the JAX Grouped GEMM interface to replace the single Key issues found:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["grouped_gemm(lhs, rhs,\nlhs_first_dims, lhs_last_dims,\nrhs_first_dims, rhs_last_dims,\nout_first_dims, out_last_dims, ...)"] --> B{Is rhs ragged?\nrhs_first_dims.size > 0\nor rhs_last_dims.size > 0}
B -->|Yes| C["WGRAD path\nlhs_is_trans=True, rhs_is_trans=False\nlhs_flatten_axis=1, rhs_flatten_axis=1\nout_shape=(num_gemms, M, N)"]
B -->|No| D["FWD / DGRAD path\nDerive lhs_is_trans from contracting_dims\nout_shape=(M_total, N)"]
C --> E{can_use_v2?}
D --> E
E -->|BF16 + NO_SCALING + no bias| F["V2 path\nalpha=ones(G), beta=zeros(G)\nGroupedGemmV2FFI"]
E -->|Otherwise| G["Legacy path\ngroup_offset=zeros(1)\nGroupedGemmFFI\nnvte_multi_tensor_gemm loop"]
F --> H{any_ragged?\nis_lhs_ragged ∥ is_rhs_ragged}
H -->|Yes| I["nvte_convert_int32_to_int64\nactive_gs_ptr from lhs/rhs dims"]
H -->|No| J["⚠ Skip int64 conversion\nactive_gs_ptr = nullptr"]
I --> K{is_rhs_ragged?}
J --> K
K -->|Yes| L["WGRAD branch\nrhs/lhs set_group_sizes_only\nout shape: (num_gemms*M, N)"]
K -->|No| M["FWD/DGRAD branch\nlhs set_group_sizes_only\nout set_group_sizes_only\n⚠ uses int64_sizes_ptr\nout shape: (M_total, N)"]
G --> N["dim_list_host D2H copy\nor async GroupedGemmGetGroupSizes\nnvte_multi_tensor_gemm per-group loop"]
Last reviewed commit: ed9c8e4 |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
35171af to
88bb7da
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
20fadc7 to
025f598
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
a427b9e to
089e530
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Collaborator
Author
|
/te-ci |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR refactors the grouped GEMM API in the JAX backend to support fully ragged (variable-size per group)
dimensions across all tensor axes, replacing the previous single group_sizes parameter with six per-tensor
dimension parameters. The motivation is to generalize the interface so that forward and backward (wgrad) passes
can be expressed uniformly without special-casing, and to eliminate the need for callers to manually compute and
pass matrix dimensions (M, N, K) — these are now derived automatically from XLA buffer descriptors in C++.
Addresses issue: #2648
Type of change
Changes
Please list the changes introduced in this PR:
arguments — lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims — each an
optional (G,) int32 array describing per-group sizes along that tensor axis (empty (0,) arrays indicate a
uniform/non-ragged dimension)
shapes inside the C++ handler, eliminating manual dimension computation in Python
arrays are non-empty (non-empty rhs_first_dims indicates a ragged K contraction dimension, producing a
(num_groups, M, N) output)
single FFI attribute struct, replacing individual attribute bindings
arrays to int64 in partitioned int64_workspace slots, and returns updated workspace offset to avoid aliasing
appropriate new per-tensor parameter (lhs_first_dims/out_first_dims for forward; rhs_first_dims for wgrad)
jnp.empty((0,), jnp.int32) sentinels for non-ragged axes
Checklist: