From 66c0be5514878e2345598e211903d51963d1447b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Wed, 4 Mar 2026 21:42:09 +0100 Subject: [PATCH 1/2] Eliminate group by (a-1), (a - 2) --- .../src/eliminate_group_by_constant.rs | 111 ++++++++++++++---- .../sqllogictest/test_files/clickbench.slt | 17 +-- 2 files changed, 99 insertions(+), 29 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index e93edc62403a9..736e79bd66fa9 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause +//! [`EliminateGroupByConstant`] removes constant and functionally redundant +//! expressions from `GROUP BY` clause use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use std::collections::HashSet; + use datafusion_common::Result; use datafusion_common::tree_node::Transformed; use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility}; @@ -47,25 +50,30 @@ impl OptimizerRule for EliminateGroupByConstant { ) -> Result> { match plan { LogicalPlan::Aggregate(aggregate) => { - let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate + // Collect bare column references in GROUP BY + let group_by_columns: HashSet<&datafusion_common::Column> = aggregate .group_expr .iter() - .partition(|expr| is_constant_expression(expr)); - - // If no constant expressions found (nothing to optimize) or - // constant expression is the only expression in aggregate, - // optimization is skipped - if const_group_expr.is_empty() - || (!const_group_expr.is_empty() - && nonconst_group_expr.is_empty() - && aggregate.aggr_expr.is_empty()) + .filter_map(|expr| match expr { + Expr::Column(c) => Some(c), + _ => None, + }) + .collect(); + + let (redundant, required): (Vec<_>, Vec<_>) = aggregate + .group_expr + .iter() + .partition(|expr| is_redundant_group_expr(expr, &group_by_columns)); + + if redundant.is_empty() + || (required.is_empty() && aggregate.aggr_expr.is_empty()) { return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); } let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( aggregate.input, - nonconst_group_expr.into_iter().cloned().collect(), + required.into_iter().cloned().collect(), aggregate.aggr_expr.clone(), )?); @@ -91,23 +99,44 @@ impl OptimizerRule for EliminateGroupByConstant { } } -/// Checks if expression is constant, and can be eliminated from group by. -/// -/// Intended to be used only within this rule, helper function, which heavily -/// relies on `SimplifyExpressions` result. -fn is_constant_expression(expr: &Expr) -> bool { +/// Checks if a GROUP BY expression is redundant (can be removed without +/// changing grouping semantics). An expression is redundant if it is a +/// deterministic function of constants and columns already present as bare +/// column references in the GROUP BY. +fn is_redundant_group_expr( + expr: &Expr, + group_by_columns: &HashSet<&datafusion_common::Column>, +) -> bool { + // Bare column references are never redundant - they define the grouping + if matches!(expr, Expr::Column(_)) { + return false; + } + is_deterministic_of(expr, group_by_columns) +} + +/// Returns true if `expr` is a deterministic expression whose only column +/// references are contained in `known_columns`. +fn is_deterministic_of( + expr: &Expr, + known_columns: &HashSet<&datafusion_common::Column>, +) -> bool { match expr { - Expr::Alias(e) => is_constant_expression(&e.expr), + Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns), + Expr::Column(c) => known_columns.contains(c), + Expr::Literal(_, _) => true, Expr::BinaryExpr(e) => { - is_constant_expression(&e.left) && is_constant_expression(&e.right) + is_deterministic_of(&e.left, known_columns) + && is_deterministic_of(&e.right, known_columns) } - Expr::Literal(_, _) => true, Expr::ScalarFunction(e) => { matches!( e.func.signature().volatility, Volatility::Immutable | Volatility::Stable - ) && e.args.iter().all(is_constant_expression) + ) && e.args.iter().all(|arg| is_deterministic_of(arg, known_columns)) } + Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns), + Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns), + Expr::Negative(e) => is_deterministic_of(e, known_columns), _ => false, } } @@ -268,6 +297,46 @@ mod tests { ") } + #[test] + fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> { + let scan = test_table_scan()?; + // GROUP BY a, a - 1, a - 2, a - 3 -> GROUP BY a + let plan = LogicalPlanBuilder::from(scan) + .aggregate( + vec![ + col("a"), + col("a") - lit(1u32), + col("a") - lit(2u32), + col("a") - lit(3u32), + ], + vec![count(col("c"))], + )? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - UInt32(3), count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") + } + + #[test] + fn test_no_eliminate_independent_columns() -> Result<()> { + // GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by column) + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate( + vec![col("a"), col("b") - lit(1u32)], + vec![count(col("c"))], + )? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], aggr=[[count(test.c)]] + TableScan: test + ") + } + #[test] fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> { let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility( diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index 10059664adad7..dd558a4f36f91 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -959,19 +959,20 @@ EXPLAIN SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT ---- logical_plan 01)Sort: c DESC NULLS FIRST, fetch=10 -02)--Projection: hits.ClientIP, hits.ClientIP - Int64(1), hits.ClientIP - Int64(2), hits.ClientIP - Int64(3), count(Int64(1)) AS count(*) AS c -03)----Aggregate: groupBy=[[hits.ClientIP, __common_expr_1 AS hits.ClientIP - Int64(1), __common_expr_1 AS hits.ClientIP - Int64(2), __common_expr_1 AS hits.ClientIP - Int64(3)]], aggr=[[count(Int64(1))]] -04)------Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, hits.ClientIP +02)--Projection: hits.ClientIP, __common_expr_1 - Int64(1) AS hits.ClientIP - Int64(1), __common_expr_1 - Int64(2) AS hits.ClientIP - Int64(2), __common_expr_1 - Int64(3) AS hits.ClientIP - Int64(3), count(Int64(1)) AS c +03)----Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, hits.ClientIP, count(Int64(1)) +04)------Aggregate: groupBy=[[hits.ClientIP]], aggr=[[count(Int64(1))]] 05)--------SubqueryAlias: hits 06)----------TableScan: hits_raw projection=[ClientIP] physical_plan 01)SortPreservingMergeExec: [c@4 DESC], fetch=10 02)--SortExec: TopK(fetch=10), expr=[c@4 DESC], preserve_partitioning=[true] -03)----ProjectionExec: expr=[ClientIP@0 as ClientIP, hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP - Int64(3), count(Int64(1))@4 as c] -04)------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP, hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP - Int64(3)], aggr=[count(Int64(1))] -05)--------RepartitionExec: partitioning=Hash([ClientIP@0, hits.ClientIP - Int64(1)@1, hits.ClientIP - Int64(2)@2, hits.ClientIP - Int64(3)@3], 4), input_partitions=1 -06)----------AggregateExec: mode=Partial, gby=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3)], aggr=[count(Int64(1))] -07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[CAST(ClientIP@7 AS Int64) as __common_expr_1, ClientIP], file_type=parquet +03)----ProjectionExec: expr=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3), count(Int64(1))@2 as c] +04)------ProjectionExec: expr=[CAST(ClientIP@0 AS Int64) as __common_expr_1, ClientIP@0 as ClientIP, count(Int64(1))@1 as count(Int64(1))] +05)--------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP], aggr=[count(Int64(1))] +06)----------RepartitionExec: partitioning=Hash([ClientIP@0], 4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[ClientIP@0 as ClientIP], aggr=[count(Int64(1))] +08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[ClientIP], file_type=parquet query IIIII rowsort SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; From c22396a6fc46752b946120e330d01dd7db8eae4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Wed, 4 Mar 2026 21:46:03 +0100 Subject: [PATCH 2/2] Fmt --- .../optimizer/src/eliminate_group_by_constant.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 736e79bd66fa9..6f5ca59e31113 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -132,7 +132,10 @@ fn is_deterministic_of( matches!( e.func.signature().volatility, Volatility::Immutable | Volatility::Stable - ) && e.args.iter().all(|arg| is_deterministic_of(arg, known_columns)) + ) && e + .args + .iter() + .all(|arg| is_deterministic_of(arg, known_columns)) } Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns), Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns), @@ -325,10 +328,7 @@ mod tests { // GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by column) let scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(scan) - .aggregate( - vec![col("a"), col("b") - lit(1u32)], - vec![count(col("c"))], - )? + .aggregate(vec![col("a"), col("b") - lit(1u32)], vec![count(col("c"))])? .build()?; assert_optimized_plan_equal!(plan, @r"