Skip to content

xmss: use numba JIT for poseidon2 permute#448

Open
unnawut wants to merge 5 commits intoleanEthereum:mainfrom
unnawut:numba-jit-subtle
Open

xmss: use numba JIT for poseidon2 permute#448
unnawut wants to merge 5 commits intoleanEthereum:mainfrom
unnawut:numba-jit-subtle

Conversation

@unnawut
Copy link
Collaborator

@unnawut unnawut commented Mar 11, 2026

🗒️ Description

This PR optimizes keygen for 12 keys from ~8 hours down to ~1.5 hours using Numba JIT, basically compiling specific python functions to machine code. The conversion does not support all operations so some changes are needed:

  • Move _external_linear_layer() and _internal_linear_layer_jit() out of class Poseidon2Params so they can be compiled by Numba (with @njit annotation). The implementations are only slightly changed.
  • Move permute() implementation out of class Poseidon2Params, also so that it can be compiled by Numba. Function callings changed a bit but the logic stays the same.
  • The biggest change is instead of @ _M4_T directly which Numba doesn't support, we need to add _m4_multiply() to do the matmul ourselves.

A single keygen now takes about 35 minutes. There are a few more optimizations to get a single keygen down to 15-20 minutes but the code diverges more from its previous form so I opted them out. Listed down here in case we want to consider them in later:

  1. Replace NumPy array ops with loops: E.g. turning state[:] = ... into for loops. I think this one would half the keygen time but the code will look different.
  2. Inline the M4 matrix math: Replace the matrix and matmul with a directly derived function. But we then we lose the matrix visibility into this:
    for c in range(chunks.shape[0]):
        a, b, cv, d = chunks[c, 0], chunks[c, 1], chunks[c, 2], chunks[c, 3]
        s = (a + b + cv + d) % p
        result[c, 0] = (s + a + 2 * b) % p
        result[c, 1] = (s + b + 2 * cv) % p
        result[c, 2] = (s + cv + 2 * d) % p
        result[c, 3] = (s + 2 * a + d) % p
  3. Parallelize epochs within a single keygen: Right now we parallelize at the key level so the time is always ceil(num_keys / num_cores) * single_keygen_time. We could switch to parallelize at epoch level to get single_keygen_time / num_cores * num_keys but I'm not sure if the epoch compute can maximize the cores all the time. Plus the change will be all over multiple files so maybe I can try this out after this PR.

The test in test_permutation.py should still cover this so no changes added/updated.

🔗 Related Issues or PRs

✅ Checklist

  • Ran tox checks to avoid unnecessary CI fails:
    uvx tox
  • Considered adding appropriate tests for the changes.
  • Considered updating the online docs in the ./docs/ directory.

@unnawut unnawut requested a review from tcoratger March 11, 2026 06:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant