Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@
// under the License.

use crate::logical_plan::consumer::SubstraitConsumer;
use datafusion::common::{Column, DFSchema, not_impl_err};
use datafusion::common::{Column, DFSchema, not_impl_err, substrait_err};
use datafusion::logical_expr::Expr;
use std::sync::Arc;
use substrait::proto::expression::FieldReference;
use substrait::proto::expression::field_reference::ReferenceType::DirectReference;
use substrait::proto::expression::field_reference::RootType;
use substrait::proto::expression::reference_segment::ReferenceType::StructField;

pub async fn from_field_reference(
_consumer: &impl SubstraitConsumer,
consumer: &impl SubstraitConsumer,
field_ref: &FieldReference,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
from_substrait_field_reference(field_ref, input_schema)
from_substrait_field_reference(consumer, field_ref, input_schema)
}

pub(crate) fn from_substrait_field_reference(
consumer: &impl SubstraitConsumer,
field_ref: &FieldReference,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
Expand All @@ -40,9 +43,39 @@ pub(crate) fn from_substrait_field_reference(
Some(_) => not_impl_err!(
"Direct reference StructField with child is not supported"
),
None => Ok(Expr::Column(Column::from(
input_schema.qualified_field(x.field as usize),
))),
None => {
let field_idx = x.field as usize;
match &field_ref.root_type {
// Normal reference: resolve against the current
// relation's schema.
Some(RootType::RootReference(_)) | None => Ok(Expr::Column(
Column::from(input_schema.qualified_field(field_idx)),
)),
// Correlated reference: resolve against the
// enclosing query's schema.
Some(RootType::OuterReference(outer_ref)) => {
let steps_out = outer_ref.steps_out as usize;
let Some(outer_schema) = consumer.get_outer_schema(steps_out)
else {
return substrait_err!(
"OuterReference with steps_out={steps_out} \
but no outer schema is available"
);
};
let (qualifier, field) =
outer_schema.qualified_field(field_idx);
let col = Column::from((qualifier, field));
Ok(Expr::OuterReferenceColumn(Arc::clone(field), col))
}
// The root is an arbitrary expression rather
// than a relation's schema.
Some(RootType::Expression(_)) => {
not_impl_err!(
"Expression root type in field reference is not supported"
)
}
}
}
},
_ => not_impl_err!(
"Direct reference with types other than StructField is not supported"
Expand Down
5 changes: 1 addition & 4 deletions datafusion/substrait/src/logical_plan/consumer/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,7 @@ pub async fn from_substrait_extended_expr(
return not_impl_err!("Type variation extensions are not supported");
}

let consumer = DefaultSubstraitConsumer {
extensions: &extensions,
state,
};
let consumer = DefaultSubstraitConsumer::new(&extensions, state);

let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
Expand Down
44 changes: 35 additions & 9 deletions datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,31 @@
use crate::logical_plan::consumer::SubstraitConsumer;
use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err};
use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier};
use datafusion::logical_expr::{Expr, Operator, Subquery};
use datafusion::logical_expr::{Expr, LogicalPlan, Operator, Subquery};
use std::sync::Arc;
use substrait::proto::Rel;
use substrait::proto::expression as substrait_expression;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp};
use substrait::proto::expression::subquery::set_predicate::PredicateOp;

/// Consume a subquery relation, making the enclosing query's schema
/// available for resolving correlated column references.
///
/// Substrait represents correlated references using `OuterReference`
/// field references with a `steps_out` depth. To resolve these,
/// the consumer maintains a stack of outer schemas.
async fn consume_subquery_rel(
consumer: &impl SubstraitConsumer,
rel: &Rel,
outer_schema: &DFSchema,
) -> datafusion::common::Result<LogicalPlan> {
consumer.push_outer_schema(Arc::new(outer_schema.clone()));
let result = consumer.consume_rel(rel).await;
consumer.pop_outer_schema();
result
}

pub async fn from_subquery(
consumer: &impl SubstraitConsumer,
subquery: &substrait_expression::Subquery,
Expand All @@ -41,7 +59,9 @@ pub async fn from_subquery(
let needle_expr = &in_predicate.needles[0];
let haystack_expr = &in_predicate.haystack;
if let Some(haystack_expr) = haystack_expr {
let haystack_expr = consumer.consume_rel(haystack_expr).await?;
let haystack_expr =
consume_subquery_rel(consumer, haystack_expr, input_schema)
.await?;
let outer_refs = haystack_expr.all_out_ref_exprs();
Ok(Expr::InSubquery(InSubquery {
expr: Box::new(
Expand All @@ -64,9 +84,12 @@ pub async fn from_subquery(
}
}
SubqueryType::Scalar(query) => {
let plan = consumer
.consume_rel(&(query.input.clone()).unwrap_or_default())
.await?;
let plan = consume_subquery_rel(
consumer,
&(query.input.clone()).unwrap_or_default(),
input_schema,
)
.await?;
let outer_ref_columns = plan.all_out_ref_exprs();
Ok(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(plan),
Expand All @@ -79,9 +102,12 @@ pub async fn from_subquery(
// exist
PredicateOp::Exists => {
let relation = &predicate.tuples;
let plan = consumer
.consume_rel(&relation.clone().unwrap_or_default())
.await?;
let plan = consume_subquery_rel(
consumer,
&relation.clone().unwrap_or_default(),
input_schema,
)
.await?;
let outer_ref_columns = plan.all_out_ref_exprs();
Ok(Expr::Exists(Exists::new(
Subquery {
Expand Down Expand Up @@ -131,7 +157,7 @@ pub async fn from_subquery(
};

let left_expr = consumer.consume_expression(left, input_schema).await?;
let plan = consumer.consume_rel(right).await?;
let plan = consume_subquery_rel(consumer, right, input_schema).await?;
let outer_ref_columns = plan.all_out_ref_exprs();

Ok(Expr::SetComparison(SetComparison::new(
Expand Down
5 changes: 1 addition & 4 deletions datafusion/substrait/src/logical_plan/consumer/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ pub async fn from_substrait_plan(
return not_impl_err!("Type variation extensions are not supported");
}

let consumer = DefaultSubstraitConsumer {
extensions: &extensions,
state,
};
let consumer = DefaultSubstraitConsumer::new(&extensions, state);
from_substrait_plan_with_consumer(&consumer, plan).await
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ pub async fn from_exchange_rel(
let mut partition_columns = vec![];
let input_schema = input.schema();
for field_ref in &scatter_fields.fields {
let column = from_substrait_field_reference(field_ref, input_schema)?;
let column =
from_substrait_field_reference(consumer, field_ref, input_schema)?;
partition_columns.push(column);
}
Partitioning::Hash(partition_columns, exchange.partition_count as usize)
Expand Down
123 changes: 121 additions & 2 deletions datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion::common::{
};
use datafusion::execution::{FunctionRegistry, SessionState};
use datafusion::logical_expr::{Expr, Extension, LogicalPlan};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use substrait::proto;
use substrait::proto::expression as substrait_expression;
use substrait::proto::expression::{
Expand Down Expand Up @@ -364,6 +364,26 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
not_impl_err!("Dynamic Parameter expression not supported")
}

// Outer Schema Stack
// These methods manage a stack of outer schemas for correlated subquery support.
// When entering a subquery, the enclosing query's schema is pushed onto the stack.
// Field references with OuterReference root_type use these to resolve columns.

/// Push an outer schema onto the stack when entering a subquery.
fn push_outer_schema(&self, _schema: Arc<DFSchema>) {}

/// Pop an outer schema from the stack when leaving a subquery.
fn pop_outer_schema(&self) {}

/// Get the outer schema at the given nesting depth.
/// `steps_out = 1` is the immediately enclosing query, `steps_out = 2`
/// is two levels out, etc. Returns `None` if `steps_out` is 0 or
/// exceeds the current nesting depth (the caller should treat this as
/// an error in the Substrait plan).
fn get_outer_schema(&self, _steps_out: usize) -> Option<Arc<DFSchema>> {
None
}

// User-Defined Functionality

// The details of extension relations, and how to handle them, are fully up to users to specify.
Expand Down Expand Up @@ -437,11 +457,16 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
pub struct DefaultSubstraitConsumer<'a> {
pub(super) extensions: &'a Extensions,
pub(super) state: &'a SessionState,
outer_schemas: RwLock<Vec<Arc<DFSchema>>>,
}

impl<'a> DefaultSubstraitConsumer<'a> {
pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self {
DefaultSubstraitConsumer { extensions, state }
DefaultSubstraitConsumer {
extensions,
state,
outer_schemas: RwLock::new(Vec::new()),
}
}
}

Expand All @@ -465,6 +490,24 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
self.state
}

fn push_outer_schema(&self, schema: Arc<DFSchema>) {
self.outer_schemas.write().unwrap().push(schema);
}

fn pop_outer_schema(&self) {
self.outer_schemas.write().unwrap().pop();
}

fn get_outer_schema(&self, steps_out: usize) -> Option<Arc<DFSchema>> {
let schemas = self.outer_schemas.read().unwrap();
// steps_out=1 → last element, steps_out=2 → second-to-last, etc.
// Returns None for steps_out=0 or steps_out > stack depth.
schemas
.len()
.checked_sub(steps_out)
.and_then(|idx| schemas.get(idx).cloned())
}

async fn consume_extension_leaf(
&self,
rel: &ExtensionLeafRel,
Expand Down Expand Up @@ -520,3 +563,79 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
Ok(LogicalPlan::Extension(Extension { node: plan }))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::consumer::utils::tests::test_consumer;
use datafusion::arrow::datatypes::{DataType, Field, Schema};

fn make_schema(fields: &[(&str, DataType)]) -> Arc<DFSchema> {
let arrow_fields: Vec<Field> = fields
.iter()
.map(|(name, dt)| Field::new(*name, dt.clone(), true))
.collect();
Arc::new(
DFSchema::try_from(Schema::new(arrow_fields))
.expect("failed to create schema"),
)
}

#[test]
fn test_get_outer_schema_empty_stack() {
let consumer = test_consumer();

// No schemas pushed — any steps_out should return None
assert!(consumer.get_outer_schema(0).is_none());
assert!(consumer.get_outer_schema(1).is_none());
assert!(consumer.get_outer_schema(2).is_none());
}

#[test]
fn test_get_outer_schema_single_level() {
let consumer = test_consumer();

let schema_a = make_schema(&[("a", DataType::Int64)]);
consumer.push_outer_schema(Arc::clone(&schema_a));

// steps_out=1 returns the one pushed schema
let result = consumer.get_outer_schema(1).unwrap();
assert_eq!(result.fields().len(), 1);
assert_eq!(result.fields()[0].name(), "a");

// steps_out=0 and steps_out=2 are out of range
assert!(consumer.get_outer_schema(0).is_none());
assert!(consumer.get_outer_schema(2).is_none());

consumer.pop_outer_schema();
assert!(consumer.get_outer_schema(1).is_none());
}

#[test]
fn test_get_outer_schema_nested() {
let consumer = test_consumer();

let schema_a = make_schema(&[("a", DataType::Int64)]);
let schema_b = make_schema(&[("b", DataType::Utf8)]);

consumer.push_outer_schema(Arc::clone(&schema_a));
consumer.push_outer_schema(Arc::clone(&schema_b));

// steps_out=1 returns the most recent (schema_b)
let result = consumer.get_outer_schema(1).unwrap();
assert_eq!(result.fields()[0].name(), "b");

// steps_out=2 returns the grandparent (schema_a)
let result = consumer.get_outer_schema(2).unwrap();
assert_eq!(result.fields()[0].name(), "a");

// steps_out=3 exceeds depth
assert!(consumer.get_outer_schema(3).is_none());

// Pop one level — now steps_out=1 returns schema_a
consumer.pop_outer_schema();
let result = consumer.get_outer_schema(1).unwrap();
assert_eq!(result.fields()[0].name(), "a");
assert!(consumer.get_outer_schema(2).is_none());
}
}
Loading