diff --git a/Cargo.lock b/Cargo.lock index 66aef04c9239..011155b77e8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2360,6 +2360,7 @@ dependencies = [ "datafusion-functions", "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", + "datafusion-proto-common", "datafusion-proto-models", "half", "hashbrown 0.17.1", diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 65ef2a3ceb21..7c4a83ba6288 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -46,6 +46,7 @@ recursive_protection = ["dep:recursive"] # `PhysicalExpr::to_proto` and letting expressions in this crate implement it. proto = [ "dep:datafusion-proto-models", + "dep:datafusion-proto-common", "datafusion-physical-expr-common/proto", ] @@ -56,6 +57,7 @@ datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } +datafusion-proto-common = { workspace = true, optional = true } datafusion-proto-models = { workspace = true, optional = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index ba59d113acaa..94fe153f979e 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -119,6 +119,67 @@ impl PhysicalExpr for TryCastExpr { self.expr.fmt_sql(f)?; write!(f, " AS {:?})", self.cast_type) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + let arrow_type: datafusion_proto_common::protobuf_common::ArrowType = + self.cast_type().try_into()?; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(ctx.encode_child(&self.expr)?)), + arrow_type: Some(arrow_type), + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl TryCastExpr { + /// Reconstruct a [`TryCastExpr`] from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`] so the decode signature matches + /// other migrated expressions and can inspect outer-node metadata if + /// needed in the future. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_common::{internal_datafusion_err, internal_err}; + use datafusion_proto_models::protobuf; + + let try_cast = match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::TryCast(try_cast)) => { + try_cast.as_ref() + } + _ => return internal_err!("PhysicalExprNode is not a TryCastExpr"), + }; + + let expr = ctx.decode_required_expression( + try_cast.expr.as_deref(), + "TryCastExpr", + "expr", + )?; + let arrow_type: &datafusion_proto_common::protobuf_common::ArrowType = + try_cast.arrow_type.as_ref().ok_or_else(|| { + internal_datafusion_err!( + "TryCastExpr is missing required field 'arrow_type'" + ) + })?; + let cast_type: DataType = arrow_type.try_into()?; + + Ok(Arc::new(TryCastExpr::new(expr, cast_type))) + } } /// Return a PhysicalExpression representing `expr` casted to @@ -143,6 +204,8 @@ pub fn try_cast( #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "proto")] + use crate::expressions::Column; use crate::expressions::col; use arrow::array::{ Decimal128Array, Decimal128Builder, StringArray, Time64NanosecondArray, @@ -154,7 +217,20 @@ mod tests { }, datatypes::*, }; + #[cfg(feature = "proto")] + use datafusion_common::DataFusionError; use datafusion_physical_expr_common::physical_expr::fmt_sql; + #[cfg(feature = "proto")] + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + #[cfg(feature = "proto")] + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + #[cfg(feature = "proto")] + use datafusion_proto_models::protobuf::{self, physical_expr_node}; + + #[cfg(feature = "proto")] + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; // runs an end-to-end test of physical type cast // 1. construct a record batch with a column "a" of type A @@ -592,4 +668,107 @@ mod tests { Ok(()) } + + #[cfg(feature = "proto")] + fn try_cast_node( + expr: Option>, + cast_type: Option, + ) -> protobuf::PhysicalExprNode { + protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr, + arrow_type: cast_type.map(|cast_type| { + let arrow_type: datafusion_proto_common::protobuf_common::ArrowType = + (&cast_type).try_into().unwrap(); + arrow_type + }), + }, + ))), + } + } + + #[cfg(feature = "proto")] + #[test] + fn try_to_proto_encodes_try_cast_expr() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let expr = TryCastExpr::new(col("a", &schema)?, DataType::Int32); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = expr.try_to_proto(&ctx)?.expect("TryCastExpr proto"); + let try_cast = match node.expr_type { + Some(physical_expr_node::ExprType::TryCast(try_cast)) => try_cast, + other => panic!("expected TryCast proto, got {other:?}"), + }; + + assert!(try_cast.expr.is_some()); + let cast_type: DataType = try_cast.arrow_type.as_ref().unwrap().try_into()?; + assert_eq!(cast_type, DataType::Int32); + + Ok(()) + } + + #[cfg(feature = "proto")] + #[test] + fn try_from_proto_decodes_try_cast_expr() { + let node = try_cast_node(Some(Box::new(column_node("a"))), Some(DataType::Int64)); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = TryCastExpr::try_from_proto(&node, &ctx).unwrap(); + let try_cast = decoded + .downcast_ref::() + .expect("decoded expr should be a TryCastExpr"); + + assert!(try_cast.expr().downcast_ref::().is_some()); + assert_eq!(try_cast.cast_type(), &DataType::Int64); + } + + #[cfg(feature = "proto")] + #[test] + fn try_from_proto_rejects_non_try_cast_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a TryCastExpr") + )); + } + + #[cfg(feature = "proto")] + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = try_cast_node(None, Some(DataType::Int32)); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'expr'") + )); + } + + #[cfg(feature = "proto")] + #[test] + fn try_from_proto_rejects_missing_arrow_type() { + let node = try_cast_node(Some(Box::new(column_node("a"))), None); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'arrow_type'") + )); + } } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 7d2e68d81095..1ac94236e2d7 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -371,16 +371,7 @@ pub fn parse_physical_expr_with_converter( convert_required!(e.arrow_type)?, None, )), - ExprType::TryCast(e) => Arc::new(TryCastExpr::new( - parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - proto_converter, - )?, - convert_required!(e.arrow_type)?, - )), + ExprType::TryCast(_) => TryCastExpr::try_from_proto(proto, &decode_ctx)?, ExprType::ScalarUdf(e) => { let udf = match &e.fun_definition { Some(buf) => ctx.codec().try_decode_udf(&e.name, buf)?, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 9cb9e897605b..6930335b0aa4 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -37,7 +37,7 @@ use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindo use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::expressions::{ CaseExpr, CastExpr, DynamicFilterPhysicalExpr, IsNotNullExpr, IsNullExpr, Literal, - NotExpr, TryCastExpr, UnKnownColumn, + NotExpr, UnKnownColumn, }; use datafusion_physical_plan::joins::HashExpr; use datafusion_physical_plan::udaf::AggregateFunctionExpr; @@ -408,18 +408,6 @@ pub fn serialize_physical_expr_with_converter( }, ))), }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_id, - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( - protobuf::PhysicalTryCastNode { - expr: Some(Box::new( - proto_converter.physical_expr_to_proto(cast.expr(), codec)?, - )), - arrow_type: Some(cast.cast_type().try_into()?), - }, - ))), - }) } else if let Some(expr) = expr.downcast_ref::() { let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?;