44from duckdb import DuckDBPyConnection , DuckDBPyRelation
55
66from countess import VERSION
7- from countess .core .parameters import FloatParam , PerNumericColumnArrayParam , TabularMultiParam
7+ from countess .core .parameters import FloatParam , PerNumericColumnArrayParam , TabularMultiParam , ColumnOrNoneChoiceParam
88from countess .core .plugins import DuckdbSimplePlugin
99from countess .utils .duckdb import duckdb_escape_identifier , duckdb_escape_literal
1010
@@ -21,6 +21,16 @@ class VampSeqScorePlugin(DuckdbSimplePlugin):
2121 version = VERSION
2222
2323 columns = PerNumericColumnArrayParam ("Columns" , CountColumnParam ("Column" ))
24+ group_col = ColumnOrNoneChoiceParam ("Group By" )
25+
26+ def prepare (self , ddbc : DuckDBPyConnection , source : DuckDBPyRelation ) -> None :
27+ super ().prepare (ddbc , source )
28+
29+ # set default values for weights on "count" columns
30+ if all (c .weight .value is None for c in self .columns ):
31+ count_cols = [ c for c in self .columns if c .label .startswith ('count' ) ]
32+ for n , c in enumerate (count_cols ):
33+ c .weight .value = (n + 1 )/ len (count_cols )
2434
2535 def execute (
2636 self , ddbc : DuckDBPyConnection , source : DuckDBPyRelation , row_limit : Optional [int ] = None
@@ -34,10 +44,23 @@ def execute(
3444 if not weighted_columns :
3545 return source
3646
37- weighted_counts = " + " .join (f"{ k } * { v } " for k , v in weighted_columns .items ())
38- total_counts = " + " .join (k for k in weighted_columns .keys ())
47+ if self .group_col .is_not_none ():
48+ group_col_id = "T0." + duckdb_escape_identifier (self .group_col .value )
49+ else :
50+ group_col_id = "1"
51+
52+ sums = ", " .join (f"sum(T0.{ k } ) as { k } " for k in weighted_columns .keys ())
53+ weighted_counts = " + " .join (f"T0.{ k } * { v } / T1.{ k } " for k , v in weighted_columns .items ())
54+ total_counts = " + " .join (f"T0.{ k } / T1.{ k } " for k in weighted_columns .keys ())
3955
40- proj = f"*, ({ weighted_counts } ) / ({ total_counts } ) as score"
56+ sql = f"""
57+ select T0.*, ({ weighted_counts } ) / ({ total_counts } ) as score
58+ from { source .alias } T0 join (
59+ select { group_col_id } as score_group, { sums }
60+ from { source .alias } T0
61+ group by score_group
62+ ) T1 on ({ group_col_id } = T1.score_group)
63+ """
4164
42- logger .debug ("VampseqScorePlugin proj %s" , proj )
43- return source . project ( proj )
65+ logger .debug ("VampseqScorePlugin sql %s" , sql )
66+ return ddbc . sql ( sql )
0 commit comments