|
3 | 3 | import re |
4 | 4 | from typing import Dict, Optional |
5 | 5 |
|
| 6 | +import pypeg2 |
6 | 7 | from duckdb import DuckDBPyConnection, DuckDBPyRelation |
7 | 8 |
|
8 | 9 | from countess import VERSION |
|
12 | 13 |
|
13 | 14 | logger = logging.getLogger(__name__) |
14 | 15 |
|
15 | | -UNOPS = {ast.UAdd: "+", ast.USub: "-", ast.Not: "not "} |
16 | | -BINOPS = { |
17 | | - ast.Add: "+", |
18 | | - ast.Mult: "*", |
19 | | - ast.Div: "/", |
20 | | - ast.Sub: "-", |
21 | | - ast.FloorDiv: "//", |
22 | | - ast.Mod: "%", |
23 | | - ast.Pow: "**", |
24 | | -} |
25 | 16 | FUNCOPS = { |
26 | 17 | "abs", |
27 | 18 | "len", |
|
65 | 56 | "var_samp": "list_var_samp", |
66 | 57 | "std_samp": "list_stddev_samp", |
67 | 58 | } |
68 | | -BOOLOPS = { |
69 | | - ast.And: "AND", |
70 | | - ast.Or: "OR", |
71 | | -} |
72 | | - |
73 | | -COMPOPS = {ast.Eq: "=", ast.NotEq: "!=", ast.Lt: "<", ast.LtE: "<=", ast.Gt: ">", ast.GtE: ">="} |
74 | | - |
75 | | - |
76 | | -def _transmogrify(ast_node): |
77 | | - """Transform an AST node back into a string which can be parsed by DuckDB's expression |
78 | | - parser. This is a pretty small subset of all the things you might write but on the |
79 | | - other hand it saved actually writing a parser.""" |
80 | | - # XXX might have to write a parser anyway since the AST parser handles decimal |
81 | | - # literals badly. Worry about that later. |
82 | | - if type(ast_node) is ast.Name: |
83 | | - return duckdb_escape_identifier(ast_node.id) |
84 | | - elif type(ast_node) is ast.Constant: |
85 | | - return duckdb_escape_literal(ast_node.value) |
86 | | - elif type(ast_node) is ast.UnaryOp and type(ast_node.op) in UNOPS: |
87 | | - return "(" + UNOPS[type(ast_node.op)] + _transmogrify(ast_node.operand) + ")" |
88 | | - elif type(ast_node) is ast.BinOp and type(ast_node.op) in BINOPS: |
89 | | - binop = BINOPS[type(ast_node.op)] |
90 | | - left = _transmogrify(ast_node.left) |
91 | | - right = _transmogrify(ast_node.right) |
92 | | - return f"({left} {binop} {right})" |
93 | | - elif type(ast_node) is ast.BoolOp and type(ast_node.op) in BOOLOPS: |
94 | | - boolop = BOOLOPS[type(ast_node.op)] |
95 | | - return "(" + (f" {boolop} ".join(_transmogrify(v) for v in ast_node.values)) + ")" |
96 | | - elif type(ast_node) is ast.Subscript: |
97 | | - value = _transmogrify(ast_node.value) |
98 | | - if type(ast_node.slice) is ast.Slice: |
99 | | - lower = _transmogrify(ast_node.slice.lower) |
100 | | - upper = _transmogrify(ast_node.slice.upper) |
101 | | - return f"({value}[{lower}:{upper}])" |
102 | | - else: |
103 | | - index = _transmogrify(ast_node.slice) |
104 | | - return f"({value}[{index}])" |
105 | | - elif type(ast_node) is ast.Compare and all(type(op) in COMPOPS for op in ast_node.ops): |
106 | | - args = [_transmogrify(x) for x in [ast_node.left] + ast_node.comparators] |
107 | | - comps = [args[num] + COMPOPS[type(op)] + args[num + 1] for num, op in enumerate(ast_node.ops)] |
108 | | - return "(" + (" AND ".join(comps)) + ")" |
109 | | - elif type(ast_node) is ast.IfExp: |
110 | | - expr1 = _transmogrify(ast_node.test) |
111 | | - expr2 = _transmogrify(ast_node.body) |
112 | | - expr3 = _transmogrify(ast_node.orelse) |
113 | | - return f"CASE WHEN {expr1} THEN {expr2} ELSE {expr3} END" |
114 | | - elif type(ast_node) is ast.Call and type(ast_node.func) is ast.Name: |
115 | | - args = ",".join(_transmogrify(x) for x in ast_node.args) |
116 | | - if ast_node.func.id in FUNCOPS: |
117 | | - return f"{ast_node.func.id}({args})" |
118 | | - elif ast_node.func.id in LISTOPS: |
119 | | - func = LISTOPS[ast_node.func.id] |
120 | | - return f"{func}([{args}])" |
121 | | - elif ast_node.func.id in CASTOPS: |
122 | | - type_ = CASTOPS[ast_node.func.id] |
123 | | - return f"TRY_CAST({args} as {type_})" |
124 | | - else: |
125 | | - raise NotImplementedError(f"Unknown Function {ast_node.func.id}") |
126 | | - |
127 | | - else: |
128 | | - raise NotImplementedError(f"Unknown Node {ast_node}") |
129 | 59 |
|
130 | 60 |
|
131 | 61 | class ExpressionPlugin(DuckdbSimplePlugin): |
@@ -192,3 +122,137 @@ def execute( |
192 | 122 | return source.project(projection).filter(filter_) |
193 | 123 | else: |
194 | 124 | return source.project(projection) |
| 125 | + |
| 126 | + |
| 127 | + |
| 128 | +### PyPEG2 parser & expression generator follows |
| 129 | + |
| 130 | + |
| 131 | + |
| 132 | + |
| 133 | +class SqlTemplatingSymbol(pypeg2.Symbol): |
| 134 | + def sql(self): |
| 135 | + return str(self.name) |
| 136 | + |
| 137 | +class IntegerLiteral(SqlTemplatingSymbol): |
| 138 | + regex = re.compile(r'[0-9]+') |
| 139 | + |
| 140 | +class DecimalLiteral(SqlTemplatingSymbol): |
| 141 | + regex = re.compile(r'[0-9]+\.[0-9]+') |
| 142 | + |
| 143 | + def sql(self): |
| 144 | + return "(%s::DECIMAL)" % self.name |
| 145 | + |
| 146 | +class SingleQuotedStringLiteral(SqlTemplatingSymbol): |
| 147 | + regex = re.compile(r"'(?:\\.|[^'\n])*'") |
| 148 | + |
| 149 | + def sql(self): |
| 150 | + return duckdb_escape_literal(self.name[1:-1]) |
| 151 | + |
| 152 | +class DoubleQuotedStringLiteral(SqlTemplatingSymbol): |
| 153 | + regex = re.compile(r'"(?:\\.|[^"\n])*"') |
| 154 | + |
| 155 | + def sql(self): |
| 156 | + return duckdb_escape_literal(self.name[1:-1]) |
| 157 | + |
| 158 | +class Label(SqlTemplatingSymbol): |
| 159 | + regex = re.compile(r"[A-Za-z_][A-Za-z_0-9]*") |
| 160 | + |
| 161 | + def sql(self): |
| 162 | + return duckdb_escape_identifier(self.name) |
| 163 | + |
| 164 | +class SqlTemplatingList(pypeg2.List): |
| 165 | + before = '' |
| 166 | + between = '' |
| 167 | + after = '' |
| 168 | + |
| 169 | + def sql(self): |
| 170 | + return self.before + self.between.join(s.sql() for s in self) + self.after |
| 171 | + |
| 172 | +class ParenExpr(SqlTemplatingList): |
| 173 | + grammar = None # filled in later |
| 174 | + before = "(" |
| 175 | + after = ")" |
| 176 | + |
| 177 | +class FunctionName(SqlTemplatingSymbol): |
| 178 | + regex = re.compile(r'[a-z]+') |
| 179 | + |
| 180 | +class FunctionCall(pypeg2.Concat): |
| 181 | + grammar = None # filled in later |
| 182 | + |
| 183 | + def sql(self): |
| 184 | + function_name, *function_params = self |
| 185 | + if function_name in FUNCOPS: |
| 186 | + return function_name + "(" + ','.join(fp.sql() for fp in function_params) + ")" |
| 187 | + if function_name in LISTOPS: |
| 188 | + return LISTOPS[function_name] + "(" + ','.join(fp.sql() for fp in function_params) + ")" |
| 189 | + if function_name in CASTOPS: |
| 190 | + return f"TRY_CAST({fp[0].sql()} AS {CASTOPS[function_name]})" |
| 191 | + |
| 192 | + raise ValueError("Unknown function %s" % function_name) |
| 193 | + |
| 194 | +class Value(pypeg2.Concat): |
| 195 | + grammar = [FunctionCall, Label, DecimalLiteral, IntegerLiteral, SingleQuotedStringLiteral, DoubleQuotedStringLiteral, ParenExpr] |
| 196 | + |
| 197 | + def sql(self): |
| 198 | + return self[0].sql() |
| 199 | + |
| 200 | +FunctionCall.grammar = FunctionName, "(", Value, ")" |
| 201 | + |
| 202 | +class UnaOp(SqlTemplatingSymbol): |
| 203 | + regex = re.compile(r"[+-]") |
| 204 | + |
| 205 | +class UnaExpr(SqlTemplatingList): |
| 206 | + grammar = pypeg2.optional(UnaOp), Value |
| 207 | + |
| 208 | +class PowOp(SqlTemplatingSymbol): |
| 209 | + regex = re.compile(r"\*\*") |
| 210 | + |
| 211 | +class PowExpr(SqlTemplatingList): |
| 212 | + grammar = UnaExpr, pypeg2.maybe_some((PowOp, UnaExpr)) |
| 213 | + |
| 214 | +class MulOp(SqlTemplatingSymbol): |
| 215 | + regex = re.compile(r"[*/]") |
| 216 | + |
| 217 | +class MulExpr(SqlTemplatingList): |
| 218 | + grammar = PowExpr, pypeg2.maybe_some((MulOp, PowExpr)) |
| 219 | + |
| 220 | +class AddOp(SqlTemplatingSymbol): |
| 221 | + regex = re.compile(r"[+-]") |
| 222 | + |
| 223 | +class AddExpr(SqlTemplatingList): |
| 224 | + grammar = MulExpr, pypeg2.maybe_some((AddOp, MulExpr)) |
| 225 | + |
| 226 | +class NotOp(SqlTemplatingSymbol): |
| 227 | + regex = re.compile(r"not") |
| 228 | + |
| 229 | +class NotExpr(SqlTemplatingList): |
| 230 | + grammar = pypeg2.maybe_some(NotOp), AddExpr |
| 231 | + |
| 232 | +class AndExpr(SqlTemplatingList): |
| 233 | + grammar = NotExpr, pypeg2.maybe_some("and", NotExpr) |
| 234 | + between = ' AND ' |
| 235 | + |
| 236 | +class OrExpr(SqlTemplatingList): |
| 237 | + grammar = AndExpr, pypeg2.maybe_some("or", AndExpr) |
| 238 | + between = ' OR ' |
| 239 | + |
| 240 | +ParenExpr.grammar = "(", OrExpr, ")" |
| 241 | + |
| 242 | +class Statement(pypeg2.Concat): |
| 243 | + grammar = pypeg2.optional(Label, "="), OrExpr |
| 244 | + |
| 245 | + def sql_clause(self): |
| 246 | + return self[-1].sql() |
| 247 | + |
| 248 | + def sql_target(self): |
| 249 | + return self[0].sql() if len(self) > 1 else None |
| 250 | + |
| 251 | + |
| 252 | +def parse_block(block: str): |
| 253 | + for line in block.split('\n'): |
| 254 | + if not line: |
| 255 | + continue |
| 256 | + pp = pypeg2.parse(line, Statement) |
| 257 | + print (pp.sql_clause(), pp.sql_target()) |
| 258 | + |
0 commit comments