Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions src/lean_spec/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,6 @@
logger = logging.getLogger(__name__)


def is_enr_string(bootnode: str) -> bool:
"""
Check if bootnode string is an ENR (vs multiaddr).

Uses prefix detection rather than attempting full parsing.
This is both faster and avoids import overhead for simple checks.

Per EIP-778, all ENR strings begin with "enr:" followed by base64url content.
"""
return bootnode.startswith("enr:")


def resolve_bootnode(bootnode: str) -> str:
"""
Resolve a bootnode string to a multiaddr.
Expand All @@ -92,7 +80,7 @@ def resolve_bootnode(bootnode: str) -> str:
Raises:
ValueError: If ENR is malformed or has no UDP connection info.
"""
if is_enr_string(bootnode):
if bootnode.startswith("enr:"):
enr = ENR.from_string(bootnode)

# Verify structural validity (correct scheme, public key present).
Expand Down Expand Up @@ -254,7 +242,7 @@ async def _init_from_checkpoint(
#
# This is defense in depth. We trust the source, but still verify
# basic invariants before using the state.
if not await verify_checkpoint_state(state):
if not verify_checkpoint_state(state):
logger.error("Checkpoint state verification failed")
return None

Expand Down Expand Up @@ -675,14 +663,14 @@ def main() -> None:
try:
asyncio.run(
run_node(
args.genesis,
args.bootnodes,
args.listen,
args.checkpoint_sync_url,
args.validator_keys,
args.node_id,
args.genesis_time_now,
args.is_aggregator,
genesis_path=args.genesis,
bootnodes=args.bootnodes,
listen_addr=args.listen,
checkpoint_sync_url=args.checkpoint_sync_url,
validator_keys_path=args.validator_keys,
node_id=args.node_id,
genesis_time_now=args.genesis_time_now,
is_aggregator=args.is_aggregator,
)
)
except KeyboardInterrupt:
Expand Down
8 changes: 0 additions & 8 deletions src/lean_spec/snappy/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
VARINT_DATA_MASK,
)

# ===========================================================================
# Varint Encoding
# ===========================================================================
#
# Varints encode integers using as few bytes as possible.
# - Small values use fewer bytes.
Expand Down Expand Up @@ -184,9 +182,7 @@ def decode_varint32(data: bytes, offset: int = 0) -> tuple[int, int]:
return result, bytes_read


# ===========================================================================
# Tag Byte Encoding - Literals
# ===========================================================================
#
# Literals are raw bytes that couldn't be compressed (no match found).
# A literal tag tells the decoder: "copy the next N bytes as-is".
Expand Down Expand Up @@ -296,9 +292,7 @@ def encode_literal_tag(length: int) -> bytes:
raise ValueError(f"Literal length too large: {length}")


# ===========================================================================
# Tag Byte Encoding - Copies
# ===========================================================================
#
# Copies are backreferences to already-decompressed data.
# A copy tag tells the decoder: "go back OFFSET bytes, copy LENGTH bytes".
Expand Down Expand Up @@ -540,9 +534,7 @@ def _encode_copy_4(length: int, offset: int) -> bytes:
)


# ===========================================================================
# Tag Decoding
# ===========================================================================
#
# Decoding is the inverse of encoding.
# Given a compressed stream, we parse tags to reconstruct the original data.
Expand Down
2 changes: 1 addition & 1 deletion src/lean_spec/subspecs/api/endpoints/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def handle_finalized(request: web.Request) -> web.Response:
try:
ssz_bytes = await asyncio.to_thread(state.encode_bytes)
except Exception as e:
logger.error(f"Failed to encode state: {e}")
logger.error("Failed to encode state: %s", e)
raise web.HTTPInternalServerError(reason="Encoding failed") from e

return web.Response(body=ssz_bytes, content_type="application/octet-stream")
2 changes: 1 addition & 1 deletion src/lean_spec/subspecs/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def start(self) -> None:
self._site = web.TCPSite(self._runner, self.config.host, self.config.port)
await self._site.start()

logger.info(f"API server listening on {self.config.host}:{self.config.port}")
logger.info("API server listening on %s:%d", self.config.host, self.config.port)

async def run(self) -> None:
"""
Expand Down
7 changes: 7 additions & 0 deletions src/lean_spec/subspecs/chain/clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
coordinate block proposals and attestations.
"""

import asyncio
from dataclasses import dataclass
from time import time as wall_time
from typing import Callable
Expand Down Expand Up @@ -94,3 +95,9 @@ def seconds_until_next_interval(self) -> float:
# Time until next boundary (may be 0 if exactly at boundary).
ms_until_next = int(MILLISECONDS_PER_INTERVAL) - time_into_interval_ms
return ms_until_next / 1000.0

async def sleep_until_next_interval(self) -> None:
"""Sleep until the next interval boundary."""
sleep_time = self.seconds_until_next_interval()
if sleep_time > 0:
await asyncio.sleep(sleep_time)
13 changes: 1 addition & 12 deletions src/lean_spec/subspecs/chain/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def run(self) -> None:
and total_interval <= last_handled_total_interval
)
if already_handled:
await self._sleep_until_next_interval()
await self.clock.sleep_until_next_interval()
# Check if stopped during sleep.
if not self._running:
break
Expand Down Expand Up @@ -214,17 +214,6 @@ async def _initial_tick(self) -> Interval | None:

return None

async def _sleep_until_next_interval(self) -> None:
"""
Sleep until the next interval boundary.

Uses the clock to calculate precise sleep duration, ensuring tick
timing is aligned with network consensus expectations.
"""
sleep_time = self.clock.seconds_until_next_interval()
if sleep_time > 0:
await asyncio.sleep(sleep_time)

def stop(self) -> None:
"""
Stop the service.
Expand Down
12 changes: 2 additions & 10 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,7 @@ def process_attestations(
start_slot = int(finalized_slot) + 1
root_to_slot: dict[Bytes32, Slot] = {}
for i in range(start_slot, len(self.historical_block_hashes)):
root = self.historical_block_hashes[i]
slot = Slot(i)
if root not in root_to_slot or slot > root_to_slot[root]:
root_to_slot[root] = slot
root_to_slot[self.historical_block_hashes[i]] = Slot(i)

# Process each attestation independently.
#
Expand Down Expand Up @@ -994,10 +991,5 @@ def select_aggregated_proofs(
)
remaining -= covered

# Final Assembly
if not results:
return [], []

# Unzip the results into parallel lists.
aggregated_attestations, aggregated_proofs = zip(*results, strict=True)
return list(aggregated_attestations), list(aggregated_proofs)
return [att for att, _ in results], [proof for _, proof in results]
98 changes: 36 additions & 62 deletions src/lean_spec/subspecs/forkchoice/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

__all__ = ["Store"]

import copy
from collections import defaultdict

from lean_spec.subspecs.chain.clock import Interval
Expand Down Expand Up @@ -40,7 +39,6 @@
from lean_spec.types import (
ZERO_HASH,
Bytes32,
Uint64,
)
from lean_spec.types.container import Container

Expand Down Expand Up @@ -401,30 +399,25 @@ def on_gossip_attestation(
), "Signature verification failed"

# Store signature and attestation data for later aggregation
new_commitee_sigs = dict(self.gossip_signatures)
new_committee_sigs = dict(self.gossip_signatures)
new_attestation_data_by_root = dict(self.attestation_data_by_root)
data_root = attestation_data.data_root_bytes()

if is_aggregator:
assert self.validator_id is not None, "Current validator ID must be set for aggregation"
current_validator_subnet = self.validator_id.compute_subnet_id(
ATTESTATION_COMMITTEE_COUNT
)
current_subnet = self.validator_id.compute_subnet_id(ATTESTATION_COMMITTEE_COUNT)
attester_subnet = validator_id.compute_subnet_id(ATTESTATION_COMMITTEE_COUNT)
if current_validator_subnet != attester_subnet:
# Not part of our committee; ignore for committee aggregation.
pass
else:
if current_subnet == attester_subnet:
sig_key = SignatureKey(validator_id, data_root)
new_commitee_sigs[sig_key] = signature
new_committee_sigs[sig_key] = signature

# Store attestation data for later extraction
new_attestation_data_by_root[data_root] = attestation_data

# Return store with updated signature map and attestation data
return self.model_copy(
update={
"gossip_signatures": new_commitee_sigs,
"gossip_signatures": new_committee_sigs,
"attestation_data_by_root": new_attestation_data_by_root,
}
)
Expand Down Expand Up @@ -465,7 +458,7 @@ def on_gossip_aggregated_attestation(
# Ensure all participants exist in the active set
validators = key_state.validators
for validator_id in validator_ids:
assert validator_id < ValidatorIndex(len(validators)), (
assert validator_id.is_valid(len(validators)), (
f"Validator {validator_id} not found in state {data.target.root.hex()}"
)

Expand All @@ -484,9 +477,10 @@ def on_gossip_aggregated_attestation(
f"Committee aggregation signature verification failed: {exc}"
) from exc

# Copy the aggregated proof map for updates
# Must deep copy the lists to maintain immutability of previous store snapshots
new_aggregated_payloads = copy.deepcopy(self.latest_new_aggregated_payloads)
# Shallow-copy the dict and its list values to maintain immutability
new_aggregated_payloads = {
k: list(v) for k, v in self.latest_new_aggregated_payloads.items()
}
data_root = data.data_root_bytes()

# Store attestation data by root for later retrieval
Expand Down Expand Up @@ -577,20 +571,14 @@ def on_block(
valid_signatures = signed_block_with_attestation.verify_signatures(parent_state, scheme)

# Execute state transition function to compute post-block state
post_state = copy.deepcopy(parent_state).state_transition(block, valid_signatures)
post_state = parent_state.state_transition(block, valid_signatures)

# If post-state has a higher justified checkpoint, update it to the store.
latest_justified = (
post_state.latest_justified
if post_state.latest_justified.slot > self.latest_justified.slot
else self.latest_justified
# Propagate any checkpoint advances from the post-state.
latest_justified = max(
post_state.latest_justified, self.latest_justified, key=lambda c: c.slot
)

# If post-state has a higher finalized checkpoint, update it to the store.
latest_finalized = (
post_state.latest_finalized
if post_state.latest_finalized.slot > self.latest_finalized.slot
else self.latest_finalized
latest_finalized = max(
post_state.latest_finalized, self.latest_finalized, key=lambda c: c.slot
)

# Create new store with the computed data.
Expand All @@ -612,11 +600,11 @@ def on_block(
)

# Copy the aggregated proof map for updates
# Must deep copy the lists to maintain immutability of previous store snapshots
# Shallow-copy the dict and its list values to maintain immutability
# Block attestations go directly to "known" payloads (like is_from_block=True)
block_proofs: dict[SignatureKey, list[AggregatedSignatureProof]] = copy.deepcopy(
store.latest_known_aggregated_payloads
)
block_proofs: dict[SignatureKey, list[AggregatedSignatureProof]] = {
k: list(v) for k, v in store.latest_known_aggregated_payloads.items()
}

# Store attestation data by root for later retrieval
new_attestation_data_by_root = dict(store.attestation_data_by_root)
Expand Down Expand Up @@ -1040,14 +1028,9 @@ def aggregate_committee_signatures(self) -> tuple["Store", list[SignedAggregated
)

# Create list of aggregated attestations for broadcasting
new_aggregates: list[SignedAggregatedAttestation] = []
for aggregated_attestation, aggregated_signature in aggregated_results:
new_aggregates.append(
SignedAggregatedAttestation(
data=aggregated_attestation.data,
proof=aggregated_signature,
)
)
new_aggregates = [
SignedAggregatedAttestation(data=att.data, proof=sig) for att, sig in aggregated_results
]

# Compute new aggregated payloads
new_gossip_sigs = dict(self.gossip_signatures)
Expand Down Expand Up @@ -1122,20 +1105,15 @@ def tick_interval(
current_interval = store.time % INTERVALS_PER_SLOT
new_aggregates: list[SignedAggregatedAttestation] = []

if current_interval == Uint64(0):
# Start of slot - process attestations if proposal exists
if has_proposal:
match int(current_interval):
case 0 if has_proposal:
store = store.accept_new_attestations()
elif current_interval == Uint64(2):
# Aggregation interval - aggregators create proofs
if is_aggregator:
case 2 if is_aggregator:
store, new_aggregates = store.aggregate_committee_signatures()
elif current_interval == Uint64(3):
# Fast confirm - update safe target based on received proofs
store = store.update_safe_target()
elif current_interval == Uint64(4):
# End of slot - accept accumulated attestations
store = store.accept_new_attestations()
case 3:
store = store.update_safe_target()
case 4:
store = store.accept_new_attestations()

return store, new_aggregates

Expand Down Expand Up @@ -1384,22 +1362,18 @@ def produce_block_with_signatures(
# Locally produced blocks bypass normal block processing.
# We must manually propagate any checkpoint advances.
# Higher slots indicate more recent justified/finalized states.
latest_justified = (
final_post_state.latest_justified
if final_post_state.latest_justified.slot > store.latest_justified.slot
else store.latest_justified
latest_justified = max(
final_post_state.latest_justified, store.latest_justified, key=lambda c: c.slot
)
latest_finalized = (
final_post_state.latest_finalized
if final_post_state.latest_finalized.slot > store.latest_finalized.slot
else store.latest_finalized
latest_finalized = max(
final_post_state.latest_finalized, store.latest_finalized, key=lambda c: c.slot
)

# Persist block and state immutably.
new_store = store.model_copy(
update={
"blocks": {**store.blocks, block_hash: final_block},
"states": {**store.states, block_hash: final_post_state},
"blocks": store.blocks | {block_hash: final_block},
"states": store.states | {block_hash: final_post_state},
"latest_justified": latest_justified,
"latest_finalized": latest_finalized,
}
Expand Down
Loading
Loading