diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index e86eaf8111b1c..69ba288e42d0d 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -16,16 +16,16 @@ // under the License. use std::any::Any; -use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, + ArrayAccessor, ArrayIter, ArrayRef, AsArray, LargeStringBuilder, StringBuilder, + StringLikeArrayBuilder, StringViewBuilder, }; use arrow::datatypes::DataType; use datafusion_common::HashMap; use unicode_segmentation::UnicodeSegmentation; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use datafusion_common::{Result, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -93,7 +93,7 @@ impl ScalarUDFImpl for TranslateFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "translate") + Ok(arg_types[0].clone()) } fn invoke_with_args( @@ -116,33 +116,42 @@ impl ScalarUDFImpl for TranslateFunc { let ascii_table = build_ascii_translate_table(from_str, to_str); let string_array = args.args[0].to_array_of_size(args.number_rows)?; + let len = string_array.len(); let result = match string_array.data_type() { DataType::Utf8View => { let arr = string_array.as_string_view(); - translate_with_map::( + let builder = StringViewBuilder::with_capacity(len); + translate_with_map( arr, &from_map, &to_graphemes, ascii_table.as_ref(), + builder, ) } DataType::Utf8 => { let arr = string_array.as_string::(); - translate_with_map::( + let builder = + StringBuilder::with_capacity(len, arr.value_data().len()); + translate_with_map( arr, &from_map, &to_graphemes, ascii_table.as_ref(), + builder, ) } DataType::LargeUtf8 => { let arr = string_array.as_string::(); - translate_with_map::( + let builder = + LargeStringBuilder::with_capacity(len, arr.value_data().len()); + translate_with_map( arr, &from_map, &to_graphemes, ascii_table.as_ref(), + builder, ) } other => { @@ -172,24 +181,30 @@ fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> { } fn invoke_translate(args: &[ArrayRef]) -> Result { + let len = args[0].len(); match args[0].data_type() { DataType::Utf8View => { let string_array = args[0].as_string_view(); let from_array = args[1].as_string::(); let to_array = args[2].as_string::(); - translate::(string_array, from_array, to_array) + let builder = StringViewBuilder::with_capacity(len); + translate(string_array, from_array, to_array, builder) } DataType::Utf8 => { let string_array = args[0].as_string::(); let from_array = args[1].as_string::(); let to_array = args[2].as_string::(); - translate::(string_array, from_array, to_array) + let builder = + StringBuilder::with_capacity(len, string_array.value_data().len()); + translate(string_array, from_array, to_array, builder) } DataType::LargeUtf8 => { let string_array = args[0].as_string::(); let from_array = args[1].as_string::(); let to_array = args[2].as_string::(); - translate::(string_array, from_array, to_array) + let builder = + LargeStringBuilder::with_capacity(len, string_array.value_data().len()); + translate(string_array, from_array, to_array, builder) } other => { exec_err!("Unsupported data type {other:?} for function translate") @@ -199,14 +214,16 @@ fn invoke_translate(args: &[ArrayRef]) -> Result { /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. /// translate('12345', '143', 'ax') = 'a2x5' -fn translate<'a, T: OffsetSizeTrait, V, B>( +fn translate<'a, V, B, O>( string_array: V, from_array: B, to_array: B, + mut builder: O, ) -> Result where V: ArrayAccessor, B: ArrayAccessor, + O: StringLikeArrayBuilder, { let string_array_iter = ArrayIter::new(string_array); let from_array_iter = ArrayIter::new(from_array); @@ -219,10 +236,9 @@ where let mut string_graphemes: Vec<&str> = Vec::new(); let mut result_graphemes: Vec<&str> = Vec::new(); - let result = string_array_iter - .zip(from_array_iter) - .zip(to_array_iter) - .map(|((string, from), to)| match (string, from, to) { + for ((string, from), to) in string_array_iter.zip(from_array_iter).zip(to_array_iter) + { + match (string, from, to) { (Some(string), Some(from), Some(to)) => { // Clear and reuse buffers from_map.clear(); @@ -254,13 +270,13 @@ where } } - Some(result_graphemes.concat()) + builder.append_value(&result_graphemes.concat()); } - _ => None, - }) - .collect::>(); + _ => builder.append_null(), + } + } - Ok(Arc::new(result) as ArrayRef) + Ok(builder.finish()) } /// Sentinel value in the ASCII translate table indicating the character should @@ -300,21 +316,23 @@ fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> { /// translation map instead of rebuilding it for every row. When an ASCII byte /// lookup table is provided, ASCII input rows use the lookup table; non-ASCII /// inputs fallback to using the map. -fn translate_with_map<'a, T: OffsetSizeTrait, V>( +fn translate_with_map<'a, V, O>( string_array: V, from_map: &HashMap<&str, usize>, to_graphemes: &[&str], ascii_table: Option<&[u8; 128]>, + mut builder: O, ) -> Result where V: ArrayAccessor, + O: StringLikeArrayBuilder, { let mut result_graphemes: Vec<&str> = Vec::new(); let mut ascii_buf: Vec = Vec::new(); - let result = ArrayIter::new(string_array) - .map(|string| { - string.map(|s| { + for string in ArrayIter::new(string_array) { + match string { + Some(s) => { // Fast path: byte-level table lookup for ASCII strings if let Some(table) = ascii_table && s.is_ascii() @@ -327,37 +345,38 @@ where } } // SAFETY: all bytes are ASCII, hence valid UTF-8. - return unsafe { - std::str::from_utf8_unchecked(&ascii_buf).to_owned() - }; - } - - // Slow path: grapheme-based translation - result_graphemes.clear(); - - for c in s.graphemes(true) { - match from_map.get(c) { - Some(n) => { - if let Some(replacement) = to_graphemes.get(*n) { - result_graphemes.push(*replacement); + builder.append_value(unsafe { + std::str::from_utf8_unchecked(&ascii_buf) + }); + } else { + // Slow path: grapheme-based translation + result_graphemes.clear(); + + for c in s.graphemes(true) { + match from_map.get(c) { + Some(n) => { + if let Some(replacement) = to_graphemes.get(*n) { + result_graphemes.push(*replacement); + } } + None => result_graphemes.push(c), } - None => result_graphemes.push(c), } - } - result_graphemes.concat() - }) - }) - .collect::>(); + builder.append_value(&result_graphemes.concat()); + } + } + None => builder.append_null(), + } + } - Ok(Arc::new(result) as ArrayRef) + Ok(builder.finish()) } #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -453,6 +472,45 @@ mod tests { Utf8, StringArray ); + // Utf8View input should produce Utf8View output + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(Some("a2x5")), + &str, + Utf8View, + StringViewArray + ); + // Null Utf8View input + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + // Non-ASCII Utf8View input + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))), + ColumnarValue::Scalar(ScalarValue::from("éñí")), + ColumnarValue::Scalar(ScalarValue::from("óü")) + ], + Ok(Some("ó2ü5")), + &str, + Utf8View, + StringViewArray + ); #[cfg(not(feature = "unicode_expressions"))] test_function!( diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 6cf02218872df..f548a37cf8335 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -1768,3 +1768,25 @@ SELECT ; ---- 48 176 32 40 + +# translate preserves input string type + +query T +SELECT translate(arrow_cast('12345', 'Utf8View'), '143', 'ax') +---- +a2x5 + +query T +SELECT arrow_typeof(translate('12345', '143', 'ax')) +---- +Utf8 + +query T +SELECT arrow_typeof(translate(arrow_cast('12345', 'LargeUtf8'), '143', 'ax')) +---- +LargeUtf8 + +query T +SELECT arrow_typeof(translate(arrow_cast('12345', 'Utf8View'), '143', 'ax')) +---- +Utf8View