diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index b4dfe379..c036d45a 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Collection, Iterable +from collections.abc import Iterable from collections.abc import Set as AbstractSet from typing import TYPE_CHECKING @@ -759,79 +759,93 @@ def build_block( return final_block, post_state, aggregated_attestations, aggregated_signatures - def aggregate_gossip_signatures( + def _extend_proofs_greedily( + proofs: set[AggregatedSignatureProof] | None, + selected: list[AggregatedSignatureProof], + covered: set[ValidatorIndex], + ) -> None: + if not proofs: + return + remaining = list(proofs) + while remaining: + best = max( + remaining, + key=lambda proof: len( + set(proof.participants.to_validator_indices()) - covered + ), + ) + participants = set(best.participants.to_validator_indices()) + if not (participants - covered): + break + selected.append(best) + covered.update(participants) + remaining.remove(best) + + def aggregate( self, - attestations: Collection[Attestation], gossip_signatures: dict[AttestationData, set[GossipSignatureEntry]] | None = None, + new_payloads: dict[AttestationData, set[AggregatedSignatureProof]] | None = None, + known_payloads: dict[AttestationData, set[AggregatedSignatureProof]] | None = None, ) -> list[tuple[AggregatedAttestation, AggregatedSignatureProof]]: """ - Collect aggregated signatures from gossip network and aggregate them. - - For each attestation group, attempt to collect individual XMSS signatures - from the gossip network. These are fresh signatures that validators - broadcast when they attest. + Aggregate gossip signatures using new payloads, with known payloads as helpers. Args: - attestations: Individual attestations to aggregate and sign. - gossip_signatures: Per-validator XMSS signatures learned from - the gossip network, keyed by the attestation data they signed. + gossip_signatures: Raw XMSS signatures learned from gossip keyed by attestation data. + new_payloads: Aggregated proofs pending processing (child proofs). + known_payloads: Known aggregated proofs already accepted. Returns: - List of (attestation, proof) pairs from gossip collection. + List of (aggregated attestation, proof) pairs to broadcast. """ results: list[tuple[AggregatedAttestation, AggregatedSignatureProof]] = [] - # Group individual attestations by data - # - # Multiple validators may attest to the same data (slot, head, target, source). - # We aggregate them into groups so each group can share a single proof. - for aggregated in AggregatedAttestation.aggregate_by_data(list(attestations)): - # Extract the common attestation data and its hash. - # - # All validators in this group signed the same message (the data root). - data = aggregated.data - data_root = data.data_root_bytes() + gossip_signatures = gossip_signatures or {} + new_payloads = new_payloads or {} + known_payloads = known_payloads or {} - # Get the list of validators who attested to this data. - validator_ids = aggregated.aggregation_bits.to_validator_indices() + # Use only keys from new_payloads and gossip_signatures + # know_payloads can be used to extend the proof with new_payloads and gossip_signatures + # but known_payloads are not recursively aggregated into their own proofs + attestation_keys = set(new_payloads.keys()) | set(gossip_signatures.keys()) + if not attestation_keys: + return results - # When a validator creates an attestation, it broadcasts the - # individual XMSS signature over the gossip network. If we have - # received these signatures, we can aggregate them ourselves. - # - # This is the preferred path: fresh signatures from the network. - - # Parallel lists for signatures, public keys, and validator IDs. - gossip_sigs: list[Signature] = [] - gossip_keys: list[PublicKey] = [] - gossip_ids: list[ValidatorIndex] = [] - - # Look up signatures by attestation data directly. - # Sort by validator ID for deterministic aggregation order. - if gossip_signatures and (entries := gossip_signatures.get(data)): - for entry in sorted(entries, key=lambda e: e.validator_id): - if entry.validator_id in validator_ids: - gossip_sigs.append(entry.signature) - gossip_keys.append(self.validators[entry.validator_id].get_pubkey()) - gossip_ids.append(entry.validator_id) - - # If we collected any gossip signatures, aggregate them into a proof. - # - # The aggregation combines multiple XMSS signatures into a single - # compact proof that can verify all participants signed the message. - if gossip_ids: - participants = AggregationBits.from_validator_indices( - ValidatorIndices(data=gossip_ids) - ) - proof = AggregatedSignatureProof.aggregate( - participants=participants, - public_keys=gossip_keys, - signatures=gossip_sigs, - message=data_root, - slot=data.slot, - ) - attestation = AggregatedAttestation(aggregation_bits=participants, data=data) - results.append((attestation, proof)) + # Aggregate the proofs for each attestation data + for data in attestation_keys: + child_proofs: list[AggregatedSignatureProof] = [] + covered_validators: set[ValidatorIndex] = set() + + self._extend_proofs_greedily(new_payloads.get(data), child_proofs, covered_validators) + self._extend_proofs_greedily(known_payloads.get(data), child_proofs, covered_validators) + + raw_entries: list[tuple[ValidatorIndex, PublicKey, Signature]] = [] + for entry in sorted(gossip_signatures.get(data, set()), key=lambda e: e.validator_id): + if entry.validator_id in covered_validators: + continue + if int(entry.validator_id) >= len(self.validators): + continue + public_key = self.validators[entry.validator_id].get_pubkey() + raw_entries.append((entry.validator_id, public_key, entry.signature)) + covered_validators.add(entry.validator_id) + + if not raw_entries and len(child_proofs) < 2: + results.append((data, child_proofs)) + continue + + raw_entries = sorted(raw_entries, key=lambda e: e.validator_id) + raw_xmss = [(pubkey, signature) for _, pubkey, signature in raw_entries] + xmss_participants = AggregationBits.from_validator_indices(ValidatorIndices(data=[e.validator_id for e in raw_entries])) + + proof = AggregatedSignatureProof.aggregate( + xmss_participants=xmss_participants, + children=child_proofs, + raw_xmss=raw_xmss, + message=data.data_root_bytes(), + slot=data.slot, + ) + attestation = AggregatedAttestation(aggregation_bits=proof.participants, data=data) + results.append((attestation, proof)) return results diff --git a/src/lean_spec/subspecs/forkchoice/store.py b/src/lean_spec/subspecs/forkchoice/store.py index f758410b..c44ee820 100644 --- a/src/lean_spec/subspecs/forkchoice/store.py +++ b/src/lean_spec/subspecs/forkchoice/store.py @@ -946,65 +946,30 @@ def update_safe_target(self) -> "Store": # The head and attestation pools remain unchanged. return self.model_copy(update={"safe_target": safe_target}) - def aggregate_committee_signatures(self) -> tuple["Store", list[SignedAggregatedAttestation]]: + def aggregate_signatures_and_payloads(self) -> tuple["Store", list[SignedAggregatedAttestation]]: """ - Aggregate committee signatures for attestations in committee_signatures. - - This method aggregates signatures from the gossip_signatures map. + Aggregate committee signatures and payloads together. Returns: Tuple of (new Store with updated payloads, list of new SignedAggregatedAttestation). """ - new_aggregated_payloads = { - attestation_data: set(proofs) - for attestation_data, proofs in self.latest_new_aggregated_payloads.items() - } - - committee_signatures = self.gossip_signatures - - # Extract attestations from gossip_signatures - attestation_list: list[Attestation] = [ - Attestation(validator_id=entry.validator_id, data=attestation_data) - for attestation_data, signatures in self.gossip_signatures.items() - for entry in signatures - ] - head_state = self.states[self.head] - # Perform aggregation - aggregated_results = head_state.aggregate_gossip_signatures( - attestation_list, - committee_signatures, + aggregated_results = head_state.aggregate( + gossip_signatures=self.gossip_signatures, + new_payloads=self.latest_new_aggregated_payloads, + known_payloads=self.latest_known_aggregated_payloads, ) # Create list of aggregated attestations for broadcasting - new_aggregates = [ - SignedAggregatedAttestation(data=att.data, proof=sig) for att, sig in aggregated_results - ] + # and update the store with the new aggregated payloads + new_aggregates: list[SignedAggregatedAttestation] = [] + new_aggregated_payloads: dict[AttestationData, set[AggregatedSignatureProof]] = {} - # Compute new aggregated payloads - new_gossip_sigs = { - attestation_data: set(signatures) - for attestation_data, signatures in self.gossip_signatures.items() - } - for aggregated_attestation, aggregated_signature in aggregated_results: - attestation_data = aggregated_attestation.data - new_aggregated_payloads.setdefault(attestation_data, set()).add(aggregated_signature) - - validator_ids = set(aggregated_signature.participants.to_validator_indices()) - existing_entries = new_gossip_sigs.get(attestation_data) - if existing_entries: - remaining = {e for e in existing_entries if e.validator_id not in validator_ids} - if remaining: - new_gossip_sigs[attestation_data] = remaining - else: - del new_gossip_sigs[attestation_data] + for att, proof in aggregated_results: + new_aggregates.append(SignedAggregatedAttestation(data=att.data, proof=proof)) + new_aggregated_payloads.setdefault(att.data, set()).add(proof) - return self.model_copy( - update={ - "latest_new_aggregated_payloads": new_aggregated_payloads, - "gossip_signatures": new_gossip_sigs, - } - ), new_aggregates + return self.model_copy(update={"latest_new_aggregated_payloads": new_aggregated_payloads, "gossip_signatures": {}}), new_aggregates def tick_interval( self, has_proposal: bool, is_aggregator: bool = False @@ -1061,7 +1026,7 @@ def tick_interval( case 0 if has_proposal: store = store.accept_new_attestations() case 2 if is_aggregator: - store, new_aggregates = store.aggregate_committee_signatures() + store, new_aggregates = store.aggregate_signatures_and_payloads() case 3: store = store.update_safe_target() case 4: diff --git a/src/lean_spec/subspecs/xmss/aggregation.py b/src/lean_spec/subspecs/xmss/aggregation.py index a9a54810..9129ac68 100644 --- a/src/lean_spec/subspecs/xmss/aggregation.py +++ b/src/lean_spec/subspecs/xmss/aggregation.py @@ -15,10 +15,14 @@ from lean_spec.config import LEAN_ENV, LeanEnvMode from lean_spec.subspecs.containers.attestation import AggregationBits from lean_spec.subspecs.containers.slot import Slot +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.types import ByteListMiB, Bytes32, Container from .containers import PublicKey, Signature +INV_PROOF_SIZE: int = 2 +"""Protocol-level inverse proof size parameter for aggregation (range 1-4).""" + class AggregationError(Exception): """Raised when signature aggregation or verification fails.""" @@ -44,39 +48,68 @@ class AggregatedSignatureProof(Container): proof_data: ByteListMiB """The raw aggregated proof bytes from leanVM.""" + bytecode_point: ByteListMiB | None = None + """ + Serialized bytecode-point claim data from recursive aggregation. + + If the bytecode point is not provided, the proof is not recursive. + """ + @classmethod def aggregate( cls, - participants: AggregationBits, - public_keys: Sequence[PublicKey], - signatures: Sequence[Signature], + xmss_participants: AggregationBits | None, + children: Sequence[Self], + raw_xmss: Sequence[tuple[PublicKey, Signature]], message: Bytes32, slot: Slot, mode: LeanEnvMode | None = None, ) -> Self: """ - Aggregate individual XMSS signatures into a single proof. + Aggregate raw_xmss signatures and children proofs into a single proof. Args: - participants: Bitfield of validators whose signatures are included. - public_keys: Public keys of the signers (must match signatures order). - signatures: Individual XMSS signatures to aggregate. + xmss_participants: Bitfield of validators whose raw_signatures are provided. + children: Sequence of child proofs to aggregate. + raw_xmss: Sequence of (public key, signature) tuples to aggregate. message: The 32-byte message that was signed. slot: The slot in which the signatures were created. mode: The mode to use for the aggregation (test or prod). Returns: - An aggregated signature proof covering all participants. + An aggregated signature proof covering raw signers and all child participants. Raises: AggregationError: If aggregation fails. """ + if not raw_xmss and not children: + raise AggregationError("At least one raw signature or child proof is required") + + if raw_xmss and xmss_participants is None: + raise AggregationError("xmss_participants is required when raw_xmss is provided") + + if not raw_xmss and len(children) < 2: + raise AggregationError( + "At least two child proofs are required when no raw signatures are provided" + ) + + aggregated_validator_ids: set[ValidatorIndex] = set() + aggregated_validator_ids.update(xmss_participants.to_validator_indices()) + + if len(aggregated_validator_ids) != len(raw_xmss): + raise AggregationError("The number of raw signatures does not match the number of XMSS participants") + + # Include child participants in the aggregated participants + for child in children: + aggregated_validator_ids.update(child.participants.to_validator_indices()) + participants = AggregationBits.from_validator_indices(ValidatorIndices(data=aggregated_validator_ids)) + mode = mode or LEAN_ENV setup_prover(mode=mode) try: proof_bytes = aggregate_signatures( - [pk.encode_bytes() for pk in public_keys], - [sig.encode_bytes() for sig in signatures], + [pk.encode_bytes() for pk, _ in raw_xmss], + [sig.encode_bytes() for _, sig in raw_xmss], message, slot, mode=mode,