Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions investigations/test/comprehensive_rust_compatibility_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn testSignatureSchemeCorrectness(
const epoch_u64 = @as(u64, epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (iterations >= epoch) break;
try keypair.secret_key.advancePreparation(@intCast(log_lifetime));
try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime));
iterations += 1;
}

Expand Down Expand Up @@ -88,7 +88,7 @@ fn testDeterministicBehavior(allocator: std.mem.Allocator) !void {
const epoch_u64 = @as(u64, epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (iterations >= epoch) break;
try keypair.secret_key.advancePreparation(8);
try keypair.secret_key.advancePreparation(scheme, 8);
iterations += 1;
}

Expand Down Expand Up @@ -192,7 +192,7 @@ fn testKeyPreparationAdvancement(allocator: std.mem.Allocator) !void {
log.print("Initial prepared interval: {} to {}\n", .{ initial_prepared.start, initial_prepared.end });

// Advance preparation
try keypair.secret_key.advancePreparation(8);
try keypair.secret_key.advancePreparation(scheme, 8);
const advanced_prepared = keypair.secret_key.getPreparedInterval(8);
log.print("Advanced prepared interval: {} to {}\n", .{ advanced_prepared.start, advanced_prepared.end });

Expand Down
8 changes: 4 additions & 4 deletions investigations/test/encoding_variants_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn testWinternitzVariants(allocator: std.mem.Allocator) !void {
const epoch_u64 = @as(u64, epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (iterations >= epoch) break;
try keypair.secret_key.advancePreparation(@intCast(log_lifetime));
try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime));
iterations += 1;
}

Expand Down Expand Up @@ -79,7 +79,7 @@ fn testTargetSumVariants(allocator: std.mem.Allocator) !void {
const epoch_u64 = @as(u64, epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (iterations >= epoch) break;
try keypair.secret_key.advancePreparation(@intCast(log_lifetime));
try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime));
iterations += 1;
}

Expand Down Expand Up @@ -128,7 +128,7 @@ fn testMultipleLifetimeConfigurations(allocator: std.mem.Allocator) !void {
const epoch_u64 = @as(u64, epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (iterations >= epoch) break;
try keypair.secret_key.advancePreparation(@intCast(log_lifetime));
try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime));
iterations += 1;
}

Expand Down Expand Up @@ -180,7 +180,7 @@ fn testSignatureSchemeCorrectnessVariants(allocator: std.mem.Allocator) !void {
const epoch_u64 = @as(u64, test_case.test_epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (iterations >= test_case.test_epoch) break;
try keypair.secret_key.advancePreparation(@intCast(log_lifetime));
try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime));
iterations += 1;
}

Expand Down
4 changes: 2 additions & 2 deletions investigations/test/performance_benchmark_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn benchmarkSigning(allocator: std.mem.Allocator, lifetime: hash_zig.KeyLifetime
const epoch_u64 = @as(u64, epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (prep_iterations >= epoch) break;
try keypair.secret_key.advancePreparation(@intCast(log_lifetime));
try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime));
prep_iterations += 1;
}

Expand Down Expand Up @@ -125,7 +125,7 @@ fn benchmarkVerification(allocator: std.mem.Allocator, lifetime: hash_zig.KeyLif
const epoch_u64 = @as(u64, epoch);
if (epoch_u64 >= prepared_interval.start and epoch_u64 < prepared_interval.end) break;
if (prep_iterations >= epoch) break;
try keypair.secret_key.advancePreparation(@intCast(log_lifetime));
try keypair.secret_key.advancePreparation(scheme, @intCast(log_lifetime));
prep_iterations += 1;
}

Expand Down
19 changes: 12 additions & 7 deletions investigations/test/performance_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ test "chacha12 rng compatibility" {
var seed: [32]u8 = undefined;
@memset(&seed, 0x42);

var rng = chacha12_rng.init(seed);
var rng = chacha12_rng.ChaCha12Rng.init(seed);
var prf_key: [32]u8 = undefined;
rng.fill(&prf_key);

Expand Down Expand Up @@ -368,25 +368,30 @@ test "epoch range validation" {
@memset(&seed, 0x77);

// Generate key with limited epoch range
// Note: keyGen(100, 10) expands to align with bottom tree boundaries
// For lifetime 2^8: leafs_per_bottom_tree = 16, min 2 bottom trees = 32 epochs
// Expansion: start aligns to 96, end expands to 128, so valid range is [96, 128)
var keypair = try sig_scheme.keyGen(100, 10);
defer keypair.secret_key.deinit();

log.print("Keypair generated:\n", .{});
log.print(" Activation epoch: {}\n", .{keypair.secret_key.activation_epoch});
log.print(" Active epochs: {}\n", .{keypair.secret_key.num_active_epochs});
log.print(" Requested: activation=100, num_active=10\n", .{});
log.print(" Expanded activation epoch: {}\n", .{keypair.secret_key.activation_epoch});
log.print(" Expanded active epochs: {}\n", .{keypair.secret_key.num_active_epochs});
log.print(" Valid range: {} - {}\n\n", .{ keypair.secret_key.activation_epoch, keypair.secret_key.activation_epoch + keypair.secret_key.num_active_epochs - 1 });

const test_message = [_]u8{ 0x54, 0x65, 0x73, 0x74, 0x20, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65 } ++ [_]u8{0x00} ** 20; // "Test message" + padding

// Test signing within valid range
log.print("Signing at epoch 105 (valid)...", .{});
// Test signing within valid range (epoch 105 is in expanded range [96, 128))
log.print("Signing at epoch 105 (valid, in expanded range [96, 128))...", .{});
var sig_valid = try sig_scheme.sign(keypair.secret_key, 105, test_message);
defer sig_valid.deinit();
log.print(" ✅\n", .{});

// Test signing outside valid range (should fail)
log.print("Signing at epoch 110 (invalid)...", .{});
const result = sig_scheme.sign(keypair.secret_key, 110, test_message);
// Epoch 128 is outside the expanded range [96, 128)
log.print("Signing at epoch 128 (invalid, outside expanded range)...", .{});
const result = sig_scheme.sign(keypair.secret_key, 128, test_message);
try testing.expectError(error.KeyNotActive, result);
log.print(" ✅ Correctly rejected\n", .{});

Expand Down
5 changes: 5 additions & 0 deletions src/poseidon2/plonky3_field.zig
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub const KoalaBearField = struct {
return KoalaBearField{ .value = toMonty(inv_normal) };
}

// Field negation (additive inverse)
pub fn neg(self: KoalaBearField) KoalaBearField {
return KoalaBearField.zero.sub(self);
}

// Double operation (exact from Plonky3)
pub fn double(self: KoalaBearField) KoalaBearField {
return self.add(self);
Expand Down
4 changes: 4 additions & 0 deletions src/signature/mod.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ pub const GeneralizedXMSSPublicKey = native.GeneralizedXMSSPublicKey;
pub const GeneralizedXMSSSecretKey = native.GeneralizedXMSSSecretKey;
pub const GeneralizedXMSSSignature = native.GeneralizedXMSSSignature;
pub const HashTreeOpening = native.HashTreeOpening;

test {
_ = @import("native/scheme.zig");
}
48 changes: 20 additions & 28 deletions src/signature/native/scheme.zig
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ test "SSZ: HashTreeOpening roundtrip" {
const original = try HashTreeOpening.init(allocator, nodes);
defer original.deinit();

// Encode
// Encode (using 8 field elements per hash, matching lifetime_2_8)
var encoded = std.ArrayList(u8).init(allocator);
defer encoded.deinit();
try original.sszEncode(&encoded);
try original.sszEncode(&encoded, 8);

// Decode
// Decode (stack-allocated, so only free nodes, not the struct itself)
var decoded: HashTreeOpening = undefined;
defer decoded.deinit();
try HashTreeOpening.sszDecode(encoded.items, &decoded, allocator);
defer allocator.free(decoded.nodes);
try HashTreeOpening.sszDecode(encoded.items, &decoded, allocator, 8);

// Verify
const decoded_nodes = decoded.getNodes();
Expand Down Expand Up @@ -778,18 +778,18 @@ pub const HashTreeOpening = struct {
}

/// Serialize to SSZ bytes (convenience method matching Rust's to_bytes)
pub fn toBytes(self: *const HashTreeOpening, allocator: std.mem.Allocator) ![]u8 {
pub fn toBytes(self: *const HashTreeOpening, allocator: std.mem.Allocator, hash_len_fe: usize) ![]u8 {
var encoded = std.ArrayList(u8).init(allocator);
errdefer encoded.deinit();
try self.sszEncode(&encoded);
try self.sszEncode(&encoded, hash_len_fe);
return encoded.toOwnedSlice();
}

/// Deserialize from SSZ bytes (convenience method matching Rust's from_bytes)
pub fn fromBytes(serialized: []const u8, allocator: std.mem.Allocator) !*HashTreeOpening {
pub fn fromBytes(serialized: []const u8, allocator: std.mem.Allocator, hash_len_fe: usize) !*HashTreeOpening {
const decoded = try allocator.create(HashTreeOpening);
errdefer allocator.destroy(decoded);
try sszDecode(serialized, decoded, allocator);
try sszDecode(serialized, decoded, allocator, hash_len_fe);
return decoded;
}
};
Expand Down Expand Up @@ -4788,16 +4788,11 @@ pub const GeneralizedXMSSSignatureScheme = struct {
activation_epoch: usize,
num_active_epochs: usize,
) !KeyGenResult {
// Rust leansig library multiplies num_active_epochs by 128 internally
// To match Rust's behavior exactly, we multiply by 128 here
// Example: Input 1024 -> Rust stores 131072 (1024 * 128) in SSZ
const rust_compatible_num_active_epochs = num_active_epochs * 128;

// Generate random parameter and PRF key (matching Rust order exactly)
const parameter = try self.generateRandomParameter();
const prf_key = try self.generateRandomPRFKey();
// RNG has already been consumed by generateRandomPRFKey() (32 bytes)
return self.keyGenWithParameter(activation_epoch, rust_compatible_num_active_epochs, parameter, prf_key, true);
return self.keyGenWithParameter(activation_epoch, num_active_epochs, parameter, prf_key, true);
}

/// Key generation with provided parameter and PRF key (for reconstructing keys from serialized data)
Expand Down Expand Up @@ -4831,17 +4826,18 @@ pub const GeneralizedXMSSSignatureScheme = struct {
return error.InvalidActivationParameters;
}

// Expand activation time to align with bottom trees
// Expand activation time to align with bottom trees (matching Rust exactly)
const leafs_per_bottom_tree = @as(usize, 1) << @intCast(self.lifetime_params.log_lifetime / 2);
const expansion_result = expandActivationTime(self.lifetime_params.log_lifetime, activation_epoch, num_active_epochs);
const num_bottom_trees = expansion_result.end - expansion_result.start;

if (num_bottom_trees < 2) {
return error.InsufficientBottomTrees;
}

// Use the provided activation parameters directly (not expanded)
const expanded_activation_epoch = activation_epoch;
const expanded_num_active_epochs = num_active_epochs;
// Compute expanded values from bottom tree alignment (matching Rust lines 692-693)
const expanded_activation_epoch = expansion_result.start * leafs_per_bottom_tree;
const expanded_num_active_epochs = num_bottom_trees * leafs_per_bottom_tree;

// Consume RNG state only if it hasn't been consumed yet
// When called from keyGen(), the RNG state is already after PRF key generation (32 bytes consumed).
Expand Down Expand Up @@ -5328,17 +5324,13 @@ pub const GeneralizedXMSSSignatureScheme = struct {

// For top tree, use bottom_tree_index directly (absolute position)
// Rust's combined_path uses epoch directly, and the top tree layers are built
// with start_index = left_bottom_tree_index from keyGen, so we use bottom_tree_index
// directly, and computePathFromLayers handles the offset via layer.start_index subtraction
// left_bottom_tree_index already declared above
// with start_index from keyGen (expansion_result.start), so we use bottom_tree_index
// directly. computePathFromLayers handles the offset via layer.start_index subtraction.
// Note: left_bottom_tree_index changes after advancePreparation, but the top tree's
// layer start_index values remain fixed from keygen. The path computation is independent
// of the current prepared position - it uses each layer's own start_index.
const top_pos = @as(u32, @intCast(bottom_tree_index));

// This ensures the path computation uses the correct offset
if (top_layers.len > 0 and top_layers[0].start_index != left_bottom_tree_index) {
log.debugPrint("ZIG_SIGN_ERROR: Top tree layer start_index mismatch! top_layers[0].start_index={}, left_bottom_tree_index={}\n", .{ top_layers[0].start_index, left_bottom_tree_index });
return error.TopTreeStartIndexMismatch;
}

// Debug: log top tree layer start_index values
log.print("ZIG_SIGN_DEBUG: Computing top tree path: bottom_tree_index={}, left_bottom_tree_index={}, top_pos={}\n", .{ bottom_tree_index, left_bottom_tree_index, top_pos });
for (top_layers, 0..) |layer, i| {
Expand Down
Loading