From 516187ddd9d4c215894f37103661ce4a37afa467 Mon Sep 17 00:00:00 2001 From: unnawut Date: Mon, 9 Mar 2026 14:43:04 +0700 Subject: [PATCH 1/5] use numba JIT for poseidon2 permute --- pyproject.toml | 1 + .../subspecs/poseidon2/permutation.py | 250 +++++++++--------- uv.lock | 48 +++- 3 files changed, 177 insertions(+), 122 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index edb94ddd..96ebd8d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "aiohttp>=3.11.0,<4", "cryptography>=46.0.0", "numpy>=2.0.0,<3", + "numba>=0.61.0,<1", "aioquic>=1.2.0,<2", "pyyaml>=6.0.0,<7", ] diff --git a/src/lean_spec/subspecs/poseidon2/permutation.py b/src/lean_spec/subspecs/poseidon2/permutation.py index 54aa05cb..23571d8a 100644 --- a/src/lean_spec/subspecs/poseidon2/permutation.py +++ b/src/lean_spec/subspecs/poseidon2/permutation.py @@ -4,14 +4,15 @@ Based on "Poseidon2: A Faster Version of the Poseidon Hash Function". See https://eprint.iacr.org/2023/323. -Uses numpy arrays for vectorized field operations. +Uses Numba JIT compilation for native-speed permutation. """ from __future__ import annotations -from typing import Final, Self +from typing import Self import numpy as np +from numba import njit from numpy.typing import NDArray from pydantic import Field, model_validator @@ -22,24 +23,125 @@ ROUND_CONSTANTS_24, ) -type State = NDArray[np.int64] -"""State vector as signed 64-bit integers.""" +@njit(cache=True) +def _external_linear_layer_jit(state: NDArray[np.int64], width: int, p: int) -> None: + """ + Apply the external linear layer (M_E) in-place. -_M4_T: Final[NDArray[np.int64]] = np.array( - [ - [2, 3, 1, 1], - [1, 2, 3, 1], - [1, 1, 2, 3], - [3, 1, 1, 2], - ], - dtype=np.int64, -).T -""" -Base 4x4 MDS matrix, pre-transposed. + Multiplies each 4-element chunk by the M4 circulant matrix, + then applies the outer circulant structure for global diffusion. + """ + num_chunks = width // 4 + + # Apply M4 to each 4-element chunk. + for c in range(num_chunks): + base = c * 4 + a = state[base] + b = state[base + 1] + c_val = state[base + 2] + d = state[base + 3] + + s = (a + b + c_val + d) % p + state[base] = (s + a + 2 * b) % p + state[base + 1] = (s + b + 2 * c_val) % p + state[base + 2] = (s + c_val + 2 * d) % p + state[base + 3] = (s + 2 * a + d) % p + + # Outer circulant: sum corresponding positions across chunks, add to each. + for i in range(4): + col_sum = np.int64(0) + for c in range(num_chunks): + col_sum += state[c * 4 + i] + for c in range(num_chunks): + state[c * 4 + i] = (state[c * 4 + i] + col_sum) % p + + +@njit(cache=True) +def _internal_linear_layer_jit( + state: NDArray[np.int64], diag_vector: NDArray[np.int64], width: int, p: int +) -> None: + """ + Apply the internal linear layer (M_I) in-place. -Pre-transposition enables efficient row-vector multiplication: `v @ M.T`. -""" + M_I = J + D where J is the all-ones matrix and D is diagonal. + O(t) computation instead of O(t^2). + """ + state_sum = np.int64(0) + for i in range(width): + state_sum += state[i] + state_sum = state_sum % p + + for i in range(width): + state[i] = (state_sum + diag_vector[i] * state[i] % p) % p + + +@njit(cache=True) +def _permute_jit( + state: NDArray[np.int64], + round_constants: NDArray[np.int64], + diag_vector: NDArray[np.int64], + width: int, + half_rounds_f: int, + rounds_p: int, + p: int, +) -> None: + """ + Full Poseidon2 permutation, compiled to native code. + + Modifies state array in-place. + S-box: x^3 computed as (x*x % p) * x % p to avoid int64 overflow. + """ + const_idx = 0 + + # 1. Initial linear layer. + _external_linear_layer_jit(state, width, p) + + # 2. First half of full rounds. + for _ in range(half_rounds_f): + for i in range(width): + state[i] = (state[i] + round_constants[const_idx + i]) % p + const_idx += width + + for i in range(width): + x = state[i] + state[i] = (x * x % p) * x % p + + _external_linear_layer_jit(state, width, p) + + # 3. Partial rounds. + for _ in range(rounds_p): + state[0] = (state[0] + round_constants[const_idx]) % p + const_idx += 1 + + x = state[0] + state[0] = (x * x % p) * x % p + + _internal_linear_layer_jit(state, diag_vector, width, p) + + # 4. Second half of full rounds. + for _ in range(half_rounds_f): + for i in range(width): + state[i] = (state[i] + round_constants[const_idx + i]) % p + const_idx += width + + for i in range(width): + x = state[i] + state[i] = (x * x % p) * x % p + + _external_linear_layer_jit(state, width, p) + + +# Trigger compilation on import so the first real call is fast. +_permute_jit( + np.zeros(16, dtype=np.int64), + np.zeros(148, dtype=np.int64), + np.zeros(16, dtype=np.int64), + 16, + 4, + 20, + 2130706433, +) class Poseidon2Params(StrictBaseModel): @@ -134,113 +236,19 @@ def permute(self, current_state: list[Fp]) -> list[Fp]: if len(current_state) != self._width: raise ValueError(f"Input state must have length {self._width}") - # Local variable access is faster in Python loops. - width = self._width - p = P - const_idx = 0 - constants = self._round_constants - - # Convert input Fp elements to numpy array. state = np.array([fp.value for fp in current_state], dtype=np.int64) - # 1. Initial Linear Layer - # - # Prevents certain algebraic attacks. - # Ensures the permutation begins with a diffusion layer. - state = self._external_linear_layer(state) - - # 2. First Half of Full Rounds (R_F / 2) - # - # Note: for S_BOX_DEGREE=3, state**3 would overflow int64 before modulo. - # Values reach up to 2^93, but int64 max is 2^63. - # Expand S-box to `(state*state % P) * state % P` to stay in range. - for _ in range(self._half_rounds_f): - # Add round constants to entire state. - state = (state + constants[const_idx : const_idx + width]) % p - const_idx += width - - # Apply S-box (x -> x^d) to full state. - state = (state * state % p) * state % p - - # Apply external linear layer for diffusion. - state = self._external_linear_layer(state) - - # 3. Partial Rounds (R_P) - for _ in range(self._rounds_p): - # Add single round constant to first element. - state[0] = (state[0] + constants[const_idx]) % p - const_idx += 1 - - # Apply S-box to first element only. - # This is the main optimization of the Hades design. - state[0] = (state[0] * state[0] % p) * state[0] % p - - # Apply internal linear layer. - state = self._internal_linear_layer(state) - - # 4. Second Half of Full Rounds (R_F / 2) - for _ in range(self._half_rounds_f): - # Add round constants to entire state. - state = (state + constants[const_idx : const_idx + width]) % p - const_idx += width - - # Apply S-box to full state. - state = (state * state % p) * state % p - - # Apply external linear layer for diffusion. - state = self._external_linear_layer(state) - - # Convert back to Fp objects. - return [Fp(value=int(x)) for x in state] - - def _external_linear_layer(self, state: State) -> State: - """ - Apply the external linear layer (M_E). - - Provides strong diffusion across the entire state. - Used in full rounds. - - For state size t=4k, constructed from M4 to form a circulant-like matrix. - Efficient while ensuring any single element change affects all others. - - See Appendix B of the paper. - """ - # Apply M4 to each 4-element chunk. - # Provides strong local diffusion within each block. - chunks = state.reshape(-1, 4) - chunks = chunks @ _M4_T - - # Apply outer circulant structure for global diffusion. - # Equivalent to multiplying by circ(2*I, I, ..., I) after M4 stage. - sums = chunks.sum(axis=0) - - # Add corresponding sum to each element. - return (chunks + sums).reshape(-1) % P - - def _internal_linear_layer(self, state: State) -> State: - """ - Apply the internal linear layer (M_I). - - Used during partial rounds. - Optimized for speed. - - Matrix structure: M_I = J + D - - - J is the all-ones matrix - - D is a diagonal matrix - - This allows O(t) computation instead of O(t^2): - - M_I * s = J*s + D*s - - J*s is a vector where each element equals the sum of all elements in s. - """ - # J*state: sum of all elements (broadcast to vector). - # D*state: element-wise multiplication with diagonal. - state_sum = state.sum() + _permute_jit( + state, + self._round_constants, + self._diag_vector, + self._width, + self._half_rounds_f, + self._rounds_p, + P, + ) - # new_state[i] = state_sum + diag_vector[i] * state[i] - return (state_sum + (self._diag_vector * state)) % P + return [Fp(value=int(x)) for x in state] # Parameters for WIDTH = 16 diff --git a/uv.lock b/uv.lock index 84fc8429..cb706baf 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" [manifest] @@ -879,6 +879,7 @@ dependencies = [ { name = "cryptography" }, { name = "httpx" }, { name = "lean-multisig-py" }, + { name = "numba" }, { name = "numpy" }, { name = "pydantic" }, { name = "pyyaml" }, @@ -939,6 +940,7 @@ requires-dist = [ { name = "cryptography", specifier = ">=46.0.0" }, { name = "httpx", specifier = ">=0.28.0,<1" }, { name = "lean-multisig-py", git = "https://github.com/anshalshukla/leanMultisig-py?branch=devnet2" }, + { name = "numba", specifier = ">=0.61.0,<1" }, { name = "numpy", specifier = ">=2.0.0,<3" }, { name = "pydantic", specifier = ">=2.12.0,<3" }, { name = "pyyaml", specifier = ">=6.0.0,<7" }, @@ -992,6 +994,26 @@ test = [ { name = "pytest-xdist", specifier = ">=3.6.1,<4" }, ] +[[package]] +name = "llvmlite" +version = "0.46.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/74/cd/08ae687ba099c7e3d21fe2ea536500563ef1943c5105bf6ab4ee3829f68e/llvmlite-0.46.0.tar.gz", hash = "sha256:227c9fd6d09dce2783c18b754b7cd9d9b3b3515210c46acc2d3c5badd9870ceb", size = 193456, upload-time = "2025-12-08T18:15:36.295Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/f8/4db016a5e547d4e054ff2f3b99203d63a497465f81ab78ec8eb2ff7b2304/llvmlite-0.46.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b9588ad4c63b4f0175a3984b85494f0c927c6b001e3a246a3a7fb3920d9a137", size = 37232767, upload-time = "2025-12-08T18:15:00.737Z" }, + { url = "https://files.pythonhosted.org/packages/aa/85/4890a7c14b4fa54400945cb52ac3cd88545bbdb973c440f98ca41591cdc5/llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3535bd2bb6a2d7ae4012681ac228e5132cdb75fefb1bcb24e33f2f3e0c865ed4", size = 56275176, upload-time = "2025-12-08T18:15:03.936Z" }, + { url = "https://files.pythonhosted.org/packages/6a/07/3d31d39c1a1a08cd5337e78299fca77e6aebc07c059fbd0033e3edfab45c/llvmlite-0.46.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4cbfd366e60ff87ea6cc62f50bc4cd800ebb13ed4c149466f50cf2163a473d1e", size = 55128630, upload-time = "2025-12-08T18:15:07.196Z" }, + { url = "https://files.pythonhosted.org/packages/2a/6b/d139535d7590a1bba1ceb68751bef22fadaa5b815bbdf0e858e3875726b2/llvmlite-0.46.0-cp312-cp312-win_amd64.whl", hash = "sha256:398b39db462c39563a97b912d4f2866cd37cba60537975a09679b28fbbc0fb38", size = 38138940, upload-time = "2025-12-08T18:15:10.162Z" }, + { url = "https://files.pythonhosted.org/packages/e6/ff/3eba7eb0aed4b6fca37125387cd417e8c458e750621fce56d2c541f67fa8/llvmlite-0.46.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:30b60892d034bc560e0ec6654737aaa74e5ca327bd8114d82136aa071d611172", size = 37232767, upload-time = "2025-12-08T18:15:13.22Z" }, + { url = "https://files.pythonhosted.org/packages/0e/54/737755c0a91558364b9200702c3c9c15d70ed63f9b98a2c32f1c2aa1f3ba/llvmlite-0.46.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6cc19b051753368a9c9f31dc041299059ee91aceec81bd57b0e385e5d5bf1a54", size = 56275176, upload-time = "2025-12-08T18:15:16.339Z" }, + { url = "https://files.pythonhosted.org/packages/e6/91/14f32e1d70905c1c0aa4e6609ab5d705c3183116ca02ac6df2091868413a/llvmlite-0.46.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bca185892908f9ede48c0acd547fe4dc1bafefb8a4967d47db6cf664f9332d12", size = 55128629, upload-time = "2025-12-08T18:15:19.493Z" }, + { url = "https://files.pythonhosted.org/packages/4a/a7/d526ae86708cea531935ae777b6dbcabe7db52718e6401e0fb9c5edea80e/llvmlite-0.46.0-cp313-cp313-win_amd64.whl", hash = "sha256:67438fd30e12349ebb054d86a5a1a57fd5e87d264d2451bcfafbbbaa25b82a35", size = 38138941, upload-time = "2025-12-08T18:15:22.536Z" }, + { url = "https://files.pythonhosted.org/packages/95/ae/af0ffb724814cc2ea64445acad05f71cff5f799bb7efb22e47ee99340dbc/llvmlite-0.46.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:d252edfb9f4ac1fcf20652258e3f102b26b03eef738dc8a6ffdab7d7d341d547", size = 37232768, upload-time = "2025-12-08T18:15:25.055Z" }, + { url = "https://files.pythonhosted.org/packages/c9/19/5018e5352019be753b7b07f7759cdabb69ca5779fea2494be8839270df4c/llvmlite-0.46.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:379fdd1c59badeff8982cb47e4694a6143bec3bb49aa10a466e095410522064d", size = 56275173, upload-time = "2025-12-08T18:15:28.109Z" }, + { url = "https://files.pythonhosted.org/packages/9f/c9/d57877759d707e84c082163c543853245f91b70c804115a5010532890f18/llvmlite-0.46.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e8cbfff7f6db0fa2c771ad24154e2a7e457c2444d7673e6de06b8b698c3b269", size = 55128628, upload-time = "2025-12-08T18:15:31.098Z" }, + { url = "https://files.pythonhosted.org/packages/30/a8/e61a8c2b3cc7a597073d9cde1fcbb567e9d827f1db30c93cf80422eac70d/llvmlite-0.46.0-cp314-cp314-win_amd64.whl", hash = "sha256:7821eda3ec1f18050f981819756631d60b6d7ab1a6cf806d9efefbe3f4082d61", size = 39153056, upload-time = "2025-12-08T18:15:33.938Z" }, +] + [[package]] name = "markdown" version = "3.10" @@ -1378,6 +1400,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/7e/a96255f63b7aef032cbee8fc4d6e37def72e3aaedc1f72759235e8f13cb1/nh3-0.3.2-cp38-abi3-win_arm64.whl", hash = "sha256:cf5964d54edd405e68583114a7cba929468bcd7db5e676ae38ee954de1cfc104", size = 584162, upload-time = "2025-10-30T11:17:44.96Z" }, ] +[[package]] +name = "numba" +version = "0.64.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/c9/a0fb41787d01d621046138da30f6c2100d80857bf34b3390dd68040f27a3/numba-0.64.0.tar.gz", hash = "sha256:95e7300af648baa3308127b1955b52ce6d11889d16e8cfe637b4f85d2fca52b1", size = 2765679, upload-time = "2026-02-18T18:41:20.974Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/a6/9fc52cb4f0d5e6d8b5f4d81615bc01012e3cf24e1052a60f17a68deb8092/numba-0.64.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:69440a8e8bc1a81028446f06b363e28635aa67bd51b1e498023f03b812e0ce68", size = 2683418, upload-time = "2026-02-18T18:40:59.886Z" }, + { url = "https://files.pythonhosted.org/packages/9b/89/1a74ea99b180b7a5587b0301ed1b183a2937c4b4b67f7994689b5d36fc34/numba-0.64.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f13721011f693ba558b8dd4e4db7f2640462bba1b855bdc804be45bbeb55031a", size = 3804087, upload-time = "2026-02-18T18:41:01.699Z" }, + { url = "https://files.pythonhosted.org/packages/91/e1/583c647404b15f807410510fec1eb9b80cb8474165940b7749f026f21cbc/numba-0.64.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0b180b1133f2b5d8b3f09d96b6d7a9e51a7da5dda3c09e998b5bcfac85d222c", size = 3504309, upload-time = "2026-02-18T18:41:03.252Z" }, + { url = "https://files.pythonhosted.org/packages/85/23/0fce5789b8a5035e7ace21216a468143f3144e02013252116616c58339aa/numba-0.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:e63dc94023b47894849b8b106db28ccb98b49d5498b98878fac1a38f83ac007a", size = 2752740, upload-time = "2026-02-18T18:41:05.097Z" }, + { url = "https://files.pythonhosted.org/packages/52/80/2734de90f9300a6e2503b35ee50d9599926b90cbb7ac54f9e40074cd07f1/numba-0.64.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3bab2c872194dcd985f1153b70782ec0fbbe348fffef340264eacd3a76d59fd6", size = 2683392, upload-time = "2026-02-18T18:41:06.563Z" }, + { url = "https://files.pythonhosted.org/packages/42/e8/14b5853ebefd5b37723ef365c5318a30ce0702d39057eaa8d7d76392859d/numba-0.64.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:703a246c60832cad231d2e73c1182f25bf3cc8b699759ec8fe58a2dbc689a70c", size = 3812245, upload-time = "2026-02-18T18:41:07.963Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a2/f60dc6c96d19b7185144265a5fbf01c14993d37ff4cd324b09d0212aa7ce/numba-0.64.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e2e49a7900ee971d32af7609adc0cfe6aa7477c6f6cccdf6d8138538cf7756f", size = 3511328, upload-time = "2026-02-18T18:41:09.504Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2a/fe7003ea7e7237ee7014f8eaeeb7b0d228a2db22572ca85bab2648cf52cb/numba-0.64.0-cp313-cp313-win_amd64.whl", hash = "sha256:396f43c3f77e78d7ec84cdfc6b04969c78f8f169351b3c4db814b97e7acf4245", size = 2752668, upload-time = "2026-02-18T18:41:11.455Z" }, + { url = "https://files.pythonhosted.org/packages/3d/8a/77d26afe0988c592dd97cb8d4e80bfb3dfc7dbdacfca7d74a7c5c81dd8c2/numba-0.64.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:f565d55eaeff382cbc86c63c8c610347453af3d1e7afb2b6569aac1c9b5c93ce", size = 2683590, upload-time = "2026-02-18T18:41:12.897Z" }, + { url = "https://files.pythonhosted.org/packages/8e/4b/600b8b7cdbc7f9cebee9ea3d13bb70052a79baf28944024ffcb59f0712e3/numba-0.64.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9b55169b18892c783f85e9ad9e6f5297a6d12967e4414e6b71361086025ff0bb", size = 3781163, upload-time = "2026-02-18T18:41:15.377Z" }, + { url = "https://files.pythonhosted.org/packages/ff/73/53f2d32bfa45b7175e9944f6b816d8c32840178c3eee9325033db5bf838e/numba-0.64.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:196bcafa02c9dd1707e068434f6d5cedde0feb787e3432f7f1f0e993cc336c4c", size = 3481172, upload-time = "2026-02-18T18:41:17.281Z" }, + { url = "https://files.pythonhosted.org/packages/b5/00/aebd2f7f1e11e38814bb96e95a27580817a7b340608d3ac085fdbab83174/numba-0.64.0-cp314-cp314-win_amd64.whl", hash = "sha256:213e9acbe7f1c05090592e79020315c1749dd52517b90e94c517dca3f014d4a1", size = 2754700, upload-time = "2026-02-18T18:41:19.277Z" }, +] + [[package]] name = "numpy" version = "2.4.1" From 02578d192e1a877b1838237e4f899fe0a5673d00 Mon Sep 17 00:00:00 2001 From: unnawut Date: Tue, 10 Mar 2026 20:51:01 +0700 Subject: [PATCH 2/5] more subtle change --- .../subspecs/poseidon2/permutation.py | 142 +++++++++++------- 1 file changed, 90 insertions(+), 52 deletions(-) diff --git a/src/lean_spec/subspecs/poseidon2/permutation.py b/src/lean_spec/subspecs/poseidon2/permutation.py index 23571d8a..bde57be5 100644 --- a/src/lean_spec/subspecs/poseidon2/permutation.py +++ b/src/lean_spec/subspecs/poseidon2/permutation.py @@ -25,55 +25,83 @@ @njit(cache=True) -def _external_linear_layer_jit(state: NDArray[np.int64], width: int, p: int) -> None: +def _m4_multiply( + chunks: NDArray[np.int64], p: int +) -> NDArray[np.int64]: """ - Apply the external linear layer (M_E) in-place. + Multiply each row of `chunks` by the M4 circulant matrix. - Multiplies each 4-element chunk by the M4 circulant matrix, - then applies the outer circulant structure for global diffusion. + Replaces `chunks @ M4.T` which requires scipy under Numba. + The circulant structure means each output is a linear combination + of `sum + one_extra_copy + two_extra_copies`. """ - num_chunks = width // 4 + result = np.empty_like(chunks) + for c in range(chunks.shape[0]): + a, b, cv, d = chunks[c, 0], chunks[c, 1], chunks[c, 2], chunks[c, 3] + s = (a + b + cv + d) % p + result[c, 0] = (s + a + 2 * b) % p + result[c, 1] = (s + b + 2 * cv) % p + result[c, 2] = (s + cv + 2 * d) % p + result[c, 3] = (s + 2 * a + d) % p + return result + +@njit(cache=True) +def _external_linear_layer_jit(state: NDArray[np.int64], p: int) -> NDArray[np.int64]: + """ + Apply the external linear layer (M_E). + + Provides strong diffusion across the entire state. + Used in full rounds. + + For state size t=4k, constructed from M4 to form a circulant-like matrix. + Efficient while ensuring any single element change affects all others. + + See Appendix B of the paper. + """ # Apply M4 to each 4-element chunk. - for c in range(num_chunks): - base = c * 4 - a = state[base] - b = state[base + 1] - c_val = state[base + 2] - d = state[base + 3] - - s = (a + b + c_val + d) % p - state[base] = (s + a + 2 * b) % p - state[base + 1] = (s + b + 2 * c_val) % p - state[base + 2] = (s + c_val + 2 * d) % p - state[base + 3] = (s + 2 * a + d) % p - - # Outer circulant: sum corresponding positions across chunks, add to each. - for i in range(4): - col_sum = np.int64(0) - for c in range(num_chunks): - col_sum += state[c * 4 + i] - for c in range(num_chunks): - state[c * 4 + i] = (state[c * 4 + i] + col_sum) % p + # Provides strong local diffusion within each block. + chunks = state.reshape(-1, 4) + chunks = _m4_multiply(chunks, p) + + # Apply outer circulant structure for global diffusion. + # Equivalent to multiplying by circ(2*I, I, ..., I) after M4 stage. + sums = np.zeros(4, dtype=np.int64) + for c in range(chunks.shape[0]): + for i in range(4): + sums[i] += chunks[c, i] + + # Add corresponding sum to each element. + return (chunks + sums).reshape(-1) % p @njit(cache=True) def _internal_linear_layer_jit( - state: NDArray[np.int64], diag_vector: NDArray[np.int64], width: int, p: int -) -> None: + state: NDArray[np.int64], diag_vector: NDArray[np.int64], p: int +) -> NDArray[np.int64]: """ - Apply the internal linear layer (M_I) in-place. + Apply the internal linear layer (M_I). + + Used during partial rounds. + Optimized for speed. + + Matrix structure: M_I = J + D + + - J is the all-ones matrix + - D is a diagonal matrix + + This allows O(t) computation instead of O(t^2): + + M_I * s = J*s + D*s - M_I = J + D where J is the all-ones matrix and D is diagonal. - O(t) computation instead of O(t^2). + J*s is a vector where each element equals the sum of all elements in s. """ - state_sum = np.int64(0) - for i in range(width): - state_sum += state[i] - state_sum = state_sum % p + # J*state: sum of all elements (broadcast to vector). + # D*state: element-wise multiplication with diagonal. + state_sum = state.sum() - for i in range(width): - state[i] = (state_sum + diag_vector[i] * state[i] % p) % p + # new_state[i] = state_sum + diag_vector[i] * state[i] + return (state_sum + (diag_vector * state)) % p @njit(cache=True) @@ -95,41 +123,51 @@ def _permute_jit( const_idx = 0 # 1. Initial linear layer. - _external_linear_layer_jit(state, width, p) + # + # Prevents certain algebraic attacks. + # Ensures the permutation begins with a diffusion layer. + state[:] = _external_linear_layer_jit(state, p) # 2. First half of full rounds. + # + # Note: for S_BOX_DEGREE=3, state**3 would overflow int64 before modulo. + # Values reach up to 2^93, but int64 max is 2^63. + # Expand S-box to `(state*state % P) * state % P` to stay in range. for _ in range(half_rounds_f): - for i in range(width): - state[i] = (state[i] + round_constants[const_idx + i]) % p + # Add round constants to entire state. + state[:] = (state + round_constants[const_idx : const_idx + width]) % p const_idx += width - for i in range(width): - x = state[i] - state[i] = (x * x % p) * x % p + # Apply S-box (x -> x^d) to full state. + state[:] = (state * state % p) * state % p - _external_linear_layer_jit(state, width, p) + # Apply external linear layer for diffusion. + state[:] = _external_linear_layer_jit(state, p) # 3. Partial rounds. for _ in range(rounds_p): + # Add single round constant to first element. state[0] = (state[0] + round_constants[const_idx]) % p const_idx += 1 - x = state[0] - state[0] = (x * x % p) * x % p + # Apply S-box to first element only. + # This is the main optimization of the Hades design. + state[0] = (state[0] * state[0] % p) * state[0] % p - _internal_linear_layer_jit(state, diag_vector, width, p) + # Apply internal linear layer. + state[:] = _internal_linear_layer_jit(state, diag_vector, p) # 4. Second half of full rounds. for _ in range(half_rounds_f): - for i in range(width): - state[i] = (state[i] + round_constants[const_idx + i]) % p + # Add round constants to entire state. + state[:] = (state + round_constants[const_idx : const_idx + width]) % p const_idx += width - for i in range(width): - x = state[i] - state[i] = (x * x % p) * x % p + # Apply S-box to full state. + state[:] = (state * state % p) * state % p - _external_linear_layer_jit(state, width, p) + # Apply external linear layer for diffusion. + state[:] = _external_linear_layer_jit(state, p) # Trigger compilation on import so the first real call is fast. From 5e67d38646366eee509dacf2c482848772c8af1e Mon Sep 17 00:00:00 2001 From: unnawut Date: Tue, 10 Mar 2026 21:06:39 +0700 Subject: [PATCH 3/5] minimize change --- .../subspecs/poseidon2/permutation.py | 56 ++++++++++++------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/src/lean_spec/subspecs/poseidon2/permutation.py b/src/lean_spec/subspecs/poseidon2/permutation.py index bde57be5..dc030c4e 100644 --- a/src/lean_spec/subspecs/poseidon2/permutation.py +++ b/src/lean_spec/subspecs/poseidon2/permutation.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import Self +from typing import Final, Self import numpy as np from numba import njit @@ -23,31 +23,46 @@ ROUND_CONSTANTS_24, ) +_M4_T: Final[NDArray[np.int64]] = np.array( + [ + [2, 3, 1, 1], + [1, 2, 3, 1], + [1, 1, 2, 3], + [3, 1, 1, 2], + ], + dtype=np.int64, +).T +""" +Base 4x4 MDS matrix, pre-transposed. + +Pre-transposition enables efficient row-vector multiplication: `v @ M.T`. +""" + @njit(cache=True) -def _m4_multiply( - chunks: NDArray[np.int64], p: int -) -> NDArray[np.int64]: +def _m4_multiply(chunks: NDArray[np.int64], m4t: NDArray[np.int64], p: int) -> NDArray[np.int64]: """ - Multiply each row of `chunks` by the M4 circulant matrix. + Multiply each row of `chunks` by the M4 matrix. - Replaces `chunks @ M4.T` which requires scipy under Numba. - The circulant structure means each output is a linear combination - of `sum + one_extra_copy + two_extra_copies`. + Equivalent to `chunks @ m4t % p`. + Numba's `@` operator requires scipy and float arrays, + so we use an explicit loop instead. Numba unrolls these + small fixed-size loops, so overhead is ~12% vs native matmul. """ result = np.empty_like(chunks) for c in range(chunks.shape[0]): - a, b, cv, d = chunks[c, 0], chunks[c, 1], chunks[c, 2], chunks[c, 3] - s = (a + b + cv + d) % p - result[c, 0] = (s + a + 2 * b) % p - result[c, 1] = (s + b + 2 * cv) % p - result[c, 2] = (s + cv + 2 * d) % p - result[c, 3] = (s + 2 * a + d) % p + for j in range(4): + s = np.int64(0) + for k in range(4): + s += chunks[c, k] * m4t[k, j] + result[c, j] = s % p return result @njit(cache=True) -def _external_linear_layer_jit(state: NDArray[np.int64], p: int) -> NDArray[np.int64]: +def _external_linear_layer_jit( + state: NDArray[np.int64], m4t: NDArray[np.int64], p: int +) -> NDArray[np.int64]: """ Apply the external linear layer (M_E). @@ -62,7 +77,7 @@ def _external_linear_layer_jit(state: NDArray[np.int64], p: int) -> NDArray[np.i # Apply M4 to each 4-element chunk. # Provides strong local diffusion within each block. chunks = state.reshape(-1, 4) - chunks = _m4_multiply(chunks, p) + chunks = _m4_multiply(chunks, m4t, p) # Apply outer circulant structure for global diffusion. # Equivalent to multiplying by circ(2*I, I, ..., I) after M4 stage. @@ -109,6 +124,7 @@ def _permute_jit( state: NDArray[np.int64], round_constants: NDArray[np.int64], diag_vector: NDArray[np.int64], + m4t: NDArray[np.int64], width: int, half_rounds_f: int, rounds_p: int, @@ -126,7 +142,7 @@ def _permute_jit( # # Prevents certain algebraic attacks. # Ensures the permutation begins with a diffusion layer. - state[:] = _external_linear_layer_jit(state, p) + state[:] = _external_linear_layer_jit(state, m4t, p) # 2. First half of full rounds. # @@ -142,7 +158,7 @@ def _permute_jit( state[:] = (state * state % p) * state % p # Apply external linear layer for diffusion. - state[:] = _external_linear_layer_jit(state, p) + state[:] = _external_linear_layer_jit(state, m4t, p) # 3. Partial rounds. for _ in range(rounds_p): @@ -167,7 +183,7 @@ def _permute_jit( state[:] = (state * state % p) * state % p # Apply external linear layer for diffusion. - state[:] = _external_linear_layer_jit(state, p) + state[:] = _external_linear_layer_jit(state, m4t, p) # Trigger compilation on import so the first real call is fast. @@ -175,6 +191,7 @@ def _permute_jit( np.zeros(16, dtype=np.int64), np.zeros(148, dtype=np.int64), np.zeros(16, dtype=np.int64), + _M4_T, 16, 4, 20, @@ -280,6 +297,7 @@ def permute(self, current_state: list[Fp]) -> list[Fp]: state, self._round_constants, self._diag_vector, + _M4_T, self._width, self._half_rounds_f, self._rounds_p, From 7c0088e04b10d325f8f71c45f71e5f35cef22bdc Mon Sep 17 00:00:00 2001 From: unnawut Date: Wed, 11 Mar 2026 00:45:18 +0700 Subject: [PATCH 4/5] remove warm up call --- src/lean_spec/subspecs/poseidon2/permutation.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/lean_spec/subspecs/poseidon2/permutation.py b/src/lean_spec/subspecs/poseidon2/permutation.py index dc030c4e..60213914 100644 --- a/src/lean_spec/subspecs/poseidon2/permutation.py +++ b/src/lean_spec/subspecs/poseidon2/permutation.py @@ -186,18 +186,6 @@ def _permute_jit( state[:] = _external_linear_layer_jit(state, m4t, p) -# Trigger compilation on import so the first real call is fast. -_permute_jit( - np.zeros(16, dtype=np.int64), - np.zeros(148, dtype=np.int64), - np.zeros(16, dtype=np.int64), - _M4_T, - 16, - 4, - 20, - 2130706433, -) - class Poseidon2Params(StrictBaseModel): """Parameters for a specific Poseidon2 instance.""" From 0a7ca5b611588c2433475aaf35bd5e27a828ee69 Mon Sep 17 00:00:00 2001 From: unnawut Date: Wed, 11 Mar 2026 12:39:41 +0700 Subject: [PATCH 5/5] lint --- src/lean_spec/subspecs/poseidon2/permutation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lean_spec/subspecs/poseidon2/permutation.py b/src/lean_spec/subspecs/poseidon2/permutation.py index 60213914..7db0e654 100644 --- a/src/lean_spec/subspecs/poseidon2/permutation.py +++ b/src/lean_spec/subspecs/poseidon2/permutation.py @@ -186,7 +186,6 @@ def _permute_jit( state[:] = _external_linear_layer_jit(state, m4t, p) - class Poseidon2Params(StrictBaseModel): """Parameters for a specific Poseidon2 instance."""