Describe the bug
The goal is to use TEOpsSequential with MXFP8 for MoE in the format
from transformer_engine.pytorch.ops import GroupedLinear as TEOpsGroupedLinear
from transformer_engine.pytorch.ops import Sequential as TEOpsSequential
from transformer_engine.pytorch.ops import SwiGLU as TEOpsSwiGLU
self.experts_gate_up = TEOpsGroupedLinear
self.experts_down = TEOpsGroupedLinear
object.__setattr__(
self,
"_experts_ffn_op",
TEOpsSequential(self.experts_gate_up, TEOpsSwiGLU(), self.experts_down),
)
This results in an error in the backward pass with MXFP8. (BF16 backward works).
grad_output = SwiGLU.backward() # grad_output is MXFP8TensorStorage
grad_output.reshape(-1, self.out_features)
AttributeError: 'MXFP8TensorStorage' object has no attribute 'reshape'
Then, I read test_grouped_mlp in test_fusible_ops.py:
fc1 = te_ops.GroupedLinear
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
fc2 = te_ops.GroupedLinear
te_ops.Sequential(
fc1,
scaled_act,
fc2,
)
which uses ScaledSwiGLU to bypass the problem. But it requires two major changes:
- the weights be interleaved into blocks of 32, which does not match the pretrained Huggingface format (e.g. Mixtral 8x7B).
Current HF gate_up weight is:
gate[14,336], up[14,336] (intermediate size)
But TE Needs:
gate[32], up[32], gate[32], up[32], ...
- Router probabilities be moved inside the swiGLU.
Before:
expert_out = down(silu(gate) * up)
final = unpermute(expert_out, probs)
After
hidden = ScaledSwiGLU(gate_up, probs)
expert_out = down(hidden)
final = unpermute(expert_out, probs=None)
Is this the recommended way to use TEOpsSequential with MXFP8?
Describe the bug
The goal is to use TEOpsSequential with MXFP8 for MoE in the format
This results in an error in the backward pass with MXFP8. (BF16 backward works).
Then, I read
test_grouped_mlpin test_fusible_ops.py:which uses ScaledSwiGLU to bypass the problem. But it requires two major changes:
Is this the recommended way to use TEOpsSequential with MXFP8?