Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.venv
benchmarks/microbenchmarks/asv/results/
*.o
*.swp
*.ii
Expand Down
171 changes: 171 additions & 0 deletions benchmarks/microbenchmarks/asv/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# TransformerEngine Microbenchmarks

GPU microbenchmarks for TE ops (GEMM, FP8 GEMM, grouped GEMM, attention,
casting, normalization), run in-process by `driver.py`. Each suite is a
`bench_*.py` file with a `Bench*` class; the driver times every `time_*` method,
prints a table with throughput, and saves raw per-call samples to JSON for
statistical comparison.

## Prerequisites

- TransformerEngine built and installed in the current Python environment.
- A ROCm or CUDA GPU.

## Running

```bash
cd benchmarks/microbenchmarks/asv
python driver.py --all # run every suite
python driver.py bench_gemm # run one suite via the driver
python bench_gemm.py # run one suite directly
python bench_gemm.py time_forward # filter to methods containing a string
python bench_gemm.py -w 5 -n 20 # custom warmup / timed iterations
python bench_casting.py --no-save # don't write a result file
python bench_casting.py --cold-cache # flush GPU cache before each sample
python bench_gemm.py --inner 50 # fix the inner-loop count to 50
python bench_gemm.py --kernel-profile # per-kernel CUDA-time breakdown
```

Results are written to `benchmarks/microbenchmarks/asv/results/<commit-hash>.json`
(gitignored), one raw-sample record per benchmark + parameter combination.

## Timing model: inner loop and cache state

Each `time_*` method runs its kernel `_inner` times inside one CUDA-event window
and divides by `_inner`, amortizing kernel-launch and CUDA-event jitter
(`~0.5 µs` on AMD). By default the driver auto-tunes `_inner` per (combo, method)
so each window lasts at least `--target-window-ms` (default `1.0 ms`).

| Flag | Effect |
|---|---|
| `--inner auto` (default) | Probe one invocation, then pick `_inner` so the next window lasts ≥ `--target-window-ms` (capped at 10000). |
| `--inner N` | Force a fixed `_inner = N`. |
| `--target-window-ms T` | Target window duration for `--inner auto` (default `1.0`). |
| `--cold-cache` | Write a `--cache-flush-mb` scratch buffer before each sample to evict L2 + Infinity Cache. Implies `--inner=1` (otherwise later inner iterations refill the cache). |
| `--cache-flush-mb M` | Scratch buffer size for `--cold-cache` (default `256`, sized for the MI300 Infinity Cache). |

- **Warm cache, large `_inner`** (default): steady-state throughput, lowest variance.
- **Cold cache, `_inner=1`**: isolated cold-memory cost — higher variance; bandwidth-bound benches (cast, norm) run ~1.5–3× slower than warm.

## Kernel profiling

`--kernel-profile` runs each benchmark once under `torch.profiler` instead of
collecting timing distributions, and prints the GPU kernels it launched, sorted
by total device time:

```bash
python driver.py bench_gemm --kernel-profile
python bench_attention.py time_forward --kernel-profile # one method
```

For each `(method, parameter combo)` it reports per-kernel total/avg CUDA time,
launch count, and share of total — useful for spotting which kernel dominates or
whether an op is launch-bound. This bypasses the timing machinery (`--inner`,
`--cold-cache`, interleaving); `--profile-inner N` sets how many invocations are
profiled per run (default `1`). Output is saved to
`results/<commit-hash>-kernelprofile.json` unless `--no-save`.

## Sample scheduling: interleaving

By default the driver does **not** collect a benchmark's samples in one
contiguous block. It samples in round-robin chunks: it sets up a group of
`(method, combo)` benchmarks, then takes one sample from each per round, for `-n`
rounds. Sequential scheduling (all of A, then all of B) makes wall-clock time a
proxy for benchmark identity, so any time-correlated GPU noise (thermal ramp,
DVFS throttle, a neighbor on a shared GPU) becomes a systematic **bias** between
benchmarks rather than noise. Round-robin spreads every benchmark across the same
window, so a transient lands on one sample of each. The per-round visit order is
also randomly permuted (seeded, so runs are reproducible) to remove residual
within-round phase/predecessor bias.

| Flag | Effect |
|---|---|
| `--interleave-group N` (default `8`) | Benchmarks sampled round-robin together. Each keeps a live GPU instance, so **lower this if a group runs out of memory**. |
| `--sequential` | Collect each benchmark's samples contiguously (≡ `--interleave-group 1`). Lowest memory, biased under thermal drift. |
| `--seed S` (default `0`) | Seed for the per-round shuffle. |
| `--no-shuffle` | Fixed round-robin order instead of permuting each round (debugging). |

Interleaving removes *within-run* time-position bias. It does **not** remove a
whole-run thermal offset between two separately produced result files, so for the
comparison below, produce the baseline and candidate files back-to-back under
similar conditions.

## Comparing two checkouts statistically

The driver records raw per-call samples; `compare_results.py` compares two result
files with a Brunner-Munzel test via
[benchstats](https://github.com/Arech/benchstats):

```bash
pip install -r requirements.txt # benchstats (pulls rich, scipy, numpy)
cd benchmarks/microbenchmarks/asv

python driver.py --all -n 20 # on the baseline checkout -> results/<base>.json
python driver.py --all -n 20 # on the candidate checkout -> results/<cand>.json
python compare_results.py results/<base>.json results/<cand>.json
```

It marks each `(benchmark, parameter combination)` faster (`>`), slower (`<`), or
not significant (`~`), and exits `1` on a significant difference (CI gating).

Two runs on the **same** commit (e.g. a dirty working tree, where `HEAD` is
unchanged) would overwrite each other; pass `--label` to keep them distinct:

```bash
python driver.py --all -n 20 --label base # -> results/<hash>-base.json
python driver.py --all -n 20 --label cand # -> results/<hash>-cand.json
python compare_results.py results/<hash>-base.json results/<hash>-cand.json
```

| Flag | Effect |
|---|---|
| `--alpha A` | Significance level (default `0.001`). |
| `--method M` | Statistical test (default `brunnermunzel`). |
| `--filter REGEX` | Only compare benchmarks whose name matches `REGEX`. |
| `--always-show-pvalues` | Show p-values for non-significant rows too. |
| `--export-to FILE` | Save the report to `.txt`/`.svg`/`.html`. |

The rank test needs a reasonable sample count (≥ ~10); the default `-n 20`
satisfies this. Only timing is tested — throughput is a constant-work transform
of time, so a rank test on it is identical.

## Writing a new benchmark

Add `bench_<name>.py` with a `Bench*` class subclassing `BenchBase`. Pull model
shapes from `models.py` so configs stay in one place.

```python
import torch
import transformer_engine.pytorch as te

from driver import BenchBase, run_as_main
from models import M_SIZES

class BenchSomething(BenchBase):
params = [M_SIZES, ["config_a", "config_b"]]
param_names = ["M", "config"]

def setup(self, M, config):
# Allocate tensors / modules. Runs once per (combo, method); the same
# instance is reused for warmup and timed iterations.
self.module = ...
self.x = ...

def time_forward(self, M, config):
# self._time runs the callable _inner times in one CUDA-event window
# and returns seconds per single invocation (handles --cold-cache).
return self._time(lambda: self.module(self.x))

# Optional: work_<name> returns per-call work for throughput columns.
def work_forward(self, M, config):
return {"flops": 2 * M * self.N * self.K} # or {"bytes": ...}

if __name__ == "__main__":
run_as_main(__file__)
```

Rules:
- `time_*` methods are timed automatically; time through `self._time(fn)`.
- `work_<name>` companions return **per-call** work and yield TFLOPS (`flops`) or GB/s (`bytes`) columns.
- Clear `.grad` attributes in backward benchmarks to prevent accumulation.
- `params` is a cross-product — keep the matrix size reasonable.
59 changes: 59 additions & 0 deletions benchmarks/microbenchmarks/asv/bench_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
###############################################################################
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################
"""Attention benchmarks via te.DotProductAttention (causal, GQA).

Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim
(Q@K^T and attn@V, each 2*b*h*s^2*d).
Backward FLOPs ~= 2 * Forward FLOPs.
"""

import torch
import transformer_engine.pytorch as te

from driver import BenchBase, run_as_main
from models import M_SIZES, attention_configs

BATCH = 2
MODELS = attention_configs() # name -> (num_q_heads, num_kv_heads, head_dim, tp)


class BenchAttention(BenchBase):
params = [M_SIZES, list(MODELS)] # M_SIZES used as seq_len
param_names = ["seq_len", "model"]

def setup(self, seq_len, model):
n_q, n_kv, hd, tp = MODELS[model]
qh, kvh = n_q // tp, n_kv // tp
dtype = torch.bfloat16
self.attn = te.DotProductAttention(
num_attention_heads=qh, kv_channels=hd,
num_gqa_groups=kvh, attn_mask_type="causal",
).to(device="cuda", dtype=dtype)
self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True)
self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True)
self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True)
self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v))

def work_forward(self, seq_len, model):
n_q, _, hd, tp = MODELS[model]
return {"flops": 4 * BATCH * (n_q // tp) * seq_len * seq_len * hd}

def work_forward_backward(self, seq_len, model):
n_q, _, hd, tp = MODELS[model]
return {"flops": 3 * 4 * BATCH * (n_q // tp) * seq_len * seq_len * hd}

def time_forward(self, seq_len, model):
return self._time(lambda: self.attn(self.q, self.k, self.v))

def time_forward_backward(self, seq_len, model):
t = self._time(lambda: self.attn(self.q, self.k, self.v).backward(self.grad_out))
self.q.grad = self.k.grad = self.v.grad = None
return t


if __name__ == "__main__":
run_as_main(__file__)
59 changes: 59 additions & 0 deletions benchmarks/microbenchmarks/asv/bench_casting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
###############################################################################
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################
"""Quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) benchmarks.

Covers E4M3 (activations/weights) and E5M2 (gradients). These casts are
memory-bound, so we report GB/s (input + output bytes).
"""

import torch
from transformer_engine.pytorch import Float8CurrentScalingQuantizer
from transformer_engine_torch import DType as TE_DType

from driver import BenchBase, run_as_main
from models import M_SIZES, hidden_sizes

HIDDEN = hidden_sizes()

# cast name -> (direction, fp8 dtype)
CAST_CONFIGS = {
"BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3),
"E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3),
"BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2),
"E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2),
}


class BenchCasting(BenchBase):
params = [M_SIZES, list(HIDDEN), list(CAST_CONFIGS)]
param_names = ["M", "model", "cast"]

def setup(self, M, model, cast):
hidden = HIDDEN[model]
direction, fp8_dtype = CAST_CONFIGS[cast]
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=fp8_dtype, device=torch.device("cuda"),
rowwise=True, columnwise=False,
)
if direction == "dequantize":
x = quantizer.quantize(torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda"))
self._call = lambda: x.dequantize(dtype=torch.bfloat16)
else:
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
self._call = lambda: quantizer.quantize(x)

def work_cast(self, M, model, cast):
# quantize: read BF16 (2B) + write FP8 (1B) + scale; dequantize: the
# reverse -- 3 bytes/element either way.
return {"bytes": M * HIDDEN[model] * 3}

def time_cast(self, M, model, cast):
return self._time(self._call)


if __name__ == "__main__":
run_as_main(__file__)
52 changes: 52 additions & 0 deletions benchmarks/microbenchmarks/asv/bench_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python3
###############################################################################
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################
"""BF16 GEMM benchmarks via te.Linear.

Shapes are the four transformer projections (QKV, AttnOut, GateUp, Down)
derived from the models in models.py.
"""

import torch
import transformer_engine.pytorch as te

from driver import BenchBase, run_as_main
from models import M_SIZES, gemm_shapes

SHAPES = gemm_shapes()


class BenchGemm(BenchBase):
params = [M_SIZES, list(SHAPES)]
param_names = ["M", "shape"]

def setup(self, M, shape):
N, K = SHAPES[shape]
dtype = torch.bfloat16
self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype)
self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True)
self.grad_out = torch.randn_like(self.linear(self.x))

def work_forward(self, M, shape):
N, K = SHAPES[shape]
return {"flops": 2 * M * N * K}

def work_forward_backward(self, M, shape):
N, K = SHAPES[shape]
return {"flops": 3 * 2 * M * N * K}

def time_forward(self, M, shape):
return self._time(lambda: self.linear(self.x))

def time_forward_backward(self, M, shape):
t = self._time(lambda: self.linear(self.x).backward(self.grad_out))
self.x.grad = None
self.linear.weight.grad = None
return t


if __name__ == "__main__":
run_as_main(__file__)
Loading