diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index f31d4d52ce88b..24d41ead8af0f 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -489,6 +489,12 @@ pub trait TableProviderFactory: Debug + Sync + Send { pub trait TableFunctionImpl: Debug + Sync + Send { /// Create a table provider fn call(&self, args: &[Expr]) -> Result>; + + /// Returns true if the arguments should be coerced and simplified. + /// Defaults to true for backward compatibility. + fn coerce_arguments(&self) -> bool { + true + } } /// A table that uses a function to generate data @@ -520,4 +526,9 @@ impl TableFunction { pub fn create_table_provider(&self, args: &[Expr]) -> Result> { self.fun.call(args) } + + /// Returns true if the arguments should be coerced and simplified + pub fn coerce_arguments(&self) -> bool { + self.fun.coerce_arguments() + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 9560616c1b6da..4b27514022967 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1842,14 +1842,17 @@ impl ContextProvider for SessionContextProvider<'_> { ); let simplifier = ExprSimplifier::new(simplify_context); let schema = DFSchema::empty(); - let args = args - .into_iter() - .map(|arg| { - simplifier - .coerce(arg, &schema) - .and_then(|e| simplifier.simplify(e)) - }) - .collect::>>()?; + let args = if tbl_func.coerce_arguments() { + args.into_iter() + .map(|arg| { + simplifier + .coerce(arg, &schema) + .and_then(|e| simplifier.simplify(e)) + }) + .collect::>>()? + } else { + args + }; let provider = tbl_func.create_table_provider(&args)?; Ok(provider_as_source(provider)) @@ -2509,3 +2512,167 @@ mod tests { } } } + +#[cfg(test)] +mod udtf_tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use async_trait::async_trait; + use datafusion_catalog::Session; + use datafusion_catalog::{TableFunction, TableFunctionImpl, TableProvider}; + use datafusion_common::{Result, plan_err}; + use datafusion_expr::{Expr, TableType}; + use datafusion_physical_plan::ExecutionPlan; + use std::any::Any; + use std::sync::Arc; + + #[derive(Debug)] + struct MockTableProvider { + schema: SchemaRef, + } + + #[async_trait] + impl TableProvider for MockTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + fn table_type(&self) -> TableType { + TableType::Base + } + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + let schema = self.schema.clone(); + Ok(Arc::new(datafusion_physical_plan::empty::EmptyExec::new( + schema, + ))) + } + } + + #[derive(Debug)] + struct NoCoerceUDTF {} + + impl TableFunctionImpl for NoCoerceUDTF { + fn call(&self, args: &[Expr]) -> Result> { + // Verify that the argument 'index' (which is technically a column reference in SQL) + // survives as an identifier instead of failing coercion because it's missing from the empty schema. + match &args[0] { + Expr::BinaryExpr(be) => { + match be.left.as_ref() { + Expr::Column(c) if c.name == "index" => { + // Success! + } + _ => { + return plan_err!( + "Expected Column('index') on left side, got {:?}", + be.left + ); + } + } + } + _ => return plan_err!("Expected BinaryExpr, got {:?}", args[0]), + } + + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + Ok(Arc::new(MockTableProvider { schema })) + } + + fn coerce_arguments(&self) -> bool { + false + } + } + + #[test] + fn test_udtf_no_coercion() -> Result<()> { + let udtf = Arc::new(TableFunction::new( + "scan_with".to_string(), + Arc::new(NoCoerceUDTF {}), + )); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_table_function_list(vec![udtf]) + .build(); + + let provider = SessionContextProvider { + state: &state, + tables: HashMap::new(), + }; + + // SQL: SELECT * FROM scan_with(index=1) + let args = vec![Expr::BinaryExpr(datafusion_expr::BinaryExpr { + left: Box::new(Expr::Column(datafusion_common::Column::from_name("index"))), + op: datafusion_expr::Operator::Eq, + right: Box::new(Expr::Literal( + datafusion_common::ScalarValue::Int32(Some(1)), + None, + )), + })]; + + let source = provider.get_table_function_source("scan_with", args)?; + assert_eq!(source.schema().fields().len(), 1); + assert_eq!(source.schema().field(0).name(), "a"); + + Ok(()) + } + + #[test] + fn test_udtf_default_coercion() -> Result<()> { + #[derive(Debug)] + struct CoerceUDTF {} + impl TableFunctionImpl for CoerceUDTF { + fn call(&self, _args: &[Expr]) -> Result> { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + Ok(Arc::new(MockTableProvider { schema })) + } + } + + let udtf = Arc::new(TableFunction::new( + "scan_with".to_string(), + Arc::new(CoerceUDTF {}), + )); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_table_function_list(vec![udtf]) + .build(); + + let provider = SessionContextProvider { + state: &state, + tables: HashMap::new(), + }; + + // In SQL: SELECT * FROM scan_with(unknown_col=1) + let args = vec![Expr::BinaryExpr(datafusion_expr::BinaryExpr { + left: Box::new(Expr::Column(datafusion_common::Column::from_name( + "unknown_col", + ))), + op: datafusion_expr::Operator::Eq, + right: Box::new(Expr::Literal( + datafusion_common::ScalarValue::Int32(Some(1)), + None, + )), + })]; + + // Should fail because coercion is ON and "unknown_col" is not in the empty schema. + let res = provider.get_table_function_source("scan_with", args); + match res { + Ok(_) => panic!("Expected error, but got success"), + Err(e) => assert!( + e.to_string() + .contains("Schema error: No field named unknown_col") + ), + } + + Ok(()) + } +}