Skip to content
5 changes: 5 additions & 0 deletions .changeset/support-nested-aggregates.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@tanstack/db': patch
---

fix: support aggregates nested inside expressions (e.g. `coalesce(count(...), 0)`)
145 changes: 131 additions & 14 deletions packages/db/src/query/compiler/group-by.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
map,
serializeValue,
} from '@tanstack/db-ivm'
import { Func, PropRef, getHavingExpression } from '../ir.js'
import { Func, PropRef, getHavingExpression, isExpressionLike } from '../ir.js'
import {
AggregateFunctionNotInSelectError,
NonAggregateExpressionNotInGroupByError,
Expand Down Expand Up @@ -49,8 +49,8 @@ function validateAndCreateMapping(

// Validate each SELECT expression
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
// Aggregate expressions are allowed and don't need to be in GROUP BY
if (expr.type === `agg` || containsAggregate(expr)) {
// Aggregate expressions (plain or wrapped) are allowed and don't need to be in GROUP BY
continue
}

Expand Down Expand Up @@ -86,12 +86,26 @@ export function processGroupBy(
// For single-group aggregation, create a single group with all data
const aggregates: Record<string, any> = {}

// Expressions that wrap aggregates (e.g. coalesce(count(...), 0)).
// Keys are the original SELECT aliases; values are pre-compiled evaluators
// over the transformed (aggregate-free) expression.
const wrappedAggExprs: Record<string, (data: any) => any> = {}
const aggCounter = { value: 0 }

if (selectClause) {
// Scan the SELECT clause for aggregate functions
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
const aggExpr = expr
aggregates[alias] = getAggregateFunction(aggExpr)
aggregates[alias] = getAggregateFunction(expr)
} else if (containsAggregate(expr)) {
const { transformed, extracted } = extractAndReplaceAggregates(
expr as BasicExpression | Aggregate,
aggCounter,
)
for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) {
aggregates[syntheticAlias] = getAggregateFunction(aggExpr)
}
wrappedAggExprs[alias] = compileExpression(transformed)
}
}
}
Expand All @@ -112,13 +126,17 @@ export function processGroupBy(
const finalResults: Record<string, any> = { ...selectResults }

if (selectClause) {
// Update with aggregate results
// First pass: populate plain aggregate results and synthetic aliases
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
finalResults[alias] = aggregatedRow[alias]
}
// Non-aggregates keep their original values from early SELECT processing
}
evaluateWrappedAggregates(
finalResults,
aggregatedRow as Record<string, any>,
wrappedAggExprs,
)
}

// Use a single key for the result and update $selected
Expand Down Expand Up @@ -201,13 +219,23 @@ export function processGroupBy(

// Create aggregate functions for any aggregated columns in the SELECT clause
const aggregates: Record<string, any> = {}
const wrappedAggExprs: Record<string, (data: any) => any> = {}
const aggCounter = { value: 0 }

if (selectClause) {
// Scan the SELECT clause for aggregate functions
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type === `agg`) {
const aggExpr = expr
aggregates[alias] = getAggregateFunction(aggExpr)
aggregates[alias] = getAggregateFunction(expr)
} else if (containsAggregate(expr)) {
const { transformed, extracted } = extractAndReplaceAggregates(
expr as BasicExpression | Aggregate,
aggCounter,
)
for (const [syntheticAlias, aggExpr] of Object.entries(extracted)) {
aggregates[syntheticAlias] = getAggregateFunction(aggExpr)
}
wrappedAggExprs[alias] = compileExpression(transformed)
}
}
}
Expand All @@ -223,9 +251,11 @@ export function processGroupBy(
const finalResults: Record<string, any> = {}

if (selectClause) {
// Process each SELECT expression
// First pass: populate group keys, plain aggregates, and synthetic aliases
for (const [alias, expr] of Object.entries(selectClause)) {
if (expr.type !== `agg`) {
if (expr.type === `agg`) {
finalResults[alias] = aggregatedRow[alias]
} else if (!wrappedAggExprs[alias]) {
// Use cached mapping to get the corresponding __key_X for non-aggregates
const groupIndex = mapping.selectToGroupByIndex.get(alias)
if (groupIndex !== undefined) {
Expand All @@ -234,11 +264,13 @@ export function processGroupBy(
// Fallback to original SELECT results
finalResults[alias] = selectResults[alias]
}
} else {
// Use aggregate results
finalResults[alias] = aggregatedRow[alias]
}
}
evaluateWrappedAggregates(
finalResults,
aggregatedRow as Record<string, any>,
wrappedAggExprs,
)
} else {
// No SELECT clause - just use the group keys
for (let i = 0; i < groupByClause.length; i++) {
Expand Down Expand Up @@ -457,6 +489,91 @@ export function replaceAggregatesByRefs(
}
}

/**
* Evaluates wrapped-aggregate expressions against the aggregated row.
* Copies synthetic __agg_N values into finalResults so the compiled wrapper
* expressions can reference them, evaluates each wrapper, then removes the
* synthetic keys so they don't leak onto user-visible result rows.
*/
function evaluateWrappedAggregates(
finalResults: Record<string, any>,
aggregatedRow: Record<string, any>,
wrappedAggExprs: Record<string, (data: any) => any>,
): void {
for (const key of Object.keys(aggregatedRow)) {
if (key.startsWith(`__agg_`)) {
finalResults[key] = aggregatedRow[key]
}
}
for (const [alias, evaluator] of Object.entries(wrappedAggExprs)) {
finalResults[alias] = evaluator({ $selected: finalResults })
}
for (const key of Object.keys(finalResults)) {
if (key.startsWith(`__agg_`)) delete finalResults[key]
}
}

/**
* Checks whether an expression contains an aggregate anywhere in its tree.
* Returns true for a top-level Aggregate, or a Func whose args (recursively)
* contain an Aggregate. Safely returns false for nested Select objects.
*/
export function containsAggregate(
expr: BasicExpression | Aggregate | Select,
): boolean {
if (!isExpressionLike(expr)) {
return false
}
if (expr.type === `agg`) {
return true
}
if (expr.type === `func`) {
return expr.args.some((arg: BasicExpression | Aggregate) =>
containsAggregate(arg),
)
}
return false
}

/**
* Walks an expression tree containing nested aggregates.
* Each Aggregate node is extracted, assigned a synthetic alias (__agg_N),
* and replaced with PropRef(["$selected", "__agg_N"]) so the wrapper
* expression can be compiled as a pure BasicExpression after groupBy
* populates the synthetic values.
*/
function extractAndReplaceAggregates(
expr: BasicExpression | Aggregate,
counter: { value: number },
): {
transformed: BasicExpression
extracted: Record<string, Aggregate>
} {
if (expr.type === `agg`) {
const alias = `__agg_${counter.value++}`
return {
transformed: new PropRef([`$selected`, alias]),
extracted: { [alias]: expr },
}
}

if (expr.type === `func`) {
const allExtracted: Record<string, Aggregate> = {}
const newArgs = expr.args.map((arg: BasicExpression | Aggregate) => {
const result = extractAndReplaceAggregates(arg, counter)
Object.assign(allExtracted, result.extracted)
return result.transformed
})
return {
transformed: new Func(expr.name, newArgs),
extracted: allExtracted,
}
}

// ref / val – pass through unchanged
return { transformed: expr as BasicExpression, extracted: {} }
}

/**
* Checks if two aggregate expressions are equal
*/
Expand Down
4 changes: 2 additions & 2 deletions packages/db/src/query/compiler/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
import { PropRef, Value as ValClass, getWhereExpression } from '../ir.js'
import { compileExpression, toBooleanPredicate } from './evaluators.js'
import { processJoins } from './joins.js'
import { processGroupBy } from './group-by.js'
import { containsAggregate, processGroupBy } from './group-by.js'
import { processOrderBy } from './order-by.js'
import { processSelect } from './select.js'
import type { CollectionSubscription } from '../../collection/subscription.js'
Expand Down Expand Up @@ -268,7 +268,7 @@ export function compileQuery(
} else if (query.select) {
// Check if SELECT contains aggregates but no GROUP BY (implicit single-group aggregation)
const hasAggregates = Object.values(query.select).some(
(expr) => expr.type === `agg`,
(expr) => expr.type === `agg` || containsAggregate(expr),
)
if (hasAggregates) {
// Handle implicit single-group aggregation
Expand Down
7 changes: 5 additions & 2 deletions packages/db/src/query/compiler/select.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { map } from '@tanstack/db-ivm'
import { PropRef, Value as ValClass, isExpressionLike } from '../ir.js'
import { AggregateNotSupportedError } from '../../errors.js'
import { compileExpression } from './evaluators.js'
import { containsAggregate } from './group-by.js'
import type { Aggregate, BasicExpression, Select } from '../ir.js'
import type {
KeyedStream,
Expand Down Expand Up @@ -226,8 +227,10 @@ function addFromObject(
continue
}

if (isAggregateExpression(expression)) {
// Placeholder for group-by processing later
if (isAggregateExpression(expression) || containsAggregate(expression)) {
// Placeholder for group-by processing later.
// Both plain aggregates (count(...)) and expressions wrapping
// aggregates (coalesce(count(...), 0)) are deferred to processGroupBy.
ops.push({
kind: `field`,
alias: [...prefixPath, key].join(`.`),
Expand Down
Loading
Loading