Skip to content

Add docs_per_step for dynamic microbatch accumulation#520

Open
jlamypoirier wants to merge 1 commit into
jlp_fp32_lm_headfrom
jlp_rl_features
Open

Add docs_per_step for dynamic microbatch accumulation#520
jlamypoirier wants to merge 1 commit into
jlp_fp32_lm_headfrom
jlp_rl_features

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 19, 2026

Stacked on top of #526 — base is jlp_fp32_lm_head. Diff view shows only the docs_per_step delta. Rebase to main once #526 lands.

Summary

A schedule config field that replaces the static microbatch count with a runtime document-count target. Each step accumulates microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global cumulative total reaches the target. Matches DeepSpeed's gradient_accumulation_passes semantics for RL: each microbatch holds one rollout and the step boundary is set by total rollouts rather than a fixed microbatch count.

  • ScheduleConfig.docs_per_step — when > 0, Trainer._prefetch_to_doc_target drives the dynamic accumulation. The final step total is broadcast back to every microbatch so the loss normalization denominator stays consistent.
  • Trainer._get_or_build_schedule(N) builds and caches a per-N Schedule with _depth_first_override = N // breadth_first_micro_batches, reusing existing schedule machinery without touching the runner.
  • Schedule._eff_{depth_first,sequential_micro_batches,num_inputs} expose the effective values under an override.

Off by default (docs_per_step=0) — the disabled path takes the original static-schedule branch.

Test plan

  • pytest tests/layers/test_docs_per_step.py — passes

Originally part of #502.

@jlamypoirier jlamypoirier mentioned this pull request May 19, 2026
4 tasks
@jlamypoirier jlamypoirier changed the title RL training features (#502 minus GSPO) Deepspeed parity hacks May 21, 2026
@jlamypoirier jlamypoirier changed the title Deepspeed parity hacks Deepspeed parity tweaks May 25, 2026
A schedule config field that replaces the static microbatch count with a
runtime document-count target. Matches DeepSpeed's
gradient_accumulation_passes semantics for RL: each microbatch holds one
rollout and the step boundary is set by total rollouts rather than a
fixed microbatch count.

- ScheduleConfig.docs_per_step — when >0, Trainer._prefetch_to_doc_target
  fetches microbatches one at a time, all-reduces the per-microbatch doc
  count, and stops once the global total reaches the target. The final
  step total is broadcast to every microbatch so the loss normalization
  stays consistent.
- Trainer._get_or_build_schedule(N) builds and caches a per-N Schedule
  with _depth_first_override = N // breadth_first_micro_batches, reusing
  the schedule machinery without touching the runner.
- Schedule._eff_{depth_first,sequential_micro_batches,num_inputs} expose
  the effective values under an override.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier changed the title Deepspeed parity tweaks Add docs_per_step for dynamic microbatch accumulation May 27, 2026
@jlamypoirier jlamypoirier changed the base branch from main to jlp_fp32_lm_head May 27, 2026 19:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant