From 2fca716a299826ccf393547d85b913eabf0bdbf5 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 6 Mar 2026 00:05:10 -0800 Subject: [PATCH] [CUDA] Faster compilation and batch support in QMV --- mlx/backend/cuda/quantized/qmm/qmm.cpp | 8 +- mlx/backend/cuda/quantized/qmm/qmv.cu | 102 +++++++++++++---------- mlx/backend/cuda/quantized/quantized.cpp | 5 ++ 3 files changed, 64 insertions(+), 51 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cpp b/mlx/backend/cuda/quantized/qmm/qmm.cpp index c26e2184c7..dea6795aff 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cpp +++ b/mlx/backend/cuda/quantized/qmm/qmm.cpp @@ -135,14 +135,8 @@ bool supports_qmv( int group_size, QuantizationMode mode, cu::Device& device) { - int m = out.shape(-2); - int n = out.shape(-1); int k = x.shape(-1); - int l = out.size() / (m * n); - if (l > 1) { - return false; - } - if (n % 8 != 0 || k % 8 != 0) { + if (k % 8 != 0) { return false; } if (!x.flags().row_contiguous || !w.flags().row_contiguous || diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index ee42454fd4..b3bf8d5cc8 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -22,8 +22,8 @@ template __device__ __forceinline__ void dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) { // Read x/w into registers. - auto x_vec = *(reinterpret_cast*>(x)); - auto w_vec = *(reinterpret_cast*>(w)); + auto x_vec = *(reinterpret_cast*>(x)); + auto w_vec = *(reinterpret_cast*>(w)); // Output is assumed to be registers. auto* out_vec = reinterpret_cast*>(out); @@ -52,8 +52,8 @@ template < __device__ __forceinline__ void dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) { // Read x/w into registers. - auto x_vec = *(reinterpret_cast*>(x)); - auto w_vec = *(reinterpret_cast*>(w)); + auto x_vec = *(reinterpret_cast*>(x)); + auto w_vec = *(reinterpret_cast*>(w)); // Output is assumed to be registers. auto* out_vec = reinterpret_cast*>(out); @@ -87,7 +87,9 @@ __global__ void qmv_kernel( const T* biases, T* out, int n, - int k) { + int k, + bool broadcast_w) { + auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -98,8 +100,10 @@ __global__ void qmv_kernel( } // Advance pointers of x/out. - x += block.group_index().y * k; - out += block.group_index().y * n; + int m = grid.dim_blocks().y; + int l = block.group_index().z; + x += block.group_index().y * k + m * k * l; + out += block.group_index().y * n + m * n * l; // For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would // move past 2 elements for 4-bit Q. @@ -110,10 +114,11 @@ __global__ void qmv_kernel( int groups_per_row = k / group_size; // Advance w/scales/biases to current row. - w += static_cast(row) * k / w_step; - scales += static_cast(row) * groups_per_row; + int w_batch = broadcast_w ? 0 : l; + w += (static_cast(row) + n * w_batch) * k / w_step; + scales += (static_cast(row) + n * w_batch) * groups_per_row; if constexpr (has_bias) { - biases += static_cast(row) * groups_per_row; + biases += (static_cast(row) + n * w_batch) * groups_per_row; } // Accumulations of current row. @@ -168,14 +173,17 @@ void qmv( int m, int n, int k, + int l, + bool broadcast_w, F&& launch_kernel) { constexpr int rows_per_block = 8; constexpr int elems_per_thread = (cute::sizeof_bits_v <= 16 && cute::sizeof_bits_v <= 4) ? 16 : 8; - dim3 num_blocks{uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m)}; + dim3 num_blocks{ + uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)}; dim3 block_dims{WARP_SIZE, rows_per_block}; - void* args[] = {&x, &w, &scales, &biases, &out, &n, &k}; + void* args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w}; dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) { auto* kernel = &qmv_kernel< @@ -207,34 +215,9 @@ inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { } } -template -inline void -dispatch_quant_types(int bits, QuantizationMode mode, const char* tag, F&& f) { - if (mode == QuantizationMode::Mxfp4) { - f.template operator()(); - } else if (mode == QuantizationMode::Mxfp8) { - f.template operator()(); - } else if (mode == QuantizationMode::Nvfp4) { - f.template operator()(); - } else { - if (bits == 2) { - f.template operator()(); - } else if (bits == 4) { - f.template operator()(); - } else if (bits == 8) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} {}-bit quantization is not supported.", tag, bits)); - } - } -} - template inline void dispatch_groups(int group_size, const char* tag, F&& f) { - if (group_size == 16) { - f.template operator()<16>(); - } else if (group_size == 32) { + if (group_size == 32) { f.template operator()<32>(); } else if (group_size == 64) { f.template operator()<64>(); @@ -246,6 +229,35 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) { } } +template +inline void dispatch_quant_types( + int bits, + int group_size, + QuantizationMode mode, + const char* tag, + F&& f) { + if (mode == QuantizationMode::Mxfp4) { + f.template operator()(); + } else if (mode == QuantizationMode::Mxfp8) { + f.template operator()(); + } else if (mode == QuantizationMode::Nvfp4) { + f.template operator()(); + } else { + dispatch_groups(group_size, tag, [&]() { + if (bits == 2) { + f.template operator()(); + } else if (bits == 4) { + f.template operator()(); + } else if (bits == 8) { + f.template operator()(); + } else { + throw std::invalid_argument( + fmt::format("{} {}-bit quantization is not supported.", tag, bits)); + } + }); + } +} + void qmv( const array& x, const array& w, @@ -260,11 +272,12 @@ void qmv( int m = out.shape(-2); int n = out.shape(-1); int k = x.shape(-1); + int l = out.size() / (m * n); + bool broadcast_w = w.ndim() == 2; dispatch_element_types(out.dtype(), tag, [&]() { - dispatch_bool(biases.has_value(), [&](auto has_bias) { - dispatch_quant_types(bits, mode, tag, [&]() { - dispatch_groups(group_size, tag, [&]() { + dispatch_quant_types( + bits, group_size, mode, tag, [&]() { encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); @@ -272,7 +285,8 @@ void qmv( encoder.set_input_array(*biases); } encoder.set_output_array(out); - cu::qmv( + constexpr bool has_bias = !cutlass::has_negative_zero_v; + cu::qmv( gpu_ptr(x), gpu_ptr(w), gpu_ptr(scales), @@ -281,13 +295,13 @@ void qmv( m, n, k, + l, + broadcast_w, [&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args); }); }); - }); - }); }); } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 4127954745..62a067bff8 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -88,7 +88,12 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( fmt::format( "[quantized_matmul] No implementation for " + "problem shape: {}x{}x{}x{} " "activation: {}, bits: {}, group size: {}, mode: \"{}\".", + M, + N, + K, + B, dtype_to_string(x.dtype()), bits_, group_size_,