From 67d3617e60e4cef01674bf5daecc7f399c8ee38d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 10 Apr 2026 16:57:25 +0800 Subject: [PATCH 1/7] Implement logical planning for UPDATE ... FROM Restore support for single-source UPDATE ... FROM in the planner by removing the rejection of early joined update plans. Move the safety block to the physical planner to ensure joined updates are safeguarded. Add test coverage for logical shapes and mock schemas, and update execution regression tests to confirm successful planning while maintaining fail-closed behavior. --- datafusion/core/src/physical_planner.rs | 19 ++++++++++++++++ .../custom_sources_cases/dml_planning.rs | 16 +++++++------- datafusion/sql/src/statement.rs | 6 ----- datafusion/sql/tests/common/mod.rs | 12 ++++++++++ datafusion/sql/tests/sql_integration.rs | 22 +++++++++++++++++++ 5 files changed, 61 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index bf84fcc53e957..5599931a93e4a 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -777,6 +777,13 @@ impl DefaultPhysicalPlanner { input, .. }) => { + // TODO: remove this guard once UPDATE ... FROM routing via + // TableProvider::update_from(...) and joined assignment handling land. + // See https://github.com/apache/datafusion/issues/19950. + if update_uses_joined_input(input)? { + return not_impl_err!("UPDATE ... FROM is not supported"); + } + if let Some(provider) = target.as_any().downcast_ref::() { @@ -2194,6 +2201,18 @@ fn extract_dml_filters( }) } +fn update_uses_joined_input(input: &Arc) -> Result { + let mut has_join = false; + input.apply(|node| { + if matches!(node, LogicalPlan::Join(_)) { + has_join = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(has_join) +} + /// Determine whether a predicate references only columns from the target table /// or its aliases. /// diff --git a/datafusion/core/tests/custom_sources_cases/dml_planning.rs b/datafusion/core/tests/custom_sources_cases/dml_planning.rs index 8c4bae5e98b36..2dfb1051a7450 100644 --- a/datafusion/core/tests/custom_sources_cases/dml_planning.rs +++ b/datafusion/core/tests/custom_sources_cases/dml_planning.rs @@ -725,8 +725,6 @@ async fn test_delete_target_table_scoping() -> Result<()> { #[tokio::test] async fn test_update_from_drops_non_target_predicates() -> Result<()> { - // UPDATE ... FROM is currently not working - // TODO fix https://github.com/apache/datafusion/issues/19950 let target_provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown( test_schema(), TableProviderFilterPushDown::Exact, @@ -743,17 +741,19 @@ async fn test_update_from_drops_non_target_predicates() -> Result<()> { let source_table = datafusion::datasource::empty::EmptyTable::new(source_schema); ctx.register_table("t2", Arc::new(source_table))?; - let result = ctx + let df = ctx .sql( "UPDATE t1 SET value = 1 FROM t2 \ WHERE t1.id = t2.id AND t2.src_only = 'active' AND t1.value > 10", ) - .await; + .await?; - // Verify UPDATE ... FROM is rejected with appropriate error - // TODO fix https://github.com/apache/datafusion/issues/19950 - assert!(result.is_err()); - let err = result.unwrap_err(); + // Verify UPDATE ... FROM reaches logical planning but still fails closed + // before the unsafe single-table update runtime path can execute. + let err = df + .collect() + .await + .expect_err("UPDATE ... FROM should fail closed"); assert!( err.to_string().contains("UPDATE ... FROM is not supported"), "Expected 'UPDATE ... FROM is not supported' error, got: {err}" diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index b91e38e53776a..d76e5b44c36ae 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1084,12 +1084,6 @@ impl SqlToRel<'_, S> { } let update_from = from_clauses.and_then(|mut f| f.pop()); - // UPDATE ... FROM is currently not working - // TODO fix https://github.com/apache/datafusion/issues/19950 - if update_from.is_some() { - return not_impl_err!("UPDATE ... FROM is not supported"); - } - if returning.is_some() { plan_err!("Update-returning clause not yet supported")?; } diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 5caade300290f..7999db78dda39 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -124,6 +124,18 @@ impl ContextProvider for MockContextProvider { Field::new("id", DataType::Int32, false), Field::new("price", DataType::Decimal128(10, 2), false), ])), + "t1" => Ok(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Int32, false), + ])), + "t2" => Ok(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Int32, false), + ])), "person" => Ok(Schema::new(vec![ Field::new("id", DataType::UInt32, false), Field::new("first_name", DataType::Utf8, false), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 2c2c7eac8bfc4..04c77304c2e00 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -738,6 +738,28 @@ fn plan_update() { ); } +#[test] +fn plan_update_from_with_aliases() { + let sql = "UPDATE t1 AS target \ + SET b = source.b, c = source.a, d = 1 \ + FROM t2 AS source \ + WHERE target.a = source.a AND target.b > 'foo' AND source.c > 1.0"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Update] table=[t1] + Projection: target.a AS a, source.b AS b, CAST(source.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d + Filter: target.a = source.a AND target.b > Utf8("foo") AND source.c > Float64(1) + Cross Join: + SubqueryAlias: target + TableScan: t1 + SubqueryAlias: source + TableScan: t2 + "# + ); +} + #[rstest] #[case::missing_assignment_target("UPDATE person SET doesnotexist = true")] #[case::missing_assignment_expression("UPDATE person SET age = doesnotexist + 42")] From 85145d8146bae713befabed8cdbfc06c61c77c95 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 10 Apr 2026 17:14:47 +0800 Subject: [PATCH 2/7] Fix EXPLAIN UPDATE handling and joined assignment logic Ensure EXPLAIN UPDATE ... FROM fails during SQL planning, instead of misleadingly passing to physical_plan_error. Maintain the physical-planner guard for direct execution failures. Update joined assignment extraction to preserve source references and avoid misclassifying columns in single-table updates. Add regression tests in sql_integration.rs and unit tests in physical_planner.rs. --- datafusion/core/src/physical_planner.rs | 171 ++++++++++++++++++++++-- datafusion/sql/src/statement.rs | 28 ++++ datafusion/sql/tests/sql_integration.rs | 11 ++ 3 files changed, 198 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 5599931a93e4a..3aad3bf0dbbfd 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -791,7 +791,7 @@ impl DefaultPhysicalPlanner { // We pass the filters and let the provider handle the projection let filters = extract_dml_filters(input, table_name)?; // Extract assignments from the projection in input plan - let assignments = extract_update_assignments(input)?; + let assignments = extract_update_assignments(input, table_name)?; provider .table_provider .update(session_state, assignments, filters) @@ -2261,7 +2261,10 @@ fn strip_column_qualifiers(expr: Expr) -> Result { /// over the source table. This function extracts column name and expression pairs /// from the projection. Column qualifiers are stripped from the expressions. /// -fn extract_update_assignments(input: &Arc) -> Result> { +fn extract_update_assignments( + input: &Arc, + target: &TableReference, +) -> Result> { // The UPDATE input plan structure is: // Projection(updated columns as expressions with aliases) // Filter(optional WHERE clause) @@ -2269,6 +2272,8 @@ fn extract_update_assignments(input: &Arc) -> Result) -> Result) -> Result) -> Result, + target: &TableReference, +) -> Result> { + let mut target_refs = vec![target.clone()]; + input.apply(|node| { + if let LogicalPlan::SubqueryAlias(alias) = node + && let LogicalPlan::TableScan(scan) = alias.input.as_ref() + && scan.table_name.resolved_eq(target) + { + target_refs.push(TableReference::bare(alias.alias.to_string())); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(target_refs) +} + +fn normalize_update_assignment_expr(expr: Expr, strip_qualifiers: bool) -> Result { + if strip_qualifiers { + strip_column_qualifiers(expr) + } else { + Ok(expr) + } +} + /// Check if an assignment is an identity assignment (column = column) /// These are columns that are not being modified in the UPDATE -fn is_identity_assignment(expr: &Expr, column_name: &str) -> bool { +fn is_identity_assignment( + expr: &Expr, + column_name: &str, + target_refs: &[TableReference], +) -> bool { match expr { - Expr::Column(col) => col.name == column_name, + Expr::Column(col) => { + col.name == column_name + && col.relation.as_ref().is_none_or(|relation| { + target_refs + .iter() + .any(|target_ref| relation.resolved_eq(target_ref)) + }) + } _ => false, } } @@ -3133,6 +3182,7 @@ impl<'n> TreeNodeVisitor<'n> for InvariantChecker { mod tests { use std::any::Any; use std::cmp::Ordering; + use std::collections::HashMap; use std::fmt::{self, Debug}; use std::ops::{BitAnd, Not}; @@ -3186,6 +3236,39 @@ mod tests { .await } + fn make_test_mem_table(schema: SchemaRef) -> Arc { + Arc::new(MemTable::try_new(schema, vec![vec![]]).unwrap()) + } + + async fn update_assignments_for_sql(sql: &str) -> Result> { + let ctx = SessionContext::new(); + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + ctx.register_table("t1", make_test_mem_table(t1_schema))?; + ctx.register_table("t2", make_test_mem_table(t2_schema))?; + + let df = ctx.sql(sql).await?; + let (table_name, input) = match df.logical_plan() { + LogicalPlan::Dml(DmlStatement { + op: WriteOp::Update, + table_name, + input, + .. + }) => (table_name.clone(), Arc::clone(input)), + plan => panic!("expected update plan, got {plan:?}"), + }; + + extract_update_assignments(&input, &table_name) + } + #[tokio::test] async fn test_all_operators() -> Result<()> { let logical_plan = test_csv_scan() @@ -4768,4 +4851,68 @@ digraph { assert_eq!(plan.schema(), schema); assert!(plan.is::()); } + + #[tokio::test] + async fn test_extract_update_assignments_preserves_joined_source_qualifiers() { + let assignments = update_assignments_for_sql( + "UPDATE t1 SET b = t2.b FROM t2 WHERE t1.id = t2.id", + ) + .await + .unwrap(); + + let assignments: HashMap<_, _> = assignments.into_iter().collect(); + assert_eq!( + assignments.get("b").map(ToString::to_string).as_deref(), + Some("t2.b") + ); + } + + #[tokio::test] + async fn test_extract_update_assignments_preserves_alias_qualified_sources() { + let assignments = update_assignments_for_sql( + "UPDATE t1 AS target SET b = source.b FROM t2 AS source \ + WHERE target.id = source.id", + ) + .await + .unwrap(); + + let assignments: HashMap<_, _> = assignments.into_iter().collect(); + assert_eq!( + assignments.get("b").map(ToString::to_string).as_deref(), + Some("source.b") + ); + } + + #[tokio::test] + async fn test_extract_update_assignments_distinguishes_same_name_join_columns() { + let assignments = update_assignments_for_sql( + "UPDATE t1 SET a = t2.a, b = t1.a FROM t2 WHERE t1.id = t2.id", + ) + .await + .unwrap(); + + let assignments: HashMap<_, _> = assignments.into_iter().collect(); + assert_eq!( + assignments.get("a").map(ToString::to_string).as_deref(), + Some("t2.a") + ); + assert_eq!( + assignments.get("b").map(ToString::to_string).as_deref(), + Some("t1.a") + ); + } + + #[tokio::test] + async fn test_extract_update_assignments_strips_single_table_target_qualifiers() { + let assignments = + update_assignments_for_sql("UPDATE t1 SET b = t1.a WHERE t1.id = 1") + .await + .unwrap(); + + let assignments: HashMap<_, _> = assignments.into_iter().collect(); + assert_eq!( + assignments.get("b").map(ToString::to_string).as_deref(), + Some("a") + ); + } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index d76e5b44c36ae..0a43a65eceb05 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -32,6 +32,7 @@ use crate::utils::normalize_ident; use arrow::datatypes::{Field, FieldRef, Fields}; use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{ Column, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, ToDFSchema, exec_err, @@ -1912,6 +1913,12 @@ impl SqlToRel<'_, S> { statement: DFStatement, ) -> Result { let plan = self.statement_to_plan(statement)?; + // TODO: remove this guard once UPDATE ... FROM routing via + // TableProvider::update_from(...) and joined assignment handling land. + // See https://github.com/apache/datafusion/issues/19950. + if update_uses_joined_input(&plan)? { + return not_impl_err!("UPDATE ... FROM is not supported"); + } if matches!(plan, LogicalPlan::Explain(_)) { return plan_err!("Nested EXPLAINs are not supported"); } @@ -2566,3 +2573,24 @@ ON p.function_name = r.routine_name } } } + +fn update_uses_joined_input(plan: &LogicalPlan) -> Result { + let LogicalPlan::Dml(DmlStatement { + op: WriteOp::Update, + input, + .. + }) = plan + else { + return Ok(false); + }; + + let mut has_join = false; + input.apply(|node| { + if matches!(node, LogicalPlan::Join(_)) { + has_join = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(has_join) +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 04c77304c2e00..669672a99fafc 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -760,6 +760,17 @@ fn plan_update_from_with_aliases() { ); } +#[test] +fn explain_update_from_is_rejected() { + let sql = "EXPLAIN UPDATE t1 SET b = t2.b, c = t2.a, d = 1 \ + FROM t2 WHERE t1.a = t2.a AND t1.b > 'foo' AND t2.c > 1.0"; + let err = logical_plan(sql).expect_err("EXPLAIN UPDATE ... FROM should fail"); + assert_snapshot!( + err.strip_backtrace(), + @r#"This feature is not implemented: UPDATE ... FROM is not supported"# + ); +} + #[rstest] #[case::missing_assignment_target("UPDATE person SET doesnotexist = true")] #[case::missing_assignment_expression("UPDATE person SET age = doesnotexist + 42")] From ac34b8b05ef80b7e134354b4191d0662e996bf88 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 10 Apr 2026 17:28:04 +0800 Subject: [PATCH 3/7] Fix blocking issue and clean up projection logic Remove the EXPLAIN-time UPDATE ... FROM rejection in statement.rs to allow the SQL planner to expose the joined logical plan. Adjust regression test in sql_integration.rs to assert the Explain -> Dml(Update) plan shape. Consolidate duplicated projection-walking logic in physical_planner.rs by using a shared helper function for extract_update_assignments(). This simplifies identity-check and qualifier-normalization rules. --- datafusion/core/src/physical_planner.rs | 69 +++++++++++++------------ datafusion/sql/src/statement.rs | 28 ---------- datafusion/sql/tests/sql_integration.rs | 16 ++++-- 3 files changed, 49 insertions(+), 64 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 3aad3bf0dbbfd..6fe0906e004fb 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2277,42 +2277,22 @@ fn extract_update_assignments( // Find the top-level projection if let LogicalPlan::Projection(projection) = input.as_ref() { - for expr in &projection.expr { - if let Expr::Alias(alias) = expr { - // The alias name is the column name being updated - // The inner expression is the new value - let column_name = alias.name.clone(); - // Only include if it's not just a column reference to itself - // (those are columns that aren't being updated) - if !is_identity_assignment(&alias.expr, &column_name, &target_refs) { - let assignment_expr = normalize_update_assignment_expr( - (*alias.expr).clone(), - strip_qualifiers, - )?; - assignments.push((column_name, assignment_expr)); - } - } - } + append_update_assignments( + &mut assignments, + projection, + &target_refs, + strip_qualifiers, + )?; } else { // Try to find projection deeper in the plan input.apply(|node| { if let LogicalPlan::Projection(projection) = node { - for expr in &projection.expr { - if let Expr::Alias(alias) = expr { - let column_name = alias.name.clone(); - if !is_identity_assignment( - &alias.expr, - &column_name, - &target_refs, - ) { - let assignment_expr = normalize_update_assignment_expr( - (*alias.expr).clone(), - strip_qualifiers, - )?; - assignments.push((column_name, assignment_expr)); - } - } - } + append_update_assignments( + &mut assignments, + projection, + &target_refs, + strip_qualifiers, + )?; return Ok(TreeNodeRecursion::Stop); } Ok(TreeNodeRecursion::Continue) @@ -2322,6 +2302,31 @@ fn extract_update_assignments( Ok(assignments) } +fn append_update_assignments( + assignments: &mut Vec<(String, Expr)>, + projection: &Projection, + target_refs: &[TableReference], + strip_qualifiers: bool, +) -> Result<()> { + for expr in &projection.expr { + if let Expr::Alias(alias) = expr { + // The alias name is the column name being updated + // The inner expression is the new value + let column_name = alias.name.clone(); + // Only include if it's not just a column reference to itself + // (those are columns that aren't being updated) + if !is_identity_assignment(&alias.expr, &column_name, target_refs) { + let assignment_expr = normalize_update_assignment_expr( + (*alias.expr).clone(), + strip_qualifiers, + )?; + assignments.push((column_name, assignment_expr)); + } + } + } + Ok(()) +} + fn collect_target_refs( input: &Arc, target: &TableReference, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 0a43a65eceb05..d76e5b44c36ae 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -32,7 +32,6 @@ use crate::utils::normalize_ident; use arrow::datatypes::{Field, FieldRef, Fields}; use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{ Column, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, ToDFSchema, exec_err, @@ -1913,12 +1912,6 @@ impl SqlToRel<'_, S> { statement: DFStatement, ) -> Result { let plan = self.statement_to_plan(statement)?; - // TODO: remove this guard once UPDATE ... FROM routing via - // TableProvider::update_from(...) and joined assignment handling land. - // See https://github.com/apache/datafusion/issues/19950. - if update_uses_joined_input(&plan)? { - return not_impl_err!("UPDATE ... FROM is not supported"); - } if matches!(plan, LogicalPlan::Explain(_)) { return plan_err!("Nested EXPLAINs are not supported"); } @@ -2573,24 +2566,3 @@ ON p.function_name = r.routine_name } } } - -fn update_uses_joined_input(plan: &LogicalPlan) -> Result { - let LogicalPlan::Dml(DmlStatement { - op: WriteOp::Update, - input, - .. - }) = plan - else { - return Ok(false); - }; - - let mut has_join = false; - input.apply(|node| { - if matches!(node, LogicalPlan::Join(_)) { - has_join = true; - return Ok(TreeNodeRecursion::Stop); - } - Ok(TreeNodeRecursion::Continue) - })?; - Ok(has_join) -} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 669672a99fafc..5999d635510bf 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -761,13 +761,21 @@ fn plan_update_from_with_aliases() { } #[test] -fn explain_update_from_is_rejected() { +fn plan_explain_update_from() { let sql = "EXPLAIN UPDATE t1 SET b = t2.b, c = t2.a, d = 1 \ FROM t2 WHERE t1.a = t2.a AND t1.b > 'foo' AND t2.c > 1.0"; - let err = logical_plan(sql).expect_err("EXPLAIN UPDATE ... FROM should fail"); + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - err.strip_backtrace(), - @r#"This feature is not implemented: UPDATE ... FROM is not supported"# + plan, + @r#" + Explain + Dml: op=[Update] table=[t1] + Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d + Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) + Cross Join: + TableScan: t1 + TableScan: t2 + "# ); } From 2b5aa22bb6aaceb60d2fa65c46f55745f8c3641a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 10 Apr 2026 17:31:55 +0800 Subject: [PATCH 4/7] Update EXPLAIN for logical and physical planner checks Adjust update.slt to ensure both EXPLAIN UPDATE ... FROM cases account for successful logical planning in addition to the existing physical-planner rejection. Align Utf8View cast with reports from sqllogictest in the filter for better consistency. --- datafusion/sqllogictest/test_files/update.slt | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index e8fdab6ab18bb..23a3ffb5e8492 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -67,10 +67,19 @@ logical_plan physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() # set from other table -# UPDATE ... FROM is currently unsupported -# TODO fix https://github.com/apache/datafusion/issues/19950 -query error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported +# UPDATE ... FROM now reaches logical planning, but the physical planner +# still fails closed until joined update provider support lands. +query TT explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0; +---- +logical_plan +01)Dml: op=[Update] table=[t1] +02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +03)----Filter: t1.a = t2.a AND t1.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) +04)------Cross Join: +05)--------TableScan: t1 +06)--------TableScan: t2 +physical_plan_error This feature is not implemented: UPDATE ... FROM is not supported # test update from other table with actual data statement ok @@ -90,10 +99,20 @@ statement error DataFusion error: This feature is not implemented: Multiple tabl explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; # test table alias -# UPDATE ... FROM is currently unsupported -# TODO fix https://github.com/apache/datafusion/issues/19950 -statement error DataFusion error: This feature is not implemented: UPDATE ... FROM is not supported +# UPDATE ... FROM with aliases also reaches logical planning, but execution +# remains blocked in the physical planner for now. +query TT explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and t.b > 'foo' and t2.c > 1.0; +---- +logical_plan +01)Dml: op=[Update] table=[t1] +02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +03)----Filter: t.a = t2.a AND t.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) +04)------Cross Join: +05)--------SubqueryAlias: t +06)----------TableScan: t1 +07)--------TableScan: t2 +physical_plan_error This feature is not implemented: UPDATE ... FROM is not supported # test update with table alias with actual data statement ok From d76d918fd39e44b029e86f06dd9d2d051834e3a6 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 10 Apr 2026 17:37:59 +0800 Subject: [PATCH 5/7] Fix target alias collection in physical planner Update the alias collection logic to only traverse the update target branch, preventing self-join source aliases from being confused with target aliases. Add a regression test ensuring the correct assignment of src.a in the UPDATE statement for improved accuracy in query execution. --- datafusion/core/src/physical_planner.rs | 46 ++++++++++++++++++++----- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6fe0906e004fb..674ed7070a9e3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2332,16 +2332,30 @@ fn collect_target_refs( target: &TableReference, ) -> Result> { let mut target_refs = vec![target.clone()]; - input.apply(|node| { - if let LogicalPlan::SubqueryAlias(alias) = node - && let LogicalPlan::TableScan(scan) = alias.input.as_ref() - && scan.table_name.resolved_eq(target) - { + collect_target_refs_from_target_branch(input, &mut target_refs)?; + Ok(target_refs) +} + +fn collect_target_refs_from_target_branch( + input: &Arc, + target_refs: &mut Vec, +) -> Result<()> { + match input.as_ref() { + LogicalPlan::Projection(projection) => { + collect_target_refs_from_target_branch(&projection.input, target_refs) + } + LogicalPlan::Filter(filter) => { + collect_target_refs_from_target_branch(&filter.input, target_refs) + } + LogicalPlan::Join(join) => { + collect_target_refs_from_target_branch(&join.left, target_refs) + } + LogicalPlan::SubqueryAlias(alias) => { target_refs.push(TableReference::bare(alias.alias.to_string())); + collect_target_refs_from_target_branch(&alias.input, target_refs) } - Ok(TreeNodeRecursion::Continue) - })?; - Ok(target_refs) + _ => Ok(()), + } } fn normalize_update_assignment_expr(expr: Expr, strip_qualifiers: bool) -> Result { @@ -4907,6 +4921,22 @@ digraph { ); } + #[tokio::test] + async fn test_extract_update_assignments_preserves_self_join_source_alias() { + let assignments = update_assignments_for_sql( + "UPDATE t1 AS target SET a = src.a FROM t1 AS src \ + WHERE target.id = src.id", + ) + .await + .unwrap(); + + let assignments: HashMap<_, _> = assignments.into_iter().collect(); + assert_eq!( + assignments.get("a").map(ToString::to_string).as_deref(), + Some("src.a") + ); + } + #[tokio::test] async fn test_extract_update_assignments_strips_single_table_target_qualifiers() { let assignments = From 0ce5ffedec1df6b729517d459a62f7db2cee12bc Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 10 Apr 2026 17:49:45 +0800 Subject: [PATCH 6/7] Refactor DML analysis and update planner tests Consolidate duplicated joined-update and target-alias walks in physical_planner.rs by implementing a shared analyze_dml_input(...) helper. Update the filter and assignment extraction to utilize this common metadata. In sql_integration.rs, encapsulate the t1/t2 setup within a local UpdatePlanningContextProvider for new joined-update planner tests, eliminating unnecessary table names from the shared mock catalog in common/mod.rs. --- datafusion/core/src/physical_planner.rs | 108 ++++++++--------- datafusion/sql/tests/common/mod.rs | 12 -- datafusion/sql/tests/sql_integration.rs | 151 ++++++++++++++++++++++-- 3 files changed, 186 insertions(+), 85 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 674ed7070a9e3..7d7bab455f1c8 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -777,10 +777,11 @@ impl DefaultPhysicalPlanner { input, .. }) => { + let analysis = analyze_dml_input(input, table_name)?; // TODO: remove this guard once UPDATE ... FROM routing via // TableProvider::update_from(...) and joined assignment handling land. // See https://github.com/apache/datafusion/issues/19950. - if update_uses_joined_input(input)? { + if analysis.has_joined_input { return not_impl_err!("UPDATE ... FROM is not supported"); } @@ -2105,27 +2106,15 @@ fn extract_dml_filters( input: &Arc, target: &TableReference, ) -> Result> { + let analysis = analyze_dml_input(input, target)?; let mut filters = Vec::new(); - let mut allowed_refs = vec![target.clone()]; - - // First pass: collect any alias references to the target table - input.apply(|node| { - if let LogicalPlan::SubqueryAlias(alias) = node - // Check if this alias points to the target table - && let LogicalPlan::TableScan(scan) = alias.input.as_ref() - && scan.table_name.resolved_eq(target) - { - allowed_refs.push(TableReference::bare(alias.alias.to_string())); - } - Ok(TreeNodeRecursion::Continue) - })?; input.apply(|node| { match node { LogicalPlan::Filter(filter) => { // Split AND predicates into individual expressions for predicate in split_conjunction(&filter.predicate) { - if predicate_is_on_target_multi(predicate, &allowed_refs)? { + if predicate_is_on_target_multi(predicate, &analysis.target_refs)? { filters.push(predicate.clone()); } } @@ -2201,16 +2190,45 @@ fn extract_dml_filters( }) } -fn update_uses_joined_input(input: &Arc) -> Result { - let mut has_join = false; - input.apply(|node| { - if matches!(node, LogicalPlan::Join(_)) { - has_join = true; - return Ok(TreeNodeRecursion::Stop); +#[derive(Debug)] +struct DmlInputAnalysis { + has_joined_input: bool, + target_refs: Vec, +} + +fn analyze_dml_input( + input: &Arc, + target: &TableReference, +) -> Result { + let mut analysis = DmlInputAnalysis { + has_joined_input: false, + target_refs: vec![target.clone()], + }; + analyze_target_branch(input, &mut analysis)?; + Ok(analysis) +} + +fn analyze_target_branch( + input: &Arc, + analysis: &mut DmlInputAnalysis, +) -> Result<()> { + match input.as_ref() { + LogicalPlan::Projection(projection) => { + analyze_target_branch(&projection.input, analysis) } - Ok(TreeNodeRecursion::Continue) - })?; - Ok(has_join) + LogicalPlan::Filter(filter) => analyze_target_branch(&filter.input, analysis), + LogicalPlan::SubqueryAlias(alias) => { + analysis + .target_refs + .push(TableReference::bare(alias.alias.to_string())); + analyze_target_branch(&alias.input, analysis) + } + LogicalPlan::Join(join) => { + analysis.has_joined_input = true; + analyze_target_branch(&join.left, analysis) + } + _ => Ok(()), + } } /// Determine whether a predicate references only columns from the target table @@ -2259,7 +2277,8 @@ fn strip_column_qualifiers(expr: Expr) -> Result { /// Extract column assignments from an UPDATE input plan. /// For UPDATE statements, the SQL planner encodes assignments as a projection /// over the source table. This function extracts column name and expression pairs -/// from the projection. Column qualifiers are stripped from the expressions. +/// from the projection. Column qualifiers are stripped only for single-table +/// updates so provider-facing expressions remain resolvable. /// fn extract_update_assignments( input: &Arc, @@ -2271,16 +2290,16 @@ fn extract_update_assignments( // TableScan // // Each projected expression has an alias matching the column name + let analysis = analyze_dml_input(input, target)?; let mut assignments = Vec::new(); - let strip_qualifiers = !update_uses_joined_input(input)?; - let target_refs = collect_target_refs(input, target)?; + let strip_qualifiers = !analysis.has_joined_input; // Find the top-level projection if let LogicalPlan::Projection(projection) = input.as_ref() { append_update_assignments( &mut assignments, projection, - &target_refs, + &analysis.target_refs, strip_qualifiers, )?; } else { @@ -2290,7 +2309,7 @@ fn extract_update_assignments( append_update_assignments( &mut assignments, projection, - &target_refs, + &analysis.target_refs, strip_qualifiers, )?; return Ok(TreeNodeRecursion::Stop); @@ -2327,37 +2346,6 @@ fn append_update_assignments( Ok(()) } -fn collect_target_refs( - input: &Arc, - target: &TableReference, -) -> Result> { - let mut target_refs = vec![target.clone()]; - collect_target_refs_from_target_branch(input, &mut target_refs)?; - Ok(target_refs) -} - -fn collect_target_refs_from_target_branch( - input: &Arc, - target_refs: &mut Vec, -) -> Result<()> { - match input.as_ref() { - LogicalPlan::Projection(projection) => { - collect_target_refs_from_target_branch(&projection.input, target_refs) - } - LogicalPlan::Filter(filter) => { - collect_target_refs_from_target_branch(&filter.input, target_refs) - } - LogicalPlan::Join(join) => { - collect_target_refs_from_target_branch(&join.left, target_refs) - } - LogicalPlan::SubqueryAlias(alias) => { - target_refs.push(TableReference::bare(alias.alias.to_string())); - collect_target_refs_from_target_branch(&alias.input, target_refs) - } - _ => Ok(()), - } -} - fn normalize_update_assignment_expr(expr: Expr, strip_qualifiers: bool) -> Result { if strip_qualifiers { strip_column_qualifiers(expr) diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 7999db78dda39..5caade300290f 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -124,18 +124,6 @@ impl ContextProvider for MockContextProvider { Field::new("id", DataType::Int32, false), Field::new("price", DataType::Decimal128(10, 2), false), ])), - "t1" => Ok(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - Field::new("d", DataType::Int32, false), - ])), - "t2" => Ok(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - Field::new("c", DataType::Float64, false), - Field::new("d", DataType::Int32, false), - ])), "person" => Ok(Schema::new(vec![ Field::new("id", DataType::UInt32, false), Field::new("first_name", DataType::Utf8, false), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 5999d635510bf..17ff44ae54f66 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -26,16 +26,18 @@ use std::vec; use arrow::datatypes::{TimeUnit::Nanosecond, *}; use common::MockContextProvider; -use datafusion_common::{DataFusionError, Result, assert_contains}; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{DataFusionError, Result, TableReference, assert_contains}; use datafusion_expr::{ ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, col, logical_plan::LogicalPlan, + ScalarUDFImpl, Signature, TableSource, Volatility, col, + logical_plan::LogicalPlan, planner::ExprPlanner, planner::TypePlanner, test::function_stub::sum_udaf, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ parser::DFParser, - planner::{NullOrdering, ParserOptions, SqlToRel}, + planner::{ContextProvider, NullOrdering, ParserOptions, SqlToRel}, }; use crate::common::{CustomExprPlanner, CustomTypePlanner, MockSessionState}; @@ -56,6 +58,109 @@ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; mod cases; mod common; +#[derive(Debug)] +struct SqlTestTableSource { + schema: SchemaRef, +} + +impl SqlTestTableSource { + fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +impl TableSource for SqlTestTableSource { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +struct UpdatePlanningContextProvider { + inner: MockContextProvider, +} + +impl UpdatePlanningContextProvider { + fn new(state: MockSessionState) -> Self { + Self { + inner: MockContextProvider { state }, + } + } +} + +impl ContextProvider for UpdatePlanningContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + match name.table() { + "t1" | "t2" => Ok(Arc::new(SqlTestTableSource::new(update_test_schema()))), + _ => self.inner.get_table_source(name), + } + } + + fn get_file_type(&self, ext: &str) -> Result> { + self.inner.get_file_type(ext) + } + + fn create_cte_work_table( + &self, + name: &str, + schema: SchemaRef, + ) -> Result> { + self.inner.create_cte_work_table(name, schema) + } + + fn get_expr_planners(&self) -> &[Arc] { + self.inner.get_expr_planners() + } + + fn get_type_planner(&self) -> Option> { + self.inner.get_type_planner() + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.inner.get_function_meta(name) + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.inner.get_aggregate_meta(name) + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.inner.get_window_meta(name) + } + + fn get_variable_type(&self, variable_names: &[String]) -> Option { + self.inner.get_variable_type(variable_names) + } + + fn options(&self) -> &datafusion_common::config::ConfigOptions { + self.inner.options() + } + + fn udf_names(&self) -> Vec { + self.inner.udf_names() + } + + fn udaf_names(&self) -> Vec { + self.inner.udaf_names() + } + + fn udwf_names(&self) -> Vec { + self.inner.udwf_names() + } +} + +fn update_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Int32, false), + ])) +} + #[test] fn parse_decimals_1() { let sql = "SELECT 1"; @@ -744,7 +849,7 @@ fn plan_update_from_with_aliases() { SET b = source.b, c = source.a, d = 1 \ FROM t2 AS source \ WHERE target.a = source.a AND target.b > 'foo' AND source.c > 1.0"; - let plan = logical_plan(sql).unwrap(); + let plan = update_logical_plan(sql).unwrap(); assert_snapshot!( plan, @r#" @@ -764,7 +869,7 @@ fn plan_update_from_with_aliases() { fn plan_explain_update_from() { let sql = "EXPLAIN UPDATE t1 SET b = t2.b, c = t2.a, d = 1 \ FROM t2 WHERE t1.a = t2.a AND t1.b > 'foo' AND t2.c > 1.0"; - let plan = logical_plan(sql).unwrap(); + let plan = update_logical_plan(sql).unwrap(); assert_snapshot!( plan, @r#" @@ -3499,6 +3604,10 @@ fn logical_plan(sql: &str) -> Result { logical_plan_with_options(sql, ParserOptions::default()) } +fn update_logical_plan(sql: &str) -> Result { + update_logical_plan_with_options(sql, ParserOptions::default()) +} + fn logical_plan_with_options(sql: &str, options: ParserOptions) -> Result { let dialect = &GenericDialect {}; logical_plan_with_dialect_and_options(sql, dialect, options) @@ -3518,7 +3627,29 @@ fn logical_plan_with_dialect_and_options( dialect: &dyn Dialect, options: ParserOptions, ) -> Result { - let state = MockSessionState::default() + let state = sql_test_state(); + let context = MockContextProvider { state }; + let planner = SqlToRel::new_with_options(&context, options); + let result = DFParser::parse_sql_with_dialect(sql, dialect); + let mut ast = result?; + planner.statement_to_plan(ast.pop_front().unwrap()) +} + +fn update_logical_plan_with_options( + sql: &str, + options: ParserOptions, +) -> Result { + let dialect = &GenericDialect {}; + let state = sql_test_state(); + let context = UpdatePlanningContextProvider::new(state); + let planner = SqlToRel::new_with_options(&context, options); + let result = DFParser::parse_sql_with_dialect(sql, dialect); + let mut ast = result?; + planner.statement_to_plan(ast.pop_front().unwrap()) +} + +fn sql_test_state() -> MockSessionState { + MockSessionState::default() .with_scalar_function(Arc::new(unicode::character_length().as_ref().clone())) .with_scalar_function(Arc::new(string::concat().as_ref().clone())) .with_scalar_function(Arc::new(make_udf( @@ -3555,13 +3686,7 @@ fn logical_plan_with_dialect_and_options( .with_aggregate_function(grouping_udaf()) .with_window_function(rank_udwf()) .with_window_function(row_number_udwf()) - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); - - let context = MockContextProvider { state }; - let planner = SqlToRel::new_with_options(&context, options); - let result = DFParser::parse_sql_with_dialect(sql, dialect); - let mut ast = result?; - planner.statement_to_plan(ast.pop_front().unwrap()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) } fn make_udf(name: &'static str, args: Vec, return_type: DataType) -> ScalarUDF { From 2b8b462a93d0f63d468f6ba904e01545215102b2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 10 Apr 2026 17:57:08 +0800 Subject: [PATCH 7/7] Refactor planning and testing for updates and SQL Simplify physical planner by reusing DmlInputAnalysis and centralizing projection lookup. Streamline assignment extraction with iterators. Reduce duplication in SQL planning setup by introducing a shared helper and improve context provider to reuse stored schemas for efficiency. Enhance test scaffolding with shared update schema and new assertion utilities. --- datafusion/core/src/physical_planner.rs | 253 ++++++++++++------------ datafusion/sql/tests/sql_integration.rs | 35 ++-- 2 files changed, 144 insertions(+), 144 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 7d7bab455f1c8..c27864c5dd156 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -790,9 +790,11 @@ impl DefaultPhysicalPlanner { { // For UPDATE, the assignments are encoded in the projection of input // We pass the filters and let the provider handle the projection - let filters = extract_dml_filters(input, table_name)?; + let filters = + extract_dml_filters_with_analysis(input, table_name, &analysis)?; // Extract assignments from the projection in input plan - let assignments = extract_update_assignments(input, table_name)?; + let assignments = + extract_update_assignments_with_analysis(input, &analysis)?; provider .table_provider .update(session_state, assignments, filters) @@ -2107,6 +2109,15 @@ fn extract_dml_filters( target: &TableReference, ) -> Result> { let analysis = analyze_dml_input(input, target)?; + extract_dml_filters_with_analysis(input, target, &analysis) +} + +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key +fn extract_dml_filters_with_analysis( + input: &Arc, + target: &TableReference, + analysis: &DmlInputAnalysis, +) -> Result> { let mut filters = Vec::new(); input.apply(|node| { @@ -2204,30 +2215,28 @@ fn analyze_dml_input( has_joined_input: false, target_refs: vec![target.clone()], }; - analyze_target_branch(input, &mut analysis)?; + analyze_target_branch(input, &mut analysis); Ok(analysis) } -fn analyze_target_branch( - input: &Arc, - analysis: &mut DmlInputAnalysis, -) -> Result<()> { - match input.as_ref() { - LogicalPlan::Projection(projection) => { - analyze_target_branch(&projection.input, analysis) - } - LogicalPlan::Filter(filter) => analyze_target_branch(&filter.input, analysis), - LogicalPlan::SubqueryAlias(alias) => { - analysis - .target_refs - .push(TableReference::bare(alias.alias.to_string())); - analyze_target_branch(&alias.input, analysis) - } - LogicalPlan::Join(join) => { - analysis.has_joined_input = true; - analyze_target_branch(&join.left, analysis) +fn analyze_target_branch(input: &Arc, analysis: &mut DmlInputAnalysis) { + let mut current = input; + loop { + match current.as_ref() { + LogicalPlan::Projection(projection) => current = &projection.input, + LogicalPlan::Filter(filter) => current = &filter.input, + LogicalPlan::SubqueryAlias(alias) => { + analysis + .target_refs + .push(TableReference::bare(alias.alias.to_string())); + current = &alias.input; + } + LogicalPlan::Join(join) => { + analysis.has_joined_input = true; + current = &join.left; + } + _ => return, } - _ => Ok(()), } } @@ -2280,6 +2289,7 @@ fn strip_column_qualifiers(expr: Expr) -> Result { /// from the projection. Column qualifiers are stripped only for single-table /// updates so provider-facing expressions remain resolvable. /// +#[cfg(test)] fn extract_update_assignments( input: &Arc, target: &TableReference, @@ -2291,67 +2301,65 @@ fn extract_update_assignments( // // Each projected expression has an alias matching the column name let analysis = analyze_dml_input(input, target)?; - let mut assignments = Vec::new(); - let strip_qualifiers = !analysis.has_joined_input; + extract_update_assignments_with_analysis(input, &analysis) +} - // Find the top-level projection +fn extract_update_assignments_with_analysis( + input: &Arc, + analysis: &DmlInputAnalysis, +) -> Result> { + find_update_projection(input)? + .map(|projection| { + append_update_assignments( + projection, + &analysis.target_refs, + !analysis.has_joined_input, + ) + }) + .transpose() + .map(|assignments| assignments.unwrap_or_default()) +} + +fn find_update_projection(input: &Arc) -> Result> { if let LogicalPlan::Projection(projection) = input.as_ref() { - append_update_assignments( - &mut assignments, - projection, - &analysis.target_refs, - strip_qualifiers, - )?; - } else { - // Try to find projection deeper in the plan - input.apply(|node| { - if let LogicalPlan::Projection(projection) = node { - append_update_assignments( - &mut assignments, - projection, - &analysis.target_refs, - strip_qualifiers, - )?; - return Ok(TreeNodeRecursion::Stop); - } - Ok(TreeNodeRecursion::Continue) - })?; + return Ok(Some(projection)); } - Ok(assignments) + let mut found_projection = None; + input.apply(|node| { + if let LogicalPlan::Projection(projection) = node { + found_projection = Some(projection); + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(found_projection) } fn append_update_assignments( - assignments: &mut Vec<(String, Expr)>, projection: &Projection, target_refs: &[TableReference], strip_qualifiers: bool, -) -> Result<()> { - for expr in &projection.expr { - if let Expr::Alias(alias) = expr { - // The alias name is the column name being updated - // The inner expression is the new value - let column_name = alias.name.clone(); - // Only include if it's not just a column reference to itself - // (those are columns that aren't being updated) - if !is_identity_assignment(&alias.expr, &column_name, target_refs) { - let assignment_expr = normalize_update_assignment_expr( - (*alias.expr).clone(), - strip_qualifiers, - )?; - assignments.push((column_name, assignment_expr)); +) -> Result> { + projection + .expr + .iter() + .filter_map(|expr| match expr { + Expr::Alias(alias) + if !is_identity_assignment(&alias.expr, &alias.name, target_refs) => + { + Some( + if strip_qualifiers { + strip_column_qualifiers((*alias.expr).clone()) + } else { + Ok((*alias.expr).clone()) + } + .map(|assignment_expr| (alias.name.clone(), assignment_expr)), + ) } - } - } - Ok(()) -} - -fn normalize_update_assignment_expr(expr: Expr, strip_qualifiers: bool) -> Result { - if strip_qualifiers { - strip_column_qualifiers(expr) - } else { - Ok(expr) - } + _ => None, + }) + .collect() } /// Check if an assignment is an identity assignment (column = column) @@ -3247,20 +3255,19 @@ mod tests { Arc::new(MemTable::try_new(schema, vec![vec![]]).unwrap()) } - async fn update_assignments_for_sql(sql: &str) -> Result> { - let ctx = SessionContext::new(); - let t1_schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ])); - let t2_schema = Arc::new(Schema::new(vec![ + fn test_update_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), - ])); - ctx.register_table("t1", make_test_mem_table(t1_schema))?; - ctx.register_table("t2", make_test_mem_table(t2_schema))?; + ])) + } + + async fn update_assignments_for_sql(sql: &str) -> Result> { + let ctx = SessionContext::new(); + let schema = test_update_schema(); + ctx.register_table("t1", make_test_mem_table(Arc::clone(&schema)))?; + ctx.register_table("t2", make_test_mem_table(schema))?; let df = ctx.sql(sql).await?; let (table_name, input) = match df.logical_plan() { @@ -3276,6 +3283,16 @@ mod tests { extract_update_assignments(&input, &table_name) } + async fn assert_update_assignment(sql: &str, column: &str, expected: &str) { + let assignments: HashMap<_, _> = update_assignments_for_sql(sql) + .await + .unwrap() + .into_iter() + .map(|(name, expr)| (name, expr.to_string())) + .collect(); + assert_eq!(assignments.get(column).map(String::as_str), Some(expected)); + } + #[tokio::test] async fn test_all_operators() -> Result<()> { let logical_plan = test_csv_scan() @@ -4861,81 +4878,53 @@ digraph { #[tokio::test] async fn test_extract_update_assignments_preserves_joined_source_qualifiers() { - let assignments = update_assignments_for_sql( + assert_update_assignment( "UPDATE t1 SET b = t2.b FROM t2 WHERE t1.id = t2.id", + "b", + "t2.b", ) - .await - .unwrap(); - - let assignments: HashMap<_, _> = assignments.into_iter().collect(); - assert_eq!( - assignments.get("b").map(ToString::to_string).as_deref(), - Some("t2.b") - ); + .await; } #[tokio::test] async fn test_extract_update_assignments_preserves_alias_qualified_sources() { - let assignments = update_assignments_for_sql( + assert_update_assignment( "UPDATE t1 AS target SET b = source.b FROM t2 AS source \ WHERE target.id = source.id", + "b", + "source.b", ) - .await - .unwrap(); - - let assignments: HashMap<_, _> = assignments.into_iter().collect(); - assert_eq!( - assignments.get("b").map(ToString::to_string).as_deref(), - Some("source.b") - ); + .await; } #[tokio::test] async fn test_extract_update_assignments_distinguishes_same_name_join_columns() { - let assignments = update_assignments_for_sql( + let assignments: HashMap<_, _> = update_assignments_for_sql( "UPDATE t1 SET a = t2.a, b = t1.a FROM t2 WHERE t1.id = t2.id", ) .await - .unwrap(); - - let assignments: HashMap<_, _> = assignments.into_iter().collect(); - assert_eq!( - assignments.get("a").map(ToString::to_string).as_deref(), - Some("t2.a") - ); - assert_eq!( - assignments.get("b").map(ToString::to_string).as_deref(), - Some("t1.a") - ); + .unwrap() + .into_iter() + .map(|(name, expr)| (name, expr.to_string())) + .collect(); + assert_eq!(assignments.get("a").map(String::as_str), Some("t2.a")); + assert_eq!(assignments.get("b").map(String::as_str), Some("t1.a")); } #[tokio::test] async fn test_extract_update_assignments_preserves_self_join_source_alias() { - let assignments = update_assignments_for_sql( + assert_update_assignment( "UPDATE t1 AS target SET a = src.a FROM t1 AS src \ WHERE target.id = src.id", + "a", + "src.a", ) - .await - .unwrap(); - - let assignments: HashMap<_, _> = assignments.into_iter().collect(); - assert_eq!( - assignments.get("a").map(ToString::to_string).as_deref(), - Some("src.a") - ); + .await; } #[tokio::test] async fn test_extract_update_assignments_strips_single_table_target_qualifiers() { - let assignments = - update_assignments_for_sql("UPDATE t1 SET b = t1.a WHERE t1.id = 1") - .await - .unwrap(); - - let assignments: HashMap<_, _> = assignments.into_iter().collect(); - assert_eq!( - assignments.get("b").map(ToString::to_string).as_deref(), - Some("a") - ); + assert_update_assignment("UPDATE t1 SET b = t1.a WHERE t1.id = 1", "b", "a") + .await; } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 17ff44ae54f66..51820f5c81923 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -29,8 +29,8 @@ use common::MockContextProvider; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DataFusionError, Result, TableReference, assert_contains}; use datafusion_expr::{ - ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, TableSource, Volatility, col, + AggregateUDF, ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, TableSource, Volatility, WindowUDF, col, logical_plan::LogicalPlan, planner::ExprPlanner, planner::TypePlanner, test::function_stub::sum_udaf, }; @@ -81,20 +81,26 @@ impl TableSource for SqlTestTableSource { struct UpdatePlanningContextProvider { inner: MockContextProvider, + update_schema: SchemaRef, } impl UpdatePlanningContextProvider { fn new(state: MockSessionState) -> Self { Self { inner: MockContextProvider { state }, + update_schema: update_test_schema(), } } + + fn get_update_table_source(&self) -> Arc { + Arc::new(SqlTestTableSource::new(Arc::clone(&self.update_schema))) + } } impl ContextProvider for UpdatePlanningContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { match name.table() { - "t1" | "t2" => Ok(Arc::new(SqlTestTableSource::new(update_test_schema()))), + "t1" | "t2" => Ok(self.get_update_table_source()), _ => self.inner.get_table_source(name), } } @@ -123,11 +129,11 @@ impl ContextProvider for UpdatePlanningContextProvider { self.inner.get_function_meta(name) } - fn get_aggregate_meta(&self, name: &str) -> Option> { + fn get_aggregate_meta(&self, name: &str) -> Option> { self.inner.get_aggregate_meta(name) } - fn get_window_meta(&self, name: &str) -> Option> { + fn get_window_meta(&self, name: &str) -> Option> { self.inner.get_window_meta(name) } @@ -3629,10 +3635,7 @@ fn logical_plan_with_dialect_and_options( ) -> Result { let state = sql_test_state(); let context = MockContextProvider { state }; - let planner = SqlToRel::new_with_options(&context, options); - let result = DFParser::parse_sql_with_dialect(sql, dialect); - let mut ast = result?; - planner.statement_to_plan(ast.pop_front().unwrap()) + plan_sql_with_options(sql, dialect, options, &context) } fn update_logical_plan_with_options( @@ -3642,9 +3645,17 @@ fn update_logical_plan_with_options( let dialect = &GenericDialect {}; let state = sql_test_state(); let context = UpdatePlanningContextProvider::new(state); - let planner = SqlToRel::new_with_options(&context, options); - let result = DFParser::parse_sql_with_dialect(sql, dialect); - let mut ast = result?; + plan_sql_with_options(sql, dialect, options, &context) +} + +fn plan_sql_with_options( + sql: &str, + dialect: &dyn Dialect, + options: ParserOptions, + context: &S, +) -> Result { + let planner = SqlToRel::new_with_options(context, options); + let mut ast = DFParser::parse_sql_with_dialect(sql, dialect)?; planner.statement_to_plan(ast.pop_front().unwrap()) }