Summary
On the GB10 (sm_121, CUDA 13), reusing arena memory after ResetPool() causes a subsequent forward+backward pass to compute a wrong gradient (collapses toward ~0). The save-for-backward (Saver-pinned) activations written by a step's forward get overwritten by that same step's later forward ops when they land on reused arena addresses. This is the same overwrite-on-reuse class as zerfoo#842 (partially worked around for specific nodes by recomputing from live inputs), but it bites the general training loop whenever ResetPool runs between steps.
Repro / evidence (Wolf CrossAsset, f32, batched B=256, qk-norm)
Wolf's batched training loop called ResetPool() once per mini-batch (to reclaim per-step scratch, issue-#118 style). Per-batch gradient L2 of the head weight, GPU vs CPU at identical seeded weights:
| batch |
CPU grad |
GPU grad |
| 0 |
4.315349 |
4.315349 (identical) |
| 1 |
3.737152 |
2.5e-06 |
| 2 |
2.698082 |
1.98e-18 |
| 3 |
1.859067 |
0 |
- Batch 0 (no preceding ResetPool) is bit-identical CPU/GPU — forward, backward, and on-device AdamW step all correct.
- Every batch whose forward runs after a ResetPool produces a near-zero gradient, even though the forward LOSS is correct. So the forward is fine; the backward reads corrupted (overwritten) saved tensors.
- Disabling ResetPool entirely → GPU trains bit-identical to CPU (loss 0.747481, acc 0.6804). Confirms the corruption is the reset+reuse, not the kernels.
- It is NOT the optimizer state: re-marking the arena floor after the first step (so AdamW m/v sit below the floor) had zero effect; the gradient itself is wrong before opt.Step.
Expected
After ResetPool(), a fresh forward+backward must compute correct gradients. Reused arena memory must not overlap tensors still pinned by the Saver for the in-flight backward.
Likely fix area
Arena allocator / save-for-backward lifetime: either (a) the Saver should pin its tensors in non-resettable storage, or (b) the allocator must not hand out regions that alias live Saver-pinned tensors within a forward/backward pass.
Wolf interim
Hoisted ResetPool to once-per-epoch (feza-ai/wolf#218) — works for the capture-replay (production) path because replay uses a frozen recorded layout. The eager path (capture disabled) still diverges over epochs, which this upstream fix should resolve.
Platform: linux/arm64, NVIDIA GB10, CUDA 13 / sm_121. Models built from the existing Containerfile. Happy to provide the env-gated probe branch and full logs.
Summary
On the GB10 (sm_121, CUDA 13), reusing arena memory after
ResetPool()causes a subsequent forward+backward pass to compute a wrong gradient (collapses toward ~0). The save-for-backward (Saver-pinned) activations written by a step's forward get overwritten by that same step's later forward ops when they land on reused arena addresses. This is the same overwrite-on-reuse class as zerfoo#842 (partially worked around for specific nodes by recomputing from live inputs), but it bites the general training loop wheneverResetPoolruns between steps.Repro / evidence (Wolf CrossAsset, f32, batched B=256, qk-norm)
Wolf's batched training loop called
ResetPool()once per mini-batch (to reclaim per-step scratch, issue-#118 style). Per-batch gradient L2 of the head weight, GPU vs CPU at identical seeded weights:Expected
After
ResetPool(), a fresh forward+backward must compute correct gradients. Reused arena memory must not overlap tensors still pinned by the Saver for the in-flight backward.Likely fix area
Arena allocator / save-for-backward lifetime: either (a) the Saver should pin its tensors in non-resettable storage, or (b) the allocator must not hand out regions that alias live Saver-pinned tensors within a forward/backward pass.
Wolf interim
Hoisted ResetPool to once-per-epoch (feza-ai/wolf#218) — works for the capture-replay (production) path because replay uses a frozen recorded layout. The eager path (capture disabled) still diverges over epochs, which this upstream fix should resolve.
Platform: linux/arm64, NVIDIA GB10, CUDA 13 / sm_121. Models built from the existing Containerfile. Happy to provide the env-gated probe branch and full logs.