diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 1a3e743784d..ae39977aaea 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -217,15 +217,15 @@ struct ReplaceMaskImpl> { static Result ExecScalarMask(KernelContext* ctx, const ArraySpan& array, const BooleanScalar& mask, ExecValue replacements, int64_t replacements_offset, ExecResult* out) { - out->value = array; - return Status::OK(); + out->value = array.ToArrayData(); + return replacements_offset; } static Result ExecArrayMask(KernelContext* ctx, const ArraySpan& array, const ArraySpan& mask, int64_t mask_offset, ExecValue replacements, int64_t replacements_offset, ExecResult* out) { - out->value = array; - return Status::OK(); + out->value = array.ToArrayData(); + return replacements_offset; } }; diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc index 587b9f2a60e..9dc8e70ab6a 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -199,6 +199,13 @@ class TestReplaceBoolean : public TestReplaceKernel { } }; +class TestReplaceNull : public TestReplaceKernel { + protected: + std::shared_ptr type() override { + return TypeTraits::type_singleton(); + } +}; + class TestReplaceFixedSizeBinary : public TestReplaceKernel { protected: std::shared_ptr type() override { return fixed_size_binary(3); } @@ -538,6 +545,35 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) { } } +TEST_F(TestReplaceNull, ReplaceWithMask) { + std::vector cases = { + {this->array("[]"), this->mask_scalar(false), this->array("[]"), this->array("[]")}, + {this->array("[]"), this->mask_scalar(true), this->array("[]"), this->array("[]")}, + {this->array("[]"), this->null_mask_scalar(), this->array("[]"), this->array("[]")}, + + {this->array("[null]"), this->mask_scalar(false), this->array("[]"), + this->array("[null]")}, + + {this->array("[null]"), this->mask_scalar(true), this->array("[null]"), + this->array("[null]")}, + + {this->array("[null]"), this->null_mask_scalar(), this->array("[]"), + this->array("[null]")}, + + {this->array("[null, null]"), this->mask("[false, false]"), this->array("[]"), + this->array("[null, null]")}, + {this->array("[null, null]"), this->mask("[true, true]"), + this->array("[null, null]"), this->array("[null, null]")}, + {this->array("[null, null]"), this->mask("[null, null]"), this->array("[]"), + this->array("[null, null]")}, + }; + + for (auto test_case : cases) { + this->Assert(ReplaceWithMask, test_case.input, test_case.mask, test_case.replacements, + test_case.expected); + } +} + TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) { EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 4e44a912d96..38489c2e522 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -1987,6 +1987,34 @@ def test_fill_null_array(arrow_type): assert result.equals(expected) +def test_replace_with_mask_null_type(): + # GH-47447: replace_with_mask crashed for null type arrays + a = pa.array([None], pa.null()) + b = pa.array([None], pa.null()) + + result = pc.replace_with_mask(a, True, b) + assert result.type == pa.null() + result.validate(full=True) + assert result.to_pylist() == [None] + + result = pc.replace_with_mask(a, False, b) + assert result.type == pa.null() + result.validate(full=True) + assert result.to_pylist() == [None] + + mask = pa.array([True]) + result = pc.replace_with_mask(a, mask, b) + assert result.type == pa.null() + result.validate(full=True) + assert result.to_pylist() == [None] + + mask = pa.array([False]) + result = pc.replace_with_mask(a, mask, b) + assert result.type == pa.null() + result.validate(full=True) + assert result.to_pylist() == [None] + + @pytest.mark.parametrize('arrow_type', numerical_arrow_types) def test_fill_null_chunked_array(arrow_type): fill_value = pa.scalar(5, type=arrow_type)