diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 857fed333..6aef68e7c 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -183,6 +183,35 @@ def _stable_app_id_hash(app_id: str, child_key: str) -> str: return hashlib.sha256(f"{app_id}:{child_key}".encode()).hexdigest() +def _salt_task_app_id(task: "SubGraphTask", sequence_id: Optional[int]) -> "SubGraphTask": + """Salts the sub-application ID with the parent's sequence_id so that repeated + invocations of the same parallel action within a parent application yield distinct + sub-application IDs. + + Without this, sub-app IDs collide across invocations and a cascaded + ``state_initializer`` (e.g. from ``initialize_from(...)`` on the parent) will + silently hydrate the prior call's persisted state instead of running the action. + See https://github.com/apache/burr/issues/761. + + ``sequence_id`` is the parent application's per-step counter, which is incremented + on every action execution -- making it the right discriminator for "which + invocation of this parallel action are we in". + + BREAKING (vs versions without this salting): sub-application IDs for any + ``TaskBasedParallelAction`` (``MapStates``, ``MapActions``, ``MapActionsAndStates``) + have changed. Sub-app state persisted under the old ID scheme is orphaned -- + on the first resume after upgrade, the sub-actions re-execute fresh rather + than load the old persisted result. This is the *fix* for #761; the old + scheme would have silently returned stale data instead. + """ + if sequence_id is None: + return task + task.application_id = hashlib.sha256( + f"{task.application_id}:{sequence_id}".encode() + ).hexdigest() + return task + + class TaskBasedParallelAction(SingleStepAction): """The base class for actions that run a set of tasks in parallel and reduce the results. This is more power-user mode -- if you need fine-grained control over the set of tasks @@ -269,6 +298,11 @@ def _run_and_update(): delete=[item for item in state.keys() if item.startswith("__")] ) task_generator = self.tasks(state_without_internals, context, run_kwargs) + # Salt sub-app IDs with the parent sequence_id so repeated invocations + # don't collide and silently replay prior persisted state (#761). + task_generator = ( + _salt_task_app_id(task, context.sequence_id) for task in task_generator + ) def execute_task(task): return task.run(run_kwargs["__context"]) @@ -296,6 +330,9 @@ async def state_generator(): This way we run through all of the task generators. These correspond to the task generation capabilities above (the map*/task generation stuff) """ all_tasks = await async_utils.arealize(task_generator) + # Salt sub-app IDs with the parent sequence_id so repeated invocations + # don't collide and silently replay prior persisted state (#761). + all_tasks = [_salt_task_app_id(task, context.sequence_id) for task in all_tasks] coroutines = [item.arun(context) for item in all_tasks] results = await asyncio.gather(*coroutines) # TODO -- yield in order... diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 25d37cc24..ecf799bc1 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -41,11 +41,17 @@ MapStates, RunnableGraph, SubGraphTask, + SubgraphType, TaskBasedParallelAction, _cascade_adapter, map_reduce_action, ) -from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData +from burr.core.persistence import ( + BaseStateLoader, + BaseStateSaver, + InMemoryPersister, + PersistedStateData, +) from burr.tracking.base import SyncTrackingClient from burr.visibility import ActionSpan @@ -1227,3 +1233,72 @@ def reads(self) -> list[str]: assert task.state_initializer is not None assert task.tracker is not None assert task.state_persister is task.state_initializer # This ensures they're the same + + +def test_map_states_reexecutes_on_repeated_invocations_with_initializer(): + """Regression test for https://github.com/apache/burr/issues/761. + + When a parent application is built with ``initialize_from(...)``, the cascaded + initializer used to hydrate sub-applications by ID. Sub-app IDs were + deterministic in ``(parent_app_id, i, j)`` only, so a second invocation of the + same parallel action collided with the first and silently replayed the prior + persisted state instead of re-running the action. + + This asserts that repeated invocations now produce fresh outputs. + """ + counter = {"n": 0} + + @action(reads=[], writes=["x"]) + def pick(state: State) -> State: + counter["n"] += 1 + return state.update(x=counter["n"]) + + @action(reads=[], writes=[]) + def back(state: State) -> State: + return state + + class Fan(MapStates): + def action(self, state: State, inputs: Dict[str, Any]) -> SubgraphType: + return pick + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for _ in range(3): + yield state + + def reduce(self, state: State, results: Generator[State, None, None]) -> State: + return state.update(xs=[s["x"] for s in results]) + + @property + def reads(self) -> list[str]: + return [] + + @property + def writes(self) -> list[str]: + return ["xs"] + + persister = InMemoryPersister() + app = ( + ApplicationBuilder() + .with_actions(fan=Fan(), back=back) + .with_transitions(("fan", "back"), ("back", "fan")) + .with_state_persister(persister) + .initialize_from( + persister, + resume_at_next_action=True, + default_state={}, + default_entrypoint="fan", + ) + .build() + ) + invocations = [] + for _ in range(3): + app.run(halt_after=["fan"]) + invocations.append(list(app.state["xs"])) + # Each invocation should run the action 3 times, producing strictly increasing + # counter values across invocations. If the bug regresses the same xs would + # appear in every invocation. + assert invocations[0] != invocations[1] + assert invocations[1] != invocations[2] + assert counter["n"] == 9