Skip to content

Fuse LayerNorm modulation into Triton kernel and remove RoPE dtype casts#133

Open
Jordanyang wants to merge 1 commit into
thu-ml:mainfrom
Jordanyang:opt_dtype_kernel_fusion
Open

Fuse LayerNorm modulation into Triton kernel and remove RoPE dtype casts#133
Jordanyang wants to merge 1 commit into
thu-ml:mainfrom
Jordanyang:opt_dtype_kernel_fusion

Conversation

@Jordanyang

Copy link
Copy Markdown

Summary

  • Remove unnecessary RoPE dtype conversions in the Wan2.1/Wan2.2 network paths.
  • Add FastLayerNorm.modulate.
  • Fuse layernorm + scale + shift into the Triton LayerNorm kernel path.

Motivation

This reduces extra dtype conversion overhead around RoPE and moves the modulation step into the fused LayerNorm kernel path, avoiding separate scale/shift operations after normalization.

Changes

  • Updated turbodiffusion/ops/core.py to support fused LayerNorm modulation.
  • Updated turbodiffusion/rcm/networks/wan2pt1.py to use the fused modulation path and remove RoPE dtype casts.
  • Updated turbodiffusion/rcm/networks/wan2pt2.py to use the same fused modulation path and remove RoPE dtype casts.

Validation

  • Successfully run Wan2.1 14B inference tests for both 720p and 480p resolutions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant