Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 13 additions & 36 deletions src/courier.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

const std = @import("std");
const yam = @import("root.zig");
const message_utils = @import("message_utils.zig");

/// Courier manages a connection to a single Bitcoin peer
pub const Courier = struct {
Expand Down Expand Up @@ -73,7 +74,9 @@ pub const Courier = struct {
return error.HandshakeTimeout;
}

const message = try self.readMessage();
// Use shared message reading utility with 4 MB limit and checksum verification
// (courier.zig enforces stricter limits for individual peer connections)
const message = try self.readMessageChecked();
defer if (message.payload.len > 0) self.allocator.free(message.payload);

const cmd = std.mem.sliceTo(&message.header.command, 0);
Expand Down Expand Up @@ -132,7 +135,8 @@ pub const Courier = struct {
const elapsed: u64 = @intCast(std.time.milliTimestamp() - start);
if (elapsed > timeout_ms) return false;

const message = self.readMessage() catch |err| {
// Use shared message reading utility with 4 MB limit and checksum verification
const message = self.readMessageChecked() catch |err| {
if (err == error.WouldBlock) continue;
return false;
};
Expand Down Expand Up @@ -167,40 +171,13 @@ pub const Courier = struct {
}
}

fn readMessage(self: *Courier) !struct { header: yam.MessageHeader, payload: []u8 } {
/// Helper method to read a message with courier's strict validation settings
/// (4 MB payload limit + checksum verification)
fn readMessageChecked(self: *Courier) !message_utils.Message {
const stream = self.stream orelse return error.NotConnected;

var header_buffer: [24]u8 align(4) = undefined;
var total_read: usize = 0;
while (total_read < header_buffer.len) {
const bytes_read = try stream.read(header_buffer[total_read..]);
if (bytes_read == 0) return error.ConnectionClosed;
total_read += bytes_read;
}

const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer);
const header = header_ptr.*;

if (header.magic != 0xD9B4BEF9) return error.InvalidMagic;

var payload: []u8 = &.{};
if (header.length > 0) {
if (header.length > 4_000_000) return error.PayloadTooLarge;

payload = try self.allocator.alloc(u8, header.length);
errdefer self.allocator.free(payload);

total_read = 0;
while (total_read < header.length) {
const bytes_read = try stream.read(payload[total_read..]);
if (bytes_read == 0) return error.ConnectionClosed;
total_read += bytes_read;
}

const calculated_checksum = yam.calculateChecksum(payload);
if (calculated_checksum != header.checksum) return error.InvalidChecksum;
}

return .{ .header = header, .payload = payload };
return message_utils.readMessage(stream, self.allocator, .{
.max_payload_size = message_utils.MAX_PAYLOAD_SIZE,
.verify_checksum = true,
});
}
};
140 changes: 140 additions & 0 deletions src/message_utils.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// message_utils.zig - Shared utilities for Bitcoin P2P message handling
// This module contains shared logic extracted from scout.zig and courier.zig
// to reduce code duplication and improve maintainability.

const std = @import("std");
const yam = @import("root.zig");

/// Maximum payload size for peer messages (4 MB)
/// This limit prevents memory exhaustion from malicious or misbehaving peers
pub const MAX_PAYLOAD_SIZE: u32 = 4_000_000;

/// Options for configuring message reading behavior
pub const ReadMessageOptions = struct {
/// Maximum allowed payload size in bytes. If null, no limit is enforced.
/// courier.zig enforces a 4 MB limit for stricter peer connection management.
max_payload_size: ?u32 = null,

/// Whether to verify the message checksum. If true, returns error.InvalidChecksum
/// when the calculated checksum doesn't match the header checksum.
verify_checksum: bool = false,
};

/// Result of reading a Bitcoin P2P protocol message
pub const Message = struct {
header: yam.MessageHeader,
payload: []u8,
};

/// Read a Bitcoin P2P protocol message from a stream
///
/// This function reads a 24-byte message header followed by the payload.
/// It handles partial reads and validates the magic number.
///
/// Caller is responsible for freeing the returned payload using the same allocator.
///
/// Parameters:
/// - stream: The network stream to read from
/// - allocator: Memory allocator for payload allocation
/// - options: Configuration options (payload size limit, checksum verification)
///
/// Returns: Message struct containing header and payload
///
/// Errors:
/// - ConnectionClosed: Stream closed before full message received
/// - InvalidMagic: Header magic number doesn't match Bitcoin mainnet (0xD9B4BEF9)
/// - PayloadTooLarge: Payload exceeds max_payload_size (if specified in options)
/// - InvalidChecksum: Checksum verification failed (if enabled in options)
pub fn readMessage(
stream: std.net.Stream,
allocator: std.mem.Allocator,
options: ReadMessageOptions,
) !Message {
// Read the 24-byte message header
var header_buffer: [24]u8 align(4) = undefined;
var total_read: usize = 0;
while (total_read < header_buffer.len) {
const bytes_read = try stream.read(header_buffer[total_read..]);
if (bytes_read == 0) return error.ConnectionClosed;
total_read += bytes_read;
}

// Parse header from buffer
const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer);
const header = header_ptr.*;

// Validate magic number (Bitcoin mainnet)
if (header.magic != 0xD9B4BEF9) return error.InvalidMagic;

// Read payload if present
var payload: []u8 = &.{};
if (header.length > 0) {
// Enforce payload size limit if specified (e.g., 4 MB for courier.zig)
if (options.max_payload_size) |max_size| {
if (header.length > max_size) return error.PayloadTooLarge;
}

// Allocate buffer for payload
payload = try allocator.alloc(u8, header.length);
errdefer allocator.free(payload);

// Read payload data (may require multiple reads)
total_read = 0;
while (total_read < header.length) {
const bytes_read = try stream.read(payload[total_read..]);
if (bytes_read == 0) {
return error.ConnectionClosed;
}
total_read += bytes_read;
}

// Verify checksum if requested (used by courier.zig for individual peer connections)
if (options.verify_checksum) {
const calculated_checksum = yam.calculateChecksum(payload);
if (calculated_checksum != header.checksum) {
return error.InvalidChecksum;
}
}
}

return .{ .header = header, .payload = payload };
}
Copy link
Contributor

@kwsantiago kwsantiago Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we have tests? I think you had in a previous commit that were useful.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a few basic ones. I hope that's ok? let me know anything else I should add.


// ============================================================================
// Tests
// ============================================================================

test "MAX_PAYLOAD_SIZE constant value" {
try std.testing.expectEqual(@as(u32, 4_000_000), MAX_PAYLOAD_SIZE);
}

test "ReadMessageOptions default values" {
const opts = ReadMessageOptions{};
try std.testing.expectEqual(@as(?u32, null), opts.max_payload_size);
try std.testing.expectEqual(false, opts.verify_checksum);
}

test "ReadMessageOptions with custom values" {
const opts = ReadMessageOptions{
.max_payload_size = MAX_PAYLOAD_SIZE,
.verify_checksum = true,
};
try std.testing.expectEqual(@as(?u32, 4_000_000), opts.max_payload_size);
try std.testing.expectEqual(true, opts.verify_checksum);
}

test "Message struct basic usage" {
const allocator = std.testing.allocator;

const header = yam.MessageHeader.new("test", 0, 0);
const payload = try allocator.alloc(u8, 0);
defer allocator.free(payload);

const message = Message{
.header = header,
.payload = payload,
};

try std.testing.expectEqual(@as(u32, 0xD9B4BEF9), message.header.magic);
try std.testing.expectEqual(@as(usize, 0), message.payload.len);
}
54 changes: 17 additions & 37 deletions src/scout.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

const std = @import("std");
const yam = @import("root.zig");
const message_utils = @import("message_utils.zig");

/// DNS seeds for Bitcoin mainnet peer discovery
const dns_seeds = [_][]const u8{
Expand Down Expand Up @@ -144,7 +145,11 @@ fn queryPeerForAddresses(allocator: std.mem.Allocator, peer: yam.PeerInfo) ![]ya
const elapsed: u64 = @intCast(std.time.nanoTimestamp() - start);
if (elapsed > timeout_ns) break;

const message = readMessage(stream, allocator) catch break;
// Use same strict validation as courier (4MB limit + checksum verification)
const message = message_utils.readMessage(stream, allocator, .{
.max_payload_size = message_utils.MAX_PAYLOAD_SIZE,
.verify_checksum = true,
}) catch break;
defer if (message.payload.len > 0) allocator.free(message.payload);

const cmd = std.mem.sliceTo(&message.header.command, 0);
Expand Down Expand Up @@ -185,9 +190,19 @@ fn performHandshake(stream: std.net.Stream, allocator: std.mem.Allocator) !void

var received_version = false;
var received_verack = false;
const timeout_ms: i64 = 30_000;
const start = std.time.milliTimestamp();

while (!received_version or !received_verack) {
const message = try readMessage(stream, allocator);
if (std.time.milliTimestamp() - start > timeout_ms) {
return error.HandshakeTimeout;
}

// Use same strict validation as courier (4MB limit + checksum verification)
const message = try message_utils.readMessage(stream, allocator, .{
.max_payload_size = message_utils.MAX_PAYLOAD_SIZE,
.verify_checksum = true,
});
Comment on lines 201 to 205
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing timeout in handshake loop. A malicious peer can hang indefinitely by sending continuous non-version/verack messages. The courier.zig performHandshake has a 30-second timeout, but scout.zig's performHandshake has none.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be covered now. added the same check to scout.

defer if (message.payload.len > 0) allocator.free(message.payload);

const cmd = std.mem.sliceTo(&message.header.command, 0);
Expand Down Expand Up @@ -216,41 +231,6 @@ fn sendMessage(stream: std.net.Stream, command: []const u8, payload: []const u8)
}
}

fn readMessage(stream: std.net.Stream, allocator: std.mem.Allocator) !struct { header: yam.MessageHeader, payload: []u8 } {
var header_buffer: [24]u8 align(4) = undefined;
var total_read: usize = 0;
while (total_read < header_buffer.len) {
const bytes_read = try stream.read(header_buffer[total_read..]);
if (bytes_read == 0) return error.ConnectionClosed;
total_read += bytes_read;
}

const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer);
const header = header_ptr.*;

if (header.magic != 0xD9B4BEF9) return error.InvalidMagic;

var payload: []u8 = &.{};
if (header.length > 0) {
if (header.length > 4_000_000) return error.PayloadTooLarge;

payload = try allocator.alloc(u8, header.length);
errdefer allocator.free(payload);

total_read = 0;
while (total_read < header.length) {
const bytes_read = try stream.read(payload[total_read..]);
if (bytes_read == 0) return error.ConnectionClosed;
total_read += bytes_read;
}

const calculated_checksum = yam.calculateChecksum(payload);
if (calculated_checksum != header.checksum) return error.InvalidChecksum;
}

return .{ .header = header, .payload = payload };
}

/// Select random peers from a list
pub fn selectRandomPeers(
allocator: std.mem.Allocator,
Expand Down