diff --git a/datafusion/spark/src/function/array/slice.rs b/datafusion/spark/src/function/array/slice.rs index 6c168a4f491b5..08ae5f32bf4ab 100644 --- a/datafusion/spark/src/function/array/slice.rs +++ b/datafusion/spark/src/function/array/slice.rs @@ -19,7 +19,9 @@ use arrow::array::{Array, ArrayRef, Int64Builder}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cast::{as_int64_array, as_list_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -94,6 +96,12 @@ impl ScalarUDFImpl for SparkSlice { &self, mut func_args: ScalarFunctionArgs, ) -> Result { + if func_args.args[0].data_type() == DataType::Null + && let Some(result) = check_null_types(&func_args.args[0]) + { + return Ok(result); + } + let array_len = func_args .args .iter() @@ -128,6 +136,16 @@ impl ScalarUDFImpl for SparkSlice { } } +fn check_null_types(cv: &ColumnarValue) -> Option { + match cv { + ColumnarValue::Scalar(ScalarValue::Null) => { + Some(ColumnarValue::create_null_array(1)) + } + ColumnarValue::Array(_) => Some(cv.clone()), + _ => None, + } +} + fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> { let [values, start, length] = take_function_args("slice", args)?; @@ -170,3 +188,40 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> { Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish()))) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::NullArray; + use arrow::datatypes::DataType::List; + use arrow::datatypes::Field; + use datafusion_common::ScalarValue; + + #[test] + fn test_spark_slice_function_when_input_array_is_null() { + let input_args = vec![ + ColumnarValue::Array(Arc::new(NullArray::new(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ]; + + let args = ScalarFunctionArgs { + args: input_args, + arg_fields: vec![Arc::new(Field::new( + "item", + List(FieldRef::new(Field::new("f", DataType::Int64, true))), + false, + ))], + number_rows: 1, + return_field: Arc::new(Field::new( + "item", + List(FieldRef::new(Field::new_list_field(DataType::Int64, true))), + false, + )), + config_options: Arc::new(Default::default()), + }; + let slice = SparkSlice::new(); + let result = slice.invoke_with_args(args).unwrap(); + assert_eq!(*result.to_array(1).unwrap(), *Arc::new(NullArray::new(1))); + } +} diff --git a/datafusion/sqllogictest/test_files/spark/array/slice.slt b/datafusion/sqllogictest/test_files/spark/array/slice.slt index 4aba076aba6ba..f7986885ba26e 100644 --- a/datafusion/sqllogictest/test_files/spark/array/slice.slt +++ b/datafusion/sqllogictest/test_files/spark/array/slice.slt @@ -114,3 +114,21 @@ query ? SELECT slice([1, 2, 3, 4], CAST('2' AS INT), 4); ---- [2, 3, 4] + +query ? +SELECT slice(column1, column2, column3) +FROM VALUES +(NULL, 1, 2), +(NULL, 1, -2), +(NULL, -1, 2), +(NULL, 0, 2); +---- +NULL +NULL +NULL +NULL + +query ? +SELECT slice(slice(NULL, 1, 2), 1, 2) +---- +NULL