diff --git a/src/lean_spec/__main__.py b/src/lean_spec/__main__.py index 5f524c76..8edde6ce 100644 --- a/src/lean_spec/__main__.py +++ b/src/lean_spec/__main__.py @@ -41,7 +41,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.gossipsub import GossipTopic from lean_spec.subspecs.networking.reqresp.message import Status -from lean_spec.subspecs.node import Node, NodeConfig, get_local_validator_id +from lean_spec.subspecs.node import Node, NodeConfig from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.subspecs.sync.checkpoint_sync import ( CheckpointSyncError, @@ -281,7 +281,7 @@ async def _init_from_checkpoint( # # The store treats this as the new "genesis" for fork choice purposes. # All blocks before the checkpoint are effectively pruned. - validator_id = get_local_validator_id(validator_registry) + validator_id = validator_registry.primary_index() if validator_registry else None store = Store.get_forkchoice_store(state, anchor_block, validator_id) logger.info( "Initialized from checkpoint at slot %d (finalized=%s)", @@ -487,7 +487,7 @@ async def run_node( block_topic = str(GossipTopic.block(GOSSIP_FORK_DIGEST)) event_source.subscribe_gossip_topic(block_topic) # Subscribe to attestation subnet topics based on local validator id. - validator_id = get_local_validator_id(validator_registry) + validator_id = validator_registry.primary_index() if validator_registry else None if validator_id is None: subnet_id = 0 logger.info("No local validator id; subscribing to attestation subnet %d", subnet_id) diff --git a/src/lean_spec/subspecs/networking/client/__init__.py b/src/lean_spec/subspecs/networking/client/__init__.py index e2e71016..b83b042d 100644 --- a/src/lean_spec/subspecs/networking/client/__init__.py +++ b/src/lean_spec/subspecs/networking/client/__init__.py @@ -13,10 +13,11 @@ Bridges connection events to NetworkService events. """ -from .event_source import LiveNetworkEventSource +from .event_source import EventSource, LiveNetworkEventSource from .reqresp_client import ReqRespClient __all__ = [ + "EventSource", "LiveNetworkEventSource", "ReqRespClient", ] diff --git a/src/lean_spec/subspecs/networking/client/event_source.py b/src/lean_spec/subspecs/networking/client/event_source.py index feaac2a9..8d27b3aa 100644 --- a/src/lean_spec/subspecs/networking/client/event_source.py +++ b/src/lean_spec/subspecs/networking/client/event_source.py @@ -103,6 +103,7 @@ import asyncio import logging from dataclasses import dataclass, field +from typing import Protocol, Self from lean_spec.snappy import SnappyDecompressionError, frame_decompress from lean_spec.subspecs.containers import SignedBlockWithAttestation @@ -162,6 +163,27 @@ logger = logging.getLogger(__name__) +class EventSource(Protocol): + """Protocol for network event sources. + + Defines the minimal interface needed by NetworkService. + LiveNetworkEventSource satisfies this with real network I/O. + MockEventSource satisfies this for testing. + """ + + def __aiter__(self) -> Self: + """Return self as async iterator.""" + ... + + async def __anext__(self) -> NetworkEvent: + """Yield the next network event.""" + ... + + async def publish(self, topic: str, data: bytes) -> None: + """Broadcast a message to all peers on a topic.""" + ... + + class GossipMessageError(Exception): """Raised when a gossip message cannot be processed.""" diff --git a/src/lean_spec/subspecs/networking/service/service.py b/src/lean_spec/subspecs/networking/service/service.py index 2ea79459..5fc59902 100644 --- a/src/lean_spec/subspecs/networking/service/service.py +++ b/src/lean_spec/subspecs/networking/service/service.py @@ -29,7 +29,7 @@ from lean_spec.snappy import frame_compress from lean_spec.subspecs.containers import SignedBlockWithAttestation from lean_spec.subspecs.containers.attestation import SignedAggregatedAttestation, SignedAttestation -from lean_spec.subspecs.networking.client.event_source import LiveNetworkEventSource +from lean_spec.subspecs.networking.client.event_source import EventSource from lean_spec.subspecs.networking.gossipsub.topic import GossipTopic from lean_spec.subspecs.networking.peer import PeerInfo from lean_spec.subspecs.networking.types import ConnectionState @@ -70,7 +70,7 @@ class NetworkService: sync_service: SyncService """Sync service that receives routed events.""" - event_source: LiveNetworkEventSource + event_source: EventSource """Source of network events from the transport layer.""" fork_digest: str = field(default="0x00000000") diff --git a/src/lean_spec/subspecs/node/__init__.py b/src/lean_spec/subspecs/node/__init__.py index d497ebb1..a5d8bcb1 100644 --- a/src/lean_spec/subspecs/node/__init__.py +++ b/src/lean_spec/subspecs/node/__init__.py @@ -1,5 +1,5 @@ """Node orchestrator for the Lean Ethereum consensus client.""" -from .node import Node, NodeConfig, get_local_validator_id +from .node import Node, NodeConfig -__all__ = ["Node", "NodeConfig", "get_local_validator_id"] +__all__ = ["Node", "NodeConfig"] diff --git a/src/lean_spec/subspecs/node/node.py b/src/lean_spec/subspecs/node/node.py index 82178da2..a6a4c2ca 100644 --- a/src/lean_spec/subspecs/node/node.py +++ b/src/lean_spec/subspecs/node/node.py @@ -19,7 +19,11 @@ from lean_spec.subspecs.api import ApiServer, ApiServerConfig from lean_spec.subspecs.chain import SlotClock -from lean_spec.subspecs.chain.config import ATTESTATION_COMMITTEE_COUNT, INTERVALS_PER_SLOT +from lean_spec.subspecs.chain.config import ( + ATTESTATION_COMMITTEE_COUNT, + INTERVALS_PER_SLOT, + SECONDS_PER_SLOT, +) from lean_spec.subspecs.chain.service import ChainService from lean_spec.subspecs.containers import Block, BlockBody, SignedBlockWithAttestation, State from lean_spec.subspecs.containers.attestation import SignedAttestation @@ -29,13 +33,16 @@ from lean_spec.subspecs.containers.validator import ValidatorIndex from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.networking import NetworkService -from lean_spec.subspecs.networking.client.event_source import LiveNetworkEventSource +from lean_spec.subspecs.networking.client.event_source import EventSource from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.subspecs.storage import Database, SQLiteDatabase from lean_spec.subspecs.sync import BlockCache, NetworkRequester, PeerManager, SyncService from lean_spec.subspecs.validator import ValidatorRegistry, ValidatorService from lean_spec.types import Bytes32, Uint64 +_ZERO_TIME = Uint64(0) +"""Default genesis time for database loading when no genesis time is available.""" + @dataclass(frozen=True, slots=True) class NodeConfig: @@ -51,7 +58,7 @@ class NodeConfig: validators: Validators """Initial validator set for genesis state.""" - event_source: LiveNetworkEventSource + event_source: EventSource """Source of network events.""" network: NetworkRequester @@ -95,6 +102,11 @@ class NodeConfig: """ Whether this node functions as an aggregator. + Aggregator selection is static (node-level flag), not VRF-based rotation. + The spec assumes at least one aggregator node exists in the network. + + With ATTESTATION_COMMITTEE_COUNT = 1, all validators share subnet 0. + When True: - The node performs attestation aggregation operations - The ENR advertises aggregator capability to peers @@ -104,20 +116,6 @@ class NodeConfig: """ -def get_local_validator_id(registry: ValidatorRegistry | None) -> ValidatorIndex | None: - """ - Get the validator index for this node. - - For now, returns None as a default for passive nodes or simple setups. - Future implementations will look up keys in the registry. - """ - if registry is None or len(registry) == 0: - return None - - # For simplicity, use the first validator in the registry. - return registry.indices()[0] - - @dataclass(slots=True) class Node: """ @@ -148,6 +146,9 @@ class Node: validator_service: ValidatorService | None = field(default=None) """Optional validator service for block/attestation production.""" + database: Database | None = field(default=None) + """Optional database reference for lifecycle management.""" + _shutdown: asyncio.Event = field(default_factory=asyncio.Event) """Event signaling shutdown request.""" @@ -170,13 +171,17 @@ def from_genesis(cls, config: NodeConfig) -> Node: # The database is optional - nodes can run without persistence. database: Database | None = None if config.database_path is not None: - database = cls._create_database(config.database_path) + database = SQLiteDatabase(config.database_path) # # If database contains valid state, resume from there. # Otherwise, fall through to genesis initialization. - validator_id = get_local_validator_id(config.validator_registry) - store = cls._try_load_from_database(database, validator_id) + validator_id = ( + config.validator_registry.primary_index() if config.validator_registry else None + ) + store = cls._try_load_from_database( + database, validator_id, config.genesis_time, config.time_fn + ) if store is None: # Generate genesis state from validators. @@ -242,7 +247,7 @@ def from_genesis(cls, config: NodeConfig) -> Node: # # SyncService delegates aggregate publishing to NetworkService # via a callback, avoiding a circular dependency. - sync_service._publish_agg_fn = network_service.publish_aggregated_attestation + sync_service.set_publish_agg_fn(network_service.publish_aggregated_attestation) # Create API server if configured api_server: ApiServer | None = None @@ -261,17 +266,20 @@ def from_genesis(cls, config: NodeConfig) -> Node: # Wire callbacks to publish produced blocks/attestations to the network. validator_service: ValidatorService | None = None if config.validator_registry is not None: - # Create a wrapper for publish_attestation that computes the subnet_id - # from the validator_id in the attestation + # These wrappers serve a dual purpose: + # + # 1. Publish to the network so peers receive the block/attestation. + # 2. Process locally so the node's own store reflects what it produced. + # + # Without local processing, the node would not see its own produced + # blocks/attestations in forkchoice until they arrived back via gossip. async def publish_attestation_wrapper(attestation: SignedAttestation) -> None: subnet_id = attestation.validator_id.compute_subnet_id(ATTESTATION_COMMITTEE_COUNT) await network_service.publish_attestation(attestation, subnet_id) - # Also route locally so we can aggregate our own attestation await sync_service.on_gossip_attestation(attestation) async def publish_block_wrapper(block: SignedBlockWithAttestation) -> None: await network_service.publish_block(block) - # Also route locally so we update our own store await sync_service.on_gossip_block(block, peer_id=None) validator_service = ValidatorService( @@ -290,35 +298,32 @@ async def publish_block_wrapper(block: SignedBlockWithAttestation) -> None: network_service=network_service, api_server=api_server, validator_service=validator_service, + database=database, ) - @staticmethod - def _create_database(path: Path | str) -> Database: - """ - Create database instance from path. - - Args: - path: Path to SQLite database file. - - Returns: - Database instance ready for use. - """ - # SQLite handles its own caching at the filesystem level. - return SQLiteDatabase(path) - @staticmethod def _try_load_from_database( database: Database | None, validator_id: ValidatorIndex | None, + genesis_time: Uint64 | None = None, + time_fn: Callable[[], float] = time.time, ) -> Store | None: """ Try to load forkchoice store from existing database state. Returns None if database is empty or unavailable. + Uses wall-clock time to set the store's time field. This ensures that + after a restart, the store reflects actual elapsed time rather than just + the head block's proposal moment. Without this, the store would reject + valid attestations as "too far in future" until the chain service ticks + catch up. + Args: database: Database to load from. validator_id: Validator index for the store instance. + genesis_time: Unix timestamp of genesis (slot 0). + time_fn: Wall-clock time source. Returns: Loaded Store or None if no valid state exists. @@ -345,12 +350,24 @@ def _try_load_from_database( if justified is None or finalized is None: return None + # Compute store time from wall clock to avoid post-restart drift. + # + # Using only the head block's slot would set the store time to the + # block's proposal moment. After a restart, this makes the store + # think it's in the past, rejecting valid attestations as "future". + # Instead, derive time from wall clock, floored by the block's slot. + gt = genesis_time if genesis_time is not None else _ZERO_TIME + elapsed_seconds = Uint64(max(0, int(time_fn()) - int(gt))) + wall_clock_intervals = elapsed_seconds * INTERVALS_PER_SLOT // SECONDS_PER_SLOT + block_intervals = head_block.slot * INTERVALS_PER_SLOT + store_time = max(wall_clock_intervals, block_intervals) + # Reconstruct minimal store from persisted data. # # The store starts with just the head block and state. # Additional blocks can be loaded on demand or via sync. return Store( - time=Uint64(head_block.slot * INTERVALS_PER_SLOT), + time=store_time, config=head_state.config, head=head_root, safe_target=head_root, @@ -383,14 +400,19 @@ async def run(self, *, install_signal_handlers: bool = True) -> None: # A separate task monitors the shutdown signal. # When triggered, it stops all services. # Once services exit, execution completes. - async with asyncio.TaskGroup() as tg: - tg.create_task(self.chain_service.run()) - tg.create_task(self.network_service.run()) - if self.api_server is not None: - tg.create_task(self.api_server.run()) - if self.validator_service is not None: - tg.create_task(self.validator_service.run()) - tg.create_task(self._wait_shutdown()) + # The finally block ensures the database is closed on shutdown. + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(self.chain_service.run()) + tg.create_task(self.network_service.run()) + if self.api_server is not None: + tg.create_task(self.api_server.run()) + if self.validator_service is not None: + tg.create_task(self.validator_service.run()) + tg.create_task(self._wait_shutdown()) + finally: + if self.database is not None: + self.database.close() def _install_signal_handlers(self) -> None: """ diff --git a/src/lean_spec/subspecs/sync/service.py b/src/lean_spec/subspecs/sync/service.py index 6311c1fc..0d6fd044 100644 --- a/src/lean_spec/subspecs/sync/service.py +++ b/src/lean_spec/subspecs/sync/service.py @@ -198,6 +198,17 @@ class SyncService: Same buffering strategy as individual attestations. """ + def set_publish_agg_fn( + self, fn: Callable[[SignedAggregatedAttestation], Coroutine[Any, Any, None]] + ) -> None: + """Wire the aggregated attestation publisher after construction. + + Breaks circular dependency between SyncService and NetworkService. + NetworkService needs SyncService at construction, but SyncService + needs NetworkService's publish method. This setter resolves the cycle. + """ + self._publish_agg_fn = fn + def __post_init__(self) -> None: """Initialize sync components.""" self._init_components() diff --git a/src/lean_spec/subspecs/validator/registry.py b/src/lean_spec/subspecs/validator/registry.py index 88a03dba..3eac698e 100644 --- a/src/lean_spec/subspecs/validator/registry.py +++ b/src/lean_spec/subspecs/validator/registry.py @@ -202,6 +202,21 @@ def indices(self) -> ValidatorIndices: """ return ValidatorIndices(data=list(self._validators.keys())) + def primary_index(self) -> ValidatorIndex | None: + """ + Get the primary validator index for store-level identity. + + Returns the first validator index in the registry. + With ATTESTATION_COMMITTEE_COUNT = 1, all validators share subnet 0, + so a single ID suffices for store-level operations. + + Returns: + First validator index, or None if registry is empty. + """ + if not self._validators: + return None + return next(iter(self._validators)) + def __len__(self) -> int: """Number of validators in the registry.""" return len(self._validators) diff --git a/tests/lean_spec/subspecs/networking/test_network_service.py b/tests/lean_spec/subspecs/networking/test_network_service.py index 6f8f9cb1..9e657f4d 100644 --- a/tests/lean_spec/subspecs/networking/test_network_service.py +++ b/tests/lean_spec/subspecs/networking/test_network_service.py @@ -78,7 +78,7 @@ async def test_block_added_to_store_blocks_dict( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -113,7 +113,7 @@ async def test_store_head_updated_after_block( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -147,7 +147,7 @@ async def test_block_ignored_in_idle_state_store_unchanged( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -194,7 +194,7 @@ async def test_attestation_processed_by_store( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -237,7 +237,7 @@ async def test_attestation_ignored_in_idle_state( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -270,7 +270,7 @@ async def test_peer_status_triggers_idle_to_syncing( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -296,7 +296,7 @@ async def test_peer_status_updates_peer_manager( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -351,7 +351,7 @@ async def test_full_sync_flow_status_then_block( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -397,7 +397,7 @@ async def test_block_before_status_is_ignored( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() @@ -442,7 +442,7 @@ async def test_multiple_blocks_chain_extension( source = MockEventSource(events=events) network_service = NetworkService( sync_service=sync_service, - event_source=source, # type: ignore[arg-type] + event_source=source, ) await network_service.run() diff --git a/tests/lean_spec/subspecs/node/test_node.py b/tests/lean_spec/subspecs/node/test_node.py index 6720f0cc..4a5be8af 100644 --- a/tests/lean_spec/subspecs/node/test_node.py +++ b/tests/lean_spec/subspecs/node/test_node.py @@ -3,11 +3,16 @@ from __future__ import annotations import asyncio -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from lean_spec.subspecs.chain.config import HISTORICAL_ROOTS_LIMIT +from lean_spec.subspecs.api import ApiServerConfig +from lean_spec.subspecs.chain.config import ( + HISTORICAL_ROOTS_LIMIT, + INTERVALS_PER_SLOT, + SECONDS_PER_SLOT, +) from lean_spec.subspecs.containers import ( Block, BlockBody, @@ -27,17 +32,67 @@ ) from lean_spec.subspecs.containers.validator import ValidatorIndex from lean_spec.subspecs.node import Node, NodeConfig +from lean_spec.subspecs.validator import ValidatorRegistry +from lean_spec.subspecs.validator.registry import ValidatorEntry from lean_spec.types import Bytes32, Uint64 from tests.lean_spec.helpers import MockEventSource, MockNetworkRequester, make_validators +GENESIS_TIME = Uint64(1704067200) + + +_DEFAULT_TEST_SLOT = Slot(10) + + +def _make_mock_db_data( + test_slot: Slot = _DEFAULT_TEST_SLOT, +) -> tuple[MagicMock, Block, State, Checkpoint]: + """Build a mock database with consistent block/state/checkpoint data.""" + head_root = Bytes32(b"\x01" * 32) + block = Block( + slot=test_slot, + proposer_index=ValidatorIndex(0), + parent_root=Bytes32.zero(), + state_root=Bytes32.zero(), + body=BlockBody(attestations=AggregatedAttestations(data=[])), + ) + checkpoint = Checkpoint(root=head_root, slot=test_slot) + state = State( + config=Config(genesis_time=GENESIS_TIME), + slot=test_slot, + latest_block_header=BlockHeader( + slot=test_slot, + proposer_index=ValidatorIndex(0), + parent_root=Bytes32.zero(), + state_root=Bytes32.zero(), + body_root=Bytes32.zero(), + ), + latest_justified=checkpoint, + latest_finalized=checkpoint, + historical_block_hashes=HistoricalBlockHashes(data=[Bytes32.zero()]), + justified_slots=JustifiedSlots(data=[]), + validators=Validators(data=[]), + justifications_roots=JustificationRoots( + data=[Bytes32.zero()] * int(HISTORICAL_ROOTS_LIMIT) + ), + justifications_validators=JustificationValidators(data=[]), + ) + + mock_db = MagicMock() + mock_db.get_head_root.return_value = head_root + mock_db.get_block.return_value = block + mock_db.get_state.return_value = state + mock_db.get_justified_checkpoint.return_value = checkpoint + mock_db.get_finalized_checkpoint.return_value = checkpoint + return mock_db, block, state, checkpoint + @pytest.fixture def node_config() -> NodeConfig: """Provide a basic node configuration for tests.""" return NodeConfig( - genesis_time=Uint64(1704067200), + genesis_time=GENESIS_TIME, validators=make_validators(3), - event_source=MockEventSource(), # type: ignore[arg-type] + event_source=MockEventSource(), network=MockNetworkRequester(), time_fn=lambda: 1704067200.0, ) @@ -50,10 +105,8 @@ def test_creates_store_with_genesis_block(self, node_config: NodeConfig) -> None """Store contains genesis block at slot 0.""" node = Node.from_genesis(node_config) - # Store should have exactly one block (genesis) assert len(node.store.blocks) == 1 - # Head should point to a block at slot 0 head_block = node.store.blocks[node.store.head] assert head_block.slot == Slot(0) @@ -74,68 +127,142 @@ def test_services_share_sync_service(self, node_config: NodeConfig) -> None: """ChainService and NetworkService reference same SyncService.""" node = Node.from_genesis(node_config) - # All services should be wired to the same SyncService instance assert node.chain_service.sync_service is node.sync_service assert node.network_service.sync_service is node.sync_service - def test_store_time_from_database_uses_intervals_not_seconds(self) -> None: - """This test verifies the invariant that store time represents intervals.""" - test_slot = Slot(10) - head_root = Bytes32(b"\x01" * 32) +class TestDatabaseLoading: + """Tests for _try_load_from_database.""" - block = Block( - slot=test_slot, - proposer_index=ValidatorIndex(0), - parent_root=Bytes32.zero(), - state_root=Bytes32.zero(), - body=BlockBody(attestations=AggregatedAttestations(data=[])), - ) + def test_returns_none_when_no_database(self) -> None: + """No database returns None.""" + assert Node._try_load_from_database(None, validator_id=None) is None - checkpoint = Checkpoint(root=head_root, slot=test_slot) + def test_returns_none_when_no_head_root(self) -> None: + """Empty database returns None.""" + mock_db = MagicMock() + mock_db.get_head_root.return_value = None - state = State( - config=Config(genesis_time=Uint64(1704067200)), - slot=test_slot, - latest_block_header=BlockHeader( - slot=test_slot, - proposer_index=ValidatorIndex(0), - parent_root=Bytes32.zero(), - state_root=Bytes32.zero(), - body_root=Bytes32.zero(), - ), - latest_justified=checkpoint, - latest_finalized=checkpoint, - historical_block_hashes=HistoricalBlockHashes(data=[Bytes32.zero()]), - justified_slots=JustifiedSlots(data=[]), - validators=Validators(data=[]), - justifications_roots=JustificationRoots( - data=[Bytes32.zero()] * int(HISTORICAL_ROOTS_LIMIT) - ), - justifications_validators=JustificationValidators(data=[]), - ) + assert Node._try_load_from_database(mock_db, validator_id=None) is None - # Simulates loading from an existing database. - # Exercises the code path where time is computed from slot. + def test_returns_none_when_block_missing(self) -> None: + """Missing block returns None.""" mock_db = MagicMock() - mock_db.get_head_root.return_value = head_root - mock_db.get_block.return_value = block - mock_db.get_state.return_value = state - mock_db.get_justified_checkpoint.return_value = checkpoint - mock_db.get_finalized_checkpoint.return_value = checkpoint + mock_db.get_head_root.return_value = Bytes32(b"\x01" * 32) + mock_db.get_block.return_value = None + mock_db.get_state.return_value = MagicMock() + + assert Node._try_load_from_database(mock_db, validator_id=None) is None + + def test_returns_none_when_state_missing(self) -> None: + """Missing state returns None.""" + mock_db = MagicMock() + mock_db.get_head_root.return_value = Bytes32(b"\x01" * 32) + mock_db.get_block.return_value = MagicMock() + mock_db.get_state.return_value = None + + assert Node._try_load_from_database(mock_db, validator_id=None) is None + + def test_returns_none_when_justified_missing(self) -> None: + """Missing justified checkpoint returns None.""" + mock_db, block, state, _ = _make_mock_db_data() + mock_db.get_justified_checkpoint.return_value = None - # Patching to 8 distinguishes from the seconds per slot. - patched_intervals = Uint64(8) - with patch("lean_spec.subspecs.node.node.INTERVALS_PER_SLOT", patched_intervals): - store = Node._try_load_from_database(mock_db, validator_id=ValidatorIndex(0)) + assert Node._try_load_from_database(mock_db, validator_id=None) is None + + def test_returns_none_when_finalized_missing(self) -> None: + """Missing finalized checkpoint returns None.""" + mock_db, block, state, _ = _make_mock_db_data() + mock_db.get_finalized_checkpoint.return_value = None + mock_db.get_justified_checkpoint.return_value = MagicMock() + + assert Node._try_load_from_database(mock_db, validator_id=None) is None + + def test_successful_load_uses_wall_clock_time(self) -> None: + """Store time uses wall clock when it exceeds block-based time.""" + test_slot = Slot(10) + mock_db, _, _, _ = _make_mock_db_data(test_slot) + + # Simulate 100 seconds after genesis (well past slot 10 at 4s/slot = 40s). + wall_time = float(GENESIS_TIME) + 100.0 + store = Node._try_load_from_database( + mock_db, + validator_id=ValidatorIndex(0), + genesis_time=GENESIS_TIME, + time_fn=lambda: wall_time, + ) + + assert store is not None + + # Wall clock: 100s * 5 intervals / 4 seconds = 125 intervals. + expected_wall = Uint64(100) * INTERVALS_PER_SLOT // SECONDS_PER_SLOT + # Block-based: slot 10 * 5 = 50 intervals. + expected_block = test_slot * INTERVALS_PER_SLOT + assert expected_wall > expected_block + assert store.time == expected_wall + + def test_load_uses_block_time_when_wall_clock_behind(self) -> None: + """Store time floors to block-based time if wall clock is behind.""" + test_slot = Slot(100) + mock_db, _, _, _ = _make_mock_db_data(test_slot) + + # Simulate wall clock only 10 seconds after genesis (slot 100 is at 400s). + wall_time = float(GENESIS_TIME) + 10.0 + store = Node._try_load_from_database( + mock_db, + validator_id=ValidatorIndex(0), + genesis_time=GENESIS_TIME, + time_fn=lambda: wall_time, + ) assert store is not None - expected_time = Uint64(test_slot * patched_intervals) - assert store.time == expected_time, ( - f"Store.time should use INTERVALS_PER_SLOT, not SECONDS_PER_SLOT. " - f"Expected time={expected_time} (slot={test_slot} * intervals={patched_intervals}), " - f"got time={store.time}" + + # Block-based: slot 100 * 5 = 500 intervals. + expected_block = test_slot * INTERVALS_PER_SLOT + assert store.time == expected_block + + +class TestOptionalServiceWiring: + """Tests for optional services (API server, validator service).""" + + def test_api_server_created_when_config_provided(self) -> None: + """API server is created when api_config is set.""" + config = NodeConfig( + genesis_time=GENESIS_TIME, + validators=make_validators(3), + event_source=MockEventSource(), + network=MockNetworkRequester(), + time_fn=lambda: 1704067200.0, + api_config=ApiServerConfig(host="127.0.0.1", port=5052), ) + node = Node.from_genesis(config) + assert node.api_server is not None + + def test_api_server_none_when_no_config(self, node_config: NodeConfig) -> None: + """API server is None when api_config is not set.""" + node = Node.from_genesis(node_config) + assert node.api_server is None + + def test_validator_service_created_when_registry_provided(self) -> None: + """Validator service is created when validator_registry is set.""" + registry = ValidatorRegistry() + registry.add(ValidatorEntry(index=ValidatorIndex(0), secret_key=MagicMock())) + + config = NodeConfig( + genesis_time=GENESIS_TIME, + validators=make_validators(3), + event_source=MockEventSource(), + network=MockNetworkRequester(), + time_fn=lambda: 1704067200.0, + validator_registry=registry, + ) + node = Node.from_genesis(config) + assert node.validator_service is not None + + def test_validator_service_none_when_no_registry(self, node_config: NodeConfig) -> None: + """Validator service is None when no registry is set.""" + node = Node.from_genesis(node_config) + assert node.validator_service is None class TestNodeShutdown: @@ -157,6 +284,103 @@ def test_is_running_reflects_shutdown_state(self, node_config: NodeConfig) -> No node.stop() assert node.is_running is False + async def test_wait_shutdown_stops_chain_service(self, node_config: NodeConfig) -> None: + """Shutdown stops the chain service.""" + node = Node.from_genesis(node_config) + asyncio.get_running_loop().call_later(0.05, node.stop) + await node.run(install_signal_handlers=False) + + assert node.chain_service._running is False + + async def test_wait_shutdown_stops_network_service(self, node_config: NodeConfig) -> None: + """Shutdown stops the network service.""" + node = Node.from_genesis(node_config) + asyncio.get_running_loop().call_later(0.05, node.stop) + await node.run(install_signal_handlers=False) + + assert node.network_service._running is False + + async def test_database_closed_after_run(self) -> None: + """Database is closed after run exits.""" + mock_db = MagicMock() + config = NodeConfig( + genesis_time=GENESIS_TIME, + validators=make_validators(3), + event_source=MockEventSource(), + network=MockNetworkRequester(), + time_fn=lambda: 1704067200.0, + ) + node = Node.from_genesis(config) + node.database = mock_db + + asyncio.get_running_loop().call_later(0.05, node.stop) + await node.run(install_signal_handlers=False) + + mock_db.close.assert_called_once() + + +class TestGenesisPersistence: + """Tests for genesis block/state persistence to database.""" + + def test_from_genesis_persists_block_to_database(self) -> None: + """Genesis block is persisted to the database.""" + config = NodeConfig( + genesis_time=GENESIS_TIME, + validators=make_validators(3), + event_source=MockEventSource(), + network=MockNetworkRequester(), + time_fn=lambda: 1704067200.0, + database_path=":memory:", + ) + node = Node.from_genesis(config) + + # The database should have the genesis block. + assert node.database is not None + head_root = node.database.get_head_root() + assert head_root is not None + block = node.database.get_block(head_root) + assert block is not None + assert block.slot == Slot(0) + + def test_from_genesis_persists_state_to_database(self) -> None: + """Genesis state is persisted to the database.""" + config = NodeConfig( + genesis_time=GENESIS_TIME, + validators=make_validators(3), + event_source=MockEventSource(), + network=MockNetworkRequester(), + time_fn=lambda: 1704067200.0, + database_path=":memory:", + ) + node = Node.from_genesis(config) + + assert node.database is not None + head_root = node.database.get_head_root() + assert head_root is not None + state = node.database.get_state(head_root) + assert state is not None + + def test_from_genesis_persists_checkpoints(self) -> None: + """Justified and finalized checkpoints are persisted.""" + config = NodeConfig( + genesis_time=GENESIS_TIME, + validators=make_validators(3), + event_source=MockEventSource(), + network=MockNetworkRequester(), + time_fn=lambda: 1704067200.0, + database_path=":memory:", + ) + node = Node.from_genesis(config) + + assert node.database is not None + assert node.database.get_justified_checkpoint() is not None + assert node.database.get_finalized_checkpoint() is not None + + def test_from_genesis_without_database_skips_persistence(self, node_config: NodeConfig) -> None: + """No database means no persistence calls.""" + node = Node.from_genesis(node_config) + assert node.database is None + class TestNodeIntegration: """Integration tests for Node orchestration.""" @@ -165,20 +389,15 @@ async def test_run_exits_on_stop(self, node_config: NodeConfig) -> None: """Node.run() exits cleanly when stop() is called.""" node = Node.from_genesis(node_config) - # Schedule stop after a short delay asyncio.get_running_loop().call_later(0.05, node.stop) await node.run(install_signal_handlers=False) - # Should complete without hanging or raising - def test_sync_service_receives_store_from_genesis(self, node_config: NodeConfig) -> None: """Sync service has access to the genesis store.""" node = Node.from_genesis(node_config) - # SyncService should have the same store as the node assert node.sync_service.store is not None assert len(node.sync_service.store.blocks) == 1 - # The store's head should be the genesis block head_block = node.sync_service.store.blocks[node.sync_service.store.head] assert head_block.slot == Slot(0) diff --git a/tests/lean_spec/subspecs/validator/test_registry.py b/tests/lean_spec/subspecs/validator/test_registry.py index 9ab01c95..1967fb88 100644 --- a/tests/lean_spec/subspecs/validator/test_registry.py +++ b/tests/lean_spec/subspecs/validator/test_registry.py @@ -66,6 +66,27 @@ def test_add_multiple_entries(self) -> None: ValidatorIndex(4): key_4, } + def test_primary_index_returns_none_for_empty_registry(self) -> None: + """Empty registry returns None for primary index.""" + registry = ValidatorRegistry() + assert registry.primary_index() is None + + def test_primary_index_returns_first_index(self) -> None: + """Primary index returns the first validator index.""" + registry = ValidatorRegistry() + key = MagicMock(name="key_5") + registry.add(ValidatorEntry(index=ValidatorIndex(5), secret_key=key)) + + assert registry.primary_index() == ValidatorIndex(5) + + def test_primary_index_with_multiple_validators(self) -> None: + """Primary index returns the first inserted index.""" + registry = ValidatorRegistry() + registry.add(ValidatorEntry(index=ValidatorIndex(3), secret_key=MagicMock())) + registry.add(ValidatorEntry(index=ValidatorIndex(1), secret_key=MagicMock())) + + assert registry.primary_index() == ValidatorIndex(3) + def test_from_secret_keys(self) -> None: """Registry from dict preserves exact index-to-key mapping.""" key_0 = MagicMock(name="key_0")