Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions mlx/backend/cuda/quantized/qmm/qmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,8 @@ bool supports_qmv(
int group_size,
QuantizationMode mode,
cu::Device& device) {
int m = out.shape(-2);
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
if (l > 1) {
return false;
}
if (n % 8 != 0 || k % 8 != 0) {
if (k % 8 != 0) {
return false;
}
if (!x.flags().row_contiguous || !w.flags().row_contiguous ||
Expand Down
102 changes: 58 additions & 44 deletions mlx/backend/cuda/quantized/qmm/qmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ template <int N, typename T, typename Q>
__device__ __forceinline__ void
dequant_fma(const T* x, const Q* w, T scale, T bias, T* out) {
// Read x/w into registers.
auto x_vec = *(reinterpret_cast<const cutlass::AlignedArray<T, N>*>(x));
auto w_vec = *(reinterpret_cast<const cutlass::AlignedArray<Q, N>*>(w));
auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));
auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));
// Output is assumed to be registers.
auto* out_vec = reinterpret_cast<cutlass::Array<T, N>*>(out);

Expand Down Expand Up @@ -52,8 +52,8 @@ template <
__device__ __forceinline__ void
dequant_fma(const T* x, const Q* w, T scale, T bias, float* out) {
// Read x/w into registers.
auto x_vec = *(reinterpret_cast<const cutlass::AlignedArray<T, N>*>(x));
auto w_vec = *(reinterpret_cast<const cutlass::AlignedArray<Q, N>*>(w));
auto x_vec = *(reinterpret_cast<const cutlass::Array<T, N>*>(x));
auto w_vec = *(reinterpret_cast<const cutlass::Array<Q, N>*>(w));
// Output is assumed to be registers.
auto* out_vec = reinterpret_cast<cutlass::Array<float, N>*>(out);

Expand Down Expand Up @@ -87,7 +87,9 @@ __global__ void qmv_kernel(
const T* biases,
T* out,
int n,
int k) {
int k,
bool broadcast_w) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

Expand All @@ -98,8 +100,10 @@ __global__ void qmv_kernel(
}

// Advance pointers of x/out.
x += block.group_index().y * k;
out += block.group_index().y * n;
int m = grid.dim_blocks().y;
int l = block.group_index().z;
x += block.group_index().y * k + m * k * l;
out += block.group_index().y * n + m * n * l;

// For sub-byte Q, pointer moves by 8bits for each advance, e.g. w += 1 would
// move past 2 elements for 4-bit Q.
Expand All @@ -110,10 +114,11 @@ __global__ void qmv_kernel(
int groups_per_row = k / group_size;

// Advance w/scales/biases to current row.
w += static_cast<int64_t>(row) * k / w_step;
scales += static_cast<int64_t>(row) * groups_per_row;
int w_batch = broadcast_w ? 0 : l;
w += (static_cast<int64_t>(row) + n * w_batch) * k / w_step;
scales += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
if constexpr (has_bias) {
biases += static_cast<int64_t>(row) * groups_per_row;
biases += (static_cast<int64_t>(row) + n * w_batch) * groups_per_row;
}

// Accumulations of current row.
Expand Down Expand Up @@ -168,14 +173,17 @@ void qmv(
int m,
int n,
int k,
int l,
bool broadcast_w,
F&& launch_kernel) {
constexpr int rows_per_block = 8;
constexpr int elems_per_thread =
(cute::sizeof_bits_v<T> <= 16 && cute::sizeof_bits_v<Q> <= 4) ? 16 : 8;

dim3 num_blocks{uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m)};
dim3 num_blocks{
uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)};
dim3 block_dims{WARP_SIZE, rows_per_block};
void* args[] = {&x, &w, &scales, &biases, &out, &n, &k};
void* args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w};

dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {
auto* kernel = &qmv_kernel<
Expand Down Expand Up @@ -207,34 +215,9 @@ inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) {
}
}

template <typename F>
inline void
dispatch_quant_types(int bits, QuantizationMode mode, const char* tag, F&& f) {
if (mode == QuantizationMode::Mxfp4) {
f.template operator()<cutlass::float_e2m1_t>();
} else if (mode == QuantizationMode::Mxfp8) {
f.template operator()<cutlass::float_e4m3_t>();
} else if (mode == QuantizationMode::Nvfp4) {
f.template operator()<cutlass::float_e2m1_t>();
} else {
if (bits == 2) {
f.template operator()<cutlass::uint2b_t>();
} else if (bits == 4) {
f.template operator()<cutlass::uint4b_t>();
} else if (bits == 8) {
f.template operator()<uint8_t>();
} else {
throw std::invalid_argument(
fmt::format("{} {}-bit quantization is not supported.", tag, bits));
}
}
}

template <typename F>
inline void dispatch_groups(int group_size, const char* tag, F&& f) {
if (group_size == 16) {
f.template operator()<16>();
} else if (group_size == 32) {
if (group_size == 32) {
f.template operator()<32>();
} else if (group_size == 64) {
f.template operator()<64>();
Expand All @@ -246,6 +229,35 @@ inline void dispatch_groups(int group_size, const char* tag, F&& f) {
}
}

template <typename F>
inline void dispatch_quant_types(
int bits,
int group_size,
QuantizationMode mode,
const char* tag,
F&& f) {
if (mode == QuantizationMode::Mxfp4) {
f.template operator()<cutlass::float_e2m1_t, 16>();
} else if (mode == QuantizationMode::Mxfp8) {
f.template operator()<cutlass::float_e4m3_t, 32>();
} else if (mode == QuantizationMode::Nvfp4) {
f.template operator()<cutlass::float_e2m1_t, 32>();
} else {
dispatch_groups(group_size, tag, [&]<int group_size>() {
if (bits == 2) {
f.template operator()<cutlass::uint2b_t, group_size>();
} else if (bits == 4) {
f.template operator()<cutlass::uint4b_t, group_size>();
} else if (bits == 8) {
f.template operator()<uint8_t, group_size>();
} else {
throw std::invalid_argument(
fmt::format("{} {}-bit quantization is not supported.", tag, bits));
}
});
}
}

void qmv(
const array& x,
const array& w,
Expand All @@ -260,19 +272,21 @@ void qmv(
int m = out.shape(-2);
int n = out.shape(-1);
int k = x.shape(-1);
int l = out.size() / (m * n);
bool broadcast_w = w.ndim() == 2;

dispatch_element_types(out.dtype(), tag, [&]<typename T>() {
dispatch_bool(biases.has_value(), [&](auto has_bias) {
dispatch_quant_types(bits, mode, tag, [&]<typename Q>() {
dispatch_groups(group_size, tag, [&]<int group_size>() {
dispatch_quant_types(
bits, group_size, mode, tag, [&]<typename Q, int group_size>() {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
if (biases) {
encoder.set_input_array(*biases);
}
encoder.set_output_array(out);
cu::qmv<group_size, has_bias.value>(
constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;
cu::qmv<group_size, has_bias>(
gpu_ptr<T>(x),
gpu_ptr<Q>(w),
gpu_ptr<T>(scales),
Expand All @@ -281,13 +295,13 @@ void qmv(
m,
n,
k,
l,
broadcast_w,
[&](auto* kernel, dim3 num_blocks, dim3 block_dims, void** args) {
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, 0, args);
});
});
});
});
});
}

Expand Down
5 changes: 5 additions & 0 deletions mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
fmt::format(
"[quantized_matmul] No implementation for "
"problem shape: {}x{}x{}x{} "
"activation: {}, bits: {}, group size: {}, mode: \"{}\".",
M,
N,
K,
B,
dtype_to_string(x.dtype()),
bits_,
group_size_,
Expand Down
Loading