Skip to content

HipKittens MXFP8 GEMM Support#566

Open
alextmagro wants to merge 33 commits into
devfrom
hipkittens_mxfp8
Open

HipKittens MXFP8 GEMM Support#566
alextmagro wants to merge 33 commits into
devfrom
hipkittens_mxfp8

Conversation

@alextmagro

@alextmagro alextmagro commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

Creates an MXFP8 GEMM with HipKittens that outperforms hipBLASlt, and offers additional epilogues such as BIAS and GELU AUX

Requires a workspace sized relative to the model. Often larger than hipBLASlt, but with significant performance improvements. Only builds for gfx950, and requires M / 256 and N / 256.

Adds hipKittens header library as a submodule.

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/jax/utils.py
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.hip Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
@alextmagro alextmagro requested a review from wangye805 May 5, 2026 20:26
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a point, they are set to false only

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.h
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/common/CMakeLists.txt Outdated

return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
key = (device, ub, grouped_gemm)
ws = _workspace_cache.get(key)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we don't rely on torch memory caching?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made this change. I will need to run an E2E run to make sure that performance isn't affected, but should be ok given my understanding of torch.empty()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem changed.

@alextmagro alextmagro requested review from aris134 and ipanfilo May 12, 2026 13:24
@alextmagro alextmagro requested a review from ipanfilo May 14, 2026 17:18
@alextmagro alextmagro added ci-level 3 CI test level 3 and removed ci-level 1 CI test level 1 labels May 14, 2026
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
Comment thread transformer_engine/common/gemm/rocm_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/rocm_gemm.cu
Comment thread transformer_engine/common/gemm/rocm_gemm.cu
if (use_hipkittens) {
auto param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k);

hipStream_t s = use_service_stream ? ss_ctl.stream : stream;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same like with is_mxfp8, no point of having it defined for one branch only

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
@@ -743,12 +786,15 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16)

INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you end up with having separate prefix for MXFP8, it has to be use for this suite for consistency

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.h
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py Outdated
@alextmagro alextmagro requested a review from ipanfilo May 18, 2026 20:43
Comment thread tests/cpp/operator/test_cublaslt_gemm.cu Outdated
@alextmagro alextmagro requested a review from ipanfilo May 29, 2026 19:34
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
@alextmagro alextmagro requested a review from ipanfilo June 17, 2026 21:09
Comment thread transformer_engine/common/gemm/kittens/mxfp8_gemm.cpp
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Comment thread transformer_engine/jax/cpp_extensions/gemm.py
@alextmagro alextmagro requested a review from ipanfilo June 18, 2026 22:47

@ipanfilo ipanfilo left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments in rocm_gemm and test_cublas_gemm are still open

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants