|
| 1 | +"""Extract the pipeline DAG from Luigi task introspection. |
| 2 | +
|
| 3 | +Walks the FullPlanPipeline task graph via requires()/output() and produces |
| 4 | +a JSON description of every stage: name, output files, upstream stages, |
| 5 | +and source code files. This replaces the hand-maintained registry with a |
| 6 | +generated artifact that stays in sync with the actual pipeline code. |
| 7 | +
|
| 8 | +Usage: |
| 9 | + cd worker_plan |
| 10 | + python -m worker_plan_internal.extract_dag |
| 11 | + python -m worker_plan_internal.extract_dag --output pipeline_dag.json |
| 12 | +""" |
| 13 | +import inspect |
| 14 | +import json |
| 15 | +import re |
| 16 | +import sys |
| 17 | +from pathlib import Path |
| 18 | +from typing import Any |
| 19 | + |
| 20 | +import luigi |
| 21 | + |
| 22 | +_WORKER_PLAN_DIR = Path(__file__).resolve().parent.parent # worker_plan/ |
| 23 | + |
| 24 | +# Module prefixes that are infrastructure/utilities, not implementation logic. |
| 25 | +# Imports from these are excluded from source_files auto-detection. |
| 26 | +_INFRASTRUCTURE_PREFIXES = ( |
| 27 | + "worker_plan_internal.plan.stages.", |
| 28 | + "worker_plan_internal.plan.run_plan_pipeline", |
| 29 | + "worker_plan_internal.plan.pipeline_environment", |
| 30 | + "worker_plan_internal.plan.ping_llm", |
| 31 | + "worker_plan_internal.llm_util.", |
| 32 | + "worker_plan_internal.llm_factory", |
| 33 | + "worker_plan_internal.luigi_util.", |
| 34 | + "worker_plan_internal.utils.", |
| 35 | + "worker_plan_internal.format_", |
| 36 | + "worker_plan_api.", |
| 37 | +) |
| 38 | + |
| 39 | + |
| 40 | +def _class_name_to_stage_name(class_name: str) -> str: |
| 41 | + """Convert CamelCase task class name to snake_case stage name. |
| 42 | +
|
| 43 | + Removes the 'Task' suffix, then converts CamelCase → snake_case. |
| 44 | +
|
| 45 | + Examples: |
| 46 | + PotentialLeversTask → potential_levers |
| 47 | + SWOTAnalysisTask → swot_analysis |
| 48 | + WBSProjectLevel1AndLevel2Task → wbs_project_level1_and_level2 |
| 49 | + GovernancePhase1AuditTask → governance_phase1_audit |
| 50 | + """ |
| 51 | + name = class_name.removesuffix("Task") |
| 52 | + # Insert underscore between lowercase/digit and uppercase |
| 53 | + name = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", name) |
| 54 | + # Insert underscore between consecutive uppercase run and uppercase+lowercase |
| 55 | + name = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", name) |
| 56 | + return name.lower() |
| 57 | + |
| 58 | + |
| 59 | +def _extract_output_filenames(task: luigi.Task) -> list[str]: |
| 60 | + """Extract output filenames (basenames) from a task's output() method.""" |
| 61 | + try: |
| 62 | + outputs = task.output() |
| 63 | + except Exception: |
| 64 | + return [] |
| 65 | + |
| 66 | + targets: list[Any] = [] |
| 67 | + if isinstance(outputs, dict): |
| 68 | + targets = list(outputs.values()) |
| 69 | + elif isinstance(outputs, (list, tuple)): |
| 70 | + targets = list(outputs) |
| 71 | + else: |
| 72 | + targets = [outputs] |
| 73 | + |
| 74 | + filenames: list[str] = [] |
| 75 | + for target in targets: |
| 76 | + if hasattr(target, "path"): |
| 77 | + filenames.append(Path(target.path).name) |
| 78 | + return filenames |
| 79 | + |
| 80 | + |
| 81 | +def _extract_upstream_tasks(task: luigi.Task) -> list[luigi.Task]: |
| 82 | + """Extract upstream task instances from a task's requires() method.""" |
| 83 | + try: |
| 84 | + deps = task.requires() |
| 85 | + except Exception: |
| 86 | + return [] |
| 87 | + |
| 88 | + if deps is None: |
| 89 | + return [] |
| 90 | + if isinstance(deps, dict): |
| 91 | + return list(deps.values()) |
| 92 | + if isinstance(deps, (list, tuple)): |
| 93 | + return list(deps) |
| 94 | + if isinstance(deps, luigi.Task): |
| 95 | + return [deps] |
| 96 | + return [] |
| 97 | + |
| 98 | + |
| 99 | +def _detect_implementation_files(cls: type) -> list[str]: |
| 100 | + """Auto-detect implementation source files from module-level imports. |
| 101 | +
|
| 102 | + Scans the module that defines *cls* for classes and functions imported |
| 103 | + from ``worker_plan_internal.*`` that are NOT infrastructure (stages, |
| 104 | + LLM utilities, API types, etc.). Returns paths relative to worker_plan/. |
| 105 | + """ |
| 106 | + module = inspect.getmodule(cls) |
| 107 | + if module is None: |
| 108 | + return [] |
| 109 | + |
| 110 | + files: list[str] = [] |
| 111 | + seen_modules: set[str] = set() |
| 112 | + |
| 113 | + for attr_name in dir(module): |
| 114 | + obj = getattr(module, attr_name, None) |
| 115 | + if obj is None or not (inspect.isclass(obj) or inspect.isfunction(obj)): |
| 116 | + continue |
| 117 | + |
| 118 | + obj_module_name = getattr(obj, "__module__", "") or "" |
| 119 | + if not obj_module_name.startswith("worker_plan_internal."): |
| 120 | + continue |
| 121 | + if any(obj_module_name.startswith(p) for p in _INFRASTRUCTURE_PREFIXES): |
| 122 | + continue |
| 123 | + if obj_module_name in seen_modules: |
| 124 | + continue |
| 125 | + seen_modules.add(obj_module_name) |
| 126 | + |
| 127 | + try: |
| 128 | + obj_file = Path(inspect.getfile(obj)).resolve() |
| 129 | + rel = str(obj_file.relative_to(_WORKER_PLAN_DIR)) |
| 130 | + if rel not in files: |
| 131 | + files.append(rel) |
| 132 | + except (TypeError, ValueError, OSError): |
| 133 | + continue |
| 134 | + |
| 135 | + return files |
| 136 | + |
| 137 | + |
| 138 | +def _extract_source_files(task: luigi.Task) -> list[str]: |
| 139 | + """Get source files: task's own file + auto-detected implementation files.""" |
| 140 | + cls = type(task) |
| 141 | + |
| 142 | + # The task's own file |
| 143 | + result: list[str] = [] |
| 144 | + try: |
| 145 | + task_file = Path(inspect.getfile(cls)).resolve() |
| 146 | + result.append(str(task_file.relative_to(_WORKER_PLAN_DIR))) |
| 147 | + except (TypeError, ValueError, OSError): |
| 148 | + pass |
| 149 | + |
| 150 | + # Supplement with auto-detected implementation files |
| 151 | + for f in _detect_implementation_files(cls): |
| 152 | + if f not in result: |
| 153 | + result.append(f) |
| 154 | + |
| 155 | + return result |
| 156 | + |
| 157 | + |
| 158 | +def _output_sort_key(stage: dict[str, Any]) -> tuple[int, int, str]: |
| 159 | + """Sort key: numeric prefix from the first output filename, then name.""" |
| 160 | + filename = stage["output_files"][0] if stage.get("output_files") else "" |
| 161 | + match = re.match(r"(\d+)-?(\d+)?", filename) |
| 162 | + if match: |
| 163 | + major = int(match.group(1)) |
| 164 | + minor = int(match.group(2)) if match.group(2) else 0 |
| 165 | + return (major, minor, stage["id"]) |
| 166 | + return (9999, 0, stage["id"]) |
| 167 | + |
| 168 | + |
| 169 | +def extract_dag() -> dict[str, Any]: |
| 170 | + """Walk the FullPlanPipeline task graph and extract DAG info. |
| 171 | +
|
| 172 | + Returns a top-level schema object with stages sorted by pipeline order. |
| 173 | + """ |
| 174 | + from worker_plan_internal.plan.stages.full_plan_pipeline import FullPlanPipeline |
| 175 | + |
| 176 | + root = FullPlanPipeline(run_id_dir=Path("/tmp/_dag_extract_dummy")) |
| 177 | + |
| 178 | + stages: list[dict[str, Any]] = [] |
| 179 | + visited: set[str] = set() |
| 180 | + |
| 181 | + def _walk(task: luigi.Task) -> None: |
| 182 | + class_name = task.__class__.__name__ |
| 183 | + if class_name in visited: |
| 184 | + return |
| 185 | + visited.add(class_name) |
| 186 | + |
| 187 | + upstream_tasks = _extract_upstream_tasks(task) |
| 188 | + |
| 189 | + # Recurse into dependencies first (depth-first) |
| 190 | + for dep in upstream_tasks: |
| 191 | + _walk(dep) |
| 192 | + |
| 193 | + # Skip the orchestrator itself |
| 194 | + if class_name == "FullPlanPipeline": |
| 195 | + return |
| 196 | + |
| 197 | + cls = type(task) |
| 198 | + stage_name = _class_name_to_stage_name(class_name) |
| 199 | + description = cls.description() if hasattr(cls, "description") else "" |
| 200 | + output_files = _extract_output_filenames(task) |
| 201 | + source_files = _extract_source_files(task) |
| 202 | + depends_on_names = sorted(set( |
| 203 | + _class_name_to_stage_name(dep.__class__.__name__) |
| 204 | + for dep in upstream_tasks |
| 205 | + )) |
| 206 | + |
| 207 | + stages.append({ |
| 208 | + "id": stage_name, |
| 209 | + "description": description, |
| 210 | + "output_files": output_files, |
| 211 | + "depends_on": depends_on_names, |
| 212 | + "source_files": source_files, |
| 213 | + }) |
| 214 | + |
| 215 | + _walk(root) |
| 216 | + |
| 217 | + stages.sort(key=_output_sort_key) |
| 218 | + |
| 219 | + return { |
| 220 | + "schema_version": "1.0", |
| 221 | + "pipeline_name": "planning_pipeline", |
| 222 | + "description": "DAG for PlanExe, an AI-driven project planning system.", |
| 223 | + "stages": stages, |
| 224 | + } |
| 225 | + |
| 226 | + |
| 227 | +def main() -> None: |
| 228 | + output_path = None |
| 229 | + args = sys.argv[1:] |
| 230 | + if len(args) >= 2 and args[0] == "--output": |
| 231 | + output_path = args[1] |
| 232 | + |
| 233 | + dag = extract_dag() |
| 234 | + dag_json = json.dumps(dag, indent=2, ensure_ascii=False) |
| 235 | + |
| 236 | + if output_path: |
| 237 | + Path(output_path).write_text(dag_json + "\n", encoding="utf-8") |
| 238 | + print(f"Wrote {len(dag['stages'])} stages to {output_path}", file=sys.stderr) |
| 239 | + else: |
| 240 | + print(dag_json) |
| 241 | + |
| 242 | + |
| 243 | +if __name__ == "__main__": |
| 244 | + main() |
0 commit comments