Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 232 additions & 54 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,14 +777,24 @@ impl DefaultPhysicalPlanner {
input,
..
}) => {
let analysis = analyze_dml_input(input, table_name)?;
// TODO: remove this guard once UPDATE ... FROM routing via
// TableProvider::update_from(...) and joined assignment handling land.
// See https://github.com/apache/datafusion/issues/19950.
if analysis.has_joined_input {
return not_impl_err!("UPDATE ... FROM is not supported");
}

if let Some(provider) =
target.as_any().downcast_ref::<DefaultTableSource>()
{
// For UPDATE, the assignments are encoded in the projection of input
// We pass the filters and let the provider handle the projection
let filters = extract_dml_filters(input, table_name)?;
let filters =
extract_dml_filters_with_analysis(input, table_name, &analysis)?;
// Extract assignments from the projection in input plan
let assignments = extract_update_assignments(input)?;
let assignments =
extract_update_assignments_with_analysis(input, &analysis)?;
provider
.table_provider
.update(session_state, assignments, filters)
Expand Down Expand Up @@ -2098,27 +2108,24 @@ fn extract_dml_filters(
input: &Arc<LogicalPlan>,
target: &TableReference,
) -> Result<Vec<Expr>> {
let mut filters = Vec::new();
let mut allowed_refs = vec![target.clone()];
let analysis = analyze_dml_input(input, target)?;
extract_dml_filters_with_analysis(input, target, &analysis)
}

// First pass: collect any alias references to the target table
input.apply(|node| {
if let LogicalPlan::SubqueryAlias(alias) = node
// Check if this alias points to the target table
&& let LogicalPlan::TableScan(scan) = alias.input.as_ref()
&& scan.table_name.resolved_eq(target)
{
allowed_refs.push(TableReference::bare(alias.alias.to_string()));
}
Ok(TreeNodeRecursion::Continue)
})?;
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key
fn extract_dml_filters_with_analysis(
input: &Arc<LogicalPlan>,
target: &TableReference,
analysis: &DmlInputAnalysis,
) -> Result<Vec<Expr>> {
let mut filters = Vec::new();

input.apply(|node| {
match node {
LogicalPlan::Filter(filter) => {
// Split AND predicates into individual expressions
for predicate in split_conjunction(&filter.predicate) {
if predicate_is_on_target_multi(predicate, &allowed_refs)? {
if predicate_is_on_target_multi(predicate, &analysis.target_refs)? {
filters.push(predicate.clone());
}
}
Expand Down Expand Up @@ -2194,6 +2201,45 @@ fn extract_dml_filters(
})
}

#[derive(Debug)]
struct DmlInputAnalysis {
has_joined_input: bool,
target_refs: Vec<TableReference>,
}

fn analyze_dml_input(
input: &Arc<LogicalPlan>,
target: &TableReference,
) -> Result<DmlInputAnalysis> {
let mut analysis = DmlInputAnalysis {
has_joined_input: false,
target_refs: vec![target.clone()],
};
analyze_target_branch(input, &mut analysis);
Ok(analysis)
}

fn analyze_target_branch(input: &Arc<LogicalPlan>, analysis: &mut DmlInputAnalysis) {
let mut current = input;
loop {
match current.as_ref() {
LogicalPlan::Projection(projection) => current = &projection.input,
LogicalPlan::Filter(filter) => current = &filter.input,
LogicalPlan::SubqueryAlias(alias) => {
analysis
.target_refs
.push(TableReference::bare(alias.alias.to_string()));
current = &alias.input;
}
LogicalPlan::Join(join) => {
analysis.has_joined_input = true;
current = &join.left;
}
_ => return,
}
}
}

/// Determine whether a predicate references only columns from the target table
/// or its aliases.
///
Expand Down Expand Up @@ -2240,61 +2286,98 @@ fn strip_column_qualifiers(expr: Expr) -> Result<Expr> {
/// Extract column assignments from an UPDATE input plan.
/// For UPDATE statements, the SQL planner encodes assignments as a projection
/// over the source table. This function extracts column name and expression pairs
/// from the projection. Column qualifiers are stripped from the expressions.
/// from the projection. Column qualifiers are stripped only for single-table
/// updates so provider-facing expressions remain resolvable.
///
fn extract_update_assignments(input: &Arc<LogicalPlan>) -> Result<Vec<(String, Expr)>> {
#[cfg(test)]
fn extract_update_assignments(
input: &Arc<LogicalPlan>,
target: &TableReference,
) -> Result<Vec<(String, Expr)>> {
// The UPDATE input plan structure is:
// Projection(updated columns as expressions with aliases)
// Filter(optional WHERE clause)
// TableScan
//
// Each projected expression has an alias matching the column name
let mut assignments = Vec::new();
let analysis = analyze_dml_input(input, target)?;
extract_update_assignments_with_analysis(input, &analysis)
}

// Find the top-level projection
fn extract_update_assignments_with_analysis(
input: &Arc<LogicalPlan>,
analysis: &DmlInputAnalysis,
) -> Result<Vec<(String, Expr)>> {
find_update_projection(input)?
.map(|projection| {
append_update_assignments(
projection,
&analysis.target_refs,
!analysis.has_joined_input,
)
})
.transpose()
.map(|assignments| assignments.unwrap_or_default())
}

fn find_update_projection(input: &Arc<LogicalPlan>) -> Result<Option<&Projection>> {
if let LogicalPlan::Projection(projection) = input.as_ref() {
for expr in &projection.expr {
if let Expr::Alias(alias) = expr {
// The alias name is the column name being updated
// The inner expression is the new value
let column_name = alias.name.clone();
// Only include if it's not just a column reference to itself
// (those are columns that aren't being updated)
if !is_identity_assignment(&alias.expr, &column_name) {
// Strip qualifiers from the assignment expression
let stripped_expr = strip_column_qualifiers((*alias.expr).clone())?;
assignments.push((column_name, stripped_expr));
}
}
return Ok(Some(projection));
}

let mut found_projection = None;
input.apply(|node| {
if let LogicalPlan::Projection(projection) = node {
found_projection = Some(projection);
return Ok(TreeNodeRecursion::Stop);
}
} else {
// Try to find projection deeper in the plan
input.apply(|node| {
if let LogicalPlan::Projection(projection) = node {
for expr in &projection.expr {
if let Expr::Alias(alias) = expr {
let column_name = alias.name.clone();
if !is_identity_assignment(&alias.expr, &column_name) {
let stripped_expr =
strip_column_qualifiers((*alias.expr).clone())?;
assignments.push((column_name, stripped_expr));
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(found_projection)
}

fn append_update_assignments(
projection: &Projection,
target_refs: &[TableReference],
strip_qualifiers: bool,
) -> Result<Vec<(String, Expr)>> {
projection
.expr
.iter()
.filter_map(|expr| match expr {
Expr::Alias(alias)
if !is_identity_assignment(&alias.expr, &alias.name, target_refs) =>
{
Some(
if strip_qualifiers {
strip_column_qualifiers((*alias.expr).clone())
} else {
Ok((*alias.expr).clone())
}
}
return Ok(TreeNodeRecursion::Stop);
.map(|assignment_expr| (alias.name.clone(), assignment_expr)),
)
}
Ok(TreeNodeRecursion::Continue)
})?;
}

Ok(assignments)
_ => None,
})
.collect()
}

/// Check if an assignment is an identity assignment (column = column)
/// These are columns that are not being modified in the UPDATE
fn is_identity_assignment(expr: &Expr, column_name: &str) -> bool {
fn is_identity_assignment(
expr: &Expr,
column_name: &str,
target_refs: &[TableReference],
) -> bool {
match expr {
Expr::Column(col) => col.name == column_name,
Expr::Column(col) => {
col.name == column_name
&& col.relation.as_ref().is_none_or(|relation| {
target_refs
.iter()
.any(|target_ref| relation.resolved_eq(target_ref))
})
}
_ => false,
}
}
Expand Down Expand Up @@ -3114,6 +3197,7 @@ impl<'n> TreeNodeVisitor<'n> for InvariantChecker {
mod tests {
use std::any::Any;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::fmt::{self, Debug};
use std::ops::{BitAnd, Not};

Expand Down Expand Up @@ -3167,6 +3251,48 @@ mod tests {
.await
}

fn make_test_mem_table(schema: SchemaRef) -> Arc<MemTable> {
Arc::new(MemTable::try_new(schema, vec![vec![]]).unwrap())
}

fn test_update_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]))
}

async fn update_assignments_for_sql(sql: &str) -> Result<Vec<(String, Expr)>> {
let ctx = SessionContext::new();
let schema = test_update_schema();
ctx.register_table("t1", make_test_mem_table(Arc::clone(&schema)))?;
ctx.register_table("t2", make_test_mem_table(schema))?;

let df = ctx.sql(sql).await?;
let (table_name, input) = match df.logical_plan() {
LogicalPlan::Dml(DmlStatement {
op: WriteOp::Update,
table_name,
input,
..
}) => (table_name.clone(), Arc::clone(input)),
plan => panic!("expected update plan, got {plan:?}"),
};

extract_update_assignments(&input, &table_name)
}

async fn assert_update_assignment(sql: &str, column: &str, expected: &str) {
let assignments: HashMap<_, _> = update_assignments_for_sql(sql)
.await
.unwrap()
.into_iter()
.map(|(name, expr)| (name, expr.to_string()))
.collect();
assert_eq!(assignments.get(column).map(String::as_str), Some(expected));
}

#[tokio::test]
async fn test_all_operators() -> Result<()> {
let logical_plan = test_csv_scan()
Expand Down Expand Up @@ -4749,4 +4875,56 @@ digraph {
assert_eq!(plan.schema(), schema);
assert!(plan.is::<EmptyExec>());
}

#[tokio::test]
async fn test_extract_update_assignments_preserves_joined_source_qualifiers() {
assert_update_assignment(
"UPDATE t1 SET b = t2.b FROM t2 WHERE t1.id = t2.id",
"b",
"t2.b",
)
.await;
}

#[tokio::test]
async fn test_extract_update_assignments_preserves_alias_qualified_sources() {
assert_update_assignment(
"UPDATE t1 AS target SET b = source.b FROM t2 AS source \
WHERE target.id = source.id",
"b",
"source.b",
)
.await;
}

#[tokio::test]
async fn test_extract_update_assignments_distinguishes_same_name_join_columns() {
let assignments: HashMap<_, _> = update_assignments_for_sql(
"UPDATE t1 SET a = t2.a, b = t1.a FROM t2 WHERE t1.id = t2.id",
)
.await
.unwrap()
.into_iter()
.map(|(name, expr)| (name, expr.to_string()))
.collect();
assert_eq!(assignments.get("a").map(String::as_str), Some("t2.a"));
assert_eq!(assignments.get("b").map(String::as_str), Some("t1.a"));
}

#[tokio::test]
async fn test_extract_update_assignments_preserves_self_join_source_alias() {
assert_update_assignment(
"UPDATE t1 AS target SET a = src.a FROM t1 AS src \
WHERE target.id = src.id",
"a",
"src.a",
)
.await;
}

#[tokio::test]
async fn test_extract_update_assignments_strips_single_table_target_qualifiers() {
assert_update_assignment("UPDATE t1 SET b = t1.a WHERE t1.id = 1", "b", "a")
.await;
}
}
16 changes: 8 additions & 8 deletions datafusion/core/tests/custom_sources_cases/dml_planning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,6 @@ async fn test_delete_target_table_scoping() -> Result<()> {

#[tokio::test]
async fn test_update_from_drops_non_target_predicates() -> Result<()> {
// UPDATE ... FROM is currently not working
// TODO fix https://github.com/apache/datafusion/issues/19950
let target_provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown(
test_schema(),
TableProviderFilterPushDown::Exact,
Expand All @@ -743,17 +741,19 @@ async fn test_update_from_drops_non_target_predicates() -> Result<()> {
let source_table = datafusion::datasource::empty::EmptyTable::new(source_schema);
ctx.register_table("t2", Arc::new(source_table))?;

let result = ctx
let df = ctx
.sql(
"UPDATE t1 SET value = 1 FROM t2 \
WHERE t1.id = t2.id AND t2.src_only = 'active' AND t1.value > 10",
)
.await;
.await?;

// Verify UPDATE ... FROM is rejected with appropriate error
// TODO fix https://github.com/apache/datafusion/issues/19950
assert!(result.is_err());
let err = result.unwrap_err();
// Verify UPDATE ... FROM reaches logical planning but still fails closed
// before the unsafe single-table update runtime path can execute.
let err = df
.collect()
.await
.expect_err("UPDATE ... FROM should fail closed");
assert!(
err.to_string().contains("UPDATE ... FROM is not supported"),
"Expected 'UPDATE ... FROM is not supported' error, got: {err}"
Expand Down
Loading
Loading