Skip to content

feat(engine): introduce Engine class; consolidate LLM/VLM recipes#2269

Draft
HuiyingLi wants to merge 34 commits into
mainfrom
huiyingl/feat/engine
Draft

feat(engine): introduce Engine class; consolidate LLM/VLM recipes#2269
HuiyingLi wants to merge 34 commits into
mainfrom
huiyingl/feat/engine

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

Summary

  • New `nemo_automodel.engine.Engine` (TorchTitan-style): `init` owns mesh, model, optimizer, and lr-scheduler construction. `VLMEngine` subclass adds multimodal CP pre-embed and PP media chunking.
  • Recipes routed through Engine: `recipes/llm/train_ft.py`, `recipes/vlm/finetune.py`, and `recipes/llm/benchmark.py` no longer build infra inline.
  • Builders moved out of recipes into `components/training/build.py` with typed-only signatures (no `ConfigNode` coupling at the component layer); `build_vlm_model` collapsed into `build_model`.
  • Recipe consolidation into `BaseRecipe` (7 phases): distributed/RNG env setup, remote-logger wiring (WandB/MLflow/Comet), multi-validation-dataset loop, PP-aware `_run_validation_epoch`, MoE load-balance collection + logging, QAT delayed-fake-quant scheduling. Net diff: train_ft.py `-1200` lines, finetune.py `-881` lines.
  • Shared metric loggers: `log_training_metrics` + `log_validation_metrics` extracted to `components/loggers/metric_logger.py`.
  • Dedupe: `is_nemo_auto_factory()` predicate in `_transformers/auto_model.py`; `MoEParallelizerConfig.coerce()` classmethod.
  • Fix: Engine uses `cfg.model.instantiate` so nested sub-ConfigNodes (e.g., `backend:`) recurse via `_instantiate_value` instead of leaking `target` into kwargs.

Test plan

  • Unit tests: `tests/unit_tests/engine/` covers Engine + VLMEngine construction, Config dataclass, forward/backward path, list-of-batches.
  • Smoke sweep on 7 training configs (LLM FSDP2/DDP/HSDP, VLM FSDP2, MoE EP, PP) — pre-refactor.
  • Qwen3-VL-MoE-30B EP=4 PP=2 smoke (5 steps, loss 2.19 → 2.33).
  • Qwen3-VL-MoE-30B EP=8 smoke (5 steps) — losses match pre-refactor `283df77c` and latest `origin/main` 91890e4 within ±0.04 across all 5 steps.
  • CI green on this PR.

🤖 Generated with Claude Code

HuiyingLi added 30 commits May 8, 2026 01:22
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
- components/training/batch_split.py: split_into_microbatches helper that
  slices a batch dict into N microbatches along dim 0; passes through opaque
  values (multimodal dicts, scalars) unchanged.
- components/checkpoint/api.py: thin save_checkpoint / load_checkpoint /
  export_weights wrappers around Checkpointer. Derive dp/tp/pp ranks from
  MeshContext so callers stop wiring them by hand. export_weights materializes
  DTensor via .full_tensor() and optionally adapts state-dict keys to HF
  format via the model's state_dict_adapter.
- components/distributed/device.py: offload / onload helpers for moving
  model + optimizer state between CPU and CUDA (used for train↔rollout
  offload in RL pipelines).

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
A single-file, plain concrete class that exposes the atomic training-step
surface for RL frameworks, recipes, and downstream backends:

  build / forward_backward / zero_grad / optimizer_step / lr_scheduler_step
  train_mode / eval_mode / save_checkpoint / load_checkpoint
  export_weights / to(device)

Design properties:
  - No ABC, no registry, no ModelHandle wrapper. State lives as direct
    attributes (engine.model, engine.optimizer, engine.mesh). Subclass to
    swap a method; fork the file to swap everything.
  - Method bodies are flat. forward_backward has the microbatch loop
    inline — gradient-accumulation prep, the MoEAuxLossAutoScaler
    main_loss_backward_scale mutation, CP / PP / FusedLinearCE handling,
    sync_ctx for FSDP grad-sync gating, and backward are visible
    top-to-bottom in one method. No hidden helpers in another file.
  - optimizer_step inlines the 9-arg scale_grads_and_clip_grad_norm call
    plus the finite-grad check; returns (ok, grad_norm).
  - Tier conventions are informal (duck-typed): backends that omit
    export_weights / to() simply don't support those tiers — no
    NotImplementedError boilerplate or tier introspection.

Engine.build() uses PR 2190's typed component builders
(components.optim.build.build_optimizer takes factory + kwargs separately),
resolving _target_ inline via _callable_and_kwargs so the Engine never
imports from recipes/_*. build_model still lives in
recipes/llm/train_ft.py — moving it is a documented follow-up.

Also exports Engine lazily at the top of the package via _LAZY_ATTRS.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
- tests/unit_tests/test_engine.py: 8 smoke tests covering the import path,
  construction without distributed init, introspection defaults when the
  Engine is unbuilt, expected method / property surface (matches the
  design doc), forward_backward signature, and the split_into_microbatches
  primitive.
- tests/unit_tests/test_engine_methods.py: 13 functional tests that drive
  the Engine end-to-end against a manually-constructed tiny nn.Module
  (bypassing the HF model-build pipeline). Exercises forward_backward
  (loss_fn=None and explicit, with / without microbatching, forward_only),
  optimizer_step (clip + step + parameter movement), export_weights,
  train_mode / eval_mode restoration, to('cpu'), _build_lr_scheduler
  direct construction (with / without warmup ratio, missing-total_steps
  error), and a full SFT-style loop iteration.

Tests use gloo backend with WORLD_SIZE=1 so they run on any host without
CUDA / multi-GPU requirements.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
- Engine.forward_backward now accepts either a single dict (split internally
  into num_microbatches) or a pre-split list[dict] of microbatches (recipes
  pass their dataloader-split batches directly). Return shape adds 'losses'
  (per-microbatch detached losses) so recipes can compute their own reporting
  loss.
- New CP/THD shaping attrs on Engine (cp_use_te, cp_padding_token_id,
  cp_num_chunks) — recipes set them after construction.
- New optional attribute hooks for the LLM path: fp8_autocast (TE FP8
  context manager factory) and extra_loss_fn (e.g. MTP auxiliary loss
  for DeepSeek-V3-style models). Both default None; orthogonal to the
  main loop body.
- New subclass hook methods _pre_cp_hook(mb) and _pre_pp_schedule_hook(mb,
  pp, input_ids) — base no-ops; VLMEngine overrides them.
- nemo_automodel/vlm_engine.py: VLMEngine(Engine) overriding the two hooks
  for (1) CP multimodal pre-embed and (2) PP media-tensor chunking.
  chunk_vlm_media moves here alongside.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
setup(): construct Engine (or VLMEngine for vlm) and inject the already-built
model/optimizer/lr_scheduler/mesh. We don't call engine.build() because the
recipe handles construction itself.

_run_train_optim_step shrinks to ~15 lines per recipe: engine.zero_grad(),
engine.forward_backward(batches, loss_fn, num_label_tokens), engine.
optimizer_step(), engine.lr_scheduler_step(). The microbatch loop,
prepare_for_grad_accumulation / prepare_for_final_backward, MoEAuxLossAuto-
Scaler, scale_grads_and_clip_grad_norm, sync_ctx, CP/PP shaping, FusedLinearCE
lm_head extraction, FP8 autocast, MTP extra loss (LLM), CP multimodal
pre-embed and PP media chunking (VLM) are all inside the (VLM)Engine.

_forward_backward_step removed from both recipes. _run_validation_epoch
calls engine.forward_backward(..., forward_only=True). vlm/finetune.py's
inline _chunk_vlm_media is now a thin deprecated wrapper that delegates to
nemo_automodel.vlm_engine.chunk_vlm_media.

Net: train_ft.py shrinks ~115 lines; vlm/finetune.py shrinks ~175 lines.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…tests

- test_engine_methods.py: add tests for the list-of-batches forward_backward
  path and the MoE aux-loss main_loss_backward_scale setting (replaces
  TestRunTrainOptimStepSetsMoEScale that mocked the deleted recipe-internal
  _forward_backward_step).
- tests/unit_tests/recipes/test_train_ft.py: drop obsolete tests that mocked
  the deleted _forward_backward_step (4x test_forward_backward_step_pp_*,
  2x test_run_validation_epoch_pp_*, TestRunTrainOptimStepSetsMoEScale).
- tests/unit_tests/recipes/test_finetune_vlm_helpers.py: drop the parallel
  obsolete VLM tests (TestForwardBackwardStepPP, TestForwardBackwardStepNonPP,
  3x test_run_train_step_*).

Six remaining test_train_ft.py / test_finetune_vlm_helpers.py failures are
pre-existing PR 2190 issues around build_optimizer's typed signature
(verified pre-migration via git stash) — not introduced by this change.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Match the native engine pattern used by both TorchTitan and bumblebee:

  TorchTitan: Trainer.__init__(config: Trainer.Config) — one nested Config
              dataclass with ~18 sub-config fields.
  bumblebee:  Megatron5DRuntime.__init__(hf_path, cfg: BBConfig) — one path
              plus one Config dataclass with ~7 sub-fields.

Before this change, Engine.__init__ took 14 keyword arguments, which was the
outlier. Now:

  @DataClass
  class Engine.Config:
      model, distributed, optimizer, lr_scheduler, dist_env,
      peft, quantization, fp8, pipeline, moe,
      activation_checkpointing, max_grad_norm, defer_fsdp_grad_sync, seed

  Engine.__init__(self, config: Engine.Config)        # one argument

Recipes change from Engine(model_cfg=..., distributed_cfg=..., ...) to
Engine(Engine.Config(model=..., distributed=..., ...)). The ``Engine.Config``
is a nested class (TorchTitan-style) so the caller types ``Engine.Config(...)``
right next to the ``Engine(...)`` call. Property accessors on Engine
(model_cfg, distributed_cfg, etc.) keep internal call sites unchanged.

Updated:
- nemo_automodel/engine.py: new Engine.Config dataclass; properties expose
  fields under their existing internal names.
- recipes/llm/train_ft.py + recipes/vlm/finetune.py: build Engine.Config()
  at the construction site.
- tests/unit_tests/test_engine{,_methods}.py: all 9 Engine call sites use
  the new shape.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The previous Engine.__init__ body initialized 5 extra attributes
(cp_use_te, cp_padding_token_id, cp_num_chunks, fp8_autocast, extra_loss_fn)
that were really configuration values — recipes set them after construction,
not runtime state populated by build(). Move them into Engine.Config:

  ── CP / THD batch shaping (passed to make_cp_batch_and_ctx) ──
  cp_use_te: bool = False
  cp_padding_token_id: int = 0
  cp_num_chunks: int = 1

  ── Optional callable hooks ──
  fp8_autocast: Callable[[], Any] | None = None
  extra_loss_fn: Callable[..., torch.Tensor | None] | None = None

Property accessors on Engine proxy reads/writes through to Engine.config so
the recipes' existing post-construction assignments (engine.cp_use_te = ...)
keep working unchanged.

After this change Engine.__init__'s body holds only runtime state:
  - model, optimizer, lr_scheduler, mesh  (built by build() or recipe-injected)
  - _num_label_tokens                     (per-step state)

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…p explicitly

Previously Engine stashed num_label_tokens during forward_backward as a
side-channel for optimizer_step to read at scale_grads_and_clip_grad_norm
time. That's a hidden coupling between two API calls — and it forced
__init__ to initialize a per-step attribute on the instance.

Now optimizer_step(num_label_tokens=None) takes it explicitly. The recipe
already computes num_label_tokens once per step (sum across batches +
DP allreduce); passing it to both forward_backward and optimizer_step is
trivial.

Engine.__init__ body shrinks accordingly — runtime state only:
  self.model
  self.optimizer
  self.lr_scheduler
  self.mesh

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The _forward_backward_step / _run_train_optim_step inner-loop deletion
left several imports stranded:

train_ft.py:
  - nullcontext, get_sync_ctx, make_cp_batch_and_ctx, filter_forward_kwargs,
    get_final_hidden_states  (used by deleted _forward_backward_step body)
  - MoEAuxLossAutoScaler  (Engine.forward_backward sets the scale now)
  - prepare_for_grad_accumulation, prepare_after_first_microbatch,
    prepare_for_final_backward  (Engine drives the microbatch lifecycle)
  - scale_grads_and_clip_grad_norm  (Engine.optimizer_step calls it)

vlm/finetune.py: same set, minus prepare_after_first_microbatch which the
VLM version never imported.

No behavior change. count_tail_padding stays — both recipes still use it
to compute num_tokens_in_batch for TPS reporting. FirstRankPerNode also
stays — used during dataloader setup. calculate_loss / calculate_mtp_loss
also stay — the MTP closure passed to Engine.extra_loss_fn calls them.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Adds ``target_and_kwargs(cfg) -> (factory, kwargs)`` to
``components/config/loader.py`` — the canonical location for Hydra/YAML
schema knowledge, sibling to ``_resolve_target`` already there.

Both prior duplicates now import from this single source:
  - ``nemo_automodel.engine`` (alias as ``_callable_and_kwargs``)
  - ``nemo_automodel.recipes._component_builders`` (alias as
    ``_callable_and_kwargs``)

No behavior change. With this, engine.py no longer carries ``_target_``
resolution logic; all ``_target_`` mentions inside engine.py are comments
or docstrings. Components' builder files (PR 2190) remain clean of
``_target_`` knowledge — they take pre-resolved ``(factory, kwargs)``.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Smoke sweep on real models (moonlight-16B-deepep-te, qwen3-moe-30b-deepep)
caught that BenchmarkingRecipeForNextTokenPrediction inherits from
TrainFinetuneRecipeForNextTokenPrediction and overrides the training loop,
but its inner per-iteration body still called the now-deleted
``self._forward_backward_step``.

Replace the inner gradient-accumulation loop with the same
``engine.forward_backward(batches, loss_fn, num_label_tokens) +
engine.optimizer_step()`` pattern train_ft.py uses. The Engine drives
prepare_for_grad_accumulation → microbatch loop → MoE aux-loss scale →
prepare_after_first_microbatch → prepare_for_final_backward internally,
so the recipe stops importing and calling those helpers directly.

Drops unused imports of prepare_for_grad_accumulation /
prepare_after_first_microbatch / prepare_for_final_backward.

Verified end-to-end:
  - moonlight 16B DeepEP TE  : 5 iters, 56s on 8xH100
  - qwen3 MoE 30B DeepEP     : 5 iters, 136s on 8xH100
  - qwen 2.5 7B PEFT         : 5 steps, ~8s on 8xH100

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Moves ~615 lines from recipes/llm/train_ft.py and ~340 lines from
recipes/vlm/finetune.py into components/:

- components/training/build.py: build_model (LLM), build_vlm_model
- components/datasets/llm/build.py: build_dataloader, build_validation_dataloader,
  _uses_te_dot_product_attention, _uses_thd_collater, _get_num_thd_chunks
- components/datasets/vlm/build.py: build_vlm_dataloader
- components/loss/calculate.py: calculate_loss
- components/loss/mtp.py: calculate_mtp_loss
- _transformers/auto_tokenizer.py: _build_tokenizer, _get_model_name,
  compute_trust_remote_code_from_model

Recipes keep back-compat re-imports so kd.py and external test patches
continue to resolve. Test patch sites for build_dataloader and
resolve_trust_remote_code updated to point at the new module locations.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
components/training/build.py::build_model and build_vlm_model now take
typed parameters only — model_factory, model_kwargs, is_nemo_auto_model,
peft_config, fp8_config, qat_config, moe_config, etc. — and do no
.get("_target_") / .to_dict() / .instantiate() calls on ConfigNode.

YAML coupling moves up to the recipe layer:
- recipes/llm/train_ft.py::build_model is now a ~50-line wrapper that
  translates cfg_* into typed args via cfg_model.instantiate as the
  factory, then delegates to the component impl.
- recipes/vlm/finetune.py::build_model mirrors the same shape, with the
  VLM nemo-auto-target check still raising at the recipe boundary.

Tests / external callers (kd.py) keep the cfg_* surface — the wrapper
preserves it. The component below it is now safe to call from
non-recipe call sites with pre-resolved typed configs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…s build

Engine.build() is gone; construction moves into __init__:
- Resolves mesh (accepts MeshContext or a dist-setup namespace)
- Calls the typed components.training.build.build_model with cfg→typed
  translation inline (model_factory via target_and_kwargs, fp8_config,
  qat_config, moe_config, freeze_config, ...)
- Calls the typed components.optim.build.build_optimizer
- Builds the LR scheduler when Engine.Config.lr_scheduler is set

Engine.Config gains the fields the build pipeline needs (qat, compile,
sdpa_method, has_packed_sequence, unfreeze_modules, freeze_config) and
drops the dead _cfg_get / _cfg_to_dict / convenience-accessor surface.
Construction is skipped when ``config.model is None`` so tests can
construct an Engine shell and inject ``self.model`` manually.

Recipe surgery:
- train_ft.py and vlm/finetune.py construct the Engine eagerly right
  after distributed setup. The recipe's prior build_model / build_optimizer
  call sites are gone; the lr_scheduler is still built post-step_scheduler
  by the recipe and attached to engine.lr_scheduler (the recipe path
  needs step_scheduler-derived total_steps).
- vlm/finetune.py uses VLMEngine the same way.
- components/training/build.py::build_model gains a freeze_config kwarg
  so the LLM and VLM paths share the same builder.

Tests:
- Engine tests construct with model=None to skip the build chain.
- test_engine.py drops the build() existence assertion.
- Recipe-setup tests get a _FakeEngine stand-in so the dummy cfg doesn't
  trip Engine's real builders.

Drops 34 dead imports auto-removed by ruff that were leftover from the
earlier recipe-build moves.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
The helper was a leftover from an earlier draft where Engine.Config.lr_scheduler
was set from the recipe's YAML. The current flow builds the LR scheduler at
the recipe layer (driven by step_scheduler) and assigns it directly to
``engine.lr_scheduler``, so the helper has no production caller. Its only
remaining consumer was a self-referential test, which is also removed —
the typed ``LRSchedulerConfig(total_steps: int)`` dataclass enforces the
required-field invariant structurally.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…cipes

Code-reuse cleanup found by /simplify review:

- 3-way duplication of NeMoAutoModel target predicate (engine.py +
  recipes/llm/train_ft.py + recipes/vlm/finetune.py) collapses into a
  single ``is_nemo_auto_factory(target)`` exported from
  ``_transformers/auto_model.py`` (alongside the classes it tests).
- 3-way duplication of ConfigNode → MoEParallelizerConfig coercion
  collapses into ``MoEParallelizerConfig.coerce(cfg)`` classmethod.
- Engine.Config drops dead ``moe`` and ``pipeline`` fields — both were
  always overridden by ``self.mesh.{moe_config,pipeline_config}``.
- ``bool(torch.isfinite(torch.tensor(grad_norm_val)))`` → ``math.isfinite(...)``
  in optimizer_step; avoids a host-side tensor alloc every step.
- Drop refactor-narrative comments ("TorchTitan-style", "Engine was
  constructed earlier", "Lives here (not in components/) because ...")
  — context belongs in git history, not the source.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
build_vlm_model was a strict subset of build_model (no quantization /
qat / sdpa_method / unfreeze_modules / is_nemo_auto_model dispatch).
Now that build_model accepts ``freeze_config`` and ``is_nemo_auto_model``,
there's no behavior the VLM-specific variant provides that build_model
can't.

The VLM recipe wrapper at recipes/vlm/finetune.py::build_model calls
build_model directly with ``is_nemo_auto_model=True``. The "no fallback
path" invariant (VLM requires NeMoAutoModelFor*) is enforced at the
wrapper level, where the friendlier ValueError already lives.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…th recipes

train_ft.py and vlm/finetune.py each had a ~30-line ``log_val_metrics``
that fanned out a ``MetricsSample`` to wandb / mlflow / comet / JSONL and
emitted a ``[val] ...`` info line. The LLM version was a superset (multi-
val-dataset + mlflow + comet); the VLM version a subset (single dataset,
wandb only).

Both implementations collapse to a shared
``components.loggers.metric_logger.log_validation_metrics`` helper that
takes the loggers as kwargs. ``val_name`` is annotated in the wandb /
comet payload (and the info line) only when it's not the default —
preserving each recipe's existing log format exactly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
… recipes

Mirrors the log_validation_metrics extraction. The shared helper does:
- WandB / MLflow / Comet remote logging gated by ``is_remote_logging_step``
- Always-on JSONL log
- Standard ``step ... loss ... grad_norm ...`` info line
- ``torch.cuda.reset_peak_memory_stats()``

Recipe ``log_train_metrics`` methods become 8-line shims. The LLM recipe
keeps its MoE load-balance metric logging inline after the helper call
(it's recipe-specific state on ``self._moe_layer_loads``).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…ted_env

VLM also gains apply_te_patches() and enable_nvtx so its setup() prefix
becomes byte-identical to LLM's. Both recipes now collapse the ~14-line
boot block (dist_env + setup_logging + apply_*_patches + StatefulRNG +
setup_distributed + mesh attrs) into a single ``self._setup_distributed_env()``
call.

Test fixtures' monkeypatches move from
``nemo_automodel.recipes.{llm.train_ft,vlm.finetune}.<symbol>`` to
``nemo_automodel.recipes.base_recipe.<symbol>`` since the imports live
there now.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
…gers

The wandb / mlflow / comet initialization block moves into a shared
BaseRecipe method. VLM gains MLflow + Comet support (previously
wandb-only).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Adds build_validation_dataloader to components/datasets/vlm/build.py
that scans cfg.to_dict() for keys starting with validation_dataset
(mirroring the LLM helper) and builds one DataLoader per match keyed
by the suffix ("default", "val", "test", "foo", ...).

VLM recipe switches self.val_dataloader → self.val_dataloaders dict,
iterates in run_train_validation_loop, and maintains per-dataset JSONL
loggers (validation.jsonl + validation_<name>.jsonl). log_val_metrics
signature now matches LLM's (val_name, log_data, metric_logger).

VLM also gains MLflow + Comet logging on validation (passed through
the shared log_validation_metrics helper).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
VLM adopts LLM's pattern: engine.forward_backward(num_label_tokens=None)
sums unnormalized batch losses, dp_allreduce, then divides by the total
label-token count. Under PP, the value travels from the last pipeline
stage to rank 0 so whichever rank logs has the data.

The shared method lives on BaseRecipe. Both recipes' duplicates removed.
VLM gains PP-aware validation; the prior "Validation is not supported
for pipeline parallelism" warn-and-skip behavior is dropped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
HuiyingLi added 4 commits May 17, 2026 22:05
_collect_moe_load_balance + _log_moe_metrics move to BaseRecipe and are
now called from VLM's training loop too. When training MoE VLMs (e.g.
Qwen3-VL-MoE-30B) with ``moe_metrics.enabled: true`` in YAML, the wandb
/ comet trackers now receive the same load-balance summaries as LLM MoE.

VLM's log_train_metrics shim also gains mlflow_logger and comet_logger
pass-through (was wandb-only).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
_setup_qat + _enable_qat_if_delayed move to BaseRecipe. VLM's setup()
now sets up the QAT delayed-fake-quant scheduler the same way LLM
does, and its training loop calls _enable_qat_if_delayed on each step.

Also fixes a NameError introduced in the prior Phase 4 commit: the
save_checkpoint call inside VLM's run_train_validation_loop still
referenced the old singular ``val_loss`` name; renamed to ``val_losses``
to match the new dict.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
No behavior changes: ruff format applied at line length 120, plus a
handful of stale F401 unused imports autofixed across the touched files.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Engine._construct previously called ``target_and_kwargs(self.config.model)``
to split into ``(factory, model_kwargs)`` and then forwarded ``model_kwargs``
to the typed builder. The ``to_dict()`` path in ``target_and_kwargs`` only
pops the *top-level* ``_target_``; nested ConfigNodes (notably the model
``backend:`` sub-config) keep their ``_target_`` key intact, so the
downstream ``BackendConfig(**kwargs["backend"])`` call in model_init.py
hit a ``TypeError: BackendConfig.__init__() got an unexpected keyword
argument '_target_'``.

The legacy recipe path called ``cfg_model.instantiate(**infra_kwargs)``,
which delegates to ``ConfigNode._instantiate_value`` and recursively
instantiates sub-ConfigNodes that carry their own ``_target_``. Adopt
that pattern in Engine: pass ``cfg.model.instantiate`` as ``model_factory``
and use ``target_and_kwargs`` only to recover the original target callable
for the ``is_nemo_auto_factory`` identity check.

Validated on Qwen3-VL-MoE-30B EP4 PP2 smoke (5 steps, loss 2.19 → 2.33).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant