diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index bf84fcc53e957..c27864c5dd156 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -777,14 +777,24 @@ impl DefaultPhysicalPlanner { input, .. }) => { + let analysis = analyze_dml_input(input, table_name)?; + // TODO: remove this guard once UPDATE ... FROM routing via + // TableProvider::update_from(...) and joined assignment handling land. + // See https://github.com/apache/datafusion/issues/19950. + if analysis.has_joined_input { + return not_impl_err!("UPDATE ... FROM is not supported"); + } + if let Some(provider) = target.as_any().downcast_ref::() { // For UPDATE, the assignments are encoded in the projection of input // We pass the filters and let the provider handle the projection - let filters = extract_dml_filters(input, table_name)?; + let filters = + extract_dml_filters_with_analysis(input, table_name, &analysis)?; // Extract assignments from the projection in input plan - let assignments = extract_update_assignments(input)?; + let assignments = + extract_update_assignments_with_analysis(input, &analysis)?; provider .table_provider .update(session_state, assignments, filters) @@ -2098,27 +2108,24 @@ fn extract_dml_filters( input: &Arc, target: &TableReference, ) -> Result> { - let mut filters = Vec::new(); - let mut allowed_refs = vec![target.clone()]; + let analysis = analyze_dml_input(input, target)?; + extract_dml_filters_with_analysis(input, target, &analysis) +} - // First pass: collect any alias references to the target table - input.apply(|node| { - if let LogicalPlan::SubqueryAlias(alias) = node - // Check if this alias points to the target table - && let LogicalPlan::TableScan(scan) = alias.input.as_ref() - && scan.table_name.resolved_eq(target) - { - allowed_refs.push(TableReference::bare(alias.alias.to_string())); - } - Ok(TreeNodeRecursion::Continue) - })?; +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key +fn extract_dml_filters_with_analysis( + input: &Arc, + target: &TableReference, + analysis: &DmlInputAnalysis, +) -> Result> { + let mut filters = Vec::new(); input.apply(|node| { match node { LogicalPlan::Filter(filter) => { // Split AND predicates into individual expressions for predicate in split_conjunction(&filter.predicate) { - if predicate_is_on_target_multi(predicate, &allowed_refs)? { + if predicate_is_on_target_multi(predicate, &analysis.target_refs)? { filters.push(predicate.clone()); } } @@ -2194,6 +2201,45 @@ fn extract_dml_filters( }) } +#[derive(Debug)] +struct DmlInputAnalysis { + has_joined_input: bool, + target_refs: Vec, +} + +fn analyze_dml_input( + input: &Arc, + target: &TableReference, +) -> Result { + let mut analysis = DmlInputAnalysis { + has_joined_input: false, + target_refs: vec![target.clone()], + }; + analyze_target_branch(input, &mut analysis); + Ok(analysis) +} + +fn analyze_target_branch(input: &Arc, analysis: &mut DmlInputAnalysis) { + let mut current = input; + loop { + match current.as_ref() { + LogicalPlan::Projection(projection) => current = &projection.input, + LogicalPlan::Filter(filter) => current = &filter.input, + LogicalPlan::SubqueryAlias(alias) => { + analysis + .target_refs + .push(TableReference::bare(alias.alias.to_string())); + current = &alias.input; + } + LogicalPlan::Join(join) => { + analysis.has_joined_input = true; + current = &join.left; + } + _ => return, + } + } +} + /// Determine whether a predicate references only columns from the target table /// or its aliases. /// @@ -2240,61 +2286,98 @@ fn strip_column_qualifiers(expr: Expr) -> Result { /// Extract column assignments from an UPDATE input plan. /// For UPDATE statements, the SQL planner encodes assignments as a projection /// over the source table. This function extracts column name and expression pairs -/// from the projection. Column qualifiers are stripped from the expressions. +/// from the projection. Column qualifiers are stripped only for single-table +/// updates so provider-facing expressions remain resolvable. /// -fn extract_update_assignments(input: &Arc) -> Result> { +#[cfg(test)] +fn extract_update_assignments( + input: &Arc, + target: &TableReference, +) -> Result> { // The UPDATE input plan structure is: // Projection(updated columns as expressions with aliases) // Filter(optional WHERE clause) // TableScan // // Each projected expression has an alias matching the column name - let mut assignments = Vec::new(); + let analysis = analyze_dml_input(input, target)?; + extract_update_assignments_with_analysis(input, &analysis) +} - // Find the top-level projection +fn extract_update_assignments_with_analysis( + input: &Arc, + analysis: &DmlInputAnalysis, +) -> Result> { + find_update_projection(input)? + .map(|projection| { + append_update_assignments( + projection, + &analysis.target_refs, + !analysis.has_joined_input, + ) + }) + .transpose() + .map(|assignments| assignments.unwrap_or_default()) +} + +fn find_update_projection(input: &Arc) -> Result> { if let LogicalPlan::Projection(projection) = input.as_ref() { - for expr in &projection.expr { - if let Expr::Alias(alias) = expr { - // The alias name is the column name being updated - // The inner expression is the new value - let column_name = alias.name.clone(); - // Only include if it's not just a column reference to itself - // (those are columns that aren't being updated) - if !is_identity_assignment(&alias.expr, &column_name) { - // Strip qualifiers from the assignment expression - let stripped_expr = strip_column_qualifiers((*alias.expr).clone())?; - assignments.push((column_name, stripped_expr)); - } - } + return Ok(Some(projection)); + } + + let mut found_projection = None; + input.apply(|node| { + if let LogicalPlan::Projection(projection) = node { + found_projection = Some(projection); + return Ok(TreeNodeRecursion::Stop); } - } else { - // Try to find projection deeper in the plan - input.apply(|node| { - if let LogicalPlan::Projection(projection) = node { - for expr in &projection.expr { - if let Expr::Alias(alias) = expr { - let column_name = alias.name.clone(); - if !is_identity_assignment(&alias.expr, &column_name) { - let stripped_expr = - strip_column_qualifiers((*alias.expr).clone())?; - assignments.push((column_name, stripped_expr)); - } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(found_projection) +} + +fn append_update_assignments( + projection: &Projection, + target_refs: &[TableReference], + strip_qualifiers: bool, +) -> Result> { + projection + .expr + .iter() + .filter_map(|expr| match expr { + Expr::Alias(alias) + if !is_identity_assignment(&alias.expr, &alias.name, target_refs) => + { + Some( + if strip_qualifiers { + strip_column_qualifiers((*alias.expr).clone()) + } else { + Ok((*alias.expr).clone()) } - } - return Ok(TreeNodeRecursion::Stop); + .map(|assignment_expr| (alias.name.clone(), assignment_expr)), + ) } - Ok(TreeNodeRecursion::Continue) - })?; - } - - Ok(assignments) + _ => None, + }) + .collect() } /// Check if an assignment is an identity assignment (column = column) /// These are columns that are not being modified in the UPDATE -fn is_identity_assignment(expr: &Expr, column_name: &str) -> bool { +fn is_identity_assignment( + expr: &Expr, + column_name: &str, + target_refs: &[TableReference], +) -> bool { match expr { - Expr::Column(col) => col.name == column_name, + Expr::Column(col) => { + col.name == column_name + && col.relation.as_ref().is_none_or(|relation| { + target_refs + .iter() + .any(|target_ref| relation.resolved_eq(target_ref)) + }) + } _ => false, } } @@ -3114,6 +3197,7 @@ impl<'n> TreeNodeVisitor<'n> for InvariantChecker { mod tests { use std::any::Any; use std::cmp::Ordering; + use std::collections::HashMap; use std::fmt::{self, Debug}; use std::ops::{BitAnd, Not}; @@ -3167,6 +3251,48 @@ mod tests { .await } + fn make_test_mem_table(schema: SchemaRef) -> Arc { + Arc::new(MemTable::try_new(schema, vec![vec![]]).unwrap()) + } + + fn test_update_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])) + } + + async fn update_assignments_for_sql(sql: &str) -> Result> { + let ctx = SessionContext::new(); + let schema = test_update_schema(); + ctx.register_table("t1", make_test_mem_table(Arc::clone(&schema)))?; + ctx.register_table("t2", make_test_mem_table(schema))?; + + let df = ctx.sql(sql).await?; + let (table_name, input) = match df.logical_plan() { + LogicalPlan::Dml(DmlStatement { + op: WriteOp::Update, + table_name, + input, + .. + }) => (table_name.clone(), Arc::clone(input)), + plan => panic!("expected update plan, got {plan:?}"), + }; + + extract_update_assignments(&input, &table_name) + } + + async fn assert_update_assignment(sql: &str, column: &str, expected: &str) { + let assignments: HashMap<_, _> = update_assignments_for_sql(sql) + .await + .unwrap() + .into_iter() + .map(|(name, expr)| (name, expr.to_string())) + .collect(); + assert_eq!(assignments.get(column).map(String::as_str), Some(expected)); + } + #[tokio::test] async fn test_all_operators() -> Result<()> { let logical_plan = test_csv_scan() @@ -4749,4 +4875,56 @@ digraph { assert_eq!(plan.schema(), schema); assert!(plan.is::()); } + + #[tokio::test] + async fn test_extract_update_assignments_preserves_joined_source_qualifiers() { + assert_update_assignment( + "UPDATE t1 SET b = t2.b FROM t2 WHERE t1.id = t2.id", + "b", + "t2.b", + ) + .await; + } + + #[tokio::test] + async fn test_extract_update_assignments_preserves_alias_qualified_sources() { + assert_update_assignment( + "UPDATE t1 AS target SET b = source.b FROM t2 AS source \ + WHERE target.id = source.id", + "b", + "source.b", + ) + .await; + } + + #[tokio::test] + async fn test_extract_update_assignments_distinguishes_same_name_join_columns() { + let assignments: HashMap<_, _> = update_assignments_for_sql( + "UPDATE t1 SET a = t2.a, b = t1.a FROM t2 WHERE t1.id = t2.id", + ) + .await + .unwrap() + .into_iter() + .map(|(name, expr)| (name, expr.to_string())) + .collect(); + assert_eq!(assignments.get("a").map(String::as_str), Some("t2.a")); + assert_eq!(assignments.get("b").map(String::as_str), Some("t1.a")); + } + + #[tokio::test] + async fn test_extract_update_assignments_preserves_self_join_source_alias() { + assert_update_assignment( + "UPDATE t1 AS target SET a = src.a FROM t1 AS src \ + WHERE target.id = src.id", + "a", + "src.a", + ) + .await; + } + + #[tokio::test] + async fn test_extract_update_assignments_strips_single_table_target_qualifiers() { + assert_update_assignment("UPDATE t1 SET b = t1.a WHERE t1.id = 1", "b", "a") + .await; + } } diff --git a/datafusion/core/tests/custom_sources_cases/dml_planning.rs b/datafusion/core/tests/custom_sources_cases/dml_planning.rs index 8c4bae5e98b36..2dfb1051a7450 100644 --- a/datafusion/core/tests/custom_sources_cases/dml_planning.rs +++ b/datafusion/core/tests/custom_sources_cases/dml_planning.rs @@ -725,8 +725,6 @@ async fn test_delete_target_table_scoping() -> Result<()> { #[tokio::test] async fn test_update_from_drops_non_target_predicates() -> Result<()> { - // UPDATE ... FROM is currently not working - // TODO fix https://github.com/apache/datafusion/issues/19950 let target_provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( test_schema(), TableProviderFilterPushDown::Exact, @@ -743,17 +741,19 @@ async fn test_update_from_drops_non_target_predicates() -> Result<()> { let source_table = datafusion::datasource::empty::EmptyTable::new(source_schema); ctx.register_table("t2", Arc::new(source_table))?; - let result = ctx + let df = ctx .sql( "UPDATE t1 SET value = 1 FROM t2 \ WHERE t1.id = t2.id AND t2.src_only = 'active' AND t1.value > 10", ) - .await; + .await?; - // Verify UPDATE ... FROM is rejected with appropriate error - // TODO fix https://github.com/apache/datafusion/issues/19950 - assert!(result.is_err()); - let err = result.unwrap_err(); + // Verify UPDATE ... FROM reaches logical planning but still fails closed + // before the unsafe single-table update runtime path can execute. + let err = df + .collect() + .await + .expect_err("UPDATE ... FROM should fail closed"); assert!( err.to_string().contains("UPDATE ... FROM is not supported"), "Expected 'UPDATE ... FROM is not supported' error, got: {err}" diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index b91e38e53776a..d76e5b44c36ae 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1084,12 +1084,6 @@ impl SqlToRel<'_, S> { } let update_from = from_clauses.and_then(|mut f| f.pop()); - // UPDATE ... FROM is currently not working - // TODO fix https://github.com/apache/datafusion/issues/19950 - if update_from.is_some() { - return not_impl_err!("UPDATE ... FROM is not supported"); - } - if returning.is_some() { plan_err!("Update-returning clause not yet supported")?; } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 2c2c7eac8bfc4..51820f5c81923 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -26,16 +26,18 @@ use std::vec; use arrow::datatypes::{TimeUnit::Nanosecond, *}; use common::MockContextProvider; -use datafusion_common::{DataFusionError, Result, assert_contains}; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{DataFusionError, Result, TableReference, assert_contains}; use datafusion_expr::{ - ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, col, logical_plan::LogicalPlan, + AggregateUDF, ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, TableSource, Volatility, WindowUDF, col, + logical_plan::LogicalPlan, planner::ExprPlanner, planner::TypePlanner, test::function_stub::sum_udaf, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ parser::DFParser, - planner::{NullOrdering, ParserOptions, SqlToRel}, + planner::{ContextProvider, NullOrdering, ParserOptions, SqlToRel}, }; use crate::common::{CustomExprPlanner, CustomTypePlanner, MockSessionState}; @@ -56,6 +58,115 @@ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; mod cases; mod common; +#[derive(Debug)] +struct SqlTestTableSource { + schema: SchemaRef, +} + +impl SqlTestTableSource { + fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +impl TableSource for SqlTestTableSource { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +struct UpdatePlanningContextProvider { + inner: MockContextProvider, + update_schema: SchemaRef, +} + +impl UpdatePlanningContextProvider { + fn new(state: MockSessionState) -> Self { + Self { + inner: MockContextProvider { state }, + update_schema: update_test_schema(), + } + } + + fn get_update_table_source(&self) -> Arc { + Arc::new(SqlTestTableSource::new(Arc::clone(&self.update_schema))) + } +} + +impl ContextProvider for UpdatePlanningContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + match name.table() { + "t1" | "t2" => Ok(self.get_update_table_source()), + _ => self.inner.get_table_source(name), + } + } + + fn get_file_type(&self, ext: &str) -> Result> { + self.inner.get_file_type(ext) + } + + fn create_cte_work_table( + &self, + name: &str, + schema: SchemaRef, + ) -> Result> { + self.inner.create_cte_work_table(name, schema) + } + + fn get_expr_planners(&self) -> &[Arc] { + self.inner.get_expr_planners() + } + + fn get_type_planner(&self) -> Option> { + self.inner.get_type_planner() + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.inner.get_function_meta(name) + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.inner.get_aggregate_meta(name) + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.inner.get_window_meta(name) + } + + fn get_variable_type(&self, variable_names: &[String]) -> Option { + self.inner.get_variable_type(variable_names) + } + + fn options(&self) -> &datafusion_common::config::ConfigOptions { + self.inner.options() + } + + fn udf_names(&self) -> Vec { + self.inner.udf_names() + } + + fn udaf_names(&self) -> Vec { + self.inner.udaf_names() + } + + fn udwf_names(&self) -> Vec { + self.inner.udwf_names() + } +} + +fn update_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Int32, false), + ])) +} + #[test] fn parse_decimals_1() { let sql = "SELECT 1"; @@ -738,6 +849,47 @@ fn plan_update() { ); } +#[test] +fn plan_update_from_with_aliases() { + let sql = "UPDATE t1 AS target \ + SET b = source.b, c = source.a, d = 1 \ + FROM t2 AS source \ + WHERE target.a = source.a AND target.b > 'foo' AND source.c > 1.0"; + let plan = update_logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Update] table=[t1] + Projection: target.a AS a, source.b AS b, CAST(source.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d + Filter: target.a = source.a AND target.b > Utf8("foo") AND source.c > Float64(1) + Cross Join: + SubqueryAlias: target + TableScan: t1 + SubqueryAlias: source + TableScan: t2 + "# + ); +} + +#[test] +fn plan_explain_update_from() { + let sql = "EXPLAIN UPDATE t1 SET b = t2.b, c = t2.a, d = 1 \ + FROM t2 WHERE t1.a = t2.a AND t1.b > 'foo' AND t2.c > 1.0"; + let plan = update_logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Explain + Dml: op=[Update] table=[t1] + Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d + Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) + Cross Join: + TableScan: t1 + TableScan: t2 + "# + ); +} + #[rstest] #[case::missing_assignment_target("UPDATE person SET doesnotexist = true")] #[case::missing_assignment_expression("UPDATE person SET age = doesnotexist + 42")] @@ -3458,6 +3610,10 @@ fn logical_plan(sql: &str) -> Result { logical_plan_with_options(sql, ParserOptions::default()) } +fn update_logical_plan(sql: &str) -> Result { + update_logical_plan_with_options(sql, ParserOptions::default()) +} + fn logical_plan_with_options(sql: &str, options: ParserOptions) -> Result { let dialect = &GenericDialect {}; logical_plan_with_dialect_and_options(sql, dialect, options) @@ -3477,7 +3633,34 @@ fn logical_plan_with_dialect_and_options( dialect: &dyn Dialect, options: ParserOptions, ) -> Result { - let state = MockSessionState::default() + let state = sql_test_state(); + let context = MockContextProvider { state }; + plan_sql_with_options(sql, dialect, options, &context) +} + +fn update_logical_plan_with_options( + sql: &str, + options: ParserOptions, +) -> Result { + let dialect = &GenericDialect {}; + let state = sql_test_state(); + let context = UpdatePlanningContextProvider::new(state); + plan_sql_with_options(sql, dialect, options, &context) +} + +fn plan_sql_with_options( + sql: &str, + dialect: &dyn Dialect, + options: ParserOptions, + context: &S, +) -> Result { + let planner = SqlToRel::new_with_options(context, options); + let mut ast = DFParser::parse_sql_with_dialect(sql, dialect)?; + planner.statement_to_plan(ast.pop_front().unwrap()) +} + +fn sql_test_state() -> MockSessionState { + MockSessionState::default() .with_scalar_function(Arc::new(unicode::character_length().as_ref().clone())) .with_scalar_function(Arc::new(string::concat().as_ref().clone())) .with_scalar_function(Arc::new(make_udf( @@ -3514,13 +3697,7 @@ fn logical_plan_with_dialect_and_options( .with_aggregate_function(grouping_udaf()) .with_window_function(rank_udwf()) .with_window_function(row_number_udwf()) - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); - - let context = MockContextProvider { state }; - let planner = SqlToRel::new_with_options(&context, options); - let result = DFParser::parse_sql_with_dialect(sql, dialect); - let mut ast = result?; - planner.statement_to_plan(ast.pop_front().unwrap()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) } fn make_udf(name: &'static str, args: Vec, return_type: DataType) -> ScalarUDF { diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index e8fdab6ab18bb..23a3ffb5e8492 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -67,10 +67,19 @@ logical_plan physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() # set from other table -# UPDATE ... FROM is currently unsupported -# TODO fix https://github.com/apache/datafusion/issues/19950 -query error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported +# UPDATE ... FROM now reaches logical planning, but the physical planner +# still fails closed until joined update provider support lands. +query TT explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0; +---- +logical_plan +01)Dml: op=[Update] table=[t1] +02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +03)----Filter: t1.a = t2.a AND t1.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) +04)------Cross Join: +05)--------TableScan: t1 +06)--------TableScan: t2 +physical_plan_error This feature is not implemented: UPDATE ... FROM is not supported # test update from other table with actual data statement ok @@ -90,10 +99,20 @@ statement error DataFusion error: This feature is not implemented: Multiple tabl explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; # test table alias -# UPDATE ... FROM is currently unsupported -# TODO fix https://github.com/apache/datafusion/issues/19950 -statement error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported +# UPDATE ... FROM with aliases also reaches logical planning, but execution +# remains blocked in the physical planner for now. +query TT explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; +---- +logical_plan +01)Dml: op=[Update] table=[t1] +02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +03)----Filter: t.a = t2.a AND t.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) +04)------Cross Join: +05)--------SubqueryAlias: t +06)----------TableScan: t1 +07)--------TableScan: t2 +physical_plan_error This feature is not implemented: UPDATE ... FROM is not supported # test update with table alias with actual data statement ok