From 890742126059c9e140cf01486927b28a2ffcea5d Mon Sep 17 00:00:00 2001 From: puddingfjz <2811443837@qq.com> Date: Thu, 21 May 2026 17:47:41 +0800 Subject: [PATCH 1/6] Add: document Python callable dynamic registration --- docs/python-callable-serialization.md | 503 ++++++++++++++++++++++++++ docs/python-packaging.md | 13 +- pyproject.toml | 1 + 3 files changed, 514 insertions(+), 3 deletions(-) create mode 100644 docs/python-callable-serialization.md diff --git a/docs/python-callable-serialization.md b/docs/python-callable-serialization.md new file mode 100644 index 000000000..85a6a0bd1 --- /dev/null +++ b/docs/python-callable-serialization.md @@ -0,0 +1,503 @@ +# Python Callable Serialization for L3+ Register + +This document specifies a design for registering Python callables after an +L3+ `Worker` has already initialized, and in the common case after child +processes have already started. + +The design is separate from +[callable-ipc-dynamic-register.md](callable-ipc-dynamic-register.md). That +document covers `ChipCallable` binary registration for chip children. This +document covers Python callables consumed by SUB workers and by higher-level +Worker-child dispatch loops. + +It is a design document, not an implementation. + +--- + +## 1. Context + +Every task submitted through the hierarchical runtime carries a `callable_id`. +For L3+ Python execution paths, that id is resolved in a Python registry: + +| Submit path | Recipient | Registry entry | +| ----------- | --------- | -------------- | +| `orch.submit_sub(cid, ...)` | SUB child | Python sub callable | +| L4+ `submit_next_level(cid, ...)` | Worker child | Python orch callable | + +Today, these entries must be registered before fork. The child process sees +the parent's `_callable_registry` only through fork-time copy-on-write. Any +parent-side mutation after fork is invisible to the already-running child. + +`ChipCallable` post-init registration already uses a control-plane plus +side-band shm payload because binary callables can be copied and prepared in +chip children. Python callables need the same high-level shape, but the +payload is serialized Python code/data and the recipients are Python-capable +children, not chip children. + +### Goals + +- Allow `Worker.register(py_callable)` after `Worker.init()` at level >= 3. +- Make the returned `cid` usable when `register()` returns. +- Preserve the current registration behavior before children start. +- Reuse the existing mailbox control-plane and per-mailbox serialization + against in-flight dispatch. +- Support unregister and cid reuse. +- Keep the API synchronous and deterministic from the caller's perspective. + +### Non-goals + +- Dynamic registration of `ChipCallable`; that is covered by the binary + callable design. +- Cross-host or cross-Python-version serialization. +- Detecting or recovering from child process crashes while a mailbox control + request is in flight. This is a shared control-plane liveness limitation, + not specific to Python callable registration. +- Loading untrusted serialized bytes safely. This feature unpickles code from + the same user process and is not a security boundary. +- Automatically registering callables inside arbitrary descendant Workers. + A `Worker.register()` call updates the registry owned by that Worker and + the already-started children that consume that registry. +- Changing `MAX_REGISTERED_CALLABLE_IDS`. + +--- + +## 2. Public Contract + +`Worker.register(target)` keeps one cid space for both `ChipCallable` and +Python callables. The target type selects the dynamic-register route. + +- L2 `ChipCallable`: existing prepare path. +- L2 Python callable: invalid target. +- L3+ before this Worker has started child processes: store the target in the + parent registry; future children will inherit the registry when they start. + This preserves the pre-`init()` behavior and extends it to the post-`init()`, + before-first-`run()` window, where no child process has been forked yet. +- L3+ after this Worker has started child processes: existing binary IPC for + `ChipCallable`, and the new serialized Python IPC path for Python callables. + +The post-start Python path is synchronous: + +1. The parent allocates a cid and stores `target` in its registry. +2. The parent serializes `target`. +3. The parent broadcasts the payload to every Python-capable child that may + resolve this Worker's registry. +4. Each child deserializes the payload and updates its local registry. +5. `register()` returns only after every required child has acknowledged. + +The parent must not submit a newly registered cid until `register()` returns. +The runtime does not attempt to make a cid visible before the synchronous +broadcast completes. + +### Recipients + +The parent routes Python callable registration to Python-capable children: + +- SUB child processes of the same Worker. +- L4+ next-level Worker-child dispatch loops, because they resolve the + parent's registered orch functions before calling `inner_worker.run(...)`. + +L3 chip children are not recipients for Python callable payloads. They can +only consume prepared `ChipCallable` ids. + +Because `Worker.register()` does not currently take a "sub" versus +"next-level orch" kind, the simplest compatible policy is to broadcast to all +Python-capable child groups owned by this Worker. Extra registry entries are +inert if a cid is never submitted to that worker type. + +This preserves the current public API: `Worker.register(target)` does not gain +an explicit target-kind parameter. Submit-time APIs continue to decide how the +cid is interpreted. + +If no Python-capable child exists after children start, registering a Python +callable should fail with a clear `RuntimeError`. Keeping a cid that no child +can ever resolve is more confusing than rejecting it. + +### Callable Shape + +The runtime does not validate function signatures at register time. Existing +dispatch-time behavior remains: + +- SUB callables are invoked as `fn(args)`. +- Worker-child orchestration callables are invoked through + `inner_worker.run(orch_fn, args, cfg)`, so they must match the usual + orchestration shape. + +Signature errors surface from the child execution path and are reported +through the mailbox error field, as they are today. + +--- + +## 3. Serialization + +The payload must fit outside the 4 KB mailbox, so Python callables use a +side-band POSIX shm exactly like dynamic `ChipCallable` registration. The +mailbox carries only a shm name and cid. + +### Serializer Policy + +Dynamic Python callable registration uses `cloudpickle`. + +`cloudpickle` is a runtime dependency of the `simpler` package, not only a test +dependency, because child processes deserialize user callables during normal +`Worker.register()` operation. + +Registration before children start already allows lambdas and closures because +the startup path copies the registry directly. A dynamic feature that rejects +these common shapes would be surprising and would make several existing L3/L4 +test patterns impossible to move to dynamic registration. + +Stdlib `pickle` is not used for this path because it serializes most +functions by module/name reference and is therefore limited to importable +top-level functions and callable classes. It is a useful mental model for the +trust boundary, but it is not the runtime format. + +### Callable Shape and Closure Semantics + +Post-start registration supports callable shapes that `cloudpickle` can +serialize and the child can deserialize in the same Python environment: + +- importable top-level functions; +- lambdas and nested functions whose captured values are serializable; +- callable class instances whose instance state is serializable. + +This is not identical to registration before children start. Startup children +inherit a snapshot of the parent's address space, so a closure may appear to +work because the child inherited the captured object at startup. Post-start +registration sends serialized bytes to an already-running child, so captured +objects are copied or reconstructed through `cloudpickle`. + +Callables should not rely on captured process-local resources being equivalent +to fork inheritance. Examples include locks, events, open files, sockets, +`SharedMemory.buf` memoryviews, mmap views, `Worker` or `ChipWorker` instances, +nanobind/C++ handles, and device-pointer wrappers. Prefer capturing stable +identifiers that the child can reopen or reconstruct, such as a shared-memory +name instead of a live `SharedMemory.buf` object. + +### Payload Format + +The parent serializes the callable into an in-memory byte blob. The C++ +broadcast binding creates the side-band POSIX shm, copies that blob into it, +fan-outs the shm name to children, and unlinks the shm after all child +round-trips have completed. Python does not create or unlink the broadcast shm. + +The Python binding must accept a Python buffer object, preferably `bytes`, not +only a raw integer pointer. The binding copies the buffer into the staging shm +before releasing any reference to the Python object. This avoids depending on a +temporary `cloudpickle.dumps(...)` result staying alive while nanobind has +released the GIL and C++ worker threads are fanning out the control command. + +The shm bytes are exactly the result of `cloudpickle.dumps(target)`. There is no +custom payload header in the first implementation. The control command already +identifies the operation as Python callable registration, and this feature is a +trusted local transport rather than a cross-version or untrusted wire protocol. +Malformed or incompatible bytes fail through `cloudpickle.loads(...)` and are +reported through the normal mailbox error field. + +### Child Deserialization + +Each recipient child: + +1. Opens the shm by name. +2. Copies the shm contents into `bytes`. +3. Verifies that `cid` is in `[0, MAX_REGISTERED_CALLABLE_IDS)`. +4. Deserializes the callable with `cloudpickle`. +5. Verifies that the result is callable. +6. Installs it into the child's local registry under the requested cid. +7. Closes the shm and acknowledges `CONTROL_DONE`. + +For cid reuse after partial unregister failures, Python registration should +overwrite `registry[cid]` in the child. The parent only allocates free cids +from its own registry, so an existing child entry at the same cid is residue +from a prior best-effort failure and should be replaced. + +Because Python callables and `ChipCallable` objects share one cid space, the +same cleanup rule also applies when a cid is reused across target types. A +post-start `ChipCallable` registration must clear any stale Python dispatch +entry for the same cid from Python-capable children before the cid is reported +usable. Otherwise a failed Python unregister followed by `ChipCallable` reuse +could leave a Worker-child dispatch loop resolving the old Python callable. + +--- + +## 4. Control Plane + +Add new control subcommands rather than overloading the existing +`CTRL_REGISTER` used for `ChipCallable`: + +```text +CTRL_PY_REGISTER = 10 +CTRL_PY_UNREGISTER = 11 +``` + +The mailbox layout for `CTRL_PY_REGISTER` mirrors binary register: + +| Offset | Field | Notes | +| ------ | ----- | ----- | +| `OFF_CALLABLE` | sub_cmd = `CTRL_PY_REGISTER` | uint64 | +| `CTRL_OFF_ARG0` | cid | low 32 bits | +| `OFF_ARGS[0..]` | NUL-terminated shm name | fixed-width slot | + +`CTRL_PY_UNREGISTER` carries only the cid in `CTRL_OFF_ARG0`. + +### Parent-Side Flow + +`Worker.register(target)` gains a Python-callable dynamic route: + +1. Hold `_registry_lock`. +2. Reject non-callable Python targets. +3. If `_py_register_active_run` is true, raise `RuntimeError`. +4. If this Worker has not started child processes, allocate the smallest free + cid, insert `self._callable_registry[cid] = target`, and return the cid; + future children will inherit the registry when they start. +5. If no configured Python-capable child group exists, raise `RuntimeError`. +6. Allocate the smallest free cid. +7. Insert `self._callable_registry[cid] = target`. +8. Serialize the target into a bytes blob with `cloudpickle.dumps(...)`. +9. Broadcast `CTRL_PY_REGISTER` to required Python-capable worker groups. +10. On any failure, pop the parent registry entry and raise. +11. Return cid on success. + +The "configured Python-capable child group" check uses the Worker's own +configuration, not child-process state: + +- `num_sub_workers > 0` means SUB children will consume this registry. +- `len(_next_level_workers) > 0` means Worker children will consume this + registry. + +This check applies only after child processes have started. Before children +start, including after `init()` but before the first `run()`, registration uses +the parent-registry path and does not reject unused Python callables. + +If no free cid exists in `[0, MAX_REGISTERED_CALLABLE_IDS)`, register raises +`RuntimeError` before mutating the parent registry or broadcasting to children. +The caller can recover by unregistering unused callables and retrying. + +`_py_register_active_run` is initialized to false in `Worker.__init__`. +`Worker.run()` and Python callable register/unregister use `_registry_lock` as +their gate: + +1. Before starting L3+ children or entering the orchestration body, + `Worker.run()` performs the active-state transition as one critical section: + + ```python + with self._registry_lock: + if self._py_register_active_run: + raise RuntimeError("Worker.run() is already active") + self._py_register_active_run = True + ``` + + On the first run, this happens before `_start_hierarchical()` takes the + startup registry snapshot, so a concurrent `register()` cannot return + through the startup path after children have already missed that registry + entry. +2. Python callable register/unregister hold `_registry_lock` for the full + parent-side mutation and broadcast. A concurrent `Worker.run()` blocks until + the broadcast has completed, then marks the run active. +3. A Python callable register/unregister that starts after `Worker.run()` has + marked the run active acquires `_registry_lock`, observes the active flag, + raises `RuntimeError`, and leaves the registry unchanged. +4. `Worker.run()` clears `_py_register_active_run` under `_registry_lock` in a + `finally` block after drain and cleanup. + +This makes dynamic Python registration deterministic: it is supported after +children start between `run()` calls, but rejected while a run is actively +submitting or draining tasks. + +This requires a generic C++ binding that can broadcast a control command to a +selected worker pool: + +```python +_Worker.broadcast_control_all(worker_type, sub_cmd, cid, payload=None) +``` + +`worker_type` selects `SUB` versus `NEXT_LEVEL`; `sub_cmd` is +`CTRL_PY_REGISTER` or `CTRL_PY_UNREGISTER`. For register, `payload` is the +`cloudpickle`-serialized callable, passed as a Python buffer object. For +unregister, `payload` is absent. The binding owns shm creation, copying, +fan-out, and unlink when a payload is present, matching +`broadcast_register_all` for binary callables while avoiding four +near-identical Python-specific bindings. + +The binding's error contract is subcommand-specific: + +- `CTRL_PY_REGISTER` raises on any child error so the parent can pop the newly + allocated cid before reporting failure. +- `CTRL_PY_UNREGISTER` returns per-child error messages for best-effort + cleanup; the parent warns and releases its cid slot even if some children + failed. + +The existing `mailbox_mu_` must be held for each child round trip, just like +binary register. This serializes Python register/unregister against +`TASK_READY` dispatch on the same child. + +Every child `CONTROL_REQUEST` handler, including existing chip-child handlers, +must reject unknown subcommands by writing `OFF_ERROR` and publishing +`CONTROL_DONE`. A misrouted Python control command must fail visibly, not ACK +as a successful no-op. + +### Parent-Side Unregister + +`Worker.unregister(cid)` uses the registered target type to select the +unregister route: + +1. Hold `_registry_lock`. +2. Raise `KeyError` if `cid` is absent from the parent registry. +3. If the target is a Python callable and `_py_register_active_run` is true, + raise `RuntimeError` before mutation. +4. If the Worker has not started child processes yet, pop the parent entry and + return. Future children will inherit the already-removed registry. +5. For a post-start `ChipCallable`, keep the existing binary unregister path. +6. If the target is a Python callable and this Worker has started child + processes, broadcast `CTRL_PY_UNREGISTER` to every Python-capable child + group configured for this Worker, regardless of when the callable was + originally registered. +7. Warn on per-child unregister errors, but pop the parent registry entry + unconditionally so the cid slot becomes reusable. + +Python callable unregister never cascades into `inner_worker.unregister(...)`. +For L4+ Worker children it removes only the parent-owned dispatch registry entry +inside `_child_worker_loop`, matching the `CTRL_PY_REGISTER` ownership rule. + +Unregister is still best-effort, but reuse must self-heal. Before any +post-start `ChipCallable` registration for a cid that may have previously held a +Python callable, the parent must clear that cid from all Python-capable child +registries owned by the same Worker. This can reuse `CTRL_PY_UNREGISTER` as an +idempotent "clear Python dispatch entry" command. If the clear step fails during +registration, the new registration fails, the parent pops the newly allocated +cid, and no reverse rollback is attempted. + +### SUB Child Handler + +`_sub_worker_loop` currently handles `TASK_READY` and `SHUTDOWN`. It gains a +`CONTROL_REQUEST` branch: + +- `CTRL_PY_REGISTER`: deserialize the callable and store `registry[cid] = fn`. +- `CTRL_PY_UNREGISTER`: `registry.pop(cid, None)`. +- Any unknown control subcommand: write `OFF_ERROR`, publish `CONTROL_DONE`, + and leave the registry unchanged. + +The loop is single-threaded, and parent-side `mailbox_mu_` serializes control +commands against task dispatch, so no child-side lock is required. + +### Worker-Child Handler + +`_child_worker_loop` already has a `CONTROL_REQUEST` branch for binary +callable cascade. It gains Python subcommands with different semantics: + +- `CTRL_PY_REGISTER`: deserialize and store into the `registry` dict passed + to `_child_worker_loop`. +- `CTRL_PY_UNREGISTER`: remove from that same `registry`. +- Existing binary `CTRL_REGISTER`: before cascading the `ChipCallable` into + `inner_worker._register_at(...)`, remove `registry[cid]` from the + Worker-child dispatch registry. This self-heals stale Python callable residue + when a cid is reused as a `ChipCallable`. +- Any unknown control subcommand: write `OFF_ERROR`, publish `CONTROL_DONE`, + and leave both the parent-owned dispatch registry and `inner_worker` + unchanged. + +This registry is the dispatch registry used when the parent submits a cid to +the Worker child. It is distinct from `inner_worker._callable_registry`. +Updating it makes a dynamically registered parent orch function visible to +the already-started Worker child. + +The Python callable is not automatically cascaded into +`inner_worker._callable_registry`. Registering callables owned by an inner +Worker remains a separate operation on that Worker. This keeps cid ownership +local and avoids unexpected collisions with entries the inner Worker already +owns. + +--- + +## 5. Failure Modes and Tests + +### Failure Semantics + +| Trigger | Handling | +| ------- | -------- | +| `cloudpickle` unavailable | Import fails at parent register time | +| Serializer cannot encode target | Parent pops cid and raises before IPC | +| Post-start no Python child group | Parent raises before cid allocation | +| cid space exhausted | Parent raises before parent mutation | +| Active `Worker.run()` register | Parent raises before cid allocation | +| Active `Worker.run()` unregister | Parent raises before parent mutation | +| Child cannot open shm | Child writes `OFF_ERROR`; parent raises | +| Child receives invalid cid | Child writes `OFF_ERROR`; parent raises | +| Child deserialization fails | Child writes `OFF_ERROR`; parent raises | +| Result is not callable | Child writes `OFF_ERROR`; parent raises | +| Unknown control subcommand | Child writes `OFF_ERROR`; parent raises | +| Some children succeed before another fails | Parent raises; no rollback | +| Unregister fails on some children | Parent warns and pops its registry | +| Cross-type cid reuse | New register clears or overwrites child residue | +| Child crashes during control | Parent may hang waiting for `CONTROL_DONE` | + +No reverse rollback is attempted after partial register success. A successful +child may retain a registry entry for a cid the parent reports as failed. +Future cid reuse must overwrite it for Python registration, or clear it before +`ChipCallable` registration, matching the best-effort unregister contract. + +If a child process crashes or stops polling its mailbox during +`CONTROL_REQUEST`, the parent may wait indefinitely for `CONTROL_DONE`. This is +the same liveness failure mode as existing mailbox control operations such as +`CTRL_PREPARE`, `CTRL_MALLOC`, and binary dynamic register. Adding timeout, +child liveness detection, and hierarchical recovery is out of scope for this +feature and should be handled as a broader control-plane reliability change. + +### Concurrency + +- Parent registry mutation stays under `_registry_lock`. +- The first `Worker.run()` marks `_py_register_active_run` before children + startup, preventing a callable from being inserted after the startup registry + snapshot but before the caller observes children as started. +- Each child mailbox round trip stays under `mailbox_mu_`. +- `register()` is synchronous. A caller that races `register()` and + `Worker.run()` from different Python threads must still wait for + `register()` to return before submitting the new cid. +- Child registry mutation is serialized by the mailbox state machine. +- The first implementation requires a quiescent Worker: dynamic Python + callable registration while `Worker.run()` is actively executing is rejected + with a clear error. Post-start registration between `run()` calls is the + supported target. + +### Test Plan + +Keep the first implementation's tests focused on behavior and ownership, not on +format evolution: + +- Unit test `cloudpickle` round trip for the supported callable shapes. +- Unit test that closures over serializable Python values work, and that + specific known-unpickleable captures fail before cid visibility. +- Unit test that child-side deserialize and execute failures are reported + through the normal mailbox error path. +- Unit test that Python register before children start uses the startup + registry path and performs no control broadcast. +- Unit test that first-run startup is serialized against Python register, so a + racing register cannot miss the startup registry snapshot. +- Unit test that post-start Python register rejects Workers with no SUB workers + and no next-level Worker children. +- Unit test selected-pool routing: `worker_type=SUB` reaches only + `sub_threads_`, and `worker_type=NEXT_LEVEL` reaches only + `next_level_threads_`. +- L3 integration test: start an L3 Worker with SUB workers, run once to start + children, dynamically register a Python sub callable, then + `submit_sub(cid, ...)`. +- L4 integration test: start an L4 Worker with an L3 child, run once to start + children, dynamically register an L3 orchestration callable on the L4 parent, + then `submit_next_level(cid, ...)`. +- Unregister test: once children have started, Python callable unregister + broadcasts `CTRL_PY_UNREGISTER`, pops the parent registry, and allows cid + reuse regardless of whether the callable was registered before or after + children started. +- Cross-type reuse test: stale Python dispatch residue from a failed + best-effort unregister is cleared when the same cid is reused for a + `ChipCallable`. +- Failure test: unsupported or non-serializable callable raises and releases the + parent cid slot. + +## Related + +- [task-flow.md](task-flow.md) explains how `Callable`, `TaskArgs`, and + `CallConfig` move through L3+ dispatch. +- [worker-manager.md](worker-manager.md) explains WorkerThread mailbox + dispatch and forked Python child loops. +- [callable-ipc-dynamic-register.md](callable-ipc-dynamic-register.md) + covers dynamic binary `ChipCallable` registration. diff --git a/docs/python-packaging.md b/docs/python-packaging.md index b7e936dd3..8e9bddff3 100644 --- a/docs/python-packaging.md +++ b/docs/python-packaging.md @@ -51,12 +51,19 @@ Internal coupling: `simpler_setup.toolchain`, `simpler_setup.kernel_compiler`, a | Category | Packages | | -------- | -------- | -| `simpler` runtime | No third-party Python deps. Requires platform backend: simulation (`a*sim`) or NPU hardware (`a2a3`/`a5` with CANN toolkit) | -| `simpler_setup` runtime | `torch` (tensor operations in golden scripts, test comparison) | +| `simpler` runtime | `cloudpickle`; platform backend | +| `simpler_setup` runtime | `torch` for golden/test tensor operations | | Build | `scikit-build-core`, `nanobind`, `cmake` | | Test | `pytest` (ut-py, st), `googletest` + `ctest` (ut-cpp) | -`pyproject.toml` declares no `[project.dependencies]` — both `torch` and `pytest` are environment prerequisites, not pip-installed transitively. This is intentional: torch's index URL (`--index-url https://download.pytorch.org/whl/cpu`) and hardware-specific builds make automatic resolution impractical. +`pyproject.toml` declares `cloudpickle` as a `[project.dependencies]` runtime +dependency. `torch` and `pytest` remain environment prerequisites, not +pip-installed transitively. This is intentional: torch's index URL +(`--index-url https://download.pytorch.org/whl/cpu`) and hardware-specific +builds make automatic resolution impractical. + +The `simpler` runtime also requires a platform backend: simulation (`a*sim`) or +NPU hardware (`a2a3`/`a5` with CANN toolkit). ### `PROJECT_ROOT` resolution diff --git a/pyproject.toml b/pyproject.toml index 20296b40b..fb819b591 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ build-backend = "scikit_build_core.build" name = "simpler" version = "0.1.0" requires-python = ">=3.9" +dependencies = ["cloudpickle>=2.2"] [project.optional-dependencies] # ``torch>=2.3`` is required by ``simpler_setup.torch_interop`` (uses From 8601557a0cfbe96b3ee8ced80c52b8d4767cf9ed Mon Sep 17 00:00:00 2001 From: puddingfjz <2811443837@qq.com> Date: Fri, 22 May 2026 12:42:05 +0800 Subject: [PATCH 2/6] docs: revise python callable serialization v2 --- docs/python-callable-serialization.md | 324 +++++++++++++++++--------- 1 file changed, 217 insertions(+), 107 deletions(-) diff --git a/docs/python-callable-serialization.md b/docs/python-callable-serialization.md index 85a6a0bd1..41d301d97 100644 --- a/docs/python-callable-serialization.md +++ b/docs/python-callable-serialization.md @@ -46,12 +46,15 @@ children, not chip children. ### Non-goals -- Dynamic registration of `ChipCallable`; that is covered by the binary - callable design. +- Dynamic registration of `ChipCallable`; that protocol is covered by the + binary callable design. This document only adds the Python-residue cleanup + hook that the existing `ChipCallable` register implementation needs when a + shared cid is reused across target types. - Cross-host or cross-Python-version serialization. -- Detecting or recovering from child process crashes while a mailbox control - request is in flight. This is a shared control-plane liveness limitation, - not specific to Python callable registration. +- Recovering a child process that crashes or wedges while a mailbox control + request is in flight. This design specifies timeout reporting for Python + callable broadcasts, but rebuilding the child process tree is broader + control-plane reliability work. - Loading untrusted serialized bytes safely. This feature unpickles code from the same user process and is not a security boundary. - Automatically registering callables inside arbitrary descendant Workers. @@ -77,8 +80,8 @@ Python callables. The target type selects the dynamic-register route. The post-start Python path is synchronous: -1. The parent allocates a cid and stores `target` in its registry. -2. The parent serializes `target`. +1. The parent serializes `target`. +2. The parent allocates a cid and stores `target` in its registry. 3. The parent broadcasts the payload to every Python-capable child that may resolve this Worker's registry. 4. Each child deserializes the payload and updates its local registry. @@ -151,6 +154,13 @@ functions by module/name reference and is therefore limited to importable top-level functions and callable classes. It is a useful mental model for the trust boundary, but it is not the runtime format. +This design assumes child processes are forked from the same Python process and +therefore share the same Python major/minor version, installed package set, and +`cloudpickle` runtime. If a future startup mode uses `spawn` or independently +provisioned interpreters, dynamic Python callable registration is supported only +when the child environment is version-compatible with the parent and can import +the callable's dependencies. + ### Callable Shape and Closure Semantics Post-start registration supports callable shapes that `cloudpickle` can @@ -182,28 +192,42 @@ round-trips have completed. Python does not create or unlink the broadcast shm. The Python binding must accept a Python buffer object, preferably `bytes`, not only a raw integer pointer. The binding copies the buffer into the staging shm -before releasing any reference to the Python object. This avoids depending on a -temporary `cloudpickle.dumps(...)` result staying alive while nanobind has -released the GIL and C++ worker threads are fanning out the control command. - -The shm bytes are exactly the result of `cloudpickle.dumps(target)`. There is no -custom payload header in the first implementation. The control command already -identifies the operation as Python callable registration, and this feature is a -trusted local transport rather than a cross-version or untrusted wire protocol. -Malformed or incompatible bytes fail through `cloudpickle.loads(...)` and are -reported through the normal mailbox error field. +before it releases the Python object reference or fans out worker threads. The +binding must not retain a raw pointer into the Python buffer after returning or +after releasing the GIL for control fan-out. + +The shm starts with a minimal Python-callable payload header followed by the +exact bytes returned by `cloudpickle.dumps(target)`: + +| Field | Size | Value | +| ----- | ---- | ----- | +| magic | 4 bytes | `SPYC` | +| version | 1 byte | `1` | +| serializer | 1 byte | `1` for `cloudpickle` | +| flags | 2 bytes | reserved, must be zero | +| payload_size | 8 bytes | little-endian unsigned byte count | + +The first implementation accepts only `(magic="SPYC", version=1, +serializer=1, flags=0)`. Unknown magic, version, serializer, non-zero flags, +size mismatch, malformed bytes, or incompatible pickle data fail through the +normal mailbox error field. ### Child Deserialization Each recipient child: 1. Opens the shm by name. -2. Copies the shm contents into `bytes`. -3. Verifies that `cid` is in `[0, MAX_REGISTERED_CALLABLE_IDS)`. -4. Deserializes the callable with `cloudpickle`. -5. Verifies that the result is callable. -6. Installs it into the child's local registry under the requested cid. -7. Closes the shm and acknowledges `CONTROL_DONE`. +2. Validates the payload header. +3. Copies the payload region into `bytes`. +4. Verifies that `cid` is in `[0, MAX_REGISTERED_CALLABLE_IDS)`. +5. Deserializes the callable with `cloudpickle.loads(payload_bytes)`. +6. Verifies that the result is callable. +7. Installs it into the child's local registry under the requested cid. +8. Closes the shm and acknowledges `CONTROL_DONE`. + +The child intentionally copies the payload region before deserializing it. +This avoids coupling `cloudpickle.loads(...)` to the lifetime rules of an +active `SharedMemory.buf` memoryview and keeps shm close/unlink behavior simple. For cid reuse after partial unregister failures, Python registration should overwrite `registry[cid]` in the child. The parent only allocates free cids @@ -212,10 +236,13 @@ from a prior best-effort failure and should be replaced. Because Python callables and `ChipCallable` objects share one cid space, the same cleanup rule also applies when a cid is reused across target types. A -post-start `ChipCallable` registration must clear any stale Python dispatch -entry for the same cid from Python-capable children before the cid is reported -usable. Otherwise a failed Python unregister followed by `ChipCallable` reuse -could leave a Worker-child dispatch loop resolving the old Python callable. +post-start `ChipCallable` registration performed after this feature lands must +clear any stale Python dispatch entry for the same cid from Python-capable +children owned by the same Worker before the cid is reported usable. This is a +v2 integration hook on the existing `ChipCallable` register implementation, not +a new binary payload protocol. Otherwise a failed Python unregister followed by +`ChipCallable` reuse could leave a Worker-child dispatch loop resolving the old +Python callable. --- @@ -243,19 +270,23 @@ The mailbox layout for `CTRL_PY_REGISTER` mirrors binary register: `Worker.register(target)` gains a Python-callable dynamic route: -1. Hold `_registry_lock`. -2. Reject non-callable Python targets. -3. If `_py_register_active_run` is true, raise `RuntimeError`. -4. If this Worker has not started child processes, allocate the smallest free - cid, insert `self._callable_registry[cid] = target`, and return the cid; - future children will inherit the registry when they start. -5. If no configured Python-capable child group exists, raise `RuntimeError`. -6. Allocate the smallest free cid. -7. Insert `self._callable_registry[cid] = target`. -8. Serialize the target into a bytes blob with `cloudpickle.dumps(...)`. -9. Broadcast `CTRL_PY_REGISTER` to required Python-capable worker groups. -10. On any failure, pop the parent registry entry and raise. -11. Return cid on success. +1. Reject non-callable Python targets. +2. If the first hierarchical startup is in progress, wait for that startup to + either complete or fail without holding `_registry_lock`. A registration + must not return through the startup path after the fork-time registry + snapshot has already been taken. +3. If this Worker has not started child processes, hold `_registry_lock`, + allocate the smallest free cid, insert + `self._callable_registry[cid] = target`, and return the cid; future children + will inherit the registry when they start. +4. If no configured Python-capable child group exists, raise `RuntimeError`. +5. Serialize the target into a bytes blob with `cloudpickle.dumps(...)`. +6. Hold `_registry_lock`, allocate the smallest free cid, insert + `self._callable_registry[cid] = target`, and release `_registry_lock`. +7. Broadcast `CTRL_PY_REGISTER` to required Python-capable worker groups. +8. On any failure, reacquire `_registry_lock`, remove the parent registry entry + if it still points at this target, and raise. +9. Return cid on success. The "configured Python-capable child group" check uses the Worker's own configuration, not child-process state: @@ -272,42 +303,39 @@ If no free cid exists in `[0, MAX_REGISTERED_CALLABLE_IDS)`, register raises `RuntimeError` before mutating the parent registry or broadcasting to children. The caller can recover by unregistering unused callables and retrying. -`_py_register_active_run` is initialized to false in `Worker.__init__`. -`Worker.run()` and Python callable register/unregister use `_registry_lock` as -their gate: - -1. Before starting L3+ children or entering the orchestration body, - `Worker.run()` performs the active-state transition as one critical section: - - ```python - with self._registry_lock: - if self._py_register_active_run: - raise RuntimeError("Worker.run() is already active") - self._py_register_active_run = True - ``` - - On the first run, this happens before `_start_hierarchical()` takes the - startup registry snapshot, so a concurrent `register()` cannot return - through the startup path after children have already missed that registry - entry. -2. Python callable register/unregister hold `_registry_lock` for the full - parent-side mutation and broadcast. A concurrent `Worker.run()` blocks until - the broadcast has completed, then marks the run active. -3. A Python callable register/unregister that starts after `Worker.run()` has - marked the run active acquires `_registry_lock`, observes the active flag, - raises `RuntimeError`, and leaves the registry unchanged. -4. `Worker.run()` clears `_py_register_active_run` under `_registry_lock` in a - `finally` block after drain and cleanup. - -This makes dynamic Python registration deterministic: it is supported after -children start between `run()` calls, but rejected while a run is actively -submitting or draining tasks. +The startup race is handled by a one-time hierarchical startup state, not by a +run-wide quiescent guard: + +- `_hierarchical_start_state` is protected by a dedicated + `_hierarchical_start_mu` / `_hierarchical_start_cv`, separate from + `_registry_lock`. +- Startup begins as `not_started`, moves to `starting` before + `_start_hierarchical()` takes the registry snapshot, and moves to `started` + only after child mailboxes are registered with the C++ Worker. +- A Python callable register/unregister that observes `starting` waits on a + condition variable without holding `_registry_lock`. After startup succeeds, + it uses the post-start control path; after startup fails, it raises. +- `_start_hierarchical()` snapshots `self._callable_registry` while holding + `_registry_lock`, then forks children from that immutable snapshot. It must + not hold `_registry_lock` across `os.fork()`. + +Once children have started, dynamic Python registration is allowed while +`Worker.run()` is actively submitting or draining tasks. The operation is still +synchronous: the caller must wait for `register()` to return before submitting +the new cid. Per-child `mailbox_mu_` serialization orders each +`CTRL_PY_REGISTER` / `CTRL_PY_UNREGISTER` round trip against any in-flight +`TASK_READY` on that same child mailbox. + +`_registry_lock` protects parent-side cid allocation and registry mutation +only. It is not held while waiting for child ACKs from +`broadcast_control_all`. This requires a generic C++ binding that can broadcast a control command to a selected worker pool: ```python -_Worker.broadcast_control_all(worker_type, sub_cmd, cid, payload=None) +_Worker.broadcast_control_all(worker_type, sub_cmd, cid, payload=None, + timeout_s=None) ``` `worker_type` selects `SUB` versus `NEXT_LEVEL`; `sub_cmd` is @@ -318,13 +346,34 @@ fan-out, and unlink when a payload is present, matching `broadcast_register_all` for binary callables while avoiding four near-identical Python-specific bindings. -The binding's error contract is subcommand-specific: +For a selected worker pool, fan-out is parallel: C++ starts one worker thread +per target child, each round trip holds that child's `mailbox_mu_`, and the +binding waits for every child to publish `CONTROL_DONE` before returning the +per-child results. Latency is bounded by the slowest child round trip, not by +the sum of all child round trips. + +`timeout_s` is optional. When set, each child round trip that does not publish +`CONTROL_DONE` before the deadline returns a failed result with a timeout error +message. The timeout does not repair the wedged child or reclaim a mailbox +that is still owned by a stuck control command; it only bounds the caller's +wait and makes the partial failure visible to Python policy code. + +The binding always returns structured per-child results. It does not switch +between "raise" and "return errors" based on `sub_cmd`. Python decides whether +those results are strict or best-effort: + +```text +ControlResult(worker_type, worker_index, ok, error_message) +``` -- `CTRL_PY_REGISTER` raises on any child error so the parent can pop the newly - allocated cid before reporting failure. -- `CTRL_PY_UNREGISTER` returns per-child error messages for best-effort - cleanup; the parent warns and releases its cid slot even if some children - failed. +- `Worker.register()` treats any failed `CTRL_PY_REGISTER` result as strict: + it removes the new parent registry entry and raises. +- `Worker.unregister()` treats failed `CTRL_PY_UNREGISTER` results as + best-effort: it warns, then releases its parent cid slot after the broadcast + has returned. +- The cross-type reuse hook treats failed Python-residue cleanup as strict: it + fails the `ChipCallable` registration before starting binary + `CTRL_REGISTER`. The existing `mailbox_mu_` must be held for each child round trip, just like binary register. This serializes Python register/unregister against @@ -335,24 +384,47 @@ must reject unknown subcommands by writing `OFF_ERROR` and publishing `CONTROL_DONE`. A misrouted Python control command must fail visibly, not ACK as a successful no-op. +### Cross-Type Reuse Hook + +The existing post-start `ChipCallable` register path keeps its binary payload +protocol, but gains one v2 hook before reporting a reused cid as usable: + +1. After allocating a cid for a `ChipCallable`, check whether that cid may have + held a Python callable in this Worker lifetime. +2. If so, broadcast `CTRL_PY_UNREGISTER` to every Python-capable child group + owned by this Worker as an idempotent clear operation. +3. If that clear operation reports any child error, fail the `ChipCallable` + registration, remove the new parent registry entry, and do not start the + binary `CTRL_REGISTER` broadcast. +4. If the clear succeeds, continue through the existing binary + `broadcast_register_all` path. + +This hook is needed only for Python-capable child registries. Chip children +continue to rely on the existing binary self-heal before +`prepare_callable_from_blob`. + ### Parent-Side Unregister `Worker.unregister(cid)` uses the registered target type to select the unregister route: -1. Hold `_registry_lock`. -2. Raise `KeyError` if `cid` is absent from the parent registry. -3. If the target is a Python callable and `_py_register_active_run` is true, - raise `RuntimeError` before mutation. +1. If the first hierarchical startup is in progress, wait for it to complete + without holding `_registry_lock`. +2. Hold `_registry_lock`. +3. Raise `KeyError` if `cid` is absent from the parent registry or already has + an unregister in progress. 4. If the Worker has not started child processes yet, pop the parent entry and return. Future children will inherit the already-removed registry. -5. For a post-start `ChipCallable`, keep the existing binary unregister path. -6. If the target is a Python callable and this Worker has started child +5. Mark `cid` as pending unregister, then release `_registry_lock`. A pending + cid remains unavailable for reuse until the broadcast finishes. +6. For a post-start `ChipCallable`, keep the existing binary unregister path. +7. If the target is a Python callable and this Worker has started child processes, broadcast `CTRL_PY_UNREGISTER` to every Python-capable child group configured for this Worker, regardless of when the callable was originally registered. -7. Warn on per-child unregister errors, but pop the parent registry entry - unconditionally so the cid slot becomes reusable. +8. Warn on per-child unregister errors. Reacquire `_registry_lock`, pop the + parent registry entry unconditionally, clear the pending marker, and make + the cid slot reusable. Python callable unregister never cascades into `inner_worker.unregister(...)`. For L4+ Worker children it removes only the parent-owned dispatch registry entry @@ -406,6 +478,28 @@ Worker remains a separate operation on that Worker. This keeps cid ownership local and avoids unexpected collisions with entries the inner Worker already owns. +Registry ownership in a Worker-child process is: + +- Parent `CTRL_PY_REGISTER`: mutates the parent dispatch `registry`, read by + `_child_worker_loop`; does not cascade. +- Parent `CTRL_PY_UNREGISTER`: removes from the parent dispatch `registry`; + does not cascade. +- Parent binary `CTRL_REGISTER`: mutates `inner_worker._callable_registry` and + cascades through the inner Worker's own register route. +- Parent binary `CTRL_UNREGISTER`: mutates `inner_worker._callable_registry` + and cascades through the inner Worker's own unregister route. +- Inner Worker register/unregister: mutates `inner_worker._callable_registry` + and is owned by the inner Worker. + +The parent dispatch registry and `inner_worker._callable_registry` may contain +the same numeric cid for different owners. A parent Python unregister must not +call `inner_worker.unregister(cid)`, because that could delete a callable that +belongs to the inner Worker. Cross-type cleanup before parent `ChipCallable` +reuse clears stale Python entries from SUB registries and from Worker-child +parent dispatch registries. It does not clear +`inner_worker._callable_registry`; the binary register then cascades into +`inner_worker` through the normal binary route. + --- ## 5. Failure Modes and Tests @@ -415,11 +509,11 @@ owns. | Trigger | Handling | | ------- | -------- | | `cloudpickle` unavailable | Import fails at parent register time | -| Serializer cannot encode target | Parent pops cid and raises before IPC | +| Serializer cannot encode target | Parent raises before cid allocation | | Post-start no Python child group | Parent raises before cid allocation | | cid space exhausted | Parent raises before parent mutation | -| Active `Worker.run()` register | Parent raises before cid allocation | -| Active `Worker.run()` unregister | Parent raises before parent mutation | +| Startup race | Wait, then use post-start route | +| Duplicate unregister for same cid | Raise before second broadcast | | Child cannot open shm | Child writes `OFF_ERROR`; parent raises | | Child receives invalid cid | Child writes `OFF_ERROR`; parent raises | | Child deserialization fails | Child writes `OFF_ERROR`; parent raises | @@ -428,35 +522,42 @@ owns. | Some children succeed before another fails | Parent raises; no rollback | | Unregister fails on some children | Parent warns and pops its registry | | Cross-type cid reuse | New register clears or overwrites child residue | -| Child crashes during control | Parent may hang waiting for `CONTROL_DONE` | +| Child `cloudpickle.loads` times out | Failed child result | +| Child crashes during control | Timeout result, or hang if unset | No reverse rollback is attempted after partial register success. A successful child may retain a registry entry for a cid the parent reports as failed. Future cid reuse must overwrite it for Python registration, or clear it before `ChipCallable` registration, matching the best-effort unregister contract. -If a child process crashes or stops polling its mailbox during -`CONTROL_REQUEST`, the parent may wait indefinitely for `CONTROL_DONE`. This is -the same liveness failure mode as existing mailbox control operations such as -`CTRL_PREPARE`, `CTRL_MALLOC`, and binary dynamic register. Adding timeout, -child liveness detection, and hierarchical recovery is out of scope for this -feature and should be handled as a broader control-plane reliability change. +Python deserialization has a larger liveness surface than binary callable +prepare. `cloudpickle.loads(...)` may import modules or run user-defined object +reconstruction hooks, and that code can block, spin, or wedge the child before +it writes `CONTROL_DONE`. For Python callable broadcasts, callers should pass a +finite `timeout_s` so `broadcast_control_all` can return a failed per-child +result instead of waiting forever. Timeout does not make the child healthy; it +only lets `Worker.register()` fail visibly and lets best-effort cleanup report +which child did not respond. Child liveness detection, process replacement, and +hierarchical recovery remain out of scope for this feature. ### Concurrency - Parent registry mutation stays under `_registry_lock`. -- The first `Worker.run()` marks `_py_register_active_run` before children - startup, preventing a callable from being inserted after the startup registry - snapshot but before the caller observes children as started. -- Each child mailbox round trip stays under `mailbox_mu_`. +- The first `Worker.run()` marks hierarchical startup as `starting` before + taking the startup registry snapshot. Concurrent register/unregister callers + wait for startup to finish, then use the correct post-start route. +- `_registry_lock` is released before any broadcast waits for child ACKs. The + parent registry entry, plus the pending-unregister marker for unregister, + keeps the cid unavailable for reuse while the IPC operation is in flight. +- Each child mailbox round trip stays under `mailbox_mu_`, so post-start Python + register/unregister can run during `Worker.run()` and will serialize against + `TASK_READY` on each recipient mailbox. - `register()` is synchronous. A caller that races `register()` and `Worker.run()` from different Python threads must still wait for `register()` to return before submitting the new cid. - Child registry mutation is serialized by the mailbox state machine. -- The first implementation requires a quiescent Worker: dynamic Python - callable registration while `Worker.run()` is actively executing is rejected - with a clear error. Post-start registration between `run()` calls is the - supported target. +- `unregister()` is synchronous from the caller's perspective. The user remains + responsible for not unregistering a cid with outstanding submitted work. ### Test Plan @@ -472,11 +573,20 @@ format evolution: registry path and performs no control broadcast. - Unit test that first-run startup is serialized against Python register, so a racing register cannot miss the startup registry snapshot. +- Unit test that post-start Python register during an active `Worker.run()` + succeeds after the relevant child mailbox reaches a safe control point. +- Unit test that unregister keeps the cid unavailable for reuse until its + broadcast has completed, even though `_registry_lock` is not held across the + broadcast. - Unit test that post-start Python register rejects Workers with no SUB workers and no next-level Worker children. - Unit test selected-pool routing: `worker_type=SUB` reaches only `sub_threads_`, and `worker_type=NEXT_LEVEL` reaches only `next_level_threads_`. +- Unit test that `broadcast_control_all` returns the same structured + per-child result shape for register and unregister commands. +- Unit test that `broadcast_control_all(timeout_s=...)` reports a timed-out + child as a failed per-child result without blocking indefinitely. - L3 integration test: start an L3 Worker with SUB workers, run once to start children, dynamically register a Python sub callable, then `submit_sub(cid, ...)`. @@ -485,13 +595,13 @@ format evolution: then `submit_next_level(cid, ...)`. - Unregister test: once children have started, Python callable unregister broadcasts `CTRL_PY_UNREGISTER`, pops the parent registry, and allows cid - reuse regardless of whether the callable was registered before or after - children started. + reuse only after `unregister()` returns, regardless of whether the callable + was registered before or after children started. - Cross-type reuse test: stale Python dispatch residue from a failed best-effort unregister is cleared when the same cid is reused for a `ChipCallable`. -- Failure test: unsupported or non-serializable callable raises and releases the - parent cid slot. +- Failure test: unsupported or non-serializable callable raises without + consuming a parent cid slot. ## Related From f905ad4d11407944db35dedff713967b6fbbc327 Mon Sep 17 00:00:00 2001 From: puddingfjz <2811443837@qq.com> Date: Fri, 22 May 2026 17:49:43 +0800 Subject: [PATCH 3/6] Add dynamic Python callable registration --- docs/python-callable-serialization.md | 53 +- python/bindings/worker_bind.h | 33 + python/simpler/worker.py | 487 +++++++++--- src/common/hierarchical/worker.h | 5 + src/common/hierarchical/worker_manager.cpp | 78 +- src/common/hierarchical/worker_manager.h | 16 +- tests/ut/py/test_worker/test_host_worker.py | 769 ++++++++++++++++++- tests/ut/py/test_worker/test_l4_recursive.py | 43 ++ 8 files changed, 1330 insertions(+), 154 deletions(-) diff --git a/docs/python-callable-serialization.md b/docs/python-callable-serialization.md index 41d301d97..aaeca5cc2 100644 --- a/docs/python-callable-serialization.md +++ b/docs/python-callable-serialization.md @@ -10,7 +10,8 @@ document covers `ChipCallable` binary registration for chip children. This document covers Python callables consumed by SUB workers and by higher-level Worker-child dispatch loops. -It is a design document, not an implementation. +It is the design and contract for the implementation in +`python/simpler/worker.py` and `src/common/hierarchical/worker_manager.*`. --- @@ -102,6 +103,10 @@ The parent routes Python callable registration to Python-capable children: L3 chip children are not recipients for Python callable payloads. They can only consume prepared `ChipCallable` ids. +To keep the `NEXT_LEVEL` control pool unambiguous, L4+ Workers must use +`add_worker(...)` for next-level children and must not also configure direct +`device_ids`. Direct chip children remain an L3-only configuration. + Because `Worker.register()` does not currently take a "sub" versus "next-level orch" kind, the simplest compatible policy is to broadcast to all Python-capable child groups owned by this Worker. Extra registry entries are @@ -185,10 +190,12 @@ name instead of a live `SharedMemory.buf` object. ### Payload Format -The parent serializes the callable into an in-memory byte blob. The C++ -broadcast binding creates the side-band POSIX shm, copies that blob into it, -fan-outs the shm name to children, and unlinks the shm after all child -round-trips have completed. Python does not create or unlink the broadcast shm. +The parent serializes the callable with `cloudpickle`, then wraps those bytes in +the Python-callable wire header described below. The resulting complete payload +is an in-memory byte blob. The C++ broadcast binding creates the side-band POSIX +shm, copies that complete payload into it, fan-outs the shm name to children, +and unlinks the shm after all child round-trips have completed. Python does not +create or unlink the broadcast shm. The Python binding must accept a Python buffer object, preferably `bytes`, not only a raw integer pointer. The binding copies the buffer into the staging shm @@ -280,7 +287,8 @@ The mailbox layout for `CTRL_PY_REGISTER` mirrors binary register: `self._callable_registry[cid] = target`, and return the cid; future children will inherit the registry when they start. 4. If no configured Python-capable child group exists, raise `RuntimeError`. -5. Serialize the target into a bytes blob with `cloudpickle.dumps(...)`. +5. Serialize the target with `cloudpickle.dumps(...)` and wrap it in the + complete Python-callable wire payload. 6. Hold `_registry_lock`, allocate the smallest free cid, insert `self._callable_registry[cid] = target`, and release `_registry_lock`. 7. Broadcast `CTRL_PY_REGISTER` to required Python-capable worker groups. @@ -340,11 +348,15 @@ _Worker.broadcast_control_all(worker_type, sub_cmd, cid, payload=None, `worker_type` selects `SUB` versus `NEXT_LEVEL`; `sub_cmd` is `CTRL_PY_REGISTER` or `CTRL_PY_UNREGISTER`. For register, `payload` is the -`cloudpickle`-serialized callable, passed as a Python buffer object. For -unregister, `payload` is absent. The binding owns shm creation, copying, -fan-out, and unlink when a payload is present, matching -`broadcast_register_all` for binary callables while avoiding four -near-identical Python-specific bindings. +complete Python-callable wire payload, passed as a Python buffer object. That +means the bytes start with the `SPYC` header and the header's payload region is +the exact `cloudpickle.dumps(target)` result. Passing raw `cloudpickle` bytes +directly to `_Worker.broadcast_control_all(..., CTRL_PY_REGISTER, ...)` is +invalid because the C++ binding is a generic staging layer and does not add or +interpret Python-callable headers. For unregister, `payload` is absent. The +binding owns shm creation, copying, fan-out, and unlink when a payload is +present, matching `broadcast_register_all` for binary callables while avoiding +four near-identical Python-specific bindings. For a selected worker pool, fan-out is parallel: C++ starts one worker thread per target child, each round trip holds that child's `mailbox_mu_`, and the @@ -358,9 +370,14 @@ message. The timeout does not repair the wedged child or reclaim a mailbox that is still owned by a stuck control command; it only bounds the caller's wait and makes the partial failure visible to Python policy code. -The binding always returns structured per-child results. It does not switch -between "raise" and "return errors" based on `sub_cmd`. Python decides whether -those results are strict or best-effort: +The Python `Worker` facade uses a finite default timeout for its own dynamic +Python callable register/unregister broadcasts. The default is 30 seconds and +can be overridden per Worker with `py_control_timeout_s`. + +Once a control request is staged and fan-out begins, the binding returns +structured per-child results. It does not switch between "raise" and "return +errors" based on `sub_cmd`. Python decides whether those results are strict or +best-effort: ```text ControlResult(worker_type, worker_index, ok, error_message) @@ -375,6 +392,12 @@ ControlResult(worker_type, worker_index, ok, error_message) fails the `ChipCallable` registration before starting binary `CTRL_REGISTER`. +Argument conversion and setup failures that happen before a selected worker +pool can be contacted, such as a non-buffer `payload` object, an empty payload +buffer for a register command, or shm creation failure, may still raise +directly from the binding. Once fan-out begins, child-side failures and +timeouts are reported through `ControlResult`. + The existing `mailbox_mu_` must be held for each child round trip, just like binary register. This serializes Python register/unregister against `TASK_READY` dispatch on the same child. @@ -508,7 +531,7 @@ parent dispatch registries. It does not clear | Trigger | Handling | | ------- | -------- | -| `cloudpickle` unavailable | Import fails at parent register time | +| `cloudpickle` unavailable | `simpler.worker` import fails. | | Serializer cannot encode target | Parent raises before cid allocation | | Post-start no Python child group | Parent raises before cid allocation | | cid space exhausted | Parent raises before parent mutation | diff --git a/python/bindings/worker_bind.h b/python/bindings/worker_bind.h index b950054e3..19c6598dd 100644 --- a/python/bindings/worker_bind.h +++ b/python/bindings/worker_bind.h @@ -77,6 +77,12 @@ inline void bind_worker(nb::module_ &m) { // --- WorkerType --- nb::enum_(m, "WorkerType").value("NEXT_LEVEL", WorkerType::NEXT_LEVEL).value("SUB", WorkerType::SUB); + nb::class_(m, "ControlResult") + .def_ro("worker_type", &ControlResult::worker_type) + .def_ro("worker_index", &ControlResult::worker_index) + .def_ro("ok", &ControlResult::ok) + .def_ro("error_message", &ControlResult::error_message); + // --- TaskState --- nb::enum_(m, "TaskState") .value("FREE", TaskState::FREE) @@ -246,6 +252,33 @@ inline void bind_worker(nb::module_ &m) { "Best-effort broadcast of CTRL_UNREGISTER to every NEXT_LEVEL child in parallel. " "Returns a list of per-child error strings (empty on full success)." ) + .def( + "broadcast_control_all", + [](Worker &self, WorkerType worker_type, uint64_t sub_cmd, int32_t cid, nb::object payload, + nb::object timeout_s) { + std::string payload_bytes; + const void *payload_ptr = nullptr; + size_t payload_size = 0; + if (!payload.is_none()) { + Py_buffer view; + if (PyObject_GetBuffer(payload.ptr(), &view, PyBUF_CONTIG_RO) != 0) { + throw nb::python_error(); + } + payload_bytes.assign(static_cast(view.buf), static_cast(view.len)); + PyBuffer_Release(&view); + payload_ptr = payload_bytes.data(); + payload_size = payload_bytes.size(); + } + double timeout_val = timeout_s.is_none() ? -1.0 : nb::cast(timeout_s); + nb::gil_scoped_release release; + return self.broadcast_control_all(worker_type, sub_cmd, cid, payload_ptr, payload_size, timeout_val); + }, + nb::arg("worker_type"), nb::arg("sub_cmd"), nb::arg("cid"), nb::arg("payload") = nb::none(), + nb::arg("timeout_s") = nb::none(), + "Broadcast an arbitrary CONTROL_REQUEST to the selected worker pool. " + "If payload is a Python buffer, C++ stages it in POSIX shm and writes the shm name " + "into the mailbox. Returns per-child ControlResult entries." + ) .def( "control_alloc_domain", &Worker::control_alloc_domain, nb::arg("worker_id"), nb::arg("request_shm_name"), nb::arg("reply_shm_name"), nb::call_guard(), diff --git a/python/simpler/worker.py b/python/simpler/worker.py index ae816c0c0..06a29f803 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -11,8 +11,9 @@ Callable identity is a ``cid`` (int), allocated exclusively by ``Worker.register(callable)``. ``Worker.run`` and the orchestrator's ``submit_next_level`` / ``submit_sub`` all take this cid — never the raw -``ChipCallable`` / Python function. L≥3 ``register()`` must run **before** -``init()`` so forked chip / sub children inherit the registry via COW. +``ChipCallable`` / Python function. L≥3 Python callables registered before +child startup are inherited through the fork-time snapshot; later +registrations are serialized and sent through the mailbox control plane. Usage:: @@ -64,9 +65,12 @@ def my_l4_orch(orch, args, config): from multiprocessing.shared_memory import SharedMemory from typing import Any, Optional +import cloudpickle + from _task_interface import ( # pyright: ignore[reportMissingImports] MAX_REGISTERED_CALLABLE_IDS, RunTiming, + WorkerType, _mailbox_load_i32, _mailbox_store_i32, read_args_from_blob, @@ -93,6 +97,7 @@ def my_l4_orch(orch, args, config): # that a hung child fails the suite instead of the CI job timing out. _BOOTSTRAP_WAIT_TIMEOUT_S = 120.0 _BOOTSTRAP_POLL_INTERVAL_S = 0.001 +_PY_CONTROL_TIMEOUT_S = 30.0 # --------------------------------------------------------------------------- @@ -163,11 +168,18 @@ def my_l4_orch(orch, args, config): # rootinfo_path) and caches the handle on the ChipWorker so subsequent # CTRL_ALLOC_DOMAIN calls can find it. _CTRL_COMM_INIT = 9 +_CTRL_PY_REGISTER = 10 +_CTRL_PY_UNREGISTER = 11 # Layout of the CTRL_COMM_INIT request shm. _COMM_INIT_HEADER = struct.Struct(" bytes: + payload = cloudpickle.dumps(target) + return _PY_CALLABLE_HEADER.pack( + _PY_CALLABLE_MAGIC, + _PY_CALLABLE_VERSION, + _PY_CALLABLE_SERIALIZER_CLOUDPICKLE, + 0, + len(payload), + ) + payload + + +def _load_py_callable_from_shm(shm_name: str): + shm = SharedMemory(name=shm_name) + try: + shm_buf = shm.buf + assert shm_buf is not None + if shm.size < _PY_CALLABLE_HEADER.size: + raise RuntimeError(f"python callable payload too small: {shm.size} bytes") + magic, version, serializer, flags, payload_size = _PY_CALLABLE_HEADER.unpack_from(shm_buf, 0) + if magic != _PY_CALLABLE_MAGIC: + raise RuntimeError(f"invalid python callable payload magic: {magic!r}") + if version != _PY_CALLABLE_VERSION: + raise RuntimeError(f"unsupported python callable payload version: {version}") + if serializer != _PY_CALLABLE_SERIALIZER_CLOUDPICKLE: + raise RuntimeError(f"unsupported python callable serializer: {serializer}") + if flags != 0: + raise RuntimeError(f"unsupported python callable payload flags: {flags}") + expected_size = _PY_CALLABLE_HEADER.size + int(payload_size) + if expected_size != shm.size: + raise RuntimeError(f"python callable payload size mismatch: header={payload_size}, shm={shm.size}") + payload = bytes(shm_buf[_PY_CALLABLE_HEADER.size : expected_size]) + finally: + shm.close() + + fn = cloudpickle.loads(payload) + if not callable(fn): + raise RuntimeError(f"python callable payload decoded to non-callable {type(fn).__name__}") + return fn + + +def _handle_py_callable_control(buf, registry: dict, sub_cmd: int, *, context: str) -> None: + cid = int(struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0]) & 0xFFFFFFFF + if cid >= MAX_REGISTERED_CALLABLE_IDS: + raise RuntimeError(f"{context}: cid {cid} out of range") + if sub_cmd == _CTRL_PY_REGISTER: + shm_name = _read_shm_name(buf, _OFF_ARGS) + registry[cid] = _load_py_callable_from_shm(shm_name) + elif sub_cmd == _CTRL_PY_UNREGISTER: + registry.pop(cid, None) + else: + raise RuntimeError(f"{context}: unknown control sub-command {int(sub_cmd)}") + + def _mailbox_addr(shm: SharedMemory) -> int: buf = shm.buf assert buf is not None @@ -299,6 +364,17 @@ def _sub_worker_loop(buf, registry: dict) -> None: msg = _format_exc("sub_worker", e) _write_error(buf, code, msg) _mailbox_store_i32(state_addr, _TASK_DONE) + elif state == _CONTROL_REQUEST: + sub_cmd = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0] + code = 0 + msg = "" + try: + _handle_py_callable_control(buf, registry, int(sub_cmd), context="sub_worker") + except Exception as e: # noqa: BLE001 + code = 1 + msg = _format_exc("sub_worker control", e) + _write_error(buf, code, msg) + _mailbox_store_i32(state_addr, _CONTROL_DONE) elif state == _SHUTDOWN: break @@ -575,6 +651,8 @@ def _run_chip_main_loop( # noqa: PLR0912 -- TASK_READY + 6 control sub-commands _handle_ctrl_release_domain(cw, buf) elif sub_cmd == _CTRL_COMM_INIT: _handle_ctrl_comm_init(cw, buf) + else: + raise RuntimeError(f"unknown control sub-command {int(sub_cmd)}") except Exception as e: # noqa: BLE001 code = 1 if sub_cmd in (_CTRL_REGISTER, _CTRL_UNREGISTER): @@ -709,16 +787,34 @@ def _child_worker_loop( # cid_val onto the registry slot keeps the inner-side # cid identical to the outer-side cid — both the L4 # scheduler and the L3 children index by the same int. + registry.pop(int(cid_val), None) inner_worker._register_at(int(cid_val), callable_obj) elif sub_cmd == _CTRL_UNREGISTER: cid_val = int(struct.unpack_from("Q", buf, _CTRL_OFF_ARG0)[0]) & 0xFFFFFFFF inner_worker.unregister(int(cid_val)) + elif sub_cmd in (_CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER): + _handle_py_callable_control( + buf, + registry, + int(sub_cmd), + context=f"child_worker level={inner_worker.level}", + ) + else: + raise RuntimeError(f"unknown control sub-command {int(sub_cmd)}") except Exception as e: # noqa: BLE001 code = 1 op = ( "register" if sub_cmd == _CTRL_REGISTER - else ("unregister" if sub_cmd == _CTRL_UNREGISTER else f"ctrl={int(sub_cmd)}") + else ( + "unregister" + if sub_cmd == _CTRL_UNREGISTER + else ( + "py_register" + if sub_cmd == _CTRL_PY_REGISTER + else ("py_unregister" if sub_cmd == _CTRL_PY_UNREGISTER else f"ctrl={int(sub_cmd)}") + ) + ) ) msg = _format_exc(f"child_worker level={inner_worker.level} {op}", e) _write_error(buf, code, msg) @@ -760,6 +856,14 @@ def __init__( # dispatch) is now handled at the C++ boundary via mailbox_mu_, so # no quiescent-state guard is needed. self._registry_lock = threading.Lock() + self._pending_unregister_cids: set[int] = set() + self._py_callable_cids_seen: set[int] = set() + self._py_control_timeout_s = float( + config.get("py_control_timeout_s", _PY_CONTROL_TIMEOUT_S) + ) + self._hierarchical_start_state = "not_started" + self._hierarchical_start_mu = threading.Lock() + self._hierarchical_start_cv = threading.Condition(self._hierarchical_start_mu) # Level-2 internals self._chip_worker: Optional[ChipWorker] = None @@ -820,51 +924,123 @@ def register(self, target) -> int: ``orch.submit_sub(cid, …)``. Timing constraints: - - L3+: Python callables (sub fn / orch fn) must be registered - **before** ``init()`` so the COW-inherited registry is visible to - forked chip / sub children. ChipCallables may be registered either - before init (pre-warmed via ``_CTRL_PREPARE`` during ``init()``) - or after init (broadcast to chip children via - ``_Worker.broadcast_register_all``; see - docs/callable-ipc-dynamic-register.md). Post-init register at - L3+ is ChipCallable-only. + - L3+: registrations before child processes start are inherited + by forked children through the startup registry snapshot. + Registrations after child processes start use the mailbox + control plane: ChipCallables keep the binary path, while Python + callables are serialized with cloudpickle and broadcast to + Python-capable child groups. - L2: may be called either before or after ``init()`` (no fork, no COW constraint). When called post-init, ChipCallables are prepared on the device immediately; pre-init registrations are batched and prepared at the end of ``init()``. """ + if self.level == 2 and not isinstance(target, ChipCallable): + raise TypeError("Worker.register: level 2 only supports ChipCallable targets") + if self.level >= 3: + self._wait_hierarchical_start_if_needed() + if not isinstance(target, ChipCallable): + if not callable(target): + raise TypeError("Worker.register: non-ChipCallable target must be callable") + if self._initialized and getattr(self, "_hierarchical_started", False): + return self._post_start_register_python(target) + with self._registry_lock: - if self.level >= 3 and self._initialized and not isinstance(target, ChipCallable): - # L3+ post-init: only ChipCallable can cross the process - # boundary. Python callables (sub fn / orch fn) must be - # registered before init() so forked children inherit them. - raise NotImplementedError( - "Worker.register() at level >= 3 must be called before init() " - "for non-ChipCallable targets; only ChipCallable is supported " - "post-init (see docs/callable-ipc-dynamic-register.md)" - ) cid = self._allocate_cid() self._callable_registry[cid] = target + if self.level >= 3 and not isinstance(target, ChipCallable): + self._py_callable_cids_seen.add(cid) + + # L3+ post-init ChipCallable: broadcast to chip / next-level children + # via C++ after parent-side cid allocation is complete. The registry + # entry keeps the cid reserved while mailbox_mu_ serializes the wire + # round trip against dispatch. + if self.level >= 3 and self._initialized and isinstance(target, ChipCallable): + try: + self._post_init_register(cid, target) + except Exception: + with self._registry_lock: + if self._callable_registry.get(cid) is target: + self._callable_registry.pop(cid, None) + raise + return cid - # L3+ post-init ChipCallable: broadcast to chip / next-level - # children via C++. Done inside the registry lock so a concurrent - # register cannot allocate the same cid we are about to pop on - # failure. Per-WorkerThread mailbox_mu_ already provides the C++ - # serialisation against in-flight dispatch. - if self.level >= 3 and self._initialized and isinstance(target, ChipCallable): - try: - self._post_init_register(cid, target) - except Exception: + # L2 post-init: pre-warm immediately so the very first + # `Worker.run(cid, …)` is a clean cache hit. + if self.level == 2 and self._initialized and isinstance(target, ChipCallable): + assert self._chip_worker is not None + self._chip_worker.prepare_callable(cid, target) + return cid + + def _wait_hierarchical_start_if_needed(self) -> None: + if self.level < 3: + return + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + + def _python_worker_types(self) -> list[WorkerType]: + worker_types: list[WorkerType] = [] + if self._config.get("num_sub_workers", 0) > 0: + worker_types.append(WorkerType.SUB) + if self._next_level_workers: + worker_types.append(WorkerType.NEXT_LEVEL) + return worker_types + + def _post_start_register_python(self, target) -> int: + worker_types = self._python_worker_types() + if not worker_types: + raise RuntimeError( + "Worker.register: no Python-capable child workers are configured " + "for dynamic Python callable registration" + ) + payload = _pack_py_callable_payload(target) + with self._registry_lock: + cid = self._allocate_cid() + self._callable_registry[cid] = target + self._py_callable_cids_seen.add(cid) + try: + self._broadcast_py_control(worker_types, _CTRL_PY_REGISTER, cid, payload=payload, strict=True) + except Exception: + with self._registry_lock: + if self._callable_registry.get(cid) is target: self._callable_registry.pop(cid, None) - raise - return cid + raise + return cid - # L2 post-init: pre-warm immediately so the very first - # `Worker.run(cid, …)` is a clean cache hit. - if self.level == 2 and self._initialized and isinstance(target, ChipCallable): - assert self._chip_worker is not None - self._chip_worker.prepare_callable(cid, target) - return cid + def _broadcast_py_control( + self, + worker_types: list[WorkerType], + sub_cmd: int, + cid: int, + *, + payload: Optional[bytes] = None, + strict: bool, + ) -> list[str]: + if not worker_types: + return [] + assert self._worker is not None + errors: list[str] = [] + payload_bytes = payload if payload is not None else None + for worker_type in worker_types: + results = self._worker.broadcast_control_all( + worker_type, + int(sub_cmd), + int(cid), + payload_bytes, + timeout_s=self._py_control_timeout_s, + ) + for result in results: + if not result.ok: + errors.append(f"{result.worker_type}[{result.worker_index}]: {result.error_message}") + if errors and strict: + raise RuntimeError( + f"Worker control broadcast cid={cid} sub_cmd={sub_cmd} failed on " + f"{len(errors)} child workers; first error: {errors[0]}" + ) + return errors def _allocate_cid(self) -> int: """Return the smallest unused cid in [0, MAX_REGISTERED_CALLABLE_IDS). @@ -875,7 +1051,7 @@ def _allocate_cid(self) -> int: would silently overwrite the next gap-after-the-hole. """ for i in range(MAX_REGISTERED_CALLABLE_IDS): - if i not in self._callable_registry: + if i not in self._callable_registry and i not in self._pending_unregister_cids: return i # The AICPU side keeps a fixed-size orch_so_table_ keyed by cid; # raise here so the failure surfaces at register-time with a @@ -897,18 +1073,21 @@ def _register_at(self, cid: int, target: ChipCallable) -> None: on a single integer key. Plain ``register`` allocates the next free slot and is therefore unsuitable here. """ + if not isinstance(target, ChipCallable): + raise TypeError("_register_at: target must be a ChipCallable") with self._registry_lock: if cid in self._callable_registry: raise RuntimeError(f"_register_at: cid={cid} already occupied") - if not isinstance(target, ChipCallable): - raise TypeError("_register_at: target must be a ChipCallable") self._callable_registry[cid] = target - if self.level >= 3 and self._initialized: - try: - self._post_init_register(cid, target) - except Exception: - self._callable_registry.pop(cid, None) - raise + + if self.level >= 3 and self._initialized: + try: + self._post_init_register(cid, target) + except Exception: + with self._registry_lock: + if self._callable_registry.get(cid) is target: + self._callable_registry.pop(cid, None) + raise def _post_init_register(self, cid: int, target: ChipCallable) -> None: """Broadcast a new ChipCallable to every NEXT_LEVEL child via C++. @@ -927,6 +1106,8 @@ def _post_init_register(self, cid: int, target: ChipCallable) -> None: if not getattr(self, "_hierarchical_started", False): return assert self._worker is not None + if cid in self._py_callable_cids_seen: + self._broadcast_py_control(self._python_worker_types(), _CTRL_PY_UNREGISTER, cid, strict=True) self._worker.broadcast_register_all(int(cid), int(target.buffer_ptr()), int(target.buffer_size())) def unregister(self, cid: int) -> None: @@ -947,17 +1128,45 @@ def unregister(self, cid: int) -> None: Raises: KeyError: cid was never registered. """ + self._wait_hierarchical_start_if_needed() + target = None with self._registry_lock: if cid not in self._callable_registry: raise KeyError(f"Worker.unregister: cid={cid} not registered") + if cid in self._pending_unregister_cids: + raise KeyError(f"Worker.unregister: cid={cid} already pending unregister") + target = self._callable_registry[cid] if self.level >= 3 and self._initialized and getattr(self, "_hierarchical_started", False): - self._broadcast_unregister(cid) + self._pending_unregister_cids.add(cid) elif self.level == 2 and self._initialized: assert self._chip_worker is not None self._chip_worker.unregister_callable(cid) - # Drop the registry entry unconditionally — even if a chip child - # reported an error, holding the slot would just waste cid budget. - self._callable_registry.pop(cid, None) + self._callable_registry.pop(cid, None) + return + else: + self._callable_registry.pop(cid, None) + return + + try: + if isinstance(target, ChipCallable): + self._broadcast_unregister(cid) + else: + errors = self._broadcast_py_control( + self._python_worker_types(), + _CTRL_PY_UNREGISTER, + cid, + strict=False, + ) + if errors: + sys.stderr.write( + f"Worker.unregister(cid={cid}): {len(errors)} Python children reported errors " + f"(continuing best-effort). First error: {errors[0]}\n" + ) + sys.stderr.flush() + finally: + with self._registry_lock: + self._callable_registry.pop(cid, None) + self._pending_unregister_cids.discard(cid) def _broadcast_unregister(self, cid: int) -> None: """Broadcast _CTRL_UNREGISTER via C++ to every NEXT_LEVEL child. @@ -983,6 +1192,8 @@ def add_worker(self, worker: "Worker") -> None: """ if self.level < 4: raise RuntimeError("Worker.add_worker() requires level >= 4") + if self._config.get("device_ids", []): + raise RuntimeError("Worker.add_worker() cannot be combined with device_ids on the same Worker") if self._initialized: raise RuntimeError("Worker.add_worker() must be called before init()") if worker._initialized: @@ -1030,6 +1241,8 @@ def _init_hierarchical(self) -> None: device_ids = self._config.get("device_ids", []) n_sub = self._config.get("num_sub_workers", 0) heap_ring_size = self._config.get("heap_ring_size", None) + if self.level >= 4 and device_ids: + raise RuntimeError("Worker level >= 4 must use add_worker(); device_ids are only supported on L3 Workers") # 1. Allocate sub-worker mailboxes (unified layout, MAILBOX_SIZE each). for _ in range(n_sub): @@ -1081,97 +1294,117 @@ def _init_hierarchical(self) -> None: def _start_hierarchical(self) -> None: # noqa: PLR0912 -- three parallel fork loops (sub/chip/next) + bootstrap wait + scheduler register/init; branches track the fork order documented in the body """Fork child processes and start C++ scheduler. Called on first run().""" - if self._hierarchical_started: - return - self._hierarchical_started = True + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "started": + return + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + self._hierarchical_start_state = "starting" device_ids = self._config.get("device_ids", []) n_sub = self._config.get("num_sub_workers", 0) - # Fork SubWorker processes (MUST be before any C++ threads) - registry = self._callable_registry - for i in range(n_sub): - pid = os.fork() - if pid == 0: - buf = self._sub_shms[i].buf - assert buf is not None - _sub_worker_loop(buf, registry) - os._exit(0) - else: - self._sub_pids.append(pid) + try: + # Fork children from an immutable snapshot. Dynamic register callers + # that race this startup wait and then use the post-start path. + with self._registry_lock: + registry = dict(self._callable_registry) - # Fork ChipWorker processes (L3 with device_ids). Always use the - # plain task-loop variant; the base communicator is established - # lazily on first ``orch.allocate_domain`` via CTRL_COMM_INIT. - chip_log_level, chip_log_info_v = _simpler_log.get_current_config() - if device_ids: - for idx, dev_id in enumerate(device_ids): + # Fork SubWorker processes (MUST be before any C++ threads) + for i in range(n_sub): pid = os.fork() if pid == 0: - buf = self._chip_shms[idx].buf + buf = self._sub_shms[i].buf assert buf is not None - _chip_process_loop( - buf, - self._l3_bins, - dev_id, - registry, - chip_log_level, - chip_log_info_v, - ) + _sub_worker_loop(buf, registry) os._exit(0) else: - self._chip_pids.append(pid) - - # Fork next-level Worker children (L4+ with Worker children). - # Each child process: init the inner Worker (which mmaps its own - # HeapRing and allocates its own child mailboxes), then enter - # _child_worker_loop. The inner Worker's own children are forked - # lazily on first run() inside _child_worker_loop, so the process - # tree nests correctly: L4 → L3 child → L3's chip/sub children. - for idx, inner_worker in enumerate(self._next_level_workers): - pid = os.fork() - if pid == 0: - buf = self._next_level_shms[idx].buf - assert buf is not None - inner_worker.init() - _child_worker_loop(buf, registry, inner_worker) - os._exit(0) - else: - self._next_level_pids.append(pid) - - # _Worker was constructed in _init_hierarchical (pre-fork) so - # children inherit the HeapRing MAP_SHARED mmap. Register PROCESS-mode - # workers via the unified mailbox. - dw = self._worker - assert dw is not None - - # Register chip workers as NEXT_LEVEL (L3) - if device_ids: - for shm in self._chip_shms: - dw.add_next_level_worker(_mailbox_addr(shm)) - - # Register Worker children as NEXT_LEVEL (L4+) - for shm in self._next_level_shms: - dw.add_next_level_worker(_mailbox_addr(shm)) + self._sub_pids.append(pid) + + # Fork ChipWorker processes (L3 with device_ids). Always use the + # plain task-loop variant; the base communicator is established + # lazily on first ``orch.allocate_domain`` via CTRL_COMM_INIT. + chip_log_level, chip_log_info_v = _simpler_log.get_current_config() + if device_ids: + for idx, dev_id in enumerate(device_ids): + pid = os.fork() + if pid == 0: + buf = self._chip_shms[idx].buf + assert buf is not None + _chip_process_loop( + buf, + self._l3_bins, + dev_id, + registry, + chip_log_level, + chip_log_info_v, + ) + os._exit(0) + else: + self._chip_pids.append(pid) + + # Fork next-level Worker children (L4+ with Worker children). + # Each child process: init the inner Worker (which mmaps its own + # HeapRing and allocates its own child mailboxes), then enter + # _child_worker_loop. The inner Worker's own children are forked + # lazily on first run() inside _child_worker_loop, so the process + # tree nests correctly: L4 → L3 child → L3's chip/sub children. + for idx, inner_worker in enumerate(self._next_level_workers): + pid = os.fork() + if pid == 0: + buf = self._next_level_shms[idx].buf + assert buf is not None + inner_worker.init() + _child_worker_loop(buf, registry, inner_worker) + os._exit(0) + else: + self._next_level_pids.append(pid) - for shm in self._sub_shms: - dw.add_sub_worker(_mailbox_addr(shm)) + # _Worker was constructed in _init_hierarchical (pre-fork) so + # children inherit the HeapRing MAP_SHARED mmap. Register PROCESS-mode + # workers via the unified mailbox. + dw = self._worker + assert dw is not None - # Start Scheduler + WorkerThreads (C++ threads start here, after fork) - dw.init() + # Register chip workers as NEXT_LEVEL (L3) + if device_ids: + for shm in self._chip_shms: + dw.add_next_level_worker(_mailbox_addr(shm)) - self._orch = Orchestrator(dw.get_orchestrator(), self) + # Register Worker children as NEXT_LEVEL (L4+) + for shm in self._next_level_shms: + dw.add_next_level_worker(_mailbox_addr(shm)) - # Pre-warm every chip child: for each registered ChipCallable cid, - # send `_CTRL_PREPARE` to all chip children so the first - # `submit_next_level` does not pay the H2D upload cost. Sub fns / - # orch fns do not need pre-warming — the registry is already - # COW-inherited. - if device_ids: - for cid, target in self._callable_registry.items(): - if isinstance(target, ChipCallable): - for worker_id in range(len(self._chip_shms)): - dw.control_prepare(worker_id, int(cid)) + for shm in self._sub_shms: + dw.add_sub_worker(_mailbox_addr(shm)) + + # Start Scheduler + WorkerThreads (C++ threads start here, after fork) + dw.init() + + self._orch = Orchestrator(dw.get_orchestrator(), self) + + # Pre-warm every chip child: for each registered ChipCallable cid, + # send `_CTRL_PREPARE` to all chip children so the first + # `submit_next_level` does not pay the H2D upload cost. Sub fns / + # orch fns do not need pre-warming — the registry is already + # COW-inherited. + if device_ids: + for cid, target in registry.items(): + if isinstance(target, ChipCallable): + for worker_id in range(len(self._chip_shms)): + dw.control_prepare(worker_id, int(cid)) + + self._hierarchical_started = True + with self._hierarchical_start_cv: + self._hierarchical_start_state = "started" + self._hierarchical_start_cv.notify_all() + except Exception: + with self._hierarchical_start_cv: + self._hierarchical_start_state = "failed" + self._hierarchical_start_cv.notify_all() + raise # ------------------------------------------------------------------ # Hierarchical abort diff --git a/src/common/hierarchical/worker.h b/src/common/hierarchical/worker.h index c90f05af6..3ff7ec1be 100644 --- a/src/common/hierarchical/worker.h +++ b/src/common/hierarchical/worker.h @@ -115,6 +115,11 @@ class Worker { manager_.broadcast_register_all(cid, reinterpret_cast(blob_ptr), static_cast(blob_size)); } std::vector broadcast_unregister_all(int32_t cid) { return manager_.broadcast_unregister_all(cid); } + std::vector broadcast_control_all( + WorkerType type, uint64_t sub_cmd, int32_t cid, const void *payload, size_t payload_size, double timeout_s + ) { + return manager_.broadcast_control_all(type, sub_cmd, cid, payload, payload_size, timeout_s); + } private: int32_t level_; diff --git a/src/common/hierarchical/worker_manager.cpp b/src/common/hierarchical/worker_manager.cpp index f50e4bc2d..c26f3e2fe 100644 --- a/src/common/hierarchical/worker_manager.cpp +++ b/src/common/hierarchical/worker_manager.cpp @@ -157,6 +157,9 @@ void WorkerThread::dispatch_process(TaskSlotState &s, int32_t group_index) { // orch thread waits for the dispatch to finish before claiming the // mailbox; without this they would race on MAILBOX_OFF_STATE. std::lock_guard lk(mailbox_mu_); + if (mailbox_control_timed_out_) { + throw std::runtime_error("WorkerThread::dispatch_process: mailbox has an unresolved timed-out control command"); + } // Clear the child-writable error fields so stale bytes from a prior // dispatch cannot masquerade as a fresh failure. @@ -338,12 +341,26 @@ static uint64_t read_control_result(const char *mbox) { // from the child, throws and leaves the mailbox in IDLE before unwinding // (so the next claim starts from a clean state). The `op_name` is used // only for the exception message. -void WorkerThread::run_control_command(const char *op_name) { +void WorkerThread::run_control_command(const char *op_name, double timeout_s) { + if (mailbox_control_timed_out_) { + throw std::runtime_error(std::string(op_name) + " failed: mailbox has an unresolved timed-out control command"); + } int32_t zero_err = 0; std::memcpy(mbox() + MAILBOX_OFF_ERROR, &zero_err, sizeof(int32_t)); std::memset(mbox() + MAILBOX_OFF_ERROR_MSG, 0, MAILBOX_ERROR_MSG_SIZE); write_mailbox_state(MailboxState::CONTROL_REQUEST); - while (read_mailbox_state() != MailboxState::CONTROL_DONE) {} + auto deadline = std::chrono::steady_clock::time_point::max(); + if (timeout_s >= 0.0) { + deadline = + std::chrono::steady_clock::now() + + std::chrono::duration_cast(std::chrono::duration(timeout_s)); + } + while (read_mailbox_state() != MailboxState::CONTROL_DONE) { + if (std::chrono::steady_clock::now() >= deadline) { + mailbox_control_timed_out_ = true; + throw std::runtime_error(std::string(op_name) + " timed out waiting for CONTROL_DONE"); + } + } int32_t err = 0; std::memcpy(&err, mbox() + MAILBOX_OFF_ERROR, sizeof(int32_t)); if (err != 0) { @@ -392,6 +409,21 @@ void WorkerThread::control_unregister(int32_t cid) { run_control_command("control_unregister"); } +void WorkerThread::control_generic(uint64_t sub_cmd, int32_t cid, const char *shm_name, double timeout_s) { + std::lock_guard lk(mailbox_mu_); + std::memcpy(mbox() + MAILBOX_OFF_CALLABLE, &sub_cmd, sizeof(uint64_t)); + uint64_t cid_v = static_cast(cid); + std::memcpy(mbox() + CTRL_OFF_ARG0, &cid_v, sizeof(uint64_t)); + const char *name = shm_name ? shm_name : ""; + size_t name_len = std::strlen(name); + if (name_len + 1 > CTRL_SHM_NAME_BYTES) { + throw std::runtime_error(std::string("control_generic: shm name too long: ") + name); + } + if (name_len > 0) std::memcpy(mbox() + MAILBOX_OFF_ARGS, name, name_len); + std::memset(mbox() + MAILBOX_OFF_ARGS + name_len, 0, CTRL_SHM_NAME_BYTES - name_len); + run_control_command("control_generic", timeout_s); +} + void WorkerThread::control_free(uint64_t ptr) { std::lock_guard lk(mailbox_mu_); write_control_args(mbox(), CTRL_FREE, ptr); @@ -661,3 +693,45 @@ std::vector WorkerManager::broadcast_unregister_all(int32_t cid) { } return errors; } + +std::vector WorkerManager::broadcast_control_all( + WorkerType type, uint64_t sub_cmd, int32_t cid, const void *payload, size_t payload_size, double timeout_s +) { + auto &threads = (type == WorkerType::NEXT_LEVEL) ? next_level_threads_ : sub_threads_; + const char *type_name = (type == WorkerType::NEXT_LEVEL) ? "NEXT_LEVEL" : "SUB"; + + std::vector results; + results.reserve(threads.size()); + for (size_t i = 0; i < threads.size(); ++i) { + results.push_back(ControlResult{type_name, static_cast(i), true, ""}); + } + if (threads.empty()) return results; + + std::unique_ptr shm; + std::string shm_name; + if (payload != nullptr || payload_size != 0) { + if (payload == nullptr || payload_size == 0) { + throw std::runtime_error("broadcast_control_all: payload pointer and size must both be set"); + } + shm_name = make_shm_name(cid); + shm = std::make_unique(shm_name, payload_size); + std::memcpy(shm->addr(), payload, payload_size); + } + + std::vector workers; + workers.reserve(threads.size()); + for (size_t i = 0; i < threads.size(); ++i) { + workers.emplace_back([&, i]() { + try { + threads[i]->control_generic(sub_cmd, cid, shm_name.empty() ? nullptr : shm_name.c_str(), timeout_s); + } catch (const std::exception &e) { + results[i].ok = false; + results[i].error_message = strip_control_prefix(e.what(), "control_generic"); + } + }); + } + for (auto &t : workers) + t.join(); + + return results; +} diff --git a/src/common/hierarchical/worker_manager.h b/src/common/hierarchical/worker_manager.h index f7b00b4ff..76a4bf2c7 100644 --- a/src/common/hierarchical/worker_manager.h +++ b/src/common/hierarchical/worker_manager.h @@ -121,6 +121,8 @@ static constexpr uint64_t CTRL_RELEASE_DOMAIN = 8; // Caches the comm handle on the chip's ChipWorker so subsequent // CTRL_ALLOC_DOMAIN calls can find it. static constexpr uint64_t CTRL_COMM_INIT = 9; +static constexpr uint64_t CTRL_PY_REGISTER = 10; +static constexpr uint64_t CTRL_PY_UNREGISTER = 11; // Control args reuse the task mailbox region (mutually exclusive with task dispatch): // offset 16: uint64 arg0 (size for malloc; ptr for free; dst for copy; cid for register) @@ -137,6 +139,13 @@ static constexpr ptrdiff_t CTRL_OFF_RESULT = 40; // of "simpler-cb---" with pid < 32-bit max. static constexpr size_t CTRL_SHM_NAME_BYTES = 32; +struct ControlResult { + std::string worker_type; + int32_t worker_index{0}; + bool ok{false}; + std::string error_message; +}; + // ============================================================================= // WorkerDispatch — per-dispatch handle handed to a WorkerThread. // ============================================================================= @@ -213,6 +222,7 @@ class WorkerThread { // for the in-flight TASK_DONE before claiming the mailbox. void control_register(int32_t cid, const char *shm_name); void control_unregister(int32_t cid); + void control_generic(uint64_t sub_cmd, int32_t cid, const char *shm_name, double timeout_s); // Dynamic CommDomain allocate / release. `request_shm_name` carries the // request payload (header + rank_ids + buffer_nbytes); for alloc the child @@ -244,6 +254,7 @@ class WorkerThread { // dispatch loop and the orch-thread control_* path. Per-WorkerThread, // so different workers can dispatch in parallel. std::mutex mailbox_mu_; + bool mailbox_control_timed_out_{false}; void loop(); void dispatch_process(TaskSlotState &s, int32_t group_index); @@ -251,7 +262,7 @@ class WorkerThread { // Common tail for the four control_* methods. Caller writes the args // region and holds `mailbox_mu_`; this helper signals the child, // spin-polls CONTROL_DONE, and throws on a non-zero child error code. - void run_control_command(const char *op_name); + void run_control_command(const char *op_name, double timeout_s = -1.0); char *mbox() const { return static_cast(mailbox_); } MailboxState read_mailbox_state() const; @@ -312,6 +323,9 @@ class WorkerManager { // worker in parallel. Returns a vector of per-worker error strings // (empty on full success). Caller decides whether to log / surface. std::vector broadcast_unregister_all(int32_t cid); + std::vector broadcast_control_all( + WorkerType type, uint64_t sub_cmd, int32_t cid, const void *payload, size_t payload_size, double timeout_s + ); // Write SHUTDOWN to every registered mailbox. void shutdown_children(); diff --git a/tests/ut/py/test_worker/test_host_worker.py b/tests/ut/py/test_worker/test_host_worker.py index 4a5d11079..1e13d54b9 100644 --- a/tests/ut/py/test_worker/test_host_worker.py +++ b/tests/ut/py/test_worker/test_host_worker.py @@ -18,7 +18,15 @@ import pytest from _task_interface import MAX_REGISTERED_CALLABLE_IDS # pyright: ignore[reportMissingImports] -from simpler.task_interface import ChipCallable, DataType, TaskArgs, TensorArgType +from simpler.task_interface import ( + MAILBOX_SIZE, + ChipCallable, + DataType, + TaskArgs, + TensorArgType, + WorkerType, + _Worker, +) from simpler.worker import Worker # --------------------------------------------------------------------------- @@ -44,6 +52,33 @@ def _increment_counter(buf) -> None: struct.pack_into("i", buf, 0, v + 1) +def _add_counter(buf, delta: int) -> None: + v = struct.unpack_from("i", buf, 0)[0] + struct.pack_into("i", buf, 0, v + delta) + + +def _set_flag(buf, offset: int, value: int) -> None: + struct.pack_into("i", buf, offset, value) + + +def _get_flag(buf, offset: int) -> int: + return struct.unpack_from("i", buf, offset)[0] + + +def _roundtrip_py_callable_payload(target): + from simpler.worker import _load_py_callable_from_shm, _pack_py_callable_payload # noqa: PLC0415 + + payload = _pack_py_callable_payload(target) + shm = SharedMemory(create=True, size=len(payload)) + try: + assert shm.buf is not None + shm.buf[: len(payload)] = payload + return _load_py_callable_from_shm(shm.name) + finally: + shm.close() + shm.unlink() + + # --------------------------------------------------------------------------- # Test: lifecycle (init / close without submitting any tasks) # --------------------------------------------------------------------------- @@ -65,15 +100,100 @@ def test_context_manager(self): hw.register(lambda args: None) # close() called by __exit__, no exception - def test_register_python_fn_after_init_raises(self): - # Post-init register of a non-ChipCallable (lambda / sub fn) is - # rejected because Python callables cannot cross the fork boundary. - # ChipCallable is the only post-init target — see the next test. + def test_l2_rejects_python_callable(self): + hw = Worker(level=2, device_id=0, platform="a2a3sim", runtime="tensormap_and_ringbuffer") + with pytest.raises(TypeError, match="level 2 only supports ChipCallable"): + hw.register(lambda args: None) + + def test_register_python_fn_after_init_before_start_succeeds(self): + # init() allocates mailboxes but does not fork children. Python + # callables registered in this window still land in the startup + # snapshot consumed by the first run(). hw = Worker(level=3, num_sub_workers=0) hw.init() - with pytest.raises(NotImplementedError, match="only ChipCallable is supported post-init"): - hw.register(lambda args: None) - hw.close() + try: + cid = hw.register(lambda args: None) + assert cid in hw._callable_registry + finally: + hw.close() + + def test_register_python_fn_after_init_before_start_does_not_broadcast(self): + class BroadcastTrap: + def broadcast_control_all(self, *args, **kwargs): + raise AssertionError("pre-start Python register must not broadcast") + + hw = Worker(level=3, num_sub_workers=1) + hw.init() + real_worker = hw._worker + try: + hw._worker = BroadcastTrap() + cid = hw.register(lambda args: None) + assert cid in hw._callable_registry + finally: + hw._worker = real_worker + hw.close() + + def test_register_python_fn_after_start_no_python_children_raises(self): + hw = Worker(level=3, num_sub_workers=0) + hw.init() + try: + hw.run(lambda orch, args, cfg: None) + with pytest.raises(RuntimeError, match="no Python-capable child"): + hw.register(lambda args: None) + finally: + hw.close() + + def test_register_waits_for_first_startup_then_uses_post_start_path(self): + hw = Worker(level=3, num_sub_workers=1) + hw.init() + try: + with hw._hierarchical_start_cv: + hw._hierarchical_start_state = "starting" + + observed = {} + + def fake_post_start_register(target): + observed["target"] = target + observed["state"] = hw._hierarchical_start_state + observed["hierarchical_started"] = hw._hierarchical_started + return 7 + + hw._post_start_register_python = fake_post_start_register + result: list[int] = [] + errors: list[BaseException] = [] + wait_entered = threading.Event() + original_wait = hw._hierarchical_start_cv.wait + + def wait_with_signal(timeout=None): + wait_entered.set() + return original_wait(timeout) + + hw._hierarchical_start_cv.wait = wait_with_signal + + def do_register(): + try: + result.append(hw.register(lambda args: None)) + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + t = threading.Thread(target=do_register) + t.start() + assert wait_entered.wait(timeout=2.0) + with hw._hierarchical_start_cv: + hw._hierarchical_started = True + hw._hierarchical_start_state = "started" + hw._hierarchical_start_cv.notify_all() + t.join(timeout=2.0) + + assert not t.is_alive() + assert errors == [] + assert result == [7] + assert observed["state"] == "started" + assert observed["hierarchical_started"] is True + finally: + if "original_wait" in locals(): + hw._hierarchical_start_cv.wait = original_wait + hw.close() def test_register_chip_callable_after_init_no_chips_succeeds(self): # With no chip children (device_ids unset), the C++ broadcast is a @@ -83,7 +203,9 @@ def test_register_chip_callable_after_init_no_chips_succeeds(self): hw = Worker(level=3, num_sub_workers=0) hw.init() try: - callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) + callable_obj = ChipCallable.build( + signature=[], func_name="x", binary=b"\x00", children=[] + ) cid = hw.register(callable_obj) assert isinstance(cid, int) assert cid >= 0 @@ -136,6 +258,635 @@ def test_unregister_chip_callable_after_init_no_chips_succeeds(self): finally: hw.close() + def test_register_chip_callable_broadcast_runs_without_registry_lock(self): + hw = Worker(level=3, num_sub_workers=0) + hw._initialized = True + hw._hierarchical_started = True + callable_obj = ChipCallable.build( + signature=[], func_name="x", binary=b"\x00", children=[] + ) + observed = {} + + def fake_post_init_register(cid, target): + observed["cid"] = cid + observed["target"] = target + observed["locked"] = hw._registry_lock.locked() + + hw._post_init_register = fake_post_init_register + + cid = hw.register(callable_obj) + + assert observed == {"cid": cid, "target": callable_obj, "locked": False} + assert hw._callable_registry[cid] is callable_obj + + def test_register_at_broadcast_runs_without_registry_lock(self): + hw = Worker(level=3, num_sub_workers=0) + hw._initialized = True + callable_obj = ChipCallable.build( + signature=[], func_name="x", binary=b"\x00", children=[] + ) + observed = {} + + def fake_post_init_register(cid, target): + observed["cid"] = cid + observed["target"] = target + observed["locked"] = hw._registry_lock.locked() + + hw._post_init_register = fake_post_init_register + + hw._register_at(7, callable_obj) + + assert observed == {"cid": 7, "target": callable_obj, "locked": False} + assert hw._callable_registry[7] is callable_obj + + def test_python_control_broadcast_passes_default_timeout(self): + from simpler.worker import _CTRL_PY_UNREGISTER, _PY_CONTROL_TIMEOUT_S # noqa: PLC0415 + + class FakeControlWorker: + def __init__(self): + self.calls = [] + + def broadcast_control_all(self, worker_type, sub_cmd, cid, payload, timeout_s=None): + self.calls.append((worker_type, sub_cmd, cid, payload, timeout_s)) + return [] + + fake = FakeControlWorker() + hw = Worker(level=3, num_sub_workers=1) + hw._worker = fake + + errors = hw._broadcast_py_control([WorkerType.SUB], _CTRL_PY_UNREGISTER, 3, strict=False) + + assert errors == [] + assert fake.calls == [(WorkerType.SUB, _CTRL_PY_UNREGISTER, 3, None, _PY_CONTROL_TIMEOUT_S)] + + def test_cloudpickle_payload_roundtrip_supported_callable_shapes(self): + class AddValue: + def __init__(self, value): + self.value = value + + def __call__(self, arg): + return arg + self.value + + scale = 3 + + def nested(arg): + return arg * scale + + cases = [ + (lambda arg: arg + 1, 4, 5), + (nested, 4, 12), + (AddValue(7), 4, 11), + ] + for target, arg, expected in cases: + loaded = _roundtrip_py_callable_payload(target) + assert callable(loaded) + assert loaded(arg) == expected + + def test_python_unregister_child_failure_warns_pops_and_allows_reuse(self, capsys): + from simpler.worker import _CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER # noqa: PLC0415 + + hw = Worker(level=3, num_sub_workers=1) + cid = hw.register(lambda args: None) + hw._initialized = True + hw._hierarchical_started = True + calls = [] + + def fake_broadcast(worker_types, sub_cmd, broadcast_cid, *, payload=None, strict): + calls.append((list(worker_types), sub_cmd, broadcast_cid, strict)) + if sub_cmd == _CTRL_PY_UNREGISTER: + return ["SUB[0]: injected unregister failure"] + if sub_cmd == _CTRL_PY_REGISTER: + return [] + raise AssertionError(f"unexpected sub_cmd={sub_cmd}") + + hw._broadcast_py_control = fake_broadcast + + hw.unregister(cid) + + captured = capsys.readouterr() + assert "Python children reported errors" in captured.err + assert "injected unregister failure" in captured.err + assert cid not in hw._callable_registry + assert cid not in hw._pending_unregister_cids + + reused = hw.register(lambda args: None) + assert reused == cid + assert calls[0] == ([WorkerType.SUB], _CTRL_PY_UNREGISTER, cid, False) + assert calls[1] == ([WorkerType.SUB], _CTRL_PY_REGISTER, cid, True) + + def test_pending_unregister_cid_is_not_reused_until_broadcast_returns(self): + from simpler.worker import _CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER # noqa: PLC0415 + + hw = Worker(level=3, num_sub_workers=1) + cid = hw.register(lambda args: None) + hw._initialized = True + hw._hierarchical_started = True + + broadcast_started = threading.Event() + release_broadcast = threading.Event() + errors: list[BaseException] = [] + + def fake_broadcast(worker_types, sub_cmd, broadcast_cid, *, payload=None, strict): + if sub_cmd == _CTRL_PY_UNREGISTER: + broadcast_started.set() + assert release_broadcast.wait(timeout=2.0) + elif sub_cmd == _CTRL_PY_REGISTER: + return [] + else: + raise AssertionError(f"unexpected sub_cmd={sub_cmd}") + return [] + + hw._broadcast_py_control = fake_broadcast + + def do_unregister(): + try: + hw.unregister(cid) + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + t = threading.Thread(target=do_unregister) + t.start() + assert broadcast_started.wait(timeout=2.0) + + cid_during_unregister = hw.register(lambda args: None) + assert cid_during_unregister != cid + assert cid in hw._pending_unregister_cids + + release_broadcast.set() + t.join(timeout=2.0) + assert not t.is_alive() + assert errors == [] + + cid_after_unregister = hw.register(lambda args: None) + assert cid_after_unregister == cid + + def test_register_python_sub_callable_after_start_succeeds(self): + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + + def bootstrap(orch, args, cfg): + orch.submit_sub(bootstrap_cid) + + hw.run(bootstrap) + counter_name = counter_shm.name + + def dynamic_sub(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + dynamic_cid = hw.register(dynamic_sub) + + def run_dynamic(orch, args, cfg): + orch.submit_sub(dynamic_cid) + + hw.run(run_dynamic) + hw.close() + + assert _read_counter(counter_buf) == 1 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_post_start_python_register_waits_for_active_sub_mailbox(self): + import time # noqa: PLC0415 + + control_shm = SharedMemory(create=True, size=8) + counter_shm, counter_buf = _make_shared_counter() + hw = Worker(level=3, num_sub_workers=1) + run_errors: list[BaseException] = [] + register_errors: list[BaseException] = [] + dynamic_cids: list[int] = [] + run_thread = None + register_thread = None + try: + assert control_shm.buf is not None + _set_flag(control_shm.buf, 0, 0) # started + _set_flag(control_shm.buf, 4, 0) # release + control_name = control_shm.name + counter_name = counter_shm.name + + def blocking_sub(args): + import time as child_time # noqa: PLC0415 + + shm = SharedMemory(name=control_name) + try: + _set_flag(shm.buf, 0, 1) + while _get_flag(shm.buf, 4) == 0: + child_time.sleep(0.001) + finally: + shm.close() + + blocking_cid = hw.register(blocking_sub) + hw.init() + + def run_blocking(): + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(blocking_cid)) + except BaseException as exc: # noqa: BLE001 + run_errors.append(exc) + + run_thread = threading.Thread(target=run_blocking) + run_thread.start() + + deadline = time.monotonic() + 2.0 + while _get_flag(control_shm.buf, 0) == 0 and time.monotonic() < deadline: + time.sleep(0.001) + assert _get_flag(control_shm.buf, 0) == 1 + + def dynamic_sub(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + def do_register(): + try: + dynamic_cids.append(hw.register(dynamic_sub)) + except BaseException as exc: # noqa: BLE001 + register_errors.append(exc) + + register_thread = threading.Thread(target=do_register) + register_thread.start() + register_thread.join(timeout=0.05) + assert register_thread.is_alive() + + _set_flag(control_shm.buf, 4, 1) + run_thread.join(timeout=2.0) + register_thread.join(timeout=2.0) + + assert not run_thread.is_alive() + assert not register_thread.is_alive() + assert run_errors == [] + assert register_errors == [] + assert len(dynamic_cids) == 1 + + hw.run(lambda orch, args, cfg: orch.submit_sub(dynamic_cids[0])) + assert _read_counter(counter_buf) == 1 + finally: + if control_shm.buf is not None: + _set_flag(control_shm.buf, 4, 1) + if run_thread is not None: + run_thread.join(timeout=2.0) + if register_thread is not None: + register_thread.join(timeout=2.0) + hw.close() + control_shm.close() + control_shm.unlink() + counter_shm.close() + counter_shm.unlink() + + def test_post_start_unregister_pre_start_python_callable_removes_child_entry(self): + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + cid = hw.register(lambda args: _increment_counter(counter_buf)) + hw.init() + + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + assert _read_counter(counter_buf) == 1 + + hw.unregister(cid) + assert cid not in hw._callable_registry + with pytest.raises(RuntimeError, match="not registered"): + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + + counter_name = counter_shm.name + + def replacement(args): + shm = SharedMemory(name=counter_name) + try: + _add_counter(shm.buf, 10) + finally: + shm.close() + + reused = hw.register(replacement) + assert reused == cid + hw.run(lambda orch, args, cfg: orch.submit_sub(reused)) + hw.close() + + assert _read_counter(counter_buf) == 11 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_post_start_unregister_post_start_python_callable_removes_child_entry(self): + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + + counter_name = counter_shm.name + + def dynamic(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + cid = hw.register(dynamic) + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + assert _read_counter(counter_buf) == 1 + + hw.unregister(cid) + assert cid not in hw._callable_registry + with pytest.raises(RuntimeError, match="not registered"): + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + + reused = hw.register(dynamic) + assert reused == cid + hw.run(lambda orch, args, cfg: orch.submit_sub(reused)) + hw.close() + + assert _read_counter(counter_buf) == 2 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_post_start_dynamic_python_callable_execute_failure_propagates(self): + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + + def boom(args): + raise RuntimeError("dynamic callable boom") + + cid = hw.register(boom) + with pytest.raises(RuntimeError, match="dynamic callable boom"): + hw.run(lambda orch, args, cfg: orch.submit_sub(cid)) + finally: + hw.close() + + def test_broadcast_control_all_accepts_memoryview_payload(self): + from simpler.worker import _CTRL_PY_REGISTER, _pack_py_callable_payload # noqa: PLC0415 + + counter_shm, counter_buf = _make_shared_counter() + try: + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + + def bootstrap(orch, args, cfg): + orch.submit_sub(bootstrap_cid) + + hw.run(bootstrap) + counter_name = counter_shm.name + + def dynamic_sub(args): + shm = SharedMemory(name=counter_name) + try: + _increment_counter(shm.buf) + finally: + shm.close() + + cid = 5 + results = hw._worker.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_REGISTER, + cid, + memoryview(_pack_py_callable_payload(dynamic_sub)), + ) + assert len(results) == 1 + assert results[0].ok + + def run_dynamic(orch, args, cfg): + orch.submit_sub(cid) + + hw.run(run_dynamic) + hw.close() + + assert _read_counter(counter_buf) == 1 + finally: + counter_shm.close() + counter_shm.unlink() + + def test_broadcast_control_all_reports_malformed_payload(self): + from simpler.worker import _CTRL_PY_REGISTER # noqa: PLC0415 + + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + results = hw._worker.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"bad") + assert len(results) == 1 + assert not results[0].ok + assert "payload" in results[0].error_message + finally: + hw.close() + + def test_broadcast_control_all_empty_payload_raises_before_fanout(self): + from simpler.worker import _CTRL_PY_REGISTER # noqa: PLC0415 + + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + with pytest.raises(RuntimeError, match="payload pointer and size"): + hw._worker.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"") + finally: + hw.close() + + def test_broadcast_control_all_timeout_reports_failed_child(self): + from simpler.worker import ( + _CTRL_PY_UNREGISTER, + _IDLE, + _OFF_STATE, + _buffer_field_addr, + _mailbox_addr, + ) + from simpler.worker import _mailbox_store_i32 # noqa: PLC0415 + + shm = SharedMemory(create=True, size=MAILBOX_SIZE) + dw = _Worker(3) + try: + assert shm.buf is not None + _mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE) + dw.add_sub_worker(_mailbox_addr(shm)) + dw.init() + results = dw.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_UNREGISTER, + 0, + None, + timeout_s=0.001, + ) + assert len(results) == 1 + assert not results[0].ok + assert "timed out" in results[0].error_message + finally: + dw.close() + shm.close() + shm.unlink() + + def test_broadcast_control_all_selected_pool_routing(self): + from simpler.worker import ( + _CTRL_PY_UNREGISTER, + _CONTROL_REQUEST, + _IDLE, + _OFF_STATE, + _buffer_field_addr, + _mailbox_addr, + _mailbox_load_i32, + _mailbox_store_i32, + ) + + def make_mailbox(): + shm = SharedMemory(create=True, size=MAILBOX_SIZE) + assert shm.buf is not None + _mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE) + return shm + + for selected_type, selected_kind in ( + (WorkerType.SUB, "SUB"), + (WorkerType.NEXT_LEVEL, "NEXT_LEVEL"), + ): + sub_shm = make_mailbox() + next_shm = make_mailbox() + dw = _Worker(3) + try: + dw.add_sub_worker(_mailbox_addr(sub_shm)) + dw.add_next_level_worker(_mailbox_addr(next_shm)) + dw.init() + results = dw.broadcast_control_all( + selected_type, + _CTRL_PY_UNREGISTER, + 0, + None, + timeout_s=0.001, + ) + assert len(results) == 1 + assert results[0].worker_type == selected_kind + sub_state = _mailbox_load_i32(_buffer_field_addr(sub_shm.buf, _OFF_STATE)) + next_state = _mailbox_load_i32(_buffer_field_addr(next_shm.buf, _OFF_STATE)) + if selected_type == WorkerType.SUB: + assert sub_state == _CONTROL_REQUEST + assert next_state == _IDLE + else: + assert sub_state == _IDLE + assert next_state == _CONTROL_REQUEST + finally: + dw.close() + sub_shm.close() + sub_shm.unlink() + next_shm.close() + next_shm.unlink() + + def test_broadcast_control_all_result_shape_for_register_and_unregister(self): + from simpler.worker import _CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER # noqa: PLC0415 + + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + register_results = hw._worker.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_REGISTER, + 5, + b"bad", + ) + unregister_results = hw._worker.broadcast_control_all( + WorkerType.SUB, + _CTRL_PY_UNREGISTER, + bootstrap_cid, + None, + ) + + for result in (register_results[0], unregister_results[0]): + assert isinstance(result.worker_type, str) + assert isinstance(result.worker_index, int) + assert isinstance(result.ok, bool) + assert isinstance(result.error_message, str) + assert not register_results[0].ok + assert unregister_results[0].ok + finally: + hw.close() + + def test_nonserializable_dynamic_python_callable_does_not_consume_cid(self): + lock = threading.Lock() + hw = Worker(level=3, num_sub_workers=1) + bootstrap_cid = hw.register(lambda args: None) + hw.init() + try: + hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + before = dict(hw._callable_registry) + + def captures_lock(args): + lock.acquire(False) + + with pytest.raises(TypeError, match="lock"): + hw.register(captures_lock) + assert hw._callable_registry == before + finally: + hw.close() + + def test_chip_register_reuse_clears_seen_python_cid_before_binary_register(self): + from simpler.worker import _CTRL_PY_UNREGISTER # noqa: PLC0415 + + calls = [] + + class FakeWorker: + def broadcast_register_all(self, cid, blob_ptr, blob_size): + calls.append(("binary_register", cid, blob_size)) + + hw = Worker(level=3, num_sub_workers=1) + hw._initialized = True + hw._hierarchical_started = True + hw._worker = FakeWorker() + hw._py_callable_cids_seen.add(0) + + def fake_py_control(worker_types, sub_cmd, cid, *, payload=None, strict): + calls.append(("py_clear", list(worker_types), sub_cmd, cid, strict)) + return [] + + hw._broadcast_py_control = fake_py_control + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) + + cid = hw.register(callable_obj) + + assert cid == 0 + assert calls[0] == ("py_clear", [WorkerType.SUB], _CTRL_PY_UNREGISTER, 0, True) + assert calls[1][0] == "binary_register" + + def test_chip_register_reuse_fails_before_binary_register_when_python_clear_fails(self): + calls = [] + + class FakeWorker: + def broadcast_register_all(self, cid, blob_ptr, blob_size): + calls.append(("binary_register", cid)) + + hw = Worker(level=3, num_sub_workers=1) + hw._initialized = True + hw._hierarchical_started = True + hw._worker = FakeWorker() + hw._py_callable_cids_seen.add(0) + + def fake_py_control(worker_types, sub_cmd, cid, *, payload=None, strict): + calls.append(("py_clear", cid, strict)) + raise RuntimeError("clear failed") + + hw._broadcast_py_control = fake_py_control + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) + + with pytest.raises(RuntimeError, match="clear failed"): + hw.register(callable_obj) + + assert calls == [("py_clear", 0, True)] + assert hw._callable_registry == {} + def test_unregister_middle_cid_reuses_hole(self): # `_allocate_cid` must fill the smallest hole, not append at # len(registry). The bug it guards against: register 0/1/2, diff --git a/tests/ut/py/test_worker/test_l4_recursive.py b/tests/ut/py/test_worker/test_l4_recursive.py index d52b019b6..4f1ed290c 100644 --- a/tests/ut/py/test_worker/test_l4_recursive.py +++ b/tests/ut/py/test_worker/test_l4_recursive.py @@ -104,6 +104,17 @@ def test_add_initialized_child_raises(self): child.close() w4.close() + def test_l4_device_ids_rejected(self): + w4 = Worker(level=4, device_ids=[0], num_sub_workers=0) + with pytest.raises(RuntimeError, match="device_ids are only supported on L3"): + w4.init() + + def test_add_worker_with_device_ids_rejected(self): + w4 = Worker(level=4, device_ids=[0], num_sub_workers=0) + child = Worker(level=3, num_sub_workers=0) + with pytest.raises(RuntimeError, match="cannot be combined with device_ids"): + w4.add_worker(child) + def test_malloc_on_l4_raises_index_error(self): # L4 has no chip mailboxes — `Worker.malloc` must surface IndexError # rather than silently dispatch CTRL_MALLOC to a next_level (L3 worker) @@ -175,6 +186,38 @@ def test_l4_register_then_unregister_recycles_cid(self): finally: w4.close() + def test_l4_register_python_orch_after_start_succeeds(self): + counter_shm, counter_buf = _make_shared_counter() + try: + l3 = Worker(level=3, num_sub_workers=1) + l3_sub_cid = l3.register(lambda args: _increment_counter(counter_buf)) + + w4 = Worker(level=4, num_sub_workers=0) + bootstrap_cid = w4.register(lambda orch, args, config: None) + w4.add_worker(l3) + w4.init() + + def bootstrap(orch, args, config): + orch.submit_next_level(bootstrap_cid, TaskArgs(), CallConfig()) + + w4.run(bootstrap) + + def dynamic_l3_orch(orch, args, config): + orch.submit_sub(l3_sub_cid) + + dynamic_cid = w4.register(dynamic_l3_orch) + + def l4_orch(orch, args, config): + orch.submit_next_level(dynamic_cid, TaskArgs(), CallConfig()) + + w4.run(l4_orch) + w4.close() + + assert _read_counter(counter_buf) == 1 + finally: + counter_shm.close() + counter_shm.unlink() + # --------------------------------------------------------------------------- # Test: L4 → L3 PROCESS mode — single dispatch From 1ac74636a99ea5d5221bd5c52e58e6421c46abfc Mon Sep 17 00:00:00 2001 From: puddingfjz <2811443837@qq.com> Date: Fri, 22 May 2026 18:01:27 +0800 Subject: [PATCH 4/6] Fix Python callable shm payload validation --- docs/python-callable-serialization.md | 7 +- python/simpler/worker.py | 24 +++---- tests/ut/py/test_worker/test_host_worker.py | 72 +++++++++------------ 3 files changed, 46 insertions(+), 57 deletions(-) diff --git a/docs/python-callable-serialization.md b/docs/python-callable-serialization.md index aaeca5cc2..8d4aed314 100644 --- a/docs/python-callable-serialization.md +++ b/docs/python-callable-serialization.md @@ -216,8 +216,11 @@ exact bytes returned by `cloudpickle.dumps(target)`: The first implementation accepts only `(magic="SPYC", version=1, serializer=1, flags=0)`. Unknown magic, version, serializer, non-zero flags, -size mismatch, malformed bytes, or incompatible pickle data fail through the -normal mailbox error field. +payload size larger than the mapped shm, malformed bytes, or incompatible +pickle data fail through the normal mailbox error field. The child treats +`payload_size` as the authoritative byte count and ignores any trailing bytes +in the shm object, because some platforms expose POSIX shm at a page-rounded +size even when the parent requested the exact payload length. ### Child Deserialization diff --git a/python/simpler/worker.py b/python/simpler/worker.py index 06a29f803..82cb0b933 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -66,7 +66,6 @@ def my_l4_orch(orch, args, config): from typing import Any, Optional import cloudpickle - from _task_interface import ( # pyright: ignore[reportMissingImports] MAX_REGISTERED_CALLABLE_IDS, RunTiming, @@ -218,13 +217,16 @@ def my_l4_orch(orch, args, config): def _pack_py_callable_payload(target) -> bytes: payload = cloudpickle.dumps(target) - return _PY_CALLABLE_HEADER.pack( - _PY_CALLABLE_MAGIC, - _PY_CALLABLE_VERSION, - _PY_CALLABLE_SERIALIZER_CLOUDPICKLE, - 0, - len(payload), - ) + payload + return ( + _PY_CALLABLE_HEADER.pack( + _PY_CALLABLE_MAGIC, + _PY_CALLABLE_VERSION, + _PY_CALLABLE_SERIALIZER_CLOUDPICKLE, + 0, + len(payload), + ) + + payload + ) def _load_py_callable_from_shm(shm_name: str): @@ -244,7 +246,7 @@ def _load_py_callable_from_shm(shm_name: str): if flags != 0: raise RuntimeError(f"unsupported python callable payload flags: {flags}") expected_size = _PY_CALLABLE_HEADER.size + int(payload_size) - if expected_size != shm.size: + if expected_size > shm.size: raise RuntimeError(f"python callable payload size mismatch: header={payload_size}, shm={shm.size}") payload = bytes(shm_buf[_PY_CALLABLE_HEADER.size : expected_size]) finally: @@ -858,9 +860,7 @@ def __init__( self._registry_lock = threading.Lock() self._pending_unregister_cids: set[int] = set() self._py_callable_cids_seen: set[int] = set() - self._py_control_timeout_s = float( - config.get("py_control_timeout_s", _PY_CONTROL_TIMEOUT_S) - ) + self._py_control_timeout_s = float(config.get("py_control_timeout_s", _PY_CONTROL_TIMEOUT_S)) self._hierarchical_start_state = "not_started" self._hierarchical_start_mu = threading.Lock() self._hierarchical_start_cv = threading.Condition(self._hierarchical_start_mu) diff --git a/tests/ut/py/test_worker/test_host_worker.py b/tests/ut/py/test_worker/test_host_worker.py index 1e13d54b9..07c28ebaf 100644 --- a/tests/ut/py/test_worker/test_host_worker.py +++ b/tests/ut/py/test_worker/test_host_worker.py @@ -27,7 +27,19 @@ WorkerType, _Worker, ) -from simpler.worker import Worker +from simpler.worker import ( + _CONTROL_REQUEST, + _CTRL_PY_REGISTER, + _CTRL_PY_UNREGISTER, + _IDLE, + _OFF_STATE, + Worker, + _buffer_field_addr, + _mailbox_addr, + _mailbox_load_i32, + _mailbox_store_i32, + _pack_py_callable_payload, +) # --------------------------------------------------------------------------- # Helpers @@ -203,9 +215,7 @@ def test_register_chip_callable_after_init_no_chips_succeeds(self): hw = Worker(level=3, num_sub_workers=0) hw.init() try: - callable_obj = ChipCallable.build( - signature=[], func_name="x", binary=b"\x00", children=[] - ) + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) cid = hw.register(callable_obj) assert isinstance(cid, int) assert cid >= 0 @@ -262,9 +272,7 @@ def test_register_chip_callable_broadcast_runs_without_registry_lock(self): hw = Worker(level=3, num_sub_workers=0) hw._initialized = True hw._hierarchical_started = True - callable_obj = ChipCallable.build( - signature=[], func_name="x", binary=b"\x00", children=[] - ) + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) observed = {} def fake_post_init_register(cid, target): @@ -282,9 +290,7 @@ def fake_post_init_register(cid, target): def test_register_at_broadcast_runs_without_registry_lock(self): hw = Worker(level=3, num_sub_workers=0) hw._initialized = True - callable_obj = ChipCallable.build( - signature=[], func_name="x", binary=b"\x00", children=[] - ) + callable_obj = ChipCallable.build(signature=[], func_name="x", binary=b"\x00", children=[]) observed = {} def fake_post_init_register(cid, target): @@ -629,8 +635,6 @@ def boom(args): hw.close() def test_broadcast_control_all_accepts_memoryview_payload(self): - from simpler.worker import _CTRL_PY_REGISTER, _pack_py_callable_payload # noqa: PLC0415 - counter_shm, counter_buf = _make_shared_counter() try: hw = Worker(level=3, num_sub_workers=1) @@ -651,7 +655,9 @@ def dynamic_sub(args): shm.close() cid = 5 - results = hw._worker.broadcast_control_all( + worker_impl = hw._worker + assert worker_impl is not None + results = worker_impl.broadcast_control_all( WorkerType.SUB, _CTRL_PY_REGISTER, cid, @@ -672,14 +678,14 @@ def run_dynamic(orch, args, cfg): counter_shm.unlink() def test_broadcast_control_all_reports_malformed_payload(self): - from simpler.worker import _CTRL_PY_REGISTER # noqa: PLC0415 - hw = Worker(level=3, num_sub_workers=1) bootstrap_cid = hw.register(lambda args: None) hw.init() try: hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) - results = hw._worker.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"bad") + worker_impl = hw._worker + assert worker_impl is not None + results = worker_impl.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"bad") assert len(results) == 1 assert not results[0].ok assert "payload" in results[0].error_message @@ -687,28 +693,19 @@ def test_broadcast_control_all_reports_malformed_payload(self): hw.close() def test_broadcast_control_all_empty_payload_raises_before_fanout(self): - from simpler.worker import _CTRL_PY_REGISTER # noqa: PLC0415 - hw = Worker(level=3, num_sub_workers=1) bootstrap_cid = hw.register(lambda args: None) hw.init() try: hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) + worker_impl = hw._worker + assert worker_impl is not None with pytest.raises(RuntimeError, match="payload pointer and size"): - hw._worker.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"") + worker_impl.broadcast_control_all(WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"") finally: hw.close() def test_broadcast_control_all_timeout_reports_failed_child(self): - from simpler.worker import ( - _CTRL_PY_UNREGISTER, - _IDLE, - _OFF_STATE, - _buffer_field_addr, - _mailbox_addr, - ) - from simpler.worker import _mailbox_store_i32 # noqa: PLC0415 - shm = SharedMemory(create=True, size=MAILBOX_SIZE) dw = _Worker(3) try: @@ -732,17 +729,6 @@ def test_broadcast_control_all_timeout_reports_failed_child(self): shm.unlink() def test_broadcast_control_all_selected_pool_routing(self): - from simpler.worker import ( - _CTRL_PY_UNREGISTER, - _CONTROL_REQUEST, - _IDLE, - _OFF_STATE, - _buffer_field_addr, - _mailbox_addr, - _mailbox_load_i32, - _mailbox_store_i32, - ) - def make_mailbox(): shm = SharedMemory(create=True, size=MAILBOX_SIZE) assert shm.buf is not None @@ -785,20 +771,20 @@ def make_mailbox(): next_shm.unlink() def test_broadcast_control_all_result_shape_for_register_and_unregister(self): - from simpler.worker import _CTRL_PY_REGISTER, _CTRL_PY_UNREGISTER # noqa: PLC0415 - hw = Worker(level=3, num_sub_workers=1) bootstrap_cid = hw.register(lambda args: None) hw.init() try: hw.run(lambda orch, args, cfg: orch.submit_sub(bootstrap_cid)) - register_results = hw._worker.broadcast_control_all( + worker_impl = hw._worker + assert worker_impl is not None + register_results = worker_impl.broadcast_control_all( WorkerType.SUB, _CTRL_PY_REGISTER, 5, b"bad", ) - unregister_results = hw._worker.broadcast_control_all( + unregister_results = worker_impl.broadcast_control_all( WorkerType.SUB, _CTRL_PY_UNREGISTER, bootstrap_cid, From be65253bb45ce4aa9f2f1cb8fcb16a9e1c1d9324 Mon Sep 17 00:00:00 2001 From: puddingfjz <2811443837@qq.com> Date: Tue, 26 May 2026 15:09:19 +0800 Subject: [PATCH 5/6] Fix: close callable startup registration race - Gate pre-start register/unregister with hierarchical startup snapshot - Drop redundant Python control payload alias and clear seen Python cids after successful residue cleanup - Add regression coverage for startup/register interleaving and repeated cid reuse --- docs/python-callable-serialization.md | 12 +-- python/simpler/worker.py | 82 ++++++++++++++------- tests/ut/py/test_worker/test_host_worker.py | 78 ++++++++++++++++++++ 3 files changed, 139 insertions(+), 33 deletions(-) diff --git a/docs/python-callable-serialization.md b/docs/python-callable-serialization.md index 8d4aed314..186357a92 100644 --- a/docs/python-callable-serialization.md +++ b/docs/python-callable-serialization.md @@ -318,11 +318,13 @@ The startup race is handled by a one-time hierarchical startup state, not by a run-wide quiescent guard: - `_hierarchical_start_state` is protected by a dedicated - `_hierarchical_start_mu` / `_hierarchical_start_cv`, separate from - `_registry_lock`. -- Startup begins as `not_started`, moves to `starting` before - `_start_hierarchical()` takes the registry snapshot, and moves to `started` - only after child mailboxes are registered with the C++ Worker. + `_hierarchical_start_mu` / `_hierarchical_start_cv`; `_registry_lock` + protects only the registry contents. +- Startup begins as `not_started`. `_start_hierarchical()` holds + `_hierarchical_start_cv`, then `_registry_lock`, while it moves to + `starting` and takes the registry snapshot. It releases `_registry_lock` + before any `os.fork()`, and moves to `started` only after child mailboxes + are registered with the C++ Worker. - A Python callable register/unregister that observes `starting` waits on a condition variable without holding `_registry_lock`. After startup succeeds, it uses the post-start control path; after startup fails, it raises. diff --git a/python/simpler/worker.py b/python/simpler/worker.py index 82cb0b933..70fc9a089 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -934,16 +934,33 @@ def register(self, target) -> int: no COW constraint). When called post-init, ChipCallables are prepared on the device immediately; pre-init registrations are batched and prepared at the end of ``init()``. + + See docs/python-callable-serialization.md for the Python dynamic + register path and docs/callable-ipc-dynamic-register.md for the + ChipCallable binary path. """ if self.level == 2 and not isinstance(target, ChipCallable): raise TypeError("Worker.register: level 2 only supports ChipCallable targets") if self.level >= 3: - self._wait_hierarchical_start_if_needed() if not isinstance(target, ChipCallable): if not callable(target): raise TypeError("Worker.register: non-ChipCallable target must be callable") - if self._initialized and getattr(self, "_hierarchical_started", False): - return self._post_start_register_python(target) + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + if self._hierarchical_start_state != "started" and not getattr( + self, "_hierarchical_started", False + ): + with self._registry_lock: + cid = self._allocate_cid() + self._callable_registry[cid] = target + if not isinstance(target, ChipCallable): + self._py_callable_cids_seen.add(cid) + return cid + if not isinstance(target, ChipCallable): + return self._post_start_register_python(target) with self._registry_lock: cid = self._allocate_cid() @@ -972,15 +989,6 @@ def register(self, target) -> int: self._chip_worker.prepare_callable(cid, target) return cid - def _wait_hierarchical_start_if_needed(self) -> None: - if self.level < 3: - return - with self._hierarchical_start_cv: - while self._hierarchical_start_state == "starting": - self._hierarchical_start_cv.wait() - if self._hierarchical_start_state == "failed": - raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") - def _python_worker_types(self) -> list[WorkerType]: worker_types: list[WorkerType] = [] if self._config.get("num_sub_workers", 0) > 0: @@ -1023,13 +1031,12 @@ def _broadcast_py_control( return [] assert self._worker is not None errors: list[str] = [] - payload_bytes = payload if payload is not None else None for worker_type in worker_types: results = self._worker.broadcast_control_all( worker_type, int(sub_cmd), int(cid), - payload_bytes, + payload, timeout_s=self._py_control_timeout_s, ) for result in results: @@ -1108,6 +1115,7 @@ def _post_init_register(self, cid: int, target: ChipCallable) -> None: assert self._worker is not None if cid in self._py_callable_cids_seen: self._broadcast_py_control(self._python_worker_types(), _CTRL_PY_UNREGISTER, cid, strict=True) + self._py_callable_cids_seen.discard(cid) self._worker.broadcast_register_all(int(cid), int(target.buffer_ptr()), int(target.buffer_size())) def unregister(self, cid: int) -> None: @@ -1128,7 +1136,24 @@ def unregister(self, cid: int) -> None: Raises: KeyError: cid was never registered. """ - self._wait_hierarchical_start_if_needed() + if self.level >= 3: + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + if self._hierarchical_start_state != "started" and not getattr( + self, "_hierarchical_started", False + ): + with self._registry_lock: + if cid not in self._callable_registry: + raise KeyError(f"Worker.unregister: cid={cid} not registered") + if cid in self._pending_unregister_cids: + raise KeyError(f"Worker.unregister: cid={cid} already pending unregister") + target = self._callable_registry.pop(cid) + if not isinstance(target, ChipCallable): + self._py_callable_cids_seen.discard(cid) + return target = None with self._registry_lock: if cid not in self._callable_registry: @@ -1294,23 +1319,24 @@ def _init_hierarchical(self) -> None: def _start_hierarchical(self) -> None: # noqa: PLR0912 -- three parallel fork loops (sub/chip/next) + bootstrap wait + scheduler register/init; branches track the fork order documented in the body """Fork child processes and start C++ scheduler. Called on first run().""" - with self._hierarchical_start_cv: - while self._hierarchical_start_state == "starting": - self._hierarchical_start_cv.wait() - if self._hierarchical_start_state == "started": - return - if self._hierarchical_start_state == "failed": - raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") - self._hierarchical_start_state = "starting" - device_ids = self._config.get("device_ids", []) n_sub = self._config.get("num_sub_workers", 0) try: - # Fork children from an immutable snapshot. Dynamic register callers - # that race this startup wait and then use the post-start path. - with self._registry_lock: - registry = dict(self._callable_registry) + # Fork children from an immutable snapshot. The state transition + # and snapshot are one gate, so dynamic register/unregister callers + # cannot return through the pre-start path after this point. + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "started": + return + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + self._hierarchical_start_state = "starting" + with self._registry_lock: + registry = dict(self._callable_registry) + self._hierarchical_start_cv.notify_all() # Fork SubWorker processes (MUST be before any C++ threads) for i in range(n_sub): diff --git a/tests/ut/py/test_worker/test_host_worker.py b/tests/ut/py/test_worker/test_host_worker.py index 07c28ebaf..51db62f19 100644 --- a/tests/ut/py/test_worker/test_host_worker.py +++ b/tests/ut/py/test_worker/test_host_worker.py @@ -207,6 +207,74 @@ def do_register(): hw._hierarchical_start_cv.wait = original_wait hw.close() + def test_register_blocks_startup_snapshot_from_not_started_window(self): + hw = Worker(level=3, num_sub_workers=0) + hw.init() + + real_registry_lock = hw._registry_lock + register_waiting = threading.Event() + release_register = threading.Event() + startup_snapshot_attempted = threading.Event() + result: list[int] = [] + errors: list[BaseException] = [] + + class BlockingRegistryLock: + def __enter__(self): + thread_name = threading.current_thread().name + if thread_name == "register-thread": + register_waiting.set() + if not release_register.wait(timeout=2.0): + raise TimeoutError("test timed out waiting to release register") + elif thread_name == "startup-thread": + startup_snapshot_attempted.set() + return real_registry_lock.__enter__() + + def __exit__(self, exc_type, exc, tb): + return real_registry_lock.__exit__(exc_type, exc, tb) + + def locked(self): + return real_registry_lock.locked() + + hw._registry_lock = BlockingRegistryLock() + + def do_register(): + try: + result.append(hw.register(lambda args: None)) + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + def do_startup(): + try: + hw._start_hierarchical() + except BaseException as exc: # noqa: BLE001 + errors.append(exc) + + register_thread = threading.Thread(target=do_register, name="register-thread") + startup_thread = threading.Thread(target=do_startup, name="startup-thread") + try: + register_thread.start() + assert register_waiting.wait(timeout=2.0) + + startup_thread.start() + assert not startup_snapshot_attempted.wait(timeout=0.2) + + release_register.set() + register_thread.join(timeout=2.0) + startup_thread.join(timeout=2.0) + + assert not register_thread.is_alive() + assert not startup_thread.is_alive() + assert errors == [] + assert result == [0] + assert startup_snapshot_attempted.is_set() + assert hw._hierarchical_start_state == "started" + finally: + release_register.set() + register_thread.join(timeout=2.0) + startup_thread.join(timeout=2.0) + hw._registry_lock = real_registry_lock + hw.close() + def test_register_chip_callable_after_init_no_chips_succeeds(self): # With no chip children (device_ids unset), the C++ broadcast is a # no-op (next_level_threads_ is empty) — exercises the facade path @@ -846,6 +914,16 @@ def fake_py_control(worker_types, sub_cmd, cid, *, payload=None, strict): assert cid == 0 assert calls[0] == ("py_clear", [WorkerType.SUB], _CTRL_PY_UNREGISTER, 0, True) assert calls[1][0] == "binary_register" + assert 0 not in hw._py_callable_cids_seen + + hw._callable_registry.pop(0) + calls.clear() + + cid = hw.register(ChipCallable.build(signature=[], func_name="y", binary=b"\x00", children=[])) + + assert cid == 0 + assert len(calls) == 1 + assert calls[0][0:2] == ("binary_register", 0) def test_chip_register_reuse_fails_before_binary_register_when_python_clear_fails(self): calls = [] From 825f0fd4a053e6864424ac537f206ad6fd9176b5 Mon Sep 17 00:00:00 2001 From: puddingfjz <2811443837@qq.com> Date: Tue, 26 May 2026 15:38:05 +0800 Subject: [PATCH 6/6] Fix: satisfy worker unregister lint Extract the pre-start unregister gate into a helper so ruff's branch-count check passes while preserving the startup snapshot locking semantics. --- python/simpler/worker.py | 44 +++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/python/simpler/worker.py b/python/simpler/worker.py index 70fc9a089..e4956e708 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -950,9 +950,7 @@ def register(self, target) -> int: self._hierarchical_start_cv.wait() if self._hierarchical_start_state == "failed": raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") - if self._hierarchical_start_state != "started" and not getattr( - self, "_hierarchical_started", False - ): + if self._hierarchical_start_state != "started" and not getattr(self, "_hierarchical_started", False): with self._registry_lock: cid = self._allocate_cid() self._callable_registry[cid] = target @@ -1118,6 +1116,26 @@ def _post_init_register(self, cid: int, target: ChipCallable) -> None: self._py_callable_cids_seen.discard(cid) self._worker.broadcast_register_all(int(cid), int(target.buffer_ptr()), int(target.buffer_size())) + def _pre_start_unregister_if_needed(self, cid: int) -> bool: + if self.level < 3: + return False + with self._hierarchical_start_cv: + while self._hierarchical_start_state == "starting": + self._hierarchical_start_cv.wait() + if self._hierarchical_start_state == "failed": + raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") + if self._hierarchical_start_state == "started" or getattr(self, "_hierarchical_started", False): + return False + with self._registry_lock: + if cid not in self._callable_registry: + raise KeyError(f"Worker.unregister: cid={cid} not registered") + if cid in self._pending_unregister_cids: + raise KeyError(f"Worker.unregister: cid={cid} already pending unregister") + target = self._callable_registry.pop(cid) + if not isinstance(target, ChipCallable): + self._py_callable_cids_seen.discard(cid) + return True + def unregister(self, cid: int) -> None: """Drop *cid* from the registry and propagate to chip children. @@ -1136,24 +1154,8 @@ def unregister(self, cid: int) -> None: Raises: KeyError: cid was never registered. """ - if self.level >= 3: - with self._hierarchical_start_cv: - while self._hierarchical_start_state == "starting": - self._hierarchical_start_cv.wait() - if self._hierarchical_start_state == "failed": - raise RuntimeError("Worker hierarchical startup failed; close this Worker and create a new one") - if self._hierarchical_start_state != "started" and not getattr( - self, "_hierarchical_started", False - ): - with self._registry_lock: - if cid not in self._callable_registry: - raise KeyError(f"Worker.unregister: cid={cid} not registered") - if cid in self._pending_unregister_cids: - raise KeyError(f"Worker.unregister: cid={cid} already pending unregister") - target = self._callable_registry.pop(cid) - if not isinstance(target, ChipCallable): - self._py_callable_cids_seen.discard(cid) - return + if self._pre_start_unregister_if_needed(cid): + return target = None with self._registry_lock: if cid not in self._callable_registry: