Skip to content

Expose Batch Invariant Kernels #2986

@wdykas

Description

@wdykas

Using batch invariant kernels has become common for many post training workloads to get 0 log prob mismatch between training and inference. We can monkey patch around this in TE but it would be very useful if TE exposed efficient batch invariant kernels for common operations like grouped gemm and regular gemm. This would be very useful for our team megatron inference and megatron RL for larger experimentation and a better interface.

A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

Describe the solution you'd like

batch invariant mode in TE

Describe alternatives you've considered

I am currently monkey patching general grouped gemm, rmsnorm, and general gemm with my own inefficient kernels.

Additional context

Add any other context or screenshots about the feature request here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions