diff --git a/src/lean_spec/__main__.py b/src/lean_spec/__main__.py index 5869fefa..b98a6ca2 100644 --- a/src/lean_spec/__main__.py +++ b/src/lean_spec/__main__.py @@ -61,18 +61,6 @@ logger = logging.getLogger(__name__) -def is_enr_string(bootnode: str) -> bool: - """ - Check if bootnode string is an ENR (vs multiaddr). - - Uses prefix detection rather than attempting full parsing. - This is both faster and avoids import overhead for simple checks. - - Per EIP-778, all ENR strings begin with "enr:" followed by base64url content. - """ - return bootnode.startswith("enr:") - - def resolve_bootnode(bootnode: str) -> str: """ Resolve a bootnode string to a multiaddr. @@ -92,7 +80,7 @@ def resolve_bootnode(bootnode: str) -> str: Raises: ValueError: If ENR is malformed or has no UDP connection info. """ - if is_enr_string(bootnode): + if bootnode.startswith("enr:"): enr = ENR.from_string(bootnode) # Verify structural validity (correct scheme, public key present). @@ -254,7 +242,7 @@ async def _init_from_checkpoint( # # This is defense in depth. We trust the source, but still verify # basic invariants before using the state. - if not await verify_checkpoint_state(state): + if not verify_checkpoint_state(state): logger.error("Checkpoint state verification failed") return None @@ -675,14 +663,14 @@ def main() -> None: try: asyncio.run( run_node( - args.genesis, - args.bootnodes, - args.listen, - args.checkpoint_sync_url, - args.validator_keys, - args.node_id, - args.genesis_time_now, - args.is_aggregator, + genesis_path=args.genesis, + bootnodes=args.bootnodes, + listen_addr=args.listen, + checkpoint_sync_url=args.checkpoint_sync_url, + validator_keys_path=args.validator_keys, + node_id=args.node_id, + genesis_time_now=args.genesis_time_now, + is_aggregator=args.is_aggregator, ) ) except KeyboardInterrupt: diff --git a/src/lean_spec/snappy/encoding.py b/src/lean_spec/snappy/encoding.py index dd9fc509..2711de7f 100644 --- a/src/lean_spec/snappy/encoding.py +++ b/src/lean_spec/snappy/encoding.py @@ -33,9 +33,7 @@ VARINT_DATA_MASK, ) -# =========================================================================== # Varint Encoding -# =========================================================================== # # Varints encode integers using as few bytes as possible. # - Small values use fewer bytes. @@ -184,9 +182,7 @@ def decode_varint32(data: bytes, offset: int = 0) -> tuple[int, int]: return result, bytes_read -# =========================================================================== # Tag Byte Encoding - Literals -# =========================================================================== # # Literals are raw bytes that couldn't be compressed (no match found). # A literal tag tells the decoder: "copy the next N bytes as-is". @@ -296,9 +292,7 @@ def encode_literal_tag(length: int) -> bytes: raise ValueError(f"Literal length too large: {length}") -# =========================================================================== # Tag Byte Encoding - Copies -# =========================================================================== # # Copies are backreferences to already-decompressed data. # A copy tag tells the decoder: "go back OFFSET bytes, copy LENGTH bytes". @@ -540,9 +534,7 @@ def _encode_copy_4(length: int, offset: int) -> bytes: ) -# =========================================================================== # Tag Decoding -# =========================================================================== # # Decoding is the inverse of encoding. # Given a compressed stream, we parse tags to reconstruct the original data. diff --git a/src/lean_spec/subspecs/api/endpoints/states.py b/src/lean_spec/subspecs/api/endpoints/states.py index d79176ee..ce38edeb 100644 --- a/src/lean_spec/subspecs/api/endpoints/states.py +++ b/src/lean_spec/subspecs/api/endpoints/states.py @@ -40,7 +40,7 @@ async def handle_finalized(request: web.Request) -> web.Response: try: ssz_bytes = await asyncio.to_thread(state.encode_bytes) except Exception as e: - logger.error(f"Failed to encode state: {e}") + logger.error("Failed to encode state: %s", e) raise web.HTTPInternalServerError(reason="Encoding failed") from e return web.Response(body=ssz_bytes, content_type="application/octet-stream") diff --git a/src/lean_spec/subspecs/api/server.py b/src/lean_spec/subspecs/api/server.py index bd55fde4..8c6b4983 100644 --- a/src/lean_spec/subspecs/api/server.py +++ b/src/lean_spec/subspecs/api/server.py @@ -99,7 +99,7 @@ async def start(self) -> None: self._site = web.TCPSite(self._runner, self.config.host, self.config.port) await self._site.start() - logger.info(f"API server listening on {self.config.host}:{self.config.port}") + logger.info("API server listening on %s:%d", self.config.host, self.config.port) async def run(self) -> None: """ diff --git a/src/lean_spec/subspecs/chain/clock.py b/src/lean_spec/subspecs/chain/clock.py index 51a371e0..0ef5502c 100644 --- a/src/lean_spec/subspecs/chain/clock.py +++ b/src/lean_spec/subspecs/chain/clock.py @@ -8,6 +8,7 @@ coordinate block proposals and attestations. """ +import asyncio from dataclasses import dataclass from time import time as wall_time from typing import Callable @@ -94,3 +95,9 @@ def seconds_until_next_interval(self) -> float: # Time until next boundary (may be 0 if exactly at boundary). ms_until_next = int(MILLISECONDS_PER_INTERVAL) - time_into_interval_ms return ms_until_next / 1000.0 + + async def sleep_until_next_interval(self) -> None: + """Sleep until the next interval boundary.""" + sleep_time = self.seconds_until_next_interval() + if sleep_time > 0: + await asyncio.sleep(sleep_time) diff --git a/src/lean_spec/subspecs/chain/service.py b/src/lean_spec/subspecs/chain/service.py index 791d6c1b..6d2c9131 100644 --- a/src/lean_spec/subspecs/chain/service.py +++ b/src/lean_spec/subspecs/chain/service.py @@ -104,7 +104,7 @@ async def run(self) -> None: and total_interval <= last_handled_total_interval ) if already_handled: - await self._sleep_until_next_interval() + await self.clock.sleep_until_next_interval() # Check if stopped during sleep. if not self._running: break @@ -214,17 +214,6 @@ async def _initial_tick(self) -> Interval | None: return None - async def _sleep_until_next_interval(self) -> None: - """ - Sleep until the next interval boundary. - - Uses the clock to calculate precise sleep duration, ensuring tick - timing is aligned with network consensus expectations. - """ - sleep_time = self.clock.seconds_until_next_interval() - if sleep_time > 0: - await asyncio.sleep(sleep_time) - def stop(self) -> None: """ Stop the service. diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index 0d94fd85..57d6e2bb 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -437,10 +437,7 @@ def process_attestations( start_slot = int(finalized_slot) + 1 root_to_slot: dict[Bytes32, Slot] = {} for i in range(start_slot, len(self.historical_block_hashes)): - root = self.historical_block_hashes[i] - slot = Slot(i) - if root not in root_to_slot or slot > root_to_slot[root]: - root_to_slot[root] = slot + root_to_slot[self.historical_block_hashes[i]] = Slot(i) # Process each attestation independently. # @@ -994,10 +991,5 @@ def select_aggregated_proofs( ) remaining -= covered - # Final Assembly - if not results: - return [], [] - # Unzip the results into parallel lists. - aggregated_attestations, aggregated_proofs = zip(*results, strict=True) - return list(aggregated_attestations), list(aggregated_proofs) + return [att for att, _ in results], [proof for _, proof in results] diff --git a/src/lean_spec/subspecs/forkchoice/store.py b/src/lean_spec/subspecs/forkchoice/store.py index b60bea00..827b93db 100644 --- a/src/lean_spec/subspecs/forkchoice/store.py +++ b/src/lean_spec/subspecs/forkchoice/store.py @@ -6,7 +6,6 @@ __all__ = ["Store"] -import copy from collections import defaultdict from lean_spec.subspecs.chain.clock import Interval @@ -40,7 +39,6 @@ from lean_spec.types import ( ZERO_HASH, Bytes32, - Uint64, ) from lean_spec.types.container import Container @@ -401,22 +399,17 @@ def on_gossip_attestation( ), "Signature verification failed" # Store signature and attestation data for later aggregation - new_commitee_sigs = dict(self.gossip_signatures) + new_committee_sigs = dict(self.gossip_signatures) new_attestation_data_by_root = dict(self.attestation_data_by_root) data_root = attestation_data.data_root_bytes() if is_aggregator: assert self.validator_id is not None, "Current validator ID must be set for aggregation" - current_validator_subnet = self.validator_id.compute_subnet_id( - ATTESTATION_COMMITTEE_COUNT - ) + current_subnet = self.validator_id.compute_subnet_id(ATTESTATION_COMMITTEE_COUNT) attester_subnet = validator_id.compute_subnet_id(ATTESTATION_COMMITTEE_COUNT) - if current_validator_subnet != attester_subnet: - # Not part of our committee; ignore for committee aggregation. - pass - else: + if current_subnet == attester_subnet: sig_key = SignatureKey(validator_id, data_root) - new_commitee_sigs[sig_key] = signature + new_committee_sigs[sig_key] = signature # Store attestation data for later extraction new_attestation_data_by_root[data_root] = attestation_data @@ -424,7 +417,7 @@ def on_gossip_attestation( # Return store with updated signature map and attestation data return self.model_copy( update={ - "gossip_signatures": new_commitee_sigs, + "gossip_signatures": new_committee_sigs, "attestation_data_by_root": new_attestation_data_by_root, } ) @@ -465,7 +458,7 @@ def on_gossip_aggregated_attestation( # Ensure all participants exist in the active set validators = key_state.validators for validator_id in validator_ids: - assert validator_id < ValidatorIndex(len(validators)), ( + assert validator_id.is_valid(len(validators)), ( f"Validator {validator_id} not found in state {data.target.root.hex()}" ) @@ -484,9 +477,10 @@ def on_gossip_aggregated_attestation( f"Committee aggregation signature verification failed: {exc}" ) from exc - # Copy the aggregated proof map for updates - # Must deep copy the lists to maintain immutability of previous store snapshots - new_aggregated_payloads = copy.deepcopy(self.latest_new_aggregated_payloads) + # Shallow-copy the dict and its list values to maintain immutability + new_aggregated_payloads = { + k: list(v) for k, v in self.latest_new_aggregated_payloads.items() + } data_root = data.data_root_bytes() # Store attestation data by root for later retrieval @@ -577,20 +571,14 @@ def on_block( valid_signatures = signed_block_with_attestation.verify_signatures(parent_state, scheme) # Execute state transition function to compute post-block state - post_state = copy.deepcopy(parent_state).state_transition(block, valid_signatures) + post_state = parent_state.state_transition(block, valid_signatures) - # If post-state has a higher justified checkpoint, update it to the store. - latest_justified = ( - post_state.latest_justified - if post_state.latest_justified.slot > self.latest_justified.slot - else self.latest_justified + # Propagate any checkpoint advances from the post-state. + latest_justified = max( + post_state.latest_justified, self.latest_justified, key=lambda c: c.slot ) - - # If post-state has a higher finalized checkpoint, update it to the store. - latest_finalized = ( - post_state.latest_finalized - if post_state.latest_finalized.slot > self.latest_finalized.slot - else self.latest_finalized + latest_finalized = max( + post_state.latest_finalized, self.latest_finalized, key=lambda c: c.slot ) # Create new store with the computed data. @@ -612,11 +600,11 @@ def on_block( ) # Copy the aggregated proof map for updates - # Must deep copy the lists to maintain immutability of previous store snapshots + # Shallow-copy the dict and its list values to maintain immutability # Block attestations go directly to "known" payloads (like is_from_block=True) - block_proofs: dict[SignatureKey, list[AggregatedSignatureProof]] = copy.deepcopy( - store.latest_known_aggregated_payloads - ) + block_proofs: dict[SignatureKey, list[AggregatedSignatureProof]] = { + k: list(v) for k, v in store.latest_known_aggregated_payloads.items() + } # Store attestation data by root for later retrieval new_attestation_data_by_root = dict(store.attestation_data_by_root) @@ -1040,14 +1028,9 @@ def aggregate_committee_signatures(self) -> tuple["Store", list[SignedAggregated ) # Create list of aggregated attestations for broadcasting - new_aggregates: list[SignedAggregatedAttestation] = [] - for aggregated_attestation, aggregated_signature in aggregated_results: - new_aggregates.append( - SignedAggregatedAttestation( - data=aggregated_attestation.data, - proof=aggregated_signature, - ) - ) + new_aggregates = [ + SignedAggregatedAttestation(data=att.data, proof=sig) for att, sig in aggregated_results + ] # Compute new aggregated payloads new_gossip_sigs = dict(self.gossip_signatures) @@ -1122,20 +1105,15 @@ def tick_interval( current_interval = store.time % INTERVALS_PER_SLOT new_aggregates: list[SignedAggregatedAttestation] = [] - if current_interval == Uint64(0): - # Start of slot - process attestations if proposal exists - if has_proposal: + match int(current_interval): + case 0 if has_proposal: store = store.accept_new_attestations() - elif current_interval == Uint64(2): - # Aggregation interval - aggregators create proofs - if is_aggregator: + case 2 if is_aggregator: store, new_aggregates = store.aggregate_committee_signatures() - elif current_interval == Uint64(3): - # Fast confirm - update safe target based on received proofs - store = store.update_safe_target() - elif current_interval == Uint64(4): - # End of slot - accept accumulated attestations - store = store.accept_new_attestations() + case 3: + store = store.update_safe_target() + case 4: + store = store.accept_new_attestations() return store, new_aggregates @@ -1384,22 +1362,18 @@ def produce_block_with_signatures( # Locally produced blocks bypass normal block processing. # We must manually propagate any checkpoint advances. # Higher slots indicate more recent justified/finalized states. - latest_justified = ( - final_post_state.latest_justified - if final_post_state.latest_justified.slot > store.latest_justified.slot - else store.latest_justified + latest_justified = max( + final_post_state.latest_justified, store.latest_justified, key=lambda c: c.slot ) - latest_finalized = ( - final_post_state.latest_finalized - if final_post_state.latest_finalized.slot > store.latest_finalized.slot - else store.latest_finalized + latest_finalized = max( + final_post_state.latest_finalized, store.latest_finalized, key=lambda c: c.slot ) # Persist block and state immutably. new_store = store.model_copy( update={ - "blocks": {**store.blocks, block_hash: final_block}, - "states": {**store.states, block_hash: final_post_state}, + "blocks": store.blocks | {block_hash: final_block}, + "states": store.states | {block_hash: final_post_state}, "latest_justified": latest_justified, "latest_finalized": latest_finalized, } diff --git a/src/lean_spec/subspecs/koalabear/field.py b/src/lean_spec/subspecs/koalabear/field.py index 6f75b200..ed71930a 100644 --- a/src/lean_spec/subspecs/koalabear/field.py +++ b/src/lean_spec/subspecs/koalabear/field.py @@ -80,7 +80,7 @@ def __init__(self, value: int) -> None: Raises: TypeError: If value is not an integer. """ - if not isinstance(value, int): + if not isinstance(value, int) or isinstance(value, bool): raise TypeError(f"Field value must be an integer, got {type(value).__name__}") # Normalize to [0, P) - handles negative values correctly @@ -98,9 +98,8 @@ def get_byte_length(cls) -> int: def serialize(self, stream: IO[bytes]) -> int: """Serialize the field element to a binary stream.""" - data = self.value.to_bytes(P_BYTES, byteorder="little") - stream.write(data) - return len(data) + stream.write(self.value.to_bytes(P_BYTES, byteorder="little")) + return P_BYTES @classmethod def deserialize(cls, stream: IO[bytes], scope: int) -> Self: diff --git a/src/lean_spec/subspecs/networking/client/event_source.py b/src/lean_spec/subspecs/networking/client/event_source.py index e844bb63..489c270c 100644 --- a/src/lean_spec/subspecs/networking/client/event_source.py +++ b/src/lean_spec/subspecs/networking/client/event_source.py @@ -291,9 +291,9 @@ def decode_message( # This prevents wasting CPU on malformed or cross-fork messages. try: topic = GossipTopic.from_string_validated(topic_str, self.fork_digest) - except (ValueError, ForkMismatchError) as e: - if isinstance(e, ForkMismatchError): - raise + except ForkMismatchError: + raise + except ValueError as e: raise GossipMessageError(f"Invalid topic: {e}") from e # Step 2: Decompress Snappy-framed data. @@ -346,9 +346,9 @@ def get_topic(self, topic_str: str) -> GossipTopic: """ try: return GossipTopic.from_string_validated(topic_str, self.fork_digest) - except (ValueError, ForkMismatchError) as e: - if isinstance(e, ForkMismatchError): - raise + except ForkMismatchError: + raise + except ValueError as e: raise GossipMessageError(f"Invalid topic: {e}") from e diff --git a/src/lean_spec/subspecs/networking/discovery/service.py b/src/lean_spec/subspecs/networking/discovery/service.py index 8c300b4b..c3a0280f 100644 --- a/src/lean_spec/subspecs/networking/discovery/service.py +++ b/src/lean_spec/subspecs/networking/discovery/service.py @@ -55,7 +55,7 @@ """Interval between node liveness revalidation (5 minutes).""" -@dataclass +@dataclass(slots=True) class LookupResult: """Result of a node lookup operation.""" diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py index 02f9aba8..b9fe97b2 100644 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ b/src/lean_spec/subspecs/networking/discovery/transport.py @@ -65,7 +65,7 @@ logger = logging.getLogger(__name__) -@dataclass +@dataclass(slots=True) class PendingRequest: """Tracks a pending request awaiting response.""" @@ -88,7 +88,7 @@ class PendingRequest: """Future to complete when response arrives.""" -@dataclass +@dataclass(slots=True) class PendingMultiRequest: """Tracks a pending request that may receive multiple responses. diff --git a/src/lean_spec/subspecs/networking/transport/quic/connection.py b/src/lean_spec/subspecs/networking/transport/quic/connection.py index 33836e7a..407afc5b 100644 --- a/src/lean_spec/subspecs/networking/transport/quic/connection.py +++ b/src/lean_spec/subspecs/networking/transport/quic/connection.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio +import contextlib import ssl import tempfile from collections.abc import Awaitable, Callable @@ -419,7 +420,9 @@ class QuicConnectionManager: _config: QuicConfiguration _connections: dict[PeerId, QuicConnection] = field(default_factory=dict) _temp_dir: Path | None = None - _context_managers: list = field(default_factory=list) + _context_managers: list[contextlib.AbstractAsyncContextManager[object]] = field( + default_factory=list + ) @classmethod async def create( diff --git a/src/lean_spec/subspecs/storage/sqlite.py b/src/lean_spec/subspecs/storage/sqlite.py index 5b6d532e..d8e3070f 100644 --- a/src/lean_spec/subspecs/storage/sqlite.py +++ b/src/lean_spec/subspecs/storage/sqlite.py @@ -16,7 +16,6 @@ import sqlite3 from pathlib import Path -from typing import TYPE_CHECKING from lean_spec.subspecs.containers import Block, Checkpoint, State, ValidatorIndex from lean_spec.subspecs.containers.attestation import AttestationData @@ -31,9 +30,6 @@ STATES, ) -if TYPE_CHECKING: - pass - class SQLiteDatabase: """ @@ -376,7 +372,7 @@ def put_head_root(self, root: Bytes32) -> None: # Slot Index Operations # - # Slots are time intervals (12 seconds each). + # Slots are time intervals. # This index maps slot numbers to blocks, enabling historical queries. # Note: not every slot has a block (missed slots happen). diff --git a/src/lean_spec/subspecs/sync/__init__.py b/src/lean_spec/subspecs/sync/__init__.py index b65141b2..d23051d2 100644 --- a/src/lean_spec/subspecs/sync/__init__.py +++ b/src/lean_spec/subspecs/sync/__init__.py @@ -50,7 +50,6 @@ # Configuration constants "MAX_BLOCKS_PER_REQUEST", "MAX_CONCURRENT_REQUESTS", - "REQUEST_TIMEOUT", "MAX_CACHED_BLOCKS", "MAX_BACKFILL_DEPTH", ] @@ -67,7 +66,6 @@ MAX_BLOCKS_PER_REQUEST, MAX_CACHED_BLOCKS, MAX_CONCURRENT_REQUESTS, - REQUEST_TIMEOUT, ) from .head_sync import HeadSync, HeadSyncResult from .peer_manager import PeerManager, SyncPeer diff --git a/src/lean_spec/subspecs/sync/checkpoint_sync.py b/src/lean_spec/subspecs/sync/checkpoint_sync.py index 9710f68d..8e418804 100644 --- a/src/lean_spec/subspecs/sync/checkpoint_sync.py +++ b/src/lean_spec/subspecs/sync/checkpoint_sync.py @@ -19,12 +19,11 @@ from __future__ import annotations import logging -from typing import Any import httpx from lean_spec.subspecs.chain.config import VALIDATOR_REGISTRY_LIMIT -from lean_spec.subspecs.containers import Slot, State +from lean_spec.subspecs.containers import State from lean_spec.subspecs.ssz.hash import hash_tree_root logger = logging.getLogger(__name__) @@ -45,7 +44,7 @@ class CheckpointSyncError(Exception): """ -async def fetch_finalized_state(url: str, state_class: type[Any]) -> "State": +async def fetch_finalized_state(url: str, state_class: type["State"]) -> "State": """ Fetch finalized state from a node via checkpoint sync. @@ -66,7 +65,7 @@ async def fetch_finalized_state(url: str, state_class: type[Any]) -> "State": base_url = url.rstrip("/") full_url = f"{base_url}{FINALIZED_STATE_ENDPOINT}" - logger.info(f"Fetching finalized state from {full_url}") + logger.info("Fetching finalized state from %s", full_url) # Request SSZ binary format. # @@ -80,14 +79,14 @@ async def fetch_finalized_state(url: str, state_class: type[Any]) -> "State": response.raise_for_status() ssz_data = response.content - logger.info(f"Downloaded {len(ssz_data)} bytes of SSZ state data") + logger.info("Downloaded %d bytes of SSZ state data", len(ssz_data)) # Deserialize from SSZ bytes. # # This validates the byte stream matches the expected schema. # Malformed data will raise an exception here. state = state_class.decode_bytes(ssz_data) - logger.info(f"Deserialized state at slot {state.slot}") + logger.info("Deserialized state at slot %s", state.slot) return state @@ -103,7 +102,7 @@ async def fetch_finalized_state(url: str, state_class: type[Any]) -> "State": raise CheckpointSyncError(f"Failed to fetch state: {e}") from e -async def verify_checkpoint_state(state: "State") -> bool: +def verify_checkpoint_state(state: "State") -> bool: """ Verify that a checkpoint state is structurally valid. @@ -127,11 +126,6 @@ async def verify_checkpoint_state(state: "State") -> bool: True if valid, False otherwise. """ try: - # Sanity check: slot must be non-negative. - if state.slot < Slot(0): - logger.error("Invalid state: negative slot") - return False - # A state with no validators cannot produce blocks. validator_count = len(state.validators) if validator_count == 0: @@ -141,8 +135,9 @@ async def verify_checkpoint_state(state: "State") -> bool: # Guard against oversized states that could exhaust memory. if validator_count > int(VALIDATOR_REGISTRY_LIMIT): logger.error( - f"Invalid state: validator count {validator_count} exceeds " - f"registry limit {VALIDATOR_REGISTRY_LIMIT}" + "Invalid state: validator count %d exceeds registry limit %s", + validator_count, + VALIDATOR_REGISTRY_LIMIT, ) return False @@ -151,9 +146,9 @@ async def verify_checkpoint_state(state: "State") -> bool: # If the data was corrupted, hashing will likely fail or produce # an unexpected result. We log the root for debugging. state_root = hash_tree_root(state) - logger.info(f"Checkpoint state verified: slot={state.slot}, root={state_root}...") + logger.info("Checkpoint state verified: slot=%s, root=%s...", state.slot, state_root) return True except Exception as e: - logger.error(f"State verification failed: {e}") + logger.error("State verification failed: %s", e) return False diff --git a/src/lean_spec/subspecs/sync/config.py b/src/lean_spec/subspecs/sync/config.py index cb7fd2b4..d9de1091 100644 --- a/src/lean_spec/subspecs/sync/config.py +++ b/src/lean_spec/subspecs/sync/config.py @@ -14,9 +14,6 @@ MAX_CONCURRENT_REQUESTS: Final[int] = 2 """Maximum concurrent requests to a single peer.""" -REQUEST_TIMEOUT: Final[float] = 10.0 -"""Timeout for individual block requests in seconds.""" - MAX_CACHED_BLOCKS: Final[int] = 1024 """Maximum blocks to hold in the pending cache.""" diff --git a/src/lean_spec/subspecs/validator/registry.py b/src/lean_spec/subspecs/validator/registry.py index 7a291486..b401f23c 100644 --- a/src/lean_spec/subspecs/validator/registry.py +++ b/src/lean_spec/subspecs/validator/registry.py @@ -181,16 +181,8 @@ def get(self, index: ValidatorIndex) -> ValidatorEntry | None: """ return self._validators.get(index) - def has(self, index: ValidatorIndex) -> bool: - """ - Check if we control this validator. - - Args: - index: Validator index to check. - - Returns: - True if we have keys for this validator. - """ + def __contains__(self, index: ValidatorIndex) -> bool: + """Check if we control this validator.""" return index in self._validators def indices(self) -> ValidatorIndices: diff --git a/src/lean_spec/subspecs/validator/service.py b/src/lean_spec/subspecs/validator/service.py index 610ab092..dfa1a685 100644 --- a/src/lean_spec/subspecs/validator/service.py +++ b/src/lean_spec/subspecs/validator/service.py @@ -141,7 +141,7 @@ async def run(self) -> None: and total_interval <= last_handled_total_interval ) if already_handled: - await self._sleep_until_next_interval() + await self.clock.sleep_until_next_interval() total_interval = self.clock.total_intervals() # Skip if we have no validators to manage. @@ -586,16 +586,6 @@ def _ensure_prepared_for_slot( return updated_entry - async def _sleep_until_next_interval(self) -> None: - """ - Sleep until the next interval boundary. - - Uses the clock to calculate precise sleep duration. - """ - sleep_time = self.clock.seconds_until_next_interval() - if sleep_time > 0: - await asyncio.sleep(sleep_time) - def stop(self) -> None: """ Stop the service. diff --git a/src/lean_spec/subspecs/xmss/hypercube.py b/src/lean_spec/subspecs/xmss/hypercube.py index f1bbf86d..32762656 100644 --- a/src/lean_spec/subspecs/xmss/hypercube.py +++ b/src/lean_spec/subspecs/xmss/hypercube.py @@ -240,16 +240,13 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> list[int]: # This loop finds which block of sub-hypercubes the index `x_curr` falls into. # # It skips over full blocks by subtracting their size from `x_curr` until found. - ji = None range_start = max(0, d_curr - (w - 1) * dim_remaining) - for j in range(range_start, min(w, d_curr + 1)): - count = prev_dim_layer_info.sizes[d_curr - j] + for ji in range(range_start, min(w, d_curr + 1)): + count = prev_dim_layer_info.sizes[d_curr - ji] if x_curr < count: - ji = j break x_curr -= count - - if ji is None: + else: raise RuntimeError("Internal logic error: failed to find coordinate") # Convert the block's distance contribution `ji` to a coordinate `ai`. diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index ceef26a8..2c1bbe25 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -346,14 +346,6 @@ def sign(self, sk: SecretKey, slot: Slot, message: Bytes32) -> Signature: boundary = (int(sk.left_bottom_tree_index) + 1) * leaves_per_bottom_tree bottom_tree = sk.left_bottom_tree if slot_int < boundary else sk.right_bottom_tree - # Ensure bottom tree exists - if bottom_tree is None: - raise ValueError( - f"Slot {slot} requires bottom tree but it is not available. " - f"Prepared interval may have been exceeded. Call advance_preparation() " - f"to slide the window forward." - ) - # Generate the combined authentication path path = combined_path(sk.top_tree, bottom_tree, slot) @@ -410,7 +402,7 @@ def verify(self, pk: PublicKey, slot: Slot, message: Bytes32, sig: Signature) -> # # Return False instead of raising to avoid panic on invalid signatures. # The slot is attacker-controlled input. - if slot > self.config.LIFETIME: + if slot >= self.config.LIFETIME: return False # Re-encode the message using the randomness `rho` from the signature. diff --git a/src/lean_spec/types/boolean.py b/src/lean_spec/types/boolean.py index 9c90abc9..4043bd3a 100644 --- a/src/lean_spec/types/boolean.py +++ b/src/lean_spec/types/boolean.py @@ -114,13 +114,6 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: raise SSZSerializationError(f"Boolean: expected 1 byte, got {len(data)}") return cls.decode_bytes(data) - def _raise_type_error(self, other: Any, op_symbol: str) -> None: - """Helper to raise a consistent TypeError for unsupported operations.""" - raise TypeError( - f"Unsupported operand type(s) for {op_symbol}: " - f"'{type(self).__name__}' and '{type(other).__name__}'" - ) - def __add__(self, other: Any) -> Self: """Disable the addition operator (`+`).""" raise TypeError("Arithmetic operations are not supported for Boolean.") @@ -140,7 +133,10 @@ def __rsub__(self, other: Any) -> Self: def __and__(self, other: Any) -> Self: """Handle the bitwise AND operator (`&`) strictly.""" if not isinstance(other, type(self)): - self._raise_type_error(other, "&") + raise TypeError( + f"Unsupported operand type(s) for &: " + f"'{type(self).__name__}' and '{type(other).__name__}'" + ) return type(self)(int(self) & int(other)) def __rand__(self, other: Any) -> Self: @@ -150,7 +146,10 @@ def __rand__(self, other: Any) -> Self: def __or__(self, other: Any) -> Self: """Handle the bitwise OR operator (`|`) strictly.""" if not isinstance(other, type(self)): - self._raise_type_error(other, "|") + raise TypeError( + f"Unsupported operand type(s) for |: " + f"'{type(self).__name__}' and '{type(other).__name__}'" + ) return type(self)(int(self) | int(other)) def __ror__(self, other: Any) -> Self: @@ -160,7 +159,10 @@ def __ror__(self, other: Any) -> Self: def __xor__(self, other: Any) -> Self: """Handle the bitwise XOR operator (`^`) strictly.""" if not isinstance(other, type(self)): - self._raise_type_error(other, "^") + raise TypeError( + f"Unsupported operand type(s) for ^: " + f"'{type(self).__name__}' and '{type(other).__name__}'" + ) return type(self)(int(self) ^ int(other)) def __rxor__(self, other: Any) -> Self: @@ -181,9 +183,8 @@ def __ne__(self, other: object) -> bool: """ Handle the inequality operator (`!=`). - Allows comparison with native `bool` and `int` types (0 or 1). - - It returns `True` for all other types. + Must be defined explicitly because `int.__ne__` would bypass + our custom `__eq__` type-checking logic. """ return not self.__eq__(other) diff --git a/src/lean_spec/types/byte_arrays.py b/src/lean_spec/types/byte_arrays.py index f761b09f..be887c1d 100644 --- a/src/lean_spec/types/byte_arrays.py +++ b/src/lean_spec/types/byte_arrays.py @@ -39,8 +39,7 @@ def _coerce_to_bytes(value: Any) -> bytes: if isinstance(value, Iterable): # bytes(bytearray(iterable)) enforces each element is an int in 0..255 return bytes(bytearray(value)) - # Fall back to Python's bytes() constructor (will raise if unsupported) - return bytes(value) + raise TypeError(f"Cannot coerce {type(value).__name__} to bytes") class BaseBytes(bytes, SSZType): @@ -53,6 +52,8 @@ class BaseBytes(bytes, SSZType): Instances are immutable byte objects with strict length checking. """ + __slots__ = () + LENGTH: ClassVar[int] """The exact number of bytes (overridden by subclasses).""" diff --git a/src/lean_spec/types/exceptions.py b/src/lean_spec/types/exceptions.py index 013edb2e..c1fbe607 100644 --- a/src/lean_spec/types/exceptions.py +++ b/src/lean_spec/types/exceptions.py @@ -14,35 +14,4 @@ class SSZValueError(SSZError): class SSZSerializationError(SSZError): - """Raised for serialization errors (encoding, decoding, stream issues). - - Supports optional context for better error diagnostics: - - - type_name: The SSZ type being processed - - field_name: The field within a container (if applicable) - - offset: The byte offset where the error occurred - """ - - def __init__( - self, - message: str, - *, - type_name: str | None = None, - field_name: str | None = None, - offset: int | None = None, - ) -> None: - """Initialize with message and optional context for better diagnostics.""" - self.type_name = type_name - self.field_name = field_name - self.offset = offset - - context_parts = [] - if type_name: - context_parts.append(f"type={type_name}") - if field_name: - context_parts.append(f"field={field_name}") - if offset is not None: - context_parts.append(f"offset={offset}") - - context = f" [{', '.join(context_parts)}]" if context_parts else "" - super().__init__(f"{message}{context}") + """Raised for serialization errors (encoding, decoding, stream issues).""" diff --git a/src/lean_spec/types/ssz_base.py b/src/lean_spec/types/ssz_base.py index b9f116fc..92c5f919 100644 --- a/src/lean_spec/types/ssz_base.py +++ b/src/lean_spec/types/ssz_base.py @@ -118,5 +118,5 @@ def __repr__(self) -> str: data: Sequence[Any] | None = getattr(self, "data", None) if data is not None: return f"{self.__class__.__name__}(data={list(data)!r})" - field_strs = [f"{name}={getattr(self, name)!r}" for name in type(self).model_fields.keys()] + field_strs = [f"{name}={getattr(self, name)!r}" for name in type(self).model_fields] return f"{self.__class__.__name__}({' '.join(field_strs)})" diff --git a/src/lean_spec/types/uint.py b/src/lean_spec/types/uint.py index b048851d..42c53f6b 100644 --- a/src/lean_spec/types/uint.py +++ b/src/lean_spec/types/uint.py @@ -14,6 +14,8 @@ class BaseUint(int, SSZType): """A base class for custom unsigned integer types that inherits from `int`.""" + __slots__ = () + BITS: ClassVar[int] """The number of bits in the integer (overridden by subclasses).""" @@ -344,13 +346,13 @@ def __rrshift__(self, other: Any) -> Self: return type(self)(int(other) >> int(self)) def __eq__(self, other: object) -> bool: - """Handle the equality operator (`==`)""" + """Handle the equality operator (`==`).""" if not isinstance(other, BaseUint): self._raise_type_error(other, "==") return super().__eq__(other) def __ne__(self, other: object) -> bool: - """Handle the inequality operator (`!=`)""" + """Handle the inequality operator (`!=`).""" if not isinstance(other, BaseUint): self._raise_type_error(other, "!=") return super().__ne__(other) diff --git a/src/lean_spec/types/union.py b/src/lean_spec/types/union.py index 0078f460..9699b5bc 100644 --- a/src/lean_spec/types/union.py +++ b/src/lean_spec/types/union.py @@ -115,11 +115,6 @@ def selected_type(self) -> type[SSZType] | None: """The type class of the currently selected option.""" return self.OPTIONS[self.selector] - @classmethod - def options(cls) -> tuple[type[SSZType] | None, ...]: - """Get the tuple of possible types for this Union.""" - return cls.OPTIONS - @classmethod def is_fixed_size(cls) -> bool: """Union types are always variable-size in SSZ.""" @@ -171,7 +166,7 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self: return cls(selector=selector, value=None) # Handle non-None option - if selected_type.is_fixed_size() and hasattr(selected_type, "get_byte_length"): + if selected_type.is_fixed_size(): required_bytes = selected_type.get_byte_length() if remaining_bytes < required_bytes: raise SSZSerializationError( diff --git a/tests/lean_spec/subspecs/chain/test_service.py b/tests/lean_spec/subspecs/chain/test_service.py index c058023f..2576601b 100644 --- a/tests/lean_spec/subspecs/chain/test_service.py +++ b/tests/lean_spec/subspecs/chain/test_service.py @@ -145,7 +145,7 @@ async def capture_sleep(duration: float) -> None: captured_duration = duration with patch("asyncio.sleep", new=capture_sleep): - await chain_service._sleep_until_next_interval() + await chain_service.clock.sleep_until_next_interval() # Should sleep until next interval boundary. expected = float(genesis) + interval_secs - current_time @@ -172,7 +172,7 @@ async def capture_sleep(duration: float) -> None: captured_duration = duration with patch("asyncio.sleep", new=capture_sleep): - await chain_service._sleep_until_next_interval() + await chain_service.clock.sleep_until_next_interval() # At boundary, next boundary is one full interval away. expected = float(MILLISECONDS_PER_INTERVAL) / 1000.0 @@ -198,7 +198,7 @@ async def capture_sleep(duration: float) -> None: captured_duration = duration with patch("asyncio.sleep", new=capture_sleep): - await chain_service._sleep_until_next_interval() + await chain_service.clock.sleep_until_next_interval() # Should sleep until genesis. expected = float(genesis) - current_time diff --git a/tests/lean_spec/subspecs/sync/test_checkpoint_sync.py b/tests/lean_spec/subspecs/sync/test_checkpoint_sync.py index fb94209f..61425fe4 100644 --- a/tests/lean_spec/subspecs/sync/test_checkpoint_sync.py +++ b/tests/lean_spec/subspecs/sync/test_checkpoint_sync.py @@ -22,7 +22,7 @@ class TestStateVerification: async def test_valid_state_passes_verification(self, genesis_state: State) -> None: """Valid state with validators passes verification checks.""" - result = await verify_checkpoint_state(genesis_state) + result = verify_checkpoint_state(genesis_state) assert result is True async def test_state_without_validators_fails_verification(self, genesis_state: State) -> None: @@ -40,7 +40,7 @@ async def test_state_without_validators_fails_verification(self, genesis_state: justifications_validators=genesis_state.justifications_validators, ) - result = await verify_checkpoint_state(empty_state) + result = verify_checkpoint_state(empty_state) assert result is False async def test_state_exceeding_validator_limit_fails(self) -> None: @@ -53,7 +53,7 @@ async def test_state_exceeding_validator_limit_fails(self) -> None: mock_validators.__len__ = MagicMock(return_value=int(VALIDATOR_REGISTRY_LIMIT) + 1) mock_state.validators = mock_validators - result = await verify_checkpoint_state(mock_state) + result = verify_checkpoint_state(mock_state) assert result is False @@ -73,7 +73,7 @@ async def test_client_fetches_and_deserializes_state(self, base_store: Store) -> assert state is not None assert state.slot == Slot(0) - is_valid = await verify_checkpoint_state(state) + is_valid = verify_checkpoint_state(state) assert is_valid is True finally: diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index b9c0d394..0d074a40 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -213,7 +213,7 @@ async def capture_sleep(duration: float) -> None: captured_duration = duration with patch("asyncio.sleep", new=capture_sleep): - await service._sleep_until_next_interval() + await service.clock.sleep_until_next_interval() # Should sleep until next interval boundary expected = interval_seconds / 2 @@ -244,7 +244,7 @@ async def capture_sleep(duration: float) -> None: captured_duration = duration with patch("asyncio.sleep", new=capture_sleep): - await service._sleep_until_next_interval() + await service.clock.sleep_until_next_interval() # Should sleep until genesis expected = float(genesis) - current_time # 100 seconds diff --git a/tests/lean_spec/test_cli.py b/tests/lean_spec/test_cli.py index fef1fa40..90a47f22 100644 --- a/tests/lean_spec/test_cli.py +++ b/tests/lean_spec/test_cli.py @@ -18,7 +18,6 @@ from lean_spec.__main__ import ( _init_from_checkpoint, create_anchor_block, - is_enr_string, resolve_bootnode, ) from lean_spec.subspecs.containers import Block, BlockBody @@ -123,48 +122,6 @@ def _make_enr_without_udp(ip_bytes: bytes) -> str: MULTIADDR_IPV6 = "/ip6/::1/udp/9000/quic-v1" -class TestIsEnrString: - """Tests for is_enr_string() detection function.""" - - def test_enr_string_detected(self) -> None: - """Valid ENR prefix returns True.""" - assert is_enr_string("enr:-IS4QHCYrYZbAKW...") is True - - def test_enr_prefix_minimal(self) -> None: - """Minimal ENR prefix 'enr:' returns True.""" - assert is_enr_string("enr:") is True - - def test_enr_with_valid_content(self) -> None: - """Full valid ENR string returns True.""" - assert is_enr_string(ENR_WITH_UDP) is True - - def test_multiaddr_not_detected(self) -> None: - """Multiaddr string returns False.""" - assert is_enr_string(MULTIADDR_IPV4) is False - assert is_enr_string(MULTIADDR_IPV6) is False - - def test_empty_string(self) -> None: - """Empty string returns False.""" - assert is_enr_string("") is False - - def test_enode_not_detected(self) -> None: - """enode:// format returns False.""" - enode = "enode://abc123@127.0.0.1:30303" - assert is_enr_string(enode) is False - - def test_similar_prefix_not_detected(self) -> None: - """Strings with similar but incorrect prefixes return False.""" - assert is_enr_string("ENR:") is False # Case sensitive - assert is_enr_string("enr") is False # Missing colon - assert is_enr_string("enr-") is False # Wrong separator - assert is_enr_string("enrs:") is False # Extra character - - def test_whitespace_prefix_not_detected(self) -> None: - """Whitespace before prefix returns False.""" - assert is_enr_string(" enr:abc") is False - assert is_enr_string("\tenr:abc") is False - - class TestResolveBootnode: """Tests for resolve_bootnode() resolution function.""" @@ -418,7 +375,6 @@ async def test_checkpoint_sync_verification_failure_returns_none(self) -> None: ), patch( "lean_spec.__main__.verify_checkpoint_state", - new_callable=AsyncMock, return_value=False, # Verification fails ), ): diff --git a/tests/lean_spec/types/test_union.py b/tests/lean_spec/types/test_union.py index 297783dc..fd890902 100644 --- a/tests/lean_spec/types/test_union.py +++ b/tests/lean_spec/types/test_union.py @@ -238,11 +238,11 @@ def test_equality_and_hashing() -> None: assert hash(u1) != hash(u3) -def test_options_helper() -> None: - """Test options() class method.""" - assert NumericUnion.options() == (Uint16, Uint32) - assert OptionalNumericUnion.options() == (None, Uint16, Uint32) - assert SimpleUnion.options() == (Uint16,) +def test_options_class_var() -> None: + """Test OPTIONS class variable access.""" + assert NumericUnion.OPTIONS == (Uint16, Uint32) + assert OptionalNumericUnion.OPTIONS == (None, Uint16, Uint32) + assert SimpleUnion.OPTIONS == (Uint16,) def test_is_fixed_size_helper() -> None: