From 6cbc6dc0f0604f592ca5cce8bd0c8b9f85303820 Mon Sep 17 00:00:00 2001 From: Leif Hedstrom Date: Wed, 11 Mar 2026 23:20:15 -0700 Subject: [PATCH] hrw4u: Add code coverage support for tests - Adds some new tests, for additional coverage (e.g. debug runs) - Eliminates some dead code discovered when there's no way to get the coverage over such code. --- tools/hrw4u/.gitignore | 2 + tools/hrw4u/Makefile | 12 +- tools/hrw4u/pyproject.toml | 36 +++ tools/hrw4u/src/common.py | 13 +- tools/hrw4u/src/debugging.py | 11 - tools/hrw4u/src/errors.py | 8 - tools/hrw4u/src/generators.py | 43 --- tools/hrw4u/src/interning.py | 10 - tools/hrw4u/src/symbols_base.py | 34 +-- tools/hrw4u/src/validation.py | 26 -- tools/hrw4u/src/visitor_base.py | 101 ------- tools/hrw4u/tests/test_common.py | 271 ++++++++++++++++++ tools/hrw4u/tests/test_debug_mode.py | 43 +++ tools/hrw4u/tests/test_tables.py | 149 ++++++++++ tools/hrw4u/tests/test_units.py | 397 ++++++++++++++++++++++++++- tools/hrw4u/tests/utils.py | 8 +- 16 files changed, 913 insertions(+), 251 deletions(-) create mode 100644 tools/hrw4u/tests/test_common.py create mode 100644 tools/hrw4u/tests/test_debug_mode.py create mode 100644 tools/hrw4u/tests/test_tables.py diff --git a/tools/hrw4u/.gitignore b/tools/hrw4u/.gitignore index c61b1049d77..488b8e281a3 100644 --- a/tools/hrw4u/.gitignore +++ b/tools/hrw4u/.gitignore @@ -1,3 +1,5 @@ build/ dist/ uv.lock +htmlcov/ +.coverage diff --git a/tools/hrw4u/Makefile b/tools/hrw4u/Makefile index 33ab62873c2..141555e2f35 100644 --- a/tools/hrw4u/Makefile +++ b/tools/hrw4u/Makefile @@ -102,7 +102,7 @@ INIT_HRW4U=$(PKG_DIR_HRW4U)/__init__.py INIT_U4WRH=$(PKG_DIR_U4WRH)/__init__.py INIT_LSP=$(PKG_DIR_LSP)/__init__.py -.PHONY: all gen gen-fwd gen-inv copy-src test clean build package env setup-deps activate update +.PHONY: all gen gen-fwd gen-inv copy-src test clean build package env setup-deps activate update coverage coverage-open all: gen @@ -167,6 +167,14 @@ $(PKG_DIR_LSP)/%: src/% test: uv run pytest --tb=short tests +coverage: + uv run pytest --cov --cov-report=term-missing --cov-report=html tests + @echo "" + @echo "HTML report: open htmlcov/index.html" + +coverage-open: coverage + uv run python -m webbrowser "file://$(shell pwd)/htmlcov/index.html" + # Build standalone binaries (optional) build: gen uv run pyinstaller --onefile --name hrw4u --strip $(SCRIPT_HRW4U) @@ -180,7 +188,7 @@ package: gen uv run python -m build --wheel --outdir $(DIST_DIR) clean: - rm -rf build dist __pycache__ *.spec *.egg-info .venv + rm -rf build dist __pycache__ *.spec *.egg-info .venv htmlcov .coverage find tests -name '__pycache__' -type d -exec rm -r {} + setup-deps: diff --git a/tools/hrw4u/pyproject.toml b/tools/hrw4u/pyproject.toml index ce607fd5b69..1a38e4145a8 100644 --- a/tools/hrw4u/pyproject.toml +++ b/tools/hrw4u/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest>=7.0,<8.0", + "pytest-cov>=4.1.0", "pyinstaller>=5.0,<7.0", "build>=0.8,<2.0", ] @@ -83,4 +84,39 @@ dev = [ "build>=1.4.0", "pyinstaller>=6.18.0", "pytest>=7.4.4", + "pytest-cov>=4.1.0", ] + +[tool.coverage.run] +source = [ + "hrw4u", + "u4wrh", +] +omit = [ + # ANTLR-generated files (not meaningful to cover) + "*/hrw4uLexer.py", + "*/hrw4uParser.py", + "*/hrw4uVisitor.py", + "*/u4wrhLexer.py", + "*/u4wrhParser.py", + "*/u4wrhVisitor.py", + # Unused/experimental modules + "*/kg_visitor.py", + # Fuzzy-matching suggestion engine (rapidfuzz dependency, hard to test meaningfully) + "*/suggestions.py", + # Package boilerplate + "*/__init__.py", + "*/__main__.py", +] + +[tool.coverage.report] +show_missing = true +skip_empty = true +exclude_lines = [ + "pragma: no cover", + "if __name__ == .__main__.", + "if TYPE_CHECKING:", +] + +[tool.coverage.html] +directory = "htmlcov" diff --git a/tools/hrw4u/src/common.py b/tools/hrw4u/src/common.py index d694f0d58de..5e9bf9ac3a2 100644 --- a/tools/hrw4u/src/common.py +++ b/tools/hrw4u/src/common.py @@ -71,7 +71,7 @@ class HeaderOperations: DESTINATION_OPERATIONS: Final = (MagicStrings.RM_DESTINATION.value, MagicStrings.SET_DESTINATION.value) -class LexerProtocol(Protocol): +class LexerProtocol(Protocol): # pragma: no cover """Protocol for ANTLR lexers.""" def removeErrorListeners(self) -> None: @@ -81,7 +81,7 @@ def addErrorListener(self, listener: Any) -> None: ... -class ParserProtocol(Protocol): +class ParserProtocol(Protocol): # pragma: no cover """Protocol for ANTLR parsers.""" def removeErrorListeners(self) -> None: @@ -96,7 +96,7 @@ def program(self) -> Any: errorHandler: BailErrorStrategy | DefaultErrorStrategy -class VisitorProtocol(Protocol): +class VisitorProtocol(Protocol): # pragma: no cover """Protocol for ANTLR visitors.""" def visit(self, tree: Any) -> list[str]: @@ -112,12 +112,7 @@ def fatal(message: str) -> NoReturn: def create_base_parser(description: str) -> tuple[argparse.ArgumentParser, argparse._MutuallyExclusiveGroup]: """Create base argument parser with common options.""" parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument( - "input_file", - help="The input file to parse (default: stdin)", - nargs="?", - type=argparse.FileType("r", encoding="utf-8"), - default=sys.stdin) + parser.add_argument("input_file", help="Optional input file path (default: reads from stdin)", nargs="?", default=None) output_group = parser.add_mutually_exclusive_group() output_group.add_argument("--ast", action="store_true", help="Produce the ANTLR parse tree only") diff --git a/tools/hrw4u/src/debugging.py b/tools/hrw4u/src/debugging.py index 5a15f39718e..d948ba441e1 100644 --- a/tools/hrw4u/src/debugging.py +++ b/tools/hrw4u/src/debugging.py @@ -18,7 +18,6 @@ from __future__ import annotations import sys -import types from .common import SystemDefaults @@ -36,16 +35,6 @@ def __call__(self, msg: str, *, levels: bool = False, out: bool = False) -> None msg = f"" if out else f"<{msg}>" print(f"{SystemDefaults.DEBUG_PREFIX} {' ' * (self.indent * SystemDefaults.INDENT_SPACES)}{msg}", file=sys.stderr) - def __enter__(self) -> "Dbg": - if self.enabled: - self.indent += 1 - return self - - def __exit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None) -> None: - if self.enabled: - self.indent = max(0, self.indent - 1) - def enter(self, msg: str) -> None: if self.enabled: self(msg, levels=True) diff --git a/tools/hrw4u/src/errors.py b/tools/hrw4u/src/errors.py index a2710fb840f..5c2d38b0acb 100644 --- a/tools/hrw4u/src/errors.py +++ b/tools/hrw4u/src/errors.py @@ -55,14 +55,6 @@ def __init__(self, filename: str, line: int, column: int, message: str, source_l self.column = column self.source_line = source_line - def add_context_note(self, context: str) -> None: - """Add contextual information using Python 3.11+ exception notes.""" - self.add_note(f"Context: {context}") - - def add_resolution_hint(self, hint: str) -> None: - """Add resolution hint using Python 3.11+ exception notes.""" - self.add_note(f"Hint: {hint}") - def _format_error(self, filename: str, line: int, col: int, message: str, source_line: str) -> str: error = f"{filename}:{line}:{col}: error: {message}" diff --git a/tools/hrw4u/src/generators.py b/tools/hrw4u/src/generators.py index 31a547cd71a..20c00ce35b4 100644 --- a/tools/hrw4u/src/generators.py +++ b/tools/hrw4u/src/generators.py @@ -25,7 +25,6 @@ from __future__ import annotations from typing import Any -from functools import cache from hrw4u.states import SectionType @@ -40,33 +39,6 @@ def _clean_tag(tag: str) -> str: """Extract clean tag name from %{TAG:payload} format.""" return tag.strip().removeprefix('%{').removesuffix('}').split(':')[0] - def generate_reverse_condition_map(self, condition_map: tuple[tuple[str, Any], ...]) -> dict[str, str]: - """Generate reverse condition mapping from forward condition map.""" - reverse_map = {} - - for ident_key, params in condition_map: - if not ident_key.endswith('.'): - tag = params.target if params else None - if tag: - clean_tag = self._clean_tag(tag) - reverse_map[clean_tag] = ident_key - - return reverse_map - - def generate_reverse_function_map(self, function_map: tuple[tuple[str, Any], ...]) -> dict[str, str]: - """Generate reverse function mapping from forward function map.""" - return {params.target: func_name for func_name, params in function_map} - - @cache - def generate_section_hook_mapping(self) -> dict[str, str]: - """Generate section name to hook name mapping.""" - return {section.value: section.hook_name for section in SectionType} - - @cache - def generate_hook_section_mapping(self) -> dict[str, str]: - """Generate hook name to section name mapping.""" - return {section.hook_name: section.value for section in SectionType} - def generate_ip_mapping(self) -> dict[str, str]: """Generate IP payload to identifier mapping from CONDITION_MAP.""" from hrw4u.tables import CONDITION_MAP @@ -161,21 +133,6 @@ def generate_complete_reverse_resolution_map(self) -> dict[str, Any]: _table_generator = TableGenerator() -def get_reverse_condition_map(condition_map: dict[str, tuple]) -> dict[str, str]: - """Get reverse condition mapping.""" - return _table_generator.generate_reverse_condition_map(tuple(condition_map.items())) - - -def get_reverse_function_map(function_map: dict[str, Any]) -> dict[str, str]: - """Get reverse function mapping.""" - return _table_generator.generate_reverse_function_map(tuple(function_map.items())) - - -def get_section_mappings() -> tuple[dict[str, str], dict[str, str]]: - """Get both section->hook and hook->section mappings.""" - return (_table_generator.generate_section_hook_mapping(), _table_generator.generate_hook_section_mapping()) - - def get_complete_reverse_resolution_map() -> dict[str, Any]: """Get the complete generated reverse resolution map.""" return _table_generator.generate_complete_reverse_resolution_map() diff --git a/tools/hrw4u/src/interning.py b/tools/hrw4u/src/interning.py index 898e8f2462d..2f6e6e8d90f 100644 --- a/tools/hrw4u/src/interning.py +++ b/tools/hrw4u/src/interning.py @@ -76,11 +76,6 @@ def intern_lsp_string(cls, string: str) -> str: """Intern an LSP-related string, returning the interned version if available.""" return cls.LSP_STRINGS.get(string, sys.intern(string)) - @classmethod - def intern_any(cls, string: str) -> str: - """General-purpose string interning with fallback to sys.intern().""" - return sys.intern(string) - def intern_keyword(keyword: str) -> str: """Intern language keywords.""" @@ -105,8 +100,3 @@ def intern_modifier(modifier: str) -> str: def intern_lsp_string(string: str) -> str: """Intern LSP-related strings.""" return StringInterning.intern_lsp_string(string) - - -def intern_any(string: str) -> str: - """General-purpose string interning.""" - return StringInterning.intern_any(string) diff --git a/tools/hrw4u/src/symbols_base.py b/tools/hrw4u/src/symbols_base.py index 1fb95cd60c0..5d213167246 100644 --- a/tools/hrw4u/src/symbols_base.py +++ b/tools/hrw4u/src/symbols_base.py @@ -17,7 +17,7 @@ from __future__ import annotations from functools import cached_property, lru_cache -from typing import Callable, Any +from typing import Any from hrw4u.debugging import Dbg from hrw4u.states import SectionType from hrw4u.common import SystemDefaults @@ -30,11 +30,6 @@ class SymbolResolverBase: def __init__(self, debug: bool = SystemDefaults.DEFAULT_DEBUG) -> None: self._dbg = Dbg(debug) - # Clear caches when debug status changes to ensure consistency - if hasattr(self, '_condition_cache'): - self._condition_cache.cache_clear() - if hasattr(self, '_operator_cache'): - self._operator_cache.cache_clear() # Cached table access for performance - Python 3.11+ cached_property @cached_property @@ -53,10 +48,6 @@ def _function_map(self) -> dict[str, types.MapParams]: def _statement_function_map(self) -> dict[str, types.MapParams]: return tables.STATEMENT_FUNCTION_MAP - @cached_property - def _reverse_resolution_map(self) -> dict[str, Any]: - return tables.REVERSE_RESOLUTION_MAP - def validate_section_access(self, name: str, section: SectionType | None, allowed_sections: set[SectionType] | None) -> None: if section and allowed_sections and section not in allowed_sections: raise SymbolResolutionError(name, f"{name} is not available in the {section.value} section") @@ -65,10 +56,6 @@ def validate_section_access(self, name: str, section: SectionType | None, allowe def _lookup_condition_cached(self, name: str) -> types.MapParams | None: return self._condition_map.get(name) - @lru_cache(maxsize=256) - def _lookup_operator_cached(self, name: str) -> types.MapParams | None: - return self._operator_map.get(name) - @lru_cache(maxsize=128) def _lookup_function_cached(self, name: str) -> types.MapParams | None: return self._function_map.get(name) @@ -90,18 +77,6 @@ def _debug_exit(self, method_name: str, result: Any = None) -> None: else: self._dbg.exit(method_name) - def _debug_log(self, message: str) -> None: - self._dbg(message) - - def _create_symbol_error(self, symbol_name: str, message: str) -> SymbolResolutionError: - return SymbolResolutionError(symbol_name, message) - - def _handle_unknown_symbol(self, symbol_name: str, symbol_type: str) -> SymbolResolutionError: - return self._create_symbol_error(symbol_name, f"Unknown {symbol_type}: '{symbol_name}'") - - def _handle_validation_error(self, symbol_name: str, validation_message: str) -> SymbolResolutionError: - return self._create_symbol_error(symbol_name, validation_message) - def find_prefix_matches(self, target: str, table: dict[str, Any]) -> list[tuple[str, Any]]: matches = [] for key, value in table.items(): @@ -109,13 +84,6 @@ def find_prefix_matches(self, target: str, table: dict[str, Any]) -> list[tuple[ matches.append((key, value)) return matches - def get_longest_prefix_match(self, target: str, table: dict[str, Any]) -> tuple[str, Any] | None: - matches = self.find_prefix_matches(target, table) - if not matches: - return None - matches.sort(key=lambda x: len(x[0]), reverse=True) - return matches[0] - def debug_context(self, method_name: str, *args: Any): class DebugContext: diff --git a/tools/hrw4u/src/validation.py b/tools/hrw4u/src/validation.py index decd045bcb4..1b6d35e4410 100644 --- a/tools/hrw4u/src/validation.py +++ b/tools/hrw4u/src/validation.py @@ -19,7 +19,6 @@ import re from typing import Callable from hrw4u.errors import SymbolResolutionError -from hrw4u import states import hrw4u.types as types from hrw4u.common import RegexPatterns @@ -70,12 +69,6 @@ def http_token(self) -> 'ValidatorChain': def http_header_name(self) -> 'ValidatorChain': return self._add(self._wrap_args(Validator.http_header_name())) - def simple_token(self) -> 'ValidatorChain': - return self._add(self._wrap_args(Validator.simple_token())) - - def regex_literal(self) -> 'ValidatorChain': - return self._add(self._wrap_args(Validator.regex_literal())) - def nbit_int(self, nbits: int) -> 'ValidatorChain': return self._add(self._wrap_args(Validator.nbit_int(nbits))) @@ -164,16 +157,6 @@ def validator(value: str) -> None: return validator - @staticmethod - def simple_token() -> Callable[[str], None]: - """Validate simple tokens (letters, digits, underscore, dash).""" - return Validator.regex_validator(Validator._SIMPLE_TOKEN_RE, "Must be a simple token (letters, digits, underscore, dash)") - - @staticmethod - def regex_literal() -> Callable[[str], None]: - """Validate regex literals in /pattern/ format.""" - return Validator.regex_validator(Validator._REGEX_LITERAL_RE, "Must be a valid regex literal in /pattern/ format") - @staticmethod def http_token() -> Callable[[str], None]: """Validate HTTP tokens according to RFC 7230.""" @@ -245,15 +228,6 @@ def needs_quotes(value: str) -> bool: def quote_if_needed(value: str) -> str: return f'"{value}"' if Validator.needs_quotes(value) else value - @staticmethod - def logic_modifier() -> Callable[[str], None]: - - def validator(value: str) -> None: - if value.upper() not in states.ALL_MODIFIERS: - raise SymbolResolutionError(value, f"Invalid logic modifier: {value}") - - return validator - @staticmethod def percent_block() -> Callable[[str], None]: diff --git a/tools/hrw4u/src/visitor_base.py b/tools/hrw4u/src/visitor_base.py index e7d1e88d2e0..6af1c1b2a92 100644 --- a/tools/hrw4u/src/visitor_base.py +++ b/tools/hrw4u/src/visitor_base.py @@ -32,37 +32,6 @@ def format_with_indent(self, text: str, indent_level: int) -> str: return " " * (indent_level * SystemDefaults.INDENT_SPACES) + text -class ErrorHandler: - - @staticmethod - def handle_visitor_error(filename: str, ctx: object, exc: Exception, error_collector=None, return_value: str = "") -> str: - """Standard error handling for visitor methods.""" - from hrw4u.errors import hrw4u_error - - error = hrw4u_error(filename, ctx, exc) - if error_collector: - error_collector.add_error(error) - return return_value - else: - raise error - - @staticmethod - def handle_symbol_error(filename: str, ctx: object, symbol_name: str, exc: Exception, error_collector=None) -> str | None: - """Handle symbol resolution errors with context.""" - from hrw4u.errors import hrw4u_error, SymbolResolutionError - - if isinstance(exc, SymbolResolutionError): - error = hrw4u_error(filename, ctx, exc) - else: - error = hrw4u_error(filename, ctx, f"symbol error in '{symbol_name}': {exc}") - - if error_collector: - error_collector.add_error(error) - return f"ERROR({symbol_name})" - else: - raise error - - @dataclass(slots=True) class VisitorState: """Encapsulates visitor state that is commonly tracked across visitor implementations.""" @@ -93,31 +62,6 @@ def _initialize_visitor(self) -> None: """Hook for subclass-specific initialization. Override as needed.""" pass - # Error handling patterns - centralized and consistent - def handle_visitor_error(self, ctx: Any, exc: Exception, return_value: str = "") -> str: - """ - Standard error handling for visitor methods. - """ - return ErrorHandler.handle_visitor_error(self.filename, ctx, exc, self.error_collector, return_value) - - def handle_symbol_error(self, ctx: Any, symbol_name: str, exc: Exception) -> str | None: - """ - Handle symbol resolution errors with context. - """ - return ErrorHandler.handle_symbol_error(self.filename, ctx, symbol_name, exc, self.error_collector) - - def safe_visit_with_error_handling(self, method_name: str, ctx: Any, visit_func, *args, **kwargs): - """ - Generic wrapper for visitor methods with consistent error handling. - """ - try: - self._dbg.enter(method_name) - return visit_func(*args, **kwargs) - except Exception as exc: - return self.handle_visitor_error(ctx, exc) - finally: - self._dbg.exit(method_name) - # State management - common patterns @property def current_section(self) -> SectionType | None: @@ -177,14 +121,6 @@ def decrement_stmt_indent(self) -> None: """Decrement statement indentation level (with bounds checking).""" self._state.stmt_indent = max(0, self._state.stmt_indent - 1) - def increment_cond_indent(self) -> None: - """Increment condition indentation level.""" - self._state.cond_indent += 1 - - def decrement_cond_indent(self) -> None: - """Decrement condition indentation level (with bounds checking).""" - self._state.cond_indent = max(0, self._state.cond_indent - 1) - # Output management patterns def emit_line(self, text: str, indent_level: int | None = None) -> None: """ @@ -227,27 +163,6 @@ def is_debug(self) -> bool: """Check if debug mode is enabled.""" return self._dbg.enabled - # Context managers for common patterns - def indent_context(self, increment: int = 1): - """Context manager for temporary indentation changes.""" - - class IndentContext: - - def __init__(self, visitor: BaseHRWVisitor, inc: int): - self.visitor = visitor - self.increment = inc - - def __enter__(self): - for _ in range(self.increment): - self.visitor.increment_stmt_indent() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - for _ in range(self.increment): - self.visitor.decrement_stmt_indent() - - return IndentContext(self, increment) - def debug_context(self, method_name: str, *args: Any): """Context manager for debug entry/exit around operations.""" @@ -327,22 +242,6 @@ def _finalize_output(self) -> None: """ pass - # Utility methods for common visitor operations - def should_continue_on_error(self) -> bool: - """Determine if processing should continue when errors occur.""" - return self.error_collector is not None - - def get_error_summary(self) -> str | None: - """Get error summary if errors were collected.""" - if self.error_collector and self.error_collector.has_errors(): - return self.error_collector.get_error_summary() - return None - - def reset_state(self) -> None: - """Reset visitor state for reuse.""" - self.output.clear() - self._state = VisitorState() - def debug(self, message: str) -> None: """Alias for debug_log for backward compatibility.""" self.debug_log(message) diff --git a/tools/hrw4u/tests/test_common.py b/tools/hrw4u/tests/test_common.py new file mode 100644 index 00000000000..d17cdf6ad5d --- /dev/null +++ b/tools/hrw4u/tests/test_common.py @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for common.py functions called in-process for coverage visibility. + +The existing test_cli.py tests these paths via subprocess, which coverage +cannot track. These tests call the same functions directly so coverage +can see them. +""" +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest +from hrw4u.common import ( + create_base_parser, + create_parse_tree, + fatal, + generate_output, + process_input, + run_main, +) +from hrw4u.errors import ErrorCollector, Hrw4uSyntaxError +from hrw4u.hrw4uLexer import hrw4uLexer +from hrw4u.hrw4uParser import hrw4uParser +from hrw4u.visitor import HRW4UVisitor + +# --------------------------------------------------------------------------- +# Approach A: Unit tests for individual building-block functions +# --------------------------------------------------------------------------- + + +class TestFatal: + """Tests for the fatal() helper.""" + + def test_fatal_exits_with_code_1(self, capsys): + with pytest.raises(SystemExit) as exc_info: + fatal("something broke") + assert exc_info.value.code == 1 + assert "something broke" in capsys.readouterr().err + + +class TestCreateBaseParser: + """Tests for create_base_parser().""" + + def test_returns_parser_and_group(self): + parser, group = create_base_parser("test description") + assert parser is not None + assert group is not None + + def test_parser_has_expected_arguments(self): + parser, _ = create_base_parser("test") + args = parser.parse_args(["--ast", "--debug", "--stop-on-error"]) + assert args.ast is True + assert args.debug is True + assert args.stop_on_error is True + + def test_parser_defaults(self): + parser, _ = create_base_parser("test") + args = parser.parse_args([]) + assert args.ast is False + assert args.debug is False + assert args.stop_on_error is False + + +class TestProcessInput: + """Tests for process_input().""" + + def test_stdin_returns_default_filename(self): + fake_stdin = io.StringIO("hello world") + fake_stdin.name = "" + content, filename = process_input(fake_stdin) + assert content == "hello world" + assert filename == "" + + def test_file_input_returns_real_filename(self, tmp_path): + p = tmp_path / "test.hrw4u" + p.write_text("set-header X-Foo bar") + with open(p, "r", encoding="utf-8") as f: + content, filename = process_input(f) + assert content == "set-header X-Foo bar" + assert str(p) in filename + + +class TestCreateParseTree: + """Tests for create_parse_tree().""" + + def test_valid_input_with_error_collection(self): + tree, parser_obj, errors = create_parse_tree( + 'REMAP { no-op(); }', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=True) + assert tree is not None + assert errors is not None + assert not errors.has_errors() + + def test_valid_input_without_error_collection(self): + tree, parser_obj, errors = create_parse_tree( + 'REMAP { no-op(); }', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=False) + assert tree is not None + assert errors is None + + def test_invalid_input_collects_errors(self): + tree, parser_obj, errors = create_parse_tree( + '{{{{ totally broken syntax !!! }}}}', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=True) + assert errors is not None + assert errors.has_errors() + + +class TestGenerateOutput: + """Tests for generate_output().""" + + def test_normal_output(self, capsys): + tree, parser_obj, errors = create_parse_tree( + 'REMAP { no-op(); }', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=True) + args = SimpleNamespace(ast=False, debug=False, no_comments=False) + generate_output(tree, parser_obj, HRW4UVisitor, "", args, errors) + out = capsys.readouterr().out + assert "no-op" in out + + def test_ast_output(self, capsys): + tree, parser_obj, errors = create_parse_tree( + 'REMAP { no-op(); }', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=True) + args = SimpleNamespace(ast=True, debug=False) + generate_output(tree, parser_obj, HRW4UVisitor, "", args, errors) + out = capsys.readouterr().out + assert "program" in out.lower() or "(" in out + + def test_ast_mode_with_parse_errors(self, capsys): + tree, parser_obj, errors = create_parse_tree( + '{{{{ broken }}}}', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=True) + args = SimpleNamespace(ast=True, debug=False) + generate_output(tree, parser_obj, HRW4UVisitor, "", args, errors) + captured = capsys.readouterr() + assert errors.has_errors() + + def test_error_collector_summary_on_errors(self, capsys): + tree, parser_obj, errors = create_parse_tree( + '{{{{ broken }}}}', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=True) + args = SimpleNamespace(ast=False, debug=False, no_comments=False) + generate_output(tree, parser_obj, HRW4UVisitor, "", args, errors) + err = capsys.readouterr().err + assert "error" in err.lower() or "Error" in err + + def test_ast_mode_tree_none_with_errors(self, capsys): + """When tree is None and errors exist, AST mode prints fallback message.""" + errors = ErrorCollector() + errors.add_error(Hrw4uSyntaxError("", 1, 0, "parse failed", "bad")) + args = SimpleNamespace(ast=True, debug=False) + generate_output(None, None, HRW4UVisitor, "", args, errors) + out = capsys.readouterr().out + assert "Parse tree not available" in out + + def test_error_collector_exits_on_parse_failure(self, capsys): + """When tree is None and errors exist in non-AST mode, should exit(1).""" + errors = ErrorCollector() + errors.add_error(Hrw4uSyntaxError("", 1, 0, "parse failed", "bad")) + args = SimpleNamespace(ast=False, debug=False, no_comments=False) + with pytest.raises(SystemExit) as exc_info: + generate_output(None, None, HRW4UVisitor, "", args, errors) + assert exc_info.value.code == 1 + + def test_visitor_exception_collected(self, capsys): + """When visitor.visit() raises, error is collected and reported.""" + + class BrokenVisitor: + + def __init__(self, **kwargs): + pass + + def visit(self, tree): + exc = RuntimeError("visitor exploded") + exc.add_note("hint: check input") + raise exc + + tree, parser_obj, errors = create_parse_tree( + 'REMAP { no-op(); }', "", hrw4uLexer, hrw4uParser, "hrw4u", collect_errors=True) + args = SimpleNamespace(ast=False, debug=False, no_comments=False) + generate_output(tree, parser_obj, BrokenVisitor, "", args, errors) + err = capsys.readouterr().err + assert "visitor exploded" in err.lower() or "Visitor error" in err + + +# --------------------------------------------------------------------------- +# Approach B: run_main() called in-process via monkeypatch +# --------------------------------------------------------------------------- + + +class TestRunMain: + """Tests for run_main() covering the CLI orchestration code.""" + + def _run(self, monkeypatch, capsys, argv, stdin_text=None): + """Helper to invoke run_main() with patched sys.argv and optional stdin.""" + monkeypatch.setattr("sys.argv", ["hrw4u"] + argv) + if stdin_text is not None: + monkeypatch.setattr("sys.stdin", io.StringIO(stdin_text)) + run_main("HRW4U test", hrw4uLexer, hrw4uParser, HRW4UVisitor, "hrw4u", "hrw", "Produce header_rewrite output") + return capsys.readouterr() + + def test_stdin_to_stdout(self, monkeypatch, capsys): + captured = self._run(monkeypatch, capsys, [], stdin_text='REMAP { no-op(); }') + assert "no-op" in captured.out + + def test_single_file_to_stdout(self, monkeypatch, capsys, tmp_path): + p = tmp_path / "test.hrw4u" + p.write_text('REMAP { inbound.req.X-Hello = "world"; }') + captured = self._run(monkeypatch, capsys, [str(p)]) + assert "X-Hello" in captured.out + + def test_multiple_files_with_separator(self, monkeypatch, capsys, tmp_path): + f1 = tmp_path / "a.hrw4u" + f1.write_text('REMAP { no-op(); }') + f2 = tmp_path / "b.hrw4u" + f2.write_text('REMAP { inbound.req.X-B = "val"; }') + captured = self._run(monkeypatch, capsys, [str(f1), str(f2)]) + assert "# ---" in captured.out + assert "no-op" in captured.out + assert "X-B" in captured.out + + def test_bulk_input_output_pairs(self, monkeypatch, capsys, tmp_path): + inp = tmp_path / "in.hrw4u" + inp.write_text('REMAP { no-op(); }') + out = tmp_path / "out.conf" + self._run(monkeypatch, capsys, [f"{inp}:{out}"]) + assert out.exists() + assert "no-op" in out.read_text() + + def test_bulk_nonexistent_input(self, monkeypatch, capsys, tmp_path): + out = tmp_path / "out.conf" + with pytest.raises(SystemExit) as exc_info: + self._run(monkeypatch, capsys, [f"/no/such/file.hrw4u:{out}"]) + assert exc_info.value.code == 1 + + def test_mixed_format_error(self, monkeypatch, capsys, tmp_path): + f1 = tmp_path / "a.hrw4u" + f1.write_text('REMAP { no-op(); }') + out = tmp_path / "out.conf" + with pytest.raises(SystemExit) as exc_info: + self._run(monkeypatch, capsys, [str(f1), f"{f1}:{out}"]) + assert exc_info.value.code == 1 + + def test_ast_mode(self, monkeypatch, capsys, tmp_path): + p = tmp_path / "test.hrw4u" + p.write_text('REMAP { no-op(); }') + captured = self._run(monkeypatch, capsys, ["--ast", str(p)]) + assert "program" in captured.out.lower() or "(" in captured.out + + def test_no_comments_flag(self, monkeypatch, capsys): + captured = self._run(monkeypatch, capsys, ["--no-comments"], stdin_text='REMAP { no-op(); }') + assert "no-op" in captured.out + + def test_stop_on_error_flag(self, monkeypatch, capsys): + captured = self._run(monkeypatch, capsys, ["--stop-on-error"], stdin_text='REMAP { no-op(); }') + assert "no-op" in captured.out + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tools/hrw4u/tests/test_debug_mode.py b/tools/hrw4u/tests/test_debug_mode.py new file mode 100644 index 00000000000..0977c9939f7 --- /dev/null +++ b/tools/hrw4u/tests/test_debug_mode.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Debug-mode tests: re-run the examples group with debug=True. + +The examples group exercises the most diverse visitor code paths +(conditions, operators, hooks, vars) and is sufficient to reach +100% coverage of the Dbg class. Running all groups in debug mode +is redundant since debug tracing doesn't affect output correctness. +""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import utils + + +@pytest.mark.parametrize("input_file,output_file", utils.collect_output_test_files("examples", "hrw4u")) +def test_examples_debug(input_file: Path, output_file: Path) -> None: + """Test hrw4u examples output matches with debug enabled.""" + utils.run_output_test(input_file, output_file, debug=True) + + +@pytest.mark.reverse +@pytest.mark.parametrize("input_file,output_file", utils.collect_output_test_files("examples", "u4wrh")) +def test_examples_reverse_debug(input_file: Path, output_file: Path) -> None: + """Test u4wrh examples reverse conversion with debug enabled.""" + utils.run_reverse_test(input_file, output_file, debug=True) diff --git a/tools/hrw4u/tests/test_tables.py b/tools/hrw4u/tests/test_tables.py new file mode 100644 index 00000000000..fe3163400d4 --- /dev/null +++ b/tools/hrw4u/tests/test_tables.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for LSPPatternMatcher in tables.py.""" +from __future__ import annotations + +import pytest +from hrw4u.tables import LSPPatternMatcher + + +class TestMatchFieldPattern: + + def test_now_field(self): + match = LSPPatternMatcher.match_field_pattern("now.YEAR") + assert match is not None + assert match.pattern == "now." + assert match.suffix == "YEAR" + assert match.context_type == "Current Date/Time Field" + assert match.maps_to == "%{NOW:YEAR}" + + def test_id_field(self): + match = LSPPatternMatcher.match_field_pattern("id.PROCESS") + assert match is not None + assert match.pattern == "id." + assert match.field_dict_key == "ID_FIELDS" + + def test_geo_field(self): + match = LSPPatternMatcher.match_field_pattern("geo.COUNTRY") + assert match is not None + assert match.pattern == "geo." + assert match.maps_to == "%{GEO:COUNTRY}" + + def test_no_match(self): + assert LSPPatternMatcher.match_field_pattern("inbound.req.X-Foo") is None + + +class TestMatchHeaderPattern: + + def test_inbound_req(self): + match = LSPPatternMatcher.match_header_pattern("inbound.req.X-Foo") + assert match is not None + assert match.context_type == "Header" + assert match.suffix == "X-Foo" + + def test_outbound_resp(self): + match = LSPPatternMatcher.match_header_pattern("outbound.resp.Content-Type") + assert match is not None + assert match.suffix == "Content-Type" + + def test_no_match(self): + assert LSPPatternMatcher.match_header_pattern("now.YEAR") is None + + +class TestMatchCookiePattern: + + def test_inbound_cookie(self): + match = LSPPatternMatcher.match_cookie_pattern("inbound.cookie.session_id") + assert match is not None + assert match.context_type == "Cookie" + assert match.suffix == "session_id" + + def test_outbound_cookie(self): + match = LSPPatternMatcher.match_cookie_pattern("outbound.cookie.token") + assert match is not None + assert match.suffix == "token" + + def test_no_match(self): + assert LSPPatternMatcher.match_cookie_pattern("inbound.req.X-Foo") is None + + +class TestMatchCertificatePattern: + + def test_inbound_client_cert(self): + match = LSPPatternMatcher.match_certificate_pattern("inbound.conn.client-cert.CN") + assert match is not None + assert match.context_type == "Certificate" + assert match.suffix == "CN" + + def test_outbound_server_cert(self): + match = LSPPatternMatcher.match_certificate_pattern("outbound.conn.server-cert.SAN") + assert match is not None + + def test_no_match(self): + assert LSPPatternMatcher.match_certificate_pattern("inbound.req.X-Foo") is None + + +class TestMatchConnectionPattern: + + def test_inbound_conn(self): + match = LSPPatternMatcher.match_connection_pattern("inbound.conn.TLS") + assert match is not None + assert match.context_type == "Connection" + assert match.field_dict_key == "CONN_FIELDS" + assert match.suffix == "TLS" + + def test_outbound_conn(self): + match = LSPPatternMatcher.match_connection_pattern("outbound.conn.H2") + assert match is not None + + def test_no_match(self): + assert LSPPatternMatcher.match_connection_pattern("inbound.req.X-Foo") is None + + +class TestMatchAnyPattern: + + def test_field(self): + match = LSPPatternMatcher.match_any_pattern("now.YEAR") + assert match is not None + assert match.context_type == "Current Date/Time Field" + + def test_header(self): + match = LSPPatternMatcher.match_any_pattern("inbound.req.X-Foo") + assert match is not None + assert match.context_type == "Header" + + def test_cookie(self): + match = LSPPatternMatcher.match_any_pattern("inbound.cookie.sid") + assert match is not None + assert match.context_type == "Cookie" + + def test_certificate(self): + match = LSPPatternMatcher.match_any_pattern("inbound.conn.client-cert.CN") + assert match is not None + assert match.context_type == "Certificate" + + def test_connection(self): + match = LSPPatternMatcher.match_any_pattern("inbound.conn.TLS") + assert match is not None + assert match.context_type == "Connection" + + def test_no_match(self): + assert LSPPatternMatcher.match_any_pattern("completely.unknown.thing") is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tools/hrw4u/tests/test_units.py b/tools/hrw4u/tests/test_units.py index cf5d7d0f2eb..7a5c19b367a 100644 --- a/tools/hrw4u/tests/test_units.py +++ b/tools/hrw4u/tests/test_units.py @@ -22,12 +22,12 @@ for internal implementation details. """ -from hrw4u.errors import ErrorCollector, Hrw4uSyntaxError, SymbolResolutionError +from hrw4u.errors import ErrorCollector, Hrw4uSyntaxError, SymbolResolutionError, ThrowingErrorListener, hrw4u_error, \ + CollectingErrorListener from hrw4u.visitor import HRW4UVisitor -from hrw4u.validation import Validator +from hrw4u.validation import Validator, ValidatorChain +from hrw4u.types import MapParams, VarType import pytest -import sys -import os class TestHRW4UVisitorUnits: @@ -284,5 +284,394 @@ def test_regex_validator_factory(self): validator("TEST123") +class TestMapParamsUnits: + """Unit tests for MapParams dunder methods.""" + + def test_repr_empty(self): + mp = MapParams() + assert repr(mp) == "MapParams()" + + def test_repr_with_flags(self): + mp = MapParams(upper=True, add=True) + r = repr(mp) + assert "upper=True" in r + assert "add=True" in r + + def test_repr_with_sections(self): + from hrw4u.states import SectionType + mp = MapParams(sections={SectionType.REMAP}) + assert "sections=" in repr(mp) + + def test_repr_with_validate(self): + mp = MapParams(validate=lambda x: None) + assert "validate=" in repr(mp) + + def test_repr_with_target(self): + mp = MapParams(target="set-header") + assert "target=..." in repr(mp) + + def test_hash_basic(self): + mp1 = MapParams(upper=True) + mp2 = MapParams(upper=True) + assert hash(mp1) == hash(mp2) + + def test_hash_with_sections(self): + from hrw4u.states import SectionType + mp = MapParams(sections={SectionType.REMAP}) + assert isinstance(hash(mp), int) + + def test_hash_with_rev_dict(self): + mp = MapParams(rev={"a": "b"}) + assert isinstance(hash(mp), int) + + def test_eq_same(self): + mp1 = MapParams(upper=True, prefix=True) + mp2 = MapParams(upper=True, prefix=True) + assert mp1 == mp2 + + def test_eq_different(self): + mp1 = MapParams(upper=True) + mp2 = MapParams(upper=False) + assert mp1 != mp2 + + def test_eq_not_mapparams(self): + mp = MapParams() + assert mp != "not a MapParams" + + def test_immutable(self): + mp = MapParams() + with pytest.raises(AttributeError, match="immutable"): + mp.upper = True + + def test_getattr_unknown(self): + mp = MapParams() + assert mp.unknown_attr is None + + def test_getattr_private(self): + mp = MapParams() + with pytest.raises(AttributeError): + _ = mp._private + + +class TestThrowingErrorListener: + """Unit tests for ThrowingErrorListener.""" + + def test_raises_syntax_error(self): + listener = ThrowingErrorListener("test.hrw4u") + + class FakeLexer: + inputStream = None + + with pytest.raises(Hrw4uSyntaxError) as exc_info: + listener.syntaxError(FakeLexer(), None, 1, 5, "unexpected token", None) + + err = exc_info.value + assert err.filename == "test.hrw4u" + assert err.line == 1 + assert err.column == 5 + + def test_extracts_source_line_from_lexer(self): + listener = ThrowingErrorListener("test.hrw4u") + + class FakeInputStream: + strdata = "first line\nsecond line\nthird line" + + class FakeLexer: + inputStream = FakeInputStream() + + with pytest.raises(Hrw4uSyntaxError) as exc_info: + listener.syntaxError(FakeLexer(), None, 2, 0, "bad token", None) + + assert exc_info.value.source_line == "second line" + + def test_extracts_source_line_from_parser(self): + listener = ThrowingErrorListener("test.hrw4u") + + class FakeInputStream: + strdata = "line one\nline two" + + class FakeTokenSource: + inputStream = FakeInputStream() + + class FakeStream: + tokenSource = FakeTokenSource() + + class FakeParser: + + def getInputStream(self): + return FakeStream() + + with pytest.raises(Hrw4uSyntaxError) as exc_info: + listener.syntaxError(FakeParser(), None, 1, 3, "parse error", None) + + assert exc_info.value.source_line == "line one" + + def test_falls_back_on_broken_recognizer(self): + listener = ThrowingErrorListener("test.hrw4u") + + class BrokenRecognizer: + pass + + with pytest.raises(Hrw4uSyntaxError) as exc_info: + listener.syntaxError(BrokenRecognizer(), None, 1, 0, "error", None) + + assert exc_info.value.source_line == "" + + +class TestHrw4uErrorFunction: + """Unit tests for hrw4u_error helper.""" + + def test_passthrough_syntax_error(self): + original = Hrw4uSyntaxError("f.hrw4u", 1, 0, "msg", "line") + assert hrw4u_error("f.hrw4u", None, original) is original + + def test_no_context(self): + exc = ValueError("something broke") + result = hrw4u_error("f.hrw4u", None, exc) + assert result.line == 0 + assert result.column == 0 + + def test_with_context(self): + + class FakeInputStream: + strdata = "some code here" + + class FakeToken: + line = 1 + column = 5 + + def getInputStream(self): + return FakeInputStream() + + class FakeCtx: + start = FakeToken() + + exc = ValueError("bad value") + result = hrw4u_error("f.hrw4u", FakeCtx(), exc) + assert result.line == 1 + assert result.column == 5 + assert result.source_line == "some code here" + + def test_with_context_broken_input_stream(self): + + class FakeToken: + line = 1 + column = 0 + + def getInputStream(self): + raise RuntimeError("broken") + + class FakeCtx: + start = FakeToken() + + exc = ValueError("oops") + result = hrw4u_error("f.hrw4u", FakeCtx(), exc) + assert result.source_line == "" + + def test_preserves_notes(self): + exc = ValueError("base error") + exc.add_note("hint: try X") + result = hrw4u_error("f.hrw4u", None, exc) + assert hasattr(result, '__notes__') + assert "hint: try X" in result.__notes__ + + +class TestErrorCollectorEdgeCases: + """Additional edge case tests for ErrorCollector.""" + + def test_empty_summary(self): + ec = ErrorCollector() + assert ec.get_error_summary() == "No errors found." + + def test_error_with_notes_in_summary(self): + ec = ErrorCollector() + err = Hrw4uSyntaxError("f.hrw4u", 1, 0, "bad", "code") + err.add_note("hint: fix it") + ec.add_error(err) + summary = ec.get_error_summary() + assert "hint: fix it" in summary + + +class TestCollectingErrorListener: + """Unit tests for CollectingErrorListener.""" + + def test_collects_errors(self): + ec = ErrorCollector() + listener = CollectingErrorListener("test.hrw4u", ec) + + class FakeLexer: + inputStream = None + + listener.syntaxError(FakeLexer(), None, 1, 0, "bad token", None) + assert ec.has_errors() + assert ec.errors[0].line == 1 + + def test_extracts_source_from_lexer(self): + ec = ErrorCollector() + listener = CollectingErrorListener("test.hrw4u", ec) + + class FakeInputStream: + strdata = "the source line" + + class FakeLexer: + inputStream = FakeInputStream() + + listener.syntaxError(FakeLexer(), None, 1, 5, "error", None) + assert ec.errors[0].source_line == "the source line" + + +class TestValidatorChainUnits: + """Unit tests for ValidatorChain convenience methods.""" + + def test_arg_at_valid(self): + chain = ValidatorChain() + validator = Validator.nbit_int(8) + chain.arg_at(1, validator) + chain(["foo", "42"]) + + def test_arg_at_missing_index(self): + chain = ValidatorChain() + validator = Validator.nbit_int(8) + chain.arg_at(5, validator) + with pytest.raises(SymbolResolutionError, match="Missing argument"): + chain(["foo"]) + + def test_nbit_int_valid(self): + v = Validator.nbit_int(8) + v("0") + v("255") + + def test_nbit_int_out_of_range(self): + v = Validator.nbit_int(8) + with pytest.raises(SymbolResolutionError, match="8-bit"): + v("256") + + def test_nbit_int_not_integer(self): + v = Validator.nbit_int(8) + with pytest.raises(SymbolResolutionError, match="Expected an integer"): + v("abc") + + def test_range_valid(self): + v = Validator.range(100, 599) + v("200") + v("100") + v("599") + + def test_range_out_of_range(self): + v = Validator.range(100, 599) + with pytest.raises(SymbolResolutionError, match="range"): + v("600") + + def test_range_not_integer(self): + v = Validator.range(100, 599) + with pytest.raises(SymbolResolutionError, match="Expected an integer"): + v("abc") + + def test_validate_assignment_bool_valid(self): + Validator.validate_assignment(VarType.BOOL, "TRUE", "my_flag") + Validator.validate_assignment(VarType.BOOL, "OFF", "my_flag") + + def test_validate_assignment_bool_invalid(self): + with pytest.raises(SymbolResolutionError, match="Invalid value"): + Validator.validate_assignment(VarType.BOOL, "maybe", "my_flag") + + def test_validate_assignment_int8_valid(self): + Validator.validate_assignment(VarType.INT8, "0", "my_int") + Validator.validate_assignment(VarType.INT8, "255", "my_int") + + def test_validate_assignment_int8_out_of_range(self): + with pytest.raises(SymbolResolutionError, match="Invalid value"): + Validator.validate_assignment(VarType.INT8, "256", "my_int") + + def test_validate_assignment_int8_not_integer(self): + with pytest.raises(SymbolResolutionError, match="Expected integer"): + Validator.validate_assignment(VarType.INT8, "abc", "my_int") + + def test_validate_assignment_int16_valid(self): + Validator.validate_assignment(VarType.INT16, "0", "my_int") + Validator.validate_assignment(VarType.INT16, "65535", "my_int") + + def test_validate_assignment_int16_out_of_range(self): + with pytest.raises(SymbolResolutionError, match="Invalid value"): + Validator.validate_assignment(VarType.INT16, "65536", "my_int") + + def test_validate_assignment_int16_not_integer(self): + with pytest.raises(SymbolResolutionError, match="Expected integer"): + Validator.validate_assignment(VarType.INT16, "abc", "my_int") + + def test_set_format_valid(self): + v = Validator.set_format() + v("[a, b, c]") + v("(x, y, z)") + + def test_set_format_invalid(self): + v = Validator.set_format() + with pytest.raises(SymbolResolutionError, match="Set must be enclosed"): + v("{a, b}") + + def test_iprange_format_valid(self): + v = Validator.iprange_format() + v("{10.0.0.0/8}") + + def test_iprange_format_invalid(self): + v = Validator.iprange_format() + with pytest.raises(SymbolResolutionError, match="IP range must be enclosed"): + v("[10.0.0.0/8]") + + def test_regex_pattern_valid(self): + v = Validator.regex_pattern() + v("/^hello$/") + v("simple") + v("/[a-z]+/") + + def test_regex_pattern_invalid(self): + v = Validator.regex_pattern() + with pytest.raises(SymbolResolutionError, match="Invalid regex"): + v("/[unclosed/") + + def test_regex_pattern_empty(self): + v = Validator.regex_pattern() + with pytest.raises(SymbolResolutionError, match="Empty regex"): + v("") + + def test_conditional_arg_validation_valid(self): + mapping = {"LOGGING": frozenset({"ON", "OFF"})} + v = Validator.conditional_arg_validation(mapping) + v(["LOGGING", "ON"]) + + def test_conditional_arg_validation_invalid_value(self): + mapping = {"LOGGING": frozenset({"ON", "OFF"})} + v = Validator.conditional_arg_validation(mapping) + with pytest.raises(SymbolResolutionError, match="Invalid value"): + v(["LOGGING", "MAYBE"]) + + def test_conditional_arg_validation_unknown_field(self): + mapping = {"LOGGING": frozenset({"ON", "OFF"})} + v = Validator.conditional_arg_validation(mapping) + with pytest.raises(SymbolResolutionError, match="Unknown field"): + v(["UNKNOWN", "ON"]) + + def test_percent_block_valid(self): + v = Validator.percent_block() + v("%{SOME_VAR}") + + def test_percent_block_invalid(self): + v = Validator.percent_block() + with pytest.raises(SymbolResolutionError, match="Invalid percent block"): + v("SOME_VAR") + + def test_needs_quotes(self): + assert Validator.needs_quotes("") is True + assert Validator.needs_quotes("simple") is False + assert Validator.needs_quotes('"already quoted"') is False + assert Validator.needs_quotes("/regex/") is False + assert Validator.needs_quotes("42") is False + assert Validator.needs_quotes("has space") is True + + def test_quote_if_needed(self): + assert Validator.quote_if_needed("simple") == "simple" + assert Validator.quote_if_needed("has space") == '"has space"' + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tools/hrw4u/tests/utils.py b/tools/hrw4u/tests/utils.py index fac320fbcd0..b86c5562b1f 100644 --- a/tools/hrw4u/tests/utils.py +++ b/tools/hrw4u/tests/utils.py @@ -144,11 +144,11 @@ def collect_failing_inputs(group: str) -> Iterator[pytest.param]: yield pytest.param(input_file, id=test_id) -def run_output_test(input_file: Path, output_file: Path) -> None: +def run_output_test(input_file: Path, output_file: Path, debug: bool = False) -> None: """Run output validation test comparing generated output with expected.""" input_text = input_file.read_text() parser, tree = parse_input_text(input_text) - visitor = HRW4UVisitor() + visitor = HRW4UVisitor(debug=debug) actual_output = "\n".join(visitor.visit(tree)).strip() expected_output = output_file.read_text().strip() assert actual_output == expected_output, f"Output mismatch in {input_file}" @@ -257,14 +257,14 @@ def _assert_structured_error_fields( f"Actual full error:\n{actual_full_error}") -def run_reverse_test(input_file: Path, output_file: Path) -> None: +def run_reverse_test(input_file: Path, output_file: Path, debug: bool = False) -> None: """Run u4wrh on output.txt and compare with input.txt (round-trip test).""" output_text = output_file.read_text() lexer = u4wrhLexer(InputStream(output_text)) stream = CommonTokenStream(lexer) parser = u4wrhParser(stream) tree = parser.program() - visitor = HRWInverseVisitor(filename=str(output_file)) + visitor = HRWInverseVisitor(filename=str(output_file), debug=debug) actual_hrw4u = "\n".join(visitor.visit(tree)).strip() expected_hrw4u = input_file.read_text().strip() assert actual_hrw4u == expected_hrw4u, f"Reverse conversion mismatch for {output_file}"