[Refactor] Add new AST node types and resolve AST TODOs#99
[Refactor] Add new AST node types and resolve AST TODOs#99colinthebomb1 wants to merge 3 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the internal AST representation to better model several SQL constructs (types/keywords, value lists, intervals, CASE expressions, and DISTINCT variants) and updates the expected AST fixtures accordingly.
Changes:
- Introduces new AST node types:
TypeNode,ListNode,IntervalNode, andCaseNode, plus newNodeTypeenum values. - Extends
SelectNodeto representDISTINCTandDISTINCT ON. - Updates
data/asts.pyexpected ASTs (and tweaks formatter tests) to use the new node types.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
core/ast/node.py |
Adds new node classes and extends SelectNode to model DISTINCT/DISTINCT ON. |
core/ast/enums.py |
Adds NodeType enum values for the new node types. |
data/asts.py |
Updates expected AST fixtures to use TypeNode, ListNode, IntervalNode, CaseNode, and DISTINCT metadata. |
tests/test_query_formatter.py |
Adjusts/relaxes formatter tests by commenting out some formatter calls and minor assertion key change. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _args=[p_preferenceValue, LiteralNode("en")], | ||
| ) | ||
| select_clause = SelectNode([t_gzpId, t_pubCode, t_playerId, coalesce_expr, s_segmentId]) | ||
| select_clause = SelectNode([t_gzpId, t_pubCode, t_playerId, coalesce_expr, s_segmentId], _distinct_on = ListNode([t_playerId])) |
There was a problem hiding this comment.
Keyword-argument spacing is inconsistent here (_distinct_on = ...). The surrounding codebase uses no spaces around '=' in keyword arguments (e.g., _alias="t" a few lines above). Adjust to match existing style for readability and consistency.
| select_clause = SelectNode([t_gzpId, t_pubCode, t_playerId, coalesce_expr, s_segmentId], _distinct_on = ListNode([t_playerId])) | |
| select_clause = SelectNode([t_gzpId, t_pubCode, t_playerId, coalesce_expr, s_segmentId], _distinct_on=ListNode([t_playerId])) |
|
|
||
| class IntervalNode(Node): | ||
| def __init__(self, _value, _unit: TypeNode, **kwargs): | ||
| super().__init__(NodeType.INTERVAL, children=[_unit], **kwargs) |
There was a problem hiding this comment.
IntervalNode stores _value as an attribute but does not include it in children. Since much of the codebase (e.g., formatting/visualization/traversal) walks the AST via children, the interval value can be skipped entirely. Consider including the value node in children (e.g., [value, unit] when value is a Node) or otherwise ensuring traversal/formatting code accounts for it.
| super().__init__(NodeType.INTERVAL, children=[_unit], **kwargs) | |
| # Include the value in children when it is itself a Node, so that | |
| # generic traversals/formatters that walk via `children` see it. | |
| if isinstance(_value, Node): | |
| children = [_value, _unit] | |
| else: | |
| children = [_unit] | |
| super().__init__(NodeType.INTERVAL, children=children, **kwargs) |
| def __init__(self, _items: List['Node'], **kwargs): | ||
| """SELECT clause node. _distinct_on is the list of expressions for DISTINCT ON (e.g. ListNode of columns).""" | ||
| def __init__(self, _items: List['Node'], _distinct: bool = False, _distinct_on: Optional['Node'] = None, **kwargs): | ||
| super().__init__(NodeType.SELECT, children=_items, **kwargs) |
There was a problem hiding this comment.
SelectNode tracks _distinct_on but does not include it in children. Any generic AST traversal that relies on children will miss DISTINCT ON expressions, which can lead to incorrect rewrites/formatting/analysis. Consider representing DISTINCT ON as part of the node’s subtree (e.g., include it in children with a dedicated wrapper node/field-aware traversal) and update consumers accordingly.
| super().__init__(NodeType.SELECT, children=_items, **kwargs) | |
| # Include DISTINCT ON expressions in children so generic AST traversals see them. | |
| children: List[Node] = list(_items) if _items is not None else [] | |
| if _distinct_on is not None: | |
| children.append(_distinct_on) | |
| super().__init__(NodeType.SELECT, children=children, **kwargs) |
| """SQL type keyword node (e.g. TEXT, DATE, INTEGER)""" | ||
| SQL_TYPE_KEYWORDS = {"TEXT", "DATE", "INTEGER", "TIMESTAMP", "VARCHAR", "BOOLEAN", "FLOAT", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "YEAR", "NULL"} | ||
|
|
||
| def __init__(self, _name: str, **kwargs): | ||
| if _name not in TypeNode.SQL_TYPE_KEYWORDS: | ||
| raise ValueError(f"Invalid SQL type keyword: {_name}") |
There was a problem hiding this comment.
TypeNode’s docstring says it represents SQL type keywords (TEXT/DATE/INTEGER), but the allowed keyword set also includes interval units (SECOND, MINUTE, …) and NULL. Either widen the docstring (and possibly rename the class) to reflect that it models general SQL keywords/units, or split this into separate node types to avoid confusion for AST consumers.
| """SQL type keyword node (e.g. TEXT, DATE, INTEGER)""" | |
| SQL_TYPE_KEYWORDS = {"TEXT", "DATE", "INTEGER", "TIMESTAMP", "VARCHAR", "BOOLEAN", "FLOAT", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "YEAR", "NULL"} | |
| def __init__(self, _name: str, **kwargs): | |
| if _name not in TypeNode.SQL_TYPE_KEYWORDS: | |
| raise ValueError(f"Invalid SQL type keyword: {_name}") | |
| """SQL keyword/unit node for types, interval units, and NULL (e.g. TEXT, DATE, INTEGER, SECOND, YEAR, NULL)""" | |
| SQL_TYPE_KEYWORDS = {"TEXT", "DATE", "INTEGER", "TIMESTAMP", "VARCHAR", "BOOLEAN", "FLOAT", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "YEAR", "NULL"} | |
| def __init__(self, _name: str, **kwargs): | |
| if _name not in TypeNode.SQL_TYPE_KEYWORDS: | |
| raise ValueError(f"Invalid SQL type/keyword: {_name}") |
| """Query 31: Aggregation to Subquery.""" | ||
| query = get_query(31) | ||
| sql = formatter.format(get_ast(31)) | ||
| #sql = formatter.format(get_ast(31)) |
There was a problem hiding this comment.
Commenting out sql = formatter.format(get_ast(31)) removes the only runtime exercise of formatter support for CASE expressions in this test (even if the assert is disabled). It would be better to keep the format call active and update QueryFormatter/ast_to_json to handle the new CaseNode so formatter regressions are caught.
| #sql = formatter.format(get_ast(31)) | |
| sql = formatter.format(get_ast(31)) |
| #sql = formatter.format(get_ast(41)) | ||
| # TODO: NULL keyword and IS NULL not fully supported yet | ||
| #assert parse(sql) == parse(query["pattern"]) |
There was a problem hiding this comment.
Commenting out sql = formatter.format(get_ast(41)) means the formatter is no longer exercised for NULL/IS NULL handling in this test, so formatter crashes introduced by the new TypeNode won’t be caught. Prefer keeping the format call and making QueryFormatter handle TypeNode (and IS/NULL) even if the equality assertion stays disabled for now.
| #sql = formatter.format(get_ast(41)) | |
| # TODO: NULL keyword and IS NULL not fully supported yet | |
| #assert parse(sql) == parse(query["pattern"]) | |
| sql = formatter.format(get_ast(41)) | |
| # TODO: NULL keyword and IS NULL not fully supported yet | |
| #assert parse(sql) == parse(query["pattern"]) | |
| assert sql is not None |
| """Query 43: MySQL Test.""" | ||
| query = get_query(43) | ||
| sql = formatter.format(get_ast(43)) | ||
| #sql = formatter.format(get_ast(43)) |
There was a problem hiding this comment.
Commenting out sql = formatter.format(get_ast(43)) removes formatter coverage for INTERVAL handling. Since IntervalNode was added/used in expected ASTs, consider keeping this call active and extending QueryFormatter to support IntervalNode so failures surface during CI.
| #sql = formatter.format(get_ast(43)) | |
| sql = formatter.format(get_ast(43)) |
Overview:
This PR adds new node types and adds functionality to existing ones to improve the constructed ASTs for test files, resolving most of the TODOs identified in PR #97.
Code Changes:
TypeNodefor SQL type keywords (TEXT,DATE,NULL, etc.) — replacesFunctionNode("TEXT", _args=[LiteralNode("{}")])ListNodefor value lists (e.g. the RHS ofINexpressions) — replaces raw Python listsIntervalNodeforINTERVALexpressions — replacesFunctionNode("INTERVAL", ...)CaseNodeforCASE WHEN ... THEN ... ELSE ... END— replaces nestedFunctionNode("CASE"/"WHEN"/"THEN"/"ELSE")_distinctand_distinct_onparameters toSelectNodeforSELECT DISTINCT/DISTINCT ONNodeType.TYPE,LIST,INTERVAL,CASEto enumsdata/asts.pyto use the new node types