diff --git a/src/fromager/commands/graph.py b/src/fromager/commands/graph.py index cf3fab98..a53b53d7 100644 --- a/src/fromager/commands/graph.py +++ b/src/fromager/commands/graph.py @@ -7,9 +7,12 @@ import typing import click +import rich +import rich.box from packaging.requirements import Requirement -from packaging.utils import canonicalize_name +from packaging.utils import NormalizedName, canonicalize_name from packaging.version import Version +from rich.table import Table from fromager import clickext, context from fromager.commands import bootstrap @@ -784,3 +787,252 @@ def n2s(nodes: typing.Iterable[DependencyNode]) -> str: topo.done(*nodes_to_build) print(f"\nBuilding {len(graph)} packages in {rounds} rounds.") + + +def get_dependency_closure(node: DependencyNode) -> set[NormalizedName]: + """Compute the full dependency closure for a node. + + Traverses all edge types and returns the set of canonical package names reachable from node, + including node itself. + + Args: + node: The starting node to compute the closure for. + + Returns: + Set of canonicalized package names in the transitive closure. + """ + dependency_names: set[NormalizedName] = set() + if node.canonicalized_name != ROOT: + dependency_names.add(node.canonicalized_name) + for dependency in node.iter_all_dependencies(): + if dependency.canonicalized_name != ROOT: + dependency_names.add(dependency.canonicalized_name) + return dependency_names + + +def get_package_names(graph: DependencyGraph) -> set[NormalizedName]: + """Extract all unique canonical package names from a graph. + + Args: + graph: The dependency graph to extract names from. + + Returns: + Set of canonicalized package names, excluding the ROOT node. + """ + return { + node.canonicalized_name for node in graph.get_all_nodes() if node.key != ROOT + } + + +def extract_collection_name(graph_path: str) -> str: + """Derive a collection name from a graph file path. + + Returns the filename without the extension as a string. + + Args: + graph_path: Filesystem path to a graph JSON file. + + Returns: + The filename without the extension. + """ + return pathlib.PurePath(graph_path).stem + + +class _CollectionScore(typing.NamedTuple): + """Overlap score between a package's dependency closure and a collection.""" + + collection: str + new_packages: int + existing_packages: int + coverage_percentage: float + + +def _analyze_suggestions( + toplevel_nodes: list[DependencyNode], + collection_packages: dict[str, set[NormalizedName]], +) -> list[dict[str, typing.Any]]: + """Score each onboarding top-level package against every collection. + + Args: + toplevel_nodes: Top-level nodes from the onboarding graph. + collection_packages: Mapping of collection name to its package name set. + + Returns: + List of result dicts, one per top-level package, sorted by package name. + """ + results: list[dict[str, typing.Any]] = [] + + for node in sorted(toplevel_nodes, key=lambda n: n.canonicalized_name): + dependency_names = get_dependency_closure(node) + total_dependency_count = len(dependency_names) + + scores: list[_CollectionScore] = [] + for collection_name, packages in collection_packages.items(): + existing_count = len(dependency_names & packages) + new_count = total_dependency_count - existing_count + coverage_percentage = ( + (existing_count / total_dependency_count * 100) + if total_dependency_count + else 0.0 + ) + scores.append( + _CollectionScore( + collection_name, new_count, existing_count, coverage_percentage + ) + ) + + # Rank: fewest new packages, then highest coverage, then name for determinism + scores.sort( + key=lambda score: ( + score.new_packages, + -score.coverage_percentage, + score.collection, + ) + ) + best_score = scores[0] if scores else None + + logger.debug( + "%s: %d deps, best fit '%s' (%d new, %.1f%% coverage)", + node.canonicalized_name, + total_dependency_count, + best_score.collection if best_score else "none", + best_score.new_packages if best_score else 0, + best_score.coverage_percentage if best_score else 0.0, + ) + + results.append( + { + "package": str(node.canonicalized_name), + "version": str(node.version), + "total_dependencies": total_dependency_count, + "best_fit": best_score.collection if best_score else "none", + "new_packages": best_score.new_packages if best_score else 0, + "existing_packages": best_score.existing_packages if best_score else 0, + "coverage_percentage": ( + round(best_score.coverage_percentage, 1) if best_score else 0.0 + ), + "all_collections": [ + { + "collection": score.collection, + "new_packages": score.new_packages, + "existing_packages": score.existing_packages, + "coverage_percentage": round(score.coverage_percentage, 1), + } + for score in scores + ], + } + ) + + return results + + +def _print_suggest_collection_table( + results: list[dict[str, typing.Any]], +) -> None: + """Render suggest-collection results as a Rich table.""" + table = Table( + title="Collection Suggestions for Onboarding Packages", + box=rich.box.MARKDOWN, + title_justify="left", + ) + table.add_column("Package", justify="left", no_wrap=True) + table.add_column("Version", justify="left", no_wrap=True) + table.add_column("Total Deps", justify="right", no_wrap=True) + table.add_column("Best Fit", justify="left", no_wrap=True) + table.add_column("New Pkgs", justify="right", no_wrap=True) + table.add_column("Existing", justify="right", no_wrap=True) + table.add_column("Coverage", justify="right", no_wrap=True) + + for result in results: + table.add_row( + result["package"], + result["version"], + str(result["total_dependencies"]), + result["best_fit"], + str(result["new_packages"]), + str(result["existing_packages"]), + f"{result['coverage_percentage']:.1f}%", + ) + + rich.get_console().print(table) + + +@graph.command(name="suggest-collection") +@click.option( + "--format", + "output_format", + type=click.Choice(["table", "json"], case_sensitive=False), + default="table", + help="Output format (default: table)", +) +@click.argument("onboarding-graph", type=str) +@click.argument("collection-graphs", nargs=-1, required=True, type=str) +def suggest_collection( + output_format: str, + onboarding_graph: str, + collection_graphs: tuple[str, ...], +) -> None: + """Suggest the best-fit collection for each onboarding package. + + Analyzes dependency overlap between top-level packages in ONBOARDING_GRAPH + and the existing COLLECTION_GRAPHS to recommend where each onboarding + package should be placed. + + For each top-level package in the onboarding graph, computes the full + transitive dependency closure and compares it against every collection. + Collections are ranked by fewest new packages required, then by highest + dependency coverage. + + \b + ONBOARDING_GRAPH Path to the onboarding collection graph.json. + COLLECTION_GRAPHS One or more paths to existing collection graph.json files. + """ + try: + onboarding = DependencyGraph.from_file(onboarding_graph) + except Exception as err: + raise click.ClickException( + f"Failed to load onboarding graph {onboarding_graph}: {err}" + ) from err + + root = onboarding.get_root_node() + + toplevel_nodes: list[DependencyNode] = [ + edge.destination_node + for edge in root.children + if edge.req_type == RequirementType.TOP_LEVEL + ] + + if not toplevel_nodes: + click.echo("No top-level packages found in onboarding graph.", err=True) + + logger.info( + "Loaded onboarding graph with %d top-level packages", len(toplevel_nodes) + ) + + collection_packages: dict[str, set[NormalizedName]] = {} + for graph_path in collection_graphs: + collection_name = extract_collection_name(graph_path) + if collection_name in collection_packages: + raise click.ClickException( + f"Duplicate collection name '{collection_name}' from {graph_path}. " + "Rename one of the graph files to avoid ambiguity." + ) + try: + collection_graph = DependencyGraph.from_file(graph_path) + except Exception as err: + raise click.ClickException( + f"Failed to load collection graph {graph_path}: {err}" + ) from err + collection_packages[collection_name] = get_package_names(collection_graph) + logger.debug( + "Collection '%s': %d packages", + collection_name, + len(collection_packages[collection_name]), + ) + + results = _analyze_suggestions(toplevel_nodes, collection_packages) + + if output_format == "json": + click.echo(json.dumps(results, indent=2)) + else: + _print_suggest_collection_table(results) diff --git a/src/fromager/dependency_graph.py b/src/fromager/dependency_graph.py index 7f1201dc..9811c887 100644 --- a/src/fromager/dependency_graph.py +++ b/src/fromager/dependency_graph.py @@ -158,6 +158,24 @@ def iter_build_requirements(self) -> typing.Iterable[DependencyNode]: ): yield install_edge.destination_node + def iter_all_dependencies(self) -> typing.Iterable[DependencyNode]: + """Get all unique, recursive dependencies following every edge type. + + Yields every reachable node exactly once using iterative DFS. + Follows install, build, and toplevel edges. + """ + visited: set[str] = {self.key} + stack: list[DependencyNode] = [self] + while stack: + current = stack.pop() + for edge in current.children: + child_node = edge.destination_node + if child_node.key in visited: + continue + visited.add(child_node.key) + yield child_node + stack.append(child_node) + def iter_install_requirements(self) -> typing.Iterable[DependencyNode]: """Get all unique, recursive install requirements""" visited: set[str] = set() diff --git a/tests/test_suggest_collection.py b/tests/test_suggest_collection.py new file mode 100644 index 00000000..8e08c5a5 --- /dev/null +++ b/tests/test_suggest_collection.py @@ -0,0 +1,454 @@ +"""Tests for the graph suggest-collection command and its helper functions.""" + +import json +import pathlib +import re + +import pytest +from click.testing import CliRunner +from packaging.requirements import Requirement +from packaging.utils import canonicalize_name +from packaging.version import Version + +from fromager.__main__ import main as fromager +from fromager.commands.graph import ( + _analyze_suggestions, + extract_collection_name, + get_dependency_closure, + get_package_names, +) +from fromager.dependency_graph import DependencyGraph +from fromager.requirements_file import RequirementType + + +def _extract_json_from_output(output: str) -> str: + """Extract JSON array from output that may contain leading log lines.""" + json_match = re.search(r"\[.*\]", output, re.DOTALL) + if json_match: + return json_match.group(0) + return "[]" + + +def _build_graph( + toplevel: dict[str, str], + dependencies: dict[str, list[tuple[str, str, str]]], +) -> DependencyGraph: + """Build a synthetic DependencyGraph for testing. + + Args: + toplevel: Mapping of package name to version for top-level packages. + dependencies: Mapping of ``"name==version"`` to a list of + ``(dep_name, dep_version, req_type)`` tuples. + + Returns: + A populated DependencyGraph. + """ + graph = DependencyGraph() + for name, version in toplevel.items(): + graph.add_dependency( + parent_name=None, + parent_version=None, + req_type=RequirementType.TOP_LEVEL, + req=Requirement(name), + req_version=Version(version), + ) + + for parent_key, deps in dependencies.items(): + pname, _, pver = parent_key.partition("==") + for dep_name, dep_version, req_type_str in deps: + graph.add_dependency( + parent_name=canonicalize_name(pname), + parent_version=Version(pver), + req_type=RequirementType(req_type_str), + req=Requirement(f"{dep_name}>={dep_version}"), + req_version=Version(dep_version), + ) + return graph + + +def _write_graph(graph: DependencyGraph, path: pathlib.Path) -> None: + """Serialize a DependencyGraph to a JSON file.""" + with open(path, "w") as f: + graph.serialize(f) + + +# --------------------------------------------------------------------------- +# Unit tests for helper functions +# --------------------------------------------------------------------------- + + +class TestGetDependencyClosure: + """Tests for get_dependency_closure.""" + + def test_single_package_no_deps(self) -> None: + """A top-level package with no dependencies has a closure of itself.""" + graph = _build_graph({"alpha": "1.0"}, {}) + node = graph.nodes["alpha==1.0"] + closure = get_dependency_closure(node) + assert closure == {canonicalize_name("alpha")} + + def test_transitive_install_deps(self) -> None: + """Closure includes transitive install dependencies.""" + graph = _build_graph( + {"alpha": "1.0"}, + { + "alpha==1.0": [("bravo", "2.0", "install")], + "bravo==2.0": [("charlie", "3.0", "install")], + }, + ) + node = graph.nodes["alpha==1.0"] + closure = get_dependency_closure(node) + assert closure == { + canonicalize_name("alpha"), + canonicalize_name("bravo"), + canonicalize_name("charlie"), + } + + def test_includes_build_deps(self) -> None: + """Default closure includes build-system dependencies.""" + graph = _build_graph( + {"alpha": "1.0"}, + { + "alpha==1.0": [ + ("bravo", "2.0", "install"), + ("setuptools", "70.0", "build-system"), + ], + }, + ) + node = graph.nodes["alpha==1.0"] + closure = get_dependency_closure(node) + assert canonicalize_name("setuptools") in closure + assert canonicalize_name("bravo") in closure + + def test_cycle_does_not_hang(self) -> None: + """Circular dependencies terminate without hanging.""" + graph = _build_graph( + {"alpha": "1.0"}, + { + "alpha==1.0": [("bravo", "1.0", "install")], + "bravo==1.0": [("alpha", "1.0", "install")], + }, + ) + node = graph.nodes["alpha==1.0"] + closure = get_dependency_closure(node) + assert closure == { + canonicalize_name("alpha"), + canonicalize_name("bravo"), + } + + def test_diamond_dependency(self) -> None: + """Diamond-shaped deps are counted once.""" + graph = _build_graph( + {"alpha": "1.0"}, + { + "alpha==1.0": [ + ("bravo", "1.0", "install"), + ("charlie", "1.0", "install"), + ], + "bravo==1.0": [("delta", "1.0", "install")], + "charlie==1.0": [("delta", "1.0", "install")], + }, + ) + node = graph.nodes["alpha==1.0"] + closure = get_dependency_closure(node) + assert len(closure) == 4 + assert canonicalize_name("delta") in closure + + +class TestGetPackageNames: + """Tests for get_package_names.""" + + def test_excludes_root(self) -> None: + """ROOT node is never in the returned set.""" + graph = _build_graph({"alpha": "1.0"}, {}) + names = get_package_names(graph) + assert "" not in names + assert canonicalize_name("alpha") in names + + def test_includes_all_nodes(self) -> None: + """All non-root nodes contribute their canonical name.""" + graph = _build_graph( + {"alpha": "1.0"}, + {"alpha==1.0": [("bravo", "2.0", "install")]}, + ) + names = get_package_names(graph) + assert names == { + canonicalize_name("alpha"), + canonicalize_name("bravo"), + } + + def test_empty_graph(self) -> None: + """An empty graph (only ROOT) returns an empty set.""" + graph = DependencyGraph() + names = get_package_names(graph) + assert names == set() + + +class TestExtractCollectionName: + """Tests for extract_collection_name.""" + + def test_simple_filename(self) -> None: + assert extract_collection_name("notebook.json") == "notebook" + + def test_preserves_full_stem(self) -> None: + assert extract_collection_name("notebook-graph.json") == "notebook-graph" + + def test_preserves_hyphens(self) -> None: + assert extract_collection_name("rhai-innovation.json") == "rhai-innovation" + + def test_full_path(self) -> None: + assert extract_collection_name("/tmp/graphs/notebook.json") == "notebook" + + def test_stem_only(self) -> None: + assert extract_collection_name("my-collection.json") == "my-collection" + + +# --------------------------------------------------------------------------- +# CLI integration tests for suggest-collection +# --------------------------------------------------------------------------- + + +class TestSuggestCollectionCLI: + """Integration tests for ``fromager graph suggest-collection``.""" + + @pytest.fixture() + def graph_dir(self, tmp_path: pathlib.Path) -> pathlib.Path: + """Create a temporary directory with onboarding and collection graphs.""" + # Onboarding graph: two top-level packages + # pkg-x depends on numpy, pandas + # pkg-y depends on numpy, torch + onboard = _build_graph( + {"pkg-x": "1.0", "pkg-y": "1.0"}, + { + "pkg-x==1.0": [ + ("numpy", "1.26", "install"), + ("pandas", "2.0", "install"), + ], + "pkg-y==1.0": [ + ("numpy", "1.26", "install"), + ("torch", "2.0", "install"), + ], + }, + ) + _write_graph(onboard, tmp_path / "onboarding.json") + + # Collection "data-science": has numpy, pandas, scipy + ds = _build_graph( + {"numpy": "1.26", "pandas": "2.0", "scipy": "1.12"}, + {}, + ) + _write_graph(ds, tmp_path / "data-science.json") + + # Collection "ml": has numpy, torch, triton + ml = _build_graph( + {"numpy": "1.26", "torch": "2.0", "triton": "3.0"}, + {}, + ) + _write_graph(ml, tmp_path / "ml.json") + + return tmp_path + + def test_table_output( + self, + cli_runner: CliRunner, + graph_dir: pathlib.Path, + ) -> None: + """Table output contains expected package names and collection fits.""" + result = cli_runner.invoke( + fromager, + [ + "graph", + "suggest-collection", + str(graph_dir / "onboarding.json"), + str(graph_dir / "data-science.json"), + str(graph_dir / "ml.json"), + ], + ) + assert result.exit_code == 0, result.output + assert "pkg-x" in result.output + assert "pkg-y" in result.output + assert "data-science" in result.output + assert "ml" in result.output + + def test_json_output( + self, + cli_runner: CliRunner, + graph_dir: pathlib.Path, + ) -> None: + """JSON output is parseable and contains expected fields.""" + result = cli_runner.invoke( + fromager, + [ + "graph", + "suggest-collection", + "--format", + "json", + str(graph_dir / "onboarding.json"), + str(graph_dir / "data-science.json"), + str(graph_dir / "ml.json"), + ], + ) + assert result.exit_code == 0, result.output + data = json.loads(_extract_json_from_output(result.output)) + assert isinstance(data, list) + assert len(data) == 2 + + packages = {r["package"] for r in data} + assert packages == {"pkg-x", "pkg-y"} + + for entry in data: + assert "best_fit" in entry + assert "total_dependencies" in entry + assert "coverage_percentage" in entry + assert "all_collections" in entry + assert len(entry["all_collections"]) == 2 + + def test_best_fit_ranking( + self, + cli_runner: CliRunner, + graph_dir: pathlib.Path, + ) -> None: + """pkg-x should prefer data-science, pkg-y should prefer ml.""" + result = cli_runner.invoke( + fromager, + [ + "graph", + "suggest-collection", + "--format", + "json", + str(graph_dir / "onboarding.json"), + str(graph_dir / "data-science.json"), + str(graph_dir / "ml.json"), + ], + ) + assert result.exit_code == 0, result.output + data = json.loads(_extract_json_from_output(result.output)) + by_pkg = {r["package"]: r for r in data} + + assert by_pkg["pkg-x"]["best_fit"] == "data-science" + assert by_pkg["pkg-y"]["best_fit"] == "ml" + + def test_empty_onboarding_graph( + self, + cli_runner: CliRunner, + graph_dir: pathlib.Path, + ) -> None: + """Empty onboarding graph warns on stderr and outputs empty results.""" + empty = DependencyGraph() + _write_graph(empty, graph_dir / "empty.json") + + result = cli_runner.invoke( + fromager, + [ + "graph", + "suggest-collection", + "--format", + "json", + str(graph_dir / "empty.json"), + str(graph_dir / "data-science.json"), + ], + ) + assert result.exit_code == 0 + assert "No top-level packages" in result.output + data = json.loads(_extract_json_from_output(result.output)) + assert data == [] + + def test_no_collection_graphs_fails( + self, + cli_runner: CliRunner, + graph_dir: pathlib.Path, + ) -> None: + """Missing collection-graphs argument causes a usage error.""" + result = cli_runner.invoke( + fromager, + [ + "graph", + "suggest-collection", + str(graph_dir / "onboarding.json"), + ], + ) + assert result.exit_code != 0 + + +# --------------------------------------------------------------------------- +# Unit tests for _analyze_suggestions +# --------------------------------------------------------------------------- + + +class TestAnalyzeSuggestions: + """Tests for the core scoring logic in _analyze_suggestions.""" + + def test_scores_single_package_against_two_collections(self) -> None: + """Best-fit collection has fewest new packages required.""" + graph = _build_graph( + {"alpha": "1.0"}, + { + "alpha==1.0": [ + ("numpy", "1.26", "install"), + ("pandas", "2.0", "install"), + ], + }, + ) + toplevel = [graph.nodes["alpha==1.0"]] + # Both collections include alpha itself; "ds" covers all three deps. + collection_packages = { + "ds": { + canonicalize_name("alpha"), + canonicalize_name("numpy"), + canonicalize_name("pandas"), + }, + "ml": {canonicalize_name("alpha"), canonicalize_name("numpy")}, + } + results = _analyze_suggestions(toplevel, collection_packages) + + assert len(results) == 1 + entry = results[0] + assert entry["best_fit"] == "ds" + assert entry["new_packages"] == 0 + assert entry["existing_packages"] == 3 + assert entry["coverage_percentage"] == 100.0 + assert len(entry["all_collections"]) == 2 + + def test_version_difference_does_not_affect_fit(self) -> None: + """A package at a different version still counts as existing in the collection.""" + graph = _build_graph( + {"alpha": "1.0"}, + {"alpha==1.0": [("numpy", "2.0", "install")]}, + ) + toplevel = [graph.nodes["alpha==1.0"]] + # Collection has both packages, but numpy at an older version. + collection_packages = { + "ds": {canonicalize_name("alpha"), canonicalize_name("numpy")}, + } + results = _analyze_suggestions(toplevel, collection_packages) + + # numpy at version 2.0 (onboarding) vs 1.26 (collection) should still match. + assert results[0]["new_packages"] == 0 + assert results[0]["existing_packages"] == 2 + assert results[0]["coverage_percentage"] == 100.0 + + def test_result_sorted_by_package_name(self) -> None: + """Results are returned alphabetically by package name.""" + graph = _build_graph({"zebra": "1.0", "alpha": "1.0"}, {}) + toplevel = [graph.nodes["zebra==1.0"], graph.nodes["alpha==1.0"]] + results = _analyze_suggestions(toplevel, {"empty": set()}) + assert [r["package"] for r in results] == ["alpha", "zebra"] + + def test_all_collections_ranked_ascending_new_packages(self) -> None: + """all_collections list is ordered fewest-new-packages first.""" + graph = _build_graph( + {"alpha": "1.0"}, + {"alpha==1.0": [("numpy", "1.0", "install"), ("pandas", "1.0", "install")]}, + ) + toplevel = [graph.nodes["alpha==1.0"]] + collection_packages = { + "best": {canonicalize_name("numpy"), canonicalize_name("pandas")}, + "mid": {canonicalize_name("numpy")}, + "worst": set(), + } + results = _analyze_suggestions(toplevel, collection_packages) + ranked = [c["collection"] for c in results[0]["all_collections"]] + new_pkg_counts = [c["new_packages"] for c in results[0]["all_collections"]] + assert new_pkg_counts == sorted(new_pkg_counts), ( + f"all_collections not sorted by new_packages: {ranked}" + )