diff --git a/plugboard-schemas/plugboard_schemas/__init__.py b/plugboard-schemas/plugboard_schemas/__init__.py index 33cd6acc..e0e34f39 100644 --- a/plugboard-schemas/plugboard_schemas/__init__.py +++ b/plugboard-schemas/plugboard_schemas/__init__.py @@ -9,6 +9,13 @@ from importlib.metadata import version from ._common import PlugboardBaseModel +from ._graph import simple_cycles +from ._validation import ( + validate_all_inputs_connected, + validate_input_events, + validate_no_unresolved_cycles, + validate_process, +) from .component import ComponentArgsDict, ComponentArgsSpec, ComponentSpec, Resource from .config import ConfigSpec, ProcessConfigSpec from .connector import ( @@ -85,4 +92,9 @@ "TuneArgsDict", "TuneArgsSpec", "TuneSpec", + "simple_cycles", + "validate_all_inputs_connected", + "validate_input_events", + "validate_no_unresolved_cycles", + "validate_process", ] diff --git a/plugboard-schemas/plugboard_schemas/_graph.py b/plugboard-schemas/plugboard_schemas/_graph.py new file mode 100644 index 00000000..50ca823a --- /dev/null +++ b/plugboard-schemas/plugboard_schemas/_graph.py @@ -0,0 +1,126 @@ +"""Graph algorithms for topology validation. + +Implements Johnson's algorithm for finding all simple cycles in a directed graph, +along with helper functions for strongly connected components. + +References: + Donald B Johnson. "Finding all the elementary circuits of a directed graph." + SIAM Journal on Computing. 1975. +""" + +from collections import defaultdict +from collections.abc import Generator + + +def simple_cycles(graph: dict[str, set[str]]) -> Generator[list[str], None, None]: + """Find all simple cycles in a directed graph using Johnson's algorithm. + + Args: + graph: A dictionary mapping each vertex to a set of its neighbours. + + Yields: + Each elementary cycle as a list of vertices. + """ + graph = {v: set(nbrs) for v, nbrs in graph.items()} + sccs = _strongly_connected_components(graph) + while sccs: + scc = sccs.pop() + startnode = scc.pop() + path = [startnode] + blocked: set[str] = set() + closed: set[str] = set() + blocked.add(startnode) + B: dict[str, set[str]] = defaultdict(set) + stack: list[tuple[str, list[str]]] = [(startnode, list(graph[startnode]))] + while stack: + thisnode, nbrs = stack[-1] + if nbrs: + nextnode = nbrs.pop() + if nextnode == startnode: + yield path[:] + closed.update(path) + elif nextnode not in blocked: + path.append(nextnode) + stack.append((nextnode, list(graph[nextnode]))) + closed.discard(nextnode) + blocked.add(nextnode) + continue + if not nbrs: + if thisnode in closed: + _unblock(thisnode, blocked, B) + else: + for nbr in graph[thisnode]: + if thisnode not in B[nbr]: + B[nbr].add(thisnode) + stack.pop() + path.pop() + _remove_node(graph, startnode) + H = _subgraph(graph, set(scc)) + sccs.extend(_strongly_connected_components(H)) + + +def _unblock(thisnode: str, blocked: set[str], B: dict[str, set[str]]) -> None: + """Unblock a node and recursively unblock nodes in its B set.""" + stack = {thisnode} + while stack: + node = stack.pop() + if node in blocked: + blocked.remove(node) + stack.update(B[node]) + B[node].clear() + + +def _strongly_connected_components(graph: dict[str, set[str]]) -> list[set[str]]: + """Find all strongly connected components using Tarjan's algorithm. + + Args: + graph: A dictionary mapping each vertex to a set of its neighbours. + + Returns: + A list of sets, each containing the vertices of a strongly connected component. + """ + index_counter = [0] + stack: list[str] = [] + lowlink: dict[str, int] = {} + index: dict[str, int] = {} + result: list[set[str]] = [] + + def _strong_connect(node: str) -> None: + index[node] = index_counter[0] + lowlink[node] = index_counter[0] + index_counter[0] += 1 + stack.append(node) + + for successor in graph.get(node, set()): + if successor not in index: + _strong_connect(successor) + lowlink[node] = min(lowlink[node], lowlink[successor]) + elif successor in stack: + lowlink[node] = min(lowlink[node], index[successor]) + + if lowlink[node] == index[node]: + connected_component: set[str] = set() + while True: + successor = stack.pop() + connected_component.add(successor) + if successor == node: + break + result.append(connected_component) + + for node in graph: + if node not in index: + _strong_connect(node) + + return result + + +def _remove_node(graph: dict[str, set[str]], target: str) -> None: + """Remove a node and all its edges from the graph.""" + del graph[target] + for nbrs in graph.values(): + nbrs.discard(target) + + +def _subgraph(graph: dict[str, set[str]], vertices: set[str]) -> dict[str, set[str]]: + """Get the subgraph induced by a set of vertices.""" + return {v: graph[v] & vertices for v in vertices} diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py new file mode 100644 index 00000000..467cabd0 --- /dev/null +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -0,0 +1,197 @@ +"""Validation utilities for process topology. + +Provides functions to validate process topology including: +- Checking that all component inputs are connected +- Checking that input events have matching output event producers +- Checking for circular connections that require initial values + +All validators accept the output of ``process.dict()`` or the relevant +sub-structures thereof. +""" + +from __future__ import annotations + +from collections import defaultdict +import typing as _t + +from ._graph import simple_cycles + + +def _build_component_graph( + connectors: dict[str, dict[str, _t.Any]], +) -> dict[str, set[str]]: + """Build a directed graph of component connections from connector dicts. + + Args: + connectors: Dictionary mapping connector IDs to connector dicts, + as returned by ``process.dict()["connectors"]``. + + Returns: + A dictionary mapping source component names to sets of target component names. + """ + graph: dict[str, set[str]] = defaultdict(set) + for conn_info in connectors.values(): + spec = conn_info["spec"] + source_entity = spec["source"]["entity"] + target_entity = spec["target"]["entity"] + if source_entity != target_entity: + graph[source_entity].add(target_entity) + if target_entity not in graph: + graph[target_entity] = set() + return dict(graph) + + +def _get_edges_in_cycle( + cycle: list[str], + connectors: dict[str, dict[str, _t.Any]], +) -> list[dict[str, _t.Any]]: + """Get all connector spec dicts that form edges within a cycle. + + Args: + cycle: List of component names forming a cycle. + connectors: Dictionary mapping connector IDs to connector dicts. + + Returns: + List of connector spec dicts that are part of the cycle. + """ + cycle_edges: list[dict[str, _t.Any]] = [] + for i, node in enumerate(cycle): + next_node = cycle[(i + 1) % len(cycle)] + for conn_info in connectors.values(): + spec = conn_info["spec"] + if spec["source"]["entity"] == node and spec["target"]["entity"] == next_node: + cycle_edges.append(spec) + return cycle_edges + + +def validate_all_inputs_connected( + process_dict: dict[str, _t.Any], +) -> list[str]: + """Check that all component inputs are connected. + + Args: + process_dict: The output of ``process.dict()``. Uses the ``"components"`` + and ``"connectors"`` keys. + + Returns: + List of error messages for unconnected inputs. + """ + components: dict[str, dict[str, _t.Any]] = process_dict["components"] + connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"] + + connected_inputs: dict[str, set[str]] = defaultdict(set) + for conn_info in connectors.values(): + spec = conn_info["spec"] + target_name = spec["target"]["entity"] + target_field = spec["target"]["descriptor"] + connected_inputs[target_name].add(target_field) + + errors: list[str] = [] + for comp_name, comp_data in components.items(): + io = comp_data.get("io", {}) + all_inputs = set(io.get("inputs", [])) + connected = connected_inputs.get(comp_name, set()) + unconnected = all_inputs - connected + if unconnected: + errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}") + return errors + + +def validate_input_events( + process_dict: dict[str, _t.Any], +) -> list[str]: + """Check that all components with input events have a matching output event producer. + + Args: + process_dict: The output of ``process.dict()``. Uses the ``"components"`` key. + + Returns: + List of error messages for unmatched input events. + """ + components: dict[str, dict[str, _t.Any]] = process_dict["components"] + + all_output_events: set[str] = set() + for comp_data in components.values(): + io = comp_data.get("io", {}) + all_output_events.update(io.get("output_events", [])) + + errors: list[str] = [] + for comp_name, comp_data in components.items(): + io = comp_data.get("io", {}) + input_events = set(io.get("input_events", [])) + unmatched = input_events - all_output_events + if unmatched: + errors.append( + f"Component '{comp_name}' has input events with no producer: {sorted(unmatched)}" + ) + return errors + + +def validate_no_unresolved_cycles( + process_dict: dict[str, _t.Any], +) -> list[str]: + """Check for circular connections that are not resolved by initial values. + + Circular loops are only valid if there are ``initial_values`` set on an + appropriate component input within the loop. + + Args: + process_dict: The output of ``process.dict()``. Uses the ``"components"`` + and ``"connectors"`` keys. + + Returns: + List of error messages for unresolved circular connections. + """ + components: dict[str, dict[str, _t.Any]] = process_dict["components"] + connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"] + + graph = _build_component_graph(connectors) + if not graph: + return [] + + # Build lookup of component initial_values by name + initial_values_by_comp: dict[str, set[str]] = {} + for comp_name, comp_data in components.items(): + io = comp_data.get("io", {}) + iv = io.get("initial_values", {}) + if iv: + initial_values_by_comp[comp_name] = set(iv.keys()) + + errors: list[str] = [] + for cycle in simple_cycles(graph): + cycle_edges = _get_edges_in_cycle(cycle, connectors) + cycle_resolved = False + for edge in cycle_edges: + target_comp = edge["target"]["entity"] + target_field = edge["target"]["descriptor"] + if target_comp in initial_values_by_comp: + if target_field in initial_values_by_comp[target_comp]: + cycle_resolved = True + break + if not cycle_resolved: + cycle_str = " -> ".join(cycle + [cycle[0]]) + errors.append( + f"Circular connection detected without initial values: {cycle_str}. " + f"Set initial_values on a component input within the loop to resolve." + ) + return errors + + +def validate_process(process_dict: dict[str, _t.Any]) -> list[str]: + """Run all topology validation checks on a process. + + This is the main validation entry point. It accepts the output of + ``process.dict()`` and runs every available check, returning a + combined list of error messages. + + Args: + process_dict: The output of ``process.dict()``. + + Returns: + List of error messages. An empty list indicates a valid topology. + """ + errors: list[str] = [] + errors.extend(validate_all_inputs_connected(process_dict)) + errors.extend(validate_input_events(process_dict)) + errors.extend(validate_no_unresolved_cycles(process_dict)) + return errors diff --git a/plugboard/cli/process/__init__.py b/plugboard/cli/process/__init__.py index 518abd79..b50e14ef 100644 --- a/plugboard/cli/process/__init__.py +++ b/plugboard/cli/process/__init__.py @@ -13,7 +13,7 @@ from plugboard.diagram import MermaidDiagram from plugboard.process import Process, ProcessBuilder -from plugboard.schemas import ConfigSpec +from plugboard.schemas import ConfigSpec, validate_process from plugboard.tune import Tuner from plugboard.utils import add_sys_path, run_coro_sync @@ -164,3 +164,31 @@ def diagram( diagram = MermaidDiagram.from_process(process) md = Markdown(f"```\n{diagram.diagram}\n```\n[Editable diagram]({diagram.url}) (external link)") print(md) + + +@app.command() +def validate( + config: Annotated[ + Path, + typer.Argument( + exists=True, + file_okay=True, + dir_okay=False, + writable=False, + readable=True, + resolve_path=True, + help="Path to the YAML configuration file.", + ), + ], +) -> None: + """Validate a Plugboard process configuration.""" + config_spec = _read_yaml(config) + with add_sys_path(config.parent): + process = _build_process(config_spec) + errors = validate_process(process.dict()) + if errors: + stderr.print("[red]Validation failed:[/red]") + for error in errors: + stderr.print(f" • {error}") + raise typer.Exit(1) + print("[green]Validation passed[/green]") diff --git a/plugboard/process/process.py b/plugboard/process/process.py index ccd45f58..939ae94a 100644 --- a/plugboard/process/process.py +++ b/plugboard/process/process.py @@ -13,8 +13,8 @@ from plugboard.component import Component from plugboard.connector import Connector -from plugboard.exceptions import NotInitialisedError -from plugboard.schemas import ConfigSpec, Status +from plugboard.exceptions import NotInitialisedError, ValidationError +from plugboard.schemas import ConfigSpec, Status, validate_process from plugboard.state import DictStateBackend, StateBackend from plugboard.utils import DI, ExportMixin, gen_rand_str from plugboard.utils.async_utils import run_coro_sync @@ -109,6 +109,10 @@ async def _set_status(self, status: Status, publish: bool = True) -> None: @abstractmethod async def init(self) -> None: """Performs component initialisation actions.""" + errors = validate_process(self.dict()) + if errors: + msg = "Process validation failed:\n" + "\n".join(errors) + raise ValidationError(msg) self._is_initialised = True await self._set_status(Status.INIT) diff --git a/plugboard/schemas/__init__.py b/plugboard/schemas/__init__.py index 427ee410..4ee8fc79 100644 --- a/plugboard/schemas/__init__.py +++ b/plugboard/schemas/__init__.py @@ -46,6 +46,11 @@ TuneArgsDict, TuneArgsSpec, TuneSpec, + simple_cycles, + validate_all_inputs_connected, + validate_input_events, + validate_no_unresolved_cycles, + validate_process, ) @@ -86,4 +91,9 @@ "TuneArgsDict", "TuneArgsSpec", "TuneSpec", + "simple_cycles", + "validate_all_inputs_connected", + "validate_input_events", + "validate_no_unresolved_cycles", + "validate_process", ] diff --git a/tests/conftest.py b/tests/conftest.py index de6612f6..b7688fcb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -143,3 +143,13 @@ def dict(self) -> dict: } ) return data + + +@pytest.fixture +def patch_validate_process() -> _t.Iterator[None]: + """Patch process validation for tests that don't require functional processes.""" + with ( + patch("plugboard.schemas.validate_process", return_value=[]), + patch("plugboard.process.process.validate_process", return_value=[]), + ): + yield diff --git a/tests/integration/test_job_id_wiring.py b/tests/integration/test_job_id_wiring.py index 1d9cb859..c820236b 100644 --- a/tests/integration/test_job_id_wiring.py +++ b/tests/integration/test_job_id_wiring.py @@ -206,7 +206,7 @@ async def test_cli_process_run_with_no_job_id(minimal_config_file: Path) -> None @pytest.mark.asyncio -async def test_direct_process_with_job_id() -> None: +async def test_direct_process_with_job_id(patch_validate_process: None) -> None: """Test building a process directly with a specified job ID.""" # Create a state backend with a specific job ID state: DictStateBackend = DictStateBackend(job_id="Job_direct12345678") @@ -221,7 +221,7 @@ async def test_direct_process_with_job_id() -> None: @pytest.mark.asyncio -async def test_direct_process_with_env_var() -> None: +async def test_direct_process_with_env_var(patch_validate_process: None) -> None: """Test building a process without a job ID while environment variable is set.""" # Set the environment variable to match the job ID to avoid conflict job_id: str = "Job_direct12345678" @@ -240,7 +240,9 @@ async def test_direct_process_with_env_var() -> None: @pytest.mark.asyncio -async def test_direct_process_with_job_id_and_env_var() -> None: +async def test_direct_process_with_job_id_and_env_var( + patch_validate_process: None, +) -> None: """Test building a process with a job ID while environment variable is set with same value.""" # Set the environment variable to match the job ID to avoid conflict job_id: str = "Job_direct12345678" @@ -271,7 +273,7 @@ async def test_direct_process_with_conflicting_job_ids() -> None: @pytest.mark.asyncio -async def test_direct_process_without_job_id() -> None: +async def test_direct_process_without_job_id(patch_validate_process: None) -> None: """Test building a process without specifying a job ID.""" # Create a state backend without a job ID state: DictStateBackend = DictStateBackend() @@ -289,7 +291,9 @@ async def test_direct_process_without_job_id() -> None: @pytest.mark.asyncio -async def test_direct_process_without_job_id_multiple_runs() -> None: +async def test_direct_process_without_job_id_multiple_runs( + patch_validate_process: None, +) -> None: """Test building a process without specifying a job ID multiple times.""" # Create a state backend without a job ID state: DictStateBackend = DictStateBackend() @@ -316,7 +320,9 @@ async def test_direct_process_without_job_id_multiple_runs() -> None: @pytest.mark.asyncio -async def test_direct_process_without_job_id_multiple_runs_multiprocessing() -> None: +async def test_direct_process_without_job_id_multiple_runs_multiprocessing( + patch_validate_process: None, +) -> None: """Test building a process without a job ID in a multiprocessing context.""" # Create a state backend without a specific job ID state: SqliteStateBackend = SqliteStateBackend() diff --git a/tests/integration/test_process_builder.py b/tests/integration/test_process_builder.py index cdfc4c5d..1900a61f 100644 --- a/tests/integration/test_process_builder.py +++ b/tests/integration/test_process_builder.py @@ -50,6 +50,13 @@ async def dummy_event_2_handler(self, event: DummyEvent2) -> None: pass +class E(ComponentTestHelper): + io = IO(output_events=[DummyEvent1, DummyEvent2]) + + async def step(self) -> None: + pass + + @pytest.fixture def process_spec() -> ProcessSpec: """Returns a `ProcessSpec` for testing.""" @@ -72,6 +79,10 @@ def process_spec() -> ProcessSpec: type="tests.integration.test_process_builder.D", args={"name": "D"}, ), + ComponentSpec( + type="tests.integration.test_process_builder.E", + args={"name": "E"}, + ), ], connectors=[ ConnectorSpec( @@ -103,11 +114,11 @@ async def test_process_builder_build(process_spec: ProcessSpec) -> None: # Must build a process with the correct type process.__class__.__name__ == process_spec.args.state.type.split(".")[-1] # Must build a process with the correct components and connectors - assert len(process.components) == 4 + assert len(process.components) == 5 # Number of connectors must be sum of: fields in config; user events; and system events assert len(process.connectors) == 2 + 2 + 1 # Must build a process with the correct component names - assert process.components.keys() == {"A", "B", "C", "D"} + assert process.components.keys() == {"A", "B", "C", "D", "E"} # Must build connectors with the correct channel types assert all( conn.__class__.__name__ == "AsyncioConnector" for conn in process.connectors.values() diff --git a/tests/integration/test_process_validation.py b/tests/integration/test_process_validation.py index f77c8678..07294111 100644 --- a/tests/integration/test_process_validation.py +++ b/tests/integration/test_process_validation.py @@ -15,9 +15,6 @@ from tests.integration.test_process_with_components_run import A, B, C -# TODO: Update these tests when we implement full graph validation - - def filter_logs(logs: list[EventDict], field: str, regex: str) -> list[EventDict]: """Filters the log output by applying regex to a field.""" pattern = re.compile(regex) @@ -26,20 +23,15 @@ def filter_logs(logs: list[EventDict], field: str, regex: str) -> list[EventDict @pytest.mark.asyncio async def test_missing_connections() -> None: - """Tests that missing connections are logged.""" + """Tests that missing input connections raise ValidationError.""" p_missing_input = LocalProcess( components=[A(name="a", iters=10), C(name="c", path="test-out.csv")], # c.in_1 is not connected connectors=[AsyncioConnector(spec=ConnectorSpec(source="a.out_1", target="unknown.x"))], ) - with capture_logs() as logs: + with pytest.raises(exceptions.ValidationError, match="unconnected inputs"): await p_missing_input.init() - # Must contain an error-level log indicating that input is not connected - logs = filter_logs(logs, "log_level", "error") - logs = filter_logs(logs, "event", "Input fields not connected") - assert logs, "Logs do not indicate missing connection" - p_missing_output = LocalProcess( components=[A(name="a", iters=10), B(name="b")], # b.out_1 is not connected diff --git a/tests/integration/test_state_backend.py b/tests/integration/test_state_backend.py index c007c74a..b5516058 100644 --- a/tests/integration/test_state_backend.py +++ b/tests/integration/test_state_backend.py @@ -230,7 +230,10 @@ async def test_state_backend_upsert_connector( @pytest.mark.asyncio async def test_state_backend_process_init( - state_backend: StateBackend, B_components: list[Component], B_connectors: list[Connector] + state_backend: StateBackend, + B_components: list[Component], + B_connectors: list[Connector], + patch_validate_process: None, ) -> None: """Tests `StateBackend` connected up correctly on `Process.init`.""" comp_b1, comp_b2 = B_components @@ -258,7 +261,10 @@ async def test_state_backend_process_init( @pytest.mark.asyncio async def test_state_backend_process_status( - state_backend: StateBackend, B_components: list[Component], B_connectors: list[Connector] + state_backend: StateBackend, + B_components: list[Component], + B_connectors: list[Connector], + patch_validate_process: None, ) -> None: """Tests `StateBackend` process status updates.""" comp_b1, comp_b2 = B_components @@ -299,7 +305,9 @@ async def test_state_backend_process_status( @pytest.mark.asyncio -async def test_state_backend_process_parameters(state_backend: StateBackend) -> None: +async def test_state_backend_process_parameters( + state_backend: StateBackend, patch_validate_process: None +) -> None: """Tests `StateBackend` process parameters storage and retrieval.""" parameters = {"param_1": 10, "param_2": "value"} component_a = A(name="ComponentA") diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index cb3f429c..0291f828 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -87,6 +87,23 @@ def test_cli_process_diagram() -> None: assert "flowchart" in result.stdout +def test_cli_process_validate() -> None: + """Tests the process validate command.""" + result = runner.invoke(app, ["process", "validate", "tests/data/minimal-process.yaml"]) + # CLI must run without error for a valid config + assert result.exit_code == 0 + assert "Validation passed" in result.stdout + + +def test_cli_process_validate_invalid() -> None: + """Tests the process validate command with an invalid process.""" + with patch("plugboard.cli.process.validate_process") as mock_validate: + mock_validate.return_value = ["Component 'x' has unconnected inputs: ['in_1']"] + result = runner.invoke(app, ["process", "validate", "tests/data/minimal-process.yaml"]) + assert result.exit_code == 1 + assert "Validation failed" in result.stderr + + def test_cli_server_discover(test_project_dir: Path) -> None: """Tests the server discover command.""" with respx.mock: diff --git a/tests/unit/test_process_validation.py b/tests/unit/test_process_validation.py new file mode 100644 index 00000000..02e0a4d2 --- /dev/null +++ b/tests/unit/test_process_validation.py @@ -0,0 +1,388 @@ +"""Tests for process topology validation.""" + +import typing as _t + +from plugboard_schemas import ( + validate_all_inputs_connected, + validate_input_events, + validate_no_unresolved_cycles, + validate_process, +) +from plugboard_schemas._graph import simple_cycles + + +# --------------------------------------------------------------------------- +# Tests for simple_cycles (Johnson's algorithm) +# --------------------------------------------------------------------------- + + +class TestSimpleCycles: + """Tests for Johnson's cycle-finding algorithm.""" + + def test_no_cycles(self) -> None: + """A DAG has no cycles.""" + graph: dict[str, set[str]] = {"a": {"b"}, "b": {"c"}, "c": set()} + assert list(simple_cycles(graph)) == [] + + def test_single_self_loop(self) -> None: + """A self-loop is a cycle of length 1.""" + graph: dict[str, set[str]] = {"a": {"a"}} + cycles = list(simple_cycles(graph)) + assert len(cycles) == 1 + assert cycles[0] == ["a"] + + def test_simple_two_node_cycle(self) -> None: + """Two nodes forming a cycle.""" + graph: dict[str, set[str]] = {"a": {"b"}, "b": {"a"}} + cycles = list(simple_cycles(graph)) + assert len(cycles) == 1 + assert set(cycles[0]) == {"a", "b"} + + def test_three_node_cycle(self) -> None: + """Three nodes forming a single cycle.""" + graph: dict[str, set[str]] = {"a": {"b"}, "b": {"c"}, "c": {"a"}} + cycles = list(simple_cycles(graph)) + assert len(cycles) == 1 + assert set(cycles[0]) == {"a", "b", "c"} + + def test_multiple_cycles(self) -> None: + """Graph with multiple distinct cycles.""" + graph: dict[str, set[str]] = { + "a": {"b"}, + "b": {"a", "c"}, + "c": {"d"}, + "d": {"c"}, + } + cycles = list(simple_cycles(graph)) + cycle_sets = [frozenset(c) for c in cycles] + assert frozenset({"a", "b"}) in cycle_sets + assert frozenset({"c", "d"}) in cycle_sets + + def test_empty_graph(self) -> None: + """Empty graph has no cycles.""" + graph: dict[str, set[str]] = {} + assert list(simple_cycles(graph)) == [] + + def test_disconnected_graph(self) -> None: + """Disconnected graph with no cycles.""" + graph: dict[str, set[str]] = {"a": set(), "b": set(), "c": set()} + assert list(simple_cycles(graph)) == [] + + def test_complex_graph(self) -> None: + """Complex graph with overlapping cycles.""" + graph: dict[str, set[str]] = { + "a": {"b"}, + "b": {"c"}, + "c": {"a", "d"}, + "d": {"b"}, + } + cycles = list(simple_cycles(graph)) + # Should find cycles: a->b->c->a and b->c->d->b + assert len(cycles) >= 2 + cycle_sets = [frozenset(c) for c in cycles] + assert frozenset({"a", "b", "c"}) in cycle_sets + assert frozenset({"b", "c", "d"}) in cycle_sets + + +# --------------------------------------------------------------------------- +# Helpers for building process.dict()-style data structures +# --------------------------------------------------------------------------- + + +def _make_component( + name: str, + inputs: list[str] | None = None, + outputs: list[str] | None = None, + input_events: list[str] | None = None, + output_events: list[str] | None = None, + initial_values: dict[str, _t.Any] | None = None, +) -> dict[str, _t.Any]: + """Build a component dict matching process.dict() format.""" + return { + "id": name, + "name": name, + "status": "created", + "io": { + "namespace": name, + "inputs": inputs or [], + "outputs": outputs or [], + "input_events": input_events or [], + "output_events": output_events or [], + "initial_values": initial_values or {}, + }, + } + + +def _make_connector(source: str, target: str) -> dict[str, _t.Any]: + """Build a connector dict matching process.dict() format.""" + src_entity, src_desc = source.split(".") + tgt_entity, tgt_desc = target.split(".") + conn_id = f"{source}..{target}" + return { + conn_id: { + "id": conn_id, + "spec": { + "source": {"entity": src_entity, "descriptor": src_desc}, + "target": {"entity": tgt_entity, "descriptor": tgt_desc}, + "mode": "pipeline", + }, + } + } + + +def _make_process_dict( + components: dict[str, dict[str, _t.Any]], + connectors: dict[str, dict[str, _t.Any]] | None = None, +) -> dict[str, _t.Any]: + """Build a process dict matching process.dict() format.""" + return { + "id": "test_process", + "name": "test_process", + "status": "created", + "components": components, + "connectors": connectors or {}, + "parameters": {}, + } + + +# --------------------------------------------------------------------------- +# Tests for validate_no_unresolved_cycles +# --------------------------------------------------------------------------- + + +class TestValidateNoUnresolvedCycles: + """Tests for circular connection validation.""" + + def test_no_cycles_passes(self) -> None: + """Linear topology passes validation.""" + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "c.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out"]), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + "c": _make_component("c", inputs=["in_1"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert errors == [] + + def test_cycle_without_initial_values_fails(self) -> None: + """Cycle without initial_values returns errors.""" + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "a.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], outputs=["out"]), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert len(errors) == 1 + assert "Circular connection detected" in errors[0] + + def test_cycle_with_initial_values_passes(self) -> None: + """Cycle with initial_values on a target input passes.""" + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "a.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component( + "a", inputs=["in_1"], outputs=["out"], initial_values={"in_1": [0]} + ), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert errors == [] + + def test_cycle_with_initial_values_on_other_field_fails(self) -> None: + """Cycle with initial_values on an unrelated field still fails.""" + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "a.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component( + "a", inputs=["in_1"], outputs=["out"], initial_values={"other": [0]} + ), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert len(errors) == 1 + assert "Circular connection detected" in errors[0] + + def test_three_node_cycle_without_initial_values_fails(self) -> None: + """Three-node cycle without initial_values returns errors.""" + connectors = { + **_make_connector("a.out", "b.in_1"), + **_make_connector("b.out", "c.in_1"), + **_make_connector("c.out", "a.in_1"), + } + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], outputs=["out"]), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + "c": _make_component("c", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert len(errors) == 1 + assert "Circular connection detected" in errors[0] + + def test_three_node_cycle_with_initial_values_passes(self) -> None: + """Three-node cycle with initial_values on any target input passes.""" + connectors = { + **_make_connector("a.out", "b.in_1"), + **_make_connector("b.out", "c.in_1"), + **_make_connector("c.out", "a.in_1"), + } + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], outputs=["out"]), + "b": _make_component( + "b", inputs=["in_1"], outputs=["out"], initial_values={"in_1": [0]} + ), + "c": _make_component("c", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert errors == [] + + def test_no_connectors_passes(self) -> None: + """Process with no connectors passes validation.""" + pd = _make_process_dict( + components={"a": _make_component("a")}, + ) + errors = validate_no_unresolved_cycles(pd) + assert errors == [] + + +# --------------------------------------------------------------------------- +# Tests for validate_all_inputs_connected +# --------------------------------------------------------------------------- + + +class TestValidateAllInputsConnected: + """Tests for the all-inputs-connected validation utility.""" + + def test_all_connected(self) -> None: + """Test that validation passes when all inputs are connected.""" + connectors = _make_connector("a.out", "b.in_1") + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out"]), + "b": _make_component("b", inputs=["in_1"]), + }, + connectors=connectors, + ) + errors = validate_all_inputs_connected(pd) + assert errors == [] + + def test_missing_input(self) -> None: + """Test that validation fails when an input is not connected.""" + connectors = _make_connector("a.out", "b.in_1") + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out"]), + "b": _make_component("b", inputs=["in_1", "in_2"]), + }, + connectors=connectors, + ) + errors = validate_all_inputs_connected(pd) + assert len(errors) == 1 + assert "in_2" in errors[0] + + def test_no_inputs_no_errors(self) -> None: + """Test that validation passes when components have no inputs.""" + pd = _make_process_dict( + components={"a": _make_component("a")}, + ) + errors = validate_all_inputs_connected(pd) + assert errors == [] + + +# --------------------------------------------------------------------------- +# Tests for validate_input_events +# --------------------------------------------------------------------------- + + +class TestValidateInputEvents: + """Tests for the input-events validation utility.""" + + def test_matched_events(self) -> None: + """Test that validation passes when all input events have producers.""" + pd = _make_process_dict( + components={ + "clock": _make_component("clock", output_events=["tick"]), + "ctrl": _make_component("ctrl", input_events=["tick"]), + }, + ) + errors = validate_input_events(pd) + assert errors == [] + + def test_unmatched_event(self) -> None: + """Test that validation fails when input events have no producer.""" + pd = _make_process_dict( + components={ + "ctrl": _make_component("ctrl", input_events=["tick"]), + }, + ) + errors = validate_input_events(pd) + assert len(errors) == 1 + assert "tick" in errors[0] + + def test_no_events(self) -> None: + """Test that validation passes when components have no events.""" + pd = _make_process_dict( + components={"a": _make_component("a")}, + ) + errors = validate_input_events(pd) + assert errors == [] + + def test_multiple_unmatched(self) -> None: + """Test that validation correctly identifies unmatched events.""" + pd = _make_process_dict( + components={ + "a": _make_component("a", input_events=["evt_x", "evt_y"]), + "b": _make_component("b", output_events=["evt_x"]), + }, + ) + errors = validate_input_events(pd) + assert len(errors) == 1 + assert "evt_y" in errors[0] + + +# --------------------------------------------------------------------------- +# Tests for validate_process (combined validator) +# --------------------------------------------------------------------------- + + +class TestValidateProcess: + """Tests for the combined validate_process utility.""" + + def test_valid_process(self) -> None: + """Test that a valid process returns no errors.""" + connectors = _make_connector("a.out_1", "b.in_1") + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out_1"]), + "b": _make_component("b", inputs=["in_1"]), + }, + connectors=connectors, + ) + errors = validate_process(pd) + assert errors == [] + + def test_multiple_errors(self) -> None: + """Test that multiple validation errors are collected.""" + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], input_events=["missing_evt"]), + }, + ) + errors = validate_process(pd) + # Should have at least: unconnected input + unmatched event + assert len(errors) >= 2