Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 89 additions & 33 deletions packages/testing/src/consensus_testing/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

- Each key pair is stored in a separate JSON file with hex-encoded SSZ.
- Directory structure: ``test_keys/{scheme}_scheme/{index}.json``
- Each file contains: ``{"public": "0a1b...", "secret": "2c3d..."}``
- Each file has four hex-encoded SSZ fields:
``attestation_public``, ``attestation_secret``,
``proposal_public``, ``proposal_secret``
"""

from __future__ import annotations
Expand All @@ -36,6 +38,7 @@
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from pathlib import Path
from typing import Literal

from lean_spec.config import LEAN_ENV
from lean_spec.subspecs.containers import AttestationData, ValidatorIndex
Expand All @@ -48,7 +51,7 @@
from lean_spec.subspecs.koalabear import Fp
from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof
from lean_spec.subspecs.xmss.constants import TARGET_CONFIG
from lean_spec.subspecs.xmss.containers import KeyPair, PublicKey, Signature
from lean_spec.subspecs.xmss.containers import PublicKey, Signature, ValidatorKeyPair
from lean_spec.subspecs.xmss.interface import (
PROD_SIGNATURE_SCHEME,
TEST_SIGNATURE_SCHEME,
Expand Down Expand Up @@ -162,14 +165,14 @@ def get_keys_dir(scheme_name: str) -> Path:
return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme"


class LazyKeyDict(Mapping[ValidatorIndex, KeyPair]):
class LazyKeyDict(Mapping[ValidatorIndex, ValidatorKeyPair]):
"""Load pre-generated keys from disk (cached after first call)."""

def __init__(self, scheme_name: str) -> None:
"""Initialize with scheme name for locating key files."""
self._scheme_name = scheme_name
self._keys_dir = get_keys_dir(scheme_name)
self._cache: dict[ValidatorIndex, KeyPair] = {}
self._cache: dict[ValidatorIndex, ValidatorKeyPair] = {}
self._available_indices: set[ValidatorIndex] | None = None

def _ensure_dir_exists(self) -> None:
Expand All @@ -194,15 +197,15 @@ def _get_available_indices(self) -> set[ValidatorIndex]:
)
return self._available_indices

def _load_key(self, idx: ValidatorIndex) -> KeyPair:
def _load_key(self, idx: ValidatorIndex) -> ValidatorKeyPair:
"""Load a single key from disk."""
key_file = self._keys_dir / f"{idx}.json"
if not key_file.exists():
raise KeyError(f"Key file not found: {key_file}")
data = json.loads(key_file.read_text())
return KeyPair.from_dict(data)
return ValidatorKeyPair.from_dict(data)

def __getitem__(self, idx: ValidatorIndex) -> KeyPair:
def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair:
"""Get key pair by validator index, loading from disk if needed."""
if idx not in self._cache:
self._cache[idx] = self._load_key(idx)
Expand Down Expand Up @@ -244,7 +247,7 @@ def __init__(
"""Initialize the manager with optional custom configuration."""
self.max_slot = max_slot
self.scheme = scheme
self._state: dict[ValidatorIndex, KeyPair] = {}
self._state: dict[ValidatorIndex, ValidatorKeyPair] = {}

try:
self.scheme_name = next(
Expand All @@ -260,7 +263,7 @@ def keys(self) -> LazyKeyDict:
_LAZY_KEY_CACHE[self.scheme_name] = LazyKeyDict(self.scheme_name)
return _LAZY_KEY_CACHE[self.scheme_name]

def __getitem__(self, idx: ValidatorIndex) -> KeyPair:
def __getitem__(self, idx: ValidatorIndex) -> ValidatorKeyPair:
"""Get key pair, returning advanced state if available."""
if idx in self._state:
return self._state[idx]
Expand All @@ -282,36 +285,36 @@ def __iter__(self) -> Iterator[ValidatorIndex]:
"""Iterate over validator indices."""
return iter(self.keys)

def get_public_key(self, idx: ValidatorIndex) -> PublicKey:
"""Get a validator's public key."""
return self[idx].public
def get_attestation_public_key(self, idx: ValidatorIndex) -> PublicKey:
"""Get a validator's attestation public key."""
return self[idx].attestation_public

def sign_attestation_data(
def get_proposal_public_key(self, idx: ValidatorIndex) -> PublicKey:
"""Get a validator's proposal public key."""
return self[idx].proposal_public

def _sign_with_secret(
self,
validator_id: ValidatorIndex,
attestation_data: AttestationData,
secret_field: Literal["attestation_secret", "proposal_secret"],
) -> Signature:
"""
Sign an attestation data with automatic key state advancement.
Shared signing logic for attestation/proposal paths.

XMSS is stateful: signing advances the internal key state.
This method handles advancement transparently.
Handles XMSS state advancement until the requested slot is within the
prepared interval, caches the updated secret, and produces the signature.

Args:
validator_id: The validator index to sign the attestation data for.
attestation_data: The attestation data to sign.

Returns:
XMSS signature.

Raises:
ValueError: If slot exceeds key lifetime.
validator_id: Validator index whose key should be used.
attestation_data: Data to sign.
secret_field: Which secret on the key pair should advance.
"""
slot = attestation_data.slot
kp = self[validator_id]
sk = kp.secret
sk = getattr(kp, secret_field)

# Advance key state until slot is in prepared interval
# Advance key state until the slot is ready for signing.
prepared = self.scheme.get_prepared_interval(sk)
while int(slot) not in prepared:
activation = self.scheme.get_activation_interval(sk)
Expand All @@ -320,13 +323,58 @@ def sign_attestation_data(
sk = self.scheme.advance_preparation(sk)
prepared = self.scheme.get_prepared_interval(sk)

# Cache advanced state
self._state[validator_id] = kp._replace(secret=sk)
# Cache advanced state (only the selected secret changes).
self._state[validator_id] = kp._replace(**{secret_field: sk})

# Sign hash tree root of the attestation data
message = attestation_data.data_root_bytes()
return self.scheme.sign(sk, slot, message)

def sign_attestation_data(
self,
validator_id: ValidatorIndex,
attestation_data: AttestationData,
) -> Signature:
"""
Sign attestation data with the attestation key.

XMSS is stateful: this delegates to the shared helper which advances the
attestation key state as needed while leaving the proposal key untouched.

Args:
validator_id: The validator index to sign the attestation data for.
attestation_data: The attestation data to sign.

Returns:
XMSS signature.

Raises:
ValueError: If slot exceeds key lifetime.
"""
return self._sign_with_secret(validator_id, attestation_data, "attestation_secret")

def sign_proposal_data(
self,
validator_id: ValidatorIndex,
attestation_data: AttestationData,
) -> Signature:
"""
Sign proposer attestation data with the proposal key.

Delegates to the shared helper which advances only the proposal key, so
the attestation key remains unchanged.

Args:
validator_id: The validator index to sign the proposal for.
attestation_data: The attestation data to sign.

Returns:
XMSS signature.

Raises:
ValueError: If slot exceeds key lifetime.
"""
return self._sign_with_secret(validator_id, attestation_data, "proposal_secret")

def build_attestation_signatures(
self,
aggregated_attestations: AggregatedAttestations,
Expand All @@ -350,7 +398,9 @@ def build_attestation_signatures(
# Look up pre-computed signatures by attestation data and validator ID.
sigs_for_data = lookup.get(agg.data, {})

public_keys: list[PublicKey] = [self.get_public_key(vid) for vid in validator_ids]
public_keys: list[PublicKey] = [
self.get_attestation_public_key(vid) for vid in validator_ids
]
signatures: list[Signature] = [
sigs_for_data.get(vid) or self.sign_attestation_data(vid, agg.data)
for vid in validator_ids
Expand All @@ -374,10 +424,16 @@ def build_attestation_signatures(
def _generate_single_keypair(
scheme: GeneralizedXmssScheme, num_slots: int, index: int
) -> dict[str, str]:
"""Generate one key pair (module-level for pickling in ProcessPoolExecutor)."""
"""Generate dual key pairs for one validator (module-level for pickling)."""
print(f"Starting key #{index} generation...")
pk, sk = scheme.key_gen(Slot(0), Uint64(num_slots))
return KeyPair(public=pk, secret=sk).to_dict()
att_pk, att_sk = scheme.key_gen(Slot(0), Uint64(num_slots))
prop_pk, prop_sk = scheme.key_gen(Slot(0), Uint64(num_slots))
return ValidatorKeyPair(
attestation_public=att_pk,
attestation_secret=att_sk,
proposal_public=prop_pk,
proposal_secret=prop_sk,
).to_dict()


def _generate_keys(lean_env: str, count: int, max_slot: int) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,14 @@ def make_fixture(self) -> Self:
# Otherwise signature verification will fail.
updated_validators = [
validator.model_copy(
update={"pubkey": key_manager[ValidatorIndex(i)].public.encode_bytes()}
update={
"attestation_pubkey": key_manager[
ValidatorIndex(i)
].attestation_public.encode_bytes(),
"proposal_pubkey": key_manager[
ValidatorIndex(i)
].proposal_public.encode_bytes(),
}
)
for i, validator in enumerate(self.anchor_state.validators)
]
Expand Down Expand Up @@ -288,6 +295,30 @@ def make_fixture(self) -> Self:
scheme=LEAN_ENV_TO_SCHEMES[self.lean_env],
)

# Simulate the proposer's interval 1 gossip attestation.
#
# With dual keys, the proposer gossips a fresh attestation
# using the attestation key. Reuse the attestation data
# from the block envelope β€” it was built from the proposer's
# chain view (which includes their own block as head).
#
# Best-effort: if the attestation data fails validation
# (e.g. source > target after justification advances),
# skip gracefully β€” matches ValidatorService behavior.
proposer_att = signed_block.message.proposer_attestation
try:
store = store.on_gossip_attestation(
SignedAttestation(
validator_id=proposer_att.validator_id,
data=proposer_att.data,
signature=proposer_att.signature,
),
scheme=LEAN_ENV_TO_SCHEMES[self.lean_env],
is_aggregator=True,
)
except (AssertionError, Exception):
pass

case AttestationStep():
# Process a gossip attestation.
# Gossip attestations arrive outside of blocks.
Expand Down Expand Up @@ -472,9 +503,11 @@ def _build_block_from_spec(
"latest_finalized": latest_finalized,
}
)
proposer_attestation = Attestation(
proposer_attestation_data = temp_store.produce_attestation_data(spec.slot)
proposer_attestation = SignedAttestation(
validator_id=proposer_index,
data=temp_store.produce_attestation_data(spec.slot),
data=proposer_attestation_data,
signature=key_manager.sign_attestation_data(proposer_index, proposer_attestation_data),
)

# Sign everything
Expand All @@ -486,7 +519,7 @@ def _build_block_from_spec(
attestation_signatures,
)

proposer_signature = key_manager.sign_attestation_data(
proposer_signature = key_manager.sign_proposal_data(
proposer_attestation.validator_id,
proposer_attestation.data,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AggregationBits,
Attestation,
AttestationData,
SignedAttestation,
)
from lean_spec.subspecs.containers.block import (
BlockSignatures,
Expand Down Expand Up @@ -231,7 +232,7 @@ def _build_block_from_spec(
# Valid proof but from wrong validators
# Sign with signer_ids but claim validator_ids as participants
signer_public_keys = [
key_manager.get_public_key(vid) for vid in invalid_spec.signer_ids
key_manager.get_attestation_public_key(vid) for vid in invalid_spec.signer_ids
]
signer_signatures = [
key_manager.sign_attestation_data(vid, attestation_data)
Expand Down Expand Up @@ -276,24 +277,31 @@ def _build_block_from_spec(

# Create proposer attestation for this block
block_root = hash_tree_root(final_block)
proposer_attestation = Attestation(
validator_id=proposer_index,
data=AttestationData(
slot=spec.slot,
head=Checkpoint(root=block_root, slot=spec.slot),
target=Checkpoint(root=block_root, slot=spec.slot),
source=Checkpoint(root=parent_root, slot=parent_state.latest_block_header.slot),
),
proposer_attestation_data = AttestationData(
slot=spec.slot,
head=Checkpoint(root=block_root, slot=spec.slot),
target=Checkpoint(root=block_root, slot=spec.slot),
source=Checkpoint(root=parent_root, slot=parent_state.latest_block_header.slot),
)

# Sign proposer attestation - use valid or dummy signature based on spec
# Sign proposer attestation and proposer signature
# use valid or dummy signatures based on spec
if spec.valid_signature:
proposer_attestation_signature = key_manager.sign_attestation_data(
proposer_attestation.validator_id,
proposer_attestation.data,
attestation_signature = key_manager.sign_attestation_data(
proposer_index, proposer_attestation_data
)
proposer_signature = key_manager.sign_proposal_data(
proposer_index,
proposer_attestation_data,
)
else:
proposer_attestation_signature = create_dummy_signature()
attestation_signature = create_dummy_signature()
proposer_signature = create_dummy_signature()

proposer_attestation = SignedAttestation(
validator_id=proposer_index,
data=proposer_attestation_data,
signature=attestation_signature,
)

return SignedBlockWithAttestation(
message=BlockWithAttestation(
Expand All @@ -302,7 +310,7 @@ def _build_block_from_spec(
),
signature=BlockSignatures(
attestation_signatures=attestation_signatures,
proposer_signature=proposer_attestation_signature,
proposer_signature=proposer_signature,
),
)

Expand Down
Loading
Loading