diff --git a/lean_consensus.pdf b/lean_consensus.pdf index d4460c0c..015c96e2 100644 Binary files a/lean_consensus.pdf and b/lean_consensus.pdf differ diff --git a/src/lean_spec/subspecs/networking/discovery/codec.py b/src/lean_spec/subspecs/networking/discovery/codec.py index 9313a561..685930b9 100644 --- a/src/lean_spec/subspecs/networking/discovery/codec.py +++ b/src/lean_spec/subspecs/networking/discovery/codec.py @@ -20,6 +20,8 @@ from __future__ import annotations +import os + from lean_spec.subspecs.networking.types import SeqNumber from lean_spec.types import Uint64, decode_rlp, encode_rlp from lean_spec.types.rlp import RLPDecodingError @@ -298,6 +300,4 @@ def _decode_talkresp(payload: bytes) -> TalkResp: def generate_request_id() -> RequestId: """Generate a random request ID.""" - import os - return RequestId(data=os.urandom(8)) diff --git a/src/lean_spec/subspecs/networking/discovery/handshake.py b/src/lean_spec/subspecs/networking/discovery/handshake.py index 18de6cc7..b57f3cd8 100644 --- a/src/lean_spec/subspecs/networking/discovery/handshake.py +++ b/src/lean_spec/subspecs/networking/discovery/handshake.py @@ -253,6 +253,8 @@ def create_handshake_response( whoareyou: WhoAreYouAuthdata, remote_pubkey: bytes, challenge_data: bytes, + remote_ip: str = "", + remote_port: int = 0, ) -> tuple[bytes, bytes, bytes]: """ Create a HANDSHAKE packet in response to WHOAREYOU. @@ -265,6 +267,8 @@ def create_handshake_response( remote_pubkey: Remote's 33-byte compressed public key. challenge_data: Full WHOAREYOU data for key derivation (masking-iv || static-header || authdata from received packet). + remote_ip: Remote peer's IP address for session keying. + remote_port: Remote peer's UDP port for session keying. Returns: Tuple of (authdata, send_key, recv_key). @@ -312,12 +316,14 @@ def create_handshake_response( is_initiator=True, ) - # Store session. + # Store session keyed by (node_id, ip, port). self._session_cache.create( node_id=remote_node_id, send_key=send_key, recv_key=recv_key, is_initiator=True, + ip=remote_ip, + port=remote_port, ) # Clean up pending handshake. @@ -330,6 +336,8 @@ def handle_handshake( self, remote_node_id: bytes, handshake: HandshakeAuthdata, + remote_ip: str = "", + remote_port: int = 0, ) -> HandshakeResult: """ Process a received HANDSHAKE packet. @@ -339,6 +347,8 @@ def handle_handshake( Args: remote_node_id: 32-byte node ID from packet source. handshake: Decoded HANDSHAKE authdata. + remote_ip: Remote peer's IP address for session keying. + remote_port: Remote peer's UDP port for session keying. Returns: HandshakeResult with established session. @@ -409,12 +419,14 @@ def handle_handshake( is_initiator=False, ) - # Create session. + # Create session keyed by (node_id, ip, port). session = self._session_cache.create( node_id=remote_node_id, send_key=send_key, recv_key=recv_key, is_initiator=False, + ip=remote_ip, + port=remote_port, ) # Clean up pending handshake. diff --git a/src/lean_spec/subspecs/networking/discovery/messages.py b/src/lean_spec/subspecs/networking/discovery/messages.py index 7edde1c9..ff1720bd 100644 --- a/src/lean_spec/subspecs/networking/discovery/messages.py +++ b/src/lean_spec/subspecs/networking/discovery/messages.py @@ -291,20 +291,3 @@ class StaticHeader(StrictBaseModel): authdata_size: Uint16 """Byte length of the authdata section following this header.""" - - -class WhoAreYouAuthdata(StrictBaseModel): - """ - Authdata for WHOAREYOU packets (flag=1). - - Sent when the recipient cannot decrypt an incoming message packet. - The nonce in the packet header is set to the nonce of the failed message. - - Total size: 24 bytes (16 + 8). - """ - - id_nonce: IdNonce - """128-bit random value for identity verification.""" - - enr_seq: SeqNumber - """Recipient's known ENR sequence for the sender. 0 if unknown.""" diff --git a/src/lean_spec/subspecs/networking/discovery/packet.py b/src/lean_spec/subspecs/networking/discovery/packet.py index 6400399d..5962f053 100644 --- a/src/lean_spec/subspecs/networking/discovery/packet.py +++ b/src/lean_spec/subspecs/networking/discovery/packet.py @@ -32,7 +32,6 @@ import os import struct from dataclasses import dataclass -from enum import IntEnum from lean_spec.types import Bytes12, Bytes16, Uint64 @@ -63,14 +62,6 @@ """Fixed portion of handshake authdata: src-id (32) + sig-size (1) + eph-key-size (1).""" -class PacketType(IntEnum): - """Packet type aliases matching PacketFlag for clarity.""" - - MESSAGE = 0 - WHOAREYOU = 1 - HANDSHAKE = 2 - - @dataclass(frozen=True, slots=True) class PacketHeader: """Decoded packet header.""" @@ -135,6 +126,7 @@ def encode_packet( authdata: bytes, message: bytes, encryption_key: bytes | None = None, + masking_iv: Bytes16 | None = None, ) -> bytes: """ Encode a Discovery v5 packet. @@ -147,6 +139,8 @@ def encode_packet( authdata: Authentication data (varies by packet type). message: Message payload (plaintext for WHOAREYOU, encrypted otherwise). encryption_key: 16-byte key for message encryption (None for WHOAREYOU). + masking_iv: Optional 16-byte IV for header masking. Random if not provided. + Must be provided for WHOAREYOU to match the IV used in challenge_data. Returns: Complete encoded packet ready for UDP transmission. @@ -156,13 +150,14 @@ def encode_packet( if len(nonce) != GCM_NONCE_SIZE: raise ValueError(f"Nonce must be {GCM_NONCE_SIZE} bytes, got {len(nonce)}") - # Fresh random IV for header masking. - # - # Using dest_node_id as the masking key is deterministic, - # so the IV MUST be random to prevent ciphertext patterns. - # Without randomness, identical packets would produce - # identical masked headers, enabling traffic analysis. - masking_iv = Bytes16(os.urandom(CTR_IV_SIZE)) + if masking_iv is None: + # Fresh random IV for header masking. + # + # Using dest_node_id as the masking key is deterministic, + # so the IV MUST be random to prevent ciphertext patterns. + # Without randomness, identical packets would produce + # identical masked headers, enabling traffic analysis. + masking_iv = Bytes16(os.urandom(CTR_IV_SIZE)) static_header = _encode_static_header(flag, nonce, len(authdata)) header = static_header + authdata @@ -182,16 +177,17 @@ def encode_packet( if encryption_key is None: raise ValueError("Encryption key required for non-WHOAREYOU packets") - # Masked header as AAD prevents header tampering. + # Per spec: message-ad = masking-iv || header (plaintext). # - # The recipient verifies the header wasn't modified - # without having to decrypt the payload first. + # The AAD binds the plaintext header to the encrypted message. + # The recipient reconstructs this from the decoded header. + message_ad = bytes(masking_iv) + header encrypted_message = aes_gcm_encrypt( - Bytes16(encryption_key), Bytes12(nonce), message, masked_header + Bytes16(encryption_key), Bytes12(nonce), message, message_ad ) # Assemble packet. - packet = masking_iv + masked_header + encrypted_message + packet = bytes(masking_iv) + masked_header + encrypted_message if len(packet) > MAX_PACKET_SIZE: raise ValueError(f"Packet exceeds max size: {len(packet)} > {MAX_PACKET_SIZE}") @@ -199,7 +195,7 @@ def encode_packet( return packet -def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeader, bytes]: +def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeader, bytes, bytes]: """ Decode and unmask a Discovery v5 packet header. @@ -208,7 +204,8 @@ def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeade data: Raw packet bytes. Returns: - Tuple of (header, message_bytes). + Tuple of (header, message_bytes, message_ad). + message_ad is masking-iv || plaintext header, used as AAD for decryption. Raises: ValueError: If packet is malformed. @@ -224,8 +221,7 @@ def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeade masked_data = data[CTR_IV_SIZE:] # Decrypt static header first to get authdata size. - static_header_masked = masked_data[:STATIC_HEADER_SIZE] - static_header = aes_ctr_decrypt(masking_key, masking_iv, static_header_masked) + static_header = aes_ctr_decrypt(masking_key, masking_iv, masked_data[:STATIC_HEADER_SIZE]) # Parse static header. protocol_id = static_header[:6] @@ -245,15 +241,19 @@ def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeade if len(data) < header_end: raise ValueError(f"Packet truncated: need {header_end}, have {len(data)}") - # Decrypt the full header including authdata. - full_masked_header = masked_data[: STATIC_HEADER_SIZE + authdata_size] - full_header = aes_ctr_decrypt(masking_key, masking_iv, full_masked_header) + # Decrypt the full header (static header + authdata) in one pass. + full_header = aes_ctr_decrypt( + masking_key, masking_iv, masked_data[: STATIC_HEADER_SIZE + authdata_size] + ) authdata = full_header[STATIC_HEADER_SIZE:] # Message bytes are everything after the header. message_bytes = data[header_end:] - return PacketHeader(flag=flag, nonce=nonce, authdata=authdata), message_bytes + # Per spec: message-ad = masking-iv || header (plaintext). + message_ad = bytes(masking_iv) + full_header + + return PacketHeader(flag=flag, nonce=nonce, authdata=authdata), message_bytes, message_ad def decode_message_authdata(authdata: bytes) -> MessageAuthdata: @@ -310,7 +310,7 @@ def decrypt_message( encryption_key: bytes, nonce: bytes, ciphertext: bytes, - masked_header: bytes, + message_ad: bytes, ) -> bytes: """ Decrypt an encrypted message payload. @@ -319,12 +319,12 @@ def decrypt_message( encryption_key: 16-byte session key. nonce: 12-byte nonce from packet header. ciphertext: Encrypted message with GCM tag. - masked_header: Masked header bytes (used as AAD). + message_ad: Additional authenticated data (masking-iv || plaintext header). Returns: Decrypted message plaintext. """ - return aes_gcm_decrypt(Bytes16(encryption_key), Bytes12(nonce), ciphertext, masked_header) + return aes_gcm_decrypt(Bytes16(encryption_key), Bytes12(nonce), ciphertext, message_ad) def encode_message_authdata(src_id: bytes) -> bytes: diff --git a/src/lean_spec/subspecs/networking/discovery/routing.py b/src/lean_spec/subspecs/networking/discovery/routing.py index eb4e2f19..92574268 100644 --- a/src/lean_spec/subspecs/networking/discovery/routing.py +++ b/src/lean_spec/subspecs/networking/discovery/routing.py @@ -4,7 +4,7 @@ Kademlia-style routing table for Node Discovery Protocol v5.1. Node Table Structure --------------------- + Nodes keep information about other nodes in their neighborhood. Neighbor nodes are stored in a routing table consisting of 'k-buckets'. For each 0 <= i < 256, every node keeps a k-bucket for nodes of logdistance(self, n) == i. @@ -14,7 +14,7 @@ seen at tail. Distance Metric ---------------- + The 'distance' between two node IDs is the bitwise XOR of the IDs, interpreted as a big-endian number: @@ -26,7 +26,7 @@ logdistance(n1, n2) = log2(distance(n1, n2)) Bucket Eviction Policy ----------------------- + When a new node N1 is encountered, it can be inserted into the corresponding bucket. @@ -37,7 +37,7 @@ removed, and N1 added to the front of the bucket. Liveness Verification ---------------------- + Implementations should perform liveness checks asynchronously and occasionally verify that a random node in a random bucket is live by sending PING. When responding to FINDNODE, implementations must avoid relaying any nodes whose @@ -148,13 +148,13 @@ class KBucket: - New nodes added to tail, eviction candidates at head Eviction Policy - --------------- + When full, ping the head node (least-recently seen). - If it responds, keep it and discard the new node. - If it fails, evict it and add the new node. Replacement Cache - ----------------- + Implementations should maintain a 'replacement cache' alongside each bucket. This cache holds recently-seen nodes which would fall into the corresponding bucket but cannot become a member because it is at capacity. Once a bucket @@ -257,7 +257,6 @@ class RoutingTable: Bucket i contains nodes with log2(distance) == i + 1. Fork Filtering - -------------- When local_fork_digest is set: @@ -266,7 +265,6 @@ class RoutingTable: - Requires eth2 ENR data to be present Lookup Algorithm - ---------------- Locates the k closest nodes to a target ID: @@ -277,7 +275,6 @@ class RoutingTable: 5. Stop when k closest have been queried Table Maintenance - ----------------- - Track close neighbors - Regularly refresh stale buckets diff --git a/src/lean_spec/subspecs/networking/discovery/service.py b/src/lean_spec/subspecs/networking/discovery/service.py index 04a8524d..3d60e7c3 100644 --- a/src/lean_spec/subspecs/networking/discovery/service.py +++ b/src/lean_spec/subspecs/networking/discovery/service.py @@ -24,7 +24,9 @@ from __future__ import annotations import asyncio +import ipaddress import logging +import os import random from dataclasses import dataclass from typing import TYPE_CHECKING, Callable @@ -558,8 +560,6 @@ async def _refresh_loop(self) -> None: await asyncio.sleep(REFRESH_INTERVAL_SECS) try: # Perform lookup for random target. - import os - target = os.urandom(32) await self.find_node(target) except Exception as e: @@ -602,18 +602,7 @@ def _encode_ip_address(self, ip_str: str) -> bytes: Returns: Raw bytes representation of the IP address. """ - import ipaddress - - try: - # Try IPv4 first. - addr = ipaddress.ip_address(ip_str) - return addr.packed - except ValueError: - # Fall back to returning as-is if somehow already bytes. - if isinstance(ip_str, bytes): - return ip_str - # Last resort: encode as UTF-8 (shouldn't happen with valid IPs). - return ip_str.encode() + return ipaddress.ip_address(ip_str).packed def _enr_to_entry(self, enr: ENR) -> NodeEntry: """Convert an ENR to a NodeEntry.""" diff --git a/src/lean_spec/subspecs/networking/discovery/session.py b/src/lean_spec/subspecs/networking/discovery/session.py index 67d820ea..04ce3251 100644 --- a/src/lean_spec/subspecs/networking/discovery/session.py +++ b/src/lean_spec/subspecs/networking/discovery/session.py @@ -69,17 +69,26 @@ def touch(self) -> None: self.last_seen = time.time() +SessionKey = tuple[bytes, str, int] +"""Session cache key: (node_id, ip, port). + +Per spec, sessions are tied to a specific UDP endpoint. +This prevents session confusion if a node changes IP or port. +""" + + @dataclass class SessionCache: """ Cache of active sessions with peers. Thread-safe session storage with automatic expiration cleanup. - Sessions are keyed by node ID. + Sessions are keyed by (node_id, ip, port) per spec requirement + that sessions are tied to a specific UDP endpoint. """ - sessions: dict[bytes, Session] = field(default_factory=dict) - """Node ID -> Session mapping.""" + sessions: dict[SessionKey, Session] = field(default_factory=dict) + """(node_id, ip, port) -> Session mapping.""" timeout_secs: float = DEFAULT_SESSION_TIMEOUT_SECS """Session expiration timeout.""" @@ -90,25 +99,28 @@ class SessionCache: _lock: Lock = field(default_factory=Lock) """Thread safety lock.""" - def get(self, node_id: bytes) -> Session | None: + def get(self, node_id: bytes, ip: str = "", port: int = 0) -> Session | None: """ - Get an active session for a node. + Get an active session for a node at a specific endpoint. Returns None if no session exists or if it has expired. Args: node_id: 32-byte peer node ID. + ip: Peer IP address. + port: Peer UDP port. Returns: Active session or None. """ + key: SessionKey = (node_id, ip, port) with self._lock: - session = self.sessions.get(node_id) + session = self.sessions.get(key) if session is None: return None if session.is_expired(self.timeout_secs): - del self.sessions[node_id] + del self.sessions[key] return None return session @@ -119,11 +131,13 @@ def create( send_key: bytes, recv_key: bytes, is_initiator: bool, + ip: str = "", + port: int = 0, ) -> Session: """ Create and store a new session. - If a session already exists for this node, it is replaced. + If a session already exists for this endpoint, it is replaced. If the cache is full, the oldest session is evicted. Args: @@ -131,6 +145,8 @@ def create( send_key: 16-byte encryption key for outgoing messages. recv_key: 16-byte decryption key for incoming messages. is_initiator: True if we initiated the handshake. + ip: Peer IP address. + port: Peer UDP port. Returns: The newly created session. @@ -142,6 +158,7 @@ def create( if len(recv_key) != 16: raise ValueError(f"Recv key must be 16 bytes, got {len(recv_key)}") + key: SessionKey = (node_id, ip, port) now = time.time() session = Session( node_id=node_id, @@ -154,40 +171,45 @@ def create( with self._lock: # Evict oldest if at capacity. - if len(self.sessions) >= self.max_sessions and node_id not in self.sessions: + if len(self.sessions) >= self.max_sessions and key not in self.sessions: self._evict_oldest() - self.sessions[node_id] = session + self.sessions[key] = session return session - def remove(self, node_id: bytes) -> bool: + def remove(self, node_id: bytes, ip: str = "", port: int = 0) -> bool: """ Remove a session. Args: node_id: 32-byte peer node ID. + ip: Peer IP address. + port: Peer UDP port. Returns: True if session was removed, False if not found. """ + key: SessionKey = (node_id, ip, port) with self._lock: - if node_id in self.sessions: - del self.sessions[node_id] + if key in self.sessions: + del self.sessions[key] return True return False - def touch(self, node_id: bytes) -> bool: + def touch(self, node_id: bytes, ip: str = "", port: int = 0) -> bool: """ Update the last_seen timestamp for a session. Args: node_id: 32-byte peer node ID. + ip: Peer IP address. + port: Peer UDP port. Returns: True if session was updated, False if not found. """ - session = self.get(node_id) + session = self.get(node_id, ip, port) if session is not None: session.touch() return True @@ -202,12 +224,12 @@ def cleanup_expired(self) -> int: """ with self._lock: expired = [ - node_id - for node_id, session in self.sessions.items() + key + for key, session in self.sessions.items() if session.is_expired(self.timeout_secs) ] - for node_id in expired: - del self.sessions[node_id] + for key in expired: + del self.sessions[key] return len(expired) def count(self) -> int: @@ -220,8 +242,8 @@ def _evict_oldest(self) -> None: if not self.sessions: return - oldest_id = min(self.sessions, key=lambda k: self.sessions[k].created_at) - del self.sessions[oldest_id] + oldest_key = min(self.sessions, key=lambda k: self.sessions[k].created_at) + del self.sessions[oldest_key] @dataclass diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py index 5eb7119e..f9f489f6 100644 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ b/src/lean_spec/subspecs/networking/discovery/transport.py @@ -17,10 +17,12 @@ import asyncio import logging +import os +import struct from dataclasses import dataclass from typing import TYPE_CHECKING, Callable -from lean_spec.types import Uint64 +from lean_spec.types import Bytes16, Uint64 from .codec import ( DiscoveryMessage, @@ -30,8 +32,21 @@ ) from .config import DiscoveryConfig from .handshake import HandshakeManager -from .messages import FindNode, Nonce, PacketFlag, Ping, Pong, TalkReq +from .messages import ( + PROTOCOL_ID, + PROTOCOL_VERSION, + Distance, + FindNode, + Nodes, + Nonce, + PacketFlag, + Ping, + Pong, + TalkReq, + TalkResp, +) from .packet import ( + PacketHeader, decode_handshake_authdata, decode_message_authdata, decode_packet_header, @@ -330,8 +345,6 @@ async def send_findnode( Returns: List of RLP-encoded ENRs from all NODES responses. """ - from .messages import Distance, Nodes - request_id = generate_request_id() findnode = FindNode( request_id=request_id, @@ -369,49 +382,16 @@ async def _send_multi_response_request( Returns: List of response messages (may be empty on timeout/error). """ - from .messages import Nodes - if self._transport is None: raise RuntimeError("Transport not started") # Register address for responses. self._node_addresses[dest_node_id] = dest_addr - # Get or create session. - session = self._session_cache.get(dest_node_id) + # Build and send packet. nonce = generate_nonce() - - # Encode message. message_bytes = encode_message(message) - - if session is not None: - authdata = encode_message_authdata(self._local_node_id) - packet = encode_packet( - dest_node_id=dest_node_id, - src_node_id=self._local_node_id, - flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), - authdata=authdata, - message=message_bytes, - encryption_key=session.send_key, - ) - else: - # Trigger handshake via deliberate decryption failure. - self._handshake_manager.start_handshake(dest_node_id) - authdata = encode_message_authdata(self._local_node_id) - - import os - - dummy_key = os.urandom(16) - packet = encode_packet( - dest_node_id=dest_node_id, - src_node_id=self._local_node_id, - flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), - authdata=authdata, - message=message_bytes, - encryption_key=dummy_key, - ) + packet = self._build_and_send_packet(dest_node_id, dest_addr, nonce, message_bytes) # Create collector for multiple responses. loop = asyncio.get_running_loop() @@ -488,8 +468,6 @@ async def send_talkreq( Returns: Response payload or None on timeout/error. """ - from .messages import TalkResp - request_id = generate_request_id() talkreq = TalkReq( request_id=request_id, @@ -519,53 +497,10 @@ async def _send_request( # Register address for responses. self._node_addresses[dest_node_id] = dest_addr - # Get or create session. - session = self._session_cache.get(dest_node_id) + # Build and send packet. nonce = generate_nonce() - - # Encode message. message_bytes = encode_message(message) - - if session is not None: - # Have session, send encrypted message. - authdata = encode_message_authdata(self._local_node_id) - packet = encode_packet( - dest_node_id=dest_node_id, - src_node_id=self._local_node_id, - flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), - authdata=authdata, - message=message_bytes, - encryption_key=session.send_key, - ) - else: - # Deliberate decryption failure triggers handshake. - # - # Discovery v5's handshake is initiated by failure: - # - # 1. We send a MESSAGE with random encryption key - # 2. Recipient cannot decrypt (they don't have the key) - # 3. Recipient responds with WHOAREYOU challenge - # 4. We complete handshake with HANDSHAKE packet - # - # This approach avoids the need for session negotiation - # before sending the first message. - self._handshake_manager.start_handshake(dest_node_id) - - authdata = encode_message_authdata(self._local_node_id) - - import os - - dummy_key = os.urandom(16) - packet = encode_packet( - dest_node_id=dest_node_id, - src_node_id=self._local_node_id, - flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), - authdata=authdata, - message=message_bytes, - encryption_key=dummy_key, - ) + packet = self._build_and_send_packet(dest_node_id, dest_addr, nonce, message_bytes) # Create pending request. loop = asyncio.get_running_loop() @@ -596,25 +531,84 @@ async def _send_request( finally: self._pending_requests.pop(request_id_bytes, None) + def _build_and_send_packet( + self, + dest_node_id: bytes, + dest_addr: tuple[str, int], + nonce: Nonce, + message_bytes: bytes, + ) -> bytes: + """ + Build a MESSAGE packet, using session key if available or a dummy key + to trigger handshake. + + Args: + dest_node_id: 32-byte destination node ID. + dest_addr: (ip, port) tuple. + nonce: 12-byte message nonce. + message_bytes: Encoded message payload. + + Returns: + Encoded packet bytes. + """ + ip, port = dest_addr + session = self._session_cache.get(dest_node_id, ip, port) + + authdata = encode_message_authdata(self._local_node_id) + + if session is not None: + return encode_packet( + dest_node_id=dest_node_id, + src_node_id=self._local_node_id, + flag=PacketFlag.MESSAGE, + nonce=bytes(nonce), + authdata=authdata, + message=message_bytes, + encryption_key=session.send_key, + ) + + # Deliberate decryption failure triggers handshake. + # + # Discovery v5's handshake is initiated by failure: + # + # 1. We send a MESSAGE with random encryption key + # 2. Recipient cannot decrypt (they don't have the key) + # 3. Recipient responds with WHOAREYOU challenge + # 4. We complete handshake with HANDSHAKE packet + # + # This approach avoids the need for session negotiation + # before sending the first message. + self._handshake_manager.start_handshake(dest_node_id) + dummy_key = os.urandom(16) + return encode_packet( + dest_node_id=dest_node_id, + src_node_id=self._local_node_id, + flag=PacketFlag.MESSAGE, + nonce=bytes(nonce), + authdata=authdata, + message=message_bytes, + encryption_key=dummy_key, + ) + async def _handle_packet(self, data: bytes, addr: tuple[str, int]) -> None: """Handle a received UDP packet.""" try: # Decode packet header. - header, message_bytes = decode_packet_header(self._local_node_id, data) + header, message_bytes, message_ad = decode_packet_header(self._local_node_id, data) if header.flag == PacketFlag.WHOAREYOU: await self._handle_whoareyou(header, message_bytes, addr, data) elif header.flag == PacketFlag.HANDSHAKE: - await self._handle_handshake(header, message_bytes, addr, data) + await self._handle_handshake(header, message_bytes, addr, message_ad) else: - await self._handle_message(header, message_bytes, addr, data) + await self._handle_message(header, message_bytes, addr, message_ad) except Exception as e: logger.debug("Error handling packet from %s: %s", addr, e) async def _handle_whoareyou( self, - header, + header: PacketHeader, message_bytes: bytes, addr: tuple[str, int], raw_packet: bytes, @@ -661,10 +655,6 @@ async def _handle_whoareyou( # - authdata: 24 bytes (id-nonce 16 + enr-seq 8) # # We use the unmasked header, which we can reconstruct from the decoded values. - import struct - - from .messages import PROTOCOL_ID, PROTOCOL_VERSION - masking_iv = raw_packet[:16] static_header = ( PROTOCOL_ID @@ -688,11 +678,14 @@ async def _handle_whoareyou( # Build and send the HANDSHAKE response. try: + ip, port = addr authdata, send_key, recv_key = self._handshake_manager.create_handshake_response( remote_node_id=remote_node_id, whoareyou=whoareyou, remote_pubkey=remote_pubkey, challenge_data=challenge_data, + remote_ip=ip, + remote_port=port, ) # Re-send the original message, now encrypted with the new session key. @@ -722,29 +715,29 @@ async def _handle_whoareyou( async def _handle_handshake( self, - header, + header: PacketHeader, message_bytes: bytes, addr: tuple[str, int], - raw_packet: bytes, + message_ad: bytes, ) -> None: """Handle a HANDSHAKE packet.""" handshake_authdata = decode_handshake_authdata(header.authdata) remote_node_id = handshake_authdata.src_id try: - result = self._handshake_manager.handle_handshake(remote_node_id, handshake_authdata) + ip, port = addr + result = self._handshake_manager.handle_handshake( + remote_node_id, handshake_authdata, remote_ip=ip, remote_port=port + ) logger.debug("Handshake completed with %s", remote_node_id.hex()[:16]) # Decrypt the included message. if len(message_bytes) > 0: - # Extract masked header for AAD. - masked_header = raw_packet[16 : 16 + 23 + len(header.authdata)] - plaintext = decrypt_message( encryption_key=result.session.recv_key, nonce=bytes(header.nonce), ciphertext=message_bytes, - masked_header=masked_header, + message_ad=message_ad, ) message = decode_message(plaintext) @@ -755,31 +748,29 @@ async def _handle_handshake( async def _handle_message( self, - header, + header: PacketHeader, message_bytes: bytes, addr: tuple[str, int], - raw_packet: bytes, + message_ad: bytes, ) -> None: """Handle an ordinary MESSAGE packet.""" message_authdata = decode_message_authdata(header.authdata) remote_node_id = message_authdata.src_id - # Get session. - session = self._session_cache.get(remote_node_id) + # Get session keyed by (node_id, ip, port). + ip, port = addr + session = self._session_cache.get(remote_node_id, ip, port) if session is None: # Can't decrypt - send WHOAREYOU. await self._send_whoareyou(remote_node_id, header.nonce, addr) return try: - # Extract masked header for AAD. - masked_header = raw_packet[16 : 16 + 23 + len(header.authdata)] - plaintext = decrypt_message( encryption_key=session.recv_key, nonce=bytes(header.nonce), ciphertext=message_bytes, - masked_header=masked_header, + message_ad=message_ad, ) message = decode_message(plaintext) @@ -798,7 +789,8 @@ async def _handle_decoded_message( ) -> None: """Process a successfully decoded message.""" # Update session activity. - self._session_cache.touch(remote_node_id) + ip, port = addr + self._session_cache.touch(remote_node_id, ip, port) # Check if this is a response to a pending request. request_id = bytes(message.request_id) @@ -826,8 +818,6 @@ async def _send_whoareyou( addr: tuple[str, int], ) -> None: """Send a WHOAREYOU packet.""" - import os - if self._transport is None: return @@ -855,6 +845,7 @@ async def _send_whoareyou( authdata=authdata, message=b"", encryption_key=None, + masking_iv=Bytes16(masking_iv), ) self._transport.sendto(packet, addr) @@ -889,7 +880,8 @@ async def send_response( # # The requester initiated the handshake. # By the time we respond, session keys must exist. - session = self._session_cache.get(dest_node_id) + ip, port = dest_addr + session = self._session_cache.get(dest_node_id, ip, port) if session is None: logger.debug("No session for response to %s", dest_node_id.hex()[:16]) return False diff --git a/tests/lean_spec/subspecs/networking/discovery/conftest.py b/tests/lean_spec/subspecs/networking/discovery/conftest.py index 3613e469..c86ed80b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/conftest.py +++ b/tests/lean_spec/subspecs/networking/discovery/conftest.py @@ -9,27 +9,28 @@ from lean_spec.types import Bytes64, Uint64 # From devp2p test vectors -_NODE_B_PRIVKEY = bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") -_NODE_B_ID = bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") -_NODE_A_ID = bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") +NODE_B_PRIVKEY = bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") +NODE_B_ID = bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") +NODE_B_PUBKEY = bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") +NODE_A_ID = bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") @pytest.fixture def local_private_key() -> bytes: """Node B's private key from devp2p test vectors.""" - return _NODE_B_PRIVKEY + return NODE_B_PRIVKEY @pytest.fixture def local_node_id() -> NodeId: """Node B's ID from devp2p test vectors.""" - return NodeId(_NODE_B_ID) + return NodeId(NODE_B_ID) @pytest.fixture def remote_node_id() -> NodeId: """Node A's ID from devp2p test vectors.""" - return NodeId(_NODE_A_ID) + return NodeId(NODE_A_ID) @pytest.fixture @@ -40,9 +41,7 @@ def local_enr() -> ENR: seq=Uint64(1), pairs={ "id": b"v4", - "secp256k1": bytes.fromhex( - "0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" - ), + "secp256k1": NODE_B_PUBKEY, "ip": bytes([127, 0, 0, 1]), "udp": (9000).to_bytes(2, "big"), }, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py index 93ffc65e..81c2cc7c 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py @@ -15,6 +15,7 @@ decode_whoareyou_authdata, ) from lean_spec.subspecs.networking.discovery.session import SessionCache +from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY class TestPendingHandshake: @@ -475,6 +476,107 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa with pytest.raises(HandshakeError, match="ENR required"): manager.handle_handshake(bytes(remote_node_id), fake_authdata) + def test_successful_handshake_with_signature_verification( + self, manager, remote_keypair, session_cache + ): + """Full handshake succeeds when signature is valid. + + Exercises the complete WHOAREYOU -> HANDSHAKE -> session flow. + """ + from lean_spec.subspecs.networking.discovery.crypto import sign_id_nonce + from lean_spec.subspecs.networking.discovery.handshake import HandshakeResult + from lean_spec.subspecs.networking.discovery.packet import ( + decode_handshake_authdata, + encode_handshake_authdata, + ) + from lean_spec.subspecs.networking.enr.enr import ENR + from lean_spec.types import Bytes32, Bytes33, Bytes64, Uint64 + + remote_priv, remote_pub, remote_node_id = remote_keypair + + # Node A (manager) creates WHOAREYOU for remote. + masking_iv = bytes(16) + id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( + bytes(remote_node_id), bytes(12), 0, masking_iv + ) + + # Remote creates handshake response. + eph_priv, eph_pub = generate_secp256k1_keypair() + local_node_id = manager._local_node_id + + # Remote signs the id_nonce proving ownership. + id_signature = sign_id_nonce( + Bytes32(remote_priv), + challenge_data, + Bytes33(eph_pub), + Bytes32(local_node_id), + ) + + # Remote includes their ENR since enr_seq=0. + remote_enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={"id": b"v4", "secp256k1": bytes(remote_pub)}, + ) + + authdata_bytes = encode_handshake_authdata( + src_id=bytes(remote_node_id), + id_signature=id_signature, + eph_pubkey=eph_pub, + record=remote_enr.to_rlp(), + ) + + handshake = decode_handshake_authdata(authdata_bytes) + + # Manager processes handshake - should succeed. + result = manager.handle_handshake(bytes(remote_node_id), handshake) + + assert result is not None + assert isinstance(result, HandshakeResult) + assert result.session is not None + assert len(result.session.send_key) == 16 + assert len(result.session.recv_key) == 16 + + def test_handle_handshake_rejects_invalid_signature( + self, manager, remote_keypair, session_cache + ): + """Handshake fails when signature is invalid.""" + from lean_spec.subspecs.networking.discovery.handshake import HandshakeError + from lean_spec.subspecs.networking.discovery.packet import ( + decode_handshake_authdata, + encode_handshake_authdata, + ) + from lean_spec.subspecs.networking.enr.enr import ENR + from lean_spec.types import Bytes64, Uint64 + + remote_priv, remote_pub, remote_node_id = remote_keypair + + # Set up WHOAREYOU state. + masking_iv = bytes(16) + manager.create_whoareyou(bytes(remote_node_id), bytes(12), 0, masking_iv) + + # Generate ephemeral key. + _eph_priv, eph_pub = generate_secp256k1_keypair() + + # Create authdata with INVALID signature (all-zero 64 bytes). + remote_enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={"id": b"v4", "secp256k1": bytes(remote_pub)}, + ) + + authdata_bytes = encode_handshake_authdata( + src_id=bytes(remote_node_id), + id_signature=bytes(64), # Wrong signature. + eph_pubkey=eph_pub, + record=remote_enr.to_rlp(), + ) + + handshake = decode_handshake_authdata(authdata_bytes) + + with pytest.raises(HandshakeError, match="Invalid ID signature"): + manager.handle_handshake(bytes(remote_node_id), handshake) + class TestHandshakeConcurrency: """Concurrent handshake handling tests.""" @@ -615,17 +717,14 @@ def test_register_enr_stores_in_cache(self, manager): from lean_spec.subspecs.networking.enr import ENR from lean_spec.types import Bytes64, Uint64 - remote_pub = bytes.fromhex( - "0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" - ) - remote_node_id = bytes(compute_node_id(remote_pub)) + remote_node_id = bytes(compute_node_id(NODE_B_PUBKEY)) enr = ENR( signature=Bytes64(bytes(64)), seq=Uint64(1), pairs={ "id": b"v4", - "secp256k1": remote_pub, + "secp256k1": NODE_B_PUBKEY, }, ) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_integration.py b/tests/lean_spec/subspecs/networking/discovery/test_integration.py index fbf71136..e8107a66 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_integration.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_integration.py @@ -182,7 +182,7 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): ) # Decode header. - header, ciphertext = decode_packet_header(node_b_keys["node_id"], packet) + header, ciphertext, message_ad = decode_packet_header(node_b_keys["node_id"], packet) assert header.flag == PacketFlag.MESSAGE @@ -200,12 +200,10 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): is_initiator=False, ) - # Extract masked header for AAD. - masked_header = packet[16 : 16 + 23 + len(header.authdata)] - # Node B uses recv_key to decrypt (which equals Node A's send_key). + # message_ad = masking-iv || plaintext header (per spec). plaintext = aes_gcm_decrypt( - Bytes16(b_recv_key), Bytes12(header.nonce), ciphertext, masked_header + Bytes16(b_recv_key), Bytes12(header.nonce), ciphertext, message_ad ) # Decode message. diff --git a/tests/lean_spec/subspecs/networking/discovery/test_messages.py b/tests/lean_spec/subspecs/networking/discovery/test_messages.py new file mode 100644 index 00000000..ab4fbeb5 --- /dev/null +++ b/tests/lean_spec/subspecs/networking/discovery/test_messages.py @@ -0,0 +1,385 @@ +""" +Tests for Discovery v5 protocol messages, types, and constants. + +Validates that protocol constants, message types, custom types, and +configuration match the Discovery v5 specification. +""" + +from __future__ import annotations + +from lean_spec.subspecs.networking.discovery.config import ( + ALPHA, + BOND_EXPIRY_SECS, + BUCKET_COUNT, + HANDSHAKE_TIMEOUT_SECS, + K_BUCKET_SIZE, + MAX_NODES_RESPONSE, + MAX_PACKET_SIZE, + MIN_PACKET_SIZE, + REQUEST_TIMEOUT_SECS, + DiscoveryConfig, +) +from lean_spec.subspecs.networking.discovery.messages import ( + MAX_REQUEST_ID_LENGTH, + PROTOCOL_ID, + PROTOCOL_VERSION, + Distance, + FindNode, + IdNonce, + IPv4, + IPv6, + MessageType, + Nodes, + Nonce, + PacketFlag, + Ping, + Pong, + Port, + RequestId, + StaticHeader, + TalkReq, + TalkResp, +) +from lean_spec.subspecs.networking.discovery.packet import WhoAreYouAuthdata +from lean_spec.subspecs.networking.types import SeqNumber +from lean_spec.types.uint import Uint8, Uint16, Uint64 +from tests.lean_spec.subspecs.networking.discovery.test_vectors import SPEC_ID_NONCE + + +class TestProtocolConstants: + """Verify protocol constants match Discovery v5 specification.""" + + def test_protocol_id(self): + assert PROTOCOL_ID == b"discv5" + assert len(PROTOCOL_ID) == 6 + + def test_protocol_version(self): + assert PROTOCOL_VERSION == 0x0001 + + def test_max_request_id_length(self): + assert MAX_REQUEST_ID_LENGTH == 8 + + def test_k_bucket_size(self): + assert K_BUCKET_SIZE == 16 + + def test_alpha_concurrency(self): + assert ALPHA == 3 + + def test_bucket_count(self): + assert BUCKET_COUNT == 256 + + def test_request_timeout(self): + assert REQUEST_TIMEOUT_SECS == 0.5 + + def test_handshake_timeout(self): + assert HANDSHAKE_TIMEOUT_SECS == 1.0 + + def test_max_nodes_response(self): + assert MAX_NODES_RESPONSE == 16 + + def test_bond_expiry(self): + assert BOND_EXPIRY_SECS == 86400 + + def test_packet_size_limits(self): + assert MAX_PACKET_SIZE == 1280 + assert MIN_PACKET_SIZE == 63 + + +class TestCustomTypes: + """Tests for custom Discovery v5 types.""" + + def test_request_id_limit(self): + req_id = RequestId(data=b"\x01\x02\x03\x04\x05\x06\x07\x08") + assert len(req_id.data) == 8 + + def test_request_id_variable_length(self): + req_id = RequestId(data=b"\x01") + assert len(req_id.data) == 1 + + def test_ipv4_length(self): + ip = IPv4(b"\xc0\xa8\x01\x01") + assert len(ip) == 4 + + def test_ipv6_length(self): + ip = IPv6(b"\x00" * 15 + b"\x01") + assert len(ip) == 16 + + def test_id_nonce_length(self): + nonce = IdNonce(b"\x01" * 16) + assert len(nonce) == 16 + + def test_nonce_length(self): + nonce = Nonce(b"\x01" * 12) + assert len(nonce) == 12 + + def test_distance_type(self): + d = Distance(256) + assert isinstance(d, Uint16) + + def test_port_type(self): + p = Port(30303) + assert isinstance(p, Uint16) + + def test_enr_seq_type(self): + seq = SeqNumber(42) + assert isinstance(seq, Uint64) + + +class TestPacketFlag: + """Tests for packet type flags.""" + + def test_message_flag(self): + assert PacketFlag.MESSAGE == 0 + + def test_whoareyou_flag(self): + assert PacketFlag.WHOAREYOU == 1 + + def test_handshake_flag(self): + assert PacketFlag.HANDSHAKE == 2 + + +class TestMessageTypes: + """Verify message type codes match wire protocol spec.""" + + def test_ping_type(self): + assert MessageType.PING == 0x01 + + def test_pong_type(self): + assert MessageType.PONG == 0x02 + + def test_findnode_type(self): + assert MessageType.FINDNODE == 0x03 + + def test_nodes_type(self): + assert MessageType.NODES == 0x04 + + def test_talkreq_type(self): + assert MessageType.TALKREQ == 0x05 + + def test_talkresp_type(self): + assert MessageType.TALKRESP == 0x06 + + def test_experimental_types(self): + assert MessageType.REGTOPIC == 0x07 + assert MessageType.TICKET == 0x08 + assert MessageType.REGCONFIRMATION == 0x09 + assert MessageType.TOPICQUERY == 0x0A + + +class TestDiscoveryConfig: + """Tests for DiscoveryConfig.""" + + def test_default_values(self): + config = DiscoveryConfig() + + assert config.k_bucket_size == K_BUCKET_SIZE + assert config.alpha == ALPHA + assert config.request_timeout_secs == REQUEST_TIMEOUT_SECS + assert config.handshake_timeout_secs == HANDSHAKE_TIMEOUT_SECS + assert config.max_nodes_response == MAX_NODES_RESPONSE + assert config.bond_expiry_secs == BOND_EXPIRY_SECS + + def test_custom_values(self): + config = DiscoveryConfig( + k_bucket_size=8, + alpha=5, + request_timeout_secs=2.0, + ) + assert config.k_bucket_size == 8 + assert config.alpha == 5 + assert config.request_timeout_secs == 2.0 + + +class TestPing: + """Tests for PING message.""" + + def test_creation_with_types(self): + ping = Ping( + request_id=RequestId(data=b"\x00\x00\x00\x01"), + enr_seq=SeqNumber(2), + ) + + assert ping.request_id.data == b"\x00\x00\x00\x01" + assert ping.enr_seq == SeqNumber(2) + + def test_max_request_id_length(self): + ping = Ping( + request_id=RequestId(data=b"\x01\x02\x03\x04\x05\x06\x07\x08"), + enr_seq=SeqNumber(1), + ) + assert len(ping.request_id.data) == 8 + + +class TestPong: + """Tests for PONG message.""" + + def test_creation_ipv4(self): + pong = Pong( + request_id=RequestId(data=b"\x00\x00\x00\x01"), + enr_seq=SeqNumber(42), + recipient_ip=b"\xc0\xa8\x01\x01", + recipient_port=Port(9000), + ) + + assert pong.enr_seq == SeqNumber(42) + assert len(pong.recipient_ip) == 4 + assert pong.recipient_port == Port(9000) + + def test_creation_ipv6(self): + ipv6 = b"\x00" * 15 + b"\x01" + pong = Pong( + request_id=RequestId(data=b"\x01"), + enr_seq=SeqNumber(1), + recipient_ip=ipv6, + recipient_port=Port(30303), + ) + + assert len(pong.recipient_ip) == 16 + + +class TestFindNode: + """Tests for FINDNODE message.""" + + def test_single_distance(self): + findnode = FindNode( + request_id=RequestId(data=b"\x01"), + distances=[Distance(256)], + ) + assert findnode.distances == [Distance(256)] + + def test_multiple_distances(self): + findnode = FindNode( + request_id=RequestId(data=b"\x01"), + distances=[Distance(0), Distance(1), Distance(255), Distance(256)], + ) + assert Distance(0) in findnode.distances + assert Distance(256) in findnode.distances + + def test_distance_zero_returns_self(self): + findnode = FindNode( + request_id=RequestId(data=b"\x01"), + distances=[Distance(0)], + ) + assert findnode.distances == [Distance(0)] + + +class TestNodes: + """Tests for NODES message.""" + + def test_single_response(self): + nodes = Nodes( + request_id=RequestId(data=b"\x01"), + total=Uint8(1), + enrs=[b"enr:-example"], + ) + assert nodes.total == Uint8(1) + assert len(nodes.enrs) == 1 + + def test_multiple_responses(self): + nodes = Nodes( + request_id=RequestId(data=b"\x01"), + total=Uint8(3), + enrs=[b"enr1", b"enr2"], + ) + assert nodes.total == Uint8(3) + assert len(nodes.enrs) == 2 + + +class TestTalkReq: + """Tests for TALKREQ message.""" + + def test_creation(self): + req = TalkReq( + request_id=RequestId(data=b"\x01"), + protocol=b"portal", + request=b"payload", + ) + assert req.protocol == b"portal" + assert req.request == b"payload" + + +class TestTalkResp: + """Tests for TALKRESP message.""" + + def test_creation(self): + resp = TalkResp( + request_id=RequestId(data=b"\x01"), + response=b"response_data", + ) + assert resp.response == b"response_data" + + def test_empty_response_unknown_protocol(self): + resp = TalkResp( + request_id=RequestId(data=b"\x01"), + response=b"", + ) + assert resp.response == b"" + + +class TestStaticHeader: + """Tests for packet static header.""" + + def test_default_protocol_id(self): + header = StaticHeader( + flag=Uint8(0), + nonce=Nonce(b"\x00" * 12), + authdata_size=Uint16(32), + ) + assert header.protocol_id == b"discv5" + assert header.version == Uint16(0x0001) + + def test_flag_values(self): + for flag in [0, 1, 2]: + header = StaticHeader( + flag=Uint8(flag), + nonce=Nonce(b"\xff" * 12), + authdata_size=Uint16(32), + ) + assert header.flag == Uint8(flag) + + +class TestWhoAreYouAuthdataConstruction: + """Tests for WHOAREYOU authdata construction.""" + + def test_creation(self): + authdata = WhoAreYouAuthdata( + id_nonce=IdNonce(b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10"), + enr_seq=Uint64(0), + ) + assert len(authdata.id_nonce) == 16 + assert authdata.enr_seq == Uint64(0) + + +class TestMessageConstructionFromTestVectors: + """Test message construction using official Discovery v5 test vector inputs.""" + + PING_REQUEST_ID = bytes.fromhex("00000001") + PING_ENR_SEQ = 2 + + def test_ping_message_construction(self): + ping = Ping( + request_id=RequestId(data=self.PING_REQUEST_ID), + enr_seq=SeqNumber(self.PING_ENR_SEQ), + ) + assert ping.request_id.data == self.PING_REQUEST_ID + assert ping.enr_seq == SeqNumber(2) + + def test_whoareyou_authdata_construction(self): + authdata = WhoAreYouAuthdata( + id_nonce=IdNonce(SPEC_ID_NONCE), + enr_seq=Uint64(0), + ) + assert authdata.id_nonce == IdNonce(SPEC_ID_NONCE) + assert authdata.enr_seq == Uint64(0) + + def test_plaintext_message_type(self): + plaintext = bytes.fromhex("01c20101") + assert plaintext[0] == MessageType.PING + + +class TestPacketStructure: + """Tests for Discovery v5 packet structure constants.""" + + def test_static_header_size(self): + expected_size = 6 + 2 + 1 + 12 + 2 + assert expected_size == 23 diff --git a/tests/lean_spec/subspecs/networking/discovery/test_packet.py b/tests/lean_spec/subspecs/networking/discovery/test_packet.py index 61b26179..d4a1d004 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_packet.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_packet.py @@ -230,7 +230,7 @@ def test_decode_packet_header(self): encryption_key=None, ) - header, message_bytes = decode_packet_header(local_node_id, packet) + header, message_bytes, _message_ad = decode_packet_header(local_node_id, packet) assert header.flag == PacketFlag.WHOAREYOU assert bytes(header.nonce) == nonce diff --git a/tests/lean_spec/subspecs/networking/discovery/test_routing.py b/tests/lean_spec/subspecs/networking/discovery/test_routing.py index deafd42f..949f8786 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_routing.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_routing.py @@ -468,6 +468,88 @@ def test_fork_filter_rejects_without_eth2_data(self, local_node_id, remote_node_ assert not table.is_fork_compatible(entry) + def test_fork_filter_rejects_mismatched_fork(self, local_node_id, remote_node_id): + """Node with different fork_digest is rejected.""" + from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH + + local_fork = Bytes4(bytes.fromhex("12345678")) + table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) + + # Build eth2 bytes with a different fork digest. + remote_digest = bytes.fromhex("deadbeef") + eth2_bytes = remote_digest + remote_digest + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") + enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={"eth2": eth2_bytes, "id": b"v4"}, + ) + entry = NodeEntry(node_id=remote_node_id, enr=enr) + + assert not table.add(entry) + assert not table.contains(remote_node_id) + + def test_fork_filter_accepts_matching_fork(self, local_node_id, remote_node_id): + """Node with matching fork_digest is accepted.""" + from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH + + local_fork = Bytes4(bytes.fromhex("12345678")) + table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) + + # Build eth2 bytes with the same fork digest. + eth2_bytes = ( + bytes.fromhex("12345678") + + bytes.fromhex("12345678") + + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") + ) + enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={"eth2": eth2_bytes, "id": b"v4"}, + ) + entry = NodeEntry(node_id=remote_node_id, enr=enr) + + assert table.add(entry) + assert table.contains(remote_node_id) + + def test_is_fork_compatible_method(self, local_node_id): + """Verify is_fork_compatible for compatible, incompatible, and no-ENR entries.""" + from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH + + local_fork = Bytes4(bytes.fromhex("12345678")) + table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) + + # Compatible entry. + eth2_match = ( + bytes.fromhex("12345678") + + bytes.fromhex("12345678") + + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") + ) + compatible_enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={"eth2": eth2_match, "id": b"v4"}, + ) + compatible_entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr=compatible_enr) + assert table.is_fork_compatible(compatible_entry) + + # Incompatible entry (different fork). + eth2_mismatch = ( + bytes.fromhex("deadbeef") + + bytes.fromhex("deadbeef") + + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") + ) + incompatible_enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={"eth2": eth2_mismatch, "id": b"v4"}, + ) + incompatible_entry = NodeEntry(node_id=NodeId(b"\x02" * 32), enr=incompatible_enr) + assert not table.is_fork_compatible(incompatible_entry) + + # Entry without ENR. + no_enr_entry = NodeEntry(node_id=NodeId(b"\x03" * 32)) + assert not table.is_fork_compatible(no_enr_entry) + class TestIPDensityTracking: """Tests for tracking IP address density. diff --git a/tests/lean_spec/subspecs/networking/discovery/test_service.py b/tests/lean_spec/subspecs/networking/discovery/test_service.py index 336cf2c6..8dfcb07f 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_service.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_service.py @@ -19,6 +19,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber from lean_spec.types import Bytes64, Uint64 +from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY class TestDiscoveryServiceInit: @@ -298,9 +299,7 @@ def test_enr_ip4_extraction(self, local_private_key): seq=Uint64(1), pairs={ "id": b"v4", - "secp256k1": bytes.fromhex( - "0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" - ), + "secp256k1": NODE_B_PUBKEY, "ip": bytes([127, 0, 0, 1]), "udp": (9000).to_bytes(2, "big"), }, @@ -320,9 +319,7 @@ def test_enr_ip6_extraction(self, local_private_key): seq=Uint64(1), pairs={ "id": b"v4", - "secp256k1": bytes.fromhex( - "0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" - ), + "secp256k1": NODE_B_PUBKEY, "ip6": ipv6_bytes, "udp6": (9000).to_bytes(2, "big"), }, @@ -341,9 +338,7 @@ def test_enr_dual_stack_has_both(self, local_private_key): seq=Uint64(1), pairs={ "id": b"v4", - "secp256k1": bytes.fromhex( - "0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" - ), + "secp256k1": NODE_B_PUBKEY, "ip": bytes([192, 168, 1, 1]), "udp": (9000).to_bytes(2, "big"), "ip6": ipv6_bytes, @@ -362,9 +357,7 @@ def test_enr_missing_ip_returns_none(self, local_private_key): seq=Uint64(1), pairs={ "id": b"v4", - "secp256k1": bytes.fromhex( - "0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" - ), + "secp256k1": NODE_B_PUBKEY, }, ) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py index 2138d311..45d8ddfd 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py @@ -26,6 +26,7 @@ decode_message_authdata, decode_packet_header, decode_whoareyou_authdata, + decrypt_message, encode_handshake_authdata, encode_message_authdata, encode_packet, @@ -33,14 +34,12 @@ ) from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64 from tests.lean_spec.helpers import make_challenge_data - -# Node B's secp256k1 keypair (from devp2p spec) -# Node B's private key is provided in the test vectors. -NODE_B_PRIVKEY = bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") -NODE_B_ID = bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") - -# Node A's ID (from devp2p spec, private key not provided) -NODE_A_ID = bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") +from tests.lean_spec.subspecs.networking.discovery.conftest import ( + NODE_A_ID, + NODE_B_ID, + NODE_B_PRIVKEY, + NODE_B_PUBKEY, +) # Spec test vector values for ECDH and key derivation. SPEC_ID_NONCE = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") @@ -50,6 +49,25 @@ "00180102030405060708090a0b0c0d0e0f100000000000000000" ) +# Spec ephemeral keypair for ECDH / ID nonce signing. +SPEC_EPHEMERAL_KEY = bytes.fromhex( + "fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" +) +SPEC_EPHEMERAL_PUBKEY = bytes.fromhex( + "039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231" +) + +# Derived session keys from spec HKDF test vector. +SPEC_INITIATOR_KEY = bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571") +SPEC_RECIPIENT_KEY = bytes.fromhex("ac74bb8773749920b0d3a8881c173ec5") + +# AES-GCM test vector values. +SPEC_AES_KEY = bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b") +SPEC_AES_NONCE = bytes.fromhex("27b5af763c446acd2749fe8e") + +# PING message plaintext (type 0x01, RLP [1]). +SPEC_PING_PLAINTEXT = bytes.fromhex("01c20101") + class TestOfficialNodeIdVectors: """Verify node ID computation matches official test vectors.""" @@ -87,17 +105,11 @@ def test_ecdh_shared_secret(self): Per spec, the shared secret is the 33-byte compressed point. """ - secret_key = bytes.fromhex( - "fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" - ) - public_key = bytes.fromhex( - "039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231" - ) expected_shared = bytes.fromhex( "033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e" ) - shared = ecdh_agree(Bytes32(secret_key), public_key) + shared = ecdh_agree(Bytes32(SPEC_EPHEMERAL_KEY), SPEC_EPHEMERAL_PUBKEY) assert shared == expected_shared @@ -108,32 +120,19 @@ def test_key_derivation_hkdf(self): Derives initiator_key and recipient_key from ECDH shared secret. Uses exact spec challenge_data (with nonce 0102030405060708090a0b0c). """ - ephemeral_key = bytes.fromhex( - "fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" - ) - dest_pubkey = bytes.fromhex( - "0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91" - ) - node_id_a = bytes.fromhex( - "aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb" - ) - node_id_b = bytes.fromhex( - "bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9" - ) - # Compute ECDH shared secret. - shared_secret = ecdh_agree(Bytes32(ephemeral_key), dest_pubkey) + shared_secret = ecdh_agree(Bytes32(SPEC_EPHEMERAL_KEY), NODE_B_PUBKEY) # Derive keys using exact spec challenge_data. initiator_key, recipient_key = derive_keys( secret=shared_secret, - initiator_id=Bytes32(node_id_a), - recipient_id=Bytes32(node_id_b), + initiator_id=Bytes32(NODE_A_ID), + recipient_id=Bytes32(NODE_B_ID), challenge_data=SPEC_CHALLENGE_DATA, ) - assert initiator_key == bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571") - assert recipient_key == bytes.fromhex("ac74bb8773749920b0d3a8881c173ec5") + assert initiator_key == SPEC_INITIATOR_KEY + assert recipient_key == SPEC_RECIPIENT_KEY def test_id_nonce_signature(self): """ @@ -146,22 +145,12 @@ def test_id_nonce_signature(self): Uses exact spec challenge_data and verifies byte-exact signature output. """ - static_key = bytes.fromhex( - "fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" - ) - ephemeral_pubkey = bytes.fromhex( - "039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231" - ) - node_id_b = bytes.fromhex( - "bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9" - ) - # Sign using exact spec challenge_data. signature = sign_id_nonce( - private_key_bytes=Bytes32(static_key), + private_key_bytes=Bytes32(SPEC_EPHEMERAL_KEY), challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=Bytes33(ephemeral_pubkey), - dest_node_id=Bytes32(node_id_b), + ephemeral_pubkey=Bytes33(SPEC_EPHEMERAL_PUBKEY), + dest_node_id=Bytes32(NODE_B_ID), ) expected_sig = bytes.fromhex( @@ -175,7 +164,7 @@ def test_id_nonce_signature(self): from cryptography.hazmat.primitives.asymmetric import ec private_key = ec.derive_private_key( - int.from_bytes(static_key, "big"), + int.from_bytes(SPEC_EPHEMERAL_KEY, "big"), ec.SECP256K1(), ) pubkey_bytes = private_key.public_key().public_bytes( @@ -186,27 +175,27 @@ def test_id_nonce_signature(self): assert verify_id_nonce_signature( signature=Bytes64(signature), challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=Bytes33(ephemeral_pubkey), - dest_node_id=Bytes32(node_id_b), + ephemeral_pubkey=Bytes33(SPEC_EPHEMERAL_PUBKEY), + dest_node_id=Bytes32(NODE_B_ID), public_key_bytes=Bytes33(pubkey_bytes), ) def test_id_nonce_signature_different_challenge_data(self): """Different challenge_data produces different signatures.""" - static_key = NODE_B_PRIVKEY - ephemeral_pubkey = bytes.fromhex( - "039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231" - ) - node_id = NODE_A_ID - challenge_data1 = make_challenge_data(bytes(16)) challenge_data2 = make_challenge_data(bytes([1]) + bytes(15)) sig1 = sign_id_nonce( - Bytes32(static_key), challenge_data1, Bytes33(ephemeral_pubkey), Bytes32(node_id) + Bytes32(NODE_B_PRIVKEY), + challenge_data1, + Bytes33(SPEC_EPHEMERAL_PUBKEY), + Bytes32(NODE_A_ID), ) sig2 = sign_id_nonce( - Bytes32(static_key), challenge_data2, Bytes33(ephemeral_pubkey), Bytes32(node_id) + Bytes32(NODE_B_PRIVKEY), + challenge_data2, + Bytes33(SPEC_EPHEMERAL_PUBKEY), + Bytes32(NODE_A_ID), ) assert sig1 != sig2 @@ -217,114 +206,27 @@ def test_aes_gcm_encryption(self): The 16-byte authentication tag is appended to ciphertext. """ - encryption_key = bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b") - nonce = bytes.fromhex("27b5af763c446acd2749fe8e") - plaintext = bytes.fromhex("01c20101") aad = bytes.fromhex("93a7400fa0d6a694ebc24d5cf570f65d04215b6ac00757875e3f3a5f42107903") expected_ciphertext = bytes.fromhex("a5d12a2d94b8ccb3ba55558229867dc13bfa3648") # Encrypt. - ciphertext = aes_gcm_encrypt(Bytes16(encryption_key), Bytes12(nonce), plaintext, aad) + ciphertext = aes_gcm_encrypt( + Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), SPEC_PING_PLAINTEXT, aad + ) assert ciphertext == expected_ciphertext # Verify decryption works. - decrypted = aes_gcm_decrypt(Bytes16(encryption_key), Bytes12(nonce), ciphertext, aad) - assert decrypted == plaintext + decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) + assert decrypted == SPEC_PING_PLAINTEXT class TestOfficialPacketVectors: - """ - Packet encoding test vectors from devp2p spec. + """Decode exact packet bytes from the devp2p spec test vectors. - Note: Full packet encoding verification requires deterministic masking IV, - which the spec test vectors use all-zeros IV for reproducibility. - These tests verify the underlying authdata encoding is correct. + These tests verify interoperability by decoding the spec's exact hex packets. """ - def test_message_authdata_encoding(self): - """MESSAGE packet authdata is just the 32-byte source node ID.""" - src_id = NODE_A_ID - - authdata = encode_message_authdata(src_id) - assert authdata == src_id - assert len(authdata) == 32 - - # Decode and verify. - decoded = decode_message_authdata(authdata) - assert decoded.src_id == src_id - - def test_whoareyou_authdata_encoding(self): - """WHOAREYOU authdata is id-nonce (16) + enr-seq (8).""" - id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") - enr_seq = 0 - - authdata = encode_whoareyou_authdata(id_nonce, enr_seq) - assert len(authdata) == 24 - - # Decode and verify. - decoded = decode_whoareyou_authdata(authdata) - assert bytes(decoded.id_nonce) == id_nonce - assert int(decoded.enr_seq) == enr_seq - - def test_whoareyou_authdata_with_nonzero_enr_seq(self): - """WHOAREYOU with known ENR sequence.""" - id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") - enr_seq = 1 - - authdata = encode_whoareyou_authdata(id_nonce, enr_seq) - - decoded = decode_whoareyou_authdata(authdata) - assert int(decoded.enr_seq) == 1 - - def test_handshake_authdata_encoding(self): - """HANDSHAKE authdata contains signature and ephemeral key.""" - src_id = NODE_A_ID - id_signature = bytes(64) # Placeholder 64-byte signature. - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" - ) - - authdata = encode_handshake_authdata( - src_id=src_id, - id_signature=id_signature, - eph_pubkey=eph_pubkey, - record=None, - ) - - # Expected size: 32 (src_id) + 1 (sig_size) + 1 (key_size) + 64 + 33 = 131 - assert len(authdata) == 131 - - # Decode and verify. - decoded = decode_handshake_authdata(authdata) - assert decoded.src_id == src_id - assert decoded.sig_size == 64 - assert decoded.eph_key_size == 33 - assert decoded.id_signature == id_signature - assert decoded.eph_pubkey == eph_pubkey - assert decoded.record is None - - def test_handshake_authdata_with_enr(self): - """HANDSHAKE authdata can include an ENR record.""" - src_id = NODE_A_ID - id_signature = bytes(64) - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" - ) - # Minimal valid RLP-encoded ENR (just for testing). - enr_record = bytes.fromhex("f84180") # Placeholder. - - authdata = encode_handshake_authdata( - src_id=src_id, - id_signature=id_signature, - eph_pubkey=eph_pubkey, - record=enr_record, - ) - - # Decode and verify. - decoded = decode_handshake_authdata(authdata) - assert decoded.record == enr_record - def test_decode_spec_ping_packet(self): """Decode the exact Ping packet from the spec test vectors. @@ -339,7 +241,7 @@ def test_decode_spec_ping_packet(self): ) packet = bytes.fromhex(packet_hex) - header, ciphertext = decode_packet_header(NODE_B_ID, packet) + header, _ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) assert header.flag == PacketFlag.MESSAGE decoded_authdata = decode_message_authdata(header.authdata) @@ -358,7 +260,7 @@ def test_decode_spec_whoareyou_packet(self): ) packet = bytes.fromhex(packet_hex) - header, message = decode_packet_header(NODE_B_ID, packet) + header, _message, _message_ad = decode_packet_header(NODE_B_ID, packet) assert header.flag == PacketFlag.WHOAREYOU decoded_authdata = decode_whoareyou_authdata(header.authdata) @@ -382,7 +284,7 @@ def test_decode_spec_handshake_packet(self): ) packet = bytes.fromhex(packet_hex) - header, ciphertext = decode_packet_header(NODE_B_ID, packet) + header, _ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) assert header.flag == PacketFlag.HANDSHAKE decoded_authdata = decode_handshake_authdata(header.authdata) @@ -396,17 +298,15 @@ class TestPacketEncodingRoundtrip: def test_message_packet_roundtrip(self): """MESSAGE packet encodes and decodes correctly.""" - src_id = NODE_A_ID - dest_id = NODE_B_ID nonce = bytes(12) # 12-byte nonce. encryption_key = bytes(16) # 16-byte key. message = b"\x01\xc2\x01\x01" # PING message. - authdata = encode_message_authdata(src_id) + authdata = encode_message_authdata(NODE_A_ID) packet = encode_packet( - dest_node_id=dest_id, - src_node_id=src_id, + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, flag=PacketFlag.MESSAGE, nonce=nonce, authdata=authdata, @@ -415,18 +315,16 @@ def test_message_packet_roundtrip(self): ) # Decode header. - header, ciphertext = decode_packet_header(dest_id, packet) + header, ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) assert header.flag == PacketFlag.MESSAGE assert len(header.authdata) == 32 decoded_authdata = decode_message_authdata(header.authdata) - assert decoded_authdata.src_id == src_id + assert decoded_authdata.src_id == NODE_A_ID def test_whoareyou_packet_roundtrip(self): """WHOAREYOU packet encodes and decodes correctly.""" - src_id = NODE_A_ID - dest_id = NODE_B_ID nonce = bytes.fromhex("0102030405060708090a0b0c") id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") enr_seq = 0 @@ -434,8 +332,8 @@ def test_whoareyou_packet_roundtrip(self): authdata = encode_whoareyou_authdata(id_nonce, enr_seq) packet = encode_packet( - dest_node_id=dest_id, - src_node_id=src_id, + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, flag=PacketFlag.WHOAREYOU, nonce=nonce, authdata=authdata, @@ -444,7 +342,7 @@ def test_whoareyou_packet_roundtrip(self): ) # Decode header. - header, message = decode_packet_header(dest_id, packet) + header, message, _message_ad = decode_packet_header(NODE_B_ID, packet) assert header.flag == PacketFlag.WHOAREYOU assert bytes(header.nonce) == nonce @@ -455,10 +353,7 @@ def test_whoareyou_packet_roundtrip(self): def test_handshake_packet_roundtrip(self): """HANDSHAKE packet encodes and decodes correctly.""" - src_id = NODE_A_ID - dest_id = NODE_B_ID nonce = bytes(12) - encryption_key = bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571") message = b"\x01\xc2\x01\x01" # PING message. id_signature = bytes(64) @@ -467,74 +362,32 @@ def test_handshake_packet_roundtrip(self): ) authdata = encode_handshake_authdata( - src_id=src_id, + src_id=NODE_A_ID, id_signature=id_signature, eph_pubkey=eph_pubkey, record=None, ) packet = encode_packet( - dest_node_id=dest_id, - src_node_id=src_id, + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, flag=PacketFlag.HANDSHAKE, nonce=nonce, authdata=authdata, message=message, - encryption_key=encryption_key, + encryption_key=SPEC_INITIATOR_KEY, ) # Decode header. - header, ciphertext = decode_packet_header(dest_id, packet) + header, ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) assert header.flag == PacketFlag.HANDSHAKE decoded_authdata = decode_handshake_authdata(header.authdata) - assert decoded_authdata.src_id == src_id + assert decoded_authdata.src_id == NODE_A_ID assert decoded_authdata.eph_pubkey == eph_pubkey -class TestKeyDerivationEdgeCases: - """Additional key derivation tests beyond official vectors.""" - - def test_derive_keys_deterministic(self): - """Same inputs always produce same keys.""" - secret = Bytes33(bytes(33)) - initiator_id = Bytes32(NODE_A_ID) - recipient_id = Bytes32(NODE_B_ID) - challenge_data = make_challenge_data() - - keys1 = derive_keys(secret, initiator_id, recipient_id, challenge_data) - keys2 = derive_keys(secret, initiator_id, recipient_id, challenge_data) - - assert keys1 == keys2 - - def test_derive_keys_id_order_matters(self): - """Swapping initiator/recipient produces different keys.""" - secret = Bytes33(bytes(33)) - id_a = Bytes32(NODE_A_ID) - id_b = Bytes32(NODE_B_ID) - challenge_data = make_challenge_data() - - keys_ab = derive_keys(secret, id_a, id_b, challenge_data) - keys_ba = derive_keys(secret, id_b, id_a, challenge_data) - - assert keys_ab != keys_ba - - def test_derive_keys_challenge_data_matters(self): - """Different challenge_data produces different keys.""" - secret = Bytes33(bytes(33)) - initiator_id = Bytes32(NODE_A_ID) - recipient_id = Bytes32(NODE_B_ID) - - challenge_data1 = make_challenge_data(bytes(16)) - challenge_data2 = make_challenge_data(bytes([1]) + bytes(15)) - - keys1 = derive_keys(secret, initiator_id, recipient_id, challenge_data1) - keys2 = derive_keys(secret, initiator_id, recipient_id, challenge_data2) - - assert keys1 != keys2 - - class TestOfficialPacketEncoding: """Byte-exact packet encoding from devp2p spec wire test vectors. @@ -670,17 +523,15 @@ def test_message_packet_header_structure(self): encode_packet, ) - src_id = NODE_A_ID - dest_id = NODE_B_ID nonce = bytes(12) encryption_key = bytes(16) message = b"\x01\xc2\x01\x01" - authdata = encode_message_authdata(src_id) + authdata = encode_message_authdata(NODE_A_ID) packet = encode_packet( - dest_node_id=dest_id, - src_node_id=src_id, + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, flag=PacketFlag.MESSAGE, nonce=nonce, authdata=authdata, @@ -706,8 +557,6 @@ def test_whoareyou_packet_header_structure(self): encode_whoareyou_authdata, ) - src_id = NODE_A_ID - dest_id = NODE_B_ID nonce = bytes(12) id_nonce = bytes(16) enr_seq = 0 @@ -715,8 +564,8 @@ def test_whoareyou_packet_header_structure(self): authdata = encode_whoareyou_authdata(id_nonce, enr_seq) packet = encode_packet( - dest_node_id=dest_id, - src_node_id=src_id, + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, flag=PacketFlag.WHOAREYOU, nonce=nonce, authdata=authdata, @@ -743,8 +592,6 @@ def test_handshake_packet_header_structure(self): encode_packet, ) - src_id = NODE_A_ID - dest_id = NODE_B_ID nonce = bytes(12) encryption_key = bytes(16) message = b"\x01\xc2\x01\x01" @@ -755,15 +602,15 @@ def test_handshake_packet_header_structure(self): ) authdata = encode_handshake_authdata( - src_id=src_id, + src_id=NODE_A_ID, id_signature=id_signature, eph_pubkey=eph_pubkey, record=None, ) packet = encode_packet( - dest_node_id=dest_id, - src_node_id=src_id, + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, flag=PacketFlag.HANDSHAKE, nonce=nonce, authdata=authdata, @@ -844,34 +691,30 @@ class TestAESCryptoEdgeCases: def test_aes_gcm_empty_plaintext(self): """AES-GCM handles empty plaintext correctly.""" - key = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) - nonce = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) aad = bytes(32) plaintext = b"" - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) + ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) # Empty plaintext should produce just the 16-byte auth tag. assert len(ciphertext) == 16 # Decryption should recover empty plaintext. - decrypted = aes_gcm_decrypt(key, nonce, ciphertext, aad) + decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) assert decrypted == b"" def test_aes_gcm_large_plaintext(self): """AES-GCM handles large plaintext correctly.""" - key = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) - nonce = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) aad = bytes(32) plaintext = bytes(1024) # 1KB of zeros. - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) + ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) # Ciphertext = plaintext length + 16-byte tag. assert len(ciphertext) == len(plaintext) + 16 # Decryption should recover original plaintext. - decrypted = aes_gcm_decrypt(key, nonce, ciphertext, aad) + decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) assert decrypted == plaintext def test_aes_gcm_wrong_key_fails_decryption(self): @@ -879,46 +722,40 @@ def test_aes_gcm_wrong_key_fails_decryption(self): import pytest from cryptography.exceptions import InvalidTag - key = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) wrong_key = Bytes16(bytes.fromhex("00000000000000001a85107ac686990b")) - nonce = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) aad = bytes(32) plaintext = b"secret message" - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) + ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) # Decryption with wrong key should fail with InvalidTag. with pytest.raises(InvalidTag): - aes_gcm_decrypt(wrong_key, nonce, ciphertext, aad) + aes_gcm_decrypt(wrong_key, Bytes12(SPEC_AES_NONCE), ciphertext, aad) def test_aes_gcm_wrong_aad_fails_decryption(self): """AES-GCM decryption fails with wrong AAD.""" import pytest from cryptography.exceptions import InvalidTag - key = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) - nonce = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) aad = bytes(32) wrong_aad = bytes([0xFF] * 32) plaintext = b"secret message" - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) + ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) # Decryption with wrong AAD should fail with InvalidTag. with pytest.raises(InvalidTag): - aes_gcm_decrypt(key, nonce, ciphertext, wrong_aad) + aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, wrong_aad) def test_aes_gcm_tampered_ciphertext_fails(self): """AES-GCM decryption fails with tampered ciphertext.""" import pytest from cryptography.exceptions import InvalidTag - key = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) - nonce = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) aad = bytes(32) plaintext = b"secret message" - ciphertext = aes_gcm_encrypt(key, nonce, plaintext, aad) + ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) # Tamper with ciphertext by flipping a bit. tampered = bytearray(ciphertext) @@ -927,4 +764,92 @@ def test_aes_gcm_tampered_ciphertext_fails(self): # Decryption of tampered ciphertext should fail with InvalidTag. with pytest.raises(InvalidTag): - aes_gcm_decrypt(key, nonce, tampered, aad) + aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), tampered, aad) + + +class TestSpecPacketPayloadDecryption: + """Verify message payload decryption using correct AAD (masking-iv || plaintext header).""" + + def test_message_packet_encrypt_decrypt_roundtrip(self): + """Encrypt a message in a packet and decrypt using message_ad from decode.""" + nonce = bytes(12) + + authdata = encode_message_authdata(NODE_A_ID) + + packet = encode_packet( + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, + flag=PacketFlag.MESSAGE, + nonce=nonce, + authdata=authdata, + message=SPEC_PING_PLAINTEXT, + encryption_key=SPEC_INITIATOR_KEY, + ) + + # Decode header - returns message_ad for AAD. + header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) + + # Decrypt using message_ad as AAD. + decrypted = decrypt_message(SPEC_INITIATOR_KEY, bytes(header.nonce), ciphertext, message_ad) + assert decrypted == SPEC_PING_PLAINTEXT + + def test_handshake_packet_encrypt_decrypt_roundtrip(self): + """Handshake packet encrypts and decrypts using correct AAD.""" + nonce = bytes(12) + + id_signature = bytes(64) + eph_pubkey = bytes.fromhex( + "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" + ) + + authdata = encode_handshake_authdata( + src_id=NODE_A_ID, + id_signature=id_signature, + eph_pubkey=eph_pubkey, + record=None, + ) + + packet = encode_packet( + dest_node_id=NODE_B_ID, + src_node_id=NODE_A_ID, + flag=PacketFlag.HANDSHAKE, + nonce=nonce, + authdata=authdata, + message=SPEC_PING_PLAINTEXT, + encryption_key=SPEC_INITIATOR_KEY, + ) + + # Decode header - returns message_ad for AAD. + header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) + + # Decrypt using message_ad as AAD. + decrypted = decrypt_message(SPEC_INITIATOR_KEY, bytes(header.nonce), ciphertext, message_ad) + assert decrypted == SPEC_PING_PLAINTEXT + + +class TestRoutingWithTestVectorNodeIds: + """Tests using official test vector node IDs with routing functions.""" + + def test_xor_distance_is_symmetric(self): + """XOR distance between test vector nodes is symmetric and non-zero.""" + from lean_spec.subspecs.networking.discovery.routing import xor_distance + from lean_spec.subspecs.networking.types import NodeId + + node_a = NodeId(NODE_A_ID) + node_b = NodeId(NODE_B_ID) + + distance = xor_distance(node_a, node_b) + assert distance > 0 + assert xor_distance(node_a, node_b) == xor_distance(node_b, node_a) + + def test_log2_distance_is_high(self): + """Log2 distance between test vector nodes is high (differ in high bits).""" + from lean_spec.subspecs.networking.discovery.messages import Distance + from lean_spec.subspecs.networking.discovery.routing import log2_distance + from lean_spec.subspecs.networking.types import NodeId + + node_a = NodeId(NODE_A_ID) + node_b = NodeId(NODE_B_ID) + + log_dist = log2_distance(node_a, node_b) + assert log_dist > Distance(200) diff --git a/tests/lean_spec/subspecs/networking/test_discovery.py b/tests/lean_spec/subspecs/networking/test_discovery.py deleted file mode 100644 index 66ebad5c..00000000 --- a/tests/lean_spec/subspecs/networking/test_discovery.py +++ /dev/null @@ -1,890 +0,0 @@ -"""Tests for Discovery v5 Protocol Specification""" - -from typing import TYPE_CHECKING - -from lean_spec.subspecs.networking.discovery import ( - MAX_REQUEST_ID_LENGTH, - PROTOCOL_ID, - PROTOCOL_VERSION, - DiscoveryConfig, - Distance, - FindNode, - IdNonce, - MessageType, - Nodes, - Nonce, - Ping, - Pong, - RequestId, - RoutingTable, - TalkReq, - TalkResp, -) -from lean_spec.subspecs.networking.discovery.config import ( - ALPHA, - BOND_EXPIRY_SECS, - BUCKET_COUNT, - HANDSHAKE_TIMEOUT_SECS, - K_BUCKET_SIZE, - MAX_NODES_RESPONSE, - MAX_PACKET_SIZE, - MIN_PACKET_SIZE, - REQUEST_TIMEOUT_SECS, -) -from lean_spec.subspecs.networking.discovery.messages import ( - IPv4, - IPv6, - PacketFlag, - Port, - StaticHeader, - WhoAreYouAuthdata, -) -from lean_spec.subspecs.networking.discovery.routing import ( - KBucket, - NodeEntry, - log2_distance, - xor_distance, -) -from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types.uint import Uint8, Uint16, Uint64 - -if TYPE_CHECKING: - from lean_spec.subspecs.networking.enr import ENR - - -class TestProtocolConstants: - """Verify protocol constants match Discovery v5 specification.""" - - def test_protocol_id(self) -> None: - """Protocol ID is 'discv5'.""" - assert PROTOCOL_ID == b"discv5" - assert len(PROTOCOL_ID) == 6 - - def test_protocol_version(self) -> None: - """Protocol version is 0x0001 (v5.1).""" - assert PROTOCOL_VERSION == 0x0001 - - def test_max_request_id_length(self) -> None: - """Request ID max length is 8 bytes.""" - assert MAX_REQUEST_ID_LENGTH == 8 - - def test_k_bucket_size(self) -> None: - """K-bucket size is 16 per Kademlia standard.""" - assert K_BUCKET_SIZE == 16 - - def test_alpha_concurrency(self) -> None: - """Alpha (lookup concurrency) is 3.""" - assert ALPHA == 3 - - def test_bucket_count(self) -> None: - """256 buckets for 256-bit node ID space.""" - assert BUCKET_COUNT == 256 - - def test_request_timeout(self) -> None: - """Request timeout is 500ms per spec.""" - assert REQUEST_TIMEOUT_SECS == 0.5 - - def test_handshake_timeout(self) -> None: - """Handshake timeout is 1s per spec.""" - assert HANDSHAKE_TIMEOUT_SECS == 1.0 - - def test_max_nodes_response(self) -> None: - """Max 16 ENRs per NODES response.""" - assert MAX_NODES_RESPONSE == 16 - - def test_bond_expiry(self) -> None: - """Bond expires after 24 hours.""" - assert BOND_EXPIRY_SECS == 86400 - - def test_packet_size_limits(self) -> None: - """Packet size limits per spec.""" - assert MAX_PACKET_SIZE == 1280 - assert MIN_PACKET_SIZE == 63 - - -class TestCustomTypes: - """Tests for custom Discovery v5 types.""" - - def test_request_id_limit(self) -> None: - """RequestId accepts up to 8 bytes.""" - req_id = RequestId(data=b"\x01\x02\x03\x04\x05\x06\x07\x08") - assert len(req_id.data) == 8 - - def test_request_id_variable_length(self) -> None: - """RequestId is variable length.""" - req_id = RequestId(data=b"\x01") - assert len(req_id.data) == 1 - - def test_ipv4_length(self) -> None: - """IPv4 is exactly 4 bytes.""" - ip = IPv4(b"\xc0\xa8\x01\x01") # 192.168.1.1 - assert len(ip) == 4 - - def test_ipv6_length(self) -> None: - """IPv6 is exactly 16 bytes.""" - ip = IPv6(b"\x00" * 15 + b"\x01") # ::1 - assert len(ip) == 16 - - def test_id_nonce_length(self) -> None: - """IdNonce is 16 bytes (128 bits).""" - nonce = IdNonce(b"\x01" * 16) - assert len(nonce) == 16 - - def test_nonce_length(self) -> None: - """Nonce is 12 bytes (96 bits).""" - nonce = Nonce(b"\x01" * 12) - assert len(nonce) == 12 - - def test_distance_type(self) -> None: - """Distance is Uint16.""" - d = Distance(256) - assert isinstance(d, Uint16) - - def test_port_type(self) -> None: - """Port is Uint16.""" - p = Port(30303) - assert isinstance(p, Uint16) - - def test_enr_seq_type(self) -> None: - """SeqNumber is Uint64.""" - seq = SeqNumber(42) - assert isinstance(seq, Uint64) - - -class TestPacketFlag: - """Tests for packet type flags.""" - - def test_message_flag(self) -> None: - """MESSAGE flag is 0.""" - assert PacketFlag.MESSAGE == 0 - - def test_whoareyou_flag(self) -> None: - """WHOAREYOU flag is 1.""" - assert PacketFlag.WHOAREYOU == 1 - - def test_handshake_flag(self) -> None: - """HANDSHAKE flag is 2.""" - assert PacketFlag.HANDSHAKE == 2 - - -class TestMessageTypes: - """Verify message type codes match wire protocol spec.""" - - def test_ping_type(self) -> None: - """PING is message type 0x01.""" - assert MessageType.PING == 0x01 - - def test_pong_type(self) -> None: - """PONG is message type 0x02.""" - assert MessageType.PONG == 0x02 - - def test_findnode_type(self) -> None: - """FINDNODE is message type 0x03.""" - assert MessageType.FINDNODE == 0x03 - - def test_nodes_type(self) -> None: - """NODES is message type 0x04.""" - assert MessageType.NODES == 0x04 - - def test_talkreq_type(self) -> None: - """TALKREQ is message type 0x05.""" - assert MessageType.TALKREQ == 0x05 - - def test_talkresp_type(self) -> None: - """TALKRESP is message type 0x06.""" - assert MessageType.TALKRESP == 0x06 - - def test_experimental_types(self) -> None: - """Experimental topic messages have correct types.""" - assert MessageType.REGTOPIC == 0x07 - assert MessageType.TICKET == 0x08 - assert MessageType.REGCONFIRMATION == 0x09 - assert MessageType.TOPICQUERY == 0x0A - - -class TestDiscoveryConfig: - """Tests for DiscoveryConfig.""" - - def test_default_values(self) -> None: - """Default config uses spec-defined constants.""" - config = DiscoveryConfig() - - assert config.k_bucket_size == K_BUCKET_SIZE - assert config.alpha == ALPHA - assert config.request_timeout_secs == REQUEST_TIMEOUT_SECS - assert config.handshake_timeout_secs == HANDSHAKE_TIMEOUT_SECS - assert config.max_nodes_response == MAX_NODES_RESPONSE - assert config.bond_expiry_secs == BOND_EXPIRY_SECS - - def test_custom_values(self) -> None: - """Custom config values override defaults.""" - config = DiscoveryConfig( - k_bucket_size=8, - alpha=5, - request_timeout_secs=2.0, - ) - assert config.k_bucket_size == 8 - assert config.alpha == 5 - assert config.request_timeout_secs == 2.0 - - -class TestPing: - """Tests for PING message.""" - - def test_creation_with_types(self) -> None: - """PING with strongly typed fields.""" - ping = Ping( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=SeqNumber(2), - ) - - assert ping.request_id.data == b"\x00\x00\x00\x01" - assert ping.enr_seq == SeqNumber(2) - - def test_max_request_id_length(self) -> None: - """Request ID accepts up to 8 bytes.""" - ping = Ping( - request_id=RequestId(data=b"\x01\x02\x03\x04\x05\x06\x07\x08"), - enr_seq=SeqNumber(1), - ) - assert len(ping.request_id.data) == 8 - - -class TestPong: - """Tests for PONG message.""" - - def test_creation_ipv4(self) -> None: - """PONG with IPv4 address (4 bytes).""" - pong = Pong( - request_id=RequestId(data=b"\x00\x00\x00\x01"), - enr_seq=SeqNumber(42), - recipient_ip=b"\xc0\xa8\x01\x01", # 192.168.1.1 - recipient_port=Port(9000), - ) - - assert pong.enr_seq == SeqNumber(42) - assert len(pong.recipient_ip) == 4 - assert pong.recipient_port == Port(9000) - - def test_creation_ipv6(self) -> None: - """PONG with IPv6 address (16 bytes).""" - ipv6 = b"\x00" * 15 + b"\x01" # ::1 - pong = Pong( - request_id=RequestId(data=b"\x01"), - enr_seq=SeqNumber(1), - recipient_ip=ipv6, - recipient_port=Port(30303), - ) - - assert len(pong.recipient_ip) == 16 - - -class TestFindNode: - """Tests for FINDNODE message.""" - - def test_single_distance(self) -> None: - """FINDNODE querying single distance.""" - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(256)], - ) - - assert findnode.distances == [Distance(256)] - - def test_multiple_distances(self) -> None: - """FINDNODE querying multiple distances.""" - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(0), Distance(1), Distance(255), Distance(256)], - ) - - assert Distance(0) in findnode.distances # Distance 0 returns node itself - assert Distance(256) in findnode.distances # Maximum distance - - def test_distance_zero_returns_self(self) -> None: - """Distance 0 is valid and returns recipient's ENR.""" - findnode = FindNode( - request_id=RequestId(data=b"\x01"), - distances=[Distance(0)], - ) - assert findnode.distances == [Distance(0)] - - -class TestNodes: - """Tests for NODES message.""" - - def test_single_response(self) -> None: - """NODES with single response (total=1).""" - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(1), - enrs=[b"enr:-example"], - ) - - assert nodes.total == Uint8(1) - assert len(nodes.enrs) == 1 - - def test_multiple_responses(self) -> None: - """NODES indicating multiple response messages.""" - nodes = Nodes( - request_id=RequestId(data=b"\x01"), - total=Uint8(3), - enrs=[b"enr1", b"enr2"], - ) - - assert nodes.total == Uint8(3) - assert len(nodes.enrs) == 2 - - -class TestTalkReq: - """Tests for TALKREQ message.""" - - def test_creation(self) -> None: - """TALKREQ with protocol identifier.""" - req = TalkReq( - request_id=RequestId(data=b"\x01"), - protocol=b"portal", - request=b"payload", - ) - - assert req.protocol == b"portal" - assert req.request == b"payload" - - -class TestTalkResp: - """Tests for TALKRESP message.""" - - def test_creation(self) -> None: - """TALKRESP with response data.""" - resp = TalkResp( - request_id=RequestId(data=b"\x01"), - response=b"response_data", - ) - - assert resp.response == b"response_data" - - def test_empty_response_unknown_protocol(self) -> None: - """Empty response indicates unknown protocol.""" - resp = TalkResp( - request_id=RequestId(data=b"\x01"), - response=b"", - ) - assert resp.response == b"" - - -class TestStaticHeader: - """Tests for packet static header.""" - - def test_default_protocol_id(self) -> None: - """Static header has correct default protocol ID.""" - header = StaticHeader( - flag=Uint8(0), - nonce=Nonce(b"\x00" * 12), - authdata_size=Uint16(32), - ) - - assert header.protocol_id == b"discv5" - assert header.version == Uint16(0x0001) - - def test_flag_values(self) -> None: - """Static header accepts different flag values.""" - for flag in [0, 1, 2]: - header = StaticHeader( - flag=Uint8(flag), - nonce=Nonce(b"\xff" * 12), - authdata_size=Uint16(32), - ) - assert header.flag == Uint8(flag) - - -class TestWhoAreYouAuthdata: - """Tests for WHOAREYOU authdata.""" - - def test_creation(self) -> None: - """WHOAREYOU authdata with id_nonce and enr_seq.""" - authdata = WhoAreYouAuthdata( - id_nonce=IdNonce(b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10"), - enr_seq=SeqNumber(0), - ) - - assert len(authdata.id_nonce) == 16 - assert authdata.enr_seq == SeqNumber(0) - - -class TestXorDistance: - """Tests for XOR distance calculation.""" - - def test_identical_ids_zero_distance(self) -> None: - """Identical node IDs have distance 0.""" - node_id = NodeId(b"\x00" * 32) - assert xor_distance(node_id, node_id) == 0 - - def test_complementary_ids_max_distance(self) -> None: - """All-zeros XOR all-ones gives maximum distance.""" - a = NodeId(b"\x00" * 32) - b = NodeId(b"\xff" * 32) - assert xor_distance(a, b) == 2**256 - 1 - - def test_distance_is_symmetric(self) -> None: - """XOR distance satisfies d(a,b) == d(b,a).""" - a = NodeId(b"\x12" * 32) - b = NodeId(b"\x34" * 32) - assert xor_distance(a, b) == xor_distance(b, a) - - def test_specific_xor_values(self) -> None: - """Verify specific XOR calculations.""" - a = NodeId(b"\x00" * 31 + b"\x05") # 5 - b = NodeId(b"\x00" * 31 + b"\x03") # 3 - assert xor_distance(a, b) == 6 # 5 XOR 3 = 6 - - -class TestLog2Distance: - """Tests for log2 distance calculation.""" - - def test_identical_ids_return_zero(self) -> None: - """Identical IDs return log2 distance 0.""" - node_id = NodeId(b"\x00" * 32) - assert log2_distance(node_id, node_id) == Distance(0) - - def test_single_bit_difference(self) -> None: - """Single bit difference in LSB gives distance 1.""" - a = NodeId(b"\x00" * 32) - b = NodeId(b"\x00" * 31 + b"\x01") - assert log2_distance(a, b) == Distance(1) - - def test_high_bit_difference(self) -> None: - """Difference in high bit gives distance 8.""" - a = NodeId(b"\x00" * 32) - b = NodeId(b"\x00" * 31 + b"\x80") # 0b10000000 - assert log2_distance(a, b) == Distance(8) - - def test_maximum_distance(self) -> None: - """Maximum distance is 256 bits.""" - a = NodeId(b"\x00" * 32) - b = NodeId(b"\x80" + b"\x00" * 31) # High bit of first byte set - assert log2_distance(a, b) == Distance(256) - - -class TestKBucket: - """Tests for K-bucket implementation.""" - - def test_new_bucket_is_empty(self) -> None: - """Newly created bucket has no nodes.""" - bucket = KBucket() - - assert bucket.is_empty - assert not bucket.is_full - assert len(bucket) == 0 - - def test_add_single_node(self) -> None: - """Adding a node increases bucket size.""" - bucket = KBucket() - entry = NodeEntry(node_id=NodeId(b"\x01" * 32)) - - assert bucket.add(entry) - assert len(bucket) == 1 - assert bucket.contains(NodeId(b"\x01" * 32)) - - def test_bucket_capacity_limit(self) -> None: - """Bucket rejects nodes when at K_BUCKET_SIZE capacity.""" - bucket = KBucket() - - for i in range(K_BUCKET_SIZE): - entry = NodeEntry(node_id=NodeId(bytes([i]) + b"\x00" * 31)) - assert bucket.add(entry) - - assert bucket.is_full - assert len(bucket) == K_BUCKET_SIZE - - extra = NodeEntry(node_id=NodeId(b"\xff" * 32)) - assert not bucket.add(extra) - assert len(bucket) == K_BUCKET_SIZE - - def test_update_moves_to_tail(self) -> None: - """Re-adding existing node moves it to tail (most recent).""" - bucket = KBucket() - - entry1 = NodeEntry(node_id=NodeId(b"\x01" * 32), enr_seq=SeqNumber(1)) - entry2 = NodeEntry(node_id=NodeId(b"\x02" * 32), enr_seq=SeqNumber(1)) - bucket.add(entry1) - bucket.add(entry2) - - updated = NodeEntry(node_id=NodeId(b"\x01" * 32), enr_seq=SeqNumber(2)) - bucket.add(updated) - - tail = bucket.tail() - assert tail is not None - assert tail.node_id == NodeId(b"\x01" * 32) - assert tail.enr_seq == SeqNumber(2) - - def test_remove_node(self) -> None: - """Removing node decreases bucket size.""" - bucket = KBucket() - entry = NodeEntry(node_id=NodeId(b"\x01" * 32)) - bucket.add(entry) - - assert bucket.remove(NodeId(b"\x01" * 32)) - assert bucket.is_empty - assert not bucket.contains(NodeId(b"\x01" * 32)) - - def test_remove_nonexistent_returns_false(self) -> None: - """Removing nonexistent node returns False.""" - bucket = KBucket() - assert not bucket.remove(NodeId(b"\x01" * 32)) - - def test_get_existing_node(self) -> None: - """Get retrieves node by ID.""" - bucket = KBucket() - entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr_seq=SeqNumber(42)) - bucket.add(entry) - - retrieved = bucket.get(NodeId(b"\x01" * 32)) - assert retrieved is not None - assert retrieved.enr_seq == SeqNumber(42) - - def test_get_nonexistent_returns_none(self) -> None: - """Get returns None for unknown node.""" - bucket = KBucket() - assert bucket.get(NodeId(b"\x01" * 32)) is None - - def test_head_returns_oldest(self) -> None: - """Head returns least-recently seen node.""" - bucket = KBucket() - bucket.add(NodeEntry(node_id=NodeId(b"\x01" * 32))) - bucket.add(NodeEntry(node_id=NodeId(b"\x02" * 32))) - - head = bucket.head() - assert head is not None - assert head.node_id == NodeId(b"\x01" * 32) - - def test_tail_returns_newest(self) -> None: - """Tail returns most-recently seen node.""" - bucket = KBucket() - bucket.add(NodeEntry(node_id=NodeId(b"\x01" * 32))) - bucket.add(NodeEntry(node_id=NodeId(b"\x02" * 32))) - - tail = bucket.tail() - assert tail is not None - assert tail.node_id == NodeId(b"\x02" * 32) - - def test_iteration(self) -> None: - """Bucket supports iteration over nodes.""" - bucket = KBucket() - bucket.add(NodeEntry(node_id=NodeId(b"\x01" * 32))) - bucket.add(NodeEntry(node_id=NodeId(b"\x02" * 32))) - - node_ids = [entry.node_id for entry in bucket] - assert len(node_ids) == 2 - - -class TestRoutingTable: - """Tests for Kademlia routing table.""" - - def test_new_table_is_empty(self) -> None: - """New routing table has no nodes.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - assert table.node_count() == 0 - - def test_has_256_buckets(self) -> None: - """Routing table has 256 k-buckets.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - assert len(table.buckets) == BUCKET_COUNT - - def test_add_node(self) -> None: - """Adding node increases count.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - entry = NodeEntry(node_id=NodeId(b"\x00" * 31 + b"\x01")) - assert table.add(entry) - assert table.node_count() == 1 - assert table.contains(entry.node_id) - - def test_cannot_add_self(self) -> None: - """Adding local node ID is rejected.""" - local_id = NodeId(b"\xab" * 32) - table = RoutingTable(local_id=local_id) - - entry = NodeEntry(node_id=local_id) - assert not table.add(entry) - assert table.node_count() == 0 - - def test_bucket_assignment_by_distance(self) -> None: - """Nodes placed in correct bucket by log2 distance.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - node_id = NodeId(b"\x00" * 31 + b"\x01") # log2 distance = 1 - entry = NodeEntry(node_id=node_id) - table.add(entry) - - bucket_idx = table.bucket_index(node_id) - assert bucket_idx == 0 # distance 1 -> bucket 0 - assert table.buckets[0].contains(node_id) - - def test_get_existing_node(self) -> None: - """Get retrieves node from table.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr_seq=SeqNumber(99)) - table.add(entry) - - retrieved = table.get(entry.node_id) - assert retrieved is not None - assert retrieved.enr_seq == SeqNumber(99) - - def test_remove_node(self) -> None: - """Remove deletes node from table.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - entry = NodeEntry(node_id=NodeId(b"\x01" * 32)) - table.add(entry) - assert table.remove(entry.node_id) - assert not table.contains(entry.node_id) - - def test_closest_nodes_sorted_by_distance(self) -> None: - """closest_nodes returns nodes sorted by XOR distance.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - for i in range(1, 5): - entry = NodeEntry(node_id=NodeId(bytes([i]) + b"\x00" * 31)) - table.add(entry) - - target = NodeId(b"\x01" + b"\x00" * 31) - closest = table.closest_nodes(target, count=3) - - assert len(closest) == 3 - assert closest[0].node_id == target # Distance 0 to itself - - def test_closest_nodes_respects_count(self) -> None: - """closest_nodes returns at most count nodes.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - for i in range(10): - entry = NodeEntry(node_id=NodeId(bytes([i + 1]) + b"\x00" * 31)) - table.add(entry) - - closest = table.closest_nodes(NodeId(b"\x05" + b"\x00" * 31), count=3) - assert len(closest) == 3 - - def test_nodes_at_distance(self) -> None: - """nodes_at_distance returns nodes in specific bucket.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - node_id = NodeId(b"\x00" * 31 + b"\x01") # distance 1 - entry = NodeEntry(node_id=node_id) - table.add(entry) - - nodes = table.nodes_at_distance(Distance(1)) - assert len(nodes) == 1 - assert nodes[0].node_id == node_id - - def test_nodes_at_invalid_distance(self) -> None: - """Invalid distances return empty list.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) - - assert table.nodes_at_distance(Distance(0)) == [] - assert table.nodes_at_distance(Distance(257)) == [] - - -class TestRoutingTableForkFiltering: - """Tests for routing table fork compatibility filtering.""" - - def _make_enr_with_eth2(self, fork_digest_hex: str) -> "ENR": - """Create a minimal ENR with eth2 data for testing.""" - from lean_spec.subspecs.networking.enr import ENR - from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH - from lean_spec.types import Bytes64 - from lean_spec.types.byte_arrays import Bytes4 - - # Create eth2 bytes: fork_digest(4) + next_fork_version(4) + next_fork_epoch(8) - fork_digest = Bytes4(bytes.fromhex(fork_digest_hex)) - eth2_bytes = ( - bytes(fork_digest) + bytes(fork_digest) + int(FAR_FUTURE_EPOCH).to_bytes(8, "little") - ) - enr = ENR( - signature=Bytes64(b"\x00" * 64), - seq=SeqNumber(1), - pairs={"eth2": eth2_bytes, "id": b"v4"}, - ) - return enr - - def test_no_filtering_without_local_fork_digest(self) -> None: - """Nodes are accepted when local_fork_digest is not set.""" - local_id = NodeId(b"\x00" * 32) - table = RoutingTable(local_id=local_id) # No fork_digest - - entry = NodeEntry(node_id=NodeId(b"\x01" * 32)) # No ENR - assert table.add(entry) - assert table.contains(entry.node_id) - - def test_filtering_rejects_node_without_enr(self) -> None: - """Node without ENR is rejected when fork filtering is enabled.""" - from lean_spec.types.byte_arrays import Bytes4 - - local_id = NodeId(b"\x00" * 32) - fork_digest = Bytes4(bytes.fromhex("12345678")) - table = RoutingTable(local_id=local_id, local_fork_digest=fork_digest) - - entry = NodeEntry(node_id=NodeId(b"\x01" * 32)) # No ENR - assert not table.add(entry) - assert not table.contains(entry.node_id) - - def test_filtering_rejects_mismatched_fork(self) -> None: - """Node with different fork_digest is rejected.""" - from lean_spec.types.byte_arrays import Bytes4 - - local_id = NodeId(b"\x00" * 32) - local_fork = Bytes4(bytes.fromhex("12345678")) - table = RoutingTable(local_id=local_id, local_fork_digest=local_fork) - - enr = self._make_enr_with_eth2("deadbeef") # Different fork - entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr=enr) - - assert not table.add(entry) - assert not table.contains(entry.node_id) - - def test_filtering_accepts_matching_fork(self) -> None: - """Node with matching fork_digest is accepted.""" - from lean_spec.types.byte_arrays import Bytes4 - - local_id = NodeId(b"\x00" * 32) - local_fork = Bytes4(bytes.fromhex("12345678")) - table = RoutingTable(local_id=local_id, local_fork_digest=local_fork) - - enr = self._make_enr_with_eth2("12345678") # Same fork - entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr=enr) - - assert table.add(entry) - assert table.contains(entry.node_id) - - def test_is_fork_compatible_method(self) -> None: - """Test is_fork_compatible method directly.""" - from lean_spec.types.byte_arrays import Bytes4 - - local_id = NodeId(b"\x00" * 32) - local_fork = Bytes4(bytes.fromhex("12345678")) - table = RoutingTable(local_id=local_id, local_fork_digest=local_fork) - - # Compatible entry - compatible_enr = self._make_enr_with_eth2("12345678") - compatible_entry = NodeEntry(node_id=NodeId(b"\x01" * 32), enr=compatible_enr) - assert table.is_fork_compatible(compatible_entry) - - # Incompatible entry (different fork) - incompatible_enr = self._make_enr_with_eth2("deadbeef") - incompatible_entry = NodeEntry(node_id=NodeId(b"\x02" * 32), enr=incompatible_enr) - assert not table.is_fork_compatible(incompatible_entry) - - # Entry without ENR - no_enr_entry = NodeEntry(node_id=NodeId(b"\x03" * 32)) - assert not table.is_fork_compatible(no_enr_entry) - - -class TestNodeEntry: - """Tests for NodeEntry dataclass.""" - - def test_default_values(self) -> None: - """NodeEntry has sensible defaults.""" - entry = NodeEntry(node_id=NodeId(b"\x01" * 32)) - - assert entry.node_id == NodeId(b"\x01" * 32) - assert entry.enr_seq == SeqNumber(0) - assert entry.last_seen == 0.0 - assert entry.endpoint is None - assert entry.verified is False - assert entry.enr is None - - def test_full_construction(self) -> None: - """NodeEntry accepts all fields.""" - entry = NodeEntry( - node_id=NodeId(b"\x01" * 32), - enr_seq=SeqNumber(42), - last_seen=1234567890.0, - endpoint="192.168.1.1:30303", - verified=True, - ) - - assert entry.enr_seq == SeqNumber(42) - assert entry.endpoint == "192.168.1.1:30303" - assert entry.verified is True - - -class TestMessageConstructionFromTestVectors: - """Test message construction using official Discovery v5 test vector inputs.""" - - # From https://github.com/ethereum/devp2p/blob/master/discv5/discv5-wire-test-vectors.md - PING_REQUEST_ID = bytes.fromhex("00000001") - PING_ENR_SEQ = 2 - WHOAREYOU_ID_NONCE = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") - - def test_ping_message_construction(self) -> None: - """Construct PING message matching test vector inputs.""" - ping = Ping( - request_id=RequestId(data=self.PING_REQUEST_ID), - enr_seq=SeqNumber(self.PING_ENR_SEQ), - ) - - assert ping.request_id.data == self.PING_REQUEST_ID - assert ping.enr_seq == SeqNumber(2) - - def test_whoareyou_authdata_construction(self) -> None: - """Construct WHOAREYOU authdata matching test vector inputs.""" - authdata = WhoAreYouAuthdata( - id_nonce=IdNonce(self.WHOAREYOU_ID_NONCE), - enr_seq=SeqNumber(0), - ) - - assert authdata.id_nonce == IdNonce(self.WHOAREYOU_ID_NONCE) - assert authdata.enr_seq == SeqNumber(0) - - def test_plaintext_message_type(self) -> None: - """PING message plaintext starts with message type 0x01.""" - # From AES-GCM test vector plaintext - plaintext = bytes.fromhex("01c20101") - assert plaintext[0] == MessageType.PING - - -class TestPacketStructure: - """Tests for Discovery v5 packet structure constants.""" - - def test_static_header_size(self) -> None: - """Static header is 23 bytes per spec.""" - # protocol-id (6) + version (2) + flag (1) + nonce (12) + authdata-size (2) - expected_size = 6 + 2 + 1 + 12 + 2 - assert expected_size == 23 - - -class TestRoutingWithTestVectorNodeIds: - """Tests using official test vector node IDs with routing functions.""" - - # Node IDs from official test vectors (keccak256 of uncompressed pubkey) - NODE_A_ID = bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") - NODE_B_ID = bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") - - def test_xor_distance_is_symmetric(self) -> None: - """XOR distance between test vector nodes is symmetric and non-zero.""" - node_a = NodeId(self.NODE_A_ID) - node_b = NodeId(self.NODE_B_ID) - - distance = xor_distance(node_a, node_b) - assert distance > 0 - assert xor_distance(node_a, node_b) == xor_distance(node_b, node_a) - - def test_log2_distance_is_high(self) -> None: - """Log2 distance between test vector nodes is high (differ in high bits).""" - node_a = NodeId(self.NODE_A_ID) - node_b = NodeId(self.NODE_B_ID) - - log_dist = log2_distance(node_a, node_b) - assert log_dist > Distance(200)