-
Notifications
You must be signed in to change notification settings - Fork 441
[Feat]: Domino support #1710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feat]: Domino support #1710
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -435,3 +435,34 @@ def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): | |
| f"Exported DFlash draft model: {len(drafter_sd)} tensors, " | ||
| f"config keys: {list(drafter_config.keys())[:5]}..." | ||
| ) | ||
|
|
||
|
|
||
| class DominoExporter(DFlashExporter): | ||
| """Draft model exporter for Domino (DFlash backbone + causal correction head). | ||
|
|
||
| Same z-lab-compatible format as DFlash, plus the Domino head weights | ||
| (``prefix_gru.*`` / ``embed_proj.*``, already captured by the inherited | ||
| ``dflash_module.`` stripping) and the extra config fields the loader needs to | ||
| rebuild the head (``projector_type``, ``emb_dim``, ``gru_hidden_dim``, | ||
| ``pure_draft_prefix_len``, ``shift_label``). | ||
| """ | ||
|
|
||
| def _export_config(self): | ||
| """Extend the DFlash config with the Domino head fields.""" | ||
| config = super()._export_config() | ||
| draft_config = self.model.dflash_config | ||
|
|
||
| emb_dim = getattr(draft_config, "emb_dim") | ||
| gru_hidden_dim = getattr(draft_config, "gru_hidden_dim") | ||
|
Comment on lines
+455
to
+456
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] The error happens at export time, after a long training run — far from where the misconfiguration was introduced. Two cleaner options:
Prefer the second — it also catches the equivalent |
||
| # Mirror the reference checkpoint: emb_dim also appears at the top level. | ||
| config["emb_dim"] = emb_dim | ||
| config["dflash_config"].update( | ||
| { | ||
| "projector_type": getattr(draft_config, "projector_type", "domino"), | ||
| "shift_label": getattr(draft_config, "shift_label", True), | ||
| "pure_draft_prefix_len": getattr(draft_config, "pure_draft_prefix_len", 1), | ||
| "gru_hidden_dim": gru_hidden_dim, | ||
| "emb_dim": emb_dim, | ||
| } | ||
| ) | ||
| return config | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,23 @@ class DFlashConfig(ModeloptBaseConfig): | |
| description="Whether to use torch.compile on DFlash forward/loss methods.", | ||
| ) | ||
|
|
||
| dflash_lambda_base_start: float = ModeloptField( | ||
| default=1.0, | ||
| description=( | ||
| "Domino only: initial weight of the base (backbone-only) loss in the " | ||
| "loss = (1 - lambda)*final + lambda*base mixture; linearly decayed to 0. " | ||
| "Ignored unless dflash_architecture_config.projector_type == 'domino'." | ||
| ), | ||
| ) | ||
|
|
||
| dflash_lambda_base_decay_ratio: float = ModeloptField( | ||
| default=1.0, | ||
| description=( | ||
| "Domino only: fraction of total training steps over which lambda_base " | ||
| "decays from dflash_lambda_base_start to 0." | ||
| ), | ||
| ) | ||
|
Comment on lines
+135
to
+150
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add schema bounds for Domino lambda fields to fail fast on invalid configs.
Suggested fix dflash_lambda_base_start: float = ModeloptField(
default=1.0,
+ ge=0.0,
+ le=1.0,
description=(
"Domino only: initial weight of the base (backbone-only) loss in the "
"loss = (1 - lambda)*final + lambda*base mixture; linearly decayed to 0. "
"Ignored unless dflash_architecture_config.projector_type == 'domino'."
),
)
dflash_lambda_base_decay_ratio: float = ModeloptField(
default=1.0,
+ gt=0.0,
+ le=1.0,
description=(
"Domino only: fraction of total training steps over which lambda_base "
"decays from dflash_lambda_base_start to 0."
),
)As per coding guidelines, “Validate external input once at the interface boundary.” 🤖 Prompt for AI AgentsSource: Coding guidelines |
||
|
|
||
|
|
||
| class MedusaConfig(ModeloptBaseConfig): | ||
| """Medusa config.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,26 +24,38 @@ | |
| from ..config import DFlashConfig | ||
|
|
||
| DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry | ||
| # Domino reuses the dflash mode/config/recipe but converts the base model to a | ||
| # DFlash module augmented with a causal correction head. It is selected via | ||
| # ``dflash_architecture_config.projector_type == "domino"`` and lives in its own | ||
| # registry so its wrapper (HFDominoModel) does not overwrite HFDFlashModel. | ||
| DominoDMRegistry = _DMRegistryCls(prefix="Domino") | ||
|
|
||
|
|
||
| def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: | ||
| """Convert the model to a DFlash model as per `config`.""" | ||
| """Convert the model to a DFlash (or Domino) model as per `config`.""" | ||
| model = model.init_modellike() if isinstance(model, ModelLikeModule) else model | ||
|
|
||
| original_cls = type(model) | ||
| if original_cls not in DFlashDMRegistry: | ||
| for cls in DFlashDMRegistry._registry: | ||
| if issubclass(original_cls, cls): | ||
| DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls]) | ||
| break | ||
|
|
||
| # merge custom config with default config (lazy import to avoid circular) | ||
| from .default_config import default_dflash_config | ||
|
|
||
| custom_config = config.dflash_architecture_config | ||
| config.dflash_architecture_config = {**default_dflash_config, **custom_config} | ||
|
|
||
| dflash_model = DFlashDMRegistry.convert(model) | ||
| # Route to the Domino registry when the architecture asks for the Domino head. | ||
| registry = ( | ||
| DominoDMRegistry | ||
| if config.dflash_architecture_config.get("projector_type") == "domino" | ||
| else DFlashDMRegistry | ||
| ) | ||
|
Comment on lines
+44
to
+49
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reject unsupported Line 47/48 currently treats any unknown Suggested fix- registry = (
- DominoDMRegistry
- if config.dflash_architecture_config.get("projector_type") == "domino"
- else DFlashDMRegistry
- )
+ projector_type = config.dflash_architecture_config.get("projector_type")
+ if projector_type == "domino":
+ registry = DominoDMRegistry
+ elif projector_type in (None, "dflash"):
+ registry = DFlashDMRegistry
+ else:
+ raise ValueError(f"Unsupported dflash_architecture_config.projector_type: {projector_type!r}")As per coding guidelines, “Validate external input once at the interface boundary.” 🤖 Prompt for AI AgentsSource: Coding guidelines |
||
|
|
||
| original_cls = type(model) | ||
| if original_cls not in registry: | ||
| for cls in registry._registry: | ||
| if issubclass(original_cls, cls): | ||
| registry.register({original_cls: "base_model_class"})(registry[cls]) | ||
| break | ||
|
|
||
| dflash_model = registry.convert(model) | ||
| dflash_model.modify(config) | ||
|
|
||
| metadata = {} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move
DominoLambdaCallbackimport to module scope or add explicit justification.Line 285 introduces an in-function import without documenting a circular/optional/heavy-import reason, which conflicts with the repo import-placement rule and can defer import failures to runtime.
As per coding guidelines, “Keep imports at the top of the file… Put an import inside a function only when there is a concrete reason… Add a brief comment in those cases naming the reason.”
🤖 Prompt for AI Agents
Source: Coding guidelines