diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index ba3513f..adddd45 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -5,33 +5,32 @@ name: Test on: push: - branches: [ "main" ] + branches: ["main"] pull_request: - branches: [ "main" ] + branches: ["main"] permissions: contents: read jobs: build: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.12 - uses: actions/setup-python@v3 - with: - python-version: "3.12" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install ruff pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with ruff - run: | - # stop the build if there are Python syntax errors or undefined names - ruff check --output-format=github . - - name: Test with pytest - run: | - python -m pytest tests + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v3 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Install pre-commit hooks + run: pre-commit install + - name: Run pre-commit hooks for linting and other checks + run: pre-commit run --all-files + - name: Test with pytest + run: | + python -m pytest tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d06639a..4e0d9b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,6 +3,13 @@ default_stages: [pre-commit] repos: + - repo: https://github.com/rbubley/mirrors-prettier + rev: "v3.8.1" # Use the sha / tag you want to point at + hooks: + - id: prettier + additional_dependencies: + - prettier@2.1.2 + - "@prettier/plugin-xml@0.12.0" - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: @@ -19,6 +26,6 @@ repos: hooks: # Run the linter. - id: ruff-check - args: [ --fix ] + args: [--fix] # Run the formatter. - id: ruff-format diff --git a/README.md b/README.md index 4967187..19654d1 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,3 @@ # Rule-Parser + DRAM LARK based Rule Parser for misciellaneous rules parsing tasks such as traits, summarize, and product. diff --git a/src/rules.lark b/src/rules.lark index 9df4cbb..e40cad2 100644 --- a/src/rules.lark +++ b/src/rules.lark @@ -36,8 +36,11 @@ list_lit: "[" expr ("," expr)+ "]" -> step_ name_ref: IDENT ":" IDENT -> qualified_name | IDENT -> simple_name + | BTICK_NAME -> quoted_name + | IDENT ":" BTICK_NAME -> quoted_name IDENT: /[A-Za-z0-9_\-\.]+/ +BTICK_NAME: /`([^`\\]|\\.)*`/ NUMBER: /[0-9]+(\.[0-9]+)?/ STRING: ESCAPED_STRING diff --git a/src/rules.py b/src/rules.py index db8bcbc..86b2e37 100644 --- a/src/rules.py +++ b/src/rules.py @@ -1,16 +1,16 @@ from __future__ import annotations -import functools +from dataclasses import dataclass +from typing import Dict, Tuple, List, Optional, Set, Iterable import operator import os -from dataclasses import dataclass from pathlib import Path -from typing import Dict, Iterable, List, Optional, Set, Tuple +import functools import networkx as nx import numpy as np import polars as pl -from lark import Lark, LarkError, Transformer +from lark import Lark, Transformer, LarkError OP_TO_EXPR = { "gt": operator.gt, @@ -71,6 +71,7 @@ class RuleError(Exception): .str.split(",") .list.eval(pl.element().str.strip_chars().str.split(" ").list.first()) ), + "DEFAULT": lambda col: pl.col(col).cast(pl.Utf8).cast(pl.List(pl.Utf8)), } CALL_FUNCTIONS = { @@ -264,6 +265,15 @@ def qualified_name(self, items): db, value = items return Name(value=str(value), db=str(db)) + def quoted_name(self, items): + val_index = 0 if (len(items) == 1) else 1 + items[val_index] = ( + items[val_index][1:-1].replace(r"\`", "`").replace(r"\\", "\\") + ) # remove backticks + if len(items) == 2: + return self.qualified_name(items) + return self.simple_name(items) + def number(self, items): return Number(float(str(items[0]))) @@ -323,7 +333,7 @@ def from_rules(cls, *args, **kwargs) -> CompiledRules: } # Expand rules using expanded defs - # we need to hit again in case defs is empty (no alias col) + # we need to hit again in case defs is empty (no parent col) # and we still need to add needed features from rules features_by_rules = {k: set() for k in rules.keys()} trees_by_rules = {k: nx.DiGraph() for k in rules.keys()} @@ -351,15 +361,15 @@ def load_rules( rules_path: str = None, rules: pl.LazyFrame = None, label_col: str = "name", - alias_col: str = "alias", + parent_col: str = "alias", rules_col: str = "rule", allow_visualize_functions: bool = False, ) -> Tuple[Dict[str, Expr], Dict[str, Expr]]: """ - Assumes TSV has columns at least: name, rule + Assumes TSV has columns at least: name, parent, child Convention used here: - - if `name` is non-empty: this is an OUTPUT RULE whose expression is in `rule` - - if `name` is empty and `alias` is non-empty: this is a DEFINITION macro: alias := rule + - if `name` is non-empty: this is an OUTPUT RULE whose expression is in `child` + - if `name` is empty and `parent` is non-empty: this is a DEFINITION macro: parent := child If your file uses a slightly different convention, adjust this function only. """ assert (rules_path is not None) != (rules is not None), ( @@ -380,7 +390,7 @@ def load_rules( f"rules TSV missing required columns {required}. Found: {lf.columns}" ) - has_alias_col = alias_col and (alias_col in cols) + has_parent_col = parent_col and (parent_col in cols) with open(Path(__file__).parent.absolute() / "rules.lark") as f: parser = Lark( @@ -431,11 +441,11 @@ def parse_rule_expr(expr_str: str) -> Expr: } definitions = {} - if has_alias_col: + if has_parent_col: definitions = { a: b - for a, b in lf.filter(~pl.col(alias_col).is_null()) - .select([pl.col(alias_col), pl.col(rules_col)]) + for a, b in lf.filter(~pl.col(parent_col).is_null()) + .select([pl.col(parent_col), pl.col(rules_col)]) .collect() .iter_rows() } @@ -567,7 +577,7 @@ def recurse(expr: Expr, add_name_to_needed: bool = True) -> Expr: out = PipeChain(calls=tuple(calls)) else: raise TypeError(expr) - # Some nodes need validation and have to be done after rules are expanded + # Some nodes need validation and have to be done after children are expanded try: out.validate() except Exception: @@ -593,8 +603,19 @@ def build_present_map( """Build present_map of needed gene_ids from annotations DataFrame""" additional_cols = additional_cols or [] besthit_cols = [col for col in besthit_cols if col in lf.columns] + besthit_cols.extend( + [ + col + for col in lf.columns + if (col.endswith("_id") and col not in besthit_cols and col != "query_id") + ] + ) for col in besthit_cols: - lf = lf.with_columns(ID_EXPR_DICT[col].alias(col)).explode(col) + if col not in ID_EXPR_DICT: + explode_col = ID_EXPR_DICT["DEFAULT"](col).alias(col) + else: + explode_col = ID_EXPR_DICT[col].alias(col) + lf = lf.with_columns(explode_col).explode(col) lf = lf.select([sample_col] + besthit_cols + additional_cols) # unpivot to long (sample, hit) @@ -833,7 +854,7 @@ def not_(x: np.ndarray = None) -> np.ndarray | pl.Expr: if isinstance(x, pl.Expr): x = [x] masks = x - if len(masks) == 0: + if len(masks) == 1: mask = masks[0] else: mask = masks[0].and_(*[mask for mask in masks[1:]])