diff --git a/mlx/backend/cpu/copy.cpp b/mlx/backend/cpu/copy.cpp index f9ff22677a..d14736c15e 100644 --- a/mlx/backend/cpu/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -70,7 +70,6 @@ void copy_general_general( dynamic_i_offset ? dynamic_i_offset->data() : nullptr; auto o_offset_ptr = dynamic_o_offset ? dynamic_o_offset->data() : nullptr; - auto size = src.size(); if (data_shape.empty()) { auto val = static_cast(*src_ptr); *dst_ptr = val; @@ -107,6 +106,8 @@ void copy_general_general( dst_ptr += o_offset_ptr[0]; } + auto size = std::accumulate( + shape.begin(), shape.end(), int64_t{1}, std::multiplies()); ContiguousIterator in(shape, strides[0], ndim - 3); ContiguousIterator out(shape, strides[1], ndim - 3); auto stride = std::accumulate( diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 660ea76c8d..6cd0c49497 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3209,6 +3209,12 @@ def test_dynamic_slicing(self): out = mx.slice(x, mx.array([1, 2, 3]), (0, 1, 2), (3, 2, 1)) self.assertTrue(mx.array_equal(expected, out)) + with mx.stream(mx.cpu): + x = mx.arange(5 * 6 * 7 * 8).reshape(5, 6, 7, 8) + expected = x[1:3, 2:4, 3:5, 4:6] + out = mx.slice(x, mx.array([1, 2, 3, 4]), (0, 1, 2, 3), (2, 2, 2, 2)) + self.assertTrue(mx.array_equal(expected, out)) + x = mx.zeros(shape=(4, 4, 4)) update = mx.random.randint(0, 100, shape=(3, 2, 1)) out = mx.slice_update(x, update, mx.array([1, 2, 3]), (0, 1, 2))