Skip to content

Fix: Handle LowLevelZeroPlugin with use_fp8=True#6419

Open
Truong5724 wants to merge 3 commits intohpcaitech:mainfrom
Truong5724:Truong-Fix-FP8Scaling
Open

Fix: Handle LowLevelZeroPlugin with use_fp8=True#6419
Truong5724 wants to merge 3 commits intohpcaitech:mainfrom
Truong5724:Truong-Fix-FP8Scaling

Conversation

@Truong5724
Copy link
Copy Markdown

📌 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for better categorization
  • I have installed pre-commit: pip install pre-commit && pre-commit install

🚨 Issue number

Fixed #6387

📝 What does this PR do?

Problem

  • When running training with FP8 enabled under TorchDynamo (see main.py in the linked issue), execution fails during the forward/backward passes with a runtime error:
RuntimeError: shape '[32, 512]' is invalid for input of size 512.
  • The issue above comes from incompatibility between FakeTensor (used by TorchDynamo) and torch._scaled_mm().

  • FakeTensor only tracks metadata (shape, dtype, device) and requires all operations to be traceable without real data. However, _scaled_mm() is a low-level kernel that depends on real tensor data and does not support FakeTensor.

  • As a result, TorchDynamo cannot correctly trace this operation, leading to runtime errors such as shape mismatch.


Solution

To ensure compatibility with TorchDynamo (FakeTensor), this PR removes the dependency on torch._scaled_mm() and replaces it with explicit dequantization + standard matrix multiplication (edit in fp8.py).

1. Forward pass

Original implementation:

out = torch._scaled_mm(
    x_fp8,
    ctx.w_fp8_t,
    bias=bias,
    out_dtype=ctx.out_dtype,
    scale_a=inv_scale_x,
    scale_b=inv_scale_w,
    use_fast_accum=True,
)[0]

Replaced with:

x_deq = x_fp8.to(ctx.out_dtype) * inv_scale_x
w_t_deq = ctx.w_fp8_t.to(ctx.out_dtype) * inv_scale_w

out = x_deq @ w_t_deq
if bias is not None:
    out = out + bias.to(ctx.out_dtype)

2. Backward pass

Original implementation:

x_grad = torch._scaled_mm(
    out_grad_fp8,
    ctx.w_fp8_t.contiguous().t(),
    out_dtype=ctx.out_dtype,
    scale_a=out_grad_scale,
    scale_b=ctx.inv_scale_w,
    use_fast_accum=True,
)[0]

w_grad = torch._scaled_mm(
    out_grad_fp8.t().contiguous(),
    ctx.x_fp8.t().contiguous().t(),
    out_dtype=ctx.out_dtype,
    scale_a=out_grad_scale,
    scale_b=ctx.inv_scale_x,
    use_fast_accum=True,
)[0]

Replaced with:

out_grad_deq = (out_grad_fp8.to(ctx.out_dtype) * out_grad_scale).contiguous()
w_t_deq = (ctx.w_fp8_t.to(ctx.out_dtype) * ctx.inv_scale_w).contiguous()
x_deq = (ctx.x_fp8.to(ctx.out_dtype) * ctx.inv_scale_x).contiguous()

x_grad = out_grad_deq @ w_t_deq.t()
w_grad = out_grad_deq.t() @ x_deq

Implications

  • Removes reliance on _scaled_mm(), which is not FakeTensor-safe.
  • Ensures compatibility with TorchDynamo execution.
  • Uses standard matrix multiplication for correctness and stability.

Verification

  • Training runs successfully under TorchDynamo.
  • No runtime errors related to FakeTensor or _scaled_mm().
  • Numerical behavior remains consistent.

💥 Checklist before requesting a review

  • I have linked my PR to an issue.
  • My issue clearly describes the problem.
  • I have performed a self-review of my code.
  • I have added thorough tests.
  • I have added docstrings for all implemented functions.

⭐️ Do you enjoy contributing to Colossal-AI?

  • 🌝 Yes, I do.
  • 🌚 No, I don't.

@Truong5724 Truong5724 requested a review from a team as a code owner April 15, 2026 12:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG]: LowLevelZeroPlugin with use_fp8=true causes shape mismatch error: shape [32, 512] is invalid for input of size 512

1 participant