Skip to content

Commit 5b39ec5

Browse files
committed
add multi column grouping to vampseq #45 although it does feel like this is a bad UI.
1 parent 0ac3916 commit 5b39ec5

1 file changed

Lines changed: 31 additions & 15 deletions

File tree

countess/plugins/vampseq.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from duckdb import DuckDBPyConnection, DuckDBPyRelation
55

66
from countess import VERSION
7-
from countess.core.parameters import ColumnOrNoneChoiceParam, FloatParam, PerNumericColumnArrayParam, TabularMultiParam
7+
from countess.core.parameters import (
8+
FloatParam,
9+
PerNumericColumnArrayParam,
10+
TabularMultiParam,
11+
PerColumnArrayParam,
12+
BooleanParam,
13+
)
814
from countess.core.plugins import DuckdbSqlPlugin
915
from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal
1016

@@ -21,7 +27,7 @@ class VampSeqScorePlugin(DuckdbSqlPlugin):
2127
version = VERSION
2228

2329
columns = PerNumericColumnArrayParam("Columns", CountColumnParam("Column"))
24-
group_col = ColumnOrNoneChoiceParam("Group By")
30+
group_by = PerColumnArrayParam("Group By", BooleanParam("Column", False))
2531

2632
def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> None:
2733
super().prepare(ddbc, source)
@@ -32,7 +38,17 @@ def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation])
3238
for n, c in enumerate(count_cols):
3339
c.weight.value = (n + 1) / len(count_cols)
3440

41+
weight_cols = set(n for n, p in self.columns.get_column_params() if p.weight.value is not None)
42+
for n, p in self.group_by.get_column_params():
43+
if n in weight_cols:
44+
p.set_value(False)
45+
3546
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
47+
group_cols = {
48+
duckdb_escape_identifier(name)
49+
for name, param in self.group_by.get_column_params()
50+
if param
51+
}
3652
weighted_columns = {
3753
duckdb_escape_identifier(name): duckdb_escape_literal(param.weight.value)
3854
for name, param in self.columns.get_column_params()
@@ -42,24 +58,24 @@ def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
4258
if not weighted_columns:
4359
return None
4460

45-
if self.group_col.is_not_none():
46-
group_col_id = "T0." + duckdb_escape_identifier(self.group_col.value)
47-
else:
48-
group_col_id = "1"
49-
50-
sums = ", ".join(f"sum(T0.{k}) as {k}" for k in weighted_columns.keys())
61+
inner_select = ", ".join(
62+
[ f"T0.{k}" for k in group_cols ] +
63+
[ f"sum(T0.{k}) as {k}" for k in weighted_columns.keys() ]
64+
)
5165
weighted_counts = " + ".join(
52-
f"CASE WHEN T1.{k} > 0 THEN T0.{k} * {v} / T1.{k} ELSE 0 END" for k, v in weighted_columns.items()
66+
f"CASE WHEN T2.{k} > 0 THEN T1.{k} * {v} / T2.{k} ELSE 0 END" for k, v in weighted_columns.items()
5367
)
5468
total_counts = " + ".join(
55-
f"CASE WHEN T1.{k} > 0 THEN T0.{k} / T1.{k} ELSE 0 END" for k in weighted_columns.keys()
69+
f"CASE WHEN T2.{k} > 0 THEN T1.{k} / T2.{k} ELSE 0 END" for k in weighted_columns.keys()
5670
)
71+
group_by = ("GROUP BY " + ", ".join("T0." + c for c in group_cols)) if group_cols else ""
72+
join_on = (" AND ".join(f"T1.{c} = T2.{c}" for c in group_cols)) if group_cols else "1=1"
5773

5874
return f"""
59-
select T0.*, ({weighted_counts}) / ({total_counts}) as score
60-
from {table_name} T0 join (
61-
select {group_col_id} as score_group, {sums}
75+
select T1.*, ({weighted_counts}) / ({total_counts}) as score
76+
from {table_name} T1 join (
77+
select {inner_select}
6278
from {table_name} T0
63-
group by score_group
64-
) T1 on ({group_col_id} = T1.score_group)
79+
{group_by}
80+
) T2 on {join_on}
6581
"""

0 commit comments

Comments
 (0)