From fd0bf5533e5db9a2a4ddadd1a7a19041bc219d68 Mon Sep 17 00:00:00 2001 From: Elijah Ben Izzy Date: Wed, 13 May 2026 21:55:09 -0700 Subject: [PATCH 1/3] fix(parallelism): salt sub-app IDs with sequence_id to prevent stale replay (#761) Sub-application IDs in `TaskBasedParallelAction` (and its subclasses `MapStates`, `MapActions`, `MapActionsAndStates`) were deterministic in `(parent_app_id, i, j)` only, with no per-invocation discriminator. When the parent application was built with `initialize_from(...)`, the cascaded state initializer would, on the second invocation of the same parallel action, see a sub-app ID matching the first invocation's persisted state and silently hydrate it instead of re-running the action. The result: every invocation after the first replayed the first invocation's outputs. Fresh runs (no `initialize_from`) were unaffected because the cascaded initializer was None, so sub-apps never tried to load existing state even though their IDs still collided. Fix: in `TaskBasedParallelAction.run_and_update`, salt each task's `application_id` with the parent context's `sequence_id` (which increments per parent action step, so it uniquely identifies a given invocation). Applied via a small `_salt_task_app_id` helper in both the sync and async generator paths so all subclasses inherit the fix without needing to override `tasks()`. Also adds a regression test that exercises the full failure mode: builds an app with `initialize_from(...)` and a `MapStates`, invokes it three times, and asserts the per-invocation outputs differ. Note: this changes sub-app IDs across versions for any `TaskBasedParallelAction` subclass; persisted sub-app state from the old scheme will not be matched by the new IDs. --- burr/core/parallelism.py | 30 +++++++++++++ tests/core/test_parallelism.py | 77 +++++++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 857fed333..df93b4f6b 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -183,6 +183,28 @@ def _stable_app_id_hash(app_id: str, child_key: str) -> str: return hashlib.sha256(f"{app_id}:{child_key}".encode()).hexdigest() +def _salt_task_app_id(task: "SubGraphTask", sequence_id: Optional[int]) -> "SubGraphTask": + """Salts the sub-application ID with the parent's sequence_id so that repeated + invocations of the same parallel action within a parent application yield distinct + sub-application IDs. + + Without this, sub-app IDs collide across invocations and a cascaded + ``state_initializer`` (e.g. from ``initialize_from(...)`` on the parent) will + silently hydrate the prior call's persisted state instead of running the action. + See https://github.com/apache/burr/issues/761. + + ``sequence_id`` is the parent application's per-step counter, which is incremented + on every action execution -- making it the right discriminator for "which + invocation of this parallel action are we in". + """ + if sequence_id is None: + return task + task.application_id = hashlib.sha256( + f"{task.application_id}:{sequence_id}".encode() + ).hexdigest() + return task + + class TaskBasedParallelAction(SingleStepAction): """The base class for actions that run a set of tasks in parallel and reduce the results. This is more power-user mode -- if you need fine-grained control over the set of tasks @@ -269,6 +291,11 @@ def _run_and_update(): delete=[item for item in state.keys() if item.startswith("__")] ) task_generator = self.tasks(state_without_internals, context, run_kwargs) + # Salt sub-app IDs with the parent sequence_id so repeated invocations + # don't collide and silently replay prior persisted state (#761). + task_generator = ( + _salt_task_app_id(task, context.sequence_id) for task in task_generator + ) def execute_task(task): return task.run(run_kwargs["__context"]) @@ -296,6 +323,9 @@ async def state_generator(): This way we run through all of the task generators. These correspond to the task generation capabilities above (the map*/task generation stuff) """ all_tasks = await async_utils.arealize(task_generator) + # Salt sub-app IDs with the parent sequence_id so repeated invocations + # don't collide and silently replay prior persisted state (#761). + all_tasks = [_salt_task_app_id(task, context.sequence_id) for task in all_tasks] coroutines = [item.arun(context) for item in all_tasks] results = await asyncio.gather(*coroutines) # TODO -- yield in order... diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 25d37cc24..18d684d70 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -40,12 +40,18 @@ MapActionsAndStates, MapStates, RunnableGraph, + SubgraphType, SubGraphTask, TaskBasedParallelAction, _cascade_adapter, map_reduce_action, ) -from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData +from burr.core.persistence import ( + BaseStateLoader, + BaseStateSaver, + InMemoryPersister, + PersistedStateData, +) from burr.tracking.base import SyncTrackingClient from burr.visibility import ActionSpan @@ -1227,3 +1233,72 @@ def reads(self) -> list[str]: assert task.state_initializer is not None assert task.tracker is not None assert task.state_persister is task.state_initializer # This ensures they're the same + + +def test_map_states_reexecutes_on_repeated_invocations_with_initializer(): + """Regression test for https://github.com/apache/burr/issues/761. + + When a parent application is built with ``initialize_from(...)``, the cascaded + initializer used to hydrate sub-applications by ID. Sub-app IDs were + deterministic in ``(parent_app_id, i, j)`` only, so a second invocation of the + same parallel action collided with the first and silently replayed the prior + persisted state instead of re-running the action. + + This asserts that repeated invocations now produce fresh outputs. + """ + counter = {"n": 0} + + @action(reads=[], writes=["x"]) + def pick(state: State) -> State: + counter["n"] += 1 + return state.update(x=counter["n"]) + + @action(reads=[], writes=[]) + def back(state: State) -> State: + return state + + class Fan(MapStates): + def action(self, state: State, inputs: Dict[str, Any]) -> SubgraphType: + return pick + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for _ in range(3): + yield state + + def reduce(self, state: State, results: Generator[State, None, None]) -> State: + return state.update(xs=[s["x"] for s in results]) + + @property + def reads(self) -> list[str]: + return [] + + @property + def writes(self) -> list[str]: + return ["xs"] + + persister = InMemoryPersister() + app = ( + ApplicationBuilder() + .with_actions(fan=Fan(), back=back) + .with_transitions(("fan", "back"), ("back", "fan")) + .with_state_persister(persister) + .initialize_from( + persister, + resume_at_next_action=True, + default_state={}, + default_entrypoint="fan", + ) + .build() + ) + invocations = [] + for _ in range(3): + app.run(halt_after=["fan"]) + invocations.append(list(app.state["xs"])) + # Each invocation should run the action 3 times, producing strictly increasing + # counter values across invocations. If the bug regresses the same xs would + # appear in every invocation. + assert invocations[0] != invocations[1] + assert invocations[1] != invocations[2] + assert counter["n"] == 9 From 29c6c797acf9e6fdbd24ca7c22bb53bfb0c3a4f8 Mon Sep 17 00:00:00 2001 From: Elijah Ben Izzy Date: Thu, 14 May 2026 21:03:06 -0700 Subject: [PATCH 2/3] style: isort fix from pre-commit on tests/core/test_parallelism.py --- tests/core/test_parallelism.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py index 18d684d70..ecf799bc1 100644 --- a/tests/core/test_parallelism.py +++ b/tests/core/test_parallelism.py @@ -40,8 +40,8 @@ MapActionsAndStates, MapStates, RunnableGraph, - SubgraphType, SubGraphTask, + SubgraphType, TaskBasedParallelAction, _cascade_adapter, map_reduce_action, From b2340258b40c0644c6f171ab4ff08aa5ce73d239 Mon Sep 17 00:00:00 2001 From: Elijah Ben Izzy Date: Thu, 14 May 2026 21:16:31 -0700 Subject: [PATCH 3/3] docs: note breaking sub-app ID scheme change in _salt_task_app_id --- burr/core/parallelism.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index df93b4f6b..6aef68e7c 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -196,6 +196,13 @@ def _salt_task_app_id(task: "SubGraphTask", sequence_id: Optional[int]) -> "SubG ``sequence_id`` is the parent application's per-step counter, which is incremented on every action execution -- making it the right discriminator for "which invocation of this parallel action are we in". + + BREAKING (vs versions without this salting): sub-application IDs for any + ``TaskBasedParallelAction`` (``MapStates``, ``MapActions``, ``MapActionsAndStates``) + have changed. Sub-app state persisted under the old ID scheme is orphaned -- + on the first resume after upgrade, the sub-actions re-execute fresh rather + than load the old persisted result. This is the *fix* for #761; the old + scheme would have silently returned stale data instead. """ if sequence_id is None: return task