diff --git a/.gitignore b/.gitignore index fca0d9389..6aae86cad 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .venv +benchmarks/microbenchmarks/asv/results/ *.o *.swp *.ii diff --git a/benchmarks/microbenchmarks/asv/README.md b/benchmarks/microbenchmarks/asv/README.md new file mode 100644 index 000000000..10644f8f7 --- /dev/null +++ b/benchmarks/microbenchmarks/asv/README.md @@ -0,0 +1,171 @@ +# TransformerEngine Microbenchmarks + +GPU microbenchmarks for TE ops (GEMM, FP8 GEMM, grouped GEMM, attention, +casting, normalization), run in-process by `driver.py`. Each suite is a +`bench_*.py` file with a `Bench*` class; the driver times every `time_*` method, +prints a table with throughput, and saves raw per-call samples to JSON for +statistical comparison. + +## Prerequisites + +- TransformerEngine built and installed in the current Python environment. +- A ROCm or CUDA GPU. + +## Running + +```bash +cd benchmarks/microbenchmarks/asv +python driver.py --all # run every suite +python driver.py bench_gemm # run one suite via the driver +python bench_gemm.py # run one suite directly +python bench_gemm.py time_forward # filter to methods containing a string +python bench_gemm.py -w 5 -n 20 # custom warmup / timed iterations +python bench_casting.py --no-save # don't write a result file +python bench_casting.py --cold-cache # flush GPU cache before each sample +python bench_gemm.py --inner 50 # fix the inner-loop count to 50 +python bench_gemm.py --kernel-profile # per-kernel CUDA-time breakdown +``` + +Results are written to `benchmarks/microbenchmarks/asv/results/.json` +(gitignored), one raw-sample record per benchmark + parameter combination. + +## Timing model: inner loop and cache state + +Each `time_*` method runs its kernel `_inner` times inside one CUDA-event window +and divides by `_inner`, amortizing kernel-launch and CUDA-event jitter +(`~0.5 µs` on AMD). By default the driver auto-tunes `_inner` per (combo, method) +so each window lasts at least `--target-window-ms` (default `1.0 ms`). + +| Flag | Effect | +|---|---| +| `--inner auto` (default) | Probe one invocation, then pick `_inner` so the next window lasts ≥ `--target-window-ms` (capped at 10000). | +| `--inner N` | Force a fixed `_inner = N`. | +| `--target-window-ms T` | Target window duration for `--inner auto` (default `1.0`). | +| `--cold-cache` | Write a `--cache-flush-mb` scratch buffer before each sample to evict L2 + Infinity Cache. Implies `--inner=1` (otherwise later inner iterations refill the cache). | +| `--cache-flush-mb M` | Scratch buffer size for `--cold-cache` (default `256`, sized for the MI300 Infinity Cache). | + +- **Warm cache, large `_inner`** (default): steady-state throughput, lowest variance. +- **Cold cache, `_inner=1`**: isolated cold-memory cost — higher variance; bandwidth-bound benches (cast, norm) run ~1.5–3× slower than warm. + +## Kernel profiling + +`--kernel-profile` runs each benchmark once under `torch.profiler` instead of +collecting timing distributions, and prints the GPU kernels it launched, sorted +by total device time: + +```bash +python driver.py bench_gemm --kernel-profile +python bench_attention.py time_forward --kernel-profile # one method +``` + +For each `(method, parameter combo)` it reports per-kernel total/avg CUDA time, +launch count, and share of total — useful for spotting which kernel dominates or +whether an op is launch-bound. This bypasses the timing machinery (`--inner`, +`--cold-cache`, interleaving); `--profile-inner N` sets how many invocations are +profiled per run (default `1`). Output is saved to +`results/-kernelprofile.json` unless `--no-save`. + +## Sample scheduling: interleaving + +By default the driver does **not** collect a benchmark's samples in one +contiguous block. It samples in round-robin chunks: it sets up a group of +`(method, combo)` benchmarks, then takes one sample from each per round, for `-n` +rounds. Sequential scheduling (all of A, then all of B) makes wall-clock time a +proxy for benchmark identity, so any time-correlated GPU noise (thermal ramp, +DVFS throttle, a neighbor on a shared GPU) becomes a systematic **bias** between +benchmarks rather than noise. Round-robin spreads every benchmark across the same +window, so a transient lands on one sample of each. The per-round visit order is +also randomly permuted (seeded, so runs are reproducible) to remove residual +within-round phase/predecessor bias. + +| Flag | Effect | +|---|---| +| `--interleave-group N` (default `8`) | Benchmarks sampled round-robin together. Each keeps a live GPU instance, so **lower this if a group runs out of memory**. | +| `--sequential` | Collect each benchmark's samples contiguously (≡ `--interleave-group 1`). Lowest memory, biased under thermal drift. | +| `--seed S` (default `0`) | Seed for the per-round shuffle. | +| `--no-shuffle` | Fixed round-robin order instead of permuting each round (debugging). | + +Interleaving removes *within-run* time-position bias. It does **not** remove a +whole-run thermal offset between two separately produced result files, so for the +comparison below, produce the baseline and candidate files back-to-back under +similar conditions. + +## Comparing two checkouts statistically + +The driver records raw per-call samples; `compare_results.py` compares two result +files with a Brunner-Munzel test via +[benchstats](https://github.com/Arech/benchstats): + +```bash +pip install -r requirements.txt # benchstats (pulls rich, scipy, numpy) +cd benchmarks/microbenchmarks/asv + +python driver.py --all -n 20 # on the baseline checkout -> results/.json +python driver.py --all -n 20 # on the candidate checkout -> results/.json +python compare_results.py results/.json results/.json +``` + +It marks each `(benchmark, parameter combination)` faster (`>`), slower (`<`), or +not significant (`~`), and exits `1` on a significant difference (CI gating). + +Two runs on the **same** commit (e.g. a dirty working tree, where `HEAD` is +unchanged) would overwrite each other; pass `--label` to keep them distinct: + +```bash +python driver.py --all -n 20 --label base # -> results/-base.json +python driver.py --all -n 20 --label cand # -> results/-cand.json +python compare_results.py results/-base.json results/-cand.json +``` + +| Flag | Effect | +|---|---| +| `--alpha A` | Significance level (default `0.001`). | +| `--method M` | Statistical test (default `brunnermunzel`). | +| `--filter REGEX` | Only compare benchmarks whose name matches `REGEX`. | +| `--always-show-pvalues` | Show p-values for non-significant rows too. | +| `--export-to FILE` | Save the report to `.txt`/`.svg`/`.html`. | + +The rank test needs a reasonable sample count (≥ ~10); the default `-n 20` +satisfies this. Only timing is tested — throughput is a constant-work transform +of time, so a rank test on it is identical. + +## Writing a new benchmark + +Add `bench_.py` with a `Bench*` class subclassing `BenchBase`. Pull model +shapes from `models.py` so configs stay in one place. + +```python +import torch +import transformer_engine.pytorch as te + +from driver import BenchBase, run_as_main +from models import M_SIZES + +class BenchSomething(BenchBase): + params = [M_SIZES, ["config_a", "config_b"]] + param_names = ["M", "config"] + + def setup(self, M, config): + # Allocate tensors / modules. Runs once per (combo, method); the same + # instance is reused for warmup and timed iterations. + self.module = ... + self.x = ... + + def time_forward(self, M, config): + # self._time runs the callable _inner times in one CUDA-event window + # and returns seconds per single invocation (handles --cold-cache). + return self._time(lambda: self.module(self.x)) + + # Optional: work_ returns per-call work for throughput columns. + def work_forward(self, M, config): + return {"flops": 2 * M * self.N * self.K} # or {"bytes": ...} + +if __name__ == "__main__": + run_as_main(__file__) +``` + +Rules: +- `time_*` methods are timed automatically; time through `self._time(fn)`. +- `work_` companions return **per-call** work and yield TFLOPS (`flops`) or GB/s (`bytes`) columns. +- Clear `.grad` attributes in backward benchmarks to prevent accumulation. +- `params` is a cross-product — keep the matrix size reasonable. diff --git a/benchmarks/microbenchmarks/asv/bench_attention.py b/benchmarks/microbenchmarks/asv/bench_attention.py new file mode 100644 index 000000000..395bb07cb --- /dev/null +++ b/benchmarks/microbenchmarks/asv/bench_attention.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Attention benchmarks via te.DotProductAttention (causal, GQA). + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim + (Q@K^T and attn@V, each 2*b*h*s^2*d). +Backward FLOPs ~= 2 * Forward FLOPs. +""" + +import torch +import transformer_engine.pytorch as te + +from driver import BenchBase, run_as_main +from models import M_SIZES, attention_configs + +BATCH = 2 +MODELS = attention_configs() # name -> (num_q_heads, num_kv_heads, head_dim, tp) + + +class BenchAttention(BenchBase): + params = [M_SIZES, list(MODELS)] # M_SIZES used as seq_len + param_names = ["seq_len", "model"] + + def setup(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh, kvh = n_q // tp, n_kv // tp + dtype = torch.bfloat16 + self.attn = te.DotProductAttention( + num_attention_heads=qh, kv_channels=hd, + num_gqa_groups=kvh, attn_mask_type="causal", + ).to(device="cuda", dtype=dtype) + self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v)) + + def work_forward(self, seq_len, model): + n_q, _, hd, tp = MODELS[model] + return {"flops": 4 * BATCH * (n_q // tp) * seq_len * seq_len * hd} + + def work_forward_backward(self, seq_len, model): + n_q, _, hd, tp = MODELS[model] + return {"flops": 3 * 4 * BATCH * (n_q // tp) * seq_len * seq_len * hd} + + def time_forward(self, seq_len, model): + return self._time(lambda: self.attn(self.q, self.k, self.v)) + + def time_forward_backward(self, seq_len, model): + t = self._time(lambda: self.attn(self.q, self.k, self.v).backward(self.grad_out)) + self.q.grad = self.k.grad = self.v.grad = None + return t + + +if __name__ == "__main__": + run_as_main(__file__) diff --git a/benchmarks/microbenchmarks/asv/bench_casting.py b/benchmarks/microbenchmarks/asv/bench_casting.py new file mode 100644 index 000000000..9f4399b03 --- /dev/null +++ b/benchmarks/microbenchmarks/asv/bench_casting.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) benchmarks. + +Covers E4M3 (activations/weights) and E5M2 (gradients). These casts are +memory-bound, so we report GB/s (input + output bytes). +""" + +import torch +from transformer_engine.pytorch import Float8CurrentScalingQuantizer +from transformer_engine_torch import DType as TE_DType + +from driver import BenchBase, run_as_main +from models import M_SIZES, hidden_sizes + +HIDDEN = hidden_sizes() + +# cast name -> (direction, fp8 dtype) +CAST_CONFIGS = { + "BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3), + "E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3), + "BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2), + "E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2), +} + + +class BenchCasting(BenchBase): + params = [M_SIZES, list(HIDDEN), list(CAST_CONFIGS)] + param_names = ["M", "model", "cast"] + + def setup(self, M, model, cast): + hidden = HIDDEN[model] + direction, fp8_dtype = CAST_CONFIGS[cast] + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, device=torch.device("cuda"), + rowwise=True, columnwise=False, + ) + if direction == "dequantize": + x = quantizer.quantize(torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")) + self._call = lambda: x.dequantize(dtype=torch.bfloat16) + else: + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self._call = lambda: quantizer.quantize(x) + + def work_cast(self, M, model, cast): + # quantize: read BF16 (2B) + write FP8 (1B) + scale; dequantize: the + # reverse -- 3 bytes/element either way. + return {"bytes": M * HIDDEN[model] * 3} + + def time_cast(self, M, model, cast): + return self._time(self._call) + + +if __name__ == "__main__": + run_as_main(__file__) diff --git a/benchmarks/microbenchmarks/asv/bench_gemm.py b/benchmarks/microbenchmarks/asv/bench_gemm.py new file mode 100644 index 000000000..24319cf80 --- /dev/null +++ b/benchmarks/microbenchmarks/asv/bench_gemm.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""BF16 GEMM benchmarks via te.Linear. + +Shapes are the four transformer projections (QKV, AttnOut, GateUp, Down) +derived from the models in models.py. +""" + +import torch +import transformer_engine.pytorch as te + +from driver import BenchBase, run_as_main +from models import M_SIZES, gemm_shapes + +SHAPES = gemm_shapes() + + +class BenchGemm(BenchBase): + params = [M_SIZES, list(SHAPES)] + param_names = ["M", "shape"] + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.linear(self.x)) + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def time_forward(self, M, shape): + return self._time(lambda: self.linear(self.x)) + + def time_forward_backward(self, M, shape): + t = self._time(lambda: self.linear(self.x).backward(self.grad_out)) + self.x.grad = None + self.linear.weight.grad = None + return t + + +if __name__ == "__main__": + run_as_main(__file__) diff --git a/benchmarks/microbenchmarks/asv/bench_gemm_fp8.py b/benchmarks/microbenchmarks/asv/bench_gemm_fp8.py new file mode 100644 index 000000000..a6f761afa --- /dev/null +++ b/benchmarks/microbenchmarks/asv/bench_gemm_fp8.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""FP8 GEMM benchmarks via te.Linear under fp8_autocast. + +Same shapes as bench_gemm.py but with FP8 (HYBRID) quantized compute. +""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +from driver import BenchBase, run_as_main +from models import M_SIZES, gemm_shapes + +SHAPES = gemm_shapes() +FP8_RECIPE = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max", +) + + +class BenchGemmFP8(BenchBase): + params = [M_SIZES, list(SHAPES)] + param_names = ["M", "shape"] + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn(M, N, dtype=dtype, device="cuda") + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def _forward(self): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + return self.linear(self.x) + + def time_forward(self, M, shape): + return self._time(self._forward) + + def time_forward_backward(self, M, shape): + t = self._time(lambda: self._forward().backward(self.grad_out)) + self.x.grad = None + self.linear.weight.grad = None + return t + + +if __name__ == "__main__": + run_as_main(__file__) diff --git a/benchmarks/microbenchmarks/asv/bench_grouped_gemm.py b/benchmarks/microbenchmarks/asv/bench_grouped_gemm.py new file mode 100644 index 000000000..58b1d27fb --- /dev/null +++ b/benchmarks/microbenchmarks/asv/bench_grouped_gemm.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Grouped GEMM benchmarks via te.GroupedLinear (MoE GateUp / Down).""" + +import torch +import transformer_engine.pytorch as te + +from driver import BenchBase, run_as_main +from models import M_SIZES_MOE, grouped_gemm_configs + +CONFIGS = grouped_gemm_configs() # name -> (num_gemms, N, K) + + +class BenchGroupedGemm(BenchBase): + params = [M_SIZES_MOE, list(CONFIGS)] + param_names = ["M", "config"] + + def setup(self, M, config): + B, N, K = CONFIGS[config] + dtype = torch.bfloat16 + self.module = te.GroupedLinear( + num_gemms=B, in_features=K, out_features=N, bias=False, + ).to(device="cuda", dtype=dtype) + self.xs = [ + torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + for _ in range(B) + ] + self.grad_outs = [torch.randn_like(o) for o in self.module(self.xs)] + + def work_forward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 2 * M * N * K} + + def work_forward_backward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 3 * 2 * M * N * K} + + def time_forward(self, M, config): + return self._time(lambda: self.module(self.xs)) + + def time_forward_backward(self, M, config): + t = self._time(lambda: torch.autograd.backward(self.module(self.xs), self.grad_outs)) + for x in self.xs: + x.grad = None + for p in self.module.parameters(): + p.grad = None + return t + + +if __name__ == "__main__": + run_as_main(__file__) diff --git a/benchmarks/microbenchmarks/asv/bench_normalization.py b/benchmarks/microbenchmarks/asv/bench_normalization.py new file mode 100644 index 000000000..3412e4170 --- /dev/null +++ b/benchmarks/microbenchmarks/asv/bench_normalization.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""RMSNorm and LayerNorm benchmarks on activation-sized tensors. + +Memory-bound; we report GB/s. The hidden dimension is swept over the distinct +model hidden sizes and M (batch * seq_len) over typical training sizes. +""" + +import torch +import transformer_engine.pytorch as te + +from driver import BenchBase, run_as_main +from models import M_SIZES, unique_hidden_sizes + +NORMS = {"RMSNorm": te.RMSNorm, "LayerNorm": te.LayerNorm} + + +class BenchNormalization(BenchBase): + params = [M_SIZES, unique_hidden_sizes(), list(NORMS)] + param_names = ["M", "hidden", "norm_type"] + + def setup(self, M, hidden, norm_type): + dtype = torch.bfloat16 + self.norm = NORMS[norm_type](hidden).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, hidden, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.norm(self.x)) + + def work_forward(self, M, hidden, norm_type): + # read input (2B) + write output (2B) + return {"bytes": M * hidden * 4} + + def work_forward_backward(self, M, hidden, norm_type): + # fwd read+write (4B) + bwd read input+grad_out, write grad_in (6B) + return {"bytes": M * hidden * 10} + + def time_forward(self, M, hidden, norm_type): + return self._time(lambda: self.norm(self.x)) + + def time_forward_backward(self, M, hidden, norm_type): + t = self._time(lambda: self.norm(self.x).backward(self.grad_out)) + self.x.grad = None + for p in self.norm.parameters(): + p.grad = None + return t + + +if __name__ == "__main__": + run_as_main(__file__) diff --git a/benchmarks/microbenchmarks/asv/compare_results.py b/benchmarks/microbenchmarks/asv/compare_results.py new file mode 100755 index 000000000..18ea2dd3b --- /dev/null +++ b/benchmarks/microbenchmarks/asv/compare_results.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Statistically compare two result JSON files written by ``driver.py``. + +A point-estimate (median) cannot tell a real regression from measurement noise. +This tool compares the raw per-call samples stored in two result files (one per +checkout) with a statistical test (Brunner-Munzel by default) via the benchstats +package. It marks each (benchmark, parameter combination) as faster (``>``), +slower (``<``), or not significantly different (``~``), prints a per-direction +summary, and exits ``1`` when a significant timing difference is found so it can +gate CI. Requires ``pip install -r requirements.txt``. + +Usage: + # run the suite on each checkout (each saves .json), then: + python compare_results.py results/.json results/.json + python compare_results.py base.json cand.json --alpha 0.01 + python compare_results.py base.json cand.json --export-to report.svg +""" + +import argparse +import json +import os +import re +import sys + +import numpy as np + +_TIME_KEY = "time_s" # metric exposed to benchstats (seconds, lower is better) + + +def _load_samples(path, name_filter=None): + """Load a driver result JSON into ``{bench_name: {"time_s": ndarray}}``. + + One benchstats "benchmark" per (benchmark, parameter combination); the name + is ``.. | name=val, ...``. Only timing is exposed: + throughput is a constant-work transform of time, so a rank test on it is + identical. + """ + with open(path) as f: + data = json.load(f) + pattern = re.compile(name_filter) if name_filter else None + + stats = {} + for bench_key, rec in data.get("results", {}).items(): + param_names = rec.get("param_names") or [] + for combo, samples in zip(rec.get("combos") or [], rec.get("samples") or []): + if not samples: + continue + arr = np.asarray(samples, dtype=np.float64) + arr = arr[np.isfinite(arr)] + if arr.size == 0: + continue + if param_names and len(param_names) == len(combo): + label = ", ".join(f"{n}={v}" for n, v in zip(param_names, combo)) + else: + label = ", ".join(str(v) for v in combo) + name = bench_key + (" | " + label if label else "") + if pattern is not None and pattern.search(name) is None: + continue + stats[name] = {_TIME_KEY: arr} + return stats + + +def run_stats(args): + """Compare two result JSONs; return 1 if a significant difference is found.""" + from benchstats.compare import compareStats + from benchstats.render import renderComparisonResults + from benchstats.common import LoggingConsole, detectExportFormat + + main_metrics = [_TIME_KEY] + export_fmt = detectExportFormat(args.export_to, None) if args.export_to else None + if export_fmt is not None and os.path.isfile(args.export_to): + os.remove(args.export_to) + + console = LoggingConsole( + record=export_fmt is not None, log_level=LoggingConsole.LogLevel.Warning, + ) + + s1 = _load_samples(args.baseline_json, args.filter) + s2 = _load_samples(args.candidate_json, args.filter) + + cr = compareStats( + s1, s2, method=args.method, alpha=args.alpha, + main_metrics=main_metrics, debug_log=console, + ) + renderComparisonResults( + cr, console, main_metrics=main_metrics, + always_show_pvalues=args.always_show_pvalues, + ) + + # benchstats encodes each comparison as baseline-vs-candidate: "<" means + # baseline < candidate (candidate slower -> regression), ">" means candidate + # faster, "~" means not significant at alpha. + for metric in main_metrics: + counts = {"<": 0, ">": 0, "~": 0} + for bm_res in cr.results.values(): + res = bm_res.get(metric) + if res is not None: + counts[res.result] = counts.get(res.result, 0) + 1 + total = sum(counts.values()) + console.print( + f"\nSummary for '{metric}' ({cr.method}, alpha={cr.alpha:g}, " + f"{total} benchmarks):" + ) + console.print(f" candidate faster (significant, '>'): {counts['>']}") + console.print(f" candidate slower (significant, '<'): {counts['<']}") + console.print(f" no significant difference ('~'): {counts['~']}") + + if export_fmt is not None: + {"txt": lambda: console.save_text(args.export_to), + "svg": lambda: console.save_svg(args.export_to, title=""), + "html": lambda: console.save_html(args.export_to)}[export_fmt]() + + if cr.at_least_one_differs: + console.warning("At least one significant timing difference was detected (exit 1).") + return 1 + return 0 + + +def main(): + parser = argparse.ArgumentParser( + description="Statistically compare two driver result JSONs via benchstats.") + parser.add_argument("baseline_json", help="Baseline result JSON") + parser.add_argument("candidate_json", help="Candidate result JSON") + parser.add_argument("--filter", default=None, + help="Only compare benchmarks whose name matches this regex.") + parser.add_argument("--alpha", type=float, default=0.001, + help="Significance level for the test (default: 0.001).") + parser.add_argument("--method", default="brunnermunzel", + help="Statistical test to use (default: brunnermunzel).") + parser.add_argument("--always-show-pvalues", action="store_true", + help="Show p-values for non-significant rows too.") + parser.add_argument("--export-to", default=None, metavar="FILE", + help="Export the report to a .txt/.svg/.html file (format from extension).") + return run_stats(parser.parse_args()) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/benchmarks/microbenchmarks/asv/driver.py b/benchmarks/microbenchmarks/asv/driver.py new file mode 100644 index 000000000..1443515f7 --- /dev/null +++ b/benchmarks/microbenchmarks/asv/driver.py @@ -0,0 +1,593 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""In-process microbenchmark driver. + +Discovers ``Bench*`` classes in ``bench_*.py`` files, runs their ``time_*`` +methods with robust GPU timing (inner-loop amortization, optional cold cache, +round-robin interleaving), prints a table with throughput, and saves the raw +per-call samples to JSON for ``compare_results.py``. + +Usage: + python driver.py [method_filter] [-w W] [-n N] [--no-save] + python driver.py --all [-w W] [-n N] + python bench_gemm.py [method_filter] [-w W] [-n N] # bench file as main +""" + +import argparse +import glob +import importlib +import itertools +import json +import os +import random +import re +import subprocess +import sys +import time + +import numpy as np + + +# --------------------------------------------------------------------------- +# Benchmark base class +# --------------------------------------------------------------------------- + +class BenchBase: + """Base for benchmark classes: driver-controlled knobs + the timing helper. + + The driver sets ``_inner`` (kernel invocations per CUDA-event window, to + amortize launch + event overhead) and ``_scratch`` (a buffer written before + each sample to evict the GPU cache in ``--cold-cache`` mode) per + (combo, method). Subclasses time their kernels through :meth:`_time`. + """ + + _inner = 1 + _scratch = None + + def _time(self, fn): + """Run *fn* ``_inner`` times in one CUDA-event window; return seconds/call. + + Honors ``--cold-cache`` (flush scratch before the window) and ``--inner`` + (loop count). The per-call value is what the driver and throughput + columns consume regardless of inner-loop count. + """ + import torch # deferred: driver stays importable without torch + evt = getattr(self, "_evt", None) + if evt is None: + evt = self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + if self._scratch is not None: + self._scratch.fill_(1.0) + evt[0].record() + for _ in range(self._inner): + fn() + evt[1].record() + torch.cuda.synchronize() + return evt[0].elapsed_time(evt[1]) / 1000 / self._inner + + +# --------------------------------------------------------------------------- +# Results +# --------------------------------------------------------------------------- + +def _get_commit_hash(): + """Current git HEAD hash, or 'unknown' outside a checkout.""" + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except Exception: + return "unknown" + + +def _results_dir(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") + + +def save_results(all_results, label=None, results_dir=None): + """Write raw per-call samples to ``/[-