diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..82478b6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,29 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + branches: [master] + +jobs: + build: + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + if: github.event.pull_request.draft != true + + steps: + - uses: actions/checkout@v4 + + - name: Setup Zig + uses: mlugg/setup-zig@v2 + with: + version: 0.15.2 + + - name: Build + run: zig build + + - name: Run tests + run: zig build test diff --git a/README.md b/README.md index e346e2a..ddb6b83 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Yam -Lightweight Bitcoin P2P CLI network tool. Connect to nodes, observe mempool propagation, export data, and broadcast transactions (experimental). +Lightweight, zero-dependency Bitcoin P2P CLI network tool. Connect to nodes, observe mempool propagation, export data, and broadcast transactions (experimental). [Yam](https://en.wikipedia.org/wiki/Yam_(route)) is named after the Mongolian messaging system. diff --git a/src/courier.zig b/src/courier.zig index a801908..742c136 100644 --- a/src/courier.zig +++ b/src/courier.zig @@ -3,6 +3,7 @@ const std = @import("std"); const yam = @import("root.zig"); +const message_utils = @import("message_utils.zig"); /// Courier manages a connection to a single Bitcoin peer pub const Courier = struct { @@ -65,9 +66,17 @@ pub const Courier = struct { var received_version = false; var received_verack = false; + const timeout_ms: i64 = 30_000; + const start = std.time.milliTimestamp(); while (!received_version or !received_verack) { - const message = try self.readMessage(); + if (std.time.milliTimestamp() - start > timeout_ms) { + return error.HandshakeTimeout; + } + + // Use shared message reading utility with 4 MB limit and checksum verification + // (courier.zig enforces stricter limits for individual peer connections) + const message = try self.readMessageChecked(); defer if (message.payload.len > 0) self.allocator.free(message.payload); const cmd = std.mem.sliceTo(&message.header.command, 0); @@ -126,7 +135,8 @@ pub const Courier = struct { const elapsed: u64 = @intCast(std.time.milliTimestamp() - start); if (elapsed > timeout_ms) return false; - const message = self.readMessage() catch |err| { + // Use shared message reading utility with 4 MB limit and checksum verification + const message = self.readMessageChecked() catch |err| { if (err == error.WouldBlock) continue; return false; }; @@ -149,45 +159,6 @@ pub const Courier = struct { } } - /// Wait for a reject message (returns reason if rejected, null if no reject) - pub fn waitForReject(self: *Courier, timeout_ms: u64) !?[]u8 { - const start = std.time.milliTimestamp(); - - while (true) { - const elapsed: u64 = @intCast(std.time.milliTimestamp() - start); - if (elapsed > timeout_ms) return null; - - const message = self.readMessage() catch |err| { - if (err == error.WouldBlock) continue; - return null; - }; - - const cmd = std.mem.sliceTo(&message.header.command, 0); - - if (std.mem.eql(u8, cmd, "reject")) { - var fbs = std.io.fixedBufferStream(message.payload); - const reject = yam.RejectMessage.deserialize(fbs.reader(), self.allocator) catch { - self.allocator.free(message.payload); - return try self.allocator.dupe(u8, "unknown reject"); - }; - defer { - self.allocator.free(reject.message); - self.allocator.free(reject.data); - } - - // Keep the reason, free the rest - if (message.payload.len > 0) self.allocator.free(message.payload); - return reject.reason; - } else if (std.mem.eql(u8, cmd, "ping")) { - // Respond to pings - try self.sendMessage("pong", message.payload); - if (message.payload.len > 0) self.allocator.free(message.payload); - } else { - if (message.payload.len > 0) self.allocator.free(message.payload); - } - } - } - fn sendMessage(self: *Courier, command: []const u8, payload: []const u8) !void { const stream = self.stream orelse return error.NotConnected; @@ -200,46 +171,13 @@ pub const Courier = struct { } } - fn readMessage(self: *Courier) !struct { header: yam.MessageHeader, payload: []u8 } { + /// Helper method to read a message with courier's strict validation settings + /// (4 MB payload limit + checksum verification) + fn readMessageChecked(self: *Courier) !message_utils.Message { const stream = self.stream orelse return error.NotConnected; - - var header_buffer: [24]u8 align(4) = undefined; - var total_read: usize = 0; - while (total_read < header_buffer.len) { - const bytes_read = try stream.read(header_buffer[total_read..]); - if (bytes_read == 0) return error.ConnectionClosed; - total_read += bytes_read; - } - - const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer); - const header = header_ptr.*; - - if (header.magic != 0xD9B4BEF9) return error.InvalidMagic; - - var payload: []u8 = &.{}; - if (header.length > 0) { - if (header.length > 4_000_000) return error.PayloadTooLarge; - - payload = try self.allocator.alloc(u8, header.length); - errdefer self.allocator.free(payload); - - total_read = 0; - while (total_read < header.length) { - const bytes_read = try stream.read(payload[total_read..]); - if (bytes_read == 0) { - self.allocator.free(payload); - return error.ConnectionClosed; - } - total_read += bytes_read; - } - - const calculated_checksum = yam.calculateChecksum(payload); - if (calculated_checksum != header.checksum) { - self.allocator.free(payload); - return error.InvalidChecksum; - } - } - - return .{ .header = header, .payload = payload }; + return message_utils.readMessage(stream, self.allocator, .{ + .max_payload_size = message_utils.MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }); } }; diff --git a/src/main.zig b/src/main.zig index f9f66b9..5cd1be1 100644 --- a/src/main.zig +++ b/src/main.zig @@ -220,9 +220,6 @@ fn broadcastTransaction(allocator: std.mem.Allocator, args: BroadcastArgs) !void if (result.success_count > 0) { std.debug.print("\nTransaction broadcast to {d} peer(s)\n", .{result.success_count}); - if (result.reject_count > 0) { - std.debug.print("Warning: {d} peer(s) rejected the transaction\n", .{result.reject_count}); - } } else { std.debug.print("\nError: Broadcast failed to all peers\n", .{}); } diff --git a/src/message_utils.zig b/src/message_utils.zig new file mode 100644 index 0000000..9e329e5 --- /dev/null +++ b/src/message_utils.zig @@ -0,0 +1,140 @@ +// message_utils.zig - Shared utilities for Bitcoin P2P message handling +// This module contains shared logic extracted from scout.zig and courier.zig +// to reduce code duplication and improve maintainability. + +const std = @import("std"); +const yam = @import("root.zig"); + +/// Maximum payload size for peer messages (4 MB) +/// This limit prevents memory exhaustion from malicious or misbehaving peers +pub const MAX_PAYLOAD_SIZE: u32 = 4_000_000; + +/// Options for configuring message reading behavior +pub const ReadMessageOptions = struct { + /// Maximum allowed payload size in bytes. If null, no limit is enforced. + /// courier.zig enforces a 4 MB limit for stricter peer connection management. + max_payload_size: ?u32 = null, + + /// Whether to verify the message checksum. If true, returns error.InvalidChecksum + /// when the calculated checksum doesn't match the header checksum. + verify_checksum: bool = false, +}; + +/// Result of reading a Bitcoin P2P protocol message +pub const Message = struct { + header: yam.MessageHeader, + payload: []u8, +}; + +/// Read a Bitcoin P2P protocol message from a stream +/// +/// This function reads a 24-byte message header followed by the payload. +/// It handles partial reads and validates the magic number. +/// +/// Caller is responsible for freeing the returned payload using the same allocator. +/// +/// Parameters: +/// - stream: The network stream to read from +/// - allocator: Memory allocator for payload allocation +/// - options: Configuration options (payload size limit, checksum verification) +/// +/// Returns: Message struct containing header and payload +/// +/// Errors: +/// - ConnectionClosed: Stream closed before full message received +/// - InvalidMagic: Header magic number doesn't match Bitcoin mainnet (0xD9B4BEF9) +/// - PayloadTooLarge: Payload exceeds max_payload_size (if specified in options) +/// - InvalidChecksum: Checksum verification failed (if enabled in options) +pub fn readMessage( + stream: std.net.Stream, + allocator: std.mem.Allocator, + options: ReadMessageOptions, +) !Message { + // Read the 24-byte message header + var header_buffer: [24]u8 align(4) = undefined; + var total_read: usize = 0; + while (total_read < header_buffer.len) { + const bytes_read = try stream.read(header_buffer[total_read..]); + if (bytes_read == 0) return error.ConnectionClosed; + total_read += bytes_read; + } + + // Parse header from buffer + const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer); + const header = header_ptr.*; + + // Validate magic number (Bitcoin mainnet) + if (header.magic != 0xD9B4BEF9) return error.InvalidMagic; + + // Read payload if present + var payload: []u8 = &.{}; + if (header.length > 0) { + // Enforce payload size limit if specified (e.g., 4 MB for courier.zig) + if (options.max_payload_size) |max_size| { + if (header.length > max_size) return error.PayloadTooLarge; + } + + // Allocate buffer for payload + payload = try allocator.alloc(u8, header.length); + errdefer allocator.free(payload); + + // Read payload data (may require multiple reads) + total_read = 0; + while (total_read < header.length) { + const bytes_read = try stream.read(payload[total_read..]); + if (bytes_read == 0) { + return error.ConnectionClosed; + } + total_read += bytes_read; + } + + // Verify checksum if requested (used by courier.zig for individual peer connections) + if (options.verify_checksum) { + const calculated_checksum = yam.calculateChecksum(payload); + if (calculated_checksum != header.checksum) { + return error.InvalidChecksum; + } + } + } + + return .{ .header = header, .payload = payload }; +} + +// ============================================================================ +// Tests +// ============================================================================ + +test "MAX_PAYLOAD_SIZE constant value" { + try std.testing.expectEqual(@as(u32, 4_000_000), MAX_PAYLOAD_SIZE); +} + +test "ReadMessageOptions default values" { + const opts = ReadMessageOptions{}; + try std.testing.expectEqual(@as(?u32, null), opts.max_payload_size); + try std.testing.expectEqual(false, opts.verify_checksum); +} + +test "ReadMessageOptions with custom values" { + const opts = ReadMessageOptions{ + .max_payload_size = MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }; + try std.testing.expectEqual(@as(?u32, 4_000_000), opts.max_payload_size); + try std.testing.expectEqual(true, opts.verify_checksum); +} + +test "Message struct basic usage" { + const allocator = std.testing.allocator; + + const header = yam.MessageHeader.new("test", 0, 0); + const payload = try allocator.alloc(u8, 0); + defer allocator.free(payload); + + const message = Message{ + .header = header, + .payload = payload, + }; + + try std.testing.expectEqual(@as(u32, 0xD9B4BEF9), message.header.magic); + try std.testing.expectEqual(@as(usize, 0), message.payload.len); +} diff --git a/src/relay.zig b/src/relay.zig index ae0b6f6..ad21158 100644 --- a/src/relay.zig +++ b/src/relay.zig @@ -25,27 +25,15 @@ pub const BroadcastOptions = struct { pub const BroadcastReport = struct { peer: yam.PeerInfo, success: bool, - rejected: bool, - reject_reason: ?[]u8, elapsed_ms: u64, - - pub fn deinit(self: *BroadcastReport, allocator: std.mem.Allocator) void { - if (self.reject_reason) |reason| { - allocator.free(reason); - } - } }; /// Result of broadcasting to multiple peers pub const BroadcastResult = struct { reports: []BroadcastReport, success_count: usize, - reject_count: usize, pub fn deinit(self: *BroadcastResult, allocator: std.mem.Allocator) void { - for (self.reports) |*report| { - report.deinit(allocator); - } allocator.free(self.reports); } }; @@ -219,17 +207,10 @@ pub const Relay = struct { /// Broadcast a transaction to connected peers /// If max_peers is set, stops after that many successful broadcasts pub fn broadcastTx(self: *Relay, tx_bytes: []const u8, options: BroadcastOptions) !BroadcastResult { - // Allocate reports for actual broadcasts (may be fewer than couriers if max_peers set) var reports_list: std.ArrayList(BroadcastReport) = .empty; - errdefer { - for (reports_list.items) |*report| { - report.deinit(self.allocator); - } - reports_list.deinit(self.allocator); - } + errdefer reports_list.deinit(self.allocator); var success_count: usize = 0; - var reject_count: usize = 0; // Initialize RNG for staggered timing var rng_seed: u64 = undefined; @@ -257,12 +238,9 @@ pub const Relay = struct { var report = BroadcastReport{ .peer = courier.peer, .success = false, - .rejected = false, - .reject_reason = null, .elapsed_ms = 0, }; - // Send transaction courier.sendTx(tx_bytes) catch |err| { std.debug.print("Failed to send tx to peer: {s}\n", .{@errorName(err)}); report.elapsed_ms = @intCast(std.time.milliTimestamp() - start); @@ -270,18 +248,8 @@ pub const Relay = struct { continue; }; - // Wait briefly for potential reject message - const reject_result = courier.waitForReject(1000) catch null; - - if (reject_result) |reject| { - report.rejected = true; - report.reject_reason = reject; - reject_count += 1; - } else { - report.success = true; - success_count += 1; - } - + report.success = true; + success_count += 1; report.elapsed_ms = @intCast(std.time.milliTimestamp() - start); try reports_list.append(self.allocator, report); } @@ -289,7 +257,6 @@ pub const Relay = struct { return .{ .reports = try reports_list.toOwnedSlice(self.allocator), .success_count = success_count, - .reject_count = reject_count, }; } }; @@ -302,21 +269,12 @@ pub fn printBroadcastReport(reports: []const BroadcastReport, allocator: std.mem for (reports) |report| { const addr_str = report.peer.format(); - const status = if (report.success) - "SUCCESS" - else if (report.rejected) - "REJECTED" - else - "FAILED"; + const status = if (report.success) "SUCCESS" else "FAILED"; std.debug.print("{s}: {s} ({d}ms)\n", .{ std.mem.sliceTo(&addr_str, ' '), status, report.elapsed_ms, }); - - if (report.reject_reason) |reason| { - std.debug.print(" Reason: {s}\n", .{reason}); - } } } diff --git a/src/root.zig b/src/root.zig index 844cadf..9ccc734 100644 --- a/src/root.zig +++ b/src/root.zig @@ -111,6 +111,7 @@ pub const VersionPayload = struct { // User Agent: CompactSize length + string bytes const user_agent_len = try readVarInt(reader); + if (user_agent_len > 256) return error.UserAgentTooLong; const user_agent = try allocator.alloc(u8, user_agent_len); errdefer allocator.free(user_agent); _ = try reader.readAll(user_agent); @@ -147,6 +148,7 @@ pub const InvType = enum(u32) { msg_witness_tx = 0x40000001, // Transaction with witness msg_witness_block = 0x40000002, // Block with witness msg_filtered_witness_block = 0x40000003, // Filtered block with witness + _, }; // Inventory vector: type + hash @@ -196,6 +198,7 @@ pub const InvMessage = struct { pub fn deserialize(reader: anytype, allocator: std.mem.Allocator) !InvMessage { // Read count as CompactSize const count = try readVarInt(reader); + if (count > 50000) return error.TooManyInvVectors; // Allocate array for vectors const vectors = try allocator.alloc(InvVector, count); @@ -253,6 +256,7 @@ pub const RejectMessage = struct { pub fn deserialize(reader: anytype, allocator: std.mem.Allocator) !RejectMessage { // Read message type (var_str) const msg_len = try readVarInt(reader); + if (msg_len > 12) return error.RejectMessageTooLong; const message = try allocator.alloc(u8, msg_len); errdefer allocator.free(message); _ = try reader.readAll(message); @@ -262,6 +266,7 @@ pub const RejectMessage = struct { // Read reason (var_str) const reason_len = try readVarInt(reader); + if (reason_len > 111) return error.RejectReasonTooLong; const reason = try allocator.alloc(u8, reason_len); errdefer allocator.free(reason); _ = try reader.readAll(reason); @@ -340,6 +345,7 @@ pub const TxInput = struct { // Read script length (CompactSize) const script_len = try readVarInt(reader); + if (script_len > 10000) return error.ScriptTooLarge; // Read script const script = try allocator.alloc(u8, script_len); @@ -391,6 +397,7 @@ pub const TxOutput = struct { // Read script length (CompactSize) const script_len = try readVarInt(reader); + if (script_len > 10000) return error.ScriptTooLarge; // Read script const script = try allocator.alloc(u8, script_len); @@ -486,6 +493,7 @@ pub const Transaction = struct { if (is_segwit) { for (inputs) |*input| { const witness_count = try readVarInt(reader); + if (witness_count > 500) return error.TooManyWitnessItems; const witness = try allocator.alloc([]u8, witness_count); var witness_initialized: usize = 0; errdefer { @@ -497,6 +505,7 @@ pub const Transaction = struct { for (witness) |*item| { const item_len = try readVarInt(reader); + if (item_len > 520) return error.WitnessItemTooLarge; item.* = try allocator.alloc(u8, item_len); errdefer allocator.free(item.*); _ = try reader.readAll(item.*); diff --git a/src/scout.zig b/src/scout.zig index 7b0b801..1a7cb46 100644 --- a/src/scout.zig +++ b/src/scout.zig @@ -3,6 +3,7 @@ const std = @import("std"); const yam = @import("root.zig"); +const message_utils = @import("message_utils.zig"); /// DNS seeds for Bitcoin mainnet peer discovery const dns_seeds = [_][]const u8{ @@ -144,7 +145,11 @@ fn queryPeerForAddresses(allocator: std.mem.Allocator, peer: yam.PeerInfo) ![]ya const elapsed: u64 = @intCast(std.time.nanoTimestamp() - start); if (elapsed > timeout_ns) break; - const message = readMessage(stream, allocator) catch break; + // Use same strict validation as courier (4MB limit + checksum verification) + const message = message_utils.readMessage(stream, allocator, .{ + .max_payload_size = message_utils.MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }) catch break; defer if (message.payload.len > 0) allocator.free(message.payload); const cmd = std.mem.sliceTo(&message.header.command, 0); @@ -185,9 +190,19 @@ fn performHandshake(stream: std.net.Stream, allocator: std.mem.Allocator) !void var received_version = false; var received_verack = false; + const timeout_ms: i64 = 30_000; + const start = std.time.milliTimestamp(); while (!received_version or !received_verack) { - const message = try readMessage(stream, allocator); + if (std.time.milliTimestamp() - start > timeout_ms) { + return error.HandshakeTimeout; + } + + // Use same strict validation as courier (4MB limit + checksum verification) + const message = try message_utils.readMessage(stream, allocator, .{ + .max_payload_size = message_utils.MAX_PAYLOAD_SIZE, + .verify_checksum = true, + }); defer if (message.payload.len > 0) allocator.free(message.payload); const cmd = std.mem.sliceTo(&message.header.command, 0); @@ -216,39 +231,6 @@ fn sendMessage(stream: std.net.Stream, command: []const u8, payload: []const u8) } } -fn readMessage(stream: std.net.Stream, allocator: std.mem.Allocator) !struct { header: yam.MessageHeader, payload: []u8 } { - var header_buffer: [24]u8 align(4) = undefined; - var total_read: usize = 0; - while (total_read < header_buffer.len) { - const bytes_read = try stream.read(header_buffer[total_read..]); - if (bytes_read == 0) return error.ConnectionClosed; - total_read += bytes_read; - } - - const header_ptr = std.mem.bytesAsValue(yam.MessageHeader, &header_buffer); - const header = header_ptr.*; - - if (header.magic != 0xD9B4BEF9) return error.InvalidMagic; - - var payload: []u8 = &.{}; - if (header.length > 0) { - payload = try allocator.alloc(u8, header.length); - errdefer allocator.free(payload); - - total_read = 0; - while (total_read < header.length) { - const bytes_read = try stream.read(payload[total_read..]); - if (bytes_read == 0) { - allocator.free(payload); - return error.ConnectionClosed; - } - total_read += bytes_read; - } - } - - return .{ .header = header, .payload = payload }; -} - /// Select random peers from a list pub fn selectRandomPeers( allocator: std.mem.Allocator,