Skip to content

Support neox partial RoPE (head_dim=256) for Qwen3.5#7043

Open
wangna11BD wants to merge 2 commits intoPaddlePaddle:developfrom
wangna11BD:spport_neox_partial_rope_head_dim256
Open

Support neox partial RoPE (head_dim=256) for Qwen3.5#7043
wangna11BD wants to merge 2 commits intoPaddlePaddle:developfrom
wangna11BD:spport_neox_partial_rope_head_dim256

Conversation

@wangna11BD
Copy link
Copy Markdown

@wangna11BD wangna11BD commented Mar 27, 2026

Motivation

为 FastDeploy 增加对 Qwen3.5 模型的推理支持。Qwen3.5 使用 head_dim=256partial_rotary_factor=0.25(即 rotary_dim=64)的 neox 风格部分旋转位置编码(partial RoPE),此前框架不支持该配置。本 PR 新增了对应的 GPU kernel 和 Python 层支持,同时修复了量化 KV cache 写入 kernel 中 shared memory 相关的 bug。

Modifications

1. 新增 Qwen3.5 partial neox RoPE CUDA kernel(custom_ops/gpu_ops/append_attn/qwen3_rope.h

  • 新增头文件,包含两个 CUDA kernel 及其 launcher:
    • GQAVariableLengthRotarySplitKernel_Qwen3(从 .cu 迁移,head_dim=128,Qwen3 全量交错式 RoPE)
    • GQAVariableLengthNeoxPartialRotarySplitKernel_Qwen3_5新增,head_dim=256,Qwen3.5 neox 风格 partial RoPE)
  • Qwen3.5 kernel 仅对 head 的 [0, rotary_dim=64) 部分进行旋转(rotate_half 语义),[rotary_dim, head_dim) 部分直接透传。
  • 新增 DISPATCH_GQA_ROPE_HEAD_DIM 宏,统一分发 head_dim=128(Qwen3)和 head_dim=256(Qwen3.5)两条路径。

2. 路由逻辑更新(custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu

  • 新增 head_dim==256 分支:从 embedding tensor shape 中自动推断 rotary_dim,并路由至 Qwen3.5 kernel。
  • 将原来硬编码的 AppendCacheKV<data_t, 128, 64> 调用统一改为 DISPATCH_GQA_ROPE_HEAD_DIM 宏分发,使 KV cache 写入流程支持 head_dim=256

3. Python 层 RoPE 修复与 MRoPE 支持(fastdeploy/model_executor/layers/rotary_embedding.py

  • Bug 修复QwenRotaryEmbedding 此前接受 partial_rotary_factor 参数但未实际生效,现修复为正确应用:rotary_dim = int(head_dim * partial_rotary_factor)
  • 新增 mrope_section 参数,支持 Qwen3.5-VL 多模态 RoPE,新增 apply_interleaved_mrope 方法将 T/H/W 三组位置频率以交错方式合并。
  • get_rope_impl 函数新增从 rope_parameters 中读取 mrope_section 并传入 QwenRotaryEmbedding

4. Shared memory Bug 修复(custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh

  • append_write_cache_kv_c8_qkvappend_write_cache_kv_c8_qkv_dynamicappend_write_cache_kv_c4_qkv kernel 中的静态 __shared__ 数组声明改为动态 shared memory(extern __shared__ char dyn_smem_buf[]),避免大 HEAD_DIM 场景下的编译或运行时问题。
  • 修复 CascadeAppendWriteCacheKVC8QKVCascadeAppendWriteCacheKVC4QKV launcher 中 kernel launch 时 shared memory size 参数传 0 的错误,改为传入正确计算的 smem_size

5. 单元测试(tests/layers/test_qwen35_rope.py

  • TestQwenRotaryEmbedding:验证 partial_rotary_factor 正确生效、输出 shape、cos/sin 数值正确性、MRoPE section 各种场景。
  • TestGqaRopeWriteCacheQwen35:端到端 CUDA kernel 测试,包括 Qwen3.5(head_dim=256)neox partial RoPE 数值正确性(与 rotate_half 参考实现对比)、V 不被旋转验证、透传区域不变验证、以及 Qwen3(head_dim=128)回归测试。

Usage or Command

python tests/layers/test_qwen35_rope.py

Accuracy Tests

新增的 CUDA kernel 输出已在 test_neox_partial_rope_correctness 测试用例中与 PyTorch rotate_half 参考实现进行了数值对比,验证精度 atol=1e-2。

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[Feature]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 27, 2026

Thanks for your contribution!

@wangna11BD wangna11BD changed the title spport neox_partial_rope head_dim=256 for qwen3.5 Support neox partial RoPE (head_dim=256) for Qwen3.5 Mar 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant