diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index e8737612a1dcf..3ea2d4abcb069 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -40,6 +40,7 @@ pub mod r#struct; pub mod union_extract; pub mod union_tag; pub mod version; +pub mod with_metadata; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); @@ -59,6 +60,7 @@ make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); make_udf_function!(arrow_metadata::ArrowMetadataFunc, arrow_metadata); +make_udf_function!(with_metadata::WithMetadataFunc, with_metadata); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -95,6 +97,10 @@ pub mod expr_fn { arrow_metadata, "Returns the metadata of the input expression", args, + ),( + with_metadata, + "Attaches Arrow field metadata (key/value pairs) to the input expression", + args, ),( r#struct, "Returns a struct with the given arguments", @@ -148,6 +154,7 @@ pub fn functions() -> Vec> { arrow_cast(), arrow_try_cast(), arrow_metadata(), + with_metadata(), nvl(), nvl2(), overlay(), diff --git a/datafusion/functions/src/core/with_metadata.rs b/datafusion/functions/src/core/with_metadata.rs new file mode 100644 index 0000000000000..481ed713ed7ad --- /dev/null +++ b/datafusion/functions/src/core/with_metadata.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Other Functions"), + description = "Attaches Arrow field metadata (key/value pairs) to the input expression. Keys must be non-empty constant strings and values must be constant strings (empty values are allowed). Existing metadata on the input field is preserved; new keys overwrite on collision. This is the inverse of `arrow_metadata`.", + syntax_example = "with_metadata(expression, key1, value1[, key2, value2, ...])", + sql_example = r#"```sql +> select arrow_metadata(with_metadata(column1, 'unit', 'ms'), 'unit') from (values (1)); ++---------------------------------------------------------------+ +| arrow_metadata(with_metadata(column1,Utf8("unit"),Utf8("ms")),Utf8("unit")) | ++---------------------------------------------------------------+ +| ms | ++---------------------------------------------------------------+ +> select arrow_metadata(with_metadata(column1, 'unit', 'ms', 'source', 'sensor')) from (values (1)); ++--------------------------+ +| {source: sensor, unit: ms} | ++--------------------------+ +```"#, + argument( + name = "expression", + description = "The expression whose output Arrow field should be annotated. Values flow through unchanged." + ), + argument( + name = "key", + description = "Metadata key. Must be a non-empty constant string literal." + ), + argument( + name = "value", + description = "Metadata value. Must be a constant string literal (may be empty)." + ) +)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct WithMetadataFunc { + signature: Signature, +} + +impl Default for WithMetadataFunc { + fn default() -> Self { + Self::new() + } +} + +impl WithMetadataFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for WithMetadataFunc { + fn name(&self) -> &str { + "with_metadata" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!( + "with_metadata: return_type called instead of return_field_from_args" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // Require at least the value expression plus one (key, value) pair, + // and an odd total (1 + 2*N). + if args.arg_fields.len() < 3 { + return exec_err!( + "with_metadata requires the input expression plus at least one (key, value) pair (minimum 3 arguments), got {}", + args.arg_fields.len() + ); + } + if args.arg_fields.len().is_multiple_of(2) { + return exec_err!( + "with_metadata requires an odd number of arguments (expression followed by key/value pairs), got {}", + args.arg_fields.len() + ); + } + + let input_field = &args.arg_fields[0]; + let mut metadata = input_field.metadata().clone(); + + // Keys are at indices 1, 3, 5, ...; values at 2, 4, 6, ... + for pair_idx in 0..((args.scalar_arguments.len() - 1) / 2) { + let key_idx = 1 + pair_idx * 2; + let value_idx = key_idx + 1; + + let key = args.scalar_arguments[key_idx] + .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .ok_or_else(|| { + datafusion_common::DataFusionError::Execution(format!( + "with_metadata requires argument {key_idx} (key) to be a non-empty constant string" + )) + })?; + + let value = args.scalar_arguments[value_idx] + .and_then(|sv| sv.try_as_str().flatten()) + .ok_or_else(|| { + datafusion_common::DataFusionError::Execution(format!( + "with_metadata requires argument {value_idx} (value) to be a constant string" + )) + })?; + + metadata.insert(key.to_string(), value.to_string()); + } + + // Preserve the input field's name, data type, and nullability; only the + // metadata changes. This makes `with_metadata(col, ...)` a true + // pass-through annotation from a schema perspective. + let field = Field::new( + input_field.name(), + input_field.data_type().clone(), + input_field.is_nullable(), + ) + .with_metadata(metadata); + + Ok(field.into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Pure value pass-through. The metadata was attached to the return + // field during planning and flows through record batch schemas; the + // physical operator does not need to rebuild arrays. + Ok(args.args[0].clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Field; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + fn field(name: &str, dt: DataType, nullable: bool) -> FieldRef { + Arc::new(Field::new(name, dt, nullable)) + } + + fn str_lit(s: &str) -> ScalarValue { + ScalarValue::Utf8(Some(s.to_string())) + } + + #[test] + fn attaches_single_key() { + let udf = WithMetadataFunc::new(); + let input = field("my_col", DataType::Int32, true); + let k = str_lit("unit"); + let v = str_lit("ms"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k), Some(&v)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.name(), "my_col"); + assert_eq!(ret.data_type(), &DataType::Int32); + assert!(ret.is_nullable()); + assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("ms")); + } + + #[test] + fn merges_existing_metadata_and_overwrites_on_collision() { + let udf = WithMetadataFunc::new(); + let mut existing = Field::new("x", DataType::Float64, false); + existing.set_metadata( + [ + ("keep".to_string(), "yes".to_string()), + ("unit".to_string(), "old".to_string()), + ] + .into_iter() + .collect(), + ); + let input: FieldRef = Arc::new(existing); + let k = str_lit("unit"); + let v = str_lit("new"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k), Some(&v)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.name(), "x"); + assert!(!ret.is_nullable()); + assert_eq!(ret.metadata().get("keep").map(String::as_str), Some("yes")); + assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("new")); + } + + #[test] + fn multiple_pairs() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Utf8, true); + let k1 = str_lit("a"); + let v1 = str_lit("1"); + let k2 = str_lit("b"); + let v2 = str_lit("2"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k1), Some(&v1), Some(&k2), Some(&v2)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.metadata().get("a").map(String::as_str), Some("1")); + assert_eq!(ret.metadata().get("b").map(String::as_str), Some("2")); + } + + #[test] + fn rejects_even_arity() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let a = str_lit("a"); + let b = str_lit("b"); + let c = str_lit("c"); + // 4 args total: input + 3 literals (odd key count) + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&a), Some(&b), Some(&c)]; + let err = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap_err(); + assert!(err.to_string().contains("odd number")); + } + + #[test] + fn rejects_too_few_args() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let k = str_lit("a"); + let fields = [Arc::clone(&input), field("", DataType::Utf8, false)]; + let scalars = [None, Some(&k)]; + let err = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap_err(); + assert!(err.to_string().contains("at least one")); + } + + #[test] + fn allows_empty_value() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let k = str_lit("unit"); + let v = str_lit(""); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k), Some(&v)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("")); + } + + #[test] + fn rejects_non_literal_key() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let v = str_lit("v"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, true), + field("", DataType::Utf8, false), + ]; + let scalars = [None, None, Some(&v)]; + let err = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap_err(); + assert!(err.to_string().contains("non-empty constant string")); + } +} diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 11ed4fc632e2a..42589481f909e 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -414,5 +414,65 @@ select arrow_metadata(TRY_CAST(id AS BIGINT)) from table_with_metadata limit 1; ---- {metadata_key: the id field} +# with_metadata: attach a single key and read it back +query T +select arrow_metadata(with_metadata(id, 'unit', 'ms'), 'unit') from table_with_metadata limit 1; +---- +ms + +# with_metadata: attach multiple keys in one call +query ? +select arrow_metadata(with_metadata(id, 'unit', 'ms', 'source', 'sensor')) from table_with_metadata limit 1; +---- +{metadata_key: the id field, source: sensor, unit: ms} + +# with_metadata: merge with existing field metadata (preserves upstream keys) +query T +select arrow_metadata(with_metadata(id, 'unit', 'ms'), 'metadata_key') from table_with_metadata limit 1; +---- +the id field + +# with_metadata: new keys overwrite existing on collision +query T +select arrow_metadata(with_metadata(id, 'metadata_key', 'overridden'), 'metadata_key') from table_with_metadata limit 1; +---- +overridden + +# with_metadata: nesting composes (inner keys + outer keys, outer wins on collision) +query ? +select arrow_metadata(with_metadata(with_metadata(id, 'a', '1'), 'b', '2')) from table_with_metadata limit 1; +---- +{a: 1, b: 2, metadata_key: the id field} + +# with_metadata: values pass through unchanged +query I +select with_metadata(id, 'unit', 'ms') from table_with_metadata order by id nulls last; +---- +1 +3 +NULL + +# with_metadata: error on even arity (missing value for last key) +statement error with_metadata requires an odd number of arguments +select with_metadata(id, 'a', '1', 'b') from table_with_metadata; + +# with_metadata: error on too few args +statement error with_metadata requires the input expression plus at least one +select with_metadata(id) from table_with_metadata; + +# with_metadata: error on non-literal key +statement error with_metadata requires argument 1 \(key\) to be a non-empty constant string +select with_metadata(id, name, 'v') from table_with_metadata; + +# with_metadata: error on empty key +statement error with_metadata requires argument 1 \(key\) to be a non-empty constant string +select with_metadata(id, '', 'v') from table_with_metadata; + +# with_metadata: empty values are allowed +query T +select arrow_metadata(with_metadata(id, 'unit', ''), 'unit') from table_with_metadata limit 1; +---- +(empty) + statement ok drop table table_with_metadata; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c303b43fc8844..1d807438f5f32 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -5287,6 +5287,7 @@ union_tag(union_expression) - [arrow_typeof](#arrow_typeof) - [get_field](#get_field) - [version](#version) +- [with_metadata](#with_metadata) ### `arrow_cast` @@ -5475,3 +5476,32 @@ version() | Apache DataFusion 42.0.0, aarch64 on macos | +--------------------------------------------+ ``` + +### `with_metadata` + +Attaches Arrow field metadata (key/value pairs) to the input expression. Keys must be non-empty constant strings and values must be constant strings (empty values are allowed). Existing metadata on the input field is preserved; new keys overwrite on collision. This is the inverse of `arrow_metadata`. + +```sql +with_metadata(expression, key1, value1[, key2, value2, ...]) +``` + +#### Arguments + +- **expression**: The expression whose output Arrow field should be annotated. Values flow through unchanged. +- **key**: Metadata key. Must be a non-empty constant string literal. +- **value**: Metadata value. Must be a constant string literal (may be empty). + +#### Example + +```sql +> select arrow_metadata(with_metadata(column1, 'unit', 'ms'), 'unit') from (values (1)); ++---------------------------------------------------------------+ +| arrow_metadata(with_metadata(column1,Utf8("unit"),Utf8("ms")),Utf8("unit")) | ++---------------------------------------------------------------+ +| ms | ++---------------------------------------------------------------+ +> select arrow_metadata(with_metadata(column1, 'unit', 'ms', 'source', 'sensor')) from (values (1)); ++--------------------------+ +| {source: sensor, unit: ms} | ++--------------------------+ +```