diff --git a/investigations/test/comprehensive_rust_compatibility_test.zig b/investigations/test/comprehensive_rust_compatibility_test.zig index 2f1b067..111d73c 100644 --- a/investigations/test/comprehensive_rust_compatibility_test.zig +++ b/investigations/test/comprehensive_rust_compatibility_test.zig @@ -45,7 +45,7 @@ fn testSignatureSchemeCorrectness( const epoch_u64 = @as(u64, epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (iterations >= epoch) break; - try keypair.secret_key.advancePreparation(@intCast(log_lifetime)); + try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime)); iterations += 1; } @@ -88,7 +88,7 @@ fn testDeterministicBehavior(allocator: std.mem.Allocator) !void { const epoch_u64 = @as(u64, epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (iterations >= epoch) break; - try keypair.secret_key.advancePreparation(8); + try keypair.secret_key.advancePreparation(scheme, 8); iterations += 1; } @@ -192,7 +192,7 @@ fn testKeyPreparationAdvancement(allocator: std.mem.Allocator) !void { log.print("Initial prepared interval: {} to {}\n", .{ initial_prepared.start, initial_prepared.end }); // Advance preparation - try keypair.secret_key.advancePreparation(8); + try keypair.secret_key.advancePreparation(scheme, 8); const advanced_prepared = keypair.secret_key.getPreparedInterval(8); log.print("Advanced prepared interval: {} to {}\n", .{ advanced_prepared.start, advanced_prepared.end }); diff --git a/investigations/test/encoding_variants_test.zig b/investigations/test/encoding_variants_test.zig index 7620505..e136bd4 100644 --- a/investigations/test/encoding_variants_test.zig +++ b/investigations/test/encoding_variants_test.zig @@ -34,7 +34,7 @@ fn testWinternitzVariants(allocator: std.mem.Allocator) !void { const epoch_u64 = @as(u64, epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (iterations >= epoch) break; - try keypair.secret_key.advancePreparation(@intCast(log_lifetime)); + try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime)); iterations += 1; } @@ -79,7 +79,7 @@ fn testTargetSumVariants(allocator: std.mem.Allocator) !void { const epoch_u64 = @as(u64, epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (iterations >= epoch) break; - try keypair.secret_key.advancePreparation(@intCast(log_lifetime)); + try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime)); iterations += 1; } @@ -128,7 +128,7 @@ fn testMultipleLifetimeConfigurations(allocator: std.mem.Allocator) !void { const epoch_u64 = @as(u64, epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (iterations >= epoch) break; - try keypair.secret_key.advancePreparation(@intCast(log_lifetime)); + try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime)); iterations += 1; } @@ -180,7 +180,7 @@ fn testSignatureSchemeCorrectnessVariants(allocator: std.mem.Allocator) !void { const epoch_u64 = @as(u64, test_case.test_epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (iterations >= test_case.test_epoch) break; - try keypair.secret_key.advancePreparation(@intCast(log_lifetime)); + try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime)); iterations += 1; } diff --git a/investigations/test/performance_benchmark_test.zig b/investigations/test/performance_benchmark_test.zig index 74270d6..8018bbe 100644 --- a/investigations/test/performance_benchmark_test.zig +++ b/investigations/test/performance_benchmark_test.zig @@ -72,7 +72,7 @@ fn benchmarkSigning(allocator: std.mem.Allocator, lifetime: hash_zig.KeyLifetime const epoch_u64 = @as(u64, epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (prep_iterations >= epoch) break; - try keypair.secret_key.advancePreparation(@intCast(log_lifetime)); + try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime)); prep_iterations += 1; } @@ -125,7 +125,7 @@ fn benchmarkVerification(allocator: std.mem.Allocator, lifetime: hash_zig.KeyLif const epoch_u64 = @as(u64, epoch); if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break; if (prep_iterations >= epoch) break; - try keypair.secret_key.advancePreparation(@intCast(log_lifetime)); + try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime)); prep_iterations += 1; } diff --git a/investigations/test/performance_test.zig b/investigations/test/performance_test.zig index f3f7d7f..d3db6c0 100644 --- a/investigations/test/performance_test.zig +++ b/investigations/test/performance_test.zig @@ -278,7 +278,7 @@ test "chacha12 rng compatibility" { var seed: [32]u8 = undefined; @memset(&seed, 0x42); - var rng = chacha12_rng.init(seed); + var rng = chacha12_rng.ChaCha12Rng.init(seed); var prf_key: [32]u8 = undefined; rng.fill(&prf_key); @@ -368,25 +368,30 @@ test "epoch range validation" { @memset(&seed, 0x77); // Generate key with limited epoch range + // Note: keyGen(100, 10) expands to align with bottom tree boundaries + // For lifetime 2^8: leafs_per_bottom_tree = 16, min 2 bottom trees = 32 epochs + // Expansion: start aligns to 96, end expands to 128, so valid range is [96, 128) var keypair = try sig_scheme.keyGen(100, 10); defer keypair.secret_key.deinit(); log.print("Keypair generated:\n", .{}); - log.print(" Activation epoch: {}\n", .{keypair.secret_key.activation_epoch}); - log.print(" Active epochs: {}\n", .{keypair.secret_key.num_active_epochs}); + log.print(" Requested: activation=100, num_active=10\n", .{}); + log.print(" Expanded activation epoch: {}\n", .{keypair.secret_key.activation_epoch}); + log.print(" Expanded active epochs: {}\n", .{keypair.secret_key.num_active_epochs}); log.print(" Valid range: {} - {}\n\n", .{ keypair.secret_key.activation_epoch, keypair.secret_key.activation_epoch + keypair.secret_key.num_active_epochs - 1 }); const test_message = [_]u8{ 0x54, 0x65, 0x73, 0x74, 0x20, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65 } ++ [_]u8{0x00} ** 20; // "Test message" + padding - // Test signing within valid range - log.print("Signing at epoch 105 (valid)...", .{}); + // Test signing within valid range (epoch 105 is in expanded range [96, 128)) + log.print("Signing at epoch 105 (valid, in expanded range [96, 128))...", .{}); var sig_valid = try sig_scheme.sign(keypair.secret_key, 105, test_message); defer sig_valid.deinit(); log.print(" ✅\n", .{}); // Test signing outside valid range (should fail) - log.print("Signing at epoch 110 (invalid)...", .{}); - const result = sig_scheme.sign(keypair.secret_key, 110, test_message); + // Epoch 128 is outside the expanded range [96, 128) + log.print("Signing at epoch 128 (invalid, outside expanded range)...", .{}); + const result = sig_scheme.sign(keypair.secret_key, 128, test_message); try testing.expectError(error.KeyNotActive, result); log.print(" ✅ Correctly rejected\n", .{}); diff --git a/src/poseidon2/plonky3_field.zig b/src/poseidon2/plonky3_field.zig index 130e5b3..fded7bb 100644 --- a/src/poseidon2/plonky3_field.zig +++ b/src/poseidon2/plonky3_field.zig @@ -67,6 +67,11 @@ pub const KoalaBearField = struct { return KoalaBearField{ .value = toMonty(inv_normal) }; } + // Field negation (additive inverse) + pub fn neg(self: KoalaBearField) KoalaBearField { + return KoalaBearField.zero.sub(self); + } + // Double operation (exact from Plonky3) pub fn double(self: KoalaBearField) KoalaBearField { return self.add(self); diff --git a/src/signature/mod.zig b/src/signature/mod.zig index fa56c51..e4f69dc 100644 --- a/src/signature/mod.zig +++ b/src/signature/mod.zig @@ -5,3 +5,7 @@ pub const GeneralizedXMSSPublicKey = native.GeneralizedXMSSPublicKey; pub const GeneralizedXMSSSecretKey = native.GeneralizedXMSSSecretKey; pub const GeneralizedXMSSSignature = native.GeneralizedXMSSSignature; pub const HashTreeOpening = native.HashTreeOpening; + +test { + _ = @import("native/scheme.zig"); +} diff --git a/src/signature/native/scheme.zig b/src/signature/native/scheme.zig index 089f917..4eb4b30 100644 --- a/src/signature/native/scheme.zig +++ b/src/signature/native/scheme.zig @@ -76,15 +76,15 @@ test "SSZ: HashTreeOpening roundtrip" { const original = try HashTreeOpening.init(allocator, nodes); defer original.deinit(); - // Encode + // Encode (using 8 field elements per hash, matching lifetime_2_8) var encoded = std.ArrayList(u8).init(allocator); defer encoded.deinit(); - try original.sszEncode(&encoded); + try original.sszEncode(&encoded, 8); - // Decode + // Decode (stack-allocated, so only free nodes, not the struct itself) var decoded: HashTreeOpening = undefined; - defer decoded.deinit(); - try HashTreeOpening.sszDecode(encoded.items, &decoded, allocator); + defer allocator.free(decoded.nodes); + try HashTreeOpening.sszDecode(encoded.items, &decoded, allocator, 8); // Verify const decoded_nodes = decoded.getNodes(); @@ -778,18 +778,18 @@ pub const HashTreeOpening = struct { } /// Serialize to SSZ bytes (convenience method matching Rust's to_bytes) - pub fn toBytes(self: *const HashTreeOpening, allocator: std.mem.Allocator) ![]u8 { + pub fn toBytes(self: *const HashTreeOpening, allocator: std.mem.Allocator, hash_len_fe: usize) ![]u8 { var encoded = std.ArrayList(u8).init(allocator); errdefer encoded.deinit(); - try self.sszEncode(&encoded); + try self.sszEncode(&encoded, hash_len_fe); return encoded.toOwnedSlice(); } /// Deserialize from SSZ bytes (convenience method matching Rust's from_bytes) - pub fn fromBytes(serialized: []const u8, allocator: std.mem.Allocator) !*HashTreeOpening { + pub fn fromBytes(serialized: []const u8, allocator: std.mem.Allocator, hash_len_fe: usize) !*HashTreeOpening { const decoded = try allocator.create(HashTreeOpening); errdefer allocator.destroy(decoded); - try sszDecode(serialized, decoded, allocator); + try sszDecode(serialized, decoded, allocator, hash_len_fe); return decoded; } }; @@ -4788,16 +4788,11 @@ pub const GeneralizedXMSSSignatureScheme = struct { activation_epoch: usize, num_active_epochs: usize, ) !KeyGenResult { - // Rust leansig library multiplies num_active_epochs by 128 internally - // To match Rust's behavior exactly, we multiply by 128 here - // Example: Input 1024 -> Rust stores 131072 (1024 * 128) in SSZ - const rust_compatible_num_active_epochs = num_active_epochs * 128; - // Generate random parameter and PRF key (matching Rust order exactly) const parameter = try self.generateRandomParameter(); const prf_key = try self.generateRandomPRFKey(); // RNG has already been consumed by generateRandomPRFKey() (32 bytes) - return self.keyGenWithParameter(activation_epoch, rust_compatible_num_active_epochs, parameter, prf_key, true); + return self.keyGenWithParameter(activation_epoch, num_active_epochs, parameter, prf_key, true); } /// Key generation with provided parameter and PRF key (for reconstructing keys from serialized data) @@ -4831,7 +4826,8 @@ pub const GeneralizedXMSSSignatureScheme = struct { return error.InvalidActivationParameters; } - // Expand activation time to align with bottom trees + // Expand activation time to align with bottom trees (matching Rust exactly) + const leafs_per_bottom_tree = @as(usize, 1) << @intCast(self.lifetime_params.log_lifetime / 2); const expansion_result = expandActivationTime(self.lifetime_params.log_lifetime, activation_epoch, num_active_epochs); const num_bottom_trees = expansion_result.end - expansion_result.start; @@ -4839,9 +4835,9 @@ pub const GeneralizedXMSSSignatureScheme = struct { return error.InsufficientBottomTrees; } - // Use the provided activation parameters directly (not expanded) - const expanded_activation_epoch = activation_epoch; - const expanded_num_active_epochs = num_active_epochs; + // Compute expanded values from bottom tree alignment (matching Rust lines 692-693) + const expanded_activation_epoch = expansion_result.start * leafs_per_bottom_tree; + const expanded_num_active_epochs = num_bottom_trees * leafs_per_bottom_tree; // Consume RNG state only if it hasn't been consumed yet // When called from keyGen(), the RNG state is already after PRF key generation (32 bytes consumed). @@ -5328,17 +5324,13 @@ pub const GeneralizedXMSSSignatureScheme = struct { // For top tree, use bottom_tree_index directly (absolute position) // Rust's combined_path uses epoch directly, and the top tree layers are built - // with start_index = left_bottom_tree_index from keyGen, so we use bottom_tree_index - // directly, and computePathFromLayers handles the offset via layer.start_index subtraction - // left_bottom_tree_index already declared above + // with start_index from keyGen (expansion_result.start), so we use bottom_tree_index + // directly. computePathFromLayers handles the offset via layer.start_index subtraction. + // Note: left_bottom_tree_index changes after advancePreparation, but the top tree's + // layer start_index values remain fixed from keygen. The path computation is independent + // of the current prepared position - it uses each layer's own start_index. const top_pos = @as(u32, @intCast(bottom_tree_index)); - // This ensures the path computation uses the correct offset - if (top_layers.len > 0 and top_layers[0].start_index != left_bottom_tree_index) { - log.debugPrint("ZIG_SIGN_ERROR: Top tree layer start_index mismatch! top_layers[0].start_index={}, left_bottom_tree_index={}\n", .{ top_layers[0].start_index, left_bottom_tree_index }); - return error.TopTreeStartIndexMismatch; - } - // Debug: log top tree layer start_index values log.print("ZIG_SIGN_DEBUG: Computing top tree path: bottom_tree_index={}, left_bottom_tree_index={}, top_pos={}\n", .{ bottom_tree_index, left_bottom_tree_index, top_pos }); for (top_layers, 0..) |layer, i| {