xmss: use numba JIT for poseidon2 permute#448
Open
unnawut wants to merge 5 commits intoleanEthereum:mainfrom
Open
xmss: use numba JIT for poseidon2 permute#448unnawut wants to merge 5 commits intoleanEthereum:mainfrom
unnawut wants to merge 5 commits intoleanEthereum:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
🗒️ Description
This PR optimizes keygen for 12 keys from ~8 hours down to ~1.5 hours using Numba JIT, basically compiling specific python functions to machine code. The conversion does not support all operations so some changes are needed:
_external_linear_layer()and_internal_linear_layer_jit()out ofclass Poseidon2Paramsso they can be compiled by Numba (with@njitannotation). The implementations are only slightly changed.permute()implementation out ofclass Poseidon2Params, also so that it can be compiled by Numba. Function callings changed a bit but the logic stays the same.@ _M4_Tdirectly which Numba doesn't support, we need to add_m4_multiply()to do the matmul ourselves.A single keygen now takes about 35 minutes. There are a few more optimizations to get a single keygen down to 15-20 minutes but the code diverges more from its previous form so I opted them out. Listed down here in case we want to consider them in later:
state[:] = ...into for loops. I think this one would half the keygen time but the code will look different.ceil(num_keys / num_cores) * single_keygen_time. We could switch to parallelize at epoch level to getsingle_keygen_time / num_cores * num_keysbut I'm not sure if the epoch compute can maximize the cores all the time. Plus the change will be all over multiple files so maybe I can try this out after this PR.The test in
test_permutation.pyshould still cover this so no changes added/updated.🔗 Related Issues or PRs
✅ Checklist
toxchecks to avoid unnecessary CI fails:uvx tox