From 502e7e73e832c9644142c49fb4e8a633236d7286 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 19 Feb 2026 16:17:42 -0500 Subject: [PATCH 1/2] . --- .../consumer/expr/field_reference.rs | 45 ++++++- .../src/logical_plan/consumer/expr/mod.rs | 5 +- .../logical_plan/consumer/expr/subquery.rs | 44 +++++-- .../src/logical_plan/consumer/plan.rs | 5 +- .../logical_plan/consumer/rel/exchange_rel.rs | 3 +- .../consumer/substrait_consumer.rs | 123 +++++++++++++++++- .../tests/cases/consumer_integration.rs | 12 +- 7 files changed, 205 insertions(+), 32 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs index c17bf9c92edcc..631ad527c0e20 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs @@ -16,21 +16,24 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{Column, DFSchema, not_impl_err}; +use datafusion::common::{Column, DFSchema, not_impl_err, substrait_err}; use datafusion::logical_expr::Expr; +use std::sync::Arc; use substrait::proto::expression::FieldReference; use substrait::proto::expression::field_reference::ReferenceType::DirectReference; +use substrait::proto::expression::field_reference::RootType; use substrait::proto::expression::reference_segment::ReferenceType::StructField; pub async fn from_field_reference( - _consumer: &impl SubstraitConsumer, + consumer: &impl SubstraitConsumer, field_ref: &FieldReference, input_schema: &DFSchema, ) -> datafusion::common::Result { - from_substrait_field_reference(field_ref, input_schema) + from_substrait_field_reference(consumer, field_ref, input_schema) } pub(crate) fn from_substrait_field_reference( + consumer: &impl SubstraitConsumer, field_ref: &FieldReference, input_schema: &DFSchema, ) -> datafusion::common::Result { @@ -40,9 +43,39 @@ pub(crate) fn from_substrait_field_reference( Some(_) => not_impl_err!( "Direct reference StructField with child is not supported" ), - None => Ok(Expr::Column(Column::from( - input_schema.qualified_field(x.field as usize), - ))), + None => { + let field_idx = x.field as usize; + match &field_ref.root_type { + // Normal reference: resolve against the current + // relation's schema. + Some(RootType::RootReference(_)) | None => Ok(Expr::Column( + Column::from(input_schema.qualified_field(field_idx)), + )), + // Correlated reference: resolve against the + // enclosing query's schema. + Some(RootType::OuterReference(outer_ref)) => { + let steps_out = outer_ref.steps_out as usize; + let Some(outer_schema) = consumer.get_outer_schema(steps_out) + else { + return substrait_err!( + "OuterReference with steps_out={steps_out} \ + but no outer schema is available" + ); + }; + let (qualifier, field) = + outer_schema.qualified_field(field_idx); + let col = Column::from((qualifier, field)); + Ok(Expr::OuterReferenceColumn(Arc::clone(field), col)) + } + // The root is an arbitrary expression rather + // than a relation's schema. + Some(RootType::Expression(_)) => { + not_impl_err!( + "Expression root type in field reference is not supported" + ) + } + } + } }, _ => not_impl_err!( "Direct reference with types other than StructField is not supported" diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs index 71e3b9e96e153..5d98850c72cca 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -117,10 +117,7 @@ pub async fn from_substrait_extended_expr( return not_impl_err!("Type variation extensions are not supported"); } - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; + let consumer = DefaultSubstraitConsumer::new(&extensions, state); let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs index 61a381e9eb407..83cf8400eebfc 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -18,13 +18,31 @@ use crate::logical_plan::consumer::SubstraitConsumer; use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err}; use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; -use datafusion::logical_expr::{Expr, Operator, Subquery}; +use datafusion::logical_expr::{Expr, LogicalPlan, Operator, Subquery}; use std::sync::Arc; +use substrait::proto::Rel; use substrait::proto::expression as substrait_expression; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::subquery::set_predicate::PredicateOp; +/// Consume a subquery relation, making the enclosing query's schema +/// available for resolving correlated column references. +/// +/// Substrait represents correlated references using `OuterReference` +/// field references with a `steps_out` depth. To resolve these, +/// the consumer maintains a stack of outer schemas. +async fn consume_subquery_rel( + consumer: &impl SubstraitConsumer, + rel: &Rel, + outer_schema: &DFSchema, +) -> datafusion::common::Result { + consumer.push_outer_schema(Arc::new(outer_schema.clone())); + let result = consumer.consume_rel(rel).await; + consumer.pop_outer_schema(); + result +} + pub async fn from_subquery( consumer: &impl SubstraitConsumer, subquery: &substrait_expression::Subquery, @@ -41,7 +59,9 @@ pub async fn from_subquery( let needle_expr = &in_predicate.needles[0]; let haystack_expr = &in_predicate.haystack; if let Some(haystack_expr) = haystack_expr { - let haystack_expr = consumer.consume_rel(haystack_expr).await?; + let haystack_expr = + consume_subquery_rel(consumer, haystack_expr, input_schema) + .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); Ok(Expr::InSubquery(InSubquery { expr: Box::new( @@ -64,9 +84,12 @@ pub async fn from_subquery( } } SubqueryType::Scalar(query) => { - let plan = consumer - .consume_rel(&(query.input.clone()).unwrap_or_default()) - .await?; + let plan = consume_subquery_rel( + consumer, + &(query.input.clone()).unwrap_or_default(), + input_schema, + ) + .await?; let outer_ref_columns = plan.all_out_ref_exprs(); Ok(Expr::ScalarSubquery(Subquery { subquery: Arc::new(plan), @@ -79,9 +102,12 @@ pub async fn from_subquery( // exist PredicateOp::Exists => { let relation = &predicate.tuples; - let plan = consumer - .consume_rel(&relation.clone().unwrap_or_default()) - .await?; + let plan = consume_subquery_rel( + consumer, + &relation.clone().unwrap_or_default(), + input_schema, + ) + .await?; let outer_ref_columns = plan.all_out_ref_exprs(); Ok(Expr::Exists(Exists::new( Subquery { @@ -131,7 +157,7 @@ pub async fn from_subquery( }; let left_expr = consumer.consume_expression(left, input_schema).await?; - let plan = consumer.consume_rel(right).await?; + let plan = consume_subquery_rel(consumer, right, input_schema).await?; let outer_ref_columns = plan.all_out_ref_exprs(); Ok(Expr::SetComparison(SetComparison::new( diff --git a/datafusion/substrait/src/logical_plan/consumer/plan.rs b/datafusion/substrait/src/logical_plan/consumer/plan.rs index d5e10fb604017..407980c4a7f4b 100644 --- a/datafusion/substrait/src/logical_plan/consumer/plan.rs +++ b/datafusion/substrait/src/logical_plan/consumer/plan.rs @@ -35,10 +35,7 @@ pub async fn from_substrait_plan( return not_impl_err!("Type variation extensions are not supported"); } - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; + let consumer = DefaultSubstraitConsumer::new(&extensions, state); from_substrait_plan_with_consumer(&consumer, plan).await } diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs index a6132e047f7da..b275e523f5861 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs @@ -42,7 +42,8 @@ pub async fn from_exchange_rel( let mut partition_columns = vec![]; let input_schema = input.schema(); for field_ref in &scatter_fields.fields { - let column = from_substrait_field_reference(field_ref, input_schema)?; + let column = + from_substrait_field_reference(consumer, field_ref, input_schema)?; partition_columns.push(column); } Partitioning::Hash(partition_columns, exchange.partition_count as usize) diff --git a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs index 4c19227a30c75..a23f1faed1eb0 100644 --- a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs @@ -31,7 +31,7 @@ use datafusion::common::{ }; use datafusion::execution::{FunctionRegistry, SessionState}; use datafusion::logical_expr::{Expr, Extension, LogicalPlan}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use substrait::proto; use substrait::proto::expression as substrait_expression; use substrait::proto::expression::{ @@ -364,6 +364,26 @@ pub trait SubstraitConsumer: Send + Sync + Sized { not_impl_err!("Dynamic Parameter expression not supported") } + // Outer Schema Stack + // These methods manage a stack of outer schemas for correlated subquery support. + // When entering a subquery, the enclosing query's schema is pushed onto the stack. + // Field references with OuterReference root_type use these to resolve columns. + + /// Push an outer schema onto the stack when entering a subquery. + fn push_outer_schema(&self, _schema: Arc) {} + + /// Pop an outer schema from the stack when leaving a subquery. + fn pop_outer_schema(&self) {} + + /// Get the outer schema at the given nesting depth. + /// `steps_out = 1` is the immediately enclosing query, `steps_out = 2` + /// is two levels out, etc. Returns `None` if `steps_out` is 0 or + /// exceeds the current nesting depth (the caller should treat this as + /// an error in the Substrait plan). + fn get_outer_schema(&self, _steps_out: usize) -> Option> { + None + } + // User-Defined Functionality // The details of extension relations, and how to handle them, are fully up to users to specify. @@ -437,11 +457,16 @@ pub trait SubstraitConsumer: Send + Sync + Sized { pub struct DefaultSubstraitConsumer<'a> { pub(super) extensions: &'a Extensions, pub(super) state: &'a SessionState, + outer_schemas: RwLock>>, } impl<'a> DefaultSubstraitConsumer<'a> { pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { - DefaultSubstraitConsumer { extensions, state } + DefaultSubstraitConsumer { + extensions, + state, + outer_schemas: RwLock::new(Vec::new()), + } } } @@ -465,6 +490,24 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { self.state } + fn push_outer_schema(&self, schema: Arc) { + self.outer_schemas.write().unwrap().push(schema); + } + + fn pop_outer_schema(&self) { + self.outer_schemas.write().unwrap().pop(); + } + + fn get_outer_schema(&self, steps_out: usize) -> Option> { + let schemas = self.outer_schemas.read().unwrap(); + // steps_out=1 → last element, steps_out=2 → second-to-last, etc. + // Returns None for steps_out=0 or steps_out > stack depth. + schemas + .len() + .checked_sub(steps_out) + .and_then(|idx| schemas.get(idx).cloned()) + } + async fn consume_extension_leaf( &self, rel: &ExtensionLeafRel, @@ -520,3 +563,79 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { Ok(LogicalPlan::Extension(Extension { node: plan })) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::utils::tests::test_consumer; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + + fn make_schema(fields: &[(&str, DataType)]) -> Arc { + let arrow_fields: Vec = fields + .iter() + .map(|(name, dt)| Field::new(*name, dt.clone(), true)) + .collect(); + Arc::new( + DFSchema::try_from(Schema::new(arrow_fields)) + .expect("failed to create schema"), + ) + } + + #[test] + fn test_get_outer_schema_empty_stack() { + let consumer = test_consumer(); + + // No schemas pushed — any steps_out should return None + assert!(consumer.get_outer_schema(0).is_none()); + assert!(consumer.get_outer_schema(1).is_none()); + assert!(consumer.get_outer_schema(2).is_none()); + } + + #[test] + fn test_get_outer_schema_single_level() { + let consumer = test_consumer(); + + let schema_a = make_schema(&[("a", DataType::Int64)]); + consumer.push_outer_schema(Arc::clone(&schema_a)); + + // steps_out=1 returns the one pushed schema + let result = consumer.get_outer_schema(1).unwrap(); + assert_eq!(result.fields().len(), 1); + assert_eq!(result.fields()[0].name(), "a"); + + // steps_out=0 and steps_out=2 are out of range + assert!(consumer.get_outer_schema(0).is_none()); + assert!(consumer.get_outer_schema(2).is_none()); + + consumer.pop_outer_schema(); + assert!(consumer.get_outer_schema(1).is_none()); + } + + #[test] + fn test_get_outer_schema_nested() { + let consumer = test_consumer(); + + let schema_a = make_schema(&[("a", DataType::Int64)]); + let schema_b = make_schema(&[("b", DataType::Utf8)]); + + consumer.push_outer_schema(Arc::clone(&schema_a)); + consumer.push_outer_schema(Arc::clone(&schema_b)); + + // steps_out=1 returns the most recent (schema_b) + let result = consumer.get_outer_schema(1).unwrap(); + assert_eq!(result.fields()[0].name(), "b"); + + // steps_out=2 returns the grandparent (schema_a) + let result = consumer.get_outer_schema(2).unwrap(); + assert_eq!(result.fields()[0].name(), "a"); + + // steps_out=3 exceeds depth + assert!(consumer.get_outer_schema(3).is_none()); + + // Pop one level — now steps_out=1 returns schema_a + consumer.pop_outer_schema(); + let result = consumer.get_outer_schema(1).unwrap(); + assert_eq!(result.fields()[0].name(), "a"); + assert!(consumer.get_outer_schema(2).is_none()); + } +} diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 2d814654ba68c..4f585ba9009ab 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -77,7 +77,7 @@ mod tests { Subquery: Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]] Projection: PARTSUPP.PS_SUPPLYCOST - Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") + Filter: outer_ref(PART.P_PARTKEY) = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE") Cross Join: Cross Join: Cross Join: @@ -134,7 +134,7 @@ mod tests { Projection: ORDERS.O_ORDERPRIORITY Filter: ORDERS.O_ORDERDATE >= CAST(Utf8("1993-07-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1993-10-01") AS Date32) AND EXISTS () Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE + Filter: LINEITEM.L_ORDERKEY = outer_ref(ORDERS.O_ORDERKEY) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE TableScan: LINEITEM TableScan: ORDERS "# @@ -425,7 +425,7 @@ mod tests { Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY) Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]] Projection: LINEITEM.L_QUANTITY - Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) + Filter: LINEITEM.L_PARTKEY = outer_ref(PARTSUPP.PS_PARTKEY) AND LINEITEM.L_SUPPKEY = outer_ref(PARTSUPP.PS_SUPPKEY) AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32) TableScan: LINEITEM TableScan: PARTSUPP Cross Join: @@ -449,10 +449,10 @@ mod tests { Projection: SUPPLIER.S_NAME Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8("F") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("SAUDI ARABIA") Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS + Filter: LINEITEM.L_ORDERKEY = outer_ref(LINEITEM.L_ORDERKEY) AND LINEITEM.L_SUPPKEY != outer_ref(LINEITEM.L_SUPPKEY) TableScan: LINEITEM Subquery: - Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE + Filter: LINEITEM.L_ORDERKEY = outer_ref(LINEITEM.L_ORDERKEY) AND LINEITEM.L_SUPPKEY != outer_ref(LINEITEM.L_SUPPKEY) AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE TableScan: LINEITEM Cross Join: Cross Join: @@ -483,7 +483,7 @@ mod tests { Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8)) TableScan: CUSTOMER Subquery: - Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY + Filter: ORDERS.O_CUSTKEY = outer_ref(CUSTOMER.C_CUSTKEY) TableScan: ORDERS TableScan: CUSTOMER "# From d0437e0a8d91dd3ec8bca962c5500353417c0bf6 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 19 Feb 2026 17:15:47 -0500 Subject: [PATCH 2/2] Enable Q17, which now parses --- .../tests/cases/consumer_integration.rs | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 4f585ba9009ab..a35fb6bf48e7a 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -353,11 +353,27 @@ mod tests { Ok(()) } - #[ignore] #[tokio::test] async fn tpch_test_17() -> Result<()> { let plan_str = tpch_plan_to_string(17).await?; - assert_snapshot!(plan_str, "panics due to out of bounds field access"); + assert_snapshot!( + plan_str, + @r#" + Projection: sum(LINEITEM.L_EXTENDEDPRICE) / Decimal128(Some(70),2,1) AS AVG_YEARLY + Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE)]] + Projection: LINEITEM.L_EXTENDEDPRICE + Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#23") AND PART.P_CONTAINER = Utf8("MED BOX") AND LINEITEM.L_QUANTITY < () + Subquery: + Projection: Decimal128(Some(2),2,1) * avg(LINEITEM.L_QUANTITY) + Aggregate: groupBy=[[]], aggr=[[avg(LINEITEM.L_QUANTITY)]] + Projection: LINEITEM.L_QUANTITY + Filter: LINEITEM.L_PARTKEY = outer_ref(PART.P_PARTKEY) + TableScan: LINEITEM + Cross Join: + TableScan: LINEITEM + TableScan: PART + "# + ); Ok(()) }