diff --git a/trex-paxe/src/main/java/com/github/trex_paxos/paxe/Channel.java b/trex-paxe/src/main/java/com/github/trex_paxos/paxe/Channel.java index 034011c..ae72950 100644 --- a/trex-paxe/src/main/java/com/github/trex_paxos/paxe/Channel.java +++ b/trex-paxe/src/main/java/com/github/trex_paxos/paxe/Channel.java @@ -1,8 +1,8 @@ package com.github.trex_paxos.paxe; -public record Channel(byte value) { - public static final byte CONSENSUS = 0; - public static final byte KEY_EXCHANGE = (byte)255; +public record Channel(short value) { + public static final short CONSENSUS = 0; + public static final short KEY_EXCHANGE = 255; static final Channel CONSENSUS_CHANNEL = new Channel(CONSENSUS); static final Channel KEY_EXCHANGE_CHANNEL = new Channel(KEY_EXCHANGE); } \ No newline at end of file diff --git a/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PaxePacket.java b/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PaxePacket.java index 0e7c547..27961db 100644 --- a/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PaxePacket.java +++ b/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PaxePacket.java @@ -4,6 +4,7 @@ import java.security.GeneralSecurityException; import java.util.Arrays; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ThreadLocalRandom; import javax.crypto.Cipher; @@ -12,42 +13,87 @@ public record PaxePacket( NodeId from, - NodeId to, + NodeId to, Channel channel, - byte flags, - byte[] nonce, - byte[] authTag, + Optional nonce, + Optional authTag, byte[] payload) { - public static final int HEADER_SIZE = 6; // from(2), to(2), channel(1), flags(1) - public static final int AUTHENCIATED_DATA_SIZE = 5; // from(2), to(2), channel(1) + public static final int HEADER_SIZE = 8; // from(2) + to(2) + channel(2) + length(2) + public static final int AUTHENCIATED_DATA_SIZE = 6; // from(2) + to(2) + channel(2) public static final int NONCE_SIZE = 12; public static final int AUTH_TAG_SIZE = 16; + public static final int MAX_PACKET_LENGTH = 65535; public PaxePacket { Objects.requireNonNull(from, "from cannot be null"); Objects.requireNonNull(to, "to cannot be null"); Objects.requireNonNull(channel, "channel cannot be null"); + Objects.requireNonNull(payload, "payload cannot be null"); Objects.requireNonNull(nonce, "nonce cannot be null"); Objects.requireNonNull(authTag, "authTag cannot be null"); - Objects.requireNonNull(payload, "payload cannot be null"); - if (nonce.length != NONCE_SIZE) - throw new IllegalArgumentException("Invalid nonce size"); - if (authTag.length != AUTH_TAG_SIZE) - throw new IllegalArgumentException("Invalid auth tag size"); + var totalSize = HEADER_SIZE + payload.length; + if (nonce.isPresent()) { + totalSize += NONCE_SIZE + AUTH_TAG_SIZE; + } + if (totalSize > MAX_PACKET_LENGTH) { + throw new IllegalArgumentException( + String.format("Total payload size %d when adding headers exceeds UDP limit of %d as %d", payload.length, MAX_PACKET_LENGTH, totalSize)); + } + + nonce.ifPresent(n -> { + if (n.length != NONCE_SIZE) + throw new IllegalArgumentException("Invalid nonce size"); + }); + + authTag.ifPresent(t -> { + if (t.length != AUTH_TAG_SIZE) + throw new IllegalArgumentException("Invalid auth tag size"); + }); + + if (nonce.isPresent() != authTag.isPresent()) { + throw new IllegalArgumentException("Both nonce and authTag must be present for encrypted packets"); + } + } + + // Legacy constructor for compatibility + public PaxePacket(NodeId from, NodeId to, Channel channel, byte flags, byte[] nonce, byte[] authTag, byte[] payload) { + this(from, to, channel, + Optional.of(nonce), + Optional.of(authTag), + payload); + } + + // Constructor for unencrypted packets + public PaxePacket(NodeId from, NodeId to, Channel channel, byte[] payload) { + this(from, to, channel, Optional.empty(), Optional.empty(), payload); + } + + private static void putLength(ByteBuffer buffer, int length) { + buffer.put((byte) ((length >>> 8) & 0xFF)); + buffer.put((byte) (length & 0xFF)); + } + + private static int getLength(ByteBuffer buffer) { + return ((buffer.get() & 0xFF) << 8) | (buffer.get() & 0xFF); } public byte[] toBytes() { - var size = HEADER_SIZE + NONCE_SIZE + AUTH_TAG_SIZE + payload.length; + var size = HEADER_SIZE + + (nonce.isPresent() ? NONCE_SIZE + AUTH_TAG_SIZE : 0) + + payload.length; + var buffer = ByteBuffer.allocate(size); buffer.putShort(from.id()); buffer.putShort(to.id()); - buffer.put(channel.value()); - buffer.put(flags); - buffer.put(nonce); - buffer.put(authTag); + buffer.putShort(channel.value()); + putLength(buffer, payload.length); + + nonce.ifPresent(buffer::put); + authTag.ifPresent(buffer::put); buffer.put(payload); + return buffer.array(); } @@ -55,49 +101,56 @@ public static PaxePacket fromBytes(byte[] bytes) { var buffer = ByteBuffer.wrap(bytes); var from = new NodeId(buffer.getShort()); var to = new NodeId(buffer.getShort()); - var channel = new Channel(buffer.get()); - var flags = buffer.get(); + var channel = new Channel(buffer.getShort()); + var payloadLength = getLength(buffer); - var nonce = new byte[NONCE_SIZE]; - buffer.get(nonce); + var remaining = buffer.remaining(); + var isEncrypted = remaining > payloadLength; - var authTag = new byte[AUTH_TAG_SIZE]; - buffer.get(authTag); + Optional nonce = Optional.empty(); + Optional authTag = Optional.empty(); + + if (isEncrypted) { + var n = new byte[NONCE_SIZE]; + buffer.get(n); + nonce = Optional.of(n); - var payload = new byte[buffer.remaining()]; + var t = new byte[AUTH_TAG_SIZE]; + buffer.get(t); + authTag = Optional.of(t); + } + + var payload = new byte[payloadLength]; buffer.get(payload); - return new PaxePacket(from, to, channel, flags, nonce, authTag, payload); + return new PaxePacket(from, to, channel, nonce, authTag, payload); } public byte[] authenticatedData() { var buffer = ByteBuffer.allocate(AUTHENCIATED_DATA_SIZE); buffer.putShort(from.id()); buffer.putShort(to.id()); - buffer.put(channel.value()); + buffer.putShort(channel.value()); return buffer.array(); } - public byte[] ciphertext() { - return payload; // Encrypted payload is the ciphertext - } + public static PaxeMessage decrypt(PaxePacket packet, byte[] key) { + if (packet.nonce.isEmpty() || packet.authTag.isEmpty()) { + throw new SecurityException("Cannot decrypt unencrypted packet"); + } - static PaxeMessage decrypt(PaxePacket packet, byte[] key) { try { - // Real decryption using AES-GCM var cipher = Cipher.getInstance("AES/GCM/NoPadding"); - var gcmSpec = new GCMParameterSpec( - PaxePacket.AUTH_TAG_SIZE * 8, packet.nonce()); + var gcmSpec = new GCMParameterSpec(AUTH_TAG_SIZE * 8, packet.nonce.get()); cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(key, "AES"), gcmSpec); cipher.updateAAD(packet.authenticatedData()); - // Combine ciphertext and tag for decryption - var combined = new byte[packet.payload().length + packet.authTag().length]; - System.arraycopy(packet.payload(), 0, combined, 0, packet.payload().length); - System.arraycopy(packet.authTag(), 0, combined, packet.payload().length, packet.authTag().length); + var combined = new byte[packet.payload.length + AUTH_TAG_SIZE]; + System.arraycopy(packet.payload, 0, combined, 0, packet.payload.length); + System.arraycopy(packet.authTag.get(), 0, combined, packet.payload.length, AUTH_TAG_SIZE); var decrypted = cipher.doFinal(combined); - return PaxeMessage.deserialize(packet.from(), packet.to(), packet.channel(), decrypted); + return PaxeMessage.deserialize(packet.from, packet.to, packet.channel, decrypted); } catch (GeneralSecurityException e) { throw new SecurityException("Decryption failed", e); } @@ -111,23 +164,14 @@ public static PaxePacket encrypt(PaxeMessage message, NodeId from, byte[] key) t var gcmSpec = new GCMParameterSpec(AUTH_TAG_SIZE * 8, nonce); cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(key, "AES"), gcmSpec); - var tempPacket = new PaxePacket( - from, - message.to(), - message.channel(), - (byte) 0, - nonce, - new byte[AUTH_TAG_SIZE], - message.serialize()); - + var tempPacket = new PaxePacket(from, message.to(), message.channel(), message.serialize()); cipher.updateAAD(tempPacket.authenticatedData()); + var ciphertext = cipher.doFinal(message.serialize()); - // Extract the authentication tag from the end of the ciphertext var authTag = new byte[AUTH_TAG_SIZE]; System.arraycopy(ciphertext, ciphertext.length - AUTH_TAG_SIZE, authTag, 0, AUTH_TAG_SIZE); - // Remove the authentication tag from the ciphertext var actualCiphertext = new byte[ciphertext.length - AUTH_TAG_SIZE]; System.arraycopy(ciphertext, 0, actualCiphertext, 0, ciphertext.length - AUTH_TAG_SIZE); @@ -135,30 +179,28 @@ public static PaxePacket encrypt(PaxeMessage message, NodeId from, byte[] key) t from, message.to(), message.channel(), - (byte) 0, - nonce, - authTag, + Optional.of(nonce), + Optional.of(authTag), actualCiphertext); - } + } - @Override + @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof PaxePacket that)) return false; - return flags == that.flags - && from.equals(that.from) + return from.equals(that.from) && to.equals(that.to) && channel.equals(that.channel) - && Arrays.equals(nonce, that.nonce) - && Arrays.equals(authTag, that.authTag) + && Arrays.equals(nonce.orElse(null), that.nonce.orElse(null)) + && Arrays.equals(authTag.orElse(null), that.authTag.orElse(null)) && Arrays.equals(payload, that.payload); } @Override public int hashCode() { - int result = Objects.hash(from, to, channel, flags); - result = 31 * result + Arrays.hashCode(nonce); - result = 31 * result + Arrays.hashCode(authTag); + int result = Objects.hash(from, to, channel); + result = 31 * result + Arrays.hashCode(nonce.orElse(null)); + result = 31 * result + Arrays.hashCode(authTag.orElse(null)); result = 31 * result + Arrays.hashCode(payload); return result; } diff --git a/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PicklePaxe.java b/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PicklePaxe.java index 5e5ba06..40eec96 100644 --- a/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PicklePaxe.java +++ b/trex-paxe/src/main/java/com/github/trex_paxos/paxe/PicklePaxe.java @@ -1,9 +1,7 @@ package com.github.trex_paxos.paxe; import com.github.trex_paxos.msg.*; - import com.github.trex_paxos.*; - import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; @@ -11,19 +9,11 @@ import java.util.UUID; import java.util.stream.IntStream; -/// The PaXE protocol is designed to be compatible with QUIC or even raw UDP. -/// -/// This class pickles and unpickles TrexMessages with a standard from/to/type header. -/// In the case of a DirectMessage, the to field is used to specify the destination node. -/// In the case of a BroadcastMessage, the to field is set to 0. -/// The purpose of the fixed size header is for rapid multiplexing and demultiplexing of messages. public class PicklePaxe { - private static final int HEADER_SIZE = 5; // fromNode(2) + toNode(2) + type(1) private static final int BALLOT_NUMBER_SIZE = Integer.BYTES + 2; // counter(4) + nodeId(2) public static byte[] pickle(TrexMessage msg) { - // Calculate size needed for the message int size = HEADER_SIZE + calculateMessageSize(msg); ByteBuffer buffer = ByteBuffer.allocate(size); writeHeader(msg, buffer); @@ -43,7 +33,6 @@ private static int calculateMessageSize(TrexMessage msg) { }; } - /// This must match the lookup table in the unpickle method public static byte toByte(TrexMessage msg) { return switch (msg) { case Prepare _ -> 1; @@ -68,19 +57,6 @@ private static void writeHeader(TrexMessage msg, ByteBuffer buffer) { buffer.put(type); } - private static void writeMessageBody(TrexMessage msg, ByteBuffer buffer) { - switch (msg) { - case Prepare p -> write(p, buffer); - case PrepareResponse p -> write(p, buffer); - case Accept a -> write(a, buffer); - case AcceptResponse a -> write(a, buffer); - case Fixed f -> write(f, buffer); - case Catchup c -> write(c, buffer); - case CatchupResponse c -> write(c, buffer); - default -> throw new IllegalArgumentException("Unknown message type: " + msg.getClass()); - } - } - public static TrexMessage unpickle(ByteBuffer buffer) { short fromNode = buffer.getShort(); short toNode = buffer.getShort(); @@ -97,6 +73,19 @@ public static TrexMessage unpickle(ByteBuffer buffer) { }; } + private static void writeMessageBody(TrexMessage msg, ByteBuffer buffer) { + switch (msg) { + case Prepare p -> write(p, buffer); + case PrepareResponse p -> write(p, buffer); + case Accept a -> write(a, buffer); + case AcceptResponse a -> write(a, buffer); + case Fixed f -> write(f, buffer); + case Catchup c -> write(c, buffer); + case CatchupResponse c -> write(c, buffer); + default -> throw new IllegalArgumentException("Unknown message type: " + msg.getClass()); + } + } + public static void write(PrepareResponse m, ByteBuffer buffer) { write(m.vote(), buffer); buffer.putLong(m.highestAcceptedIndex()); diff --git a/trex-paxe/src/test/java/com/github/trex_paxos/paxe/PaxePacketTest.java b/trex-paxe/src/test/java/com/github/trex_paxos/paxe/PaxePacketTest.java index cc1adcd..a7f44cb 100644 --- a/trex-paxe/src/test/java/com/github/trex_paxos/paxe/PaxePacketTest.java +++ b/trex-paxe/src/test/java/com/github/trex_paxos/paxe/PaxePacketTest.java @@ -5,10 +5,11 @@ import javax.crypto.KeyGenerator; import java.security.GeneralSecurityException; import java.util.Arrays; +import java.util.Optional; + import static org.junit.jupiter.api.Assertions.*; class PaxePacketTest { - static { System.setProperty(SRPUtils.class.getName() + ".useHash", "SHA-1"); } @@ -19,36 +20,36 @@ class PaxePacketTest { void testConstructorAndGetters() { NodeId from = new NodeId((short) 1); NodeId to = new NodeId((short) 2); - Channel channel = new Channel((byte) 3); - byte flags = 0x04; + Channel channel = new Channel((short) 3); byte[] nonce = new byte[PaxePacket.NONCE_SIZE]; byte[] authTag = new byte[PaxePacket.AUTH_TAG_SIZE]; byte[] payload = "Test payload".getBytes(); - PaxePacket packet = new PaxePacket(from, to, channel, flags, nonce, authTag, payload); + PaxePacket packet = new PaxePacket(from, to, channel, Optional.of(nonce), Optional.of(authTag), payload); assertEquals(from, packet.from()); assertEquals(to, packet.to()); assertEquals(channel, packet.channel()); - assertEquals(flags, packet.flags()); - assertArrayEquals(nonce, packet.nonce()); - assertArrayEquals(authTag, packet.authTag()); + assertArrayEquals(nonce, packet.nonce().orElseThrow()); + assertArrayEquals(authTag, packet.authTag().orElseThrow()); assertArrayEquals(payload, packet.payload()); } @Test void testConstructorWithInvalidNonceSize() { assertThrows(IllegalArgumentException.class, - () -> new PaxePacket(new NodeId((short) 1), new NodeId((short) 2), new Channel((byte) 3), - (byte) 0, new byte[PaxePacket.NONCE_SIZE - 1], new byte[PaxePacket.AUTH_TAG_SIZE], + () -> new PaxePacket(new NodeId((short) 1), new NodeId((short) 2), new Channel((short) 3), + Optional.of(new byte[PaxePacket.NONCE_SIZE - 1]), + Optional.of(new byte[PaxePacket.AUTH_TAG_SIZE]), new byte[0])); } @Test void testConstructorWithInvalidAuthTagSize() { assertThrows(IllegalArgumentException.class, - () -> new PaxePacket(new NodeId((short) 1), new NodeId((short) 2), new Channel((byte) 3), - (byte) 0, new byte[PaxePacket.NONCE_SIZE], new byte[PaxePacket.AUTH_TAG_SIZE - 1], + () -> new PaxePacket(new NodeId((short) 1), new NodeId((short) 2), new Channel((short) 3), + Optional.of(new byte[PaxePacket.NONCE_SIZE]), + Optional.of(new byte[PaxePacket.AUTH_TAG_SIZE - 1]), new byte[0])); } @@ -56,39 +57,37 @@ void testConstructorWithInvalidAuthTagSize() { void testToBytes() { NodeId from = new NodeId((short) 1); NodeId to = new NodeId((short) 2); - Channel channel = new Channel((byte) 3); - byte flags = 0x04; + Channel channel = new Channel((short) 3); byte[] nonce = new byte[PaxePacket.NONCE_SIZE]; byte[] authTag = new byte[PaxePacket.AUTH_TAG_SIZE]; byte[] payload = "Test payload".getBytes(); - PaxePacket packet = new PaxePacket(from, to, channel, flags, nonce, authTag, payload); + PaxePacket packet = new PaxePacket(from, to, channel, Optional.of(nonce), Optional.of(authTag), payload); byte[] bytes = packet.toBytes(); assertEquals(PaxePacket.HEADER_SIZE + PaxePacket.NONCE_SIZE + PaxePacket.AUTH_TAG_SIZE + payload.length, bytes.length); - assertEquals((short) ((bytes[0] << 8) | (bytes[1] & 0xFF)), from.id()); - assertEquals((short) ((bytes[2] << 8) | (bytes[3] & 0xFF)), to.id()); - assertEquals(channel.value(), bytes[4]); - assertEquals(flags, bytes[5]); - assertArrayEquals(nonce, Arrays.copyOfRange(bytes, 6, 6 + PaxePacket.NONCE_SIZE)); - assertArrayEquals(authTag, Arrays.copyOfRange(bytes, 6 + PaxePacket.NONCE_SIZE, - 6 + PaxePacket.NONCE_SIZE + PaxePacket.AUTH_TAG_SIZE)); + assertEquals(from.id(), (short) ((bytes[0] << 8) | (bytes[1] & 0xFF))); + assertEquals(to.id(), (short) ((bytes[2] << 8) | (bytes[3] & 0xFF))); + assertEquals(channel.value(), (short) ((bytes[4] << 8) | (bytes[5] & 0xFF))); + assertEquals(payload.length, ((bytes[6] & 0xFF) << 8) | (bytes[7] & 0xFF)); + assertArrayEquals(nonce, Arrays.copyOfRange(bytes, 8, 8 + PaxePacket.NONCE_SIZE)); + assertArrayEquals(authTag, Arrays.copyOfRange(bytes, 8 + PaxePacket.NONCE_SIZE, + 8 + PaxePacket.NONCE_SIZE + PaxePacket.AUTH_TAG_SIZE)); assertArrayEquals(payload, - Arrays.copyOfRange(bytes, 6 + PaxePacket.NONCE_SIZE + PaxePacket.AUTH_TAG_SIZE, bytes.length)); + Arrays.copyOfRange(bytes, 8 + PaxePacket.NONCE_SIZE + PaxePacket.AUTH_TAG_SIZE, bytes.length)); } @Test void testFromBytes() { NodeId from = new NodeId((short) 1); NodeId to = new NodeId((short) 2); - Channel channel = new Channel((byte) 3); - byte flags = 0x04; + Channel channel = new Channel((short) 3); byte[] nonce = new byte[PaxePacket.NONCE_SIZE]; byte[] authTag = new byte[PaxePacket.AUTH_TAG_SIZE]; byte[] payload = "Test payload".getBytes(); - PaxePacket originalPacket = new PaxePacket(from, to, channel, flags, nonce, authTag, payload); + PaxePacket originalPacket = new PaxePacket(from, to, channel, Optional.of(nonce), Optional.of(authTag), payload); byte[] bytes = originalPacket.toBytes(); PaxePacket reconstructedPacket = PaxePacket.fromBytes(bytes); @@ -100,9 +99,8 @@ void testFromBytes() { void testAuthenticatedData() { NodeId from = new NodeId((short) 1); NodeId to = new NodeId((short) 2); - Channel channel = new Channel((byte) 3); - PaxePacket packet = new PaxePacket(from, to, channel, (byte) 0, new byte[PaxePacket.NONCE_SIZE], - new byte[PaxePacket.AUTH_TAG_SIZE], new byte[0]); + Channel channel = new Channel((short) 3); + PaxePacket packet = new PaxePacket(from, to, channel, Optional.empty(), Optional.empty(), new byte[0]); byte[] authenticatedData = packet.authenticatedData(); @@ -111,7 +109,8 @@ void testAuthenticatedData() { assertEquals((byte) from.id(), authenticatedData[1]); assertEquals((byte) (to.id() >> 8), authenticatedData[2]); assertEquals((byte) to.id(), authenticatedData[3]); - assertEquals(channel.value(), authenticatedData[4]); + assertEquals((byte) (channel.value() >> 8), authenticatedData[4]); + assertEquals((byte) channel.value(), authenticatedData[5]); } @Test @@ -120,20 +119,16 @@ void testEncryptDecrypt() throws GeneralSecurityException { keyGen.init(AES_KEY_SIZE); SecretKey key = keyGen.generateKey(); - // Create input data NodeId from = new NodeId((short) 1); PaxeMessage originalMessage = new PaxeMessage( from, new NodeId((short) 2), - new Channel((byte) 1), + new Channel((short) 1), "Hello, World!".getBytes()); - // Perform encryption and decryption PaxePacket encryptedPacket = PaxePacket.encrypt(originalMessage, from, key.getEncoded()); PaxeMessage decryptedMessage = PaxePacket.decrypt(encryptedPacket, key.getEncoded()); - // Assert equality assertEquals(originalMessage, decryptedMessage); } - -} +} \ No newline at end of file