From aecf6275b95f03c6681307f7c6e6532e997ee965 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Sat, 21 Feb 2026 02:05:25 +0800 Subject: [PATCH 1/9] feat(cuda): implement masked scatter kernel --- mlx/backend/cuda/CMakeLists.txt | 2 +- .../cuda/{indexing.cpp => indexing.cu} | 131 ++++++++++++++++++ mlx/backend/cuda/primitives.cpp | 1 - mlx/backend/cuda/scan.cu | 75 +++++----- mlx/backend/cuda/scan.h | 20 +++ 5 files changed, 196 insertions(+), 33 deletions(-) rename mlx/backend/cuda/{indexing.cpp => indexing.cu} (75%) create mode 100644 mlx/backend/cuda/scan.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1b95116e0e..35fbd33e56 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -31,7 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cu similarity index 75% rename from mlx/backend/cuda/indexing.cpp rename to mlx/backend/cuda/indexing.cu index 424566d258..7091ed3ec6 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/scan.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -54,6 +55,53 @@ void append_indices_arg( } // namespace +namespace cu { + +template +__global__ void masked_assign( + const bool* mask, + const int32_t* scatter_offsets, + const T* src, + T* out, + IdxT total, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim, + IdxT src_batch_size, + IdxT mask_batch_size) { + IdxT block_id = static_cast(blockIdx.x) + + static_cast(gridDim.x) * + (static_cast(blockIdx.y) + + static_cast(gridDim.y) * static_cast(blockIdx.z)); + IdxT thread_id = block_id * blockDim.x + threadIdx.x; + IdxT stride = + static_cast(blockDim.x) * gridDim.x * gridDim.y * gridDim.z; + + for (IdxT idx = thread_id; idx < total; idx += stride) { + if (!mask[idx]) { + continue; + } + + IdxT src_index = static_cast(scatter_offsets[idx]); + if (src_index >= src_batch_size) { + // Match Metal backend behavior by skipping out-of-range source reads. + continue; + } + + IdxT batch_idx = idx / mask_batch_size; + if constexpr (src_contiguous) { + out[idx] = src[batch_idx * src_batch_size + src_index]; + } else { + IdxT src_elem = batch_idx * src_batch_size + src_index; + IdxT src_loc = + elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim); + out[idx] = src[src_loc]; + } + } +} + +} // namespace cu + void Gather::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Gather::eval_gpu"); assert(inputs.size() > 0); @@ -435,4 +483,87 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dims, {}, 0, args.args()); } +void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("MaskedScatter::eval_gpu"); + assert(inputs.size() == 3); + + const array& dst = inputs[0]; + const array& mask = inputs[1]; + const array& src = inputs[2]; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + const size_t total = mask.size(); + const CopyType copy_type = (total == 1) + ? CopyType::Scalar + : (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy_gpu(dst, out, copy_type, s); + if (total == 0) { + return; + } + + array mask_flat = flatten_in_eval(mask, 1, -1, s); + if (mask_flat.data() != mask.data()) { + encoder.add_temporary(mask_flat); + } + if (!mask_flat.flags().row_contiguous) { + mask_flat = contiguous_copy_gpu(mask_flat, s); + encoder.add_temporary(mask_flat); + } + + array scatter_offsets(mask_flat.shape(), int32, nullptr, {}); + scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder)); + encoder.add_temporary(scatter_offsets); + + scan_gpu_inplace( + mask_flat, + scatter_offsets, + Scan::Sum, + /* axis= */ 1, + /* reverse= */ false, + /* inclusive= */ false, + s); + + const size_t batch_count = mask.shape(0); + const size_t mask_batch_size = mask_flat.size() / batch_count; + const size_t src_batch_size = src.size() / src.shape(0); + + encoder.set_input_array(mask_flat); + encoder.set_input_array(scatter_offsets); + encoder.set_input_array(src); + encoder.set_output_array(out); + + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using T = cuda_type_t; + dispatch_bool(src.flags().row_contiguous, [&](auto src_contiguous) { + dispatch_bool( + total > INT32_MAX || src.size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto [num_blocks, block_dims] = get_launch_args( + mask_flat.size(), + mask_flat.shape(), + mask_flat.strides(), + large()); + auto kernel = cu::masked_assign; + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(mask_flat), + gpu_ptr(scatter_offsets), + gpu_ptr(src), + gpu_ptr(out), + static_cast(mask_flat.size()), + const_param(src.shape()), + const_param(src.strides()), + static_cast(src.ndim()), + static_cast(src_batch_size), + static_cast(mask_batch_size)); + }); + }); + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 0caac8de6e..cb9b69fbb0 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -36,7 +36,6 @@ NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) -NO_GPU(MaskedScatter) namespace distributed { NO_GPU_MULTI(Send) diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index bd25084c1f..a07d056d89 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/backend/cuda/scan.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -362,51 +363,38 @@ constexpr bool supports_scan_op() { } } -void Scan::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Scan::eval_gpu"); - assert(inputs.size() == 1); - auto in = inputs[0]; - auto& s = stream(); +void scan_gpu_inplace( + array in, + array& out, + Scan::ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive, + const Stream& s) { auto& encoder = cu::get_command_encoder(s); - - if (in.flags().contiguous && in.strides()[axis_] != 0) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.copy_shared_buffer(in); - } else { - out.set_data( - cu::malloc_async(in.data_size() * out.itemsize(), encoder), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - in = contiguous_copy_gpu(in, s); - out.copy_shared_buffer(in); - } - constexpr int N_READS = 4; - int32_t axis_size = in.shape(axis_); - bool contiguous = in.strides()[axis_] == 1; + int32_t axis_size = in.shape(axis); + bool contiguous = in.strides()[axis] == 1; encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { using T = cuda_type_t; - dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { + dispatch_scan_ops(reduce_type, [&](auto scan_op_tag) { using Op = MLX_GET_TYPE(scan_op_tag); if constexpr (supports_scan_op()) { using U = typename cu::ScanResult::type; - dispatch_bool(inclusive_, [&](auto inclusive) { - dispatch_bool(reverse_, [&](auto reverse) { + dispatch_bool(inclusive, [&](auto inclusive_tag) { + dispatch_bool(reverse, [&](auto reverse_tag) { if (contiguous) { auto kernel = cu::contiguous_scan< T, U, Op, N_READS, - inclusive.value, - reverse.value>; + inclusive_tag.value, + reverse_tag.value>; int block_dim = cuda::ceil_div(axis_size, N_READS); block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); @@ -427,9 +415,9 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { N_READS, BM, BN, - inclusive.value, - reverse.value>; - int64_t stride = in.strides()[axis_]; + inclusive_tag.value, + reverse_tag.value>; + int64_t stride = in.strides()[axis]; int64_t stride_blocks = cuda::ceil_div(stride, BN); dim3 num_blocks = get_2d_grid_dims( in.shape(), in.strides(), axis_size * stride); @@ -463,4 +451,29 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { }); } +void Scan::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Scan::eval_gpu"); + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + cu::malloc_async(in.data_size() * out.itemsize(), encoder), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + in = contiguous_copy_gpu(in, s); + out.copy_shared_buffer(in); + } + + scan_gpu_inplace(in, out, reduce_type_, axis_, reverse_, inclusive_, s); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/scan.h b/mlx/backend/cuda/scan.h new file mode 100644 index 0000000000..ea233edfb1 --- /dev/null +++ b/mlx/backend/cuda/scan.h @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/primitives.h" +#include "mlx/stream.h" + +namespace mlx::core { + +void scan_gpu_inplace( + array in, + array& out, + Scan::ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive, + const Stream& s); + +} // namespace mlx::core From 71cb5afc41fd6e0b016f360ea842b04b91ac70ab Mon Sep 17 00:00:00 2001 From: Lyxot Date: Sat, 21 Feb 2026 02:06:25 +0800 Subject: [PATCH 2/9] test(cuda): enable masked scatter test coverage --- python/tests/cuda_skip.py | 4 ---- tests/autograd_tests.cpp | 5 ----- tests/ops_tests.cpp | 5 ----- 3 files changed, 14 deletions(-) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 20793d5c91..fd2924a411 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -55,8 +55,4 @@ "TestQuantized.test_throw", "TestQuantized.test_vjp_scales_biases", "TestExportImport.test_export_quantized_model", - # Masked scatter - "TestOps.test_masked_scatter", - "TestVmap.test_vmap_masked_scatter", - "TestArray.test_setitem_with_boolean_mask", } diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index ff8d986bd9..25c871cdf9 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1357,11 +1357,6 @@ TEST_CASE("test grad dynamic slices") { } TEST_CASE("test masked_scatter autograd") { - if (cu::is_available()) { - INFO("Skipping masked_scatter cuda autograd tests"); - return; - } - // Test jvp { auto self = array({10.f, 20.f, 30.f, 40.f}, {4}); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..23740b7004 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2437,11 +2437,6 @@ TEST_CASE("test scatter") { } TEST_CASE("test masked_scatter") { - if (cu::is_available()) { - INFO("Skipping masked_scatter cuda ops tests"); - return; - } - // Wrong mask dtype CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2}))); From 50c0da5a699a591e5707b32aefbae5749c0faf4a Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 25 Feb 2026 04:07:58 +0800 Subject: [PATCH 3/9] refactor(cuda): align masked scatter jit with scatter kernels --- mlx/backend/cuda/CMakeLists.txt | 2 +- mlx/backend/cuda/device/scatter.cuh | 38 ++++++ .../cuda/{indexing.cu => indexing.cpp} | 121 +++++++----------- 3 files changed, 83 insertions(+), 78 deletions(-) rename mlx/backend/cuda/{indexing.cu => indexing.cpp} (84%) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 35fbd33e56..1b95116e0e 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -31,7 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cu + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh index b2f6403505..9a124d5426 100644 --- a/mlx/backend/cuda/device/scatter.cuh +++ b/mlx/backend/cuda/device/scatter.cuh @@ -65,4 +65,42 @@ __global__ void scatter( Op{}(out + out_idx, upd[upd_loc]); } +template +__global__ void masked_scatter_assign( + const bool* mask, + const int32_t* scatter_offsets, + const T* src, + T* out, + IdxT size, + IdxT src_batch_size, + IdxT mask_batch_size, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index >= size) { + return; + } + + if (!mask[index]) { + return; + } + + IdxT src_index = static_cast(scatter_offsets[index]); + if (src_index >= src_batch_size) { + // Match Metal backend behavior by skipping out-of-range source reads. + return; + } + + IdxT batch_idx = index / mask_batch_size; + if constexpr (SrcContiguous) { + out[index] = src[batch_idx * src_batch_size + src_index]; + } else { + IdxT src_elem = batch_idx * src_batch_size + src_index; + IdxT src_loc = + elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim); + out[index] = src[src_loc]; + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cu b/mlx/backend/cuda/indexing.cpp similarity index 84% rename from mlx/backend/cuda/indexing.cu rename to mlx/backend/cuda/indexing.cpp index 7091ed3ec6..a3b5332270 100644 --- a/mlx/backend/cuda/indexing.cu +++ b/mlx/backend/cuda/indexing.cpp @@ -55,53 +55,6 @@ void append_indices_arg( } // namespace -namespace cu { - -template -__global__ void masked_assign( - const bool* mask, - const int32_t* scatter_offsets, - const T* src, - T* out, - IdxT total, - const __grid_constant__ Shape src_shape, - const __grid_constant__ Strides src_strides, - int32_t src_ndim, - IdxT src_batch_size, - IdxT mask_batch_size) { - IdxT block_id = static_cast(blockIdx.x) + - static_cast(gridDim.x) * - (static_cast(blockIdx.y) + - static_cast(gridDim.y) * static_cast(blockIdx.z)); - IdxT thread_id = block_id * blockDim.x + threadIdx.x; - IdxT stride = - static_cast(blockDim.x) * gridDim.x * gridDim.y * gridDim.z; - - for (IdxT idx = thread_id; idx < total; idx += stride) { - if (!mask[idx]) { - continue; - } - - IdxT src_index = static_cast(scatter_offsets[idx]); - if (src_index >= src_batch_size) { - // Match Metal backend behavior by skipping out-of-range source reads. - continue; - } - - IdxT batch_idx = idx / mask_batch_size; - if constexpr (src_contiguous) { - out[idx] = src[batch_idx * src_batch_size + src_index]; - } else { - IdxT src_elem = batch_idx * src_batch_size + src_index; - IdxT src_loc = - elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim); - out[idx] = src[src_loc]; - } - } -} - -} // namespace cu - void Gather::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Gather::eval_gpu"); assert(inputs.size() > 0); @@ -528,42 +481,56 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { const size_t batch_count = mask.shape(0); const size_t mask_batch_size = mask_flat.size() / batch_count; const size_t src_batch_size = src.size() / src.shape(0); + bool large = total > INT32_MAX || src.size() > INT32_MAX; + + std::string module_name = + fmt::format("masked_scatter_assign_{}", dtype_to_string(out.dtype())); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) { + for (int use_large = 0; use_large <= 1; ++use_large) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_assign<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src_contiguous ? "true" : "false", + use_large ? "int64_t" : "int32_t")); + } + } + return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); + }); + + cu::KernelArgs args; + args.append(mask_flat); + args.append(scatter_offsets); + args.append(src); + args.append(out); + if (large) { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } else { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); encoder.set_input_array(mask_flat); encoder.set_input_array(scatter_offsets); encoder.set_input_array(src); encoder.set_output_array(out); - dispatch_all_types(out.dtype(), [&](auto type_tag) { - using T = cuda_type_t; - dispatch_bool(src.flags().row_contiguous, [&](auto src_contiguous) { - dispatch_bool( - total > INT32_MAX || src.size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto [num_blocks, block_dims] = get_launch_args( - mask_flat.size(), - mask_flat.shape(), - mask_flat.strides(), - large()); - auto kernel = cu::masked_assign; - encoder.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(mask_flat), - gpu_ptr(scatter_offsets), - gpu_ptr(src), - gpu_ptr(out), - static_cast(mask_flat.size()), - const_param(src.shape()), - const_param(src.strides()), - static_cast(src.ndim()), - static_cast(src_batch_size), - static_cast(mask_batch_size)); - }); - }); - }); + std::string kernel_name = fmt::format( + "mlx::core::cu::masked_scatter_assign<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src.flags().row_contiguous ? "true" : "false", + large ? "int64_t" : "int32_t"); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } } // namespace mlx::core From f5693f7e891359e6845ed4d13c63a90c5a1eb394 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 27 Feb 2026 18:23:21 +0800 Subject: [PATCH 4/9] refactor(cuda): use add_kernel_node_raw for masked scatter launch --- mlx/backend/cuda/indexing.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index a3b5332270..2ccd420ae1 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -530,7 +530,8 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { large ? "int64_t" : "int32_t"); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, 0, args.args()); } } // namespace mlx::core From 4c318ef71c25b5d04911e83a4bf24fce35b18c39 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Sat, 28 Feb 2026 15:52:33 +0800 Subject: [PATCH 5/9] test: update bench script --- benchmarks/python/masked_scatter.py | 34 ++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/benchmarks/python/masked_scatter.py b/benchmarks/python/masked_scatter.py index 71857c5436..e1c84ee6bb 100644 --- a/benchmarks/python/masked_scatter.py +++ b/benchmarks/python/masked_scatter.py @@ -1,5 +1,6 @@ import math import os +import platform import subprocess import time from copy import copy @@ -17,9 +18,6 @@ if not os.path.isdir(RESULTS_DIR): os.mkdir(RESULTS_DIR) -DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) -DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n") - TORCH_DEVICE = torch.device( "mps" if torch.backends.mps.is_available() @@ -27,11 +25,36 @@ ) +def get_device_name(): + if TORCH_DEVICE.type == "cuda": + try: + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + stderr=subprocess.DEVNULL, + ) + return out.decode("utf-8").splitlines()[0].strip() + except Exception: + return "CUDA_GPU" + if TORCH_DEVICE.type == "mps": + try: + out = subprocess.check_output( + ["sysctl", "-n", "machdep.cpu.brand_string"], + stderr=subprocess.DEVNULL, + ) + return out.decode("utf-8").strip() + except Exception: + return "Apple_Silicon" + return platform.processor() or platform.machine() or "CPU" + + +DEVICE_NAME = get_device_name() + + N_WARMUP = 5 N_ITER_BENCH = 50 N_ITER_FUNC = 20 -VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)] +VECTOR_LENGTHS = [4096 * (2**i) for i in range(12)] MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5] D_TYPES = ("float32", "float16") @@ -202,9 +225,10 @@ def main(): ) output_path = os.path.join( RESULTS_DIR, - f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf", + f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.png", ) fig.savefig(output_path) + print(f"Saved benchmark image: {output_path}") plt.close(fig) From 608cfdde01683a33ef305baf687eb404c2bb7f4d Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 4 Mar 2026 20:11:13 +0800 Subject: [PATCH 6/9] perf: replace mask prefix scan with CUB segmented exclusive scan --- mlx/backend/cuda/indexing.cpp | 12 +-- mlx/backend/cuda/scan.cu | 163 +++++++++++++++++++++++++--------- mlx/backend/cuda/scan.h | 9 +- 3 files changed, 128 insertions(+), 56 deletions(-) diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 2ccd420ae1..afa99bd72b 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -469,20 +469,14 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder)); encoder.add_temporary(scatter_offsets); - scan_gpu_inplace( - mask_flat, - scatter_offsets, - Scan::Sum, - /* axis= */ 1, - /* reverse= */ false, - /* inclusive= */ false, - s); - const size_t batch_count = mask.shape(0); const size_t mask_batch_size = mask_flat.size() / batch_count; const size_t src_batch_size = src.size() / src.shape(0); bool large = total > INT32_MAX || src.size() > INT32_MAX; + segmented_exclusive_mask_scan_gpu( + mask_flat, scatter_offsets, static_cast(mask_batch_size), s); + std::string module_name = fmt::format("masked_scatter_assign_{}", dtype_to_string(out.dtype())); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index a07d056d89..1bfea55d54 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -5,6 +5,7 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/backend/cuda/scan.h" +#include "mlx/backend/cuda/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -12,6 +13,10 @@ #include #include #include +#include +#include +#include +#include #include @@ -363,38 +368,139 @@ constexpr bool supports_scan_op() { } } -void scan_gpu_inplace( - array in, +namespace { + +struct BoolToInt32 { + __host__ __device__ int32_t operator()(bool v) const { + return static_cast(v); + } +}; + +template +struct MaskSegmentKey { + IdxT segment_size; + + __host__ __device__ IdxT operator()(IdxT i) const { + return i / segment_size; + } +}; + +} // namespace + +void segmented_exclusive_mask_scan_gpu( + const array& in, array& out, - Scan::ReduceType reduce_type, - int axis, - bool reverse, - bool inclusive, + int64_t segment_size, const Stream& s) { + if (segment_size <= 0) { + throw std::runtime_error("segment_size must be positive."); + } + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + using CubIdx = int64_t; + auto count_iter = thrust::counting_iterator(0); + auto key_iter = thrust::make_transform_iterator( + count_iter, MaskSegmentKey{static_cast(segment_size)}); + auto value_iter = + thrust::make_transform_iterator(gpu_ptr(in), BoolToInt32{}); + + size_t workspace_size = 0; + if (segment_size == static_cast(in.size())) { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + nullptr, + workspace_size, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + workspace, + workspace_size, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + } else { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + nullptr, + workspace_size, + key_iter, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + workspace, + workspace_size, + key_iter, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + } + return; +} + +void Scan::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Scan::eval_gpu"); + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + cu::malloc_async(in.data_size() * out.itemsize(), encoder), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + in = contiguous_copy_gpu(in, s); + out.copy_shared_buffer(in); + } + constexpr int N_READS = 4; - int32_t axis_size = in.shape(axis); - bool contiguous = in.strides()[axis] == 1; + int32_t axis_size = in.shape(axis_); + bool contiguous = in.strides()[axis_] == 1; encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { using T = cuda_type_t; - dispatch_scan_ops(reduce_type, [&](auto scan_op_tag) { + dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { using Op = MLX_GET_TYPE(scan_op_tag); if constexpr (supports_scan_op()) { using U = typename cu::ScanResult::type; - dispatch_bool(inclusive, [&](auto inclusive_tag) { - dispatch_bool(reverse, [&](auto reverse_tag) { + dispatch_bool(inclusive_, [&](auto inclusive) { + dispatch_bool(reverse_, [&](auto reverse) { if (contiguous) { auto kernel = cu::contiguous_scan< T, U, Op, N_READS, - inclusive_tag.value, - reverse_tag.value>; + inclusive.value, + reverse.value>; int block_dim = cuda::ceil_div(axis_size, N_READS); block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); @@ -415,9 +521,9 @@ void scan_gpu_inplace( N_READS, BM, BN, - inclusive_tag.value, - reverse_tag.value>; - int64_t stride = in.strides()[axis]; + inclusive.value, + reverse.value>; + int64_t stride = in.strides()[axis_]; int64_t stride_blocks = cuda::ceil_div(stride, BN); dim3 num_blocks = get_2d_grid_dims( in.shape(), in.strides(), axis_size * stride); @@ -451,29 +557,4 @@ void scan_gpu_inplace( }); } -void Scan::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Scan::eval_gpu"); - assert(inputs.size() == 1); - auto in = inputs[0]; - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); - - if (in.flags().contiguous && in.strides()[axis_] != 0) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.copy_shared_buffer(in); - } else { - out.set_data( - cu::malloc_async(in.data_size() * out.itemsize(), encoder), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - in = contiguous_copy_gpu(in, s); - out.copy_shared_buffer(in); - } - - scan_gpu_inplace(in, out, reduce_type_, axis_, reverse_, inclusive_, s); -} - } // namespace mlx::core diff --git a/mlx/backend/cuda/scan.h b/mlx/backend/cuda/scan.h index ea233edfb1..9700ea645a 100644 --- a/mlx/backend/cuda/scan.h +++ b/mlx/backend/cuda/scan.h @@ -8,13 +8,10 @@ namespace mlx::core { -void scan_gpu_inplace( - array in, +void segmented_exclusive_mask_scan_gpu( + const array& in, array& out, - Scan::ReduceType reduce_type, - int axis, - bool reverse, - bool inclusive, + int64_t segment_size, const Stream& s); } // namespace mlx::core From e70c1fc9fa124a09e2139b836838a1011e90120e Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 4 Mar 2026 21:14:29 +0800 Subject: [PATCH 7/9] perf: fuse masked scatter copy and assign --- mlx/backend/cuda/device/scatter.cuh | 45 ++++++++++++++++++----------- mlx/backend/cuda/indexing.cpp | 32 +++++++++++--------- 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh index 9a124d5426..5225650c8d 100644 --- a/mlx/backend/cuda/device/scatter.cuh +++ b/mlx/backend/cuda/device/scatter.cuh @@ -65,8 +65,9 @@ __global__ void scatter( Op{}(out + out_idx, upd[upd_loc]); } -template -__global__ void masked_scatter_assign( +template +__global__ void masked_scatter_fused( + const T* dst, const bool* mask, const int32_t* scatter_offsets, const T* src, @@ -74,6 +75,9 @@ __global__ void masked_scatter_assign( IdxT size, IdxT src_batch_size, IdxT mask_batch_size, + const __grid_constant__ Shape dst_shape, + const __grid_constant__ Strides dst_strides, + int32_t dst_ndim, const __grid_constant__ Shape src_shape, const __grid_constant__ Strides src_strides, int32_t src_ndim) { @@ -82,25 +86,32 @@ __global__ void masked_scatter_assign( return; } - if (!mask[index]) { - return; + T dst_val; + if constexpr (DstContiguous) { + dst_val = dst[index]; + } else { + IdxT dst_loc = + elem_to_loc(index, dst_shape.data(), dst_strides.data(), dst_ndim); + dst_val = dst[dst_loc]; } - IdxT src_index = static_cast(scatter_offsets[index]); - if (src_index >= src_batch_size) { - // Match Metal backend behavior by skipping out-of-range source reads. - return; + if (mask[index]) { + IdxT src_index = static_cast(scatter_offsets[index]); + if (src_index < src_batch_size) { + IdxT batch_idx = index / mask_batch_size; + if constexpr (SrcContiguous) { + out[index] = src[batch_idx * src_batch_size + src_index]; + } else { + IdxT src_elem = batch_idx * src_batch_size + src_index; + IdxT src_loc = elem_to_loc( + src_elem, src_shape.data(), src_strides.data(), src_ndim); + out[index] = src[src_loc]; + } + return; + } } - IdxT batch_idx = index / mask_batch_size; - if constexpr (SrcContiguous) { - out[index] = src[batch_idx * src_batch_size + src_index]; - } else { - IdxT src_elem = batch_idx * src_batch_size + src_index; - IdxT src_loc = - elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim); - out[index] = src[src_loc]; - } + out[index] = dst_val; } } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index afa99bd72b..f93277c7be 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -448,10 +448,7 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); const size_t total = mask.size(); - const CopyType copy_type = (total == 1) - ? CopyType::Scalar - : (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General); - copy_gpu(dst, out, copy_type, s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); if (total == 0) { return; } @@ -478,23 +475,27 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { mask_flat, scatter_offsets, static_cast(mask_batch_size), s); std::string module_name = - fmt::format("masked_scatter_assign_{}", dtype_to_string(out.dtype())); + fmt::format("masked_scatter_fused_{}", dtype_to_string(out.dtype())); cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { std::vector kernel_names; for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) { - for (int use_large = 0; use_large <= 1; ++use_large) { - kernel_names.push_back( - fmt::format( - "mlx::core::cu::masked_scatter_assign<{}, {}, {}>", - dtype_to_cuda_type(out.dtype()), - src_contiguous ? "true" : "false", - use_large ? "int64_t" : "int32_t")); + for (int dst_contiguous = 0; dst_contiguous <= 1; ++dst_contiguous) { + for (int use_large = 0; use_large <= 1; ++use_large) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src_contiguous ? "true" : "false", + dst_contiguous ? "true" : "false", + use_large ? "int64_t" : "int32_t")); + } } } return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); }); cu::KernelArgs args; + args.append(dst); args.append(mask_flat); args.append(scatter_offsets); args.append(src); @@ -508,19 +509,24 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { args.append(src_batch_size); args.append(mask_batch_size); } + args.append_ndim(dst.shape()); + args.append_ndim(dst.strides()); + args.append(dst.ndim()); args.append_ndim(src.shape()); args.append_ndim(src.strides()); args.append(src.ndim()); + encoder.set_input_array(dst); encoder.set_input_array(mask_flat); encoder.set_input_array(scatter_offsets); encoder.set_input_array(src); encoder.set_output_array(out); std::string kernel_name = fmt::format( - "mlx::core::cu::masked_scatter_assign<{}, {}, {}>", + "mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), src.flags().row_contiguous ? "true" : "false", + dst.flags().row_contiguous ? "true" : "false", large ? "int64_t" : "int32_t"); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); From f1c2a0b8332c464c264b83ac4290a806d416c797 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 5 Mar 2026 20:35:06 +0800 Subject: [PATCH 8/9] perf: add contiguous vectorized masked scatter kernel --- mlx/backend/cuda/device/scatter.cuh | 44 +++++++++++++++++++++++++++++ mlx/backend/cuda/indexing.cpp | 31 +++++++++++++++----- 2 files changed, 68 insertions(+), 7 deletions(-) diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh index 5225650c8d..84e64e1e5e 100644 --- a/mlx/backend/cuda/device/scatter.cuh +++ b/mlx/backend/cuda/device/scatter.cuh @@ -114,4 +114,48 @@ __global__ void masked_scatter_fused( out[index] = dst_val; } +template +__global__ void masked_scatter_fused_vec_contiguous( + const T* dst, + const bool* mask, + const int32_t* scatter_offsets, + const T* src, + T* out, + IdxT size, + IdxT src_batch_size, + IdxT mask_batch_size, + const __grid_constant__ Shape, + const __grid_constant__ Strides, + int32_t, + const __grid_constant__ Shape, + const __grid_constant__ Strides, + int32_t) { + IdxT vec_index = cg::this_grid().thread_rank(); + IdxT base = vec_index * N_READS; + if (base >= size) { + return; + } + + auto out_vec = load_vector(dst, vec_index, size, static_cast(0)); + auto mask_vec = load_vector(mask, vec_index, size, false); + auto offset_vec = load_vector(scatter_offsets, vec_index, size, 0); + +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + IdxT index = base + i; + if (index >= size) { + break; + } + if (mask_vec[i]) { + IdxT src_index = static_cast(offset_vec[i]); + if (src_index < src_batch_size) { + IdxT batch_idx = index / mask_batch_size; + out_vec[i] = src[batch_idx * src_batch_size + src_index]; + } + } + } + + store_vector(out, vec_index, out_vec, size); +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index f93277c7be..c442b34c07 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -491,6 +491,14 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { } } } + for (int use_large = 0; use_large <= 1; ++use_large) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_fused_vec_contiguous<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + use_large ? "int64_t" : "int32_t", + 16)); + } return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); }); @@ -522,14 +530,23 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(src); encoder.set_output_array(out); - std::string kernel_name = fmt::format( - "mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>", - dtype_to_cuda_type(out.dtype()), - src.flags().row_contiguous ? "true" : "false", - dst.flags().row_contiguous ? "true" : "false", - large ? "int64_t" : "int32_t"); + bool vectorized = src.flags().row_contiguous && dst.flags().row_contiguous; + std::string kernel_name = vectorized + ? fmt::format( + "mlx::core::cu::masked_scatter_fused_vec_contiguous<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + large ? "int64_t" : "int32_t", + 16) + : fmt::format( + "mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src.flags().row_contiguous ? "true" : "false", + dst.flags().row_contiguous ? "true" : "false", + large ? "int64_t" : "int32_t"); auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); + auto [num_blocks, block_dims] = vectorized + ? get_launch_args(mask_flat, large, 16, 256) + : get_launch_args(mask_flat, large); encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, 0, args.args()); } From fead885eb000f8a2006a7f4c4ea64be41260b496 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 5 Mar 2026 23:30:27 +0800 Subject: [PATCH 9/9] perf: speed up contiguous masked scatter with tile prefix sums Add a contiguous masked-scatter fast path that replaces per-element offset reads with a tile-based two-pass flow: - kernel 1 counts selected mask elements per tile\n- CUB segmented exclusive scan builds tile offsets - kernel 2 performs fused copy+scatter and computes intra-tile ranks in-register This removes the large scatter-offset buffer from the contiguous route and reduces global-memory traffic on large vectors. Also rename the mask scan helper for clarity and add an int32 segmented exclusive prefix-sum helper used by the new path. --- mlx/backend/cuda/device/scatter.cuh | 142 +++++++++++++++++----- mlx/backend/cuda/indexing.cpp | 175 +++++++++++++++++++++------- mlx/backend/cuda/scan.cu | 73 +++++++++++- mlx/backend/cuda/scan.h | 8 +- 4 files changed, 324 insertions(+), 74 deletions(-) diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh index 84e64e1e5e..a807740795 100644 --- a/mlx/backend/cuda/device/scatter.cuh +++ b/mlx/backend/cuda/device/scatter.cuh @@ -114,48 +114,138 @@ __global__ void masked_scatter_fused( out[index] = dst_val; } -template +template +__global__ void masked_scatter_tile_count( + const bool* mask, + int32_t* tile_counts, + IdxT mask_batch_size, + int32_t num_tiles_per_batch) { + IdxT tile = cg::this_grid().block_rank(); + IdxT batch_idx = tile / num_tiles_per_batch; + IdxT tile_in_batch = tile - batch_idx * num_tiles_per_batch; + IdxT tile_items = static_cast(blockDim.x) * ITEMS_PER_THREAD; + IdxT tile_start = batch_idx * mask_batch_size + tile_in_batch * tile_items; + IdxT batch_end = (batch_idx + 1) * mask_batch_size; + IdxT tile_end = tile_start + tile_items; + if (tile_end > batch_end) { + tile_end = batch_end; + } + + int32_t local_count = 0; + IdxT index = tile_start + threadIdx.x; +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + if (index < tile_end) { + local_count += static_cast(mask[index]); + } + index += blockDim.x; + } + + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + int nwarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; + + unsigned int active = __activemask(); + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + local_count += __shfl_down_sync(active, local_count, offset); + } + + __shared__ int32_t warp_sums[WARP_SIZE]; + if (lane == 0) { + warp_sums[warp] = local_count; + } + __syncthreads(); + + if (warp == 0) { + int32_t block_sum = (lane < nwarps) ? warp_sums[lane] : 0; + unsigned int warp0_active = __activemask(); + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + block_sum += __shfl_down_sync(warp0_active, block_sum, offset); + } + if (lane == 0) { + tile_counts[tile] = block_sum; + } + } +} + +template __global__ void masked_scatter_fused_vec_contiguous( const T* dst, const bool* mask, - const int32_t* scatter_offsets, + const int32_t* tile_offsets, const T* src, T* out, - IdxT size, IdxT src_batch_size, IdxT mask_batch_size, - const __grid_constant__ Shape, - const __grid_constant__ Strides, - int32_t, - const __grid_constant__ Shape, - const __grid_constant__ Strides, - int32_t) { - IdxT vec_index = cg::this_grid().thread_rank(); - IdxT base = vec_index * N_READS; - if (base >= size) { - return; + int32_t num_tiles_per_batch) { + IdxT tile = cg::this_grid().block_rank(); + IdxT batch_idx = tile / num_tiles_per_batch; + IdxT tile_in_batch = tile - batch_idx * num_tiles_per_batch; + IdxT tile_items = static_cast(blockDim.x) * ITEMS_PER_THREAD; + IdxT tile_start = batch_idx * mask_batch_size + tile_in_batch * tile_items; + IdxT batch_end = (batch_idx + 1) * mask_batch_size; + IdxT tile_end = tile_start + tile_items; + if (tile_end > batch_end) { + tile_end = batch_end; } - auto out_vec = load_vector(dst, vec_index, size, static_cast(0)); - auto mask_vec = load_vector(mask, vec_index, size, false); - auto offset_vec = load_vector(scatter_offsets, vec_index, size, 0); + IdxT src_base = batch_idx * src_batch_size; + IdxT tile_prefix = static_cast(tile_offsets[tile]); + IdxT iter_prefix = 0; + + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + int nwarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; + + __shared__ int32_t warp_counts[WARP_SIZE]; + __shared__ int32_t warp_offsets[WARP_SIZE]; + __shared__ int32_t iter_count; + + IdxT index = tile_start + threadIdx.x; #pragma unroll - for (int i = 0; i < N_READS; ++i) { - IdxT index = base + i; - if (index >= size) { - break; + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + bool active = index < tile_end; + bool mask_value = active ? mask[index] : false; + T out_value = active ? dst[index] : static_cast(0); + + unsigned int active_mask = __activemask(); + unsigned int ballots = __ballot_sync(active_mask, mask_value); + unsigned int lane_mask = (lane == 0) ? 0u : ((1u << lane) - 1u); + int32_t warp_exclusive = __popc(ballots & lane_mask); + int32_t warp_count = __popc(ballots); + + if (lane == 0) { + warp_counts[warp] = warp_count; } - if (mask_vec[i]) { - IdxT src_index = static_cast(offset_vec[i]); + __syncthreads(); + + if (threadIdx.x == 0) { + int32_t offset = 0; + for (int w = 0; w < nwarps; ++w) { + warp_offsets[w] = offset; + offset += warp_counts[w]; + } + iter_count = offset; + } + __syncthreads(); + + if (active && mask_value) { + IdxT src_index = tile_prefix + iter_prefix + + static_cast(warp_offsets[warp] + warp_exclusive); if (src_index < src_batch_size) { - IdxT batch_idx = index / mask_batch_size; - out_vec[i] = src[batch_idx * src_batch_size + src_index]; + out_value = src[src_base + src_index]; } } - } - store_vector(out, vec_index, out_vec, size); + if (active) { + out[index] = out_value; + } + + iter_prefix += static_cast(iter_count); + index += blockDim.x; + __syncthreads(); + } } } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index c442b34c07..2d9dcdddc9 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -462,17 +462,13 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { encoder.add_temporary(mask_flat); } - array scatter_offsets(mask_flat.shape(), int32, nullptr, {}); - scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder)); - encoder.add_temporary(scatter_offsets); - const size_t batch_count = mask.shape(0); const size_t mask_batch_size = mask_flat.size() / batch_count; const size_t src_batch_size = src.size() / src.shape(0); bool large = total > INT32_MAX || src.size() > INT32_MAX; - - segmented_exclusive_mask_scan_gpu( - mask_flat, scatter_offsets, static_cast(mask_batch_size), s); + constexpr int kTileItemsPerThread = 4; + constexpr uint32_t kTileBlockDim = 256; + constexpr size_t kTileItems = kTileItemsPerThread * kTileBlockDim; std::string module_name = fmt::format("masked_scatter_fused_{}", dtype_to_string(out.dtype())); @@ -492,63 +488,152 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { } } for (int use_large = 0; use_large <= 1; ++use_large) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_tile_count<{}, {}>", + use_large ? "int64_t" : "int32_t", + kTileItemsPerThread)); kernel_names.push_back( fmt::format( "mlx::core::cu::masked_scatter_fused_vec_contiguous<{}, {}, {}>", dtype_to_cuda_type(out.dtype()), use_large ? "int64_t" : "int32_t", - 16)); + kTileItemsPerThread)); } return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); }); - cu::KernelArgs args; - args.append(dst); - args.append(mask_flat); - args.append(scatter_offsets); - args.append(src); - args.append(out); - if (large) { - args.append(mask_flat.size()); - args.append(src_batch_size); - args.append(mask_batch_size); - } else { - args.append(mask_flat.size()); - args.append(src_batch_size); - args.append(mask_batch_size); + bool contiguous_fast_path = + src.flags().row_contiguous && dst.flags().row_contiguous; + size_t num_tiles_per_batch = 0; + size_t total_tiles = 0; + dim3 tile_num_blocks; + if (contiguous_fast_path) { + num_tiles_per_batch = cuda::ceil_div(mask_batch_size, kTileItems); + total_tiles = batch_count * num_tiles_per_batch; + if (total_tiles > UINT32_MAX) { + contiguous_fast_path = false; + } else { + tile_num_blocks = dim3(static_cast(total_tiles), 1, 1); + } } - args.append_ndim(dst.shape()); - args.append_ndim(dst.strides()); - args.append(dst.ndim()); - args.append_ndim(src.shape()); - args.append_ndim(src.strides()); - args.append(src.ndim()); - encoder.set_input_array(dst); - encoder.set_input_array(mask_flat); - encoder.set_input_array(scatter_offsets); - encoder.set_input_array(src); - encoder.set_output_array(out); + if (contiguous_fast_path) { + array tile_counts( + {static_cast(batch_count), static_cast(num_tiles_per_batch)}, + int32, + nullptr, + {}); + tile_counts.set_data(cu::malloc_async(tile_counts.nbytes(), encoder)); + encoder.add_temporary(tile_counts); + + array tile_offsets(tile_counts.shape(), int32, nullptr, {}); + tile_offsets.set_data(cu::malloc_async(tile_offsets.nbytes(), encoder)); + encoder.add_temporary(tile_offsets); + + cu::KernelArgs count_args; + count_args.append(mask_flat); + count_args.append(tile_counts); + if (large) { + count_args.append(mask_batch_size); + count_args.append(num_tiles_per_batch); + } else { + count_args.append(mask_batch_size); + count_args.append(num_tiles_per_batch); + } + + encoder.set_input_array(mask_flat); + encoder.set_output_array(tile_counts); + auto count_kernel = mod.get_kernel( + fmt::format( + "mlx::core::cu::masked_scatter_tile_count<{}, {}>", + large ? "int64_t" : "int32_t", + kTileItemsPerThread)); + encoder.add_kernel_node_raw( + count_kernel, tile_num_blocks, kTileBlockDim, {}, 0, count_args.args()); + + segmented_exclusive_int32_prefix_sum_gpu( + tile_counts, + tile_offsets, + static_cast(num_tiles_per_batch), + s); + + cu::KernelArgs fused_args; + fused_args.append(dst); + fused_args.append(mask_flat); + fused_args.append(tile_offsets); + fused_args.append(src); + fused_args.append(out); + if (large) { + fused_args.append(src_batch_size); + fused_args.append(mask_batch_size); + fused_args.append(num_tiles_per_batch); + } else { + fused_args.append(src_batch_size); + fused_args.append(mask_batch_size); + fused_args.append(num_tiles_per_batch); + } - bool vectorized = src.flags().row_contiguous && dst.flags().row_contiguous; - std::string kernel_name = vectorized - ? fmt::format( + encoder.set_input_array(dst); + encoder.set_input_array(mask_flat); + encoder.set_input_array(tile_offsets); + encoder.set_input_array(src); + encoder.set_output_array(out); + auto fused_kernel = mod.get_kernel( + fmt::format( "mlx::core::cu::masked_scatter_fused_vec_contiguous<{}, {}, {}>", dtype_to_cuda_type(out.dtype()), large ? "int64_t" : "int32_t", - 16) - : fmt::format( + kTileItemsPerThread)); + encoder.add_kernel_node_raw( + fused_kernel, tile_num_blocks, kTileBlockDim, {}, 0, fused_args.args()); + } else { + array scatter_offsets(mask_flat.shape(), int32, nullptr, {}); + scatter_offsets.set_data( + cu::malloc_async(scatter_offsets.nbytes(), encoder)); + encoder.add_temporary(scatter_offsets); + segmented_exclusive_mask_prefix_sum_gpu( + mask_flat, scatter_offsets, static_cast(mask_batch_size), s); + + cu::KernelArgs args; + args.append(dst); + args.append(mask_flat); + args.append(scatter_offsets); + args.append(src); + args.append(out); + if (large) { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } else { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } + args.append_ndim(dst.shape()); + args.append_ndim(dst.strides()); + args.append(dst.ndim()); + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); + + encoder.set_input_array(dst); + encoder.set_input_array(mask_flat); + encoder.set_input_array(scatter_offsets); + encoder.set_input_array(src); + encoder.set_output_array(out); + + auto kernel = mod.get_kernel( + fmt::format( "mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>", dtype_to_cuda_type(out.dtype()), src.flags().row_contiguous ? "true" : "false", dst.flags().row_contiguous ? "true" : "false", - large ? "int64_t" : "int32_t"); - auto kernel = mod.get_kernel(kernel_name); - auto [num_blocks, block_dims] = vectorized - ? get_launch_args(mask_flat, large, 16, 256) - : get_launch_args(mask_flat, large); - encoder.add_kernel_node_raw( - kernel, num_blocks, block_dims, {}, 0, args.args()); + large ? "int64_t" : "int32_t")); + auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, 0, args.args()); + } } } // namespace mlx::core diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 1bfea55d54..3c2af5d0bf 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -387,7 +387,7 @@ struct MaskSegmentKey { } // namespace -void segmented_exclusive_mask_scan_gpu( +void segmented_exclusive_mask_prefix_sum_gpu( const array& in, array& out, int64_t segment_size, @@ -453,7 +453,76 @@ void segmented_exclusive_mask_scan_gpu( cuda::std::equal_to<>{}, encoder.stream())); } - return; +} + +void segmented_exclusive_int32_prefix_sum_gpu( + const array& in, + array& out, + int64_t segment_size, + const Stream& s) { + if (segment_size <= 0) { + throw std::runtime_error("segment_size must be positive."); + } + if (in.dtype() != int32 || out.dtype() != int32) { + throw std::runtime_error( + "segmented_exclusive_int32_prefix_sum_gpu expects int32."); + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + using CubIdx = int64_t; + auto count_iter = thrust::counting_iterator(0); + auto key_iter = thrust::make_transform_iterator( + count_iter, MaskSegmentKey{static_cast(segment_size)}); + + size_t workspace_size = 0; + if (segment_size == static_cast(in.size())) { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + nullptr, + workspace_size, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + workspace, + workspace_size, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + } else { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + nullptr, + workspace_size, + key_iter, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + workspace, + workspace_size, + key_iter, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + } } void Scan::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/cuda/scan.h b/mlx/backend/cuda/scan.h index 9700ea645a..f2074cc6e6 100644 --- a/mlx/backend/cuda/scan.h +++ b/mlx/backend/cuda/scan.h @@ -8,7 +8,13 @@ namespace mlx::core { -void segmented_exclusive_mask_scan_gpu( +void segmented_exclusive_mask_prefix_sum_gpu( + const array& in, + array& out, + int64_t segment_size, + const Stream& s); + +void segmented_exclusive_int32_prefix_sum_gpu( const array& in, array& out, int64_t segment_size,