diff --git a/src/courier.zig b/src/courier.zig index e3d86bf..742c136 100644 --- a/src/courier.zig +++ b/src/courier.zig @@ -3,6 +3,7 @@ const std = @import("std"); const yam = @import("root.zig"); +const message_utils = @import("message_utils.zig"); /// Courier manages a connection to a single Bitcoin peer pub const Courier = struct { @@ -73,7 +74,9 @@ pub const Courier = struct { return error.HandshakeTimeout; } - const message = try self.readMessage(); + // Use shared message reading utility with 4 MB limit and checksum verification + // (courier.zig enforces stricter limits for individual peer connections) + const message = try self.readMessageChecked(); defer if (message.payload.len > 0) self.allocator.free(message.payload); const cmd = std.mem.sliceTo(&message.header.command, 0); @@ -132,7 +135,8 @@ pub const Courier = struct { const elapsed: u64 = @intCast(std.time.milliTimestamp() - start); if (elapsed > timeout_ms) return false; - const message = self.readMessage() catch |err| { + // Use shared message reading utility with 4 MB limit and checksum verification + const message = self.readMessageChecked() catch |err| { if (err == error.WouldBlock) continue; return false; }; @@ -167,40 +171,13 @@ pub const Courier = struct { } } - fn readMessage(self: *Courier) !struct { header: yam.MessageHeader, payload: []u8 } { + /// Helper method to read a message with courier's strict validation settings + /// (4 MB payload limit + checksum verification) + fn readMessageChecked(self: *Courier) !message_utils.Message { const stream = self.stream orelse return error.NotConnected; - - var header_buffer: [24]u8 align(4) = undefined; - var total_read: usize = 0; - while (total_read < header_buffer.len) { - const bytes_read = try stream.read(header_buffer[total_read..]); - if (bytes_read == 0) return error.ConnectionClosed; - total_read += bytes_read; - } - - const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer); - const header = header_ptr.*; - - if (header.magic != 0xD9B4BEF9) return error.InvalidMagic; - - var payload: []u8 = &.{}; - if (header.length > 0) { - if (header.length > 4_000_000) return error.PayloadTooLarge; - - payload = try self.allocator.alloc(u8, header.length); - errdefer self.allocator.free(payload); - - total_read = 0; - while (total_read < header.length) { - const bytes_read = try stream.read(payload[total_read..]); - if (bytes_read == 0) return error.ConnectionClosed; - total_read += bytes_read; - } - - const calculated_checksum = yam.calculateChecksum(payload); - if (calculated_checksum != header.checksum) return error.InvalidChecksum; - } - - return .{ .header = header, .payload = payload }; + return message_utils.readMessage(stream, self.allocator, .{ + .max_payload_size = message_utils.MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }); } }; diff --git a/src/message_utils.zig b/src/message_utils.zig new file mode 100644 index 0000000..9e329e5 --- /dev/null +++ b/src/message_utils.zig @@ -0,0 +1,140 @@ +// message_utils.zig - Shared utilities for Bitcoin P2P message handling +// This module contains shared logic extracted from scout.zig and courier.zig +// to reduce code duplication and improve maintainability. + +const std = @import("std"); +const yam = @import("root.zig"); + +/// Maximum payload size for peer messages (4 MB) +/// This limit prevents memory exhaustion from malicious or misbehaving peers +pub const MAX_PAYLOAD_SIZE: u32 = 4_000_000; + +/// Options for configuring message reading behavior +pub const ReadMessageOptions = struct { + /// Maximum allowed payload size in bytes. If null, no limit is enforced. + /// courier.zig enforces a 4 MB limit for stricter peer connection management. + max_payload_size: ?u32 = null, + + /// Whether to verify the message checksum. If true, returns error.InvalidChecksum + /// when the calculated checksum doesn't match the header checksum. + verify_checksum: bool = false, +}; + +/// Result of reading a Bitcoin P2P protocol message +pub const Message = struct { + header: yam.MessageHeader, + payload: []u8, +}; + +/// Read a Bitcoin P2P protocol message from a stream +/// +/// This function reads a 24-byte message header followed by the payload. +/// It handles partial reads and validates the magic number. +/// +/// Caller is responsible for freeing the returned payload using the same allocator. +/// +/// Parameters: +/// - stream: The network stream to read from +/// - allocator: Memory allocator for payload allocation +/// - options: Configuration options (payload size limit, checksum verification) +/// +/// Returns: Message struct containing header and payload +/// +/// Errors: +/// - ConnectionClosed: Stream closed before full message received +/// - InvalidMagic: Header magic number doesn't match Bitcoin mainnet (0xD9B4BEF9) +/// - PayloadTooLarge: Payload exceeds max_payload_size (if specified in options) +/// - InvalidChecksum: Checksum verification failed (if enabled in options) +pub fn readMessage( + stream: std.net.Stream, + allocator: std.mem.Allocator, + options: ReadMessageOptions, +) !Message { + // Read the 24-byte message header + var header_buffer: [24]u8 align(4) = undefined; + var total_read: usize = 0; + while (total_read < header_buffer.len) { + const bytes_read = try stream.read(header_buffer[total_read..]); + if (bytes_read == 0) return error.ConnectionClosed; + total_read += bytes_read; + } + + // Parse header from buffer + const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer); + const header = header_ptr.*; + + // Validate magic number (Bitcoin mainnet) + if (header.magic != 0xD9B4BEF9) return error.InvalidMagic; + + // Read payload if present + var payload: []u8 = &.{}; + if (header.length > 0) { + // Enforce payload size limit if specified (e.g., 4 MB for courier.zig) + if (options.max_payload_size) |max_size| { + if (header.length > max_size) return error.PayloadTooLarge; + } + + // Allocate buffer for payload + payload = try allocator.alloc(u8, header.length); + errdefer allocator.free(payload); + + // Read payload data (may require multiple reads) + total_read = 0; + while (total_read < header.length) { + const bytes_read = try stream.read(payload[total_read..]); + if (bytes_read == 0) { + return error.ConnectionClosed; + } + total_read += bytes_read; + } + + // Verify checksum if requested (used by courier.zig for individual peer connections) + if (options.verify_checksum) { + const calculated_checksum = yam.calculateChecksum(payload); + if (calculated_checksum != header.checksum) { + return error.InvalidChecksum; + } + } + } + + return .{ .header = header, .payload = payload }; +} + +// ============================================================================ +// Tests +// ============================================================================ + +test "MAX_PAYLOAD_SIZE constant value" { + try std.testing.expectEqual(@as(u32, 4_000_000), MAX_PAYLOAD_SIZE); +} + +test "ReadMessageOptions default values" { + const opts = ReadMessageOptions{}; + try std.testing.expectEqual(@as(?u32, null), opts.max_payload_size); + try std.testing.expectEqual(false, opts.verify_checksum); +} + +test "ReadMessageOptions with custom values" { + const opts = ReadMessageOptions{ + .max_payload_size = MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }; + try std.testing.expectEqual(@as(?u32, 4_000_000), opts.max_payload_size); + try std.testing.expectEqual(true, opts.verify_checksum); +} + +test "Message struct basic usage" { + const allocator = std.testing.allocator; + + const header = yam.MessageHeader.new("test", 0, 0); + const payload = try allocator.alloc(u8, 0); + defer allocator.free(payload); + + const message = Message{ + .header = header, + .payload = payload, + }; + + try std.testing.expectEqual(@as(u32, 0xD9B4BEF9), message.header.magic); + try std.testing.expectEqual(@as(usize, 0), message.payload.len); +} diff --git a/src/scout.zig b/src/scout.zig index 5e2165f..1a7cb46 100644 --- a/src/scout.zig +++ b/src/scout.zig @@ -3,6 +3,7 @@ const std = @import("std"); const yam = @import("root.zig"); +const message_utils = @import("message_utils.zig"); /// DNS seeds for Bitcoin mainnet peer discovery const dns_seeds = [_][]const u8{ @@ -144,7 +145,11 @@ fn queryPeerForAddresses(allocator: std.mem.Allocator, peer: yam.PeerInfo) ![]ya const elapsed: u64 = @intCast(std.time.nanoTimestamp() - start); if (elapsed > timeout_ns) break; - const message = readMessage(stream, allocator) catch break; + // Use same strict validation as courier (4MB limit + checksum verification) + const message = message_utils.readMessage(stream, allocator, .{ + .max_payload_size = message_utils.MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }) catch break; defer if (message.payload.len > 0) allocator.free(message.payload); const cmd = std.mem.sliceTo(&message.header.command, 0); @@ -185,9 +190,19 @@ fn performHandshake(stream: std.net.Stream, allocator: std.mem.Allocator) !void var received_version = false; var received_verack = false; + const timeout_ms: i64 = 30_000; + const start = std.time.milliTimestamp(); while (!received_version or !received_verack) { - const message = try readMessage(stream, allocator); + if (std.time.milliTimestamp() - start > timeout_ms) { + return error.HandshakeTimeout; + } + + // Use same strict validation as courier (4MB limit + checksum verification) + const message = try message_utils.readMessage(stream, allocator, .{ + .max_payload_size = message_utils.MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }); defer if (message.payload.len > 0) allocator.free(message.payload); const cmd = std.mem.sliceTo(&message.header.command, 0); @@ -216,41 +231,6 @@ fn sendMessage(stream: std.net.Stream, command: []const u8, payload: []const u8) } } -fn readMessage(stream: std.net.Stream, allocator: std.mem.Allocator) !struct { header: yam.MessageHeader, payload: []u8 } { - var header_buffer: [24]u8 align(4) = undefined; - var total_read: usize = 0; - while (total_read < header_buffer.len) { - const bytes_read = try stream.read(header_buffer[total_read..]); - if (bytes_read == 0) return error.ConnectionClosed; - total_read += bytes_read; - } - - const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer); - const header = header_ptr.*; - - if (header.magic != 0xD9B4BEF9) return error.InvalidMagic; - - var payload: []u8 = &.{}; - if (header.length > 0) { - if (header.length > 4_000_000) return error.PayloadTooLarge; - - payload = try allocator.alloc(u8, header.length); - errdefer allocator.free(payload); - - total_read = 0; - while (total_read < header.length) { - const bytes_read = try stream.read(payload[total_read..]); - if (bytes_read == 0) return error.ConnectionClosed; - total_read += bytes_read; - } - - const calculated_checksum = yam.calculateChecksum(payload); - if (calculated_checksum != header.checksum) return error.InvalidChecksum; - } - - return .{ .header = header, .payload = payload }; -} - /// Select random peers from a list pub fn selectRandomPeers( allocator: std.mem.Allocator,