Skip to content

PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731

Open
sharifhsn wants to merge 2 commits intomne-tools:mainfrom
sharifhsn:perf-opt
Open

PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731
sharifhsn wants to merge 2 commits intomne-tools:mainfrom
sharifhsn:perf-opt

Conversation

@sharifhsn
Copy link
Copy Markdown

@sharifhsn sharifhsn commented Mar 9, 2026

Reference issue

Related: #5439, #7784, #8095, #12609

What does this implement/fix?

Speeds up spatio_temporal_cluster_1samp_test (and the other permutation_cluster_* functions) by ~5-10x on realistic data. The PR is split into 7 incremental commits. Maintainers can accept or reject each layer independently.

Commit 1 — Precompute sum-of-squares for sign-flip t-test (+29/−9 lines, 3.2x)
For the default ttest_1samp_no_p, s²=1 means sum(X²) is constant across permutations. Each permutation becomes a single signs @ X dot product instead of calling stat_fun. Also skips buffer_size verification for built-in stat functions.

Commit 2 — Numba union-find for spatio-temporal CCL (+226/−11 lines, 10.3x cumulative)
JIT-compiled union-find kernel (_st_fused_ccl) with path compression and union-by-rank, replacing the Python BFS in _get_clusters_st. Bundles tightly-coupled pieces: pre-computed CSR adjacency arrays, _sums_only flag to skip cluster list construction (uses np.bincount instead), and _csr_data parameter threading. These are bundled because _sums_only only fires inside if has_numba: and CSR data is only consumed by the Numba kernel.

Commit 3 — Extract _union helper + simplify (+36/−72 lines)
Extract duplicated find+union logic into _union() with inline="always", simplify _sum_cluster_data, trim docstrings/comments to match codebase style.

Commit 4 — Fix step-down reshape (+1/−1 lines)
Pre-existing bug: adjacency is None and adjacency is not False was equivalent to just adjacency is None, missing the adjacency is False case where step_down_include still needs reshaping.

Commit 5 — Changelog entries

Commit 6 — Test fixture (+1 line)
Patch has_numba in numba_conditional fixture so the "NumPy" test variant actually exercises the Python BFS fallback path for spatio-temporal clustering.

Commit 7 — Docstring (+18/−1 lines)
Expand _st_fused_ccl docstring with algorithm description, complexity analysis, and Wikipedia reference, per reviewer request.

Commits 3-7 are cleanup, bugfix, docs, and tests — they don't affect performance. All optimizations fall back to the original code paths when Numba is not installed. No public API changes.

Benchmarks

Per-commit cumulative speedup (local, Apple M-series, spatio_temporal_cluster_1samp_test, ico-5, 15 subjects x 15 timepoints x 20,484 vertices, threshold=3.0, 512 permutations, median of 3 runs):

Cumulative through ms/perm Speedup Net lines
main (baseline) 16.94 1.0x
commit 1 (precomputed sum_sq) 5.37 3.2x +20
commit 2 (Numba union-find) 1.64 10.3x +235

AWS HPC end-to-end (AMD EPYC 7R13, same data dimensions):

Permutations Before After Speedup
256 4.12 s 0.86 s 4.8x
1024 16.35 s 3.25 s 5.0x
4096 65.00 s 12.64 s 5.1x

Per-permutation cost: 15.8 ms → 3.1 ms (5.2x). Projected 10,000 permutations: 31 s vs 159 s.

Reproduce benchmarks locally
"""Quick benchmark: perf-opt vs baseline on realistic source-space data."""
import time
import numpy as np
import mne
from mne.stats import spatio_temporal_cluster_1samp_test
from mne.stats import cluster_level as cl

# Load fsaverage ico-5 adjacency
subjects_dir = mne.datasets.sample.data_path() / "subjects"
src = mne.setup_source_space(
    "fsaverage", spacing="ico5", subjects_dir=subjects_dir, add_dist=False
)
adjacency = mne.spatial_src_adjacency(src)

# Synthetic data: 15 subjects x 15 timepoints x 20,484 vertices
rng = np.random.default_rng(42)
X = rng.standard_normal((15, 15, adjacency.shape[0]))
X[:, 5:10, 1000:1100] += 1.0  # inject focal activation

# Warmup JIT
spatio_temporal_cluster_1samp_test(
    X, adjacency=adjacency, n_permutations=64,
    threshold=3.0, tail=1, verbose=False, seed=42
)

# Optimized
t0 = time.perf_counter()
spatio_temporal_cluster_1samp_test(
    X, adjacency=adjacency, n_permutations=512,
    threshold=3.0, tail=1, verbose=False, seed=42
)
t_opt = time.perf_counter() - t0

# Baseline (disable Numba path)
saved = cl.has_numba
cl.has_numba = False
t0 = time.perf_counter()
spatio_temporal_cluster_1samp_test(
    X, adjacency=adjacency, n_permutations=512,
    threshold=3.0, tail=1, verbose=False, seed=42
)
t_base = time.perf_counter() - t0
cl.has_numba = saved

print(f"Optimized: {t_opt:.2f}s  Baseline: {t_base:.2f}s  Speedup: {t_base/t_opt:.1f}x")

Additional information

  • Numba JIT warmup happens once on first call; subsequent calls pay no warmup cost
  • TFCE (threshold=dict(...)) correctly falls back to the original code path
  • Custom stat functions still benefit from the CCL and overhead optimizations but not the precomputed sum-of-squares
  • AI (Claude) was used to generate the code, which was checked over manually

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants