Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 90 additions & 21 deletions datafusion/optimizer/src/eliminate_group_by_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
// specific language governing permissions and limitations
// under the License.

//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause
//! [`EliminateGroupByConstant`] removes constant and functionally redundant
//! expressions from `GROUP BY` clause
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

use std::collections::HashSet;

use datafusion_common::Result;
use datafusion_common::tree_node::Transformed;
use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility};
Expand Down Expand Up @@ -47,25 +50,30 @@ impl OptimizerRule for EliminateGroupByConstant {
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Aggregate(aggregate) => {
let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate
// Collect bare column references in GROUP BY
let group_by_columns: HashSet<&datafusion_common::Column> = aggregate
.group_expr
.iter()
.partition(|expr| is_constant_expression(expr));

// If no constant expressions found (nothing to optimize) or
// constant expression is the only expression in aggregate,
// optimization is skipped
if const_group_expr.is_empty()
|| (!const_group_expr.is_empty()
&& nonconst_group_expr.is_empty()
&& aggregate.aggr_expr.is_empty())
.filter_map(|expr| match expr {
Expr::Column(c) => Some(c),
_ => None,
})
.collect();

let (redundant, required): (Vec<_>, Vec<_>) = aggregate
.group_expr
.iter()
.partition(|expr| is_redundant_group_expr(expr, &group_by_columns));

if redundant.is_empty()
|| (required.is_empty() && aggregate.aggr_expr.is_empty())
{
return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
}

let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
aggregate.input,
nonconst_group_expr.into_iter().cloned().collect(),
required.into_iter().cloned().collect(),
aggregate.aggr_expr.clone(),
)?);

Expand All @@ -91,23 +99,47 @@ impl OptimizerRule for EliminateGroupByConstant {
}
}

/// Checks if expression is constant, and can be eliminated from group by.
///
/// Intended to be used only within this rule, helper function, which heavily
/// relies on `SimplifyExpressions` result.
fn is_constant_expression(expr: &Expr) -> bool {
/// Checks if a GROUP BY expression is redundant (can be removed without
/// changing grouping semantics). An expression is redundant if it is a
/// deterministic function of constants and columns already present as bare
/// column references in the GROUP BY.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a more common term for this and something that might be more technically precise for "determinsitic function' would be "functionally dependent"

So instead of

An expression is redundant if it is a
deterministic function of constants and columns already present as bare
column references in the GROUP BY."

Maybe osmething like

An expression is redundant if it is a
it is functionally dependent (e.g. a function of constants and columns already present as bare
column references in the GROUP BY."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deterministic function in this code also seems quite precise for this application?
In the code we check for deterministic functions (those in BinaryExpr), we cannot detect all functional dependencies...

fn is_redundant_group_expr(
expr: &Expr,
group_by_columns: &HashSet<&datafusion_common::Column>,
) -> bool {
// Bare column references are never redundant - they define the grouping
if matches!(expr, Expr::Column(_)) {
return false;
}
is_deterministic_of(expr, group_by_columns)
}

/// Returns true if `expr` is a deterministic expression whose only column
/// references are contained in `known_columns`.
fn is_deterministic_of(
expr: &Expr,
known_columns: &HashSet<&datafusion_common::Column>,
) -> bool {
match expr {
Expr::Alias(e) => is_constant_expression(&e.expr),
Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns),
Expr::Column(c) => known_columns.contains(c),
Expr::Literal(_, _) => true,
Expr::BinaryExpr(e) => {
is_constant_expression(&e.left) && is_constant_expression(&e.right)
is_deterministic_of(&e.left, known_columns)
&& is_deterministic_of(&e.right, known_columns)
}
Expr::Literal(_, _) => true,
Expr::ScalarFunction(e) => {
matches!(
e.func.signature().volatility,
Volatility::Immutable | Volatility::Stable
) && e.args.iter().all(is_constant_expression)
) && e
.args
.iter()
.all(|arg| is_deterministic_of(arg, known_columns))
}
Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns),
Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns),
Expr::Negative(e) => is_deterministic_of(e, known_columns),
_ => false,
}
}
Expand Down Expand Up @@ -268,6 +300,43 @@ mod tests {
")
}

#[test]
fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> {
let scan = test_table_scan()?;
// GROUP BY a, a - 1, a - 2, a - 3 -> GROUP BY a
let plan = LogicalPlanBuilder::from(scan)
.aggregate(
vec![
col("a"),
col("a") - lit(1u32),
col("a") - lit(2u32),
col("a") - lit(3u32),
],
vec![count(col("c"))],
)?
.build()?;

assert_optimized_plan_equal!(plan, @r"
Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - UInt32(3), count(test.c)
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
TableScan: test
")
}

#[test]
fn test_no_eliminate_independent_columns() -> Result<()> {
// GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by column)
let scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(scan)
.aggregate(vec![col("a"), col("b") - lit(1u32)], vec![count(col("c"))])?
.build()?;

assert_optimized_plan_equal!(plan, @r"
Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], aggr=[[count(test.c)]]
TableScan: test
")
}

#[test]
fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
Expand Down
17 changes: 9 additions & 8 deletions datafusion/sqllogictest/test_files/clickbench.slt
Original file line number Diff line number Diff line change
Expand Up @@ -959,19 +959,20 @@ EXPLAIN SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT
----
logical_plan
01)Sort: c DESC NULLS FIRST, fetch=10
02)--Projection: hits.ClientIP, hits.ClientIP - Int64(1), hits.ClientIP - Int64(2), hits.ClientIP - Int64(3), count(Int64(1)) AS count(*) AS c
03)----Aggregate: groupBy=[[hits.ClientIP, __common_expr_1 AS hits.ClientIP - Int64(1), __common_expr_1 AS hits.ClientIP - Int64(2), __common_expr_1 AS hits.ClientIP - Int64(3)]], aggr=[[count(Int64(1))]]
04)------Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, hits.ClientIP
02)--Projection: hits.ClientIP, __common_expr_1 - Int64(1) AS hits.ClientIP - Int64(1), __common_expr_1 - Int64(2) AS hits.ClientIP - Int64(2), __common_expr_1 - Int64(3) AS hits.ClientIP - Int64(3), count(Int64(1)) AS c
03)----Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, hits.ClientIP, count(Int64(1))
04)------Aggregate: groupBy=[[hits.ClientIP]], aggr=[[count(Int64(1))]]
05)--------SubqueryAlias: hits
06)----------TableScan: hits_raw projection=[ClientIP]
physical_plan
01)SortPreservingMergeExec: [c@4 DESC], fetch=10
02)--SortExec: TopK(fetch=10), expr=[c@4 DESC], preserve_partitioning=[true]
03)----ProjectionExec: expr=[ClientIP@0 as ClientIP, hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP - Int64(3), count(Int64(1))@4 as c]
04)------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP, hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP - Int64(3)], aggr=[count(Int64(1))]
05)--------RepartitionExec: partitioning=Hash([ClientIP@0, hits.ClientIP - Int64(1)@1, hits.ClientIP - Int64(2)@2, hits.ClientIP - Int64(3)@3], 4), input_partitions=1
06)----------AggregateExec: mode=Partial, gby=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3)], aggr=[count(Int64(1))]
07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[CAST(ClientIP@7 AS Int64) as __common_expr_1, ClientIP], file_type=parquet
03)----ProjectionExec: expr=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3), count(Int64(1))@2 as c]
04)------ProjectionExec: expr=[CAST(ClientIP@0 AS Int64) as __common_expr_1, ClientIP@0 as ClientIP, count(Int64(1))@1 as count(Int64(1))]
05)--------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP], aggr=[count(Int64(1))]
06)----------RepartitionExec: partitioning=Hash([ClientIP@0], 4), input_partitions=1
07)------------AggregateExec: mode=Partial, gby=[ClientIP@0 as ClientIP], aggr=[count(Int64(1))]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thing it not only removes the groupby keys but moves the projections above, minimizing the data usage/shuffle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is pretty clever

08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[ClientIP], file_type=parquet

query IIIII rowsort
SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10;
Expand Down
Loading