Skip to content

TE MoE TEOpsSequential with MXFP8 AttributeError: 'MXFP8TensorStorage' object has no attribute 'reshape' #2973

@faradawn

Description

@faradawn

Describe the bug

The goal is to use TEOpsSequential with MXFP8 for MoE in the format

    from transformer_engine.pytorch.ops import GroupedLinear as TEOpsGroupedLinear
    from transformer_engine.pytorch.ops import Sequential as TEOpsSequential
    from transformer_engine.pytorch.ops import SwiGLU as TEOpsSwiGLU


            self.experts_gate_up = TEOpsGroupedLinear
            self.experts_down = TEOpsGroupedLinear

            object.__setattr__(
                self,
                "_experts_ffn_op",
                TEOpsSequential(self.experts_gate_up, TEOpsSwiGLU(), self.experts_down),
            )

This results in an error in the backward pass with MXFP8. (BF16 backward works).

grad_output = SwiGLU.backward() # grad_output is MXFP8TensorStorage 

grad_output.reshape(-1, self.out_features)

AttributeError: 'MXFP8TensorStorage' object has no attribute 'reshape'

Then, I read test_grouped_mlp in test_fusible_ops.py:

fc1 = te_ops.GroupedLinear
scaled_act = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size)
fc2 = te_ops.GroupedLinear

te_ops.Sequential(
                fc1,
                scaled_act,
                fc2,
            )

which uses ScaledSwiGLU to bypass the problem. But it requires two major changes:

  1. the weights be interleaved into blocks of 32, which does not match the pretrained Huggingface format (e.g. Mixtral 8x7B).
Current HF gate_up weight is:

gate[14,336], up[14,336] (intermediate size)

But TE Needs: 

gate[32], up[32], gate[32], up[32], ...
  1. Router probabilities be moved inside the swiGLU.
Before:

expert_out = down(silu(gate) * up)
final = unpermute(expert_out, probs)

After

hidden = ScaledSwiGLU(gate_up, probs)
expert_out = down(hidden)
final = unpermute(expert_out, probs=None)

Is this the recommended way to use TEOpsSequential with MXFP8?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions