From 0bf466b935948a22f32c8477c832fe209f543a3a Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 13 Jun 2026 22:03:14 +0000 Subject: [PATCH 1/5] add domino support Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 9 + .../torch/export/plugins/hf_spec_export.py | 31 ++ modelopt/torch/speculative/config.py | 17 + .../torch/speculative/dflash/conversion.py | 30 +- .../torch/speculative/plugins/__init__.py | 1 + .../torch/speculative/plugins/hf_dflash.py | 8 +- .../torch/speculative/plugins/hf_domino.py | 364 ++++++++++++++++++ .../speculative/plugins/modeling_domino.py | 90 +++++ .../general/speculative_decoding/domino.yaml | 86 +++++ .../speculative/plugins/test_hf_domino.py | 213 ++++++++++ .../Qwen/Qwen3-8B/hf_online_domino.yaml | 60 +++ 11 files changed, 899 insertions(+), 10 deletions(-) create mode 100644 modelopt/torch/speculative/plugins/hf_domino.py create mode 100644 modelopt/torch/speculative/plugins/modeling_domino.py create mode 100644 modelopt_recipes/general/speculative_decoding/domino.yaml create mode 100644 tests/unit/torch/speculative/plugins/test_hf_domino.py create mode 100644 tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index f62b099121d..b998c702a7f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -276,6 +276,15 @@ def train(): and recipe.eagle.eagle_base_lora_warmup_steps > 0 ): callbacks.append(LoRAWarmupCallback(recipe.eagle.eagle_base_lora_warmup_steps)) + # Domino (dflash recipe with projector_type=domino) needs the lambda_base + # curriculum schedule driven by the trainer's global step. + if ( + isinstance(recipe, ModelOptDFlashRecipe) + and recipe.dflash.dflash_architecture_config.get("projector_type") == "domino" + ): + from modelopt.torch.speculative.plugins.hf_domino import DominoLambdaCallback + + callbacks.append(DominoLambdaCallback()) # Leave training_args.ignore_data_skip at its default (False). The dataset is # map-style, so HF Trainer's resume skips consumed indices at the batch-sampler # level (accelerate.skip_first_batches) without re-fetching them, landing at the diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 54d6e493c25..f3d153b0bc4 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -435,3 +435,34 @@ def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): f"Exported DFlash draft model: {len(drafter_sd)} tensors, " f"config keys: {list(drafter_config.keys())[:5]}..." ) + + +class DominoExporter(DFlashExporter): + """Draft model exporter for Domino (DFlash backbone + causal correction head). + + Same z-lab-compatible format as DFlash, plus the Domino head weights + (``prefix_gru.*`` / ``embed_proj.*``, already captured by the inherited + ``dflash_module.`` stripping) and the extra config fields the loader needs to + rebuild the head (``projector_type``, ``emb_dim``, ``gru_hidden_dim``, + ``pure_draft_prefix_len``, ``shift_label``). + """ + + def _export_config(self): + """Extend the DFlash config with the Domino head fields.""" + config = super()._export_config() + draft_config = self.model.dflash_config + + emb_dim = getattr(draft_config, "emb_dim") + gru_hidden_dim = getattr(draft_config, "gru_hidden_dim") + # Mirror the reference checkpoint: emb_dim also appears at the top level. + config["emb_dim"] = emb_dim + config["dflash_config"].update( + { + "projector_type": getattr(draft_config, "projector_type", "domino"), + "shift_label": getattr(draft_config, "shift_label", True), + "pure_draft_prefix_len": getattr(draft_config, "pure_draft_prefix_len", 1), + "gru_hidden_dim": gru_hidden_dim, + "emb_dim": emb_dim, + } + ) + return config diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 7649b2d0357..2982273265c 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -132,6 +132,23 @@ class DFlashConfig(ModeloptBaseConfig): description="Whether to use torch.compile on DFlash forward/loss methods.", ) + dflash_lambda_base_start: float = ModeloptField( + default=1.0, + description=( + "Domino only: initial weight of the base (backbone-only) loss in the " + "loss = (1 - lambda)*final + lambda*base mixture; linearly decayed to 0. " + "Ignored unless dflash_architecture_config.projector_type == 'domino'." + ), + ) + + dflash_lambda_base_decay_ratio: float = ModeloptField( + default=1.0, + description=( + "Domino only: fraction of total training steps over which lambda_base " + "decays from dflash_lambda_base_start to 0." + ), + ) + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py index 943be90ca0f..03406088dd1 100644 --- a/modelopt/torch/speculative/dflash/conversion.py +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -24,26 +24,38 @@ from ..config import DFlashConfig DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry +# Domino reuses the dflash mode/config/recipe but converts the base model to a +# DFlash module augmented with a causal correction head. It is selected via +# ``dflash_architecture_config.projector_type == "domino"`` and lives in its own +# registry so its wrapper (HFDominoModel) does not overwrite HFDFlashModel. +DominoDMRegistry = _DMRegistryCls(prefix="Domino") def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: - """Convert the model to a DFlash model as per `config`.""" + """Convert the model to a DFlash (or Domino) model as per `config`.""" model = model.init_modellike() if isinstance(model, ModelLikeModule) else model - original_cls = type(model) - if original_cls not in DFlashDMRegistry: - for cls in DFlashDMRegistry._registry: - if issubclass(original_cls, cls): - DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls]) - break - # merge custom config with default config (lazy import to avoid circular) from .default_config import default_dflash_config custom_config = config.dflash_architecture_config config.dflash_architecture_config = {**default_dflash_config, **custom_config} - dflash_model = DFlashDMRegistry.convert(model) + # Route to the Domino registry when the architecture asks for the Domino head. + registry = ( + DominoDMRegistry + if config.dflash_architecture_config.get("projector_type") == "domino" + else DFlashDMRegistry + ) + + original_cls = type(model) + if original_cls not in registry: + for cls in registry._registry: + if issubclass(original_cls, cls): + registry.register({original_cls: "base_model_class"})(registry[cls]) + break + + dflash_model = registry.convert(model) dflash_model.modify(config) metadata = {} diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index c30a65b2b47..ec90b8c0fda 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -31,5 +31,6 @@ with import_plugin("transformers"): from .hf_dflash import * + from .hf_domino import * from .hf_eagle import * from .hf_medusa import * diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1760cb2072d..a2bddf0be75 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -180,7 +180,9 @@ def modify(self, config): self._find_base_model_parts() - self.dflash_module = DFlashModule(self.dflash_config) + # Factory hook: subclasses (e.g. Domino) override to build an augmented + # draft module while reusing all of DFlash's modify() setup. + self.dflash_module = self._build_draft_module(self.dflash_config) # Match base model dtype/device. Skip if base is on meta (during from_pretrained # restore — the model will be moved to the correct device after weight loading). if self.dflash_offline: @@ -197,6 +199,10 @@ def modify(self, config): self.is_quantized = False self._num_anchors = self.dflash_num_anchors + def _build_draft_module(self, dflash_config): + """Build the draft module. Subclasses override to use an augmented module.""" + return DFlashModule(dflash_config) + def get_exporter(self): """Get the exporter for the DFlash draft model.""" from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter diff --git a/modelopt/torch/speculative/plugins/hf_domino.py b/modelopt/torch/speculative/plugins/hf_domino.py new file mode 100644 index 00000000000..e4d37412d13 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_domino.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Domino speculative decoding plugin for HuggingFace models. + +Domino reuses the DFlash draft backbone and training pipeline (anchor sampling, +noise/mask construction, KV-injection attention) and adds a lightweight causal +correction head (see ``modeling_domino.DominoModule``): + +- Backbone produces *base* logits for a full draft block in parallel. +- A GRU runs over the block's previously decoded (teacher-forced) token + embeddings to produce a causal state, which is fused with the backbone hidden + state and projected to a vocab-sized logit correction added to the suffix + positions. This injects the intra-block causal dependency the parallel + backbone lacks. + +Training uses next-token (shift_label) alignment and a two-term loss:: + + loss = (1 - lambda_base) * final_loss + lambda_base * base_loss + +where ``final_loss`` is CE on the corrected logits and ``base_loss`` is CE on the +backbone-only logits. ``lambda_base`` decays linearly from ``lambda_base_start`` +to 0 over ``lambda_base_decay_ratio`` of training (curriculum: learn a good +parallel backbone first, then the causal correction). The schedule is driven by +``DominoLambdaCallback`` from the HF Trainer's global step. +""" + +import logging + +import torch +import torch.nn.functional as F +from transformers import PreTrainedModel +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_pt_utils import LabelSmoother +from transformers.utils import ModelOutput + +from ..dflash.conversion import DominoDMRegistry +from .hf_dflash import HFDFlashModel +from .modeling_dflash import DFlashBaseModelOutput +from .modeling_domino import DominoModule + +logger = logging.getLogger(__name__) + +__all__ = ["DominoLambdaCallback", "HFDominoModel", "compute_lambda_base"] + + +def compute_lambda_base( + global_step: int, + total_steps: int, + lambda_start: float = 1.0, + decay_ratio: float = 1.0, +) -> float: + """Linearly decay lambda_base from ``lambda_start`` to 0. + + Decay completes after ``decay_ratio * total_steps`` steps; clamped to [0, 1]. + """ + decay_steps = max(1, int(total_steps * decay_ratio)) + progress = min(global_step / decay_steps, 1.0) + lambda_base = lambda_start * (1.0 - progress) + return max(0.0, min(1.0, lambda_base)) + + +@DominoDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class HFDominoModel(HFDFlashModel): + """DFlash model with the Domino causal correction head (HF transformers). + + Registered in ``DominoDMRegistry`` so that ``convert_to_dflash_model`` can + route to it when ``dflash_architecture_config.projector_type == "domino"``. + """ + + def _build_draft_module(self, dflash_config): + """Build the Domino draft module (DFlash backbone + GRU correction head).""" + return DominoModule(dflash_config) + + def modify(self, config): + """Initialize the Domino draft module and read the lambda_base schedule.""" + super().modify(config) + # Curriculum schedule for the base/final loss mixing weight. Read here + # (DFlashConfig carries the two fields); updated each step by + # DominoLambdaCallback. Defaults keep a single forward (e.g. unit tests) + # well-defined without a scheduler. + self.dflash_lambda_base_start = getattr(config, "dflash_lambda_base_start", 1.0) + self.dflash_lambda_base_decay_ratio = getattr(config, "dflash_lambda_base_decay_ratio", 1.0) + self._lambda_base = self.dflash_lambda_base_start + if not getattr(self.dflash_module, "shift_label", True): + raise NotImplementedError( + "Domino currently supports shift_label=True (next-token alignment) only." + ) + + def get_exporter(self): + """Get the exporter for the Domino draft model.""" + from modelopt.torch.export.plugins.hf_spec_export import DominoExporter + + return DominoExporter(self) + + def _current_lambda_base(self) -> float: + return float(getattr(self, "_lambda_base", self.dflash_lambda_base_start)) + + def _apply_domino_head(self, hidden, base_logits, input_ids, anchor_positions, n_blocks): + """Add the GRU causal correction to the suffix positions of each block. + + Args: + hidden: Draft backbone output [B, N*block_size, H]. + base_logits: Backbone logits [B, N*block_size, vocab]. + input_ids: Original token IDs [B, seq_len]. + anchor_positions: Anchor positions per block [B, N]. + n_blocks: Number of blocks N. + + Returns: + Corrected logits [B, N*block_size, vocab]. + """ + bsz, seq_len = input_ids.shape + bs = self.dflash_block_size + device = input_ids.device + suffix_start = self.dflash_module.pure_draft_prefix_len + + hidden4d = hidden.reshape(bsz, n_blocks, bs, hidden.size(-1)) + base4d = base_logits.reshape(bsz, n_blocks, bs, -1) + + # Teacher-forced previous tokens: the real token at anchor+j for j in [0, bs). + prev_offsets = torch.arange(bs, device=device).view(1, 1, -1) + prev_idx = (anchor_positions.unsqueeze(-1) + prev_offsets).clamp(max=seq_len - 1) + prev_ids = torch.gather(input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, prev_idx) + + block_emb = self._base_model_embeddings(prev_ids) # [B, N, bs, H] + gru_in = block_emb.reshape(bsz * n_blocks, bs, block_emb.size(-1)) + gru_out, _ = self.dflash_module.prefix_gru(gru_in) + gru_out = gru_out.reshape(bsz, n_blocks, bs, -1) + + # Causal state for suffix positions: gru_out[p] summarizes anchor+0..anchor+p. + prefix_states = gru_out[:, :, suffix_start:, :] + z_n = hidden4d[:, :, suffix_start:, :] + logits_e = self.dflash_module.embed_proj(torch.cat([z_n, prefix_states], dim=-1)) + + prefix_logits = base4d[:, :, :suffix_start, :] + suffix_logits = base4d[:, :, suffix_start:, :] + logits_e + final4d = torch.cat([prefix_logits, suffix_logits], dim=2) + return final4d.reshape(bsz, n_blocks * bs, -1) + + def _compute_domino_loss( + self, base_logits, final_logits, input_ids, anchor_positions, block_keep_mask, loss_mask + ): + """Compute the (1-lambda)*final + lambda*base weighted CE loss and accuracies. + + Uses next-token (shift_label) alignment: position k predicts the token at + anchor+k+1, and position 0 is *not* excluded (unlike base DFlash). + """ + bsz, seq_len = input_ids.shape + bs = self.dflash_block_size + n_blocks = anchor_positions.shape[1] + device = input_ids.device + + # shift_label=True: label for block position k is the token at anchor+k+1. + label_offsets = torch.arange(1, 1 + bs, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets + valid_label = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + + # Weight mask: valid block * in bounds * loss_mask. No pos-0 exclusion. + weight_mask = block_keep_mask.unsqueeze(-1).expand(-1, -1, bs).float() + weight_mask = weight_mask * valid_label.float() + orig_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + weight_mask = weight_mask * orig_loss_mask + + binary_eval_mask = weight_mask.view(-1) + + # Loss decay: exp(-k/gamma) so the first prediction (k=0) gets weight 1.0. + if self.dflash_loss_decay_factor > 0: + k = torch.arange(bs, device=device).view(1, 1, -1) + decay = torch.exp(-k.clamp(min=0).float() / self.dflash_loss_decay_factor) + weight_mask = weight_mask * decay + + flat_final = final_logits.reshape(-1, final_logits.size(-1)) + flat_base = base_logits.reshape(-1, base_logits.size(-1)) + flat_targets = target_ids.reshape(-1) + flat_weights = weight_mask.reshape(-1) + valid_count = flat_weights.sum() + 1e-6 + + lambda_base = self._current_lambda_base() + + if valid_count > 1.0: + final_loss = ( + F.cross_entropy(flat_final, flat_targets, reduction="none") * flat_weights + ).sum() / valid_count + base_loss = ( + F.cross_entropy(flat_base, flat_targets, reduction="none") * flat_weights + ).sum() / valid_count + loss = (1.0 - lambda_base) * final_loss + lambda_base * base_loss + + with torch.no_grad(): + eval_count = binary_eval_mask.sum() + 1e-6 + final_correct = (flat_final.argmax(dim=-1) == flat_targets) & ( + binary_eval_mask > 0.5 + ) + base_correct = (flat_base.argmax(dim=-1) == flat_targets) & (binary_eval_mask > 0.5) + accuracy = (final_correct.sum().float() / eval_count).item() + base_accuracy = (base_correct.sum().float() / eval_count).item() + else: + loss = flat_final.sum() * 0.0 + final_loss = loss + base_loss = loss + accuracy = 0.0 + base_accuracy = 0.0 + + metrics = { + "final_loss": final_loss.detach().item(), + "base_loss": base_loss.detach().item(), + "base_accuracy": base_accuracy, + "lambda_base": lambda_base, + } + return loss, accuracy, metrics + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """Domino training forward: DFlash backbone + causal correction head + dual loss. + + Mirrors ``HFDFlashModel.forward`` for data preparation (reusing the inherited + anchor/noise/mask/position helpers), then applies the Domino head and the + two-term loss. Eval/offline-eval is delegated to the DFlash parent. + """ + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + device = input_ids.device + + if seq_len % block_size != 0: + raise ValueError( + f"seq_len ({seq_len}) must be divisible by block_size ({block_size}). " + f"Adjust training_seq_len or use padding." + ) + + # 1. Target hidden states (Domino does not use target-logit KD). + if self.dflash_offline: + assert "base_model_outputs" in kwargs + base_outputs = DFlashBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) + target_hidden = base_outputs.target_hidden + else: + with torch.no_grad(): + base_out = self._base_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + offset = 1 + selected = [base_out.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + + # 2. Build loss mask (same convention as DFlash). + if labels is not None: + loss_mask = (labels != LabelSmoother.ignore_index).float() + elif attention_mask is not None: + loss_mask = attention_mask.float() + else: + loss_mask = torch.ones(bsz, seq_len, device=device) + if kwargs.get("loss_mask") is not None: + loss_mask = loss_mask * kwargs["loss_mask"] + + # 3. Random anchor sampling (inherited). + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + if n_blocks == 0 or not block_keep_mask.any(): + # Zero loss that still flows through all draft params for DDP sync. + dummy = sum(p.sum() for p in self.dflash_module.parameters()) * 0.0 + return ModelOutput(loss=dummy, logits=None, train_acc=[[0.0]]) + + # 4. Build draft inputs (inherited helpers). + noise_embedding = self._build_noise_embedding( + input_ids, anchor_positions, block_keep_mask, n_blocks + ) + full_pos = self._build_position_ids(seq_len, anchor_positions, device) + attn_mask = self._build_draft_attention_mask( + seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device + ) + + # 5. Draft backbone forward. + hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=full_pos, + attention_mask=attn_mask, + ) + + # 6. Base + corrected logits, then dual loss. + base_logits = self._base_model_lm_head(hidden) + final_logits = self._apply_domino_head( + hidden, base_logits, input_ids, anchor_positions, n_blocks + ) + loss, accuracy, metrics = self._compute_domino_loss( + base_logits, final_logits, input_ids, anchor_positions, block_keep_mask, loss_mask + ) + + return ModelOutput(loss=loss, logits=None, train_acc=[[accuracy]], domino_metrics=metrics) + + +class DominoLambdaCallback(TrainerCallback): + """Update the model's ``lambda_base`` from the HF Trainer global step. + + Linearly decays the base-loss weight from ``lambda_base_start`` to 0 over + ``lambda_base_decay_ratio`` of total training steps. + """ + + def on_step_begin(self, args, state, control, **kwargs): + """Set ``model._lambda_base`` for the upcoming step.""" + model = kwargs.get("model") + if model is None: + return + # Unwrap DDP/FSDP if needed. + inner = getattr(model, "module", model) + if not hasattr(inner, "dflash_lambda_base_start"): + return + total_steps = state.max_steps if state.max_steps and state.max_steps > 0 else 1 + inner._lambda_base = compute_lambda_base( + state.global_step, + total_steps, + inner.dflash_lambda_base_start, + inner.dflash_lambda_base_decay_ratio, + ) diff --git a/modelopt/torch/speculative/plugins/modeling_domino.py b/modelopt/torch/speculative/plugins/modeling_domino.py new file mode 100644 index 00000000000..85dbcc31a8f --- /dev/null +++ b/modelopt/torch/speculative/plugins/modeling_domino.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Domino draft module — DFlash backbone plus a lightweight causal correction head. + +Domino extends the parallel DFlash draft backbone (``DFlashModule``) with a small +GRU-based correction head. The backbone produces *base* logits for a full draft +block in one parallel forward; the head then injects intra-block causal dependency +(which the parallel backbone lacks) by running a GRU over the block's previously +decoded tokens and adding a logit correction to the suffix positions. + +The head consists of: + - ``prefix_gru``: single-layer GRU over token embeddings of the block prefix, + producing a causal state summarizing tokens seen so far in the block. + - ``embed_proj``: MLP mapping ``[backbone_hidden ; gru_state]`` to a vocab-sized + logit correction. + +These two submodules live on ``DominoModule`` so they export under the +``dflash_module.`` prefix and serialize alongside the backbone (matching the +z-lab/SpecForge ``prefix_gru.*`` / ``embed_proj.*`` checkpoint layout). + +The head is *applied* by the training wrapper (``HFDominoModel``), which owns the +base model's embedding table; this module only holds the parameters. See +``hf_domino.py`` for the forward/loss orchestration. +""" + +from torch import nn + +from .modeling_dflash import DFlashModule + +__all__ = ["DominoModule"] + + +class DominoModule(DFlashModule): + """DFlash draft module augmented with the Domino causal correction head.""" + + def __init__(self, config): + """Initialize the DFlash backbone, then add the GRU + projection head.""" + super().__init__(config) + + self.projector_type = getattr(config, "projector_type", "domino") + self.gru_hidden_dim = config.gru_hidden_dim + self.emb_dim = config.emb_dim + # pure_draft_prefix_len positions at the block start keep base logits only + # (no causal correction); the GRU correction applies to the suffix. + self.pure_draft_prefix_len = getattr(config, "pure_draft_prefix_len", 1) + self.shift_label = getattr(config, "shift_label", True) + + # Causal state over the block's token embeddings. bias=False matches the + # reference checkpoint (only weight_ih_l0 / weight_hh_l0 are stored). + self.prefix_gru = nn.GRU( + input_size=config.hidden_size, + hidden_size=self.gru_hidden_dim, + num_layers=1, + batch_first=True, + bias=False, + ) + # [backbone_hidden ; gru_state] -> emb_dim -> vocab logit correction. + in_dim = config.hidden_size + self.gru_hidden_dim + self.embed_proj = nn.Sequential( + nn.Linear(in_dim, self.emb_dim, bias=False), + nn.SiLU(), + nn.Linear(self.emb_dim, config.vocab_size, bias=False), + ) + + # DFlashModule.__init__ already ran _init_weights before these modules + # existed, so initialize the new Linear layers explicitly. The GRU keeps + # PyTorch's default (uniform) init. + self._init_head_weights(config) + + def _init_head_weights(self, config): + """Initialize the correction-head Linear layers (GRU keeps default init).""" + std = getattr(config, "initializer_range", 0.02) + for module in self.embed_proj.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) diff --git a/modelopt_recipes/general/speculative_decoding/domino.yaml b/modelopt_recipes/general/speculative_decoding/domino.yaml new file mode 100644 index 00000000000..256b9e03a45 --- /dev/null +++ b/modelopt_recipes/general/speculative_decoding/domino.yaml @@ -0,0 +1,86 @@ +# Domino speculative-decoding training recipe. +# +# Domino reuses the DFlash mode/pipeline and adds a causal correction head +# (GRU + projection), selected via dflash_architecture_config.projector_type=domino. +# Online training is the default path (data.mode=online). Override fields via an +# OmegaConf dotlist on the CLI. + +metadata: + recipe_type: speculative_dflash + description: Domino training recipe (DFlash backbone + causal correction head). + +# maps to ModelArguments (main.py) +model: + model_name_or_path: + trust_remote_code: false + use_fake_base_for_offline: false + +# maps to DataArguments (main.py) +data: + mode: online + data_path: + offline_data_path: + # Jinja chat template with {% generation %} tags for answer_only_loss. + chat_template: + +# maps to TrainingArguments (main.py) +training: + # --- commonly modified --- + output_dir: + num_train_epochs: 6 + per_device_train_batch_size: 1 + learning_rate: 6.0e-4 + warmup_ratio: 0.04 + training_seq_len: 3072 + logging_steps: 50 + save_steps: 2000 + cp_size: 1 + dp_shard_size: 1 + disable_tqdm: true + estimate_ar: false + ar_validate_steps: 0 + answer_only_loss: true + + # --- rarely modified --- + do_eval: false + lr_scheduler_type: linear + save_strategy: steps + weight_decay: 0.0 + max_grad_norm: 1.0 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + ddp_find_unused_parameters: true + ddp_timeout: 1800 + report_to: tensorboard + +# maps to DFlashConfig (modelopt/torch/speculative/config.py). +dflash: + dflash_block_size: 16 + dflash_num_anchors: 256 + dflash_use_torch_compile: false + # Domino does not use target-logit KD; it trains its own base/final CE losses. + dflash_self_logit_distillation: false + # gamma for exponential loss decay (block_size=16 -> 7). + dflash_loss_decay_factor: 7.0 + # Qwen3 has no native mask token; 151669 is an unused id used by the reference. + dflash_mask_token_id: 151669 + # lambda_base curriculum: start fully on base loss, decay to final over all steps. + dflash_lambda_base_start: 1.0 + dflash_lambda_base_decay_ratio: 1.0 + dflash_architecture_config: + num_hidden_layers: 5 + # Draft attention/MLP dims. DFlash's draft is an independent model and does + # NOT auto-inherit these from the base (a fresh Qwen3Config already carries + # defaults, so modify()'s inherit-if-missing guard is a no-op). Set them + # explicitly to match the Qwen3-8B reference drafter (GQA: 8 KV heads). + num_attention_heads: 32 + num_key_value_heads: 8 + head_dim: 128 + intermediate_size: 12288 + projector_type: domino + emb_dim: 256 + gru_hidden_dim: 1024 + pure_draft_prefix_len: 1 + shift_label: true diff --git a/tests/unit/torch/speculative/plugins/test_hf_domino.py b/tests/unit/torch/speculative/plugins/test_hf_domino.py new file mode 100644 index 00000000000..1cdfa74a45f --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_domino.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for the Domino speculative decoding plugin. + +Domino reuses the DFlash mode/pipeline and adds a GRU-based causal correction +head. These tests cover conversion routing, the training forward (base + final +dual loss), and the export format (weights + config) against the z-lab reference +layout (``prefix_gru.*`` / ``embed_proj.*``). +""" + +import json +from copy import deepcopy + +import torch +from _test_utils.torch.transformers_models import get_tiny_llama + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG +from modelopt.torch.speculative.plugins.hf_dflash import HFDFlashModel +from modelopt.torch.speculative.plugins.hf_domino import HFDominoModel, compute_lambda_base +from modelopt.torch.speculative.plugins.modeling_domino import DominoModule + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be a multiple of BLOCK_SIZE +GRU_HIDDEN_DIM = 32 +EMB_DIM = 16 + + +def _get_domino_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a Domino config for testing (dflash mode + projector_type=domino).""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_mask_token_id"] = 0 # token 0 as mask for the tiny model + config["dflash_self_logit_distillation"] = False + config["dflash_loss_decay_factor"] = 4.0 + config["dflash_lambda_base_start"] = 1.0 + config["dflash_lambda_base_decay_ratio"] = 1.0 + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "projector_type": "domino", + "gru_hidden_dim": GRU_HIDDEN_DIM, + "emb_dim": EMB_DIM, + "pure_draft_prefix_len": 1, + "shift_label": True, + } + return config + + +class TestDominoConvert: + """Test Domino model conversion routing.""" + + def test_convert_creates_domino_model(self): + """projector_type=domino routes to HFDominoModel (a HFDFlashModel subclass).""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_domino_config())]) + assert isinstance(model, HFDominoModel) + assert isinstance(model, HFDFlashModel) + + def test_convert_attaches_domino_module_with_head(self): + """The draft module is a DominoModule with prefix_gru + embed_proj.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_domino_config())]) + assert isinstance(model.dflash_module, DominoModule) + assert isinstance(model.dflash_module.prefix_gru, torch.nn.GRU) + assert model.dflash_module.prefix_gru.bias is False + # embed_proj: Linear(H+gru -> emb) -> SiLU -> Linear(emb -> vocab) + assert model.dflash_module.embed_proj[0].in_features == ( + model.dflash_config.hidden_size + GRU_HIDDEN_DIM + ) + assert model.dflash_module.embed_proj[2].out_features == model.dflash_config.vocab_size + + def test_head_params_trainable(self): + """The GRU + projection head parameters are trainable.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_domino_config())]) + head = [ + (n, p) for n, p in model.named_parameters() if "prefix_gru" in n or "embed_proj" in n + ] + assert len(head) >= 3 # weight_ih_l0, weight_hh_l0, 2x embed_proj + assert all(p.requires_grad for _, p in head) + + def test_dflash_mode_still_creates_plain_dflash(self): + """Without projector_type=domino, conversion still yields a plain DFlash model.""" + from modelopt.torch.speculative.plugins.modeling_dflash import DFlashModule + + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_mask_token_id"] = 0 + config["dflash_architecture_config"] = {"num_hidden_layers": NUM_DRAFT_LAYERS} + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", config)]) + assert isinstance(model, HFDFlashModel) + assert not isinstance(model, HFDominoModel) + assert type(model.dflash_module) is DFlashModule + + +class TestDominoForward: + """Test the Domino training forward (online path on CPU).""" + + def _make_batch(self, vocab_size): + torch.manual_seed(0) + input_ids = torch.randint(1, vocab_size, (2, SEQ_LEN)) + attention_mask = torch.ones_like(input_ids) + labels = input_ids.clone() + return input_ids, attention_mask, labels + + def test_forward_produces_dual_loss_and_grads(self): + """Forward returns a scalar loss; backward populates head + backbone grads.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_domino_config())]) + model.train() + + vocab = model.dflash_config.vocab_size + input_ids, attention_mask, labels = self._make_batch(vocab) + + out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert out.loss.requires_grad + assert out.loss.dim() == 0 + # dual-loss bookkeeping + assert "base_loss" in out.domino_metrics + assert "final_loss" in out.domino_metrics + assert "base_accuracy" in out.domino_metrics + assert out.domino_metrics["lambda_base"] == 1.0 # default before any callback + + out.loss.backward() + gru_grad = model.dflash_module.prefix_gru.weight_ih_l0.grad + proj_grad = model.dflash_module.embed_proj[2].weight.grad + backbone_grad = model.dflash_module.fc.weight.grad + assert gru_grad is not None and torch.isfinite(gru_grad).all() + assert proj_grad is not None and torch.isfinite(proj_grad).all() + assert backbone_grad is not None and torch.isfinite(backbone_grad).all() + + def test_lambda_zero_uses_final_only(self): + """With lambda_base=0 the loss equals final_loss (correction head only).""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_domino_config())]) + model.train() + model._lambda_base = 0.0 + + vocab = model.dflash_config.vocab_size + input_ids, attention_mask, labels = self._make_batch(vocab) + out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert abs(out.loss.item() - out.domino_metrics["final_loss"]) < 1e-4 + + +class TestLambdaSchedule: + """Test the lambda_base curriculum schedule.""" + + def test_linear_decay(self): + assert compute_lambda_base(0, 100, 1.0, 1.0) == 1.0 + assert abs(compute_lambda_base(50, 100, 1.0, 1.0) - 0.5) < 1e-6 + assert compute_lambda_base(100, 100, 1.0, 1.0) == 0.0 + # decay_ratio=0.5 → fully decayed at the halfway point + assert compute_lambda_base(50, 100, 1.0, 0.5) == 0.0 + + +class TestDominoExporter: + """Test the Domino checkpoint export format (z-lab reference layout).""" + + def _export(self, tmp_path): + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_domino_config())]) + exporter = model.get_exporter() + export_dir = tmp_path / "exported" + exporter.export(export_dir) + return export_dir + + def test_export_weight_keys_match_reference(self, tmp_path): + """Exported weights include head tensors under the reference names, no prefix.""" + from safetensors.torch import load_file + + export_dir = self._export(tmp_path) + sd = load_file(str(export_dir / "model.safetensors")) + for key in sd: + assert "dflash_module." not in key + assert "rotary_emb" not in key + assert "prefix_gru.weight_ih_l0" in sd + assert "prefix_gru.weight_hh_l0" in sd + assert "embed_proj.0.weight" in sd + assert "embed_proj.2.weight" in sd + # GRU stores no bias (bias=False) + assert "prefix_gru.bias_ih_l0" not in sd + + def test_export_config_has_domino_fields(self, tmp_path): + """config.json carries the dflash_config domino fields + top-level emb_dim.""" + export_dir = self._export(tmp_path) + with open(export_dir / "config.json") as f: + cfg = json.load(f) + + assert cfg["architectures"] == ["DFlashDraftModel"] + assert cfg["emb_dim"] == EMB_DIM + dc = cfg["dflash_config"] + assert dc["projector_type"] == "domino" + assert dc["shift_label"] is True + assert dc["pure_draft_prefix_len"] == 1 + assert dc["gru_hidden_dim"] == GRU_HIDDEN_DIM + assert dc["emb_dim"] == EMB_DIM + assert "mask_token_id" in dc + assert "target_layer_ids" in dc diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml new file mode 100644 index 00000000000..8d23c22101e --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml @@ -0,0 +1,60 @@ +# Domino online speculative decoding training for Qwen3-8B. +# +# Domino = the parallel DFlash draft backbone + a lightweight causal correction +# head (GRU over the block's previously decoded tokens -> logit correction on the +# block suffix), trained with a base/final dual loss whose lambda_base weight is +# decayed from 1->0 over training (curriculum). See the domino.yaml recipe and +# modelopt/torch/speculative/plugins/{modeling,hf}_domino.py. +# +# 1-step pipeline (training only): +# task_0: Online Domino training + export of the drafter checkpoint +# +# NOTE: the inference side (vLLM / AR evaluation) is intentionally not wired up +# yet — the Domino correction head is not applied in pseudo_speculative_generate +# or in the serving stack. Add the vLLM smoke-test / MT-Bench AR-eval steps +# (see hf_online_dflash.yaml task_1/task_2) once the inference path lands. +# +# Reference: SpecForge PR #571 (z-lab) | drafter format: +# huggingface.co/Huang2020/Qwen3-8B-Domino-b16 +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_domino.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml --yes + +job_name: Qwen3-8B_Domino_online +pipeline: + global_vars: + hf_model: /hf-local/Qwen/Qwen3-8B + + # Step 1: Online Domino training (the script exports the drafter at the end). + task_0: + script: common/specdec/dflash_online_training.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/domino.yaml + - model.model_name_or_path=<> + - data.data_path=/hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K-openai.jsonl + - data.chat_template=examples/Qwen/Qwen3-8B/chat_template_train.jinja + - training.output_dir=/scratchspace/domino_bs16 + - training.per_device_train_batch_size=1 + - training.num_train_epochs=1 + - training.training_seq_len=4096 + - training.save_steps=5000 + - training.logging_steps=100 + - training.disable_tqdm=true + - training.answer_only_loss=true + # Domino knobs (also set in the recipe; repeated here for visibility). + - dflash.dflash_block_size=16 + - dflash.dflash_num_anchors=512 + - dflash.dflash_loss_decay_factor=7 + - dflash.dflash_mask_token_id=151669 + - dflash.dflash_self_logit_distillation=false + - dflash.dflash_lambda_base_start=1.0 + - dflash.dflash_lambda_base_decay_ratio=1.0 + environment: + - MAX_FINAL_LOSS: "5.0" + - MIN_FINAL_ACC: "0.15" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 From f0a1a99de1e6d95b3e3a86141559e77b16bc6c98 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 14 Jun 2026 01:43:24 +0000 Subject: [PATCH 2/5] add domino support Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../Qwen/Qwen3-8B/hf_online_domino.yaml | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml index 8d23c22101e..f3095079a15 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_domino.yaml @@ -6,8 +6,9 @@ # decayed from 1->0 over training (curriculum). See the domino.yaml recipe and # modelopt/torch/speculative/plugins/{modeling,hf}_domino.py. # -# 1-step pipeline (training only): -# task_0: Online Domino training + export of the drafter checkpoint +# 2-step pipeline: +# task_0: Build training conversations (Daring-Anteater multi-turn SFT, 50K) +# task_1: Online Domino training + export of the drafter checkpoint # # NOTE: the inference side (vLLM / AR evaluation) is intentionally not wired up # yet — the Domino correction head is not applied in pseudo_speculative_generate @@ -26,17 +27,35 @@ pipeline: global_vars: hf_model: /hf-local/Qwen/Qwen3-8B - # Step 1: Online Domino training (the script exports the drafter at the end). + # Step 1: Build input conversations. example_data_config.yaml enables only the + # daring-anteater source (train: 50000) — multi-turn SFT with real assistant + # completions. --full-conversations keeps those completions so answer_only_loss + # has assistant spans to mask. make_dataset.sh writes /scratchspace/data/train.jsonl. task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + # Step 2: Online Domino training (the script exports the drafter at the end). + # Consumes the conversations built in task_0 (shared via /scratchspace). + task_1: script: common/specdec/dflash_online_training.sh args: - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/domino.yaml - model.model_name_or_path=<> - - data.data_path=/hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K-openai.jsonl + - data.data_path=/scratchspace/data/train.jsonl - data.chat_template=examples/Qwen/Qwen3-8B/chat_template_train.jinja - training.output_dir=/scratchspace/domino_bs16 - training.per_device_train_batch_size=1 - training.num_train_epochs=1 + - training.max_steps=2000 - training.training_seq_len=4096 - training.save_steps=5000 - training.logging_steps=100 From 9d904c39deb72c2ebb08f653944a5ebf5978787c Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 15 Jun 2026 18:53:32 +0000 Subject: [PATCH 3/5] Add Domino changelog entry under 0.46 Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 49c58586674..e220aff20f0 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ Changelog - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. +- Add **Domino** speculative-decoding training: the parallel DFlash draft backbone plus a lightweight GRU causal correction head, selected via ``dflash_architecture_config.projector_type=domino``. Trained with a base/final dual loss whose ``dflash_lambda_base_start``/``dflash_lambda_base_decay_ratio`` curriculum decays the base-loss weight 1→0. Exports in the z-lab drafter format; recipe at ``modelopt_recipes/general/speculative_decoding/domino.yaml``. Training only — the inference path is not wired up yet. 0.45 (2026-06-xx) ^^^^^^^^^^^^^^^^^ From ba737bc13ec5957d372679068555a1056dfd500c Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 15 Jun 2026 21:40:40 +0000 Subject: [PATCH 4/5] coderabbit comments Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 3 +-- modelopt/torch/speculative/config.py | 4 ++++ modelopt/torch/speculative/dflash/conversion.py | 15 ++++++++++----- .../torch/speculative/plugins/modeling_domino.py | 5 +++++ .../torch/speculative/plugins/test_hf_domino.py | 6 ++---- 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index b998c702a7f..cd350bef1ed 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -54,6 +54,7 @@ ModelOptMedusaRecipe, ModelOptSpeculativeRecipeBase, ) +from modelopt.torch.speculative.plugins.hf_domino import DominoLambdaCallback from modelopt.torch.speculative.plugins.hf_training_args import ( TrainingArguments as SpecTrainingArgs, ) @@ -282,8 +283,6 @@ def train(): isinstance(recipe, ModelOptDFlashRecipe) and recipe.dflash.dflash_architecture_config.get("projector_type") == "domino" ): - from modelopt.torch.speculative.plugins.hf_domino import DominoLambdaCallback - callbacks.append(DominoLambdaCallback()) # Leave training_args.ignore_data_skip at its default (False). The dataset is # map-style, so HF Trainer's resume skips consumed indices at the batch-sampler diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 2982273265c..e554355958f 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -134,6 +134,8 @@ class DFlashConfig(ModeloptBaseConfig): dflash_lambda_base_start: float = ModeloptField( default=1.0, + ge=0.0, + le=1.0, description=( "Domino only: initial weight of the base (backbone-only) loss in the " "loss = (1 - lambda)*final + lambda*base mixture; linearly decayed to 0. " @@ -143,6 +145,8 @@ class DFlashConfig(ModeloptBaseConfig): dflash_lambda_base_decay_ratio: float = ModeloptField( default=1.0, + gt=0.0, + le=1.0, description=( "Domino only: fraction of total training steps over which lambda_base " "decays from dflash_lambda_base_start to 0." diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py index 03406088dd1..b5cb82c4db1 100644 --- a/modelopt/torch/speculative/dflash/conversion.py +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -42,11 +42,16 @@ def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertRe config.dflash_architecture_config = {**default_dflash_config, **custom_config} # Route to the Domino registry when the architecture asks for the Domino head. - registry = ( - DominoDMRegistry - if config.dflash_architecture_config.get("projector_type") == "domino" - else DFlashDMRegistry - ) + projector_type = config.dflash_architecture_config.get("projector_type") + if projector_type == "domino": + registry = DominoDMRegistry + elif projector_type in (None, "dflash"): + registry = DFlashDMRegistry + else: + raise ValueError( + f"Unsupported dflash_architecture_config.projector_type: {projector_type!r}. " + "Expected 'dflash' (default) or 'domino'." + ) original_cls = type(model) if original_cls not in registry: diff --git a/modelopt/torch/speculative/plugins/modeling_domino.py b/modelopt/torch/speculative/plugins/modeling_domino.py index 85dbcc31a8f..de598d9fd3a 100644 --- a/modelopt/torch/speculative/plugins/modeling_domino.py +++ b/modelopt/torch/speculative/plugins/modeling_domino.py @@ -56,6 +56,11 @@ def __init__(self, config): # pure_draft_prefix_len positions at the block start keep base logits only # (no causal correction); the GRU correction applies to the suffix. self.pure_draft_prefix_len = getattr(config, "pure_draft_prefix_len", 1) + if not 0 <= self.pure_draft_prefix_len < self.block_size: + raise ValueError( + f"pure_draft_prefix_len must be in [0, {self.block_size - 1}] " + f"(block_size={self.block_size}), got {self.pure_draft_prefix_len}." + ) self.shift_label = getattr(config, "shift_label", True) # Causal state over the block's token embeddings. bias=False matches the diff --git a/tests/unit/torch/speculative/plugins/test_hf_domino.py b/tests/unit/torch/speculative/plugins/test_hf_domino.py index 1cdfa74a45f..cff275895b7 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_domino.py +++ b/tests/unit/torch/speculative/plugins/test_hf_domino.py @@ -26,11 +26,13 @@ import torch from _test_utils.torch.transformers_models import get_tiny_llama +from safetensors.torch import load_file import modelopt.torch.speculative as mtsp from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG from modelopt.torch.speculative.plugins.hf_dflash import HFDFlashModel from modelopt.torch.speculative.plugins.hf_domino import HFDominoModel, compute_lambda_base +from modelopt.torch.speculative.plugins.modeling_dflash import DFlashModule from modelopt.torch.speculative.plugins.modeling_domino import DominoModule BLOCK_SIZE = 4 @@ -96,8 +98,6 @@ def test_head_params_trainable(self): def test_dflash_mode_still_creates_plain_dflash(self): """Without projector_type=domino, conversion still yields a plain DFlash model.""" - from modelopt.torch.speculative.plugins.modeling_dflash import DFlashModule - config = deepcopy(DFLASH_DEFAULT_CFG["config"]) config["dflash_mask_token_id"] = 0 config["dflash_architecture_config"] = {"num_hidden_layers": NUM_DRAFT_LAYERS} @@ -181,8 +181,6 @@ def _export(self, tmp_path): def test_export_weight_keys_match_reference(self, tmp_path): """Exported weights include head tensors under the reference names, no prefix.""" - from safetensors.torch import load_file - export_dir = self._export(tmp_path) sd = load_file(str(export_dir / "model.safetensors")) for key in sd: From 35437c5f692799eb30e1c77ed3337a885af3380f Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Mon, 15 Jun 2026 21:52:04 +0000 Subject: [PATCH 5/5] address comments Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../torch/export/plugins/hf_spec_export.py | 5 ++-- .../torch/speculative/plugins/hf_domino.py | 30 ++++++++++++++++++- .../general/speculative_decoding/domino.yaml | 4 +++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index f3d153b0bc4..5664f138a7e 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -452,8 +452,9 @@ def _export_config(self): config = super()._export_config() draft_config = self.model.dflash_config - emb_dim = getattr(draft_config, "emb_dim") - gru_hidden_dim = getattr(draft_config, "gru_hidden_dim") + # Present because HFDominoModel.modify validates them at convert time. + emb_dim = draft_config.emb_dim + gru_hidden_dim = draft_config.gru_hidden_dim # Mirror the reference checkpoint: emb_dim also appears at the top level. config["emb_dim"] = emb_dim config["dflash_config"].update( diff --git a/modelopt/torch/speculative/plugins/hf_domino.py b/modelopt/torch/speculative/plugins/hf_domino.py index e4d37412d13..6ee7a52ddce 100644 --- a/modelopt/torch/speculative/plugins/hf_domino.py +++ b/modelopt/torch/speculative/plugins/hf_domino.py @@ -86,6 +86,15 @@ def _build_draft_module(self, dflash_config): def modify(self, config): """Initialize the Domino draft module and read the lambda_base schedule.""" + # Validate head fields up front: a clear error here beats a cryptic + # AttributeError later in DominoModule.__init__ or the exporter. + arch_config = config.dflash_architecture_config + missing = [k for k in ("emb_dim", "gru_hidden_dim") if arch_config.get(k) is None] + if missing: + raise ValueError( + f"Domino (projector_type='domino') requires {missing} in " + "dflash_architecture_config (the GRU correction head dimensions)." + ) super().modify(config) # Curriculum schedule for the base/final loss mixing weight. Read here # (DFlashConfig carries the two fields); updated each step by @@ -249,6 +258,14 @@ def forward( two-term loss. Eval/offline-eval is delegated to the DFlash parent. """ if not self.training: + # Eval delegates to the DFlash backbone; the correction head is not + # applied yet, so warn once that acceptance rates are backbone-only. + if not getattr(self, "_warned_eval_head_bypass", False): + logger.warning( + "Domino eval uses the DFlash backbone only (correction head not " + "applied yet); reported acceptance rates are backbone-only." + ) + self._warned_eval_head_bypass = True return super().forward( input_ids=input_ids, attention_mask=attention_mask, @@ -355,7 +372,18 @@ def on_step_begin(self, args, state, control, **kwargs): inner = getattr(model, "module", model) if not hasattr(inner, "dflash_lambda_base_start"): return - total_steps = state.max_steps if state.max_steps and state.max_steps > 0 else 1 + if state.max_steps and state.max_steps > 0: + total_steps = state.max_steps + else: + # No max_steps -> decay window is one step -> lambda_base is 0 from the + # start, disabling the curriculum. Warn once instead of doing it silently. + total_steps = 1 + if not getattr(self, "_warned_no_max_steps", False): + logger.warning( + "DominoLambdaCallback: state.max_steps unset (<=0); lambda_base " + "curriculum disabled (decays to 0 from the first step)." + ) + self._warned_no_max_steps = True inner._lambda_base = compute_lambda_base( state.global_step, total_steps, diff --git a/modelopt_recipes/general/speculative_decoding/domino.yaml b/modelopt_recipes/general/speculative_decoding/domino.yaml index 256b9e03a45..6dc21ea66af 100644 --- a/modelopt_recipes/general/speculative_decoding/domino.yaml +++ b/modelopt_recipes/general/speculative_decoding/domino.yaml @@ -37,6 +37,8 @@ training: cp_size: 1 dp_shard_size: 1 disable_tqdm: true + # Keep off: eval runs the DFlash backbone only (correction head not applied + # yet), so AR would reflect the backbone alone, not the trained model. estimate_ar: false ar_validate_steps: 0 answer_only_loss: true @@ -51,6 +53,8 @@ training: bf16: true tf32: true remove_unused_columns: false + # Required: while lambda_base == 1 the head params are unused in the backward + # graph, so DDP needs unused params allowed. ddp_find_unused_parameters: true ddp_timeout: 1800 report_to: tensorboard