44from duckdb import DuckDBPyConnection , DuckDBPyRelation
55
66from countess import VERSION
7- from countess .core .parameters import ColumnOrNoneChoiceParam , FloatParam , PerNumericColumnArrayParam , TabularMultiParam
7+ from countess .core .parameters import (
8+ FloatParam ,
9+ PerNumericColumnArrayParam ,
10+ TabularMultiParam ,
11+ PerColumnArrayParam ,
12+ BooleanParam ,
13+ )
814from countess .core .plugins import DuckdbSqlPlugin
915from countess .utils .duckdb import duckdb_escape_identifier , duckdb_escape_literal
1016
@@ -21,7 +27,7 @@ class VampSeqScorePlugin(DuckdbSqlPlugin):
2127 version = VERSION
2228
2329 columns = PerNumericColumnArrayParam ("Columns" , CountColumnParam ("Column" ))
24- group_col = ColumnOrNoneChoiceParam ("Group By" )
30+ group_by = PerColumnArrayParam ("Group By" , BooleanParam ( "Column" , False ) )
2531
2632 def prepare (self , ddbc : DuckDBPyConnection , source : Optional [DuckDBPyRelation ]) -> None :
2733 super ().prepare (ddbc , source )
@@ -32,7 +38,17 @@ def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation])
3238 for n , c in enumerate (count_cols ):
3339 c .weight .value = (n + 1 ) / len (count_cols )
3440
41+ weight_cols = set (n for n , p in self .columns .get_column_params () if p .weight .value is not None )
42+ for n , p in self .group_by .get_column_params ():
43+ if n in weight_cols :
44+ p .set_value (False )
45+
3546 def sql (self , table_name : str , columns : Iterable [str ]) -> Optional [str ]:
47+ group_cols = {
48+ duckdb_escape_identifier (name )
49+ for name , param in self .group_by .get_column_params ()
50+ if param
51+ }
3652 weighted_columns = {
3753 duckdb_escape_identifier (name ): duckdb_escape_literal (param .weight .value )
3854 for name , param in self .columns .get_column_params ()
@@ -42,24 +58,24 @@ def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
4258 if not weighted_columns :
4359 return None
4460
45- if self .group_col .is_not_none ():
46- group_col_id = "T0." + duckdb_escape_identifier (self .group_col .value )
47- else :
48- group_col_id = "1"
49-
50- sums = ", " .join (f"sum(T0.{ k } ) as { k } " for k in weighted_columns .keys ())
61+ inner_select = ", " .join (
62+ [ f"T0.{ k } " for k in group_cols ] +
63+ [ f"sum(T0.{ k } ) as { k } " for k in weighted_columns .keys () ]
64+ )
5165 weighted_counts = " + " .join (
52- f"CASE WHEN T1 .{ k } > 0 THEN T0 .{ k } * { v } / T1 .{ k } ELSE 0 END" for k , v in weighted_columns .items ()
66+ f"CASE WHEN T2 .{ k } > 0 THEN T1 .{ k } * { v } / T2 .{ k } ELSE 0 END" for k , v in weighted_columns .items ()
5367 )
5468 total_counts = " + " .join (
55- f"CASE WHEN T1 .{ k } > 0 THEN T0 .{ k } / T1 .{ k } ELSE 0 END" for k in weighted_columns .keys ()
69+ f"CASE WHEN T2 .{ k } > 0 THEN T1 .{ k } / T2 .{ k } ELSE 0 END" for k in weighted_columns .keys ()
5670 )
71+ group_by = ("GROUP BY " + ", " .join ("T0." + c for c in group_cols )) if group_cols else ""
72+ join_on = (" AND " .join (f"T1.{ c } = T2.{ c } " for c in group_cols )) if group_cols else "1=1"
5773
5874 return f"""
59- select T0 .*, ({ weighted_counts } ) / ({ total_counts } ) as score
60- from { table_name } T0 join (
61- select { group_col_id } as score_group, { sums }
75+ select T1 .*, ({ weighted_counts } ) / ({ total_counts } ) as score
76+ from { table_name } T1 join (
77+ select { inner_select }
6278 from { table_name } T0
63- group by score_group
64- ) T1 on ( { group_col_id } = T1.score_group)
79+ { group_by }
80+ ) T2 on { join_on }
6581 """
0 commit comments