feat: add LoRA infrastructure for Archon engine (Phase 1 & 2)#1000
feat: add LoRA infrastructure for Archon engine (Phase 1 & 2)#1000NJX-njx wants to merge 2 commits intoinclusionAI:mainfrom
Conversation
Port LoRA core infrastructure and checkpointing from fw/archon-lora branch to main, implementing Phase 1 (core LoRA modules) and Phase 2 (PEFT- compatible checkpoint I/O) from the lora-global-plan. Phase 1 - Core LoRA Infrastructure: - LoRALinear module following torchtune patterns (FSDP2-compatible) - AdapterModule protocol for parameter extraction - Utilities: get_adapter_params, set_trainable_params, disable/enable_adapter - PEFT-compatible naming (lowercase lora_a/lora_b internally) - Zero-initialization of lora_b ensures initial output matches base model Phase 2 - Checkpointing & PEFT Conversion: - save_lora_adapter: Save adapter in PEFT format (safetensors + config) - load_lora_adapter: Load adapter from PEFT format with key validation - is_lora_adapter_checkpoint: Detect PEFT adapter checkpoints - Qwen2StateDictAdapter: 16 LoRA key mappings + to_peft_module_map - BaseStateDictAdapter: create_peft_adapter_config method - ArchonEngine: lora_config attribute, LoRA-aware save/load Tests: - 18+ unit tests for LoRALinear (forward, gradient, from_linear, dropout) - PEFT compatibility tests (skipped if PEFT not installed) - State dict key conversion tests (all 16 mappings) - Checkpoint detection tests - Round-trip conversion tests Ref inclusionAI#945
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Archon engine by integrating Low-Rank Adaptation (LoRA) capabilities, enabling efficient fine-tuning of large language models. It establishes the foundational modules for LoRA, ensures compatibility with the HuggingFace PEFT checkpoint format for seamless model sharing, and integrates these features directly into the engine's save and load mechanisms. The design adheres to torchtune patterns, supporting FSDP2, and provides robust testing for the new components. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces LoRA infrastructure for the Archon engine, including core LoRA modules, PEFT-compatible checkpointing, engine integration, and tests. The changes involve adding new files and modifying existing ones to support LoRA fine-tuning within AReaL's Archon engine. The code adheres to torchtune design patterns and ensures FSDP2 compatibility. The review focuses on correctness and maintainability, with suggestions for improving code clarity and error handling.
| alpha=float(config.lora_alpha), | ||
| target_modules=config.target_modules if config.target_modules else [], | ||
| ) |
There was a problem hiding this comment.
It's good to see the alpha value being converted to float to ensure type consistency. However, consider adding a try-except block to handle potential ValueError if config.lora_alpha cannot be converted to a float. This will make the code more robust.
rank=config.lora_rank,
alpha=float(config.lora_alpha) if isinstance(config.lora_alpha, (int, float)) else config.lora_alpha,
target_modules=config.target_modules if config.target_modules else [],| if engine.lora_config is None: | ||
| raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine") |
There was a problem hiding this comment.
This RuntimeError is appropriate for preventing the saving of LoRA adapters when LoRA is not enabled. Consider adding a more descriptive error message to provide better guidance to the user, such as suggesting they enable LoRA in the engine configuration.
raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.")| if engine.lora_config is None: | ||
| raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine") |
There was a problem hiding this comment.
Similar to the save_lora_adapter function, this RuntimeError is appropriate for preventing the loading of LoRA adapters when LoRA is not enabled. Consider adding a more descriptive error message to provide better guidance to the user, such as suggesting they enable LoRA in the engine configuration.
raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.")| if not os.path.exists(weights_path): | ||
| raise FileNotFoundError( | ||
| f"Adapter weights not found at {path}. " | ||
| "Expected adapter_model.safetensors or adapter_model.bin" | ||
| ) |
There was a problem hiding this comment.
This FileNotFoundError is appropriate for handling missing adapter weights. Consider including the path in the error message to help the user quickly identify the missing file.
raise FileNotFoundError(
f"Adapter weights not found at {weights_path}. "
"Expected adapter_model.safetensors or adapter_model.bin"
)| raise ValueError( | ||
| "Adapter checkpoint keys don't match model. " + " ".join(error_msg) | ||
| ) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Pull request overview
This PR adds LoRA (Low-Rank Adaptation) infrastructure for the Archon engine, implementing Phase 1 (Core LoRA Modules) and Phase 2 (PEFT-Compatible Checkpointing). It introduces a custom LoRALinear module following torchtune patterns, adapter utilities for parameter management, PEFT-format checkpoint save/load, Qwen2 key mappings, and engine integration for LoRA-aware save/load routing.
Changes:
- Core LoRA modules:
LoRALinearwith FSDP2-compatible design andAdapterModuleprotocol with utilities for parameter extraction, freezing, and adapter toggling - PEFT-compatible checkpointing: Save/load LoRA adapters in HuggingFace PEFT format with key conversion between HF (
lora_A/lora_B) and Archon (lora_a/lora_b) conventions, plus Qwen2-specific key mappings - Engine integration:
ArchonEnginesave/load methods routed to LoRA-specific checkpoint paths when LoRA is enabled
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
areal/experimental/models/archon/lora/lora_linear.py |
New LoRALinear module with kaiming/zero init, forward pass, from_linear classmethod, and adapter protocol |
areal/experimental/models/archon/lora/adapter.py |
AdapterModule protocol and utilities: get_adapter_params, set_trainable_params, get_adapter_state_dict, disable/enable_adapter |
areal/experimental/models/archon/lora/__init__.py |
Package init with public API exports |
areal/experimental/engine/archon_lora_checkpoint.py |
PEFT-format save/load/detect functions for LoRA adapter checkpoints |
areal/experimental/models/archon/base.py |
Added to_peft_module_map and create_peft_adapter_config() to base state dict adapter |
areal/experimental/models/archon/qwen2/model/state_dict_adapter.py |
Added 16 LoRA key mappings and PEFT module name mapping for Qwen2 |
areal/experimental/engine/archon_engine.py |
LoRA config extraction in __init__, LoRA-aware save/load routing |
areal/utils/logging.py |
Added "LoRACheckpoint" logger color |
tests/experimental/archon/test_lora_linear.py |
Comprehensive unit tests for LoRALinear and adapter utilities |
tests/experimental/archon/test_archon_lora_checkpoint.py |
Tests for key conversion, config generation, checkpoint detection, and round-trip consistency |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| logger.info(f"Loaded {loaded_count} adapter parameters into model") | ||
|
|
||
| if dist.is_initialized(): | ||
| dist.barrier() |
There was a problem hiding this comment.
Same issue as the save path: dist.barrier() should use engine.cpu_group to match the convention in archon_checkpoint.py (e.g., line 437, 450, 463).
| dist.barrier() | |
| dist.barrier(group=engine.cpu_group) |
| if self.lora_config is not None: | ||
| from areal.experimental.engine.archon_lora_checkpoint import ( | ||
| save_lora_adapter, | ||
| ) | ||
|
|
||
| save_lora_adapter(self, meta.path, meta.base_model_path) | ||
| return |
There was a problem hiding this comment.
When lora_config is not None, the save method unconditionally saves only the LoRA adapter and returns early, ignoring meta.weight_format and meta.with_optim. This means optimizer state is never saved during LoRA training, making it impossible to resume training from a checkpoint. Consider also saving optimizer state when meta.with_optim is True, similar to the non-LoRA code path.
| from areal.experimental.engine.archon_lora_checkpoint import ( | ||
| is_lora_adapter_checkpoint, | ||
| load_lora_adapter, | ||
| ) | ||
|
|
||
| if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path): | ||
| load_lora_adapter(self, meta.path) | ||
| return |
There was a problem hiding this comment.
The LoRA checkpoint imports are done unconditionally on every load() call, even when LoRA is not enabled. In save(), the import is correctly guarded inside if self.lora_config is not None. For consistency and to avoid unnecessary imports, move the import inside the if block, similar to save():
if self.lora_config is not None:
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)
if is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return| from areal.experimental.engine.archon_lora_checkpoint import ( | |
| is_lora_adapter_checkpoint, | |
| load_lora_adapter, | |
| ) | |
| if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path): | |
| load_lora_adapter(self, meta.path) | |
| return | |
| if self.lora_config is not None: | |
| from areal.experimental.engine.archon_lora_checkpoint import ( | |
| is_lora_adapter_checkpoint, | |
| load_lora_adapter, | |
| ) | |
| if is_lora_adapter_checkpoint(meta.path): | |
| load_lora_adapter(self, meta.path) | |
| return |
| if hasattr(config, "use_lora") and config.use_lora: | ||
| from dataclasses import dataclass | ||
|
|
||
| @dataclass | ||
| class LoRAConfig: | ||
| enabled: bool | ||
| rank: int | ||
| alpha: float | ||
| target_modules: list[str] | ||
|
|
||
| self.lora_config = LoRAConfig( | ||
| enabled=True, | ||
| rank=config.lora_rank, | ||
| alpha=float(config.lora_alpha), | ||
| target_modules=config.target_modules if config.target_modules else [], | ||
| ) |
There was a problem hiding this comment.
Defining the LoRAConfig dataclass inside the __init__ method body is an anti-pattern. This creates a new class object on every ArchonEngine instantiation, makes the class non-importable/non-reusable, and complicates type annotations elsewhere. The dataclass should be defined at module level or in a separate config module. Additionally, the enabled field is redundant since self.lora_config being non-None already indicates LoRA is enabled.
| adapter_params = get_adapter_params(engine.model) | ||
|
|
||
| if not adapter_params: | ||
| logger.warning("No adapter parameters found in model") | ||
| if rank == 0: | ||
| logger.warning("Creating empty adapter checkpoint") | ||
|
|
||
| # Convert to HF format using state dict adapter | ||
| archon_state = {k: v.cpu().detach().clone() for k, v in adapter_params.items()} | ||
| hf_state = engine.state_dict_adapter.to_hf(archon_state) | ||
|
|
||
| # Add PEFT prefix: base_model.model.{key} | ||
| peft_state = {f"base_model.model.{k}": v for k, v in hf_state.items()} |
There was a problem hiding this comment.
Under FSDP2, parameters returned by get_adapter_params(engine.model) will be DTensor (sharded across ranks). Calling .cpu().detach().clone() on a DTensor only captures the local shard, not the full tensor. This means each rank would save different fragments of the LoRA weights, and the rank-0 checkpoint would be incomplete/incorrect.
The existing checkpoint code in archon_checkpoint.py uses get_model_state_dict() with StateDictOptions to properly gather/unshard tensors before saving. The LoRA save path should do the same — e.g., use get_model_state_dict then filter to adapter keys, or use full_tensor() on DTensors before saving.
| # Load adapter weights into model | ||
| loaded_count = 0 | ||
| for key, value in archon_state.items(): | ||
| if key in expected_adapter_params: | ||
| param = expected_adapter_params[key] | ||
| value = value.to(device=param.device, dtype=param.dtype) | ||
| param.data.copy_(value) | ||
| loaded_count += 1 |
There was a problem hiding this comment.
Similar to the save path issue: under FSDP2, expected_adapter_params[key] will be a DTensor (sharded). Directly assigning to param.data.copy_(value) on a DTensor won't properly distribute the loaded weights across the FSDP2 mesh. The existing load code in archon_checkpoint.py uses set_model_state_dict() to handle this correctly. The LoRA load path should use an equivalent mechanism to ensure proper shard distribution.
| self.register_parameter("bias", None) | ||
|
|
||
| # LoRA adapters (trainable) | ||
| # Note: naming lora_a, lora_b (lowercase) matches PEFT convention |
There was a problem hiding this comment.
The comment says "naming lora_a, lora_b (lowercase) matches PEFT convention", but PEFT actually uses lora_A and lora_B (uppercase A/B). The lowercase naming is an Archon convention, with case conversion handled in the state dict adapter. Consider updating the comment to reflect this accurately, e.g., "naming lora_a, lora_b (lowercase) follows torchtune convention; PEFT uses lora_A/lora_B which is handled by the state dict adapter."
| # Note: naming lora_a, lora_b (lowercase) matches PEFT convention | |
| # Note: naming lora_a, lora_b (lowercase) follows torchtune/Archon convention; | |
| # PEFT uses lora_A/lora_B, which is handled by the state dict adapter. |
|
|
||
| # Synchronize all ranks | ||
| if dist.is_initialized(): | ||
| dist.barrier() |
There was a problem hiding this comment.
The barrier calls here use the default process group (dist.barrier()), but the existing checkpoint code in archon_checkpoint.py consistently uses dist.barrier(group=engine.cpu_group) to synchronize. Using different groups for barriers can cause hangs or incorrect synchronization in multi-group setups. Consider passing engine.cpu_group to dist.barrier() for consistency with the existing checkpoint code.
| dist.barrier() | |
| dist.barrier(group=engine.cpu_group) |
|
Hi @NJX-njx , I've fixed several bugs in LoRA-Archon training in this pr #1015 . Now it can successfully enter the training loop and checkpoints can be saved and loaded correctly from the training / infer backend. I will continue to test the stability of lora training on an RLVR exp w/ DAPO-17k dataset. |
|
This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days. Please add a comment or push new commits to keep it active. Thank you for your contribution! |
Summary
Adds LoRA (Low-Rank Adaptation) infrastructure for the Archon engine, implementing Phase 1 (Core LoRA Modules) and Phase 2 (PEFT-Compatible Checkpointing) from the LoRA global plan.
This PR provides the foundational building blocks for LoRA fine-tuning within AReaL's Archon engine, following torchtune design patterns with FSDP2 compatibility.
Closes #945
What's Changed
Phase 1: Core LoRA Modules
areal/experimental/models/archon/lora/lora_linear.py—LoRALinearmodule with:lora_a, zero init forlora_bfrom_linear()classmethod for wrapping existingnn.Linearlayersadapter_params()protocol method for AdapterModule compatibilitynn.Linearsub-modules)areal/experimental/models/archon/lora/adapter.py— Adapter utilities:AdapterModuleprotocol (runtime-checkable)get_adapter_params()— collect all LoRA parameters from a modelset_trainable_params()— freeze base model, keep only adapter params trainableget_adapter_state_dict()— filter state dict to adapter-only keysdisable_adapter()/enable_adapter()— toggle LoRA at inference timePhase 2: PEFT-Compatible Checkpointing
areal/experimental/engine/archon_lora_checkpoint.py— Checkpoint utilities:save_lora_adapter()— saves in HuggingFace PEFT format (adapter_model.safetensors+adapter_config.json)load_lora_adapter()— loads PEFT-format checkpoints with key conversion (HF ↔ Archon)is_lora_adapter_checkpoint()— detects PEFT LoRA checkpointsareal/experimental/models/archon/base.py— ExtendedBaseStateDictAdapter:to_peft_module_mapfor Archon→HF module name mappingcreate_peft_adapter_config()for generating PEFT-format configareal/experimental/models/archon/qwen2/model/state_dict_adapter.py— Qwen2 LoRA key mappings:lora_A/lora_B↔ Archonlora_a/lora_bto_peft_module_mapfor all target modulesEngine Integration
areal/experimental/engine/archon_engine.py— LoRA-aware save/load:lora_configand routes to LoRA-specific checkpoint pathTests (18+ test cases)
tests/experimental/archon/test_lora_linear.py— LoRALinear unit tests:from_linear, adapter protocoltests/experimental/archon/test_archon_lora_checkpoint.py— Checkpoint tests:Design Decisions
transformers+peftdirectly.Testing
No GPU required for these tests (CPU-only).