Skip to content
12 changes: 12 additions & 0 deletions plugboard-schemas/plugboard_schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from importlib.metadata import version

from ._common import PlugboardBaseModel
from ._graph import simple_cycles
from ._validation import (
validate_all_inputs_connected,
validate_input_events,
validate_no_unresolved_cycles,
validate_process,
)
from .component import ComponentArgsDict, ComponentArgsSpec, ComponentSpec, Resource
from .config import ConfigSpec, ProcessConfigSpec
from .connector import (
Expand Down Expand Up @@ -85,4 +92,9 @@
"TuneArgsDict",
"TuneArgsSpec",
"TuneSpec",
"simple_cycles",
"validate_all_inputs_connected",
"validate_input_events",
"validate_no_unresolved_cycles",
"validate_process",
]
126 changes: 126 additions & 0 deletions plugboard-schemas/plugboard_schemas/_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Graph algorithms for topology validation.

Implements Johnson's algorithm for finding all simple cycles in a directed graph,
along with helper functions for strongly connected components.

References:
Donald B Johnson. "Finding all the elementary circuits of a directed graph."
SIAM Journal on Computing. 1975.
"""

from collections import defaultdict
from collections.abc import Generator


def simple_cycles(graph: dict[str, set[str]]) -> Generator[list[str], None, None]:
"""Find all simple cycles in a directed graph using Johnson's algorithm.

Args:
graph: A dictionary mapping each vertex to a set of its neighbours.

Yields:
Each elementary cycle as a list of vertices.
"""
graph = {v: set(nbrs) for v, nbrs in graph.items()}
sccs = _strongly_connected_components(graph)
while sccs:
scc = sccs.pop()
startnode = scc.pop()
path = [startnode]
blocked: set[str] = set()
closed: set[str] = set()
blocked.add(startnode)
B: dict[str, set[str]] = defaultdict(set)
stack: list[tuple[str, list[str]]] = [(startnode, list(graph[startnode]))]
while stack:
thisnode, nbrs = stack[-1]
if nbrs:
nextnode = nbrs.pop()
if nextnode == startnode:
yield path[:]
closed.update(path)
elif nextnode not in blocked:
path.append(nextnode)
stack.append((nextnode, list(graph[nextnode])))
closed.discard(nextnode)
blocked.add(nextnode)
continue
if not nbrs:
if thisnode in closed:
_unblock(thisnode, blocked, B)
else:
for nbr in graph[thisnode]:
if thisnode not in B[nbr]:
B[nbr].add(thisnode)
stack.pop()
path.pop()
_remove_node(graph, startnode)
H = _subgraph(graph, set(scc))
sccs.extend(_strongly_connected_components(H))


def _unblock(thisnode: str, blocked: set[str], B: dict[str, set[str]]) -> None:
"""Unblock a node and recursively unblock nodes in its B set."""
stack = {thisnode}
while stack:
node = stack.pop()
if node in blocked:
blocked.remove(node)
stack.update(B[node])
B[node].clear()


def _strongly_connected_components(graph: dict[str, set[str]]) -> list[set[str]]:
"""Find all strongly connected components using Tarjan's algorithm.

Args:
graph: A dictionary mapping each vertex to a set of its neighbours.

Returns:
A list of sets, each containing the vertices of a strongly connected component.
"""
index_counter = [0]
stack: list[str] = []
lowlink: dict[str, int] = {}
index: dict[str, int] = {}
result: list[set[str]] = []

def _strong_connect(node: str) -> None:
index[node] = index_counter[0]
lowlink[node] = index_counter[0]
index_counter[0] += 1
stack.append(node)

for successor in graph.get(node, set()):
if successor not in index:
_strong_connect(successor)
lowlink[node] = min(lowlink[node], lowlink[successor])
elif successor in stack:
lowlink[node] = min(lowlink[node], index[successor])

if lowlink[node] == index[node]:
connected_component: set[str] = set()
while True:
successor = stack.pop()
connected_component.add(successor)
if successor == node:
break
result.append(connected_component)

for node in graph:
if node not in index:
_strong_connect(node)

return result


def _remove_node(graph: dict[str, set[str]], target: str) -> None:
"""Remove a node and all its edges from the graph."""
del graph[target]
for nbrs in graph.values():
nbrs.discard(target)


def _subgraph(graph: dict[str, set[str]], vertices: set[str]) -> dict[str, set[str]]:
"""Get the subgraph induced by a set of vertices."""
return {v: graph[v] & vertices for v in vertices}
197 changes: 197 additions & 0 deletions plugboard-schemas/plugboard_schemas/_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Validation utilities for process topology.

Provides functions to validate process topology including:
- Checking that all component inputs are connected
- Checking that input events have matching output event producers
- Checking for circular connections that require initial values

All validators accept the output of ``process.dict()`` or the relevant
sub-structures thereof.
"""

from __future__ import annotations

from collections import defaultdict
import typing as _t

from ._graph import simple_cycles


def _build_component_graph(
connectors: dict[str, dict[str, _t.Any]],
) -> dict[str, set[str]]:
"""Build a directed graph of component connections from connector dicts.

Args:
connectors: Dictionary mapping connector IDs to connector dicts,
as returned by ``process.dict()["connectors"]``.

Returns:
A dictionary mapping source component names to sets of target component names.
"""
graph: dict[str, set[str]] = defaultdict(set)
for conn_info in connectors.values():
spec = conn_info["spec"]
source_entity = spec["source"]["entity"]
target_entity = spec["target"]["entity"]
if source_entity != target_entity:
graph[source_entity].add(target_entity)
if target_entity not in graph:
graph[target_entity] = set()
return dict(graph)


def _get_edges_in_cycle(
cycle: list[str],
connectors: dict[str, dict[str, _t.Any]],
) -> list[dict[str, _t.Any]]:
"""Get all connector spec dicts that form edges within a cycle.

Args:
cycle: List of component names forming a cycle.
connectors: Dictionary mapping connector IDs to connector dicts.

Returns:
List of connector spec dicts that are part of the cycle.
"""
cycle_edges: list[dict[str, _t.Any]] = []
for i, node in enumerate(cycle):
next_node = cycle[(i + 1) % len(cycle)]
for conn_info in connectors.values():
spec = conn_info["spec"]
if spec["source"]["entity"] == node and spec["target"]["entity"] == next_node:
cycle_edges.append(spec)
return cycle_edges


def validate_all_inputs_connected(
process_dict: dict[str, _t.Any],
) -> list[str]:
"""Check that all component inputs are connected.

Args:
process_dict: The output of ``process.dict()``. Uses the ``"components"``
and ``"connectors"`` keys.

Returns:
List of error messages for unconnected inputs.
"""
components: dict[str, dict[str, _t.Any]] = process_dict["components"]
connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"]

connected_inputs: dict[str, set[str]] = defaultdict(set)
for conn_info in connectors.values():
spec = conn_info["spec"]
target_name = spec["target"]["entity"]
target_field = spec["target"]["descriptor"]
connected_inputs[target_name].add(target_field)

errors: list[str] = []
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
all_inputs = set(io.get("inputs", []))
connected = connected_inputs.get(comp_name, set())
unconnected = all_inputs - connected
if unconnected:
errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}")
return errors


def validate_input_events(
process_dict: dict[str, _t.Any],
) -> list[str]:
"""Check that all components with input events have a matching output event producer.

Args:
process_dict: The output of ``process.dict()``. Uses the ``"components"`` key.

Returns:
List of error messages for unmatched input events.
"""
components: dict[str, dict[str, _t.Any]] = process_dict["components"]

all_output_events: set[str] = set()
for comp_data in components.values():
io = comp_data.get("io", {})
all_output_events.update(io.get("output_events", []))

errors: list[str] = []
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
input_events = set(io.get("input_events", []))
unmatched = input_events - all_output_events
if unmatched:
errors.append(
f"Component '{comp_name}' has input events with no producer: {sorted(unmatched)}"
)
return errors


def validate_no_unresolved_cycles(
process_dict: dict[str, _t.Any],
) -> list[str]:
"""Check for circular connections that are not resolved by initial values.

Circular loops are only valid if there are ``initial_values`` set on an
appropriate component input within the loop.

Args:
process_dict: The output of ``process.dict()``. Uses the ``"components"``
and ``"connectors"`` keys.

Returns:
List of error messages for unresolved circular connections.
"""
components: dict[str, dict[str, _t.Any]] = process_dict["components"]
connectors: dict[str, dict[str, _t.Any]] = process_dict["connectors"]

graph = _build_component_graph(connectors)
if not graph:
return []

# Build lookup of component initial_values by name
initial_values_by_comp: dict[str, set[str]] = {}
for comp_name, comp_data in components.items():
io = comp_data.get("io", {})
iv = io.get("initial_values", {})
if iv:
initial_values_by_comp[comp_name] = set(iv.keys())

errors: list[str] = []
for cycle in simple_cycles(graph):
cycle_edges = _get_edges_in_cycle(cycle, connectors)
cycle_resolved = False
for edge in cycle_edges:
target_comp = edge["target"]["entity"]
target_field = edge["target"]["descriptor"]
if target_comp in initial_values_by_comp:
if target_field in initial_values_by_comp[target_comp]:
cycle_resolved = True
break
if not cycle_resolved:
cycle_str = " -> ".join(cycle + [cycle[0]])
errors.append(
f"Circular connection detected without initial values: {cycle_str}. "
f"Set initial_values on a component input within the loop to resolve."
)
return errors


def validate_process(process_dict: dict[str, _t.Any]) -> list[str]:
"""Run all topology validation checks on a process.

This is the main validation entry point. It accepts the output of
``process.dict()`` and runs every available check, returning a
combined list of error messages.

Args:
process_dict: The output of ``process.dict()``.

Returns:
List of error messages. An empty list indicates a valid topology.
"""
errors: list[str] = []
errors.extend(validate_all_inputs_connected(process_dict))
errors.extend(validate_input_events(process_dict))
errors.extend(validate_no_unresolved_cycles(process_dict))
return errors
30 changes: 29 additions & 1 deletion plugboard/cli/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from plugboard.diagram import MermaidDiagram
from plugboard.process import Process, ProcessBuilder
from plugboard.schemas import ConfigSpec
from plugboard.schemas import ConfigSpec, validate_process
from plugboard.tune import Tuner
from plugboard.utils import add_sys_path, run_coro_sync

Expand Down Expand Up @@ -164,3 +164,31 @@ def diagram(
diagram = MermaidDiagram.from_process(process)
md = Markdown(f"```\n{diagram.diagram}\n```\n[Editable diagram]({diagram.url}) (external link)")
print(md)


@app.command()
def validate(
config: Annotated[
Path,
typer.Argument(
exists=True,
file_okay=True,
dir_okay=False,
writable=False,
readable=True,
resolve_path=True,
help="Path to the YAML configuration file.",
),
],
) -> None:
"""Validate a Plugboard process configuration."""
config_spec = _read_yaml(config)
with add_sys_path(config.parent):
process = _build_process(config_spec)
errors = validate_process(process.dict())
if errors:
stderr.print("[red]Validation failed:[/red]")
for error in errors:
stderr.print(f" • {error}")
raise typer.Exit(1)
print("[green]Validation passed[/green]")
Loading
Loading