Skip to content

Conversation

@tscholak
Copy link
Collaborator

@tscholak tscholak commented Jan 12, 2026

Summary

Adds vLLM-optimized Apriel2 model implementation to fast_llm_external_models

  • Uses vLLM's ModelRegistry.register_model() for runtime registration (no vLLM patching required)
  • Supports hybrid architectures: attention, mamba, GDN, and KDA mixers

Attribution

Model implementation based on work by @nandahkrishna from the apriel2-vllm branch. This PR adapts that implementation for plugin-based registration as an alternative to patching vLLM directly.

Goal

Evaluate whether vLLM's plugin/registration mechanism can work for us as a short-term solution, avoiding the need to maintain a patched vLLM fork.

Usage

from fast_llm_external_models.apriel2.vllm import register
from vllm import LLM

register()
llm = LLM(model="path/to/apriel2/checkpoint")

vLLM vs Transformers Alignment Verification

Statistical comparison using test_apriel2.py stats command with:

  • 64 prompts from C4 dataset (deterministic sampling with seed=42)
  • 128 tokens prompt length
  • 16 decode steps
  • Identical token IDs sent to both backends (controlled tokenization)
  • Per-position logprob comparison with percentile statistics

Models Tested

Model Description
pure-gdn 100% GDN layers
attn-swa 100% attention (sliding window)
every5th-kda 80% attention + 20% KDA

Results Summary

Model Mode Match% Mean Diff p95 Diff Max Diff Outliers
GDN no-compile 84.4% 1.05 7.83 22.80 142 (14.0%)
GDN compiled 83.4% 1.07 7.87 13.99 155 (15.4%)
SWA no-compile 87.5% 0.83 7.55 18.59 111 (10.8%)
SWA compiled 84.7% 1.14 8.85 23.38 140 (13.7%)
KDA no-compile 87.1% 0.83 7.30 15.70 120 (11.7%)
KDA compiled 84.8% 1.03 8.39 15.15 141 (13.8%)

Per-Position Token Match Rate (no-compile mode)

Position GDN SWA KDA
prefill 95.3% 96.9% 98.4%
decode1 96.9% 93.8% 95.3%
decode2 92.2% 93.8% 89.1%
decode3 92.2% 92.2% 90.6%
decode4 90.6% 92.2% 90.6%
decode5 90.6% 90.6% 87.5%
decode6 85.7% 93.8% 89.1%
decode7 85.7% 89.1% 89.1%
decode8 84.1% 87.5% 89.1%
decode9 79.4% 82.8% 85.9%
decode10 81.0% 85.9% 84.4%
decode11 79.4% 84.4% 82.8%
decode12 73.0% 81.2% 82.8%
decode13 74.6% 79.7% 82.8%
decode14 74.2% 79.7% 78.1%
decode15 73.8% 76.6% 78.1%

Key Findings

1. Divergence is NOT mixer-specific

All models (GDN, SWA, KDA) show similar divergence patterns between vLLM and Transformers. This indicates the issue is in shared model code (RMSNorm, MLP, embeddings) rather than mixer implementations.

2. torch.compile has minimal impact

Compile vs no-compile produces nearly identical results:

  • GDN: 84.4% vs 83.4% match
  • SWA: 87.5% vs 84.7% match
  • KDA: 87.1% vs 84.8% match

Previous reports of GDN torch.compile issues appear to have been measurement artifacts.

3. Divergence accumulates over decode steps

  • Prefill: 95-98% token match rate
  • Decode15: 73-78% token match rate

Small numerical differences compound during autoregressive generation, causing progressive divergence.

4. Prefill is well-aligned

All models show excellent prefill alignment (95-98% match, avg diff ~0.04), making them reliable for likelihood-based evaluation (MMLU, etc.).


Implications

For likelihood-based evaluation (MMLU)

All models reliable - prefill-only evaluation shows 95-98% alignment

For generative evaluation (GSM8K)

⚠️ All models show accumulating divergence - vLLM and Transformers will produce different outputs over long generations, regardless of mixer type or compilation mode

Root Cause Investigation Needed

The divergence affects all model types equally, suggesting the issue is in:

  • RMSNorm implementation differences
  • MLP/SwiGLU numerical precision
  • Embedding layer handling
  • KV cache management differences

Test Configuration

# Run statistical comparison
python test_apriel2.py stats /path/to/model \
    --num-prompts 64 \
    --prompt-length 128 \
    --decode-length 16 \
    --batch-size 1 \
    --dtype bfloat16 \
    --tf-kernels vllm \
    [--no-compile]  # Add for no-compile mode

Test plan

  • Test registration mechanism with vLLM
  • Verify model loads correctly
  • Statistical comparison of vLLM vs Transformers (GDN, SWA, KDA)
  • Tested compile vs no-compile modes
  • Per-position analysis of divergence patterns
  • Investigate shared code divergence (RMSNorm, MLP, embeddings)

🤖 Generated with Claude Code

tscholak and others added 17 commits January 10, 2026 12:38
- Add README.md documenting the algebraic structure of the conversion system
  (surgery monoid, action law, plan composition, total vs partial operations)
- Add prune_supernet_step1.yaml and prune_supernet_step2.yaml examples
  demonstrating the two-step workflow for pruning a homogeneous supernet
  to a heterogeneous network with different mixer types per layer

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add modeling_apriel2.py with full vLLM-optimized implementation
  supporting attention, mamba, GDN, and KDA mixer types
- Add register() function for runtime model registration via
  vLLM's ModelRegistry (no patching required)
- Based on Nanda's vllm_diff.patch, adapted for external package use

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Refactor weight loading: each mixer module (Attention, MLP, GDN, KDA)
  now handles its own weight structure via load_weights() methods
- Fix KDA mamba_type to use "gdn_attention" for vLLM backend registration
- Add KDA op registration import for custom op support
- Remove unused positions parameter from KDA forward
- Add config_convertor.py for Apriel2TextConfig to vLLM config mapping
- Add test_apriel2.py for coherence and logit comparison testing
  between vLLM and Transformers implementations

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Remove all PyTorch fallback implementations to ensure fast CUDA kernels
are always used. The module now fails loudly at import/instantiation
if required kernels are missing.

Changes:
- Remove torch_causal_conv1d_fn and torch_causal_conv1d_update fallbacks
- Remove torch_selective_scan_fn and torch_selective_state_update stubs
- Remove torch_chunk_gated_delta_rule function
- Remove _recurrent_gated_delta_rule method from Apriel2GatedDeltaNet
- Remove _forward_local method from GatedRMSNormalization
- Remove TestFastVsSlowPath test class (no longer needed)
- Handle CausalConv1d seq_len==1 edge case via update() instead of fallback
- Add ImportError at module load for missing causal_conv1d/mamba_ssm
- Add ImportError at class init for missing FLA kernels

Required packages:
- causal_conv1d (for CausalConv1d)
- mamba_ssm (for Mamba/SSM operations)
- fla (for GDN, KDA, GatedRMSNormalization)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The chunk_gated_delta_rule call was always passing initial_state=None,
ignoring any existing recurrent state from previous decode cycles.
This broke continued generation scenarios (prefill -> decode -> prefill).

Changed initial_state=None to initial_state=recurrent_state to match
the correct behavior already present in KDA's chunk_kda call.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add test_vs_qwen3next_with_cache and test_vs_fla_with_cache tests that
verify mixer implementations through all inference phases:
- Phase 1: Initial prefill with cache population
- Phase 2: Single-token decode using cached states
- Phase 3: Prefill again (decode→prefill transition)

Tests compare outputs and recurrent states at each phase. Convolution
states are not compared due to different storage formats between
implementations (Apriel2 stores kernel_size-1, references store
kernel_size).

For GDN, Phase 3 documents expected divergence from Qwen3Next due to
its bug where chunk mode ignores initial_state.

For KDA, all phases should match since FLA correctly passes
initial_state in chunk mode.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Merge test_vs_qwen3next and test_vs_qwen3next_with_cache into single
  parameterized test with use_cache fixture
- Merge test_vs_fla and test_vs_fla_with_cache similarly
- Add use_cache (False/True) and decode_steps (4) fixtures
- Use proper Apriel2Cache from cache.py instead of ad-hoc SimpleCache
- Use same total sequence length for both cache and non-cache modes
- Skip cache tests when seq_len < decode_steps + 2 (too small for 3 phases)
- Split sequence as: prefill=2/3, decode=4, prefill2=1/3 of remaining

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix KDA mode selection to match FLA: use fused_recurrent only when
  seq_len <= 64 AND not training (single expression instead of override)
- Replace use_cache fixture with explicit phase fixtures (prefill_len,
  decode_steps, prefill2_len) for clearer test parameterization
- Update test_chunked_vs_recurrent to use Apriel2Cache and fixtures
- Rename config_dict to mixer_config for consistency across all tests
- Remove unused qwen3_config fixture (recreated inline where needed)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
CausalConv1d is now tested through KDA equivalence tests which use
CausalConv1d for q_conv, k_conv, v_conv. The isolated tests were also
obsolete since CPU fallback was removed.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Move all cache classes (_AttentionCache, _SSMCache, _DummyCacheLayer,
Apriel2Cache, _LayerListAccessor) into modeling_apriel2.py for better
tooling compatibility - modeling code is expected to be together.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Enable "fast" mode (bf16/sdpa) tests that were previously skipped
- Add test_dtype fixture parameter to all tests that create models
- Convert models to correct dtype with .to(device="cuda", dtype=test_dtype)
- Create input tensors with explicit dtype parameter
- Fix assert_close to cast tensors to same dtype before comparison

All 1718 mixer equivalence tests now pass in both fp32 and bf16 modes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add gdn_mixer_config and kda_mixer_config fixtures to centralize
  mixer config dict construction (eliminates 6 duplicate dicts)
- Add kda_hidden_size fixture for derived hidden_size calculation
- Add make_apriel2_config() helper for minimal Apriel2TextConfig
  construction (eliminates 4 duplicate config blocks)
- Update all GDN and KDA tests to use new fixtures
- Consolidate duplicate imports within test methods

Net reduction: 47 lines (-125/+78)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tscholak tscholak changed the base branch from main to fix/require-cuda-kernels-no-fallbacks January 18, 2026 22:59
oleksost and others added 3 commits January 19, 2026 14:15
- Fix rope_theta parameter: use 'rope_theta' key instead of 'base' in
  get_rope() call. This fixes attention alignment (0.002 fp32 / 0.05 bf16)
- Switch GDN from qwen3_fused_gdn_gating to fused_gdn_gating
- Add commented-out GQA head expansion code for GDN (WIP)
- Add dtype parameter to test_apriel2.py for bf16/fp32 comparison
- Use flash_attention_2 for bf16 transformers to match vLLM backend

Current alignment status:
- attn-swa: ✅ MATCH (0.002 fp32 / 0.05 bf16)
- KDA: ✅ MATCH (0.003 fp32 / 0.07 bf16)
- GDN: ❌ MISMATCH (14.6 - investigation ongoing)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Base automatically changed from fix/require-cuda-kernels-no-fallbacks to oo/apriel_modeling_bug January 19, 2026 14:42
Base automatically changed from oo/apriel_modeling_bug to main January 19, 2026 14:50
tscholak and others added 6 commits January 19, 2026 14:51
…sigmoid

The vLLM KDA implementation was hardcoding activation="sigmoid" for the
output normalization, while the transformers implementation defaults to
"silu" when not specified in config. This caused significant logprob
differences (avg 1.1) between vLLM and transformers.

Now reads norm_activation from mixer_config.normalization.activation
with default "silu" to match transformers behavior.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Changes to transformers model (modeling_apriel2.py):
- Add USE_VLLM_CONV, USE_VLLM_GDN_OPS, USE_VLLM_GATED_NORM flags
- Restructure kernel imports to use vLLM ops when flags enabled
- Add _debug_enabled, _debug_layer, _debug_final flags for debugging
- Handle vLLM vs FLA signature differences for fused_recurrent_gated_delta_rule

Changes to vLLM model (vllm/modeling_apriel2.py):
- Add _debug_enabled, _debug_layer flags for GDN mixer
- Add _debug_final, _debug_lm_head flags for final norm and LM head
- Gate debug prints with boolean flags instead of num_tokens checks

Changes to test script (vllm/test_apriel2.py):
- Add comprehensive comparison command for vLLM vs TF logprob testing
- Test across prompt sizes, decode lengths, and batch sizes

Results: Prefill logprobs now match perfectly between vLLM and TF
when using vLLM kernels (USE_VLLM_GDN_OPS=True, USE_VLLM_GATED_NORM=True).
Some divergence remains during multi-token decode for certain prompt lengths.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add _debug_state flag and _debug_state_stats() method to both TF and
vLLM GDN mixer classes to track recurrent state evolution during
prefill and decode phases.

Key additions:
- TF: Debug state after prefill and during decode for layer 1
- vLLM: Debug state with correct slot indexing for decode phase
- Print state statistics (mean, std, min, max, first8 values)

This helps investigate the decode divergence at specific prompt lengths
(50, 51, 59, 60, 70 tokens) where vLLM and TF produce different results.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add pure_gdn_step1.yaml: converts fixed -> pattern with all GDN blocks
- Add pure_gdn_step2.yaml: unwraps stochastic -> pure GDN mixer
- Improve TF GDN debug logging with try/except for tensor access
- Add vLLM GDN debug output logging during decode phase
- Add first mismatch details in test_apriel2.py compare output

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
tscholak and others added 11 commits January 21, 2026 16:10
Replace scattered class-level and function-local debug flags with
top-level DEBUG_* constants for easier control:

- DEBUG_GDN_LAYER: GDN layer forward pass (tensors, shapes)
- DEBUG_GDN_STATE: GDN recurrent state during decode
- DEBUG_GDN_OUTPUT: GDN output hidden states during decode
- DEBUG_KDA_LAYER: KDA layer outputs
- DEBUG_DECODER_LAYER: Decoder layer outputs (residual, norm)
- DEBUG_FINAL_NORM: Final norm before LM head
- DEBUG_LM_HEAD: LM head input/output

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Always call repeat_interleave for K→V head expansion (no-op when
value_heads_per_key == 1) to avoid conditional branches that confuse
torch.compile's shape inference.

Also temporarily comment out compilation_config in test script while
investigating hybrid model compilation issues.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Also keep USE_VLLM_* flags at False for upstream kernel testing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Change AttentionDecoderLayer.forward signature: move positions to optional kwarg
- All layers now accept (hidden_states, residual, positions=None, **kwargs)
- Remove isinstance dispatch in Apriel2Model.forward loop
- Call all layer types uniformly with same arguments

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Match Llama's approach: use torch._check to assert relationship between
positions and input_ids sizes without hardcoding values. This helps the
compiler understand dynamic shapes during chunked prefill warmup.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Debug code with f-strings (e.g., f"num_tokens={num_tokens}") caused
torch.compile to fail with ConstraintViolationError because f-strings
are evaluated before the function call, causing tensor.size() calls
to be traced even when debug flags are False.

Also commented out debug-related code that converts tensor values to
Python integers (e.g., int(tensor[0])) which breaks CUDA graph capture.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add 'stats' command for rigorous vLLM vs Transformers comparison
- Use C4 dataset for reproducible, diverse prompts
- Controlled tokenization: same token IDs to both backends via TokensPrompt
- Per-position statistics (prefill + each decode step)
- Percentile-based analysis (p10, p50, p90, p95, p99)
- Outlier detection and reporting
- Configurable: num_prompts, prompt_length, decode_length, tf_kernels, seed
- Fix --no-compile argparse bug in compare command

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Implements support for loading stochastic mixer models directly in vLLM
without conversion. Key changes:

- Add Apriel2StochasticMixer class that contains all sub-mixers and
  routes inputs to the active mixer at runtime
- Add Apriel2StochasticDecoderLayer for stochastic decoder blocks
- Implement "convex hull" page size computation that considers ALL
  sub-mixer types to ensure unified page size fits any mixer
- Use virtual layer indices (Falcon H1 style) to give each sub-mixer
  type its own cache allocation without conflicts
- Add test_loading.py for testing model loading without generation

The stochastic mixer allocates caches for all mixer types, enabling
future runtime mixer switching capability.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Extract _create_mixer_params helper to eliminate ~90 lines of duplication
  in get_block_params for stochastic mixer handling
- Fix MIXER_TYPE_OFFSETS bug: use mixer index instead of type to prevent
  collisions when multiple mixers share the same type (e.g., attention and
  sliding_window both have type "attention")
- Remove dead class-level get_kv_cache_spec method (vLLM calls instance
  methods on each layer, not the class-level method)
- Remove unused get_block_specs and get_block_name_for_layer functions

Net reduction of ~200 lines.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Cache get_unified_page_size_for_config results by object identity.
This avoids redundant computation when vLLM calls each layer's
get_kv_cache_spec independently (96 calls → 1 for 24-layer model
with 4 stochastic sub-mixers).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…tching

All mixers now use the vLLM-standard signature:
  forward(hidden_states, output, positions=None, **kwargs) -> None

This enables runtime placement switching between mixer types (attention,
gdn, kda, mamba) via collective_rpc without signature mismatches.

Changes:
- Apriel2Attention: write to output buffer instead of returning
- Apriel2MambaMixer/GDN/KDA: add positions parameter for uniformity
- Apriel2AttentionDecoderLayer: allocate buffer and pass to mixer
- Apriel2StochasticMixer: delegate to active mixer with unified signature
- Add worker monkey-patching for collective_rpc placement methods
- Add test_placement_comparison.py to validate output equivalence

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants