Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub mod r#struct;
pub mod union_extract;
pub mod union_tag;
pub mod version;
pub mod with_metadata;

// create UDFs
make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast);
Expand All @@ -59,6 +60,7 @@ make_udf_function!(union_extract::UnionExtractFun, union_extract);
make_udf_function!(union_tag::UnionTagFunc, union_tag);
make_udf_function!(version::VersionFunc, version);
make_udf_function!(arrow_metadata::ArrowMetadataFunc, arrow_metadata);
make_udf_function!(with_metadata::WithMetadataFunc, with_metadata);

pub mod expr_fn {
use datafusion_expr::{Expr, Literal};
Expand Down Expand Up @@ -95,6 +97,10 @@ pub mod expr_fn {
arrow_metadata,
"Returns the metadata of the input expression",
args,
),(
with_metadata,
"Attaches Arrow field metadata (key/value pairs) to the input expression",
args,
),(
r#struct,
"Returns a struct with the given arguments",
Expand Down Expand Up @@ -148,6 +154,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
arrow_cast(),
arrow_try_cast(),
arrow_metadata(),
with_metadata(),
nvl(),
nvl2(),
overlay(),
Expand Down
314 changes: 314 additions & 0 deletions datafusion/functions/src/core/with_metadata.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{Result, exec_err, internal_err};
use datafusion_expr::{
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
Signature, Volatility,
};
use datafusion_macros::user_doc;

#[user_doc(
doc_section(label = "Other Functions"),
description = "Attaches Arrow field metadata (key/value pairs) to the input expression. Keys and values must be non-empty constant strings. Existing metadata on the input field is preserved; new keys overwrite on collision. This is the inverse of `arrow_metadata`.",
syntax_example = "with_metadata(expression, key1, value1[, key2, value2, ...])",
sql_example = r#"```sql
> select arrow_metadata(with_metadata(column1, 'unit', 'ms'), 'unit') from (values (1));
+---------------------------------------------------------------+
| arrow_metadata(with_metadata(column1,Utf8("unit"),Utf8("ms")),Utf8("unit")) |
+---------------------------------------------------------------+
| ms |
+---------------------------------------------------------------+
> select arrow_metadata(with_metadata(column1, 'unit', 'ms', 'source', 'sensor')) from (values (1));
+--------------------------+
| {source: sensor, unit: ms} |
+--------------------------+
```"#,
argument(
name = "expression",
description = "The expression whose output Arrow field should be annotated. Values flow through unchanged."
),
argument(
name = "key",
description = "Metadata key. Must be a non-empty constant string literal."
),
argument(
name = "value",
description = "Metadata value. Must be a non-empty constant string literal."
)
)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WithMetadataFunc {
signature: Signature,
}

impl Default for WithMetadataFunc {
fn default() -> Self {
Self::new()
}
}

impl WithMetadataFunc {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for WithMetadataFunc {
fn name(&self) -> &str {
"with_metadata"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!(
"with_metadata: return_type called instead of return_field_from_args"
)
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
// Require at least the value expression plus one (key, value) pair,
// and an odd total (1 + 2*N).
if args.arg_fields.len() < 3 {
return exec_err!(
"with_metadata requires the input expression plus at least one (key, value) pair (minimum 3 arguments), got {}",
args.arg_fields.len()
);
}
if args.arg_fields.len().is_multiple_of(2) {
return exec_err!(
"with_metadata requires an odd number of arguments (expression followed by key/value pairs), got {}",
args.arg_fields.len()
);
}

let input_field = &args.arg_fields[0];
let mut metadata = input_field.metadata().clone();

// Keys are at indices 1, 3, 5, ...; values at 2, 4, 6, ...
for pair_idx in 0..((args.scalar_arguments.len() - 1) / 2) {
let key_idx = 1 + pair_idx * 2;
let value_idx = key_idx + 1;

let key = args.scalar_arguments[key_idx]
.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
.ok_or_else(|| {
datafusion_common::DataFusionError::Execution(format!(
"with_metadata requires argument {key_idx} (key) to be a non-empty constant string"
))
})?;

let value = args.scalar_arguments[value_idx]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the public contract says both keys and values must be non empty constant strings, but right now only the keys are enforced.

Values can still be empty strings and get stored without any issue. Would you prefer to add the same non empty check for values here, or relax the docs and tests so the behavior is consistent?

Either way works, it would just be nice for callers to have one clear rule.

.and_then(|sv| sv.try_as_str().flatten())
.ok_or_else(|| {
datafusion_common::DataFusionError::Execution(format!(
"with_metadata requires argument {value_idx} (value) to be a constant string"
))
})?;

metadata.insert(key.to_string(), value.to_string());
}

// Preserve the input field's name, data type, and nullability; only the
// metadata changes. This makes `with_metadata(col, ...)` a true
// pass-through annotation from a schema perspective.
let field = Field::new(
input_field.name(),
input_field.data_type().clone(),
input_field.is_nullable(),
)
.with_metadata(metadata);

Ok(field.into())
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
// Pure value pass-through. The metadata was attached to the return
// field during planning and flows through record batch schemas; the
// physical operator does not need to rebuild arrays.
Ok(args.args[0].clone())
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::Field;
use datafusion_common::ScalarValue;
use std::sync::Arc;

fn field(name: &str, dt: DataType, nullable: bool) -> FieldRef {
Arc::new(Field::new(name, dt, nullable))
}

fn str_lit(s: &str) -> ScalarValue {
ScalarValue::Utf8(Some(s.to_string()))
}

#[test]
fn attaches_single_key() {
let udf = WithMetadataFunc::new();
let input = field("my_col", DataType::Int32, true);
let k = str_lit("unit");
let v = str_lit("ms");
let fields = [
Arc::clone(&input),
field("", DataType::Utf8, false),
field("", DataType::Utf8, false),
];
let scalars = [None, Some(&k), Some(&v)];
let ret = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &scalars,
})
.unwrap();
assert_eq!(ret.name(), "my_col");
assert_eq!(ret.data_type(), &DataType::Int32);
assert!(ret.is_nullable());
assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("ms"));
}

#[test]
fn merges_existing_metadata_and_overwrites_on_collision() {
let udf = WithMetadataFunc::new();
let mut existing = Field::new("x", DataType::Float64, false);
existing.set_metadata(
[
("keep".to_string(), "yes".to_string()),
("unit".to_string(), "old".to_string()),
]
.into_iter()
.collect(),
);
let input: FieldRef = Arc::new(existing);
let k = str_lit("unit");
let v = str_lit("new");
let fields = [
Arc::clone(&input),
field("", DataType::Utf8, false),
field("", DataType::Utf8, false),
];
let scalars = [None, Some(&k), Some(&v)];
let ret = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &scalars,
})
.unwrap();
assert_eq!(ret.name(), "x");
assert!(!ret.is_nullable());
assert_eq!(ret.metadata().get("keep").map(String::as_str), Some("yes"));
assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("new"));
}

#[test]
fn multiple_pairs() {
let udf = WithMetadataFunc::new();
let input = field("c", DataType::Utf8, true);
let k1 = str_lit("a");
let v1 = str_lit("1");
let k2 = str_lit("b");
let v2 = str_lit("2");
let fields = [
Arc::clone(&input),
field("", DataType::Utf8, false),
field("", DataType::Utf8, false),
field("", DataType::Utf8, false),
field("", DataType::Utf8, false),
];
let scalars = [None, Some(&k1), Some(&v1), Some(&k2), Some(&v2)];
let ret = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &scalars,
})
.unwrap();
assert_eq!(ret.metadata().get("a").map(String::as_str), Some("1"));
assert_eq!(ret.metadata().get("b").map(String::as_str), Some("2"));
}

#[test]
fn rejects_even_arity() {
let udf = WithMetadataFunc::new();
let input = field("c", DataType::Int32, true);
let a = str_lit("a");
let b = str_lit("b");
let c = str_lit("c");
// 4 args total: input + 3 literals (odd key count)
let fields = [
Arc::clone(&input),
field("", DataType::Utf8, false),
field("", DataType::Utf8, false),
field("", DataType::Utf8, false),
];
let scalars = [None, Some(&a), Some(&b), Some(&c)];
let err = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &scalars,
})
.unwrap_err();
assert!(err.to_string().contains("odd number"));
}

#[test]
fn rejects_too_few_args() {
let udf = WithMetadataFunc::new();
let input = field("c", DataType::Int32, true);
let k = str_lit("a");
let fields = [Arc::clone(&input), field("", DataType::Utf8, false)];
let scalars = [None, Some(&k)];
let err = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &scalars,
})
.unwrap_err();
assert!(err.to_string().contains("at least one"));
}

#[test]
fn rejects_non_literal_key() {
let udf = WithMetadataFunc::new();
let input = field("c", DataType::Int32, true);
let v = str_lit("v");
let fields = [
Arc::clone(&input),
field("", DataType::Utf8, true),
field("", DataType::Utf8, false),
];
let scalars = [None, None, Some(&v)];
let err = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &scalars,
})
.unwrap_err();
assert!(err.to_string().contains("non-empty constant string"));
}
}
Loading
Loading