diff --git a/benchmarks/python/masked_scatter.py b/benchmarks/python/masked_scatter.py index 71857c5436..e1c84ee6bb 100644 --- a/benchmarks/python/masked_scatter.py +++ b/benchmarks/python/masked_scatter.py @@ -1,5 +1,6 @@ import math import os +import platform import subprocess import time from copy import copy @@ -17,9 +18,6 @@ if not os.path.isdir(RESULTS_DIR): os.mkdir(RESULTS_DIR) -DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) -DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n") - TORCH_DEVICE = torch.device( "mps" if torch.backends.mps.is_available() @@ -27,11 +25,36 @@ ) +def get_device_name(): + if TORCH_DEVICE.type == "cuda": + try: + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + stderr=subprocess.DEVNULL, + ) + return out.decode("utf-8").splitlines()[0].strip() + except Exception: + return "CUDA_GPU" + if TORCH_DEVICE.type == "mps": + try: + out = subprocess.check_output( + ["sysctl", "-n", "machdep.cpu.brand_string"], + stderr=subprocess.DEVNULL, + ) + return out.decode("utf-8").strip() + except Exception: + return "Apple_Silicon" + return platform.processor() or platform.machine() or "CPU" + + +DEVICE_NAME = get_device_name() + + N_WARMUP = 5 N_ITER_BENCH = 50 N_ITER_FUNC = 20 -VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)] +VECTOR_LENGTHS = [4096 * (2**i) for i in range(12)] MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5] D_TYPES = ("float32", "float16") @@ -202,9 +225,10 @@ def main(): ) output_path = os.path.join( RESULTS_DIR, - f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf", + f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.png", ) fig.savefig(output_path) + print(f"Saved benchmark image: {output_path}") plt.close(fig) diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh index b2f6403505..a807740795 100644 --- a/mlx/backend/cuda/device/scatter.cuh +++ b/mlx/backend/cuda/device/scatter.cuh @@ -65,4 +65,187 @@ __global__ void scatter( Op{}(out + out_idx, upd[upd_loc]); } +template +__global__ void masked_scatter_fused( + const T* dst, + const bool* mask, + const int32_t* scatter_offsets, + const T* src, + T* out, + IdxT size, + IdxT src_batch_size, + IdxT mask_batch_size, + const __grid_constant__ Shape dst_shape, + const __grid_constant__ Strides dst_strides, + int32_t dst_ndim, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index >= size) { + return; + } + + T dst_val; + if constexpr (DstContiguous) { + dst_val = dst[index]; + } else { + IdxT dst_loc = + elem_to_loc(index, dst_shape.data(), dst_strides.data(), dst_ndim); + dst_val = dst[dst_loc]; + } + + if (mask[index]) { + IdxT src_index = static_cast(scatter_offsets[index]); + if (src_index < src_batch_size) { + IdxT batch_idx = index / mask_batch_size; + if constexpr (SrcContiguous) { + out[index] = src[batch_idx * src_batch_size + src_index]; + } else { + IdxT src_elem = batch_idx * src_batch_size + src_index; + IdxT src_loc = elem_to_loc( + src_elem, src_shape.data(), src_strides.data(), src_ndim); + out[index] = src[src_loc]; + } + return; + } + } + + out[index] = dst_val; +} + +template +__global__ void masked_scatter_tile_count( + const bool* mask, + int32_t* tile_counts, + IdxT mask_batch_size, + int32_t num_tiles_per_batch) { + IdxT tile = cg::this_grid().block_rank(); + IdxT batch_idx = tile / num_tiles_per_batch; + IdxT tile_in_batch = tile - batch_idx * num_tiles_per_batch; + IdxT tile_items = static_cast(blockDim.x) * ITEMS_PER_THREAD; + IdxT tile_start = batch_idx * mask_batch_size + tile_in_batch * tile_items; + IdxT batch_end = (batch_idx + 1) * mask_batch_size; + IdxT tile_end = tile_start + tile_items; + if (tile_end > batch_end) { + tile_end = batch_end; + } + + int32_t local_count = 0; + IdxT index = tile_start + threadIdx.x; +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + if (index < tile_end) { + local_count += static_cast(mask[index]); + } + index += blockDim.x; + } + + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + int nwarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; + + unsigned int active = __activemask(); + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + local_count += __shfl_down_sync(active, local_count, offset); + } + + __shared__ int32_t warp_sums[WARP_SIZE]; + if (lane == 0) { + warp_sums[warp] = local_count; + } + __syncthreads(); + + if (warp == 0) { + int32_t block_sum = (lane < nwarps) ? warp_sums[lane] : 0; + unsigned int warp0_active = __activemask(); + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + block_sum += __shfl_down_sync(warp0_active, block_sum, offset); + } + if (lane == 0) { + tile_counts[tile] = block_sum; + } + } +} + +template +__global__ void masked_scatter_fused_vec_contiguous( + const T* dst, + const bool* mask, + const int32_t* tile_offsets, + const T* src, + T* out, + IdxT src_batch_size, + IdxT mask_batch_size, + int32_t num_tiles_per_batch) { + IdxT tile = cg::this_grid().block_rank(); + IdxT batch_idx = tile / num_tiles_per_batch; + IdxT tile_in_batch = tile - batch_idx * num_tiles_per_batch; + IdxT tile_items = static_cast(blockDim.x) * ITEMS_PER_THREAD; + IdxT tile_start = batch_idx * mask_batch_size + tile_in_batch * tile_items; + IdxT batch_end = (batch_idx + 1) * mask_batch_size; + IdxT tile_end = tile_start + tile_items; + if (tile_end > batch_end) { + tile_end = batch_end; + } + + IdxT src_base = batch_idx * src_batch_size; + IdxT tile_prefix = static_cast(tile_offsets[tile]); + IdxT iter_prefix = 0; + + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + int nwarps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; + + __shared__ int32_t warp_counts[WARP_SIZE]; + __shared__ int32_t warp_offsets[WARP_SIZE]; + __shared__ int32_t iter_count; + + IdxT index = tile_start + threadIdx.x; + +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + bool active = index < tile_end; + bool mask_value = active ? mask[index] : false; + T out_value = active ? dst[index] : static_cast(0); + + unsigned int active_mask = __activemask(); + unsigned int ballots = __ballot_sync(active_mask, mask_value); + unsigned int lane_mask = (lane == 0) ? 0u : ((1u << lane) - 1u); + int32_t warp_exclusive = __popc(ballots & lane_mask); + int32_t warp_count = __popc(ballots); + + if (lane == 0) { + warp_counts[warp] = warp_count; + } + __syncthreads(); + + if (threadIdx.x == 0) { + int32_t offset = 0; + for (int w = 0; w < nwarps; ++w) { + warp_offsets[w] = offset; + offset += warp_counts[w]; + } + iter_count = offset; + } + __syncthreads(); + + if (active && mask_value) { + IdxT src_index = tile_prefix + iter_prefix + + static_cast(warp_offsets[warp] + warp_exclusive); + if (src_index < src_batch_size) { + out_value = src[src_base + src_index]; + } + } + + if (active) { + out[index] = out_value; + } + + iter_prefix += static_cast(iter_count); + index += blockDim.x; + __syncthreads(); + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 424566d258..2d9dcdddc9 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/scan.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -435,4 +436,204 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dims, {}, 0, args.args()); } +void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("MaskedScatter::eval_gpu"); + assert(inputs.size() == 3); + + const array& dst = inputs[0]; + const array& mask = inputs[1]; + const array& src = inputs[2]; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + const size_t total = mask.size(); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + if (total == 0) { + return; + } + + array mask_flat = flatten_in_eval(mask, 1, -1, s); + if (mask_flat.data() != mask.data()) { + encoder.add_temporary(mask_flat); + } + if (!mask_flat.flags().row_contiguous) { + mask_flat = contiguous_copy_gpu(mask_flat, s); + encoder.add_temporary(mask_flat); + } + + const size_t batch_count = mask.shape(0); + const size_t mask_batch_size = mask_flat.size() / batch_count; + const size_t src_batch_size = src.size() / src.shape(0); + bool large = total > INT32_MAX || src.size() > INT32_MAX; + constexpr int kTileItemsPerThread = 4; + constexpr uint32_t kTileBlockDim = 256; + constexpr size_t kTileItems = kTileItemsPerThread * kTileBlockDim; + + std::string module_name = + fmt::format("masked_scatter_fused_{}", dtype_to_string(out.dtype())); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) { + for (int dst_contiguous = 0; dst_contiguous <= 1; ++dst_contiguous) { + for (int use_large = 0; use_large <= 1; ++use_large) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src_contiguous ? "true" : "false", + dst_contiguous ? "true" : "false", + use_large ? "int64_t" : "int32_t")); + } + } + } + for (int use_large = 0; use_large <= 1; ++use_large) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_tile_count<{}, {}>", + use_large ? "int64_t" : "int32_t", + kTileItemsPerThread)); + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_fused_vec_contiguous<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + use_large ? "int64_t" : "int32_t", + kTileItemsPerThread)); + } + return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); + }); + + bool contiguous_fast_path = + src.flags().row_contiguous && dst.flags().row_contiguous; + size_t num_tiles_per_batch = 0; + size_t total_tiles = 0; + dim3 tile_num_blocks; + if (contiguous_fast_path) { + num_tiles_per_batch = cuda::ceil_div(mask_batch_size, kTileItems); + total_tiles = batch_count * num_tiles_per_batch; + if (total_tiles > UINT32_MAX) { + contiguous_fast_path = false; + } else { + tile_num_blocks = dim3(static_cast(total_tiles), 1, 1); + } + } + + if (contiguous_fast_path) { + array tile_counts( + {static_cast(batch_count), static_cast(num_tiles_per_batch)}, + int32, + nullptr, + {}); + tile_counts.set_data(cu::malloc_async(tile_counts.nbytes(), encoder)); + encoder.add_temporary(tile_counts); + + array tile_offsets(tile_counts.shape(), int32, nullptr, {}); + tile_offsets.set_data(cu::malloc_async(tile_offsets.nbytes(), encoder)); + encoder.add_temporary(tile_offsets); + + cu::KernelArgs count_args; + count_args.append(mask_flat); + count_args.append(tile_counts); + if (large) { + count_args.append(mask_batch_size); + count_args.append(num_tiles_per_batch); + } else { + count_args.append(mask_batch_size); + count_args.append(num_tiles_per_batch); + } + + encoder.set_input_array(mask_flat); + encoder.set_output_array(tile_counts); + auto count_kernel = mod.get_kernel( + fmt::format( + "mlx::core::cu::masked_scatter_tile_count<{}, {}>", + large ? "int64_t" : "int32_t", + kTileItemsPerThread)); + encoder.add_kernel_node_raw( + count_kernel, tile_num_blocks, kTileBlockDim, {}, 0, count_args.args()); + + segmented_exclusive_int32_prefix_sum_gpu( + tile_counts, + tile_offsets, + static_cast(num_tiles_per_batch), + s); + + cu::KernelArgs fused_args; + fused_args.append(dst); + fused_args.append(mask_flat); + fused_args.append(tile_offsets); + fused_args.append(src); + fused_args.append(out); + if (large) { + fused_args.append(src_batch_size); + fused_args.append(mask_batch_size); + fused_args.append(num_tiles_per_batch); + } else { + fused_args.append(src_batch_size); + fused_args.append(mask_batch_size); + fused_args.append(num_tiles_per_batch); + } + + encoder.set_input_array(dst); + encoder.set_input_array(mask_flat); + encoder.set_input_array(tile_offsets); + encoder.set_input_array(src); + encoder.set_output_array(out); + auto fused_kernel = mod.get_kernel( + fmt::format( + "mlx::core::cu::masked_scatter_fused_vec_contiguous<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + large ? "int64_t" : "int32_t", + kTileItemsPerThread)); + encoder.add_kernel_node_raw( + fused_kernel, tile_num_blocks, kTileBlockDim, {}, 0, fused_args.args()); + } else { + array scatter_offsets(mask_flat.shape(), int32, nullptr, {}); + scatter_offsets.set_data( + cu::malloc_async(scatter_offsets.nbytes(), encoder)); + encoder.add_temporary(scatter_offsets); + segmented_exclusive_mask_prefix_sum_gpu( + mask_flat, scatter_offsets, static_cast(mask_batch_size), s); + + cu::KernelArgs args; + args.append(dst); + args.append(mask_flat); + args.append(scatter_offsets); + args.append(src); + args.append(out); + if (large) { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } else { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } + args.append_ndim(dst.shape()); + args.append_ndim(dst.strides()); + args.append(dst.ndim()); + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); + + encoder.set_input_array(dst); + encoder.set_input_array(mask_flat); + encoder.set_input_array(scatter_offsets); + encoder.set_input_array(src); + encoder.set_output_array(out); + + auto kernel = mod.get_kernel( + fmt::format( + "mlx::core::cu::masked_scatter_fused<{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src.flags().row_contiguous ? "true" : "false", + dst.flags().row_contiguous ? "true" : "false", + large ? "int64_t" : "int32_t")); + auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, 0, args.args()); + } +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 0caac8de6e..cb9b69fbb0 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -36,7 +36,6 @@ NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) -NO_GPU(MaskedScatter) namespace distributed { NO_GPU_MULTI(Send) diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index bd25084c1f..3c2af5d0bf 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -4,6 +4,8 @@ #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/backend/cuda/scan.h" +#include "mlx/backend/cuda/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -11,6 +13,10 @@ #include #include #include +#include +#include +#include +#include #include @@ -362,6 +368,163 @@ constexpr bool supports_scan_op() { } } +namespace { + +struct BoolToInt32 { + __host__ __device__ int32_t operator()(bool v) const { + return static_cast(v); + } +}; + +template +struct MaskSegmentKey { + IdxT segment_size; + + __host__ __device__ IdxT operator()(IdxT i) const { + return i / segment_size; + } +}; + +} // namespace + +void segmented_exclusive_mask_prefix_sum_gpu( + const array& in, + array& out, + int64_t segment_size, + const Stream& s) { + if (segment_size <= 0) { + throw std::runtime_error("segment_size must be positive."); + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + using CubIdx = int64_t; + auto count_iter = thrust::counting_iterator(0); + auto key_iter = thrust::make_transform_iterator( + count_iter, MaskSegmentKey{static_cast(segment_size)}); + auto value_iter = + thrust::make_transform_iterator(gpu_ptr(in), BoolToInt32{}); + + size_t workspace_size = 0; + if (segment_size == static_cast(in.size())) { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + nullptr, + workspace_size, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + workspace, + workspace_size, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + } else { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + nullptr, + workspace_size, + key_iter, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + workspace, + workspace_size, + key_iter, + value_iter, + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + } +} + +void segmented_exclusive_int32_prefix_sum_gpu( + const array& in, + array& out, + int64_t segment_size, + const Stream& s) { + if (segment_size <= 0) { + throw std::runtime_error("segment_size must be positive."); + } + if (in.dtype() != int32 || out.dtype() != int32) { + throw std::runtime_error( + "segmented_exclusive_int32_prefix_sum_gpu expects int32."); + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + using CubIdx = int64_t; + auto count_iter = thrust::counting_iterator(0); + auto key_iter = thrust::make_transform_iterator( + count_iter, MaskSegmentKey{static_cast(segment_size)}); + + size_t workspace_size = 0; + if (segment_size == static_cast(in.size())) { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + nullptr, + workspace_size, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSum( + workspace, + workspace_size, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + encoder.stream())); + } else { + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + nullptr, + workspace_size, + key_iter, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + + void* workspace = allocate_workspace(encoder, workspace_size); + auto capture = encoder.capture_context(); + CHECK_CUDA_ERROR( + cub::DeviceScan::ExclusiveSumByKey( + workspace, + workspace_size, + key_iter, + gpu_ptr(in), + gpu_ptr(out), + static_cast(in.size()), + cuda::std::equal_to<>{}, + encoder.stream())); + } +} + void Scan::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Scan::eval_gpu"); assert(inputs.size() == 1); diff --git a/mlx/backend/cuda/scan.h b/mlx/backend/cuda/scan.h new file mode 100644 index 0000000000..f2074cc6e6 --- /dev/null +++ b/mlx/backend/cuda/scan.h @@ -0,0 +1,23 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/primitives.h" +#include "mlx/stream.h" + +namespace mlx::core { + +void segmented_exclusive_mask_prefix_sum_gpu( + const array& in, + array& out, + int64_t segment_size, + const Stream& s); + +void segmented_exclusive_int32_prefix_sum_gpu( + const array& in, + array& out, + int64_t segment_size, + const Stream& s); + +} // namespace mlx::core diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 20793d5c91..fd2924a411 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -55,8 +55,4 @@ "TestQuantized.test_throw", "TestQuantized.test_vjp_scales_biases", "TestExportImport.test_export_quantized_model", - # Masked scatter - "TestOps.test_masked_scatter", - "TestVmap.test_vmap_masked_scatter", - "TestArray.test_setitem_with_boolean_mask", } diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index ff8d986bd9..25c871cdf9 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1357,11 +1357,6 @@ TEST_CASE("test grad dynamic slices") { } TEST_CASE("test masked_scatter autograd") { - if (cu::is_available()) { - INFO("Skipping masked_scatter cuda autograd tests"); - return; - } - // Test jvp { auto self = array({10.f, 20.f, 30.f, 40.f}, {4}); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..23740b7004 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2437,11 +2437,6 @@ TEST_CASE("test scatter") { } TEST_CASE("test masked_scatter") { - if (cu::is_available()) { - INFO("Skipping masked_scatter cuda ops tests"); - return; - } - // Wrong mask dtype CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2})));