Skip to content

[CUDA] Use qmv kernel for fp quantizations#3239

Merged
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:remove-fp-qmv
Mar 11, 2026
Merged

[CUDA] Use qmv kernel for fp quantizations#3239
zcbenz merged 1 commit intoml-explore:mainfrom
zcbenz:remove-fp-qmv

Conversation

@zcbenz
Copy link
Collaborator

@zcbenz zcbenz commented Mar 10, 2026

Use the QMV kernel for fp quantizations.

Did a simple benchmarking and it is about 9% faster on A100.

Details
import time
import mlx.core as mx

M,N,K = (1, 16384, 16384)

x = mx.random.normal(shape=(M, K), dtype=mx.float16)
w = mx.random.normal(shape=(N, K), dtype=mx.float16)

w_q, scales = mx.quantize(w, mode='mxfp4')
y = mx.quantized_matmul(x, w_q, scales, transpose=True, mode='mxfp4')
mx.eval(y)

def fun():
    y = mx.quantized_matmul(x, w_q, scales, transpose=True, mode='mxfp4')
    mx.eval(y)

for _ in range(100):
    fun()

iterations = 1000
tic = time.time()
for _ in range(iterations):
    fun()
toc = time.time()

s = toc - tic
gb = iterations * (x.nbytes + w_q.nbytes + scales.nbytes + y.nbytes) / 1e9

print("{:5.2f}".format(gb / s))

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Do you think we should test on Hopper and Blackwell as well?

@zcbenz
Copy link
Collaborator Author

zcbenz commented Mar 11, 2026

Tested on H100 this is also 9% faster, but on DGX the old fp_qmv is at least 20% faster 😲.

@zcbenz zcbenz force-pushed the remove-fp-qmv branch 2 times, most recently from c675806 to 2fc2dfc Compare March 11, 2026 07:03
@zcbenz
Copy link
Collaborator Author

zcbenz commented Mar 11, 2026

For sm120/121 the fp_qmv kernel is faster because it uses less registers, I'm not entirely sure whether this is a Blackwell thing or because the consumer hardwares have less registers and do not optimize vectorized instructions (There is no spare B200 for me to test).

Anyway I'm keeping fp_qmv for now for sm >= 100, and I'm going to work on another version of qmv that is implemented in a similar style using less registers, assuming the affine quantization would also benefit from it.

@zcbenz zcbenz changed the title [CUDA] Merge fp_qmv into qmv [CUDA] Use qmv kernel for fp quantizations Mar 11, 2026
@zcbenz zcbenz merged commit ce45c52 into ml-explore:main Mar 11, 2026
16 checks passed
@zcbenz zcbenz deleted the remove-fp-qmv branch March 11, 2026 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants