Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change paxe format #6

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.github.trex_paxos.paxe;

public record Channel(byte value) {
public static final byte CONSENSUS = 0;
public static final byte KEY_EXCHANGE = (byte)255;
public record Channel(short value) {
public static final short CONSENSUS = 0;
public static final short KEY_EXCHANGE = 255;
static final Channel CONSENSUS_CHANNEL = new Channel(CONSENSUS);
static final Channel KEY_EXCHANGE_CHANNEL = new Channel(KEY_EXCHANGE);
}
162 changes: 102 additions & 60 deletions trex-paxe/src/main/java/com/github/trex_paxos/paxe/PaxePacket.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;

import javax.crypto.Cipher;
Expand All @@ -12,92 +13,144 @@

public record PaxePacket(
NodeId from,
NodeId to,
NodeId to,
Channel channel,
byte flags,
byte[] nonce,
byte[] authTag,
Optional<byte[]> nonce,
Optional<byte[]> authTag,
byte[] payload) {

public static final int HEADER_SIZE = 6; // from(2), to(2), channel(1), flags(1)
public static final int AUTHENCIATED_DATA_SIZE = 5; // from(2), to(2), channel(1)
public static final int HEADER_SIZE = 8; // from(2) + to(2) + channel(2) + length(2)
public static final int AUTHENCIATED_DATA_SIZE = 6; // from(2) + to(2) + channel(2)
public static final int NONCE_SIZE = 12;
public static final int AUTH_TAG_SIZE = 16;
public static final int MAX_PACKET_LENGTH = 65535;

public PaxePacket {
Objects.requireNonNull(from, "from cannot be null");
Objects.requireNonNull(to, "to cannot be null");
Objects.requireNonNull(channel, "channel cannot be null");
Objects.requireNonNull(payload, "payload cannot be null");
Objects.requireNonNull(nonce, "nonce cannot be null");
Objects.requireNonNull(authTag, "authTag cannot be null");
Objects.requireNonNull(payload, "payload cannot be null");

if (nonce.length != NONCE_SIZE)
throw new IllegalArgumentException("Invalid nonce size");
if (authTag.length != AUTH_TAG_SIZE)
throw new IllegalArgumentException("Invalid auth tag size");
var totalSize = HEADER_SIZE + payload.length;
if (nonce.isPresent()) {
totalSize += NONCE_SIZE + AUTH_TAG_SIZE;
}
if (totalSize > MAX_PACKET_LENGTH) {
throw new IllegalArgumentException(
String.format("Total payload size %d when adding headers exceeds UDP limit of %d as %d", payload.length, MAX_PACKET_LENGTH, totalSize));
}

nonce.ifPresent(n -> {
if (n.length != NONCE_SIZE)
throw new IllegalArgumentException("Invalid nonce size");
});

authTag.ifPresent(t -> {
if (t.length != AUTH_TAG_SIZE)
throw new IllegalArgumentException("Invalid auth tag size");
});

if (nonce.isPresent() != authTag.isPresent()) {
throw new IllegalArgumentException("Both nonce and authTag must be present for encrypted packets");
}
}

// Legacy constructor for compatibility
public PaxePacket(NodeId from, NodeId to, Channel channel, byte flags, byte[] nonce, byte[] authTag, byte[] payload) {
this(from, to, channel,
Optional.of(nonce),
Optional.of(authTag),
payload);
}

// Constructor for unencrypted packets
public PaxePacket(NodeId from, NodeId to, Channel channel, byte[] payload) {
this(from, to, channel, Optional.empty(), Optional.empty(), payload);
}

private static void putLength(ByteBuffer buffer, int length) {
buffer.put((byte) ((length >>> 8) & 0xFF));
buffer.put((byte) (length & 0xFF));
}

private static int getLength(ByteBuffer buffer) {
return ((buffer.get() & 0xFF) << 8) | (buffer.get() & 0xFF);
}

public byte[] toBytes() {
var size = HEADER_SIZE + NONCE_SIZE + AUTH_TAG_SIZE + payload.length;
var size = HEADER_SIZE +
(nonce.isPresent() ? NONCE_SIZE + AUTH_TAG_SIZE : 0) +
payload.length;

var buffer = ByteBuffer.allocate(size);
buffer.putShort(from.id());
buffer.putShort(to.id());
buffer.put(channel.value());
buffer.put(flags);
buffer.put(nonce);
buffer.put(authTag);
buffer.putShort(channel.value());
putLength(buffer, payload.length);

nonce.ifPresent(buffer::put);
authTag.ifPresent(buffer::put);
buffer.put(payload);

return buffer.array();
}

public static PaxePacket fromBytes(byte[] bytes) {
var buffer = ByteBuffer.wrap(bytes);
var from = new NodeId(buffer.getShort());
var to = new NodeId(buffer.getShort());
var channel = new Channel(buffer.get());
var flags = buffer.get();
var channel = new Channel(buffer.getShort());
var payloadLength = getLength(buffer);

var nonce = new byte[NONCE_SIZE];
buffer.get(nonce);
var remaining = buffer.remaining();
var isEncrypted = remaining > payloadLength;

var authTag = new byte[AUTH_TAG_SIZE];
buffer.get(authTag);
Optional<byte[]> nonce = Optional.empty();
Optional<byte[]> authTag = Optional.empty();

if (isEncrypted) {
var n = new byte[NONCE_SIZE];
buffer.get(n);
nonce = Optional.of(n);

var payload = new byte[buffer.remaining()];
var t = new byte[AUTH_TAG_SIZE];
buffer.get(t);
authTag = Optional.of(t);
}

var payload = new byte[payloadLength];
buffer.get(payload);

return new PaxePacket(from, to, channel, flags, nonce, authTag, payload);
return new PaxePacket(from, to, channel, nonce, authTag, payload);
}

public byte[] authenticatedData() {
var buffer = ByteBuffer.allocate(AUTHENCIATED_DATA_SIZE);
buffer.putShort(from.id());
buffer.putShort(to.id());
buffer.put(channel.value());
buffer.putShort(channel.value());
return buffer.array();
}

public byte[] ciphertext() {
return payload; // Encrypted payload is the ciphertext
}
public static PaxeMessage decrypt(PaxePacket packet, byte[] key) {
if (packet.nonce.isEmpty() || packet.authTag.isEmpty()) {
throw new SecurityException("Cannot decrypt unencrypted packet");
}

static PaxeMessage decrypt(PaxePacket packet, byte[] key) {
try {
// Real decryption using AES-GCM
var cipher = Cipher.getInstance("AES/GCM/NoPadding");
var gcmSpec = new GCMParameterSpec(
PaxePacket.AUTH_TAG_SIZE * 8, packet.nonce());
var gcmSpec = new GCMParameterSpec(AUTH_TAG_SIZE * 8, packet.nonce.get());
cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(key, "AES"), gcmSpec);
cipher.updateAAD(packet.authenticatedData());

// Combine ciphertext and tag for decryption
var combined = new byte[packet.payload().length + packet.authTag().length];
System.arraycopy(packet.payload(), 0, combined, 0, packet.payload().length);
System.arraycopy(packet.authTag(), 0, combined, packet.payload().length, packet.authTag().length);
var combined = new byte[packet.payload.length + AUTH_TAG_SIZE];
System.arraycopy(packet.payload, 0, combined, 0, packet.payload.length);
System.arraycopy(packet.authTag.get(), 0, combined, packet.payload.length, AUTH_TAG_SIZE);

var decrypted = cipher.doFinal(combined);
return PaxeMessage.deserialize(packet.from(), packet.to(), packet.channel(), decrypted);
return PaxeMessage.deserialize(packet.from, packet.to, packet.channel, decrypted);
} catch (GeneralSecurityException e) {
throw new SecurityException("Decryption failed", e);
}
Expand All @@ -111,54 +164,43 @@ public static PaxePacket encrypt(PaxeMessage message, NodeId from, byte[] key) t
var gcmSpec = new GCMParameterSpec(AUTH_TAG_SIZE * 8, nonce);
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(key, "AES"), gcmSpec);

var tempPacket = new PaxePacket(
from,
message.to(),
message.channel(),
(byte) 0,
nonce,
new byte[AUTH_TAG_SIZE],
message.serialize());

var tempPacket = new PaxePacket(from, message.to(), message.channel(), message.serialize());
cipher.updateAAD(tempPacket.authenticatedData());

var ciphertext = cipher.doFinal(message.serialize());

// Extract the authentication tag from the end of the ciphertext
var authTag = new byte[AUTH_TAG_SIZE];
System.arraycopy(ciphertext, ciphertext.length - AUTH_TAG_SIZE, authTag, 0, AUTH_TAG_SIZE);

// Remove the authentication tag from the ciphertext
var actualCiphertext = new byte[ciphertext.length - AUTH_TAG_SIZE];
System.arraycopy(ciphertext, 0, actualCiphertext, 0, ciphertext.length - AUTH_TAG_SIZE);

return new PaxePacket(
from,
message.to(),
message.channel(),
(byte) 0,
nonce,
authTag,
Optional.of(nonce),
Optional.of(authTag),
actualCiphertext);
}
}

@Override
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof PaxePacket that)) return false;
return flags == that.flags
&& from.equals(that.from)
return from.equals(that.from)
&& to.equals(that.to)
&& channel.equals(that.channel)
&& Arrays.equals(nonce, that.nonce)
&& Arrays.equals(authTag, that.authTag)
&& Arrays.equals(nonce.orElse(null), that.nonce.orElse(null))
&& Arrays.equals(authTag.orElse(null), that.authTag.orElse(null))
&& Arrays.equals(payload, that.payload);
}

@Override
public int hashCode() {
int result = Objects.hash(from, to, channel, flags);
result = 31 * result + Arrays.hashCode(nonce);
result = 31 * result + Arrays.hashCode(authTag);
int result = Objects.hash(from, to, channel);
result = 31 * result + Arrays.hashCode(nonce.orElse(null));
result = 31 * result + Arrays.hashCode(authTag.orElse(null));
result = 31 * result + Arrays.hashCode(payload);
return result;
}
Expand Down
37 changes: 13 additions & 24 deletions trex-paxe/src/main/java/com/github/trex_paxos/paxe/PicklePaxe.java
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
package com.github.trex_paxos.paxe;

import com.github.trex_paxos.msg.*;

import com.github.trex_paxos.*;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.IntStream;

/// The PaXE protocol is designed to be compatible with QUIC or even raw UDP.
///
/// This class pickles and unpickles TrexMessages with a standard from/to/type header.
/// In the case of a DirectMessage, the to field is used to specify the destination node.
/// In the case of a BroadcastMessage, the to field is set to 0.
/// The purpose of the fixed size header is for rapid multiplexing and demultiplexing of messages.
public class PicklePaxe {

private static final int HEADER_SIZE = 5; // fromNode(2) + toNode(2) + type(1)
private static final int BALLOT_NUMBER_SIZE = Integer.BYTES + 2; // counter(4) + nodeId(2)

public static byte[] pickle(TrexMessage msg) {
// Calculate size needed for the message
int size = HEADER_SIZE + calculateMessageSize(msg);
ByteBuffer buffer = ByteBuffer.allocate(size);
writeHeader(msg, buffer);
Expand All @@ -43,7 +33,6 @@ private static int calculateMessageSize(TrexMessage msg) {
};
}

/// This must match the lookup table in the unpickle method
public static byte toByte(TrexMessage msg) {
return switch (msg) {
case Prepare _ -> 1;
Expand All @@ -68,19 +57,6 @@ private static void writeHeader(TrexMessage msg, ByteBuffer buffer) {
buffer.put(type);
}

private static void writeMessageBody(TrexMessage msg, ByteBuffer buffer) {
switch (msg) {
case Prepare p -> write(p, buffer);
case PrepareResponse p -> write(p, buffer);
case Accept a -> write(a, buffer);
case AcceptResponse a -> write(a, buffer);
case Fixed f -> write(f, buffer);
case Catchup c -> write(c, buffer);
case CatchupResponse c -> write(c, buffer);
default -> throw new IllegalArgumentException("Unknown message type: " + msg.getClass());
}
}

public static TrexMessage unpickle(ByteBuffer buffer) {
short fromNode = buffer.getShort();
short toNode = buffer.getShort();
Expand All @@ -97,6 +73,19 @@ public static TrexMessage unpickle(ByteBuffer buffer) {
};
}

private static void writeMessageBody(TrexMessage msg, ByteBuffer buffer) {
switch (msg) {
case Prepare p -> write(p, buffer);
case PrepareResponse p -> write(p, buffer);
case Accept a -> write(a, buffer);
case AcceptResponse a -> write(a, buffer);
case Fixed f -> write(f, buffer);
case Catchup c -> write(c, buffer);
case CatchupResponse c -> write(c, buffer);
default -> throw new IllegalArgumentException("Unknown message type: " + msg.getClass());
}
}

public static void write(PrepareResponse m, ByteBuffer buffer) {
write(m.vote(), buffer);
buffer.putLong(m.highestAcceptedIndex());
Expand Down
Loading
Loading