Skip to content

Commit

Permalink
Finish Pseudorandom Signature Tree
Browse files Browse the repository at this point in the history
  • Loading branch information
h114mx001 committed Jun 4, 2024
1 parent c6c428f commit 8f559cd
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 51 deletions.
2 changes: 1 addition & 1 deletion LamportSignature/Lamport.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_verify_key_pair_as_tuple(self) -> tuple:
'''
return (self._verify_key_0, self._verify_key_1)

def get_verify_key_pair_as_byte(self) -> bytes:
def get_verify_key_pair_as_bytes(self) -> bytes:
'''
Get the verify key pair as a bytestream
'''
Expand Down
104 changes: 58 additions & 46 deletions LamportSignature/PRSignatureTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,38 @@ class LamportSignatureTree_Signature:
'''
Define the signature layout for the Lamport signature tree
'''
def __init__(self, counter, authentication_path: list[bytes], message: bytes, signature: bytes):
def __init__(self, counter, authentication_path_signature: list[LamportSignature], signature: LamportSignature):
'''
Initialize the signature with the authentication path, message, and signature
+ authentication_path: The authentication path for the signature
+ message: The message that is signed
+ signature: The signature
'''
self.counter = counter
self.authentication_path = authentication_path
self.message = message
self.authentication_path_signature = authentication_path_signature
self.signature = signature

def serialize(self):
def serialize(self) -> str:
'''
Serialize the signature
'''
return json.dumps({
"counter": self.counter,
"authentication_path": [auth_path.hex() for auth_path in self.authentication_path],
"message": self.message.hex(),
"signature": self.signature.hex()
"authentication_path_signature": [signature.serialize() for signature in self.authentication_path_signature],
"signature": self.signature.serialize()
})

@staticmethod
def deserialize(serialized: str):
def deserialize(serialized_signature: str):
'''
Deserialize the signature
'''
data = json.loads(serialized)
return LamportSignatureTree_Signature(
data["counter"],
[bytes.fromhex(auth_path) for auth_path in data["authentication_path"]],
bytes.fromhex(data["message"]),
bytes.fromhex(data["signature"])
)
signature_dict = json.loads(serialized_signature)
counter = signature_dict["counter"]
authentication_path_signature = [LamportSignature.deserialize(signature) for signature in signature_dict["authentication_path_signature"]]
signature = LamportSignature.deserialize(signature_dict["signature"])
return LamportSignatureTree_Signature(counter, authentication_path_signature, signature)


class LamportPRSignatureTree_Node:
'''
Expand All @@ -67,12 +64,9 @@ def __init__(self, id: int, prf: AES_PRF):
self.signed = False
# signature caching, for the authentication path
self.signature = None

key = self.prf.eval(self.id)
lamport = Lamport_ChaCha20_SHA256_keygen(key)
self.verify_key = lamport.get_verify_key_pair()
# don't want to store the whole lamport object, just the verify key pair
del lamport
self.verify_key = lamport.get_verify_key_pairs()

def is_leaf_node(self):
'''
Expand All @@ -87,48 +81,59 @@ def sign(self, message: bytes):
'''
# if the node is not the leaf node, the signature is actually for authentication path. we can cache it to avoid recomputation.
# else, if the leaf node reuse the sign key, raise error

if self.signed:
if (not self.is_leaf_node()):
return self.signature
raise Exception("No reuse leaf's sign key for different message!")

key = self.prf.eval(self.id)
lamport = Lamport_ChaCha20_SHA256_keygen(key)
signature = lamport.sign(message)
signature = lamport.hash_and_sign(message)
self.signed = True
self.signature = signature
return signature

def get_verify_key_pair_as_bytes(self) -> bytes:
'''
Get the verify key pair as bytes
'''
return b"".join(self.verify_key_0 + self.verify_key_1)

def verify(self, message: bytes, signature: LamportSignature):
'''
Verify a message by the node's verify key pair
+ message: The message to verify
+ signature: The signature to verify
'''
return Lamport_ChaCha20_SHA256_Signature(message, self.verify_key_0, self.verify_key_1, signature).verify()
return self.verify_key.hash_and_verify(message, signature)

def get_children_verify_keys(self):
'''
Get the verify keys of the children's node
'''
left_node_verify_keys = self.left.verify_key.get_verify_key_pair_as_bytes()
right_node_verify_keys = self.right.verify_key.get_verify_key_pair_as_bytes()
return left_node_verify_keys + right_node_verify_keys

def get_authentication_path_signature(self):
def get_authentication_path_signature(self) -> LamportSignature:
'''
Return the signature of the children's node for the authentication path.
'''
# if the node is a leaf node, return None
if self.is_leaf_node():
return None
left_node_verify_keys = self.left.get_verify_key_pair_as_bytes()
right_node_verify_keys = self.right.get_verify_key_pair_as_bytes()
# Only need to have the signature, but not the whole object. we can regen them in verification step later on
return self.sign(left_node_verify_keys + right_node_verify_keys).signature
return self.sign(self.get_children_verify_keys())

def traverse_up_to_root(self):
'''
Traverse up to the root node and return the authentication path
'''
node = self
authentication_path = []
while node is not None:
authentication_path.append(node)
node = node.parent
return authentication_path[:0:-1]

class LamportPRSignatureTree:
'''
Design of a SignatureTree, with prescribed security level L
Design of a SignatureTree, with prescribed security level L.
This signature has a counter start from 0, and the leaf node has the same value with the counter.
'''
def __init__(self, L: int, key: bytes = None):
'''
Expand Down Expand Up @@ -192,7 +197,7 @@ def __get_leaf_with_id(self, counter: int) -> LamportPRSignatureTree_Node:
'''
return self.leaves[counter]

def __get_authentication_path(self, counter: int):
def __get_authentication_path_signature(self, counter: int):
'''
Get the authentication path for a node with a specific id
'''
Expand All @@ -203,30 +208,37 @@ def __get_authentication_path(self, counter: int):
node = node.parent
return authentication_path[:0:-1]

def sign(self, message: bytes):
def sign(self, message: bytes) -> LamportSignatureTree_Signature:
'''
Sign a message on the signature tree
'''
if self.counter > self.capacity:
raise Exception("Signature tree is full")
# get the current node to sign
node = self.__get_leaf_with_id(self.counter)
authentication_path = self.__get_authentication_path(self.counter)
authentication_path_signature = self.__get_authentication_path_signature(self.counter)
signature = node.sign(message)
# again, we don't want to store the whole lamport object, just the signature, as everything has been generated.
signature = LamportSignatureTree_Signature(self.counter, authentication_path, signature.message, signature.signature)
signature = LamportSignatureTree_Signature(self.counter, authentication_path_signature, signature)
self.counter += 1
return signature

def verify(self, signature: LamportSignatureTree_Signature):
def verify(self, offset: int, message: bytes, signature: LamportSignatureTree_Signature):
'''
Verify a signature
Verify a signature, with the offset on the signature tree
+ offset: The offset to verify the signature
+ message: The message to verify
+ signature: The signature to verify
'''
node = self.__get_leaf_with_id(signature.counter)
if not node.verify(signature.message, signature.signature):
leaf_node = self.__get_leaf_with_id(offset)
if not leaf_node.verify(message, signature.signature):
return False
for i, auth_sig in enumerate(signature.authentication_path):
if not node.verify(auth_sig, signature.authentication_path[i]):
authentication_path_signature = signature.authentication_path_signature
authentication_path_nodes = leaf_node.traverse_up_to_root()
if len(authentication_path_signature) != len(authentication_path_nodes):
return False
for n, s in zip(authentication_path_nodes, authentication_path_signature):
node_verify_keys = n.get_children_verify_keys()
if not n.verify(node_verify_keys, s):
return False
node = node.parent
return True
return True
7 changes: 4 additions & 3 deletions LamportSignature/SignatureChain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def deserialize(serialized_signature: str):

class LamportSignatureChain:
'''
Implementation of the signer for Lamport Signature, as a stateful signature chain
Implementation of the signer for Lamport Signature, as a stateful signature chain.
This signature chain has the state started from 1, and the state will be updated after each signing.
'''
def __init__(self, secret_seed: bytes):
'''
Expand All @@ -75,7 +76,7 @@ def sign(self, message: bytes) -> LamportSignatureChain_Signature:
# generate a new Lamport pair. Seed sampled from /dev/urandom for your CSPRNG fetish :)))
new_Lamport = Lamport_ChaCha20_SHA256_keygen(urandom(32))
# sign the message with the current Lamport pair
new_Lamport_verify_key = new_Lamport.get_verify_key_pairs().get_verify_key_pair_as_byte()
new_Lamport_verify_key = new_Lamport.get_verify_key_pairs().get_verify_key_pair_as_bytes()
# hash the message || new_Lamport_verify_key first, to assert 256-bit message
current_message = message + new_Lamport_verify_key
current_signature = self.current_Lamport.hash_and_sign(current_message)
Expand All @@ -99,7 +100,7 @@ def verify(state: int, message: bytes, signature: LamportSignatureChain_Signatur
# print(signature.past_signatures)
state_signature = signature.past_signatures[state-1]
state_verify_key_pair = signature.past_verify_keys[state-1]
state_appended_verify_key_bytes = signature.past_verify_keys[state].get_verify_key_pair_as_byte()
state_appended_verify_key_bytes = signature.past_verify_keys[state].get_verify_key_pair_as_bytes()
# print(state_signature.message.hex())
padded_message = message + state_appended_verify_key_bytes
# print(SHA256.new(padded_message).hexdigest())
Expand Down
65 changes: 64 additions & 1 deletion test/test_LamportSignature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from LamportSignature.Lamport import LamportSignature, LamportSigningKeyPair, Lamport_ChaCha20_SHA256_keygen
from LamportSignature.SignatureChain import LamportSignatureChain_Signature, LamportSignatureChain
from LamportSignature.PRSignatureTree import LamportPRSignatureTree, LamportSignatureTree_Signature
from os import urandom
import configparser
from base64 import b64decode
Expand Down Expand Up @@ -98,4 +99,66 @@ def test_five_times_serialize_deserialize_signature_chain():
assert LamportSignatureChain.verify(state, message, signature) == True
serialized = signature.serialize()
new_signature = LamportSignatureChain_Signature.deserialize(serialized)
assert LamportSignatureChain.verify(state, message, new_signature) == True
assert LamportSignatureChain.verify(state, message, new_signature) == True


def test_one_time_signature_tree():
key = urandom(16)
L = 3
signature_tree = LamportPRSignatureTree(L, key)
message = urandom(32)
signature = signature_tree.sign(message)
result = signature_tree.verify(0, message, signature)
assert result == True

def test_one_time_signature_tree_serialization():
key = urandom(16)
L = 3
signature_tree = LamportPRSignatureTree(L, key)
message = urandom(32)
signature = signature_tree.sign(message)
result = signature_tree.verify(0, message, signature)
assert result == True
serialized = signature.serialize()
new_signature = LamportSignatureTree_Signature.deserialize(serialized)
result = signature_tree.verify(0, message, new_signature)
assert result == True

def test_five_time_signature_tree_serialization():
key = urandom(16)
L = 3
signature_tree = LamportPRSignatureTree(L, key)
for i in range(0, 5):
message = urandom(32)
signature = signature_tree.sign(message)
result = signature_tree.verify(i, message, signature)
assert result == True
serialized = signature.serialize()
new_signature = LamportSignatureTree_Signature.deserialize(serialized)
result = signature_tree.verify(i, message, new_signature)
assert result == True

def test_forty_time_signature_tree_serialization():
key = urandom(16)
L = 6
signature_tree = LamportPRSignatureTree(L, key)
for i in range(0, 40):
message = urandom(32)
signature = signature_tree.sign(message)
result = signature_tree.verify(i, message, signature)
assert result == True
serialized = signature.serialize()
new_signature = LamportSignatureTree_Signature.deserialize(serialized)
result = signature_tree.verify(i, message, new_signature)
assert result == True

def test_overload_sign_tree():
key = urandom(16)
L = 1
signature_tree = LamportPRSignatureTree(L, key)
with pytest.raises(Exception):
for i in range(100):
message = urandom(32)
signature = signature_tree.sign(message)
result = signature_tree.verify(i, message, signature)

0 comments on commit 8f559cd

Please sign in to comment.