From a057bd706165b163a0b87255cdea4de75a2d9e2c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 19:54:36 +0000 Subject: [PATCH 1/7] Initial plan From 26c0cb5577ea5a301697c44bd557ffffd2dbc21d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:07:00 +0000 Subject: [PATCH 2/7] Add process topology validation: Johnson's algorithm, cycle detection, and validation utilities - Add _graph.py with Johnson's cycle-finding algorithm (no additional dependencies) - Add _validation.py with validation utilities for process topology: - validate_all_inputs_connected: check all component inputs are connected - validate_input_events: check input events have matching output producers - validate_no_unresolved_cycles: check circular connections have initial_values - Add model validator on ProcessSpec for circular connection detection - Export new utilities from plugboard_schemas - Add comprehensive unit tests Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com> --- .../plugboard_schemas/__init__.py | 12 + plugboard-schemas/plugboard_schemas/_graph.py | 126 +++++++ .../plugboard_schemas/_validation.py | 178 ++++++++++ .../plugboard_schemas/process.py | 9 + tests/unit/test_process_validation.py | 331 ++++++++++++++++++ 5 files changed, 656 insertions(+) create mode 100644 plugboard-schemas/plugboard_schemas/_graph.py create mode 100644 plugboard-schemas/plugboard_schemas/_validation.py create mode 100644 tests/unit/test_process_validation.py diff --git a/plugboard-schemas/plugboard_schemas/__init__.py b/plugboard-schemas/plugboard_schemas/__init__.py index 33cd6acc..a4eda967 100644 --- a/plugboard-schemas/plugboard_schemas/__init__.py +++ b/plugboard-schemas/plugboard_schemas/__init__.py @@ -9,6 +9,13 @@ from importlib.metadata import version from ._common import PlugboardBaseModel +from ._graph import simple_cycles +from ._validation import ( + ValidationError, + validate_all_inputs_connected, + validate_input_events, + validate_no_unresolved_cycles, +) from .component import ComponentArgsDict, ComponentArgsSpec, ComponentSpec, Resource from .config import ConfigSpec, ProcessConfigSpec from .connector import ( @@ -85,4 +92,9 @@ "TuneArgsDict", "TuneArgsSpec", "TuneSpec", + "ValidationError", + "simple_cycles", + "validate_all_inputs_connected", + "validate_input_events", + "validate_no_unresolved_cycles", ] diff --git a/plugboard-schemas/plugboard_schemas/_graph.py b/plugboard-schemas/plugboard_schemas/_graph.py new file mode 100644 index 00000000..50ca823a --- /dev/null +++ b/plugboard-schemas/plugboard_schemas/_graph.py @@ -0,0 +1,126 @@ +"""Graph algorithms for topology validation. + +Implements Johnson's algorithm for finding all simple cycles in a directed graph, +along with helper functions for strongly connected components. + +References: + Donald B Johnson. "Finding all the elementary circuits of a directed graph." + SIAM Journal on Computing. 1975. +""" + +from collections import defaultdict +from collections.abc import Generator + + +def simple_cycles(graph: dict[str, set[str]]) -> Generator[list[str], None, None]: + """Find all simple cycles in a directed graph using Johnson's algorithm. + + Args: + graph: A dictionary mapping each vertex to a set of its neighbours. + + Yields: + Each elementary cycle as a list of vertices. + """ + graph = {v: set(nbrs) for v, nbrs in graph.items()} + sccs = _strongly_connected_components(graph) + while sccs: + scc = sccs.pop() + startnode = scc.pop() + path = [startnode] + blocked: set[str] = set() + closed: set[str] = set() + blocked.add(startnode) + B: dict[str, set[str]] = defaultdict(set) + stack: list[tuple[str, list[str]]] = [(startnode, list(graph[startnode]))] + while stack: + thisnode, nbrs = stack[-1] + if nbrs: + nextnode = nbrs.pop() + if nextnode == startnode: + yield path[:] + closed.update(path) + elif nextnode not in blocked: + path.append(nextnode) + stack.append((nextnode, list(graph[nextnode]))) + closed.discard(nextnode) + blocked.add(nextnode) + continue + if not nbrs: + if thisnode in closed: + _unblock(thisnode, blocked, B) + else: + for nbr in graph[thisnode]: + if thisnode not in B[nbr]: + B[nbr].add(thisnode) + stack.pop() + path.pop() + _remove_node(graph, startnode) + H = _subgraph(graph, set(scc)) + sccs.extend(_strongly_connected_components(H)) + + +def _unblock(thisnode: str, blocked: set[str], B: dict[str, set[str]]) -> None: + """Unblock a node and recursively unblock nodes in its B set.""" + stack = {thisnode} + while stack: + node = stack.pop() + if node in blocked: + blocked.remove(node) + stack.update(B[node]) + B[node].clear() + + +def _strongly_connected_components(graph: dict[str, set[str]]) -> list[set[str]]: + """Find all strongly connected components using Tarjan's algorithm. + + Args: + graph: A dictionary mapping each vertex to a set of its neighbours. + + Returns: + A list of sets, each containing the vertices of a strongly connected component. + """ + index_counter = [0] + stack: list[str] = [] + lowlink: dict[str, int] = {} + index: dict[str, int] = {} + result: list[set[str]] = [] + + def _strong_connect(node: str) -> None: + index[node] = index_counter[0] + lowlink[node] = index_counter[0] + index_counter[0] += 1 + stack.append(node) + + for successor in graph.get(node, set()): + if successor not in index: + _strong_connect(successor) + lowlink[node] = min(lowlink[node], lowlink[successor]) + elif successor in stack: + lowlink[node] = min(lowlink[node], index[successor]) + + if lowlink[node] == index[node]: + connected_component: set[str] = set() + while True: + successor = stack.pop() + connected_component.add(successor) + if successor == node: + break + result.append(connected_component) + + for node in graph: + if node not in index: + _strong_connect(node) + + return result + + +def _remove_node(graph: dict[str, set[str]], target: str) -> None: + """Remove a node and all its edges from the graph.""" + del graph[target] + for nbrs in graph.values(): + nbrs.discard(target) + + +def _subgraph(graph: dict[str, set[str]], vertices: set[str]) -> dict[str, set[str]]: + """Get the subgraph induced by a set of vertices.""" + return {v: graph[v] & vertices for v in vertices} diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py new file mode 100644 index 00000000..6658e47e --- /dev/null +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -0,0 +1,178 @@ +"""Validation utilities for `ProcessSpec` objects. + +Provides functions to validate process topology including: +- Checking that all component inputs are connected +- Checking that input events have matching output event producers +- Checking for circular connections that require initial values +""" + +from __future__ import annotations + +from collections import defaultdict +import typing as _t + +from ._graph import simple_cycles + + +if _t.TYPE_CHECKING: + from .component import ComponentSpec + from .connector import ConnectorSpec + + +class ValidationError(Exception): + """Raised when a process specification fails validation.""" + + pass + + +def _build_component_graph( + connectors: list[ConnectorSpec], +) -> dict[str, set[str]]: + """Build a directed graph of component connections from connector specs. + + Args: + connectors: List of connector specifications. + + Returns: + A dictionary mapping source component names to sets of target component names. + """ + graph: dict[str, set[str]] = defaultdict(set) + for conn in connectors: + source_entity = conn.source.entity + target_entity = conn.target.entity + if source_entity != target_entity: + graph[source_entity].add(target_entity) + # Ensure target is in graph even with no outgoing edges + if target_entity not in graph: + graph[target_entity] = set() + return dict(graph) + + +def _get_edges_in_cycle( + cycle: list[str], + connectors: list[ConnectorSpec], +) -> list[ConnectorSpec]: + """Get all connector specs that form edges within a cycle. + + Args: + cycle: List of component names forming a cycle. + connectors: All connector specifications. + + Returns: + List of connector specs that are part of the cycle. + """ + cycle_edges: list[ConnectorSpec] = [] + cycle_set = set(cycle) + for i, node in enumerate(cycle): + next_node = cycle[(i + 1) % len(cycle)] + for conn in connectors: + if conn.source.entity == node and conn.target.entity == next_node: + cycle_edges.append(conn) + return [c for c in cycle_edges if c.source.entity in cycle_set and c.target.entity in cycle_set] + + +def validate_all_inputs_connected( + components: dict[str, dict[str, _t.Any]], + connectors: list[ConnectorSpec], +) -> list[str]: + """Check that all component inputs are connected. + + Args: + components: Dictionary mapping component names to their IO info. + Each value must have an ``"inputs"`` key with a list of input field names. + connectors: List of connector specifications. + + Returns: + List of error messages for unconnected inputs. + """ + # Build mapping of which component inputs are connected + connected_inputs: dict[str, set[str]] = defaultdict(set) + for conn in connectors: + target_name = conn.target.entity + target_field = conn.target.descriptor + connected_inputs[target_name].add(target_field) + + errors: list[str] = [] + for comp_name, comp_info in components.items(): + all_inputs = set(comp_info.get("inputs", [])) + connected = connected_inputs.get(comp_name, set()) + unconnected = all_inputs - connected + if unconnected: + errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}") + return errors + + +def validate_input_events( + components: dict[str, dict[str, _t.Any]], +) -> list[str]: + """Check that all components with input events have a matching output event producer. + + Args: + components: Dictionary mapping component names to their IO info. + Each value must have ``"input_events"`` and ``"output_events"`` keys + with lists of event type strings. + + Returns: + List of error messages for unmatched input events. + """ + # Collect all output event types across all components + all_output_events: set[str] = set() + for comp_info in components.values(): + all_output_events.update(comp_info.get("output_events", [])) + + errors: list[str] = [] + for comp_name, comp_info in components.items(): + input_events = set(comp_info.get("input_events", [])) + unmatched = input_events - all_output_events + if unmatched: + errors.append( + f"Component '{comp_name}' has input events with no producer: {sorted(unmatched)}" + ) + return errors + + +def validate_no_unresolved_cycles( + components: list[ComponentSpec], + connectors: list[ConnectorSpec], +) -> list[str]: + """Check for circular connections that are not resolved by initial values. + + Circular loops are only valid if there are ``initial_values`` set on an + appropriate component input within the loop. + + Args: + components: List of component specifications. + connectors: List of connector specifications. + + Returns: + List of error messages for unresolved circular connections. + """ + graph = _build_component_graph(connectors) + if not graph: + return [] + + # Build lookup of component initial_values by name + initial_values_by_comp: dict[str, set[str]] = {} + for comp in components: + if comp.args.initial_values: + initial_values_by_comp[comp.args.name] = set(comp.args.initial_values.keys()) + + errors: list[str] = [] + for cycle in simple_cycles(graph): + # Check if any edge in the cycle targets a component input with initial_values + cycle_edges = _get_edges_in_cycle(cycle, connectors) + cycle_resolved = False + for edge in cycle_edges: + target_comp = edge.target.entity + target_field = edge.target.descriptor + if target_comp in initial_values_by_comp: + if target_field in initial_values_by_comp[target_comp]: + cycle_resolved = True + break + if not cycle_resolved: + cycle_str = " -> ".join(cycle + [cycle[0]]) + errors.append( + f"Circular connection detected without initial values: {cycle_str}. " + f"Set initial_values on a component input within the loop to resolve." + ) + return errors diff --git a/plugboard-schemas/plugboard_schemas/process.py b/plugboard-schemas/plugboard_schemas/process.py index e1ac6f6f..a5f02b02 100644 --- a/plugboard-schemas/plugboard_schemas/process.py +++ b/plugboard-schemas/plugboard_schemas/process.py @@ -7,6 +7,7 @@ from typing_extensions import Self from ._common import PlugboardBaseModel +from ._validation import validate_no_unresolved_cycles from .component import ComponentSpec from .connector import DEFAULT_CONNECTOR_CLS_PATH, ConnectorBuilderSpec, ConnectorSpec from .state import DEFAULT_STATE_BACKEND_CLS_PATH, RAY_STATE_BACKEND_CLS_PATH, StateBackendSpec @@ -77,6 +78,14 @@ def _set_default_state_backend(self: Self) -> Self: self.args.state.type = RAY_STATE_BACKEND_CLS_PATH return self + @model_validator(mode="after") + def _validate_no_unresolved_cycles(self: Self) -> Self: + """Validate that circular connections have initial_values set.""" + errors = validate_no_unresolved_cycles(self.args.components, self.args.connectors) + if errors: + raise ValueError("\n".join(errors)) + return self + @field_validator("type", mode="before") @classmethod def _validate_type(cls, value: _t.Any) -> str: diff --git a/tests/unit/test_process_validation.py b/tests/unit/test_process_validation.py new file mode 100644 index 00000000..4c25843f --- /dev/null +++ b/tests/unit/test_process_validation.py @@ -0,0 +1,331 @@ +"""Tests for process topology validation.""" + +from plugboard_schemas import ( + ProcessSpec, + validate_all_inputs_connected, + validate_input_events, +) +from plugboard_schemas._graph import simple_cycles +import pytest + + +# --------------------------------------------------------------------------- +# Tests for simple_cycles (Johnson's algorithm) +# --------------------------------------------------------------------------- + + +class TestSimpleCycles: + """Tests for Johnson's cycle-finding algorithm.""" + + def test_no_cycles(self) -> None: + """A DAG has no cycles.""" + graph: dict[str, set[str]] = {"a": {"b"}, "b": {"c"}, "c": set()} + assert list(simple_cycles(graph)) == [] + + def test_single_self_loop(self) -> None: + """A self-loop is a cycle of length 1.""" + graph: dict[str, set[str]] = {"a": {"a"}} + cycles = list(simple_cycles(graph)) + assert len(cycles) == 1 + assert cycles[0] == ["a"] + + def test_simple_two_node_cycle(self) -> None: + """Two nodes forming a cycle.""" + graph: dict[str, set[str]] = {"a": {"b"}, "b": {"a"}} + cycles = list(simple_cycles(graph)) + assert len(cycles) == 1 + assert set(cycles[0]) == {"a", "b"} + + def test_three_node_cycle(self) -> None: + """Three nodes forming a single cycle.""" + graph: dict[str, set[str]] = {"a": {"b"}, "b": {"c"}, "c": {"a"}} + cycles = list(simple_cycles(graph)) + assert len(cycles) == 1 + assert set(cycles[0]) == {"a", "b", "c"} + + def test_multiple_cycles(self) -> None: + """Graph with multiple distinct cycles.""" + graph: dict[str, set[str]] = { + "a": {"b"}, + "b": {"a", "c"}, + "c": {"d"}, + "d": {"c"}, + } + cycles = list(simple_cycles(graph)) + cycle_sets = [frozenset(c) for c in cycles] + assert frozenset({"a", "b"}) in cycle_sets + assert frozenset({"c", "d"}) in cycle_sets + + def test_empty_graph(self) -> None: + """Empty graph has no cycles.""" + graph: dict[str, set[str]] = {} + assert list(simple_cycles(graph)) == [] + + def test_disconnected_graph(self) -> None: + """Disconnected graph with no cycles.""" + graph: dict[str, set[str]] = {"a": set(), "b": set(), "c": set()} + assert list(simple_cycles(graph)) == [] + + def test_complex_graph(self) -> None: + """Complex graph with overlapping cycles.""" + graph: dict[str, set[str]] = { + "a": {"b"}, + "b": {"c"}, + "c": {"a", "d"}, + "d": {"b"}, + } + cycles = list(simple_cycles(graph)) + # Should find cycles: a->b->c->a and b->c->d->b + assert len(cycles) >= 2 + cycle_sets = [frozenset(c) for c in cycles] + assert frozenset({"a", "b", "c"}) in cycle_sets + assert frozenset({"b", "c", "d"}) in cycle_sets + + +# --------------------------------------------------------------------------- +# Tests for validate_no_unresolved_cycles +# --------------------------------------------------------------------------- + + +class TestValidateNoUnresolvedCycles: + """Tests for circular connection validation.""" + + @staticmethod + def _make_component(name: str, type_: str = "some.Component", **kwargs: object) -> dict: + args: dict = {"name": name} + args.update(kwargs) + return {"type": type_, "args": args} + + @staticmethod + def _make_connector(source: str, target: str) -> dict: + return {"source": source, "target": target} + + def test_no_cycles_passes(self) -> None: + """Linear topology passes validation.""" + spec = ProcessSpec.model_validate( + { + "args": { + "components": [ + self._make_component("a"), + self._make_component("b"), + self._make_component("c"), + ], + "connectors": [ + self._make_connector("a.out", "b.in_1"), + self._make_connector("b.out", "c.in_1"), + ], + } + } + ) + assert spec is not None + + def test_cycle_without_initial_values_fails(self) -> None: + """Cycle without initial_values raises ValueError.""" + with pytest.raises(ValueError, match="Circular connection detected"): + ProcessSpec.model_validate( + { + "args": { + "components": [ + self._make_component("a"), + self._make_component("b"), + ], + "connectors": [ + self._make_connector("a.out", "b.in_1"), + self._make_connector("b.out", "a.in_1"), + ], + } + } + ) + + def test_cycle_with_initial_values_passes(self) -> None: + """Cycle with initial_values on a target input passes.""" + spec = ProcessSpec.model_validate( + { + "args": { + "components": [ + self._make_component("a", initial_values={"in_1": [0]}), + self._make_component("b"), + ], + "connectors": [ + self._make_connector("a.out", "b.in_1"), + self._make_connector("b.out", "a.in_1"), + ], + } + } + ) + assert spec is not None + + def test_cycle_with_initial_values_on_other_field_fails(self) -> None: + """Cycle with initial_values on an unrelated field still fails.""" + with pytest.raises(ValueError, match="Circular connection detected"): + ProcessSpec.model_validate( + { + "args": { + "components": [ + self._make_component("a", initial_values={"other": [0]}), + self._make_component("b"), + ], + "connectors": [ + self._make_connector("a.out", "b.in_1"), + self._make_connector("b.out", "a.in_1"), + ], + } + } + ) + + def test_three_node_cycle_without_initial_values_fails(self) -> None: + """Three-node cycle without initial_values raises ValueError.""" + with pytest.raises(ValueError, match="Circular connection detected"): + ProcessSpec.model_validate( + { + "args": { + "components": [ + self._make_component("a"), + self._make_component("b"), + self._make_component("c"), + ], + "connectors": [ + self._make_connector("a.out", "b.in_1"), + self._make_connector("b.out", "c.in_1"), + self._make_connector("c.out", "a.in_1"), + ], + } + } + ) + + def test_three_node_cycle_with_initial_values_passes(self) -> None: + """Three-node cycle with initial_values on any target input passes.""" + spec = ProcessSpec.model_validate( + { + "args": { + "components": [ + self._make_component("a"), + self._make_component("b", initial_values={"in_1": [0]}), + self._make_component("c"), + ], + "connectors": [ + self._make_connector("a.out", "b.in_1"), + self._make_connector("b.out", "c.in_1"), + self._make_connector("c.out", "a.in_1"), + ], + } + } + ) + assert spec is not None + + def test_no_connectors_passes(self) -> None: + """Process with no connectors passes validation.""" + spec = ProcessSpec.model_validate( + { + "args": { + "components": [self._make_component("a")], + } + } + ) + assert spec is not None + + +# --------------------------------------------------------------------------- +# Tests for validate_all_inputs_connected (runtime utility) +# --------------------------------------------------------------------------- + + +class TestValidateAllInputsConnected: + """Tests for the all-inputs-connected validation utility.""" + + @staticmethod + def _conn(source: str, target: str) -> dict: + return {"source": source, "target": target} + + def test_all_connected(self) -> None: + """Test that validation passes when all inputs are connected.""" + from typing import Any + + from plugboard_schemas import ConnectorSpec + + components: dict[str, dict[str, Any]] = { + "a": {"inputs": []}, + "b": {"inputs": ["in_1"]}, + } + connectors = [ConnectorSpec.model_validate(self._conn("a.out", "b.in_1"))] + errors = validate_all_inputs_connected(components, connectors) + assert errors == [] + + def test_missing_input(self) -> None: + """Test that validation fails when an input is not connected.""" + from typing import Any + + from plugboard_schemas import ConnectorSpec + + components: dict[str, dict[str, Any]] = { + "a": {"inputs": []}, + "b": {"inputs": ["in_1", "in_2"]}, + } + connectors = [ConnectorSpec.model_validate(self._conn("a.out", "b.in_1"))] + errors = validate_all_inputs_connected(components, connectors) + assert len(errors) == 1 + assert "in_2" in errors[0] + + def test_no_inputs_no_errors(self) -> None: + """Test that validation passes when components have no inputs.""" + from typing import Any + + from plugboard_schemas import ConnectorSpec + + components: dict[str, dict[str, Any]] = {"a": {"inputs": []}} + connectors: list[ConnectorSpec] = [] + errors = validate_all_inputs_connected(components, connectors) + assert errors == [] + + +# --------------------------------------------------------------------------- +# Tests for validate_input_events (runtime utility) +# --------------------------------------------------------------------------- + + +class TestValidateInputEvents: + """Tests for the input-events validation utility.""" + + def test_matched_events(self) -> None: + """Test that validation passes when all input events have producers.""" + from typing import Any + + components: dict[str, dict[str, Any]] = { + "clock": {"input_events": [], "output_events": ["tick"]}, + "ctrl": {"input_events": ["tick"], "output_events": []}, + } + errors = validate_input_events(components) + assert errors == [] + + def test_unmatched_event(self) -> None: + """Test that validation fails when input events have no producer.""" + from typing import Any + + components: dict[str, dict[str, Any]] = { + "ctrl": {"input_events": ["tick"], "output_events": []}, + } + errors = validate_input_events(components) + assert len(errors) == 1 + assert "tick" in errors[0] + + def test_no_events(self) -> None: + """Test that validation passes when components have no events.""" + from typing import Any + + components: dict[str, dict[str, Any]] = { + "a": {"input_events": [], "output_events": []}, + } + errors = validate_input_events(components) + assert errors == [] + + def test_multiple_unmatched(self) -> None: + """Test that validation correctly identifies unmatched events.""" + from typing import Any + + components: dict[str, dict[str, Any]] = { + "a": {"input_events": ["evt_x", "evt_y"], "output_events": []}, + "b": {"input_events": [], "output_events": ["evt_x"]}, + } + errors = validate_input_events(components) + assert len(errors) == 1 + assert "evt_y" in errors[0] From fc0bf55a8c68f0de28f7af968c3977842acea762 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:08:10 +0000 Subject: [PATCH 3/7] Address code review: remove unused ValidationError class and redundant filter Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com> --- plugboard-schemas/plugboard_schemas/__init__.py | 2 -- plugboard-schemas/plugboard_schemas/_validation.py | 9 +-------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/plugboard-schemas/plugboard_schemas/__init__.py b/plugboard-schemas/plugboard_schemas/__init__.py index a4eda967..8f19c242 100644 --- a/plugboard-schemas/plugboard_schemas/__init__.py +++ b/plugboard-schemas/plugboard_schemas/__init__.py @@ -11,7 +11,6 @@ from ._common import PlugboardBaseModel from ._graph import simple_cycles from ._validation import ( - ValidationError, validate_all_inputs_connected, validate_input_events, validate_no_unresolved_cycles, @@ -92,7 +91,6 @@ "TuneArgsDict", "TuneArgsSpec", "TuneSpec", - "ValidationError", "simple_cycles", "validate_all_inputs_connected", "validate_input_events", diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py index 6658e47e..29c48919 100644 --- a/plugboard-schemas/plugboard_schemas/_validation.py +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -19,12 +19,6 @@ from .connector import ConnectorSpec -class ValidationError(Exception): - """Raised when a process specification fails validation.""" - - pass - - def _build_component_graph( connectors: list[ConnectorSpec], ) -> dict[str, set[str]]: @@ -62,13 +56,12 @@ def _get_edges_in_cycle( List of connector specs that are part of the cycle. """ cycle_edges: list[ConnectorSpec] = [] - cycle_set = set(cycle) for i, node in enumerate(cycle): next_node = cycle[(i + 1) % len(cycle)] for conn in connectors: if conn.source.entity == node and conn.target.entity == next_node: cycle_edges.append(conn) - return [c for c in cycle_edges if c.source.entity in cycle_set and c.target.entity in cycle_set] + return cycle_edges def validate_all_inputs_connected( From b40300f24b018070c3de1a014157fd27e55d372d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:46:22 +0000 Subject: [PATCH 4/7] Refactor validators to runtime checks accepting process.dict() output - Remove Pydantic model validator from ProcessSpec - Refactor all validators to accept process.dict() format - Add validate_process() entry point combining all checks - Call validate_process() in Process.init() raising ValidationError - Update plugboard.schemas re-exports - Update tests for new runtime validation approach Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com> --- .../plugboard_schemas/__init__.py | 2 + .../plugboard_schemas/_validation.py | 130 +++--- .../plugboard_schemas/process.py | 9 - plugboard/process/process.py | 7 +- plugboard/schemas/__init__.py | 10 + tests/integration/test_process_validation.py | 12 +- tests/unit/test_process_validation.py | 401 ++++++++++-------- 7 files changed, 326 insertions(+), 245 deletions(-) diff --git a/plugboard-schemas/plugboard_schemas/__init__.py b/plugboard-schemas/plugboard_schemas/__init__.py index 8f19c242..e0e34f39 100644 --- a/plugboard-schemas/plugboard_schemas/__init__.py +++ b/plugboard-schemas/plugboard_schemas/__init__.py @@ -14,6 +14,7 @@ validate_all_inputs_connected, validate_input_events, validate_no_unresolved_cycles, + validate_process, ) from .component import ComponentArgsDict, ComponentArgsSpec, ComponentSpec, Resource from .config import ConfigSpec, ProcessConfigSpec @@ -95,4 +96,5 @@ "validate_all_inputs_connected", "validate_input_events", "validate_no_unresolved_cycles", + "validate_process", ] diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py index 29c48919..467cabd0 100644 --- a/plugboard-schemas/plugboard_schemas/_validation.py +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -1,9 +1,12 @@ -"""Validation utilities for `ProcessSpec` objects. +"""Validation utilities for process topology. Provides functions to validate process topology including: - Checking that all component inputs are connected - Checking that input events have matching output event producers - Checking for circular connections that require initial values + +All validators accept the output of ``process.dict()`` or the relevant +sub-structures thereof. """ from __future__ import annotations @@ -14,29 +17,25 @@ from ._graph import simple_cycles -if _t.TYPE_CHECKING: - from .component import ComponentSpec - from .connector import ConnectorSpec - - def _build_component_graph( - connectors: list[ConnectorSpec], + connectors: dict[str, dict[str, _t.Any]], ) -> dict[str, set[str]]: - """Build a directed graph of component connections from connector specs. + """Build a directed graph of component connections from connector dicts. Args: - connectors: List of connector specifications. + connectors: Dictionary mapping connector IDs to connector dicts, + as returned by ``process.dict()["connectors"]``. Returns: A dictionary mapping source component names to sets of target component names. """ graph: dict[str, set[str]] = defaultdict(set) - for conn in connectors: - source_entity = conn.source.entity - target_entity = conn.target.entity + for conn_info in connectors.values(): + spec = conn_info["spec"] + source_entity = spec["source"]["entity"] + target_entity = spec["target"]["entity"] if source_entity != target_entity: graph[source_entity].add(target_entity) - # Ensure target is in graph even with no outgoing edges if target_entity not in graph: graph[target_entity] = set() return dict(graph) @@ -44,50 +43,53 @@ def _build_component_graph( def _get_edges_in_cycle( cycle: list[str], - connectors: list[ConnectorSpec], -) -> list[ConnectorSpec]: - """Get all connector specs that form edges within a cycle. + connectors: dict[str, dict[str, _t.Any]], +) -> list[dict[str, _t.Any]]: + """Get all connector spec dicts that form edges within a cycle. Args: cycle: List of component names forming a cycle. - connectors: All connector specifications. + connectors: Dictionary mapping connector IDs to connector dicts. Returns: - List of connector specs that are part of the cycle. + List of connector spec dicts that are part of the cycle. """ - cycle_edges: list[ConnectorSpec] = [] + cycle_edges: list[dict[str, _t.Any]] = [] for i, node in enumerate(cycle): next_node = cycle[(i + 1) % len(cycle)] - for conn in connectors: - if conn.source.entity == node and conn.target.entity == next_node: - cycle_edges.append(conn) + for conn_info in connectors.values(): + spec = conn_info["spec"] + if spec["source"]["entity"] == node and spec["target"]["entity"] == next_node: + cycle_edges.append(spec) return cycle_edges def validate_all_inputs_connected( - components: dict[str, dict[str, _t.Any]], - connectors: list[ConnectorSpec], + process_dict: dict[str, _t.Any], ) -> list[str]: """Check that all component inputs are connected. Args: - components: Dictionary mapping component names to their IO info. - Each value must have an ``"inputs"`` key with a list of input field names. - connectors: List of connector specifications. + process_dict: The output of ``process.dict()``. Uses the ``"components"`` + and ``"connectors"`` keys. Returns: List of error messages for unconnected inputs. """ - # Build mapping of which component inputs are connected + components: dict[str, dict[str, _t.Any]] = process_dict["components"] + connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"] + connected_inputs: dict[str, set[str]] = defaultdict(set) - for conn in connectors: - target_name = conn.target.entity - target_field = conn.target.descriptor + for conn_info in connectors.values(): + spec = conn_info["spec"] + target_name = spec["target"]["entity"] + target_field = spec["target"]["descriptor"] connected_inputs[target_name].add(target_field) errors: list[str] = [] - for comp_name, comp_info in components.items(): - all_inputs = set(comp_info.get("inputs", [])) + for comp_name, comp_data in components.items(): + io = comp_data.get("io", {}) + all_inputs = set(io.get("inputs", [])) connected = connected_inputs.get(comp_name, set()) unconnected = all_inputs - connected if unconnected: @@ -96,26 +98,27 @@ def validate_all_inputs_connected( def validate_input_events( - components: dict[str, dict[str, _t.Any]], + process_dict: dict[str, _t.Any], ) -> list[str]: """Check that all components with input events have a matching output event producer. Args: - components: Dictionary mapping component names to their IO info. - Each value must have ``"input_events"`` and ``"output_events"`` keys - with lists of event type strings. + process_dict: The output of ``process.dict()``. Uses the ``"components"`` key. Returns: List of error messages for unmatched input events. """ - # Collect all output event types across all components + components: dict[str, dict[str, _t.Any]] = process_dict["components"] + all_output_events: set[str] = set() - for comp_info in components.values(): - all_output_events.update(comp_info.get("output_events", [])) + for comp_data in components.values(): + io = comp_data.get("io", {}) + all_output_events.update(io.get("output_events", [])) errors: list[str] = [] - for comp_name, comp_info in components.items(): - input_events = set(comp_info.get("input_events", [])) + for comp_name, comp_data in components.items(): + io = comp_data.get("io", {}) + input_events = set(io.get("input_events", [])) unmatched = input_events - all_output_events if unmatched: errors.append( @@ -125,8 +128,7 @@ def validate_input_events( def validate_no_unresolved_cycles( - components: list[ComponentSpec], - connectors: list[ConnectorSpec], + process_dict: dict[str, _t.Any], ) -> list[str]: """Check for circular connections that are not resolved by initial values. @@ -134,30 +136,34 @@ def validate_no_unresolved_cycles( appropriate component input within the loop. Args: - components: List of component specifications. - connectors: List of connector specifications. + process_dict: The output of ``process.dict()``. Uses the ``"components"`` + and ``"connectors"`` keys. Returns: List of error messages for unresolved circular connections. """ + components: dict[str, dict[str, _t.Any]] = process_dict["components"] + connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"] + graph = _build_component_graph(connectors) if not graph: return [] # Build lookup of component initial_values by name initial_values_by_comp: dict[str, set[str]] = {} - for comp in components: - if comp.args.initial_values: - initial_values_by_comp[comp.args.name] = set(comp.args.initial_values.keys()) + for comp_name, comp_data in components.items(): + io = comp_data.get("io", {}) + iv = io.get("initial_values", {}) + if iv: + initial_values_by_comp[comp_name] = set(iv.keys()) errors: list[str] = [] for cycle in simple_cycles(graph): - # Check if any edge in the cycle targets a component input with initial_values cycle_edges = _get_edges_in_cycle(cycle, connectors) cycle_resolved = False for edge in cycle_edges: - target_comp = edge.target.entity - target_field = edge.target.descriptor + target_comp = edge["target"]["entity"] + target_field = edge["target"]["descriptor"] if target_comp in initial_values_by_comp: if target_field in initial_values_by_comp[target_comp]: cycle_resolved = True @@ -169,3 +175,23 @@ def validate_no_unresolved_cycles( f"Set initial_values on a component input within the loop to resolve." ) return errors + + +def validate_process(process_dict: dict[str, _t.Any]) -> list[str]: + """Run all topology validation checks on a process. + + This is the main validation entry point. It accepts the output of + ``process.dict()`` and runs every available check, returning a + combined list of error messages. + + Args: + process_dict: The output of ``process.dict()``. + + Returns: + List of error messages. An empty list indicates a valid topology. + """ + errors: list[str] = [] + errors.extend(validate_all_inputs_connected(process_dict)) + errors.extend(validate_input_events(process_dict)) + errors.extend(validate_no_unresolved_cycles(process_dict)) + return errors diff --git a/plugboard-schemas/plugboard_schemas/process.py b/plugboard-schemas/plugboard_schemas/process.py index a5f02b02..e1ac6f6f 100644 --- a/plugboard-schemas/plugboard_schemas/process.py +++ b/plugboard-schemas/plugboard_schemas/process.py @@ -7,7 +7,6 @@ from typing_extensions import Self from ._common import PlugboardBaseModel -from ._validation import validate_no_unresolved_cycles from .component import ComponentSpec from .connector import DEFAULT_CONNECTOR_CLS_PATH, ConnectorBuilderSpec, ConnectorSpec from .state import DEFAULT_STATE_BACKEND_CLS_PATH, RAY_STATE_BACKEND_CLS_PATH, StateBackendSpec @@ -78,14 +77,6 @@ def _set_default_state_backend(self: Self) -> Self: self.args.state.type = RAY_STATE_BACKEND_CLS_PATH return self - @model_validator(mode="after") - def _validate_no_unresolved_cycles(self: Self) -> Self: - """Validate that circular connections have initial_values set.""" - errors = validate_no_unresolved_cycles(self.args.components, self.args.connectors) - if errors: - raise ValueError("\n".join(errors)) - return self - @field_validator("type", mode="before") @classmethod def _validate_type(cls, value: _t.Any) -> str: diff --git a/plugboard/process/process.py b/plugboard/process/process.py index ccd45f58..0bedb39b 100644 --- a/plugboard/process/process.py +++ b/plugboard/process/process.py @@ -13,8 +13,8 @@ from plugboard.component import Component from plugboard.connector import Connector -from plugboard.exceptions import NotInitialisedError -from plugboard.schemas import ConfigSpec, Status +from plugboard.exceptions import NotInitialisedError, ValidationError +from plugboard.schemas import ConfigSpec, Status, validate_process from plugboard.state import DictStateBackend, StateBackend from plugboard.utils import DI, ExportMixin, gen_rand_str from plugboard.utils.async_utils import run_coro_sync @@ -109,6 +109,9 @@ async def _set_status(self, status: Status, publish: bool = True) -> None: @abstractmethod async def init(self) -> None: """Performs component initialisation actions.""" + errors = validate_process(self.dict()) + if errors: + raise ValidationError("\n".join(errors)) self._is_initialised = True await self._set_status(Status.INIT) diff --git a/plugboard/schemas/__init__.py b/plugboard/schemas/__init__.py index 427ee410..4ee8fc79 100644 --- a/plugboard/schemas/__init__.py +++ b/plugboard/schemas/__init__.py @@ -46,6 +46,11 @@ TuneArgsDict, TuneArgsSpec, TuneSpec, + simple_cycles, + validate_all_inputs_connected, + validate_input_events, + validate_no_unresolved_cycles, + validate_process, ) @@ -86,4 +91,9 @@ "TuneArgsDict", "TuneArgsSpec", "TuneSpec", + "simple_cycles", + "validate_all_inputs_connected", + "validate_input_events", + "validate_no_unresolved_cycles", + "validate_process", ] diff --git a/tests/integration/test_process_validation.py b/tests/integration/test_process_validation.py index f77c8678..07294111 100644 --- a/tests/integration/test_process_validation.py +++ b/tests/integration/test_process_validation.py @@ -15,9 +15,6 @@ from tests.integration.test_process_with_components_run import A, B, C -# TODO: Update these tests when we implement full graph validation - - def filter_logs(logs: list[EventDict], field: str, regex: str) -> list[EventDict]: """Filters the log output by applying regex to a field.""" pattern = re.compile(regex) @@ -26,20 +23,15 @@ def filter_logs(logs: list[EventDict], field: str, regex: str) -> list[EventDict @pytest.mark.asyncio async def test_missing_connections() -> None: - """Tests that missing connections are logged.""" + """Tests that missing input connections raise ValidationError.""" p_missing_input = LocalProcess( components=[A(name="a", iters=10), C(name="c", path="test-out.csv")], # c.in_1 is not connected connectors=[AsyncioConnector(spec=ConnectorSpec(source="a.out_1", target="unknown.x"))], ) - with capture_logs() as logs: + with pytest.raises(exceptions.ValidationError, match="unconnected inputs"): await p_missing_input.init() - # Must contain an error-level log indicating that input is not connected - logs = filter_logs(logs, "log_level", "error") - logs = filter_logs(logs, "event", "Input fields not connected") - assert logs, "Logs do not indicate missing connection" - p_missing_output = LocalProcess( components=[A(name="a", iters=10), B(name="b")], # b.out_1 is not connected diff --git a/tests/unit/test_process_validation.py b/tests/unit/test_process_validation.py index 4c25843f..02e0a4d2 100644 --- a/tests/unit/test_process_validation.py +++ b/tests/unit/test_process_validation.py @@ -1,12 +1,14 @@ """Tests for process topology validation.""" +import typing as _t + from plugboard_schemas import ( - ProcessSpec, validate_all_inputs_connected, validate_input_events, + validate_no_unresolved_cycles, + validate_process, ) from plugboard_schemas._graph import simple_cycles -import pytest # --------------------------------------------------------------------------- @@ -82,6 +84,67 @@ def test_complex_graph(self) -> None: assert frozenset({"b", "c", "d"}) in cycle_sets +# --------------------------------------------------------------------------- +# Helpers for building process.dict()-style data structures +# --------------------------------------------------------------------------- + + +def _make_component( + name: str, + inputs: list[str] | None = None, + outputs: list[str] | None = None, + input_events: list[str] | None = None, + output_events: list[str] | None = None, + initial_values: dict[str, _t.Any] | None = None, +) -> dict[str, _t.Any]: + """Build a component dict matching process.dict() format.""" + return { + "id": name, + "name": name, + "status": "created", + "io": { + "namespace": name, + "inputs": inputs or [], + "outputs": outputs or [], + "input_events": input_events or [], + "output_events": output_events or [], + "initial_values": initial_values or {}, + }, + } + + +def _make_connector(source: str, target: str) -> dict[str, _t.Any]: + """Build a connector dict matching process.dict() format.""" + src_entity, src_desc = source.split(".") + tgt_entity, tgt_desc = target.split(".") + conn_id = f"{source}..{target}" + return { + conn_id: { + "id": conn_id, + "spec": { + "source": {"entity": src_entity, "descriptor": src_desc}, + "target": {"entity": tgt_entity, "descriptor": tgt_desc}, + "mode": "pipeline", + }, + } + } + + +def _make_process_dict( + components: dict[str, dict[str, _t.Any]], + connectors: dict[str, dict[str, _t.Any]] | None = None, +) -> dict[str, _t.Any]: + """Build a process dict matching process.dict() format.""" + return { + "id": "test_process", + "name": "test_process", + "status": "created", + "components": components, + "connectors": connectors or {}, + "parameters": {}, + } + + # --------------------------------------------------------------------------- # Tests for validate_no_unresolved_cycles # --------------------------------------------------------------------------- @@ -90,196 +153,159 @@ def test_complex_graph(self) -> None: class TestValidateNoUnresolvedCycles: """Tests for circular connection validation.""" - @staticmethod - def _make_component(name: str, type_: str = "some.Component", **kwargs: object) -> dict: - args: dict = {"name": name} - args.update(kwargs) - return {"type": type_, "args": args} - - @staticmethod - def _make_connector(source: str, target: str) -> dict: - return {"source": source, "target": target} - def test_no_cycles_passes(self) -> None: """Linear topology passes validation.""" - spec = ProcessSpec.model_validate( - { - "args": { - "components": [ - self._make_component("a"), - self._make_component("b"), - self._make_component("c"), - ], - "connectors": [ - self._make_connector("a.out", "b.in_1"), - self._make_connector("b.out", "c.in_1"), - ], - } - } + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "c.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out"]), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + "c": _make_component("c", inputs=["in_1"]), + }, + connectors=connectors, ) - assert spec is not None + errors = validate_no_unresolved_cycles(pd) + assert errors == [] def test_cycle_without_initial_values_fails(self) -> None: - """Cycle without initial_values raises ValueError.""" - with pytest.raises(ValueError, match="Circular connection detected"): - ProcessSpec.model_validate( - { - "args": { - "components": [ - self._make_component("a"), - self._make_component("b"), - ], - "connectors": [ - self._make_connector("a.out", "b.in_1"), - self._make_connector("b.out", "a.in_1"), - ], - } - } - ) + """Cycle without initial_values returns errors.""" + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "a.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], outputs=["out"]), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert len(errors) == 1 + assert "Circular connection detected" in errors[0] def test_cycle_with_initial_values_passes(self) -> None: """Cycle with initial_values on a target input passes.""" - spec = ProcessSpec.model_validate( - { - "args": { - "components": [ - self._make_component("a", initial_values={"in_1": [0]}), - self._make_component("b"), - ], - "connectors": [ - self._make_connector("a.out", "b.in_1"), - self._make_connector("b.out", "a.in_1"), - ], - } - } + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "a.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component( + "a", inputs=["in_1"], outputs=["out"], initial_values={"in_1": [0]} + ), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, ) - assert spec is not None + errors = validate_no_unresolved_cycles(pd) + assert errors == [] def test_cycle_with_initial_values_on_other_field_fails(self) -> None: """Cycle with initial_values on an unrelated field still fails.""" - with pytest.raises(ValueError, match="Circular connection detected"): - ProcessSpec.model_validate( - { - "args": { - "components": [ - self._make_component("a", initial_values={"other": [0]}), - self._make_component("b"), - ], - "connectors": [ - self._make_connector("a.out", "b.in_1"), - self._make_connector("b.out", "a.in_1"), - ], - } - } - ) + connectors = {**_make_connector("a.out", "b.in_1"), **_make_connector("b.out", "a.in_1")} + pd = _make_process_dict( + components={ + "a": _make_component( + "a", inputs=["in_1"], outputs=["out"], initial_values={"other": [0]} + ), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert len(errors) == 1 + assert "Circular connection detected" in errors[0] def test_three_node_cycle_without_initial_values_fails(self) -> None: - """Three-node cycle without initial_values raises ValueError.""" - with pytest.raises(ValueError, match="Circular connection detected"): - ProcessSpec.model_validate( - { - "args": { - "components": [ - self._make_component("a"), - self._make_component("b"), - self._make_component("c"), - ], - "connectors": [ - self._make_connector("a.out", "b.in_1"), - self._make_connector("b.out", "c.in_1"), - self._make_connector("c.out", "a.in_1"), - ], - } - } - ) + """Three-node cycle without initial_values returns errors.""" + connectors = { + **_make_connector("a.out", "b.in_1"), + **_make_connector("b.out", "c.in_1"), + **_make_connector("c.out", "a.in_1"), + } + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], outputs=["out"]), + "b": _make_component("b", inputs=["in_1"], outputs=["out"]), + "c": _make_component("c", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, + ) + errors = validate_no_unresolved_cycles(pd) + assert len(errors) == 1 + assert "Circular connection detected" in errors[0] def test_three_node_cycle_with_initial_values_passes(self) -> None: """Three-node cycle with initial_values on any target input passes.""" - spec = ProcessSpec.model_validate( - { - "args": { - "components": [ - self._make_component("a"), - self._make_component("b", initial_values={"in_1": [0]}), - self._make_component("c"), - ], - "connectors": [ - self._make_connector("a.out", "b.in_1"), - self._make_connector("b.out", "c.in_1"), - self._make_connector("c.out", "a.in_1"), - ], - } - } + connectors = { + **_make_connector("a.out", "b.in_1"), + **_make_connector("b.out", "c.in_1"), + **_make_connector("c.out", "a.in_1"), + } + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], outputs=["out"]), + "b": _make_component( + "b", inputs=["in_1"], outputs=["out"], initial_values={"in_1": [0]} + ), + "c": _make_component("c", inputs=["in_1"], outputs=["out"]), + }, + connectors=connectors, ) - assert spec is not None + errors = validate_no_unresolved_cycles(pd) + assert errors == [] def test_no_connectors_passes(self) -> None: """Process with no connectors passes validation.""" - spec = ProcessSpec.model_validate( - { - "args": { - "components": [self._make_component("a")], - } - } + pd = _make_process_dict( + components={"a": _make_component("a")}, ) - assert spec is not None + errors = validate_no_unresolved_cycles(pd) + assert errors == [] # --------------------------------------------------------------------------- -# Tests for validate_all_inputs_connected (runtime utility) +# Tests for validate_all_inputs_connected # --------------------------------------------------------------------------- class TestValidateAllInputsConnected: """Tests for the all-inputs-connected validation utility.""" - @staticmethod - def _conn(source: str, target: str) -> dict: - return {"source": source, "target": target} - def test_all_connected(self) -> None: """Test that validation passes when all inputs are connected.""" - from typing import Any - - from plugboard_schemas import ConnectorSpec - - components: dict[str, dict[str, Any]] = { - "a": {"inputs": []}, - "b": {"inputs": ["in_1"]}, - } - connectors = [ConnectorSpec.model_validate(self._conn("a.out", "b.in_1"))] - errors = validate_all_inputs_connected(components, connectors) + connectors = _make_connector("a.out", "b.in_1") + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out"]), + "b": _make_component("b", inputs=["in_1"]), + }, + connectors=connectors, + ) + errors = validate_all_inputs_connected(pd) assert errors == [] def test_missing_input(self) -> None: """Test that validation fails when an input is not connected.""" - from typing import Any - - from plugboard_schemas import ConnectorSpec - - components: dict[str, dict[str, Any]] = { - "a": {"inputs": []}, - "b": {"inputs": ["in_1", "in_2"]}, - } - connectors = [ConnectorSpec.model_validate(self._conn("a.out", "b.in_1"))] - errors = validate_all_inputs_connected(components, connectors) + connectors = _make_connector("a.out", "b.in_1") + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out"]), + "b": _make_component("b", inputs=["in_1", "in_2"]), + }, + connectors=connectors, + ) + errors = validate_all_inputs_connected(pd) assert len(errors) == 1 assert "in_2" in errors[0] def test_no_inputs_no_errors(self) -> None: """Test that validation passes when components have no inputs.""" - from typing import Any - - from plugboard_schemas import ConnectorSpec - - components: dict[str, dict[str, Any]] = {"a": {"inputs": []}} - connectors: list[ConnectorSpec] = [] - errors = validate_all_inputs_connected(components, connectors) + pd = _make_process_dict( + components={"a": _make_component("a")}, + ) + errors = validate_all_inputs_connected(pd) assert errors == [] # --------------------------------------------------------------------------- -# Tests for validate_input_events (runtime utility) +# Tests for validate_input_events # --------------------------------------------------------------------------- @@ -288,44 +314,75 @@ class TestValidateInputEvents: def test_matched_events(self) -> None: """Test that validation passes when all input events have producers.""" - from typing import Any - - components: dict[str, dict[str, Any]] = { - "clock": {"input_events": [], "output_events": ["tick"]}, - "ctrl": {"input_events": ["tick"], "output_events": []}, - } - errors = validate_input_events(components) + pd = _make_process_dict( + components={ + "clock": _make_component("clock", output_events=["tick"]), + "ctrl": _make_component("ctrl", input_events=["tick"]), + }, + ) + errors = validate_input_events(pd) assert errors == [] def test_unmatched_event(self) -> None: """Test that validation fails when input events have no producer.""" - from typing import Any - - components: dict[str, dict[str, Any]] = { - "ctrl": {"input_events": ["tick"], "output_events": []}, - } - errors = validate_input_events(components) + pd = _make_process_dict( + components={ + "ctrl": _make_component("ctrl", input_events=["tick"]), + }, + ) + errors = validate_input_events(pd) assert len(errors) == 1 assert "tick" in errors[0] def test_no_events(self) -> None: """Test that validation passes when components have no events.""" - from typing import Any - - components: dict[str, dict[str, Any]] = { - "a": {"input_events": [], "output_events": []}, - } - errors = validate_input_events(components) + pd = _make_process_dict( + components={"a": _make_component("a")}, + ) + errors = validate_input_events(pd) assert errors == [] def test_multiple_unmatched(self) -> None: """Test that validation correctly identifies unmatched events.""" - from typing import Any - - components: dict[str, dict[str, Any]] = { - "a": {"input_events": ["evt_x", "evt_y"], "output_events": []}, - "b": {"input_events": [], "output_events": ["evt_x"]}, - } - errors = validate_input_events(components) + pd = _make_process_dict( + components={ + "a": _make_component("a", input_events=["evt_x", "evt_y"]), + "b": _make_component("b", output_events=["evt_x"]), + }, + ) + errors = validate_input_events(pd) assert len(errors) == 1 assert "evt_y" in errors[0] + + +# --------------------------------------------------------------------------- +# Tests for validate_process (combined validator) +# --------------------------------------------------------------------------- + + +class TestValidateProcess: + """Tests for the combined validate_process utility.""" + + def test_valid_process(self) -> None: + """Test that a valid process returns no errors.""" + connectors = _make_connector("a.out_1", "b.in_1") + pd = _make_process_dict( + components={ + "a": _make_component("a", outputs=["out_1"]), + "b": _make_component("b", inputs=["in_1"]), + }, + connectors=connectors, + ) + errors = validate_process(pd) + assert errors == [] + + def test_multiple_errors(self) -> None: + """Test that multiple validation errors are collected.""" + pd = _make_process_dict( + components={ + "a": _make_component("a", inputs=["in_1"], input_events=["missing_evt"]), + }, + ) + errors = validate_process(pd) + # Should have at least: unconnected input + unmatched event + assert len(errors) >= 2 From d96d7b21d55c071c6589c7ed5b192071b5ecb2d6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:47:23 +0000 Subject: [PATCH 5/7] Add descriptive header to validation error message Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com> --- plugboard/process/process.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugboard/process/process.py b/plugboard/process/process.py index 0bedb39b..939ae94a 100644 --- a/plugboard/process/process.py +++ b/plugboard/process/process.py @@ -111,7 +111,8 @@ async def init(self) -> None: """Performs component initialisation actions.""" errors = validate_process(self.dict()) if errors: - raise ValidationError("\n".join(errors)) + msg = "Process validation failed:\n" + "\n".join(errors) + raise ValidationError(msg) self._is_initialised = True await self._set_status(Status.INIT) From 8b708b5efb4dce4198332e61c20269857aabdb7a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:27:24 +0000 Subject: [PATCH 6/7] Add `plugboard process validate` CLI command Adds a new `validate` subcommand to the CLI that checks whether a process defined in a YAML config file is valid. It builds the process and runs all topology validation checks (connected inputs, event matching, cycle detection). Exits with code 0 on success, code 1 with error details on failure. Co-authored-by: toby-coleman <13170610+toby-coleman@users.noreply.github.com> --- plugboard/cli/process/__init__.py | 30 +++++++++++++++++++++++++++++- tests/unit/test_cli.py | 17 +++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/plugboard/cli/process/__init__.py b/plugboard/cli/process/__init__.py index 518abd79..b50e14ef 100644 --- a/plugboard/cli/process/__init__.py +++ b/plugboard/cli/process/__init__.py @@ -13,7 +13,7 @@ from plugboard.diagram import MermaidDiagram from plugboard.process import Process, ProcessBuilder -from plugboard.schemas import ConfigSpec +from plugboard.schemas import ConfigSpec, validate_process from plugboard.tune import Tuner from plugboard.utils import add_sys_path, run_coro_sync @@ -164,3 +164,31 @@ def diagram( diagram = MermaidDiagram.from_process(process) md = Markdown(f"```\n{diagram.diagram}\n```\n[Editable diagram]({diagram.url}) (external link)") print(md) + + +@app.command() +def validate( + config: Annotated[ + Path, + typer.Argument( + exists=True, + file_okay=True, + dir_okay=False, + writable=False, + readable=True, + resolve_path=True, + help="Path to the YAML configuration file.", + ), + ], +) -> None: + """Validate a Plugboard process configuration.""" + config_spec = _read_yaml(config) + with add_sys_path(config.parent): + process = _build_process(config_spec) + errors = validate_process(process.dict()) + if errors: + stderr.print("[red]Validation failed:[/red]") + for error in errors: + stderr.print(f" • {error}") + raise typer.Exit(1) + print("[green]Validation passed[/green]") diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index cb3f429c..0291f828 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -87,6 +87,23 @@ def test_cli_process_diagram() -> None: assert "flowchart" in result.stdout +def test_cli_process_validate() -> None: + """Tests the process validate command.""" + result = runner.invoke(app, ["process", "validate", "tests/data/minimal-process.yaml"]) + # CLI must run without error for a valid config + assert result.exit_code == 0 + assert "Validation passed" in result.stdout + + +def test_cli_process_validate_invalid() -> None: + """Tests the process validate command with an invalid process.""" + with patch("plugboard.cli.process.validate_process") as mock_validate: + mock_validate.return_value = ["Component 'x' has unconnected inputs: ['in_1']"] + result = runner.invoke(app, ["process", "validate", "tests/data/minimal-process.yaml"]) + assert result.exit_code == 1 + assert "Validation failed" in result.stderr + + def test_cli_server_discover(test_project_dir: Path) -> None: """Tests the server discover command.""" with respx.mock: From ad9cd897b4cdf5f90c2e0af8e9376d413d97b0b7 Mon Sep 17 00:00:00 2001 From: Toby Coleman Date: Tue, 10 Mar 2026 21:03:00 +0000 Subject: [PATCH 7/7] Update tests for validation checks --- tests/conftest.py | 10 ++++++++++ tests/integration/test_job_id_wiring.py | 18 ++++++++++++------ tests/integration/test_process_builder.py | 15 +++++++++++++-- tests/integration/test_state_backend.py | 14 +++++++++++--- 4 files changed, 46 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index de6612f6..b7688fcb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -143,3 +143,13 @@ def dict(self) -> dict: } ) return data + + +@pytest.fixture +def patch_validate_process() -> _t.Iterator[None]: + """Patch process validation for tests that don't require functional processes.""" + with ( + patch("plugboard.schemas.validate_process", return_value=[]), + patch("plugboard.process.process.validate_process", return_value=[]), + ): + yield diff --git a/tests/integration/test_job_id_wiring.py b/tests/integration/test_job_id_wiring.py index 1d9cb859..c820236b 100644 --- a/tests/integration/test_job_id_wiring.py +++ b/tests/integration/test_job_id_wiring.py @@ -206,7 +206,7 @@ async def test_cli_process_run_with_no_job_id(minimal_config_file: Path) -> None @pytest.mark.asyncio -async def test_direct_process_with_job_id() -> None: +async def test_direct_process_with_job_id(patch_validate_process: None) -> None: """Test building a process directly with a specified job ID.""" # Create a state backend with a specific job ID state: DictStateBackend = DictStateBackend(job_id="Job_direct12345678") @@ -221,7 +221,7 @@ async def test_direct_process_with_job_id() -> None: @pytest.mark.asyncio -async def test_direct_process_with_env_var() -> None: +async def test_direct_process_with_env_var(patch_validate_process: None) -> None: """Test building a process without a job ID while environment variable is set.""" # Set the environment variable to match the job ID to avoid conflict job_id: str = "Job_direct12345678" @@ -240,7 +240,9 @@ async def test_direct_process_with_env_var() -> None: @pytest.mark.asyncio -async def test_direct_process_with_job_id_and_env_var() -> None: +async def test_direct_process_with_job_id_and_env_var( + patch_validate_process: None, +) -> None: """Test building a process with a job ID while environment variable is set with same value.""" # Set the environment variable to match the job ID to avoid conflict job_id: str = "Job_direct12345678" @@ -271,7 +273,7 @@ async def test_direct_process_with_conflicting_job_ids() -> None: @pytest.mark.asyncio -async def test_direct_process_without_job_id() -> None: +async def test_direct_process_without_job_id(patch_validate_process: None) -> None: """Test building a process without specifying a job ID.""" # Create a state backend without a job ID state: DictStateBackend = DictStateBackend() @@ -289,7 +291,9 @@ async def test_direct_process_without_job_id() -> None: @pytest.mark.asyncio -async def test_direct_process_without_job_id_multiple_runs() -> None: +async def test_direct_process_without_job_id_multiple_runs( + patch_validate_process: None, +) -> None: """Test building a process without specifying a job ID multiple times.""" # Create a state backend without a job ID state: DictStateBackend = DictStateBackend() @@ -316,7 +320,9 @@ async def test_direct_process_without_job_id_multiple_runs() -> None: @pytest.mark.asyncio -async def test_direct_process_without_job_id_multiple_runs_multiprocessing() -> None: +async def test_direct_process_without_job_id_multiple_runs_multiprocessing( + patch_validate_process: None, +) -> None: """Test building a process without a job ID in a multiprocessing context.""" # Create a state backend without a specific job ID state: SqliteStateBackend = SqliteStateBackend() diff --git a/tests/integration/test_process_builder.py b/tests/integration/test_process_builder.py index cdfc4c5d..1900a61f 100644 --- a/tests/integration/test_process_builder.py +++ b/tests/integration/test_process_builder.py @@ -50,6 +50,13 @@ async def dummy_event_2_handler(self, event: DummyEvent2) -> None: pass +class E(ComponentTestHelper): + io = IO(output_events=[DummyEvent1, DummyEvent2]) + + async def step(self) -> None: + pass + + @pytest.fixture def process_spec() -> ProcessSpec: """Returns a `ProcessSpec` for testing.""" @@ -72,6 +79,10 @@ def process_spec() -> ProcessSpec: type="tests.integration.test_process_builder.D", args={"name": "D"}, ), + ComponentSpec( + type="tests.integration.test_process_builder.E", + args={"name": "E"}, + ), ], connectors=[ ConnectorSpec( @@ -103,11 +114,11 @@ async def test_process_builder_build(process_spec: ProcessSpec) -> None: # Must build a process with the correct type process.__class__.__name__ == process_spec.args.state.type.split(".")[-1] # Must build a process with the correct components and connectors - assert len(process.components) == 4 + assert len(process.components) == 5 # Number of connectors must be sum of: fields in config; user events; and system events assert len(process.connectors) == 2 + 2 + 1 # Must build a process with the correct component names - assert process.components.keys() == {"A", "B", "C", "D"} + assert process.components.keys() == {"A", "B", "C", "D", "E"} # Must build connectors with the correct channel types assert all( conn.__class__.__name__ == "AsyncioConnector" for conn in process.connectors.values() diff --git a/tests/integration/test_state_backend.py b/tests/integration/test_state_backend.py index c007c74a..b5516058 100644 --- a/tests/integration/test_state_backend.py +++ b/tests/integration/test_state_backend.py @@ -230,7 +230,10 @@ async def test_state_backend_upsert_connector( @pytest.mark.asyncio async def test_state_backend_process_init( - state_backend: StateBackend, B_components: list[Component], B_connectors: list[Connector] + state_backend: StateBackend, + B_components: list[Component], + B_connectors: list[Connector], + patch_validate_process: None, ) -> None: """Tests `StateBackend` connected up correctly on `Process.init`.""" comp_b1, comp_b2 = B_components @@ -258,7 +261,10 @@ async def test_state_backend_process_init( @pytest.mark.asyncio async def test_state_backend_process_status( - state_backend: StateBackend, B_components: list[Component], B_connectors: list[Connector] + state_backend: StateBackend, + B_components: list[Component], + B_connectors: list[Connector], + patch_validate_process: None, ) -> None: """Tests `StateBackend` process status updates.""" comp_b1, comp_b2 = B_components @@ -299,7 +305,9 @@ async def test_state_backend_process_status( @pytest.mark.asyncio -async def test_state_backend_process_parameters(state_backend: StateBackend) -> None: +async def test_state_backend_process_parameters( + state_backend: StateBackend, patch_validate_process: None +) -> None: """Tests `StateBackend` process parameters storage and retrieval.""" parameters = {"param_1": 10, "param_2": "value"} component_a = A(name="ComponentA")