Skip to content

feat: add LoRA infrastructure for Archon engine (Phase 1 & 2)#1000

Open
NJX-njx wants to merge 2 commits intoinclusionAI:mainfrom
NJX-njx:feature/archon-lora-infra
Open

feat: add LoRA infrastructure for Archon engine (Phase 1 & 2)#1000
NJX-njx wants to merge 2 commits intoinclusionAI:mainfrom
NJX-njx:feature/archon-lora-infra

Conversation

@NJX-njx
Copy link
Copy Markdown
Contributor

@NJX-njx NJX-njx commented Mar 6, 2026

Summary

Adds LoRA (Low-Rank Adaptation) infrastructure for the Archon engine, implementing Phase 1 (Core LoRA Modules) and Phase 2 (PEFT-Compatible Checkpointing) from the LoRA global plan.

This PR provides the foundational building blocks for LoRA fine-tuning within AReaL's Archon engine, following torchtune design patterns with FSDP2 compatibility.

Closes #945

What's Changed

Phase 1: Core LoRA Modules

  • areal/experimental/models/archon/lora/lora_linear.pyLoRALinear module with:

    • Kaiming-uniform init for lora_a, zero init for lora_b
    • from_linear() classmethod for wrapping existing nn.Linear layers
    • adapter_params() protocol method for AdapterModule compatibility
    • Disable/enable support for inference without adapter overhead
    • FSDP2-compatible design (no custom autograd, standard nn.Linear sub-modules)
  • areal/experimental/models/archon/lora/adapter.py — Adapter utilities:

    • AdapterModule protocol (runtime-checkable)
    • get_adapter_params() — collect all LoRA parameters from a model
    • set_trainable_params() — freeze base model, keep only adapter params trainable
    • get_adapter_state_dict() — filter state dict to adapter-only keys
    • disable_adapter() / enable_adapter() — toggle LoRA at inference time

Phase 2: PEFT-Compatible Checkpointing

  • areal/experimental/engine/archon_lora_checkpoint.py — Checkpoint utilities:

    • save_lora_adapter() — saves in HuggingFace PEFT format (adapter_model.safetensors + adapter_config.json)
    • load_lora_adapter() — loads PEFT-format checkpoints with key conversion (HF ↔ Archon)
    • is_lora_adapter_checkpoint() — detects PEFT LoRA checkpoints
  • areal/experimental/models/archon/base.py — Extended BaseStateDictAdapter:

    • Added to_peft_module_map for Archon→HF module name mapping
    • Added create_peft_adapter_config() for generating PEFT-format config
  • areal/experimental/models/archon/qwen2/model/state_dict_adapter.py — Qwen2 LoRA key mappings:

    • 18 LoRA key conversion entries (attention, MLP, lm_head)
    • Case conversion: HF lora_A/lora_B ↔ Archon lora_a/lora_b
    • to_peft_module_map for all target modules

Engine Integration

  • areal/experimental/engine/archon_engine.py — LoRA-aware save/load:
    • Detects lora_config and routes to LoRA-specific checkpoint path
    • Preserves existing non-LoRA checkpoint behavior

Tests (18+ test cases)

  • tests/experimental/archon/test_lora_linear.py — LoRALinear unit tests:

    • Initialization, forward pass, gradient flow, from_linear, adapter protocol
    • Bias handling, dropout modes, PEFT numerical equivalence (conditional)
  • tests/experimental/archon/test_archon_lora_checkpoint.py — Checkpoint tests:

    • Key conversion (attention, MLP, lm_head, case conversion)
    • PEFT adapter config generation
    • Checkpoint detection (valid, missing, non-LoRA, invalid JSON)
    • State dict round-trip consistency

Design Decisions

  1. Custom LoRALinear over PEFT library: Follows torchtune patterns for native FSDP2 compatibility. PEFT's injection model doesn't align with Archon's at-construction approach.
  2. PEFT checkpoint format: Ensures HuggingFace ecosystem interoperability — adapters can be loaded by transformers + peft directly.
  3. At-construction injection: LoRA layers are built into the model at construction time (not post-hoc injected), consistent with Archon's model building pattern.

Testing

pytest tests/experimental/archon/test_lora_linear.py -v
pytest tests/experimental/archon/test_archon_lora_checkpoint.py -v

No GPU required for these tests (CPU-only).

Port LoRA core infrastructure and checkpointing from fw/archon-lora branch
to main, implementing Phase 1 (core LoRA modules) and Phase 2 (PEFT-
compatible checkpoint I/O) from the lora-global-plan.

Phase 1 - Core LoRA Infrastructure:
- LoRALinear module following torchtune patterns (FSDP2-compatible)
- AdapterModule protocol for parameter extraction
- Utilities: get_adapter_params, set_trainable_params, disable/enable_adapter
- PEFT-compatible naming (lowercase lora_a/lora_b internally)
- Zero-initialization of lora_b ensures initial output matches base model

Phase 2 - Checkpointing & PEFT Conversion:
- save_lora_adapter: Save adapter in PEFT format (safetensors + config)
- load_lora_adapter: Load adapter from PEFT format with key validation
- is_lora_adapter_checkpoint: Detect PEFT adapter checkpoints
- Qwen2StateDictAdapter: 16 LoRA key mappings + to_peft_module_map
- BaseStateDictAdapter: create_peft_adapter_config method
- ArchonEngine: lora_config attribute, LoRA-aware save/load

Tests:
- 18+ unit tests for LoRALinear (forward, gradient, from_linear, dropout)
- PEFT compatibility tests (skipped if PEFT not installed)
- State dict key conversion tests (all 16 mappings)
- Checkpoint detection tests
- Round-trip conversion tests

Ref inclusionAI#945
Copilot AI review requested due to automatic review settings March 6, 2026 13:20
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Archon engine by integrating Low-Rank Adaptation (LoRA) capabilities, enabling efficient fine-tuning of large language models. It establishes the foundational modules for LoRA, ensures compatibility with the HuggingFace PEFT checkpoint format for seamless model sharing, and integrates these features directly into the engine's save and load mechanisms. The design adheres to torchtune patterns, supporting FSDP2, and provides robust testing for the new components.

Highlights

  • Core LoRA Modules Introduced: Implemented a custom LoRALinear module with Kaiming-uniform initialization for lora_a and zero initialization for lora_b, including a from_linear() classmethod for wrapping existing nn.Linear layers and FSDP2-compatible design. An AdapterModule protocol was defined, along with utilities to extract, manage, and toggle adapter parameters (e.g., get_adapter_params, set_trainable_params, disable_adapter).
  • PEFT-Compatible Checkpointing: Added functionality to save and load LoRA adapters in HuggingFace PEFT format (adapter_model.safetensors + adapter_config.json), ensuring interoperability with the HuggingFace ecosystem. This includes save_lora_adapter(), load_lora_adapter(), and is_lora_adapter_checkpoint() functions.
  • Archon Engine Integration & Model Adaptations: The ArchonEngine was updated to be LoRA-aware, routing save/load operations to LoRA-specific checkpoint paths when enabled. The BaseStateDictAdapter was extended with to_peft_module_map and create_peft_adapter_config() for generating PEFT-format configurations, and Qwen2StateDictAdapter received specific LoRA key mappings for attention, MLP, and lm_head modules.
  • Comprehensive Testing: New unit tests were added for LoRALinear covering initialization, forward pass, gradient flow, from_linear conversion, adapter protocol, bias handling, and dropout modes. Checkpoint tests were also introduced for key conversion, PEFT adapter config generation, checkpoint detection, and state dict round-trip consistency.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/experimental/engine/archon_engine.py
    • Added lora_config attribute to ArchonEngine to store LoRA parameters if enabled.
    • Modified save method to conditionally call save_lora_adapter when LoRA is active.
    • Modified load method to conditionally call load_lora_adapter if LoRA is active and the checkpoint is a PEFT adapter.
  • areal/experimental/engine/archon_lora_checkpoint.py
    • Added new module for LoRA adapter checkpoint I/O in PEFT format.
    • Implemented save_lora_adapter to save LoRA weights and PEFT configuration.
    • Implemented load_lora_adapter to load PEFT-format checkpoints with key conversion.
    • Implemented is_lora_adapter_checkpoint to detect PEFT LoRA checkpoints.
  • areal/experimental/models/archon/base.py
    • Added to_peft_module_map attribute to BaseStateDictAdapter for mapping Archon module names to PEFT names.
    • Added create_peft_adapter_config method to generate PEFT adapter configuration JSON.
  • areal/experimental/models/archon/lora/init.py
    • Added new module to expose LoRA-related classes and functions.
  • areal/experimental/models/archon/lora/adapter.py
    • Added new module defining the AdapterModule protocol.
    • Implemented get_adapter_params to extract adapter parameters from a model.
    • Implemented set_trainable_params to freeze base model parameters and keep only adapter parameters trainable.
    • Implemented get_adapter_state_dict to filter a state dictionary to include only adapter parameters.
    • Implemented disable_adapter and enable_adapter functions to control LoRA activation during inference.
  • areal/experimental/models/archon/lora/lora_linear.py
    • Added new module implementing the LoRALinear layer for LoRA.
    • Included __init__ with Kaiming-uniform init for lora_a and zero init for lora_b.
    • Provided from_linear classmethod to convert nn.Linear layers to LoRALinear.
    • Implemented adapter_params method to comply with the AdapterModule protocol.
  • areal/experimental/models/archon/qwen2/model/state_dict_adapter.py
    • Extended Qwen2StateDictAdapter with LoRA key mappings for attention, MLP, and LM head modules.
    • Populated to_peft_module_map with Qwen2-specific module name conversions for PEFT compatibility.
  • areal/utils/logging.py
    • Added 'LoRACheckpoint' to the LOG_COLORS dictionary for colored logging output.
  • tests/experimental/archon/test_archon_lora_checkpoint.py
    • Added new test file for LoRA adapter checkpointing.
    • Included tests for Qwen2 LoRA key conversion, case conversion, and all 16 LoRA mappings.
    • Added tests for PEFT adapter config generation, including base model path handling.
    • Implemented tests for LoRA adapter checkpoint detection, covering valid, missing, non-LoRA, and invalid JSON scenarios.
    • Provided tests for state dict round-trip consistency with LoRA keys and mixed base/LoRA keys.
  • tests/experimental/archon/test_lora_linear.py
    • Added new test file for the LoRALinear module and adapter utilities.
    • Included tests for LoRALinear initialization, forward pass, gradient flow, from_linear conversion, AdapterModule protocol, and disabled flag functionality.
    • Added tests for get_adapter_params, set_trainable_params, get_adapter_state_dict, disable_adapter, and enable_adapter utilities.
    • Provided tests for LoRALinear with bias and dropout functionality.
    • Included compatibility tests against PEFT's LoRA Linear module for forward pass, gradient flow, scaling factor, and initialization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces LoRA infrastructure for the Archon engine, including core LoRA modules, PEFT-compatible checkpointing, engine integration, and tests. The changes involve adding new files and modifying existing ones to support LoRA fine-tuning within AReaL's Archon engine. The code adheres to torchtune design patterns and ensures FSDP2 compatibility. The review focuses on correctness and maintainability, with suggestions for improving code clarity and error handling.

Comment on lines +210 to +212
alpha=float(config.lora_alpha),
target_modules=config.target_modules if config.target_modules else [],
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's good to see the alpha value being converted to float to ensure type consistency. However, consider adding a try-except block to handle potential ValueError if config.lora_alpha cannot be converted to a float. This will make the code more robust.

                rank=config.lora_rank,
                alpha=float(config.lora_alpha) if isinstance(config.lora_alpha, (int, float)) else config.lora_alpha,
                target_modules=config.target_modules if config.target_modules else [],

Comment on lines +53 to +54
if engine.lora_config is None:
raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This RuntimeError is appropriate for preventing the saving of LoRA adapters when LoRA is not enabled. Consider adding a more descriptive error message to provide better guidance to the user, such as suggesting they enable LoRA in the engine configuration.

        raise RuntimeError("Cannot save LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.")

Comment on lines +136 to +137
if engine.lora_config is None:
raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the save_lora_adapter function, this RuntimeError is appropriate for preventing the loading of LoRA adapters when LoRA is not enabled. Consider adding a more descriptive error message to provide better guidance to the user, such as suggesting they enable LoRA in the engine configuration.

        raise RuntimeError("Cannot load LoRA adapter: LoRA not enabled on engine. Please ensure 'use_lora' is set to True in the training configuration.")

Comment on lines +152 to +156
if not os.path.exists(weights_path):
raise FileNotFoundError(
f"Adapter weights not found at {path}. "
"Expected adapter_model.safetensors or adapter_model.bin"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This FileNotFoundError is appropriate for handling missing adapter weights. Consider including the path in the error message to help the user quickly identify the missing file.

            raise FileNotFoundError(
                f"Adapter weights not found at {weights_path}. "
                "Expected adapter_model.safetensors or adapter_model.bin"
            )

Comment on lines +191 to +193
raise ValueError(
"Adapter checkpoint keys don't match model. " + " ".join(error_msg)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This ValueError is appropriate for handling key mismatches. Consider including the path in the error message to help the user quickly identify the incorrect checkpoint.

            raise ValueError(
                f"Adapter checkpoint keys don't match model at {path}. " + " ".join(error_msg)
            )

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds LoRA (Low-Rank Adaptation) infrastructure for the Archon engine, implementing Phase 1 (Core LoRA Modules) and Phase 2 (PEFT-Compatible Checkpointing). It introduces a custom LoRALinear module following torchtune patterns, adapter utilities for parameter management, PEFT-format checkpoint save/load, Qwen2 key mappings, and engine integration for LoRA-aware save/load routing.

Changes:

  • Core LoRA modules: LoRALinear with FSDP2-compatible design and AdapterModule protocol with utilities for parameter extraction, freezing, and adapter toggling
  • PEFT-compatible checkpointing: Save/load LoRA adapters in HuggingFace PEFT format with key conversion between HF (lora_A/lora_B) and Archon (lora_a/lora_b) conventions, plus Qwen2-specific key mappings
  • Engine integration: ArchonEngine save/load methods routed to LoRA-specific checkpoint paths when LoRA is enabled

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
areal/experimental/models/archon/lora/lora_linear.py New LoRALinear module with kaiming/zero init, forward pass, from_linear classmethod, and adapter protocol
areal/experimental/models/archon/lora/adapter.py AdapterModule protocol and utilities: get_adapter_params, set_trainable_params, get_adapter_state_dict, disable/enable_adapter
areal/experimental/models/archon/lora/__init__.py Package init with public API exports
areal/experimental/engine/archon_lora_checkpoint.py PEFT-format save/load/detect functions for LoRA adapter checkpoints
areal/experimental/models/archon/base.py Added to_peft_module_map and create_peft_adapter_config() to base state dict adapter
areal/experimental/models/archon/qwen2/model/state_dict_adapter.py Added 16 LoRA key mappings and PEFT module name mapping for Qwen2
areal/experimental/engine/archon_engine.py LoRA config extraction in __init__, LoRA-aware save/load routing
areal/utils/logging.py Added "LoRACheckpoint" logger color
tests/experimental/archon/test_lora_linear.py Comprehensive unit tests for LoRALinear and adapter utilities
tests/experimental/archon/test_archon_lora_checkpoint.py Tests for key conversion, config generation, checkpoint detection, and round-trip consistency

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

logger.info(f"Loaded {loaded_count} adapter parameters into model")

if dist.is_initialized():
dist.barrier()
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the save path: dist.barrier() should use engine.cpu_group to match the convention in archon_checkpoint.py (e.g., line 437, 450, 463).

Suggested change
dist.barrier()
dist.barrier(group=engine.cpu_group)

Copilot uses AI. Check for mistakes.
Comment on lines +688 to +694
if self.lora_config is not None:
from areal.experimental.engine.archon_lora_checkpoint import (
save_lora_adapter,
)

save_lora_adapter(self, meta.path, meta.base_model_path)
return
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When lora_config is not None, the save method unconditionally saves only the LoRA adapter and returns early, ignoring meta.weight_format and meta.with_optim. This means optimizer state is never saved during LoRA training, making it impossible to resume training from a checkpoint. Consider also saving optimizer state when meta.with_optim is True, similar to the non-LoRA code path.

Copilot uses AI. Check for mistakes.
Comment on lines +712 to +719
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)

if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoRA checkpoint imports are done unconditionally on every load() call, even when LoRA is not enabled. In save(), the import is correctly guarded inside if self.lora_config is not None. For consistency and to avoid unnecessary imports, move the import inside the if block, similar to save():

if self.lora_config is not None:
    from areal.experimental.engine.archon_lora_checkpoint import (
        is_lora_adapter_checkpoint,
        load_lora_adapter,
    )
    if is_lora_adapter_checkpoint(meta.path):
        load_lora_adapter(self, meta.path)
        return
Suggested change
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)
if self.lora_config is not None and is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return
if self.lora_config is not None:
from areal.experimental.engine.archon_lora_checkpoint import (
is_lora_adapter_checkpoint,
load_lora_adapter,
)
if is_lora_adapter_checkpoint(meta.path):
load_lora_adapter(self, meta.path)
return

Copilot uses AI. Check for mistakes.
Comment on lines +197 to +212
if hasattr(config, "use_lora") and config.use_lora:
from dataclasses import dataclass

@dataclass
class LoRAConfig:
enabled: bool
rank: int
alpha: float
target_modules: list[str]

self.lora_config = LoRAConfig(
enabled=True,
rank=config.lora_rank,
alpha=float(config.lora_alpha),
target_modules=config.target_modules if config.target_modules else [],
)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining the LoRAConfig dataclass inside the __init__ method body is an anti-pattern. This creates a new class object on every ArchonEngine instantiation, makes the class non-importable/non-reusable, and complicates type annotations elsewhere. The dataclass should be defined at module level or in a separate config module. Additionally, the enabled field is redundant since self.lora_config being non-None already indicates LoRA is enabled.

Copilot uses AI. Check for mistakes.
Comment on lines +66 to +78
adapter_params = get_adapter_params(engine.model)

if not adapter_params:
logger.warning("No adapter parameters found in model")
if rank == 0:
logger.warning("Creating empty adapter checkpoint")

# Convert to HF format using state dict adapter
archon_state = {k: v.cpu().detach().clone() for k, v in adapter_params.items()}
hf_state = engine.state_dict_adapter.to_hf(archon_state)

# Add PEFT prefix: base_model.model.{key}
peft_state = {f"base_model.model.{k}": v for k, v in hf_state.items()}
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under FSDP2, parameters returned by get_adapter_params(engine.model) will be DTensor (sharded across ranks). Calling .cpu().detach().clone() on a DTensor only captures the local shard, not the full tensor. This means each rank would save different fragments of the LoRA weights, and the rank-0 checkpoint would be incomplete/incorrect.

The existing checkpoint code in archon_checkpoint.py uses get_model_state_dict() with StateDictOptions to properly gather/unshard tensors before saving. The LoRA save path should do the same — e.g., use get_model_state_dict then filter to adapter keys, or use full_tensor() on DTensors before saving.

Copilot uses AI. Check for mistakes.
Comment on lines +206 to +213
# Load adapter weights into model
loaded_count = 0
for key, value in archon_state.items():
if key in expected_adapter_params:
param = expected_adapter_params[key]
value = value.to(device=param.device, dtype=param.dtype)
param.data.copy_(value)
loaded_count += 1
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the save path issue: under FSDP2, expected_adapter_params[key] will be a DTensor (sharded). Directly assigning to param.data.copy_(value) on a DTensor won't properly distribute the loaded weights across the FSDP2 mesh. The existing load code in archon_checkpoint.py uses set_model_state_dict() to handle this correctly. The LoRA load path should use an equivalent mechanism to ensure proper shard distribution.

Copilot uses AI. Check for mistakes.
self.register_parameter("bias", None)

# LoRA adapters (trainable)
# Note: naming lora_a, lora_b (lowercase) matches PEFT convention
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "naming lora_a, lora_b (lowercase) matches PEFT convention", but PEFT actually uses lora_A and lora_B (uppercase A/B). The lowercase naming is an Archon convention, with case conversion handled in the state dict adapter. Consider updating the comment to reflect this accurately, e.g., "naming lora_a, lora_b (lowercase) follows torchtune convention; PEFT uses lora_A/lora_B which is handled by the state dict adapter."

Suggested change
# Note: naming lora_a, lora_b (lowercase) matches PEFT convention
# Note: naming lora_a, lora_b (lowercase) follows torchtune/Archon convention;
# PEFT uses lora_A/lora_B, which is handled by the state dict adapter.

Copilot uses AI. Check for mistakes.

# Synchronize all ranks
if dist.is_initialized():
dist.barrier()
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The barrier calls here use the default process group (dist.barrier()), but the existing checkpoint code in archon_checkpoint.py consistently uses dist.barrier(group=engine.cpu_group) to synchronize. Using different groups for barriers can cause hangs or incorrect synchronization in multi-group setups. Consider passing engine.cpu_group to dist.barrier() for consistency with the existing checkpoint code.

Suggested change
dist.barrier()
dist.barrier(group=engine.cpu_group)

Copilot uses AI. Check for mistakes.
@MikaStars39
Copy link
Copy Markdown

MikaStars39 commented Mar 11, 2026

Hi @NJX-njx , I've fixed several bugs in LoRA-Archon training in this pr #1015 . Now it can successfully enter the training loop and checkpoints can be saved and loaded correctly from the training / infer backend. I will continue to test the stability of lora training on an RLVR exp w/ DAPO-17k dataset.
cc @rchardx

@mike1729 mike1729 mentioned this pull request Mar 28, 2026
21 tasks
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 8, 2026

This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days.

Please add a comment or push new commits to keep it active.

Thank you for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Question] Supporting RL training of other model families

4 participants