diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index d285f8b377eca..3fc865a828ef8 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -327,6 +327,14 @@ impl DynamicFilterPhysicalExpr { Arc::strong_count(self) > 1 || Arc::strong_count(&self.inner) > 1 } + /// Returns a unique identifier for the inner shared state. + /// + /// Useful for checking if two [Arc] with the same + /// underlying [DynamicFilterPhysicalExpr] are the same. + pub fn inner_id(&self) -> u64 { + Arc::as_ptr(&self.inner) as *const () as u64 + } + fn render( &self, f: &mut std::fmt::Formatter<'_>, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7c0268867691e..9e8f6f8658bc5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -860,6 +860,12 @@ message PhysicalExprNode { // across serde roundtrips. optional uint64 expr_id = 30; + // For DynamicFilterPhysicalExpr, this identifies the shared inner state. + // Multiple expressions may have different expr_id values (different outer Arc wrappers) + // but the same dynamic_filter_inner_id (shared inner state). + // Used to reconstruct shared inner state during deserialization. + optional uint64 dynamic_filter_inner_id = 31; + oneof ExprType { // column references PhysicalColumn column = 1; @@ -897,9 +903,16 @@ message PhysicalExprNode { UnknownColumn unknown_column = 20; PhysicalHashExprNode hash_expr = 21; + + PhysicalDynamicFilterNode dynamic_filter = 22; } } +message PhysicalDynamicFilterNode { + repeated PhysicalExprNode children = 1; + PhysicalExprNode initial_expr = 2; +} + message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 5b2b9133ce13a..c2019693b8084 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16225,6 +16225,115 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { deserializer.deserialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalDynamicFilterNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.children.is_empty() { + len += 1; + } + if self.initial_expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalDynamicFilterNode", len)?; + if !self.children.is_empty() { + struct_ser.serialize_field("children", &self.children)?; + } + if let Some(v) = self.initial_expr.as_ref() { + struct_ser.serialize_field("initialExpr", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalDynamicFilterNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "children", + "initial_expr", + "initialExpr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Children, + InitialExpr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "children" => Ok(GeneratedField::Children), + "initialExpr" | "initial_expr" => Ok(GeneratedField::InitialExpr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalDynamicFilterNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalDynamicFilterNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut children__ = None; + let mut initial_expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Children => { + if children__.is_some() { + return Err(serde::de::Error::duplicate_field("children")); + } + children__ = Some(map_.next_value()?); + } + GeneratedField::InitialExpr => { + if initial_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("initialExpr")); + } + initial_expr__ = map_.next_value()?; + } + } + } + Ok(PhysicalDynamicFilterNode { + children: children__.unwrap_or_default(), + initial_expr: initial_expr__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalDynamicFilterNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -16236,6 +16345,9 @@ impl serde::Serialize for PhysicalExprNode { if self.expr_id.is_some() { len += 1; } + if self.dynamic_filter_inner_id.is_some() { + len += 1; + } if self.expr_type.is_some() { len += 1; } @@ -16245,6 +16357,11 @@ impl serde::Serialize for PhysicalExprNode { #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("exprId", ToString::to_string(&v).as_str())?; } + if let Some(v) = self.dynamic_filter_inner_id.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("dynamicFilterInnerId", ToString::to_string(&v).as_str())?; + } if let Some(v) = self.expr_type.as_ref() { match v { physical_expr_node::ExprType::Column(v) => { @@ -16304,6 +16421,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::HashExpr(v) => { struct_ser.serialize_field("hashExpr", v)?; } + physical_expr_node::ExprType::DynamicFilter(v) => { + struct_ser.serialize_field("dynamicFilter", v)?; + } } } struct_ser.end() @@ -16318,6 +16438,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { const FIELDS: &[&str] = &[ "expr_id", "exprId", + "dynamic_filter_inner_id", + "dynamicFilterInnerId", "column", "literal", "binary_expr", @@ -16350,11 +16472,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "unknownColumn", "hash_expr", "hashExpr", + "dynamic_filter", + "dynamicFilter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { ExprId, + DynamicFilterInnerId, Column, Literal, BinaryExpr, @@ -16374,6 +16499,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { Extension, UnknownColumn, HashExpr, + DynamicFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16396,6 +16522,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { { match value { "exprId" | "expr_id" => Ok(GeneratedField::ExprId), + "dynamicFilterInnerId" | "dynamic_filter_inner_id" => Ok(GeneratedField::DynamicFilterInnerId), "column" => Ok(GeneratedField::Column), "literal" => Ok(GeneratedField::Literal), "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), @@ -16415,6 +16542,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16435,6 +16563,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { V: serde::de::MapAccess<'de>, { let mut expr_id__ = None; + let mut dynamic_filter_inner_id__ = None; let mut expr_type__ = None; while let Some(k) = map_.next_key()? { match k { @@ -16446,6 +16575,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::DynamicFilterInnerId => { + if dynamic_filter_inner_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilterInnerId")); + } + dynamic_filter_inner_id__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::Column => { if expr_type__.is_some() { return Err(serde::de::Error::duplicate_field("column")); @@ -16577,12 +16714,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("hashExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::HashExpr) +; + } + GeneratedField::DynamicFilter => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dynamicFilter")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::DynamicFilter) ; } } } Ok(PhysicalExprNode { expr_id: expr_id__, + dynamic_filter_inner_id: dynamic_filter_inner_id__, expr_type: expr_type__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d9602665c284a..adfef2cd61999 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1292,9 +1292,15 @@ pub struct PhysicalExprNode { /// across serde roundtrips. #[prost(uint64, optional, tag = "30")] pub expr_id: ::core::option::Option, + /// For DynamicFilterPhysicalExpr, this identifies the shared inner state. + /// Multiple expressions may have different expr_id values (different outer Arc wrappers) + /// but the same dynamic_filter_inner_id (shared inner state). + /// Used to reconstruct shared inner state during deserialization. + #[prost(uint64, optional, tag = "31")] + pub dynamic_filter_inner_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22" )] pub expr_type: ::core::option::Option, } @@ -1347,9 +1353,20 @@ pub mod physical_expr_node { UnknownColumn(super::UnknownColumn), #[prost(message, tag = "21")] HashExpr(super::PhysicalHashExprNode), + #[prost(message, tag = "22")] + DynamicFilter(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalDynamicFilterNode { + #[prost(message, repeated, tag = "1")] + pub children: ::prost::alloc::vec::Vec, + #[prost(message, optional, boxed, tag = "2")] + pub initial_expr: ::core::option::Option< + ::prost::alloc::boxed::Box, + >, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalScalarUdfNode { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index e424be162648b..6f63012f97c51 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -58,6 +58,7 @@ use super::{ use crate::logical_plan::{self}; use crate::protobuf::physical_expr_node::ExprType; use crate::{convert_required, protobuf}; +use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -495,6 +496,27 @@ pub fn parse_physical_expr_with_converter( hash_expr.description.clone(), )) } + ExprType::DynamicFilter(dynamic_filter) => { + let children = parse_physical_exprs( + &dynamic_filter.children, + ctx, + input_schema, + codec, + proto_converter, + )?; + + let initial_expr = parse_required_physical_expr( + dynamic_filter.initial_expr.as_deref(), + ctx, + "initial_expr", + input_schema, + codec, + proto_converter, + )?; + + // Constructor signature is: new(children, inner) + Arc::new(DynamicFilterPhysicalExpr::new(children, initial_expr)) + } ExprType::Extension(extension) => { let inputs: Vec> = extension .inputs diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index bfba715b91249..3a196808e228d 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -52,6 +52,7 @@ use datafusion_functions_table::generate_series::{ }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; use datafusion_physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, @@ -3064,6 +3065,7 @@ impl protobuf::PhysicalPlanNode { }); Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -3150,6 +3152,7 @@ impl protobuf::PhysicalPlanNode { }); Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -3818,6 +3821,18 @@ impl DeduplicatingSerializer { session_id: rand::random(), } } + + fn hash(&self, ptr: u64) -> u64 { + // Hash session_id, pointer address, and process ID together to create expr_id. + // - session_id: random per serializer, prevents collisions when merging serializations + // - ptr: unique address per Arc within a process + // - pid: prevents collisions if serializer is shared across processes + let mut hasher = DefaultHasher::new(); + self.session_id.hash(&mut hasher); + ptr.hash(&mut hasher); + std::process::id().hash(&mut hasher); + hasher.finish() + } } impl PhysicalProtoConverterExtension for DeduplicatingSerializer { @@ -3864,16 +3879,14 @@ impl PhysicalProtoConverterExtension for DeduplicatingSerializer { codec: &dyn PhysicalExtensionCodec, ) -> Result { let mut proto = serialize_physical_expr_with_converter(expr, codec, self)?; - - // Hash session_id, pointer address, and process ID together to create expr_id. - // - session_id: random per serializer, prevents collisions when merging serializations - // - ptr: unique address per Arc within a process - // - pid: prevents collisions if serializer is shared across processes - let mut hasher = DefaultHasher::new(); - self.session_id.hash(&mut hasher); - (Arc::as_ptr(expr) as *const () as u64).hash(&mut hasher); - std::process::id().hash(&mut hasher); - proto.expr_id = Some(hasher.finish()); + // Special case for dynamic filters. Two expressions may live in separate Arcs but + // point to the same inner dynamic filter state. This inner state must be deduplicated. + if let Some(dynamic_filter) = + expr.as_any().downcast_ref::() + { + proto.dynamic_filter_inner_id = Some(self.hash(dynamic_filter.inner_id())) + } + proto.expr_id = Some(self.hash(Arc::as_ptr(expr) as *const () as u64)); Ok(proto) } @@ -3885,6 +3898,10 @@ impl PhysicalProtoConverterExtension for DeduplicatingSerializer { struct DeduplicatingDeserializer { /// Cache mapping expr_id to deserialized expressions. cache: RefCell>>, + /// Cache mapping dynamic_filter_inner_id to the first deserialized DynamicFilterPhysicalExpr. + /// This ensures that multiple dynamic filters with the same dynamic_filter_inner_id + /// can share the same inner state after deserialization. + dynamic_filter_cache: RefCell>>, } impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { @@ -3918,12 +3935,52 @@ impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { where Self: Sized, { + // First check the regular cache by expr_id (same outer Arc) if let Some(expr_id) = proto.expr_id { - // Check cache first if let Some(cached) = self.cache.borrow().get(&expr_id) { return Ok(Arc::clone(cached)); } - // Deserialize and cache + } + + // Check if we need to share inner state with a cached dynamic filter + if let Some(dynamic_filter_id) = proto.dynamic_filter_inner_id { + if let Some(cached_filter) = + self.dynamic_filter_cache.borrow().get(&dynamic_filter_id) + { + // We have a cached filter with the same dynamic_filter_inner_id + // Deserialize to get the new children, then create a new Arc with shared inner state + let expr = parse_physical_expr_with_converter( + proto, + ctx, + input_schema, + codec, + self, + )?; + + // Get the children from the newly deserialized expression + if let Some(new_df) = + expr.as_any().downcast_ref::() + { + let new_children: Vec> = + new_df.children().into_iter().cloned().collect(); + // Create a new Arc with the cached filter's inner state but new children + let expr_with_shared_inner = + Arc::clone(cached_filter).with_new_children(new_children)?; + + // Cache by expr_id if present + if let Some(expr_id) = proto.expr_id { + self.cache + .borrow_mut() + .insert(expr_id, Arc::clone(&expr_with_shared_inner)); + } + + return Ok(expr_with_shared_inner); + } + } + } + + // Normal deserialization path + let expr = if let Some(expr_id) = proto.expr_id { let expr = parse_physical_expr_with_converter( proto, ctx, @@ -3932,10 +3989,24 @@ impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { self, )?; self.cache.borrow_mut().insert(expr_id, Arc::clone(&expr)); - Ok(expr) + expr } else { - parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) - } + parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self)? + }; + + // If this is a dynamic filter, cache it by dynamic_filter_inner_id + if let Some(dynamic_filter_id) = proto.dynamic_filter_inner_id { + if expr + .as_any() + .downcast_ref::() + .is_some() + { + self.dynamic_filter_cache + .borrow_mut() + .insert(dynamic_filter_id, Arc::clone(&expr)); + } + }; + Ok(expr) } fn physical_expr_to_proto( diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index a38e59acdab26..9c5e335bceff3 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -36,8 +36,9 @@ use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindo use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + BinaryExpr, CaseExpr, CastExpr, Column, DynamicFilterPhysicalExpr, InListExpr, + IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, + UnKnownColumn, }; use datafusion_physical_plan::joins::{HashExpr, HashTableLookupExpr}; use datafusion_physical_plan::udaf::AggregateFunctionExpr; @@ -72,6 +73,7 @@ pub fn serialize_physical_aggr_expr( codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -256,6 +258,29 @@ pub fn serialize_physical_expr_with_converter( codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { + // Check for DynamicFilterPhysicalExpr before snapshotting + if let Some(df) = value.as_any().downcast_ref::() { + let children = df + .children() + .iter() + .map(|child| proto_converter.physical_expr_to_proto(child, codec)) + .collect::>>()?; + + let current_expr = + Box::new(proto_converter.physical_expr_to_proto(&df.current()?, codec)?); + + return Ok(protobuf::PhysicalExprNode { + expr_id: None, + dynamic_filter_inner_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::DynamicFilter( + Box::new(protobuf::PhysicalDynamicFilterNode { + children, + initial_expr: Some(current_expr), + }), + )), + }); + } + // Snapshot the expr in case it has dynamic predicate state so // it can be serialized let value = snapshot_physical_expr(Arc::clone(value))?; @@ -282,6 +307,7 @@ pub fn serialize_physical_expr_with_converter( }; return Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal(value)), }); } @@ -289,6 +315,7 @@ pub fn serialize_physical_expr_with_converter( if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Column( protobuf::PhysicalColumn { name: expr.name().to_string(), @@ -299,6 +326,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::UnknownColumn( protobuf::UnknownColumn { name: expr.name().to_string(), @@ -318,6 +346,7 @@ pub fn serialize_physical_expr_with_converter( Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( binary_expr, )), @@ -325,6 +354,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some( protobuf::physical_expr_node::ExprType::Case( Box::new( @@ -368,6 +398,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( protobuf::PhysicalNot { expr: Some(Box::new( @@ -379,6 +410,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( Box::new(protobuf::PhysicalIsNull { expr: Some(Box::new( @@ -390,6 +422,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( Box::new(protobuf::PhysicalIsNotNull { expr: Some(Box::new( @@ -401,6 +434,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( protobuf::PhysicalInListNode { expr: Some(Box::new( @@ -414,6 +448,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( protobuf::PhysicalNegativeNode { expr: Some(Box::new( @@ -425,6 +460,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(lit) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( lit.value().try_into()?, )), @@ -432,6 +468,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { expr: Some(Box::new( @@ -444,6 +481,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( protobuf::PhysicalTryCastNode { expr: Some(Box::new( @@ -458,6 +496,7 @@ pub fn serialize_physical_expr_with_converter( codec.try_encode_udf(expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), @@ -475,6 +514,7 @@ pub fn serialize_physical_expr_with_converter( } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( protobuf::PhysicalLikeExprNode { negated: expr.negated(), @@ -492,6 +532,7 @@ pub fn serialize_physical_expr_with_converter( let (s0, s1, s2, s3) = expr.seeds(); Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( protobuf::PhysicalHashExprNode { on_columns: serialize_physical_exprs( @@ -518,6 +559,7 @@ pub fn serialize_physical_expr_with_converter( .collect::>()?; Ok(protobuf::PhysicalExprNode { expr_id: None, + dynamic_filter_inner_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, )), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index bc310150d8982..fcc5cfb55185a 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -121,6 +121,7 @@ use datafusion_proto::physical_plan::{ PhysicalProtoConverterExtension, }; use datafusion_proto::protobuf; +use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; use prost::Message; @@ -129,6 +130,9 @@ use crate::cases::{ MyRegexUdfNode, }; +use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; +use datafusion_physical_expr::utils::reassign_expr_columns; + /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is /// lost during serde because the string representation of a plan often only shows a subset of state. @@ -2751,6 +2755,7 @@ fn test_backward_compatibility_no_expr_id() -> Result<()> { // Manually create a proto without expr_id set let proto = PhysicalExprNode { expr_id: None, // Simulating old proto without this field + dynamic_filter_inner_id: None, expr_type: Some( datafusion_proto::protobuf::physical_expr_node::ExprType::Column( datafusion_proto::protobuf::PhysicalColumn { @@ -2949,6 +2954,210 @@ fn test_deduplication_within_expr_deserialization() -> Result<()> { Ok(()) } +#[test] +fn test_dynamic_filters_different_filter_same_inner_state() { + let filter_expr_1 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + // Column "a" is now at index 1, which creates a new filter. + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int64, false), + Field::new("a", DataType::Int64, false), + ])); + let filter_expr_2 = + reassign_expr_columns(Arc::clone(&filter_expr_1), &schema).unwrap(); + + // Meta-assertion: ensure this test is testing the case where the inner state is the same but + // the exprs are different + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + assert!(!outer_equal); + assert!(inner_equal); + test_deduplication_of_dynamic_filter_expression(filter_expr_1, filter_expr_2, schema) + .unwrap(); +} + +#[test] +fn test_dynamic_filters_same_filter() { + let filter_expr_1 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + let filter_expr_2 = Arc::clone(&filter_expr_1); + + // Ensure this test is testing the case where the inner state is the same and the exprs are the same + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + assert!(outer_equal); + assert!(inner_equal); + test_deduplication_of_dynamic_filter_expression(filter_expr_1, filter_expr_2, schema) + .unwrap(); +} + +#[test] +fn test_dynamic_filters_different_filter() { + let filter_expr_1 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + let filter_expr_2 = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + + // Ensure this test is testing the case where the inner state is the different and the outer exprs are different + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + assert!(!outer_equal); + assert!(!inner_equal); + test_deduplication_of_dynamic_filter_expression(filter_expr_1, filter_expr_2, schema) + .unwrap(); +} + +/// Returns (outer_equal, inner_equal) +/// +/// outer_equal is true if the two arcs point to the same data. +/// inner_equal is true if the two dynamic filters have the same inner state +fn dynamic_filter_outer_inner_equal( + filter_expr_1: &Arc, + filter_expr_2: &Arc, +) -> (bool, bool) { + ( + std::ptr::addr_eq(Arc::as_ptr(filter_expr_1), Arc::as_ptr(filter_expr_2)), + filter_expr_1 + .as_any() + .downcast_ref::() + .unwrap() + .inner_id() + == filter_expr_2 + .as_any() + .downcast_ref::() + .unwrap() + .inner_id(), + ) +} + +fn test_deduplication_of_dynamic_filter_expression( + filter_expr_1: Arc, + filter_expr_2: Arc, + schema: Arc, +) -> Result<()> { + let (outer_equal, inner_equal) = + dynamic_filter_outer_inner_equal(&filter_expr_1, &filter_expr_2); + + // Create execution plan: FilterExec(filter2) -> FilterExec(filter1) -> EmptyExec + let empty_exec = Arc::new(EmptyExec::new(schema)) as Arc; + let filter_exec1 = Arc::new(FilterExec::try_new(filter_expr_1, empty_exec)?) + as Arc; + let filter_exec2 = Arc::new(FilterExec::try_new(filter_expr_2, filter_exec1)?) + as Arc; + + // Serialize the plan + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let proto = converter.execution_plan_to_proto(&filter_exec2, &codec)?; + + let outer_filter = match &proto.physical_plan_type { + Some(PhysicalPlanType::Filter(outer_filter)) => outer_filter, + _ => panic!("Expected PhysicalPlanType::Filter"), + }; + + let inner_filter = match &outer_filter.input { + Some(inner_input) => match &inner_input.physical_plan_type { + Some(PhysicalPlanType::Filter(inner_filter)) => inner_filter, + _ => panic!("Expected PhysicalPlanType::Filter"), + }, + _ => panic!("Expected inner input"), + }; + + let filter1_proto = inner_filter + .expr + .as_ref() + .expect("Should have filter expression"); + + let filter2_proto = outer_filter + .expr + .as_ref() + .expect("Should have filter expression"); + + // Both should have dynamic_filter_inner_id set + let filter1_dynamic_id = filter1_proto + .dynamic_filter_inner_id + .expect("Filter1 should have dynamic_filter_inner_id"); + let filter2_dynamic_id = filter2_proto + .dynamic_filter_inner_id + .expect("Filter2 should have dynamic_filter_inner_id"); + + assert_eq!( + inner_equal, + filter1_dynamic_id == filter2_dynamic_id, + "Dynamic filters sharing the same inner state should have the same dynamic_filter_inner_id" + ); + + let filter1_expr_id = filter1_proto.expr_id.expect("Should have expr_id"); + let filter2_expr_id = filter2_proto.expr_id.expect("Should have expr_id"); + assert_eq!( + outer_equal, + filter1_expr_id == filter2_expr_id, + "Different filters have different expr ids" + ); + + // Test deserialization - verify that filters with same dynamic_filter_inner_id share state + let ctx = SessionContext::new(); + let deserialized_plan = + converter.proto_to_execution_plan(ctx.task_ctx().as_ref(), &codec, &proto)?; + + // Extract the two filter expressions from the deserialized plan + let outer_filter = deserialized_plan + .as_any() + .downcast_ref::() + .expect("Should be FilterExec"); + let filter2_deserialized = outer_filter.predicate(); + + let inner_filter = outer_filter.children()[0] + .as_any() + .downcast_ref::() + .expect("Inner should be FilterExec"); + let filter1_deserialized = inner_filter.predicate(); + + // The Arcs should be different (different outer wrappers) + assert_eq!( + outer_equal, + Arc::ptr_eq(filter1_deserialized, filter2_deserialized), + "Deserialized filters should be different Arcs" + ); + + // Check if they're DynamicFilterPhysicalExpr (they might be snapshotted to Literal) + let (df1, df2) = match ( + filter1_deserialized + .as_any() + .downcast_ref::(), + filter2_deserialized + .as_any() + .downcast_ref::(), + ) { + (Some(df1), Some(df2)) => (df1, df2), + _ => panic!("Should be DynamicFilterPhysicalExpr"), + }; + + // But they should have the same inner_id (shared inner state) + assert_eq!( + inner_equal, + df1.inner_id() == df2.inner_id(), + "Deserialized filters should share inner state" + ); + + Ok(()) +} + /// Test that session_id rotates between top-level serialization operations. /// This verifies that each top-level serialization gets a fresh session_id, /// which prevents cross-process collisions when serialized plans are merged.