diff --git a/datafusion/functions-nested/src/array_subtract.rs b/datafusion/functions-nested/src/array_subtract.rs new file mode 100644 index 000000000000..24600da04f74 --- /dev/null +++ b/datafusion/functions-nested/src/array_subtract.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_subtract function. + +use crate::utils::{ + array_math_binary_op, coerce_array_math_arg_types, make_scalar_function, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{ + DataType, + DataType::{LargeList, List}, +}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; + +make_udf_expr_and_func!( + ArraySubtract, + array_subtract, + array1 array2, + "returns the element-wise difference of two numeric arrays.", + array_subtract_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the element-wise difference of two numeric arrays of equal length, computed as `array1[i] - array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty.", + syntax_example = "array_subtract(array1, array2)", + sql_example = r#"```sql +> select array_subtract([10.0, 20.0, 30.0], [1.0, 2.0, 3.0]); ++--------------------------------------------------------------+ +| array_subtract(List([10.0,20.0,30.0]),List([1.0,2.0,3.0])) | ++--------------------------------------------------------------+ +| [9.0, 18.0, 27.0] | ++--------------------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArraySubtract { + signature: Signature, + aliases: Vec, +} + +impl Default for ArraySubtract { + fn default() -> Self { + Self::new() + } +} + +impl ArraySubtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_subtract".to_string()], + } + } +} + +impl ScalarUDFImpl for ArraySubtract { + fn name(&self) -> &str { + "array_subtract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + coerce_array_math_arg_types(self.name(), arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_subtract_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_subtract_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("array_subtract", args)?; + let sub = |a: f64, b: f64| a - b; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => { + array_math_binary_op::("array_subtract", array1, array2, sub) + } + (LargeList(_), LargeList(_)) => { + array_math_binary_op::("array_subtract", array1, array2, sub) + } + (arg_type1, arg_type2) => exec_err!( + "array_subtract received unexpected types after coercion: {arg_type1} and {arg_type2}" + ), + } +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index acb797845277..13c79c050465 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -47,6 +47,7 @@ pub mod array_filter; pub mod array_has; pub mod array_normalize; pub mod array_scale; +pub mod array_subtract; pub mod array_transform; pub mod arrays_zip; pub mod cardinality; @@ -99,6 +100,7 @@ pub mod expr_fn { pub use super::array_has::array_has_any; pub use super::array_normalize::array_normalize; pub use super::array_scale::array_scale; + pub use super::array_subtract::array_subtract; pub use super::array_transform::array_transform; pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; @@ -176,6 +178,7 @@ pub fn all_default_nested_functions() -> Vec> { array_normalize::array_normalize_udf(), array_add::array_add_udf(), array_scale::array_scale_udf(), + array_subtract::array_subtract_udf(), cosine_distance::cosine_distance_udf(), inner_product::inner_product_udf(), distance::array_distance_udf(), diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index bdd71f2ff8f2..1b2bf428ff2d 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,12 +22,13 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, + Array, ArrayRef, BooleanArray, Float64Array, GenericListArray, NullBufferBuilder, + OffsetBufferBuilder, OffsetSizeTrait, Scalar, }; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use datafusion_common::cast::{ - as_fixed_size_list_array, as_large_list_array, as_large_list_view_array, - as_list_array, as_list_view_array, + as_fixed_size_list_array, as_float64_array, as_generic_list_array, + as_large_list_array, as_large_list_view_array, as_list_array, as_list_view_array, }; use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err}; @@ -327,6 +328,90 @@ pub(crate) fn coerce_array_math_arg_types( Ok(coerced) } +/// Element-wise binary operation kernel for two `Float64` lists of equal per-row +/// length. The caller is responsible for type-dispatching on `O` (`i32` for +/// `List`, `i64` for `LargeList`). +/// +/// Semantics: +/// - whole-row NULL on either side → NULL output row, length 0 +/// - per-element NULL on either side → NULL at that output position +/// - per-row length mismatch → exec error tagged with `op_name` +/// +/// `op_name` flows into the error message; `op` is the per-element scalar op +/// (e.g. `|a, b| a + b` for `array_add`, `|a, b| a - b` for `array_subtract`). +pub(crate) fn array_math_binary_op( + op_name: &str, + lhs: &ArrayRef, + rhs: &ArrayRef, + op: F, +) -> Result +where + O: OffsetSizeTrait, + F: Fn(f64, f64) -> f64, +{ + let lhs = as_generic_list_array::(lhs)?; + let rhs = as_generic_list_array::(rhs)?; + + let lhs_values = as_float64_array(lhs.values())?; + let rhs_values = as_float64_array(rhs.values())?; + let lhs_offsets = lhs.value_offsets(); + let rhs_offsets = rhs.value_offsets(); + + let row_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls()); + + let mut out_values: Vec = Vec::with_capacity(lhs_values.len()); + let mut out_inner_nulls = NullBufferBuilder::new(lhs_values.len()); + let mut out_offsets = OffsetBufferBuilder::::new(lhs.len()); + + for row in 0..lhs.len() { + if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) { + out_offsets.push_length(0); + continue; + } + + let start1 = lhs_offsets[row].as_usize(); + let len1 = lhs.value_length(row).as_usize(); + let start2 = rhs_offsets[row].as_usize(); + let len2 = rhs.value_length(row).as_usize(); + + if len1 != len2 { + return exec_err!( + "{op_name} requires both list inputs to have the same length per row, got {len1} and {len2} at row {row}" + ); + } + + let l_slice = lhs_values.slice(start1, len1); + let r_slice = rhs_values.slice(start2, len2); + + let l_vals = l_slice.values(); + let r_vals = r_slice.values(); + + for i in 0..len1 { + out_values.push(op(l_vals[i], r_vals[i])); + } + + match NullBuffer::union(l_slice.nulls(), r_slice.nulls()) { + Some(nb) => out_inner_nulls.append_buffer(&nb), + None => out_inner_nulls.append_n_non_nulls(len1), + } + + out_offsets.push_length(len1); + } + + let values_array = Arc::new(Float64Array::new( + out_values.into(), + out_inner_nulls.finish(), + )); + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + + Ok(Arc::new(GenericListArray::::try_new( + field, + out_offsets.finish(), + values_array, + row_nulls, + )?)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sqllogictest/test_files/array_subtract.slt b/datafusion/sqllogictest/test_files/array_subtract.slt new file mode 100644 index 000000000000..4a680c93aae9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/array_subtract.slt @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## array_subtract + +# Basic element-wise difference +query ? +select array_subtract([10.0, 20.0, 30.0], [1.0, 2.0, 3.0]); +---- +[9.0, 18.0, 27.0] + +# Negative components +query ? +select array_subtract([1.0, -2.0, 3.0], [-1.0, 2.0, -3.0]); +---- +[2.0, -4.0, 6.0] + +# Single-element arrays +query ? +select array_subtract([7.0], [5.0]); +---- +[2.0] + +# Bare NULL on left -> NULL row +query ? +select array_subtract(NULL, [1.0, 2.0]); +---- +NULL + +# Bare NULL on right -> NULL row +query ? +select array_subtract([1.0, 2.0], NULL); +---- +NULL + +# Both bare NULL -> NULL row +query ? +select array_subtract(NULL, NULL); +---- +NULL + +# NULL element on left propagates to that position only +query ? +select array_subtract([10.0, NULL, 30.0], [1.0, 2.0, 3.0]); +---- +[9.0, NULL, 27.0] + +# NULL element on right propagates to that position only +query ? +select array_subtract([10.0, 20.0, 30.0], [1.0, NULL, 3.0]); +---- +[9.0, NULL, 27.0] + +# NULL element on both sides at the same position +query ? +select array_subtract([10.0, NULL, 30.0], [1.0, NULL, 3.0]); +---- +[9.0, NULL, 27.0] + +# NULL elements at different positions both propagate +query ? +select array_subtract([10.0, NULL, 30.0], [NULL, 2.0, 3.0]); +---- +[NULL, NULL, 27.0] + +# Length mismatch is an exec error +query error array_subtract requires both list inputs to have the same length per row +select array_subtract([1.0, 2.0], [10.0, 20.0, 30.0]); + +# Empty arrays on both sides return empty array +query ? +select array_subtract(arrow_cast(make_array(), 'List(Float64)'), arrow_cast(make_array(), 'List(Float64)')); +---- +[] + +# Integer literals coerced to Float64 +query ? +select array_subtract([10, 20, 30], [1, 2, 3]); +---- +[9.0, 18.0, 27.0] + +# Mixed int + float literals coerced to Float64 +query ? +select array_subtract([1, 2, 3], [0.5, 0.5, 0.5]); +---- +[0.5, 1.5, 2.5] + +# LargeList input on both sides +query ? +select array_subtract( + arrow_cast([10.0, 20.0, 30.0], 'LargeList(Float64)'), + arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)') +); +---- +[9.0, 18.0, 27.0] + +# Mixed List + LargeList -> both widened to LargeList +query ? +select array_subtract( + [10.0, 20.0, 30.0], + arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)') +); +---- +[9.0, 18.0, 27.0] + +# FixedSizeList input (coerced to List) +query ? +select array_subtract( + arrow_cast([10.0, 20.0, 30.0], 'FixedSizeList(3, Float64)'), + arrow_cast([1.0, 2.0, 3.0], 'FixedSizeList(3, Float64)') +); +---- +[9.0, 18.0, 27.0] + +# Float32 inner type on one side +query ? +select array_subtract( + arrow_cast([10.0, 20.0, 30.0], 'List(Float32)'), + [1.0, 2.0, 3.0] +); +---- +[9.0, 18.0, 27.0] + +# Int64 inner type +query ? +select array_subtract( + arrow_cast([10, 20, 30], 'List(Int64)'), + arrow_cast([1, 2, 3], 'List(Int64)') +); +---- +[9.0, 18.0, 27.0] + +# Unsupported non-list input (plan error) +query error array_subtract does not support type +select array_subtract(1, [1.0, 2.0]); + +# Wrong arg count +query error array_subtract function requires 2 arguments, got 0 +select array_subtract(); + +query error array_subtract function requires 2 arguments, got 1 +select array_subtract([1.0, 2.0]); + +# Return type matches input variant +query ?T +select array_subtract([1.0, 2.0], [3.0, 4.0]), arrow_typeof(array_subtract([1.0, 2.0], [3.0, 4.0])); +---- +[-2.0, -2.0] List(Float64) + +# Multi-row query: normal row, NULL row, element-NULL row, length-matched row +query ? +select array_subtract(a, b) from (values + (make_array(10.0, 20.0, 30.0), make_array(1.0, 2.0, 3.0)), + (NULL, make_array(1.0, 2.0, 3.0)), + (make_array(1.0, 2.0, 3.0), NULL), + (make_array(10.0, NULL, 30.0), make_array(1.0, 2.0, 3.0)) +) as t(a, b); +---- +[9.0, 18.0, 27.0] +NULL +NULL +[9.0, NULL, 27.0] + +# list_subtract alias +query ? +select list_subtract([3.0, 4.0], [1.0, 2.0]); +---- +[2.0, 2.0] + +# list_subtract alias multi-row +query ? +select list_subtract(a, b) from (values + (make_array(10.0, 20.0), make_array(1.0, 2.0)), + (NULL, make_array(1.0, 2.0)) +) as t(a, b); +---- +[9.0, 18.0] +NULL + +# Decimal element types are coerced to Float64 (lossy) like other array-math UDFs +query ? +select array_subtract( + arrow_cast([10, 20, 30], 'List(Decimal128(10, 2))'), + arrow_cast([1, 2, 3], 'List(Decimal128(10, 2))') +); +---- +[9.0, 18.0, 27.0] + +# Explicit cast to DOUBLE works as the documented opt-in +query ? +select array_subtract( + arrow_cast(arrow_cast([10, 20, 30], 'List(Decimal128(10, 2))'), 'List(Float64)'), + [1.0, 2.0, 3.0] +); +---- +[9.0, 18.0, 27.0] + +# Chained array_subtract: result of inner call feeds the outer call +query ? +select array_subtract(array_subtract([100.0, 200.0, 300.0], [10.0, 20.0, 30.0]), [1.0, 2.0, 3.0]); +---- +[89.0, 178.0, 267.0] + +# Chained array_subtract propagates element-level NULLs through both layers +query ? +select array_subtract( + array_subtract([100.0, NULL, 300.0], [10.0, 20.0, 30.0]), + [1.0, 2.0, NULL] +); +---- +[89.0, NULL, NULL] + +# Chained array_subtract over multiple rows +query ? +select array_subtract(array_subtract(a, b), c) from (values + (make_array(100.0, 200.0), make_array(10.0, 20.0), make_array(1.0, 2.0)), + (NULL, make_array(1.0, 2.0), make_array(3.0, 4.0)), + (make_array(100.0, 200.0), make_array(10.0, NULL), make_array(1.0, 2.0)) +) as t(a, b, c); +---- +[89.0, 178.0] +NULL +[89.0, NULL] \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ccb171b4f57e..5074ddb24bbc 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3290,6 +3290,7 @@ _Alias of [current_date](#current_date)._ - [array_scale](#array_scale) - [array_slice](#array_slice) - [array_sort](#array_sort) +- [array_subtract](#array_subtract) - [array_to_string](#array_to_string) - [array_transform](#array_transform) - [array_union](#array_union) @@ -3347,6 +3348,7 @@ _Alias of [current_date](#current_date)._ - [list_scale](#list_scale) - [list_slice](#list_slice) - [list_sort](#list_sort) +- [list_subtract](#list_subtract) - [list_to_string](#list_to_string) - [list_transform](#list_transform) - [list_union](#list_union) @@ -4513,6 +4515,34 @@ array_sort(array, desc, nulls_first) - list_sort +### `array_subtract` + +Returns the element-wise difference of two numeric arrays of equal length, computed as `array1[i] - array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty. + +```sql +array_subtract(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_subtract([10.0, 20.0, 30.0], [1.0, 2.0, 3.0]); ++--------------------------------------------------------------+ +| array_subtract(List([10.0,20.0,30.0]),List([1.0,2.0,3.0])) | ++--------------------------------------------------------------+ +| [9.0, 18.0, 27.0] | ++--------------------------------------------------------------+ +``` + +#### Aliases + +- list_subtract + ### `array_to_string` Converts each element to its text representation. @@ -4985,6 +5015,10 @@ _Alias of [array_slice](#array_slice)._ _Alias of [array_sort](#array_sort)._ +### `list_subtract` + +_Alias of [array_subtract](#array_subtract)._ + ### `list_to_string` _Alias of [array_to_string](#array_to_string)._