Using batch invariant kernels has become common for many post training workloads to get 0 log prob mismatch between training and inference. We can monkey patch around this in TE but it would be very useful if TE exposed efficient batch invariant kernels for common operations like grouped gemm and regular gemm. This would be very useful for our team megatron inference and megatron RL for larger experimentation and a better interface.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
Describe the solution you'd like
batch invariant mode in TE
Describe alternatives you've considered
I am currently monkey patching general grouped gemm, rmsnorm, and general gemm with my own inefficient kernels.
Additional context
Add any other context or screenshots about the feature request here.
Using batch invariant kernels has become common for many post training workloads to get 0 log prob mismatch between training and inference. We can monkey patch around this in TE but it would be very useful if TE exposed efficient batch invariant kernels for common operations like grouped gemm and regular gemm. This would be very useful for our team megatron inference and megatron RL for larger experimentation and a better interface.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
Describe the solution you'd like
batch invariant mode in TE
Describe alternatives you've considered
I am currently monkey patching general grouped gemm, rmsnorm, and general gemm with my own inefficient kernels.
Additional context
Add any other context or screenshots about the feature request here.