Skip to content

Commit 1465ec7

Browse files
committed
correct vampseq to use frequencies instead of counts
1 parent ef6e89a commit 1465ec7

1 file changed

Lines changed: 29 additions & 6 deletions

File tree

countess/plugins/vampseq.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from duckdb import DuckDBPyConnection, DuckDBPyRelation
55

66
from countess import VERSION
7-
from countess.core.parameters import FloatParam, PerNumericColumnArrayParam, TabularMultiParam
7+
from countess.core.parameters import FloatParam, PerNumericColumnArrayParam, TabularMultiParam, ColumnOrNoneChoiceParam
88
from countess.core.plugins import DuckdbSimplePlugin
99
from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal
1010

@@ -21,6 +21,16 @@ class VampSeqScorePlugin(DuckdbSimplePlugin):
2121
version = VERSION
2222

2323
columns = PerNumericColumnArrayParam("Columns", CountColumnParam("Column"))
24+
group_col = ColumnOrNoneChoiceParam("Group By")
25+
26+
def prepare(self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation) -> None:
27+
super().prepare(ddbc, source)
28+
29+
# set default values for weights on "count" columns
30+
if all(c.weight.value is None for c in self.columns):
31+
count_cols = [ c for c in self.columns if c.label.startswith('count') ]
32+
for n, c in enumerate(count_cols):
33+
c.weight.value = (n+1)/len(count_cols)
2434

2535
def execute(
2636
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
@@ -34,10 +44,23 @@ def execute(
3444
if not weighted_columns:
3545
return source
3646

37-
weighted_counts = " + ".join(f"{k} * {v}" for k, v in weighted_columns.items())
38-
total_counts = " + ".join(k for k in weighted_columns.keys())
47+
if self.group_col.is_not_none():
48+
group_col_id = "T0." + duckdb_escape_identifier(self.group_col.value)
49+
else:
50+
group_col_id = "1"
51+
52+
sums = ", ".join(f"sum(T0.{k}) as {k}" for k in weighted_columns.keys())
53+
weighted_counts = " + ".join(f"T0.{k} * {v} / T1.{k}" for k, v in weighted_columns.items())
54+
total_counts = " + ".join(f"T0.{k} / T1.{k}" for k in weighted_columns.keys())
3955

40-
proj = f"*, ({weighted_counts}) / ({total_counts}) as score"
56+
sql = f"""
57+
select T0.*, ({weighted_counts}) / ({total_counts}) as score
58+
from {source.alias} T0 join (
59+
select {group_col_id} as score_group, {sums}
60+
from {source.alias} T0
61+
group by score_group
62+
) T1 on ({group_col_id} = T1.score_group)
63+
"""
4164

42-
logger.debug("VampseqScorePlugin proj %s", proj)
43-
return source.project(proj)
65+
logger.debug("VampseqScorePlugin sql %s", sql)
66+
return ddbc.sql(sql)

0 commit comments

Comments
 (0)