diff --git a/core/constants.py b/core/constants.py index 002a4e6..4e68963 100644 --- a/core/constants.py +++ b/core/constants.py @@ -15,8 +15,9 @@ XCHACHA20POLY1305_NONCE_LEN = 24 OTP_PAD_SIZE = 11264 -OTP_PADDING_LENGTH = 2 -OTP_PADDING_LIMIT = 1024 +OTP_MAX_BUCKET = 64 +OTP_MAX_RANDOM_PAD = 16 +OTP_SIZE_LENGTH = 2 SMP_NONCE_LENGTH = 64 SMP_PROOF_LENGTH = 64 @@ -71,5 +72,5 @@ ARGON2_MEMORY = 256 * 1024 # MB ARGON2_ITERS = 3 ARGON2_OUTPUT_LEN = 32 # bytes -ARGON2_SALT_LEN = 16 # bytes (Must be always 16 for interoperability with libsodium.) +ARGON2_SALT_LEN = 16 # bytes (Must be always 16 for interoperability with implementations using libsodium.) ARGON2_LANES = 4 diff --git a/core/crypto.py b/core/crypto.py index 6aa8048..82c4865 100644 --- a/core/crypto.py +++ b/core/crypto.py @@ -14,9 +14,12 @@ import oqs import secrets +from typing import Tuple from core.constants import ( OTP_PAD_SIZE, - OTP_PADDING_LENGTH, + OTP_MAX_RANDOM_PAD, + OTP_SIZE_LENGTH, + OTP_MAX_BUCKET, ML_KEM_1024_NAME, ML_KEM_1024_SK_LEN, ML_KEM_1024_PK_LEN, @@ -59,7 +62,7 @@ def verify_signature(algorithm: str, message: bytes, signature: bytes, public_ke with oqs.Signature(algorithm) as verifier: return verifier.verify(message, signature[:ALGOS_BUFFER_LIMITS[algorithm]["SIGN_LEN"]], public_key[:ALGOS_BUFFER_LIMITS[algorithm]["PK_LEN"]]) -def generate_sign_keys(algorithm: str = ML_DSA_87_NAME): +def generate_sign_keys(algorithm: str = ML_DSA_87_NAME) -> Tuple[bytes, bytes]: """ Generates a new post-quantum signature keypair. @@ -74,34 +77,43 @@ def generate_sign_keys(algorithm: str = ML_DSA_87_NAME): private_key = signer.export_secret_key() return private_key, public_key -def otp_encrypt_with_padding(plaintext: bytes, key: bytes, padding_limit: int) -> bytes: +def otp_encrypt_with_padding(plaintext: bytes, key: bytes) -> Tuple[bytes, bytes]: """ - Encrypts plaintext using a one-time pad with random padding. + Encrypts plaintext using a one-time pad with random or bucket padding. Process: - - Prefixes length of padding. - - Adds random padding (0..padding_limit bytes). + - Prefixes length of message. + - Adds random padding (0..padding_limit bytes) if message > 64 bytes + - If 64 bytes > message, pad message up to 64 bytes, - XORs with one-time pad key. Args: plaintext: Data to encrypt. key: OTP key (>= plaintext length + padding). - padding_limit: Max padding length. Returns: Ciphertext bytes. """ - if padding_limit > ((2 ** (8 * OTP_PADDING_LENGTH)) - 1): - raise ValueError("Padding too large") - plaintext_padding = secrets.token_bytes(padding_limit) - padding_length_bytes = len(plaintext_padding).to_bytes(OTP_PADDING_LENGTH, "big") - padded_plaintext = padding_length_bytes + plaintext + plaintext_padding + if len(plaintext) <= OTP_MAX_BUCKET - OTP_SIZE_LENGTH: + pad_len = OTP_MAX_BUCKET - OTP_SIZE_LENGTH - len(plaintext) + else: + pad_len = secrets.randbelow(OTP_MAX_RANDOM_PAD + 1) + + padding = secrets.token_bytes(pad_len) + + plaintext_length_bytes = len(plaintext).to_bytes(OTP_SIZE_LENGTH, "big") + + padded_plaintext = plaintext_length_bytes + plaintext + padding + + if len(padded_plaintext) > len(key): + raise ValueError("Padded plaintext is larger than key!") + return one_time_pad(padded_plaintext, key) def otp_decrypt_with_padding(ciphertext: bytes, key: bytes) -> bytes: """ - Decrypts one-time pad ciphertext that contains prefixed padding length. + Decrypts one-time pad ciphertext that contains prefixed plaintext length. Args: ciphertext: Ciphertext bytes. @@ -110,11 +122,15 @@ def otp_decrypt_with_padding(ciphertext: bytes, key: bytes) -> bytes: Returns: Original plaintext bytes without padding. """ - plaintext_with_padding = one_time_pad(ciphertext, key) - padding_length = int.from_bytes(plaintext_with_padding[:OTP_PADDING_LENGTH], "big") - if padding_length != 0: - return plaintext_with_padding[OTP_PADDING_LENGTH : -padding_length] - return plaintext_with_padding[OTP_PADDING_LENGTH:] + plaintext_with_padding, _ = one_time_pad(ciphertext, key) + + plaintext_length = int.from_bytes(plaintext_with_padding[:OTP_SIZE_LENGTH], "big") + + if plaintext_length <= 0: + raise ValueError(f"{plaintext_length} plaintext length, ciphertext corrupted or invalid key!") + + return plaintext_with_padding[OTP_SIZE_LENGTH : OTP_SIZE_LENGTH + plaintext_length] + def one_time_pad(plaintext: bytes, key: bytes) -> bytes: """ @@ -131,14 +147,16 @@ def one_time_pad(plaintext: bytes, key: bytes) -> bytes: for index, plain_byte in enumerate(plaintext): key_byte = key[index] otpd_plaintext += bytes([plain_byte ^ key_byte]) - return otpd_plaintext -def generate_kem_keys(algorithm: str): + key = key[len(otpd_plaintext):] + return otpd_plaintext, key + +def generate_kem_keys(algorithm: str) -> Tuple[bytes, bytes]: """ Generates a KEM keypair. Args: - algorithm: PQ KEM algorithm (default Kyber1024). + algorithm: PQ KEM algorithm. Returns: (private_key, public_key) as bytes. @@ -148,23 +166,46 @@ def generate_kem_keys(algorithm: str): private_key = kem.export_secret_key() return private_key, public_key -def encap_shared_secret(public_key: bytes, algorithm: str): +def encap_shared_secret(public_key: bytes, algorithm: str) -> Tuple[bytes, bytes]: + """ + Derive a KEM shared secret from a public key. + + Args: + public_key: KEM public key. + algorithm: KEM algorithm NIST name. + + Returns: + (KEM ciphertext, shared secret) as bytes. + """ + with oqs.KeyEncapsulation(algorithm) as kem: return kem.encap_secret(public_key[:ALGOS_BUFFER_LIMITS[algorithm]["PK_LEN"]]) -def decap_shared_secret(ciphertext: bytes, private_key: bytes, algorithm: str): +def decap_shared_secret(ciphertext: bytes, private_key: bytes, algorithm: str) -> bytes: + """ + Decrypts a single KEM ciphertext to derive a shared secret. + + Args: + ciphertext: KEM ciphertext. + private_key: KEM private key. + algorithm: KEM algorithm NIST name. + size: Desired shared_secret size in bytes. + + Returns: + Shared secret of size as bytes. + """ with oqs.KeyEncapsulation(algorithm, secret_key = private_key[:ALGOS_BUFFER_LIMITS[algorithm]["SK_LEN"]]) as kem: return kem.decap_secret(ciphertext[:ALGOS_BUFFER_LIMITS[algorithm]["CT_LEN"]]) -def decrypt_shared_secrets(ciphertext_blob: bytes, private_key: bytes, algorithm: str = None, otp_pad_size: int = OTP_PAD_SIZE): +def decrypt_shared_secrets(ciphertext_blob: bytes, private_key: bytes, algorithm: str = None, size: int = OTP_PAD_SIZE): """ - Decrypts concatenated KEM ciphertexts to derive shared one-time pad. + Decrypts concatenated KEM ciphertexts to derive shared secret. Args: - ciphertext_blob: Concatenated Kyber ciphertexts. + ciphertext_blob: Concatenated KEM ciphertexts. private_key: KEM private key. algorithm: KEM algorithm NIST name. - otp_pad_size: Desired OTP pad size in bytes. + size: Desired OTP pad size in bytes. Returns: Shared secret OTP pad bytes. @@ -174,7 +215,7 @@ def decrypt_shared_secrets(ciphertext_blob: bytes, private_key: bytes, algorithm cursor = 0 with oqs.KeyEncapsulation(algorithm, secret_key=private_key[:ALGOS_BUFFER_LIMITS[algorithm]["SK_LEN"]]) as kem: - while len(shared_secrets) < otp_pad_size: + while len(shared_secrets) < size: ciphertext = ciphertext_blob[cursor:cursor + cipher_size] if len(ciphertext) != cipher_size: raise ValueError(f"Ciphertext of {algorithm} blob is malformed or incomplete ({len(ciphertext)})") @@ -185,28 +226,28 @@ def decrypt_shared_secrets(ciphertext_blob: bytes, private_key: bytes, algorithm return shared_secrets #[:otp_pad_size] -def generate_shared_secrets(public_key: bytes, algorithm: str = None, otp_pad_size: int = OTP_PAD_SIZE): +def generate_shared_secrets(public_key: bytes, algorithm: str = None, size: int = OTP_PAD_SIZE) -> Tuple[bytes, bytes]: """ - Generates a one-time pad via `algorithm` encapsulation. + Generates many shared secrets via `algorithm` encapsulation in chunks. Args: - public_key: Recipient's public key. + public_key: Recipient's KEM public key. algorithm: KEM algorithm NIST name. - otp_pad_size: Desired OTP pad size in bytes. + size: Desired shared secrets size in bytes. Returns: - (ciphertexts_blob, shared_secrets) for transport & encryption. + (ciphertexts_blob, shared_secrets) as bytes. """ shared_secrets = b'' ciphertexts_blob = b'' with oqs.KeyEncapsulation(algorithm) as kem: - while len(shared_secrets) < otp_pad_size: + while len(shared_secrets) < size: ciphertext, shared_secret = kem.encap_secret(public_key[:ALGOS_BUFFER_LIMITS[algorithm]["PK_LEN"]]) ciphertexts_blob += ciphertext shared_secrets += shared_secret - return ciphertexts_blob, shared_secrets[:otp_pad_size] + return ciphertexts_blob, shared_secrets # [:otp_pad_size] def random_number_range(a: int, b: int) -> int: """ diff --git a/core/trad_crypto.py b/core/trad_crypto.py index 7c0dcdc..b565384 100644 --- a/core/trad_crypto.py +++ b/core/trad_crypto.py @@ -67,7 +67,7 @@ def derive_key_argon2id(password: bytes, salt: bytes = None, output_length: int ), salt -def encrypt_xchacha20poly1305(key: bytes, plaintext: bytes, counter: int = None, counter_safety: int = 2 ** 32) -> tuple[bytes, bytes]: +def encrypt_xchacha20poly1305(key: bytes, plaintext: bytes, nonce: bytes = None, counter: int = None, counter_safety: int = 2 ** 32) -> tuple[bytes, bytes]: """ Encrypt plaintext using ChaCha20Poly1305. @@ -83,7 +83,9 @@ def encrypt_xchacha20poly1305(key: bytes, plaintext: bytes, counter: int = None, - nonce: The randomly generated AES-GCM nonce. - ciphertext: The encrypted data including the authentication tag. """ - nonce = secrets.token_bytes(XCHACHA20POLY1305_NONCE_LEN) + if nonce is None: + nonce = sha3_512(secrets.token_bytes(XCHACHA20POLY1305_NONCE_LEN))[:XCHACHA20POLY1305_NONCE_LEN] + if counter is not None: if counter > counter_safety: raise ValueError("ChaCha counter has overflowen") diff --git a/logic/contacts.py b/logic/contacts.py index 993f3c8..87da493 100644 --- a/logic/contacts.py +++ b/logic/contacts.py @@ -86,6 +86,10 @@ def save_contact(user_data: dict, user_data_lock, contact_id: str) -> None: } }, + "our_strand_key": None, + "our_next_strand_nonce": None, + "contact_next_strand_key": None, + "contact_strand_nonce": None, "our_pads": { "hash_chain": None, "pads": None diff --git a/logic/message.py b/logic/message.py index e6abd12..7e46a34 100644 --- a/logic/message.py +++ b/logic/message.py @@ -13,7 +13,11 @@ from core.requests import http_request from logic.storage import save_account_data from logic.pfs import send_new_ephemeral_keys -from core.trad_crypto import sha3_512 +from core.trad_crypto import ( + sha3_512, + encrypt_xchacha20poly1305, + decrypt_xchacha20poly1305 +) from core.crypto import ( generate_shared_secrets, decrypt_shared_secrets, @@ -23,17 +27,21 @@ otp_encrypt_with_padding, otp_decrypt_with_padding ) -from core.constants import ( - ALGOS_BUFFER_LIMITS, +from core.constants import ( + MESSAGE_HASH_CHAIN_LEN, + OTP_MAX_BUCKET, OTP_PAD_SIZE, - OTP_PADDING_LIMIT, - OTP_PADDING_LENGTH, ML_KEM_1024_NAME, - CLASSIC_MCELIECE_8_F_NAME, + ML_KEM_1024_CT_LEN, ML_DSA_87_NAME, + ML_DSA_87_SIGN_LEN, + CLASSIC_MCELIECE_8_F_NAME, + CLASSIC_MCELIECE_8_F_CT_LEN, + XCHACHA20POLY1305_NONCE_LEN + ) from base64 import b64decode, b64encode -import json +import secrets import logging logger = logging.getLogger(__name__) @@ -41,7 +49,7 @@ def generate_and_send_pads(user_data, user_data_lock, contact_id: str, ui_queue) -> bool: """ - Generates a new hash-chained OTP batch, signs it with Dilithium, and sends it to the server. + Generates a new OTP batch, signs it with ML-DSA-87, encrypt everything and send it to the server. Updates local pad and hash chain state upon success. Returns: bool: True if successful, False otherwise. @@ -54,31 +62,43 @@ def generate_and_send_pads(user_data, user_data_lock, contact_id: str, ui_queue) contact_mceliece_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] our_lt_private_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] + our_strand_key = user_data["contacts"][contact_id]["our_strand_key"] + kyber_ciphertext_blob , kyber_shared_secrets = generate_shared_secrets(contact_kyber_public_key, ML_KEM_1024_NAME) mceliece_ciphertext_blob, mceliece_shared_secrets = generate_shared_secrets(contact_mceliece_public_key, CLASSIC_MCELIECE_8_F_NAME) otp_batch_signature = create_signature(ML_DSA_87_NAME, kyber_ciphertext_blob + mceliece_ciphertext_blob, our_lt_private_key) - otp_batch_signature = b64encode(otp_batch_signature).decode() - payload = { - "otp_hashchain_ciphertext": b64encode(kyber_ciphertext_blob + mceliece_ciphertext_blob).decode(), - "otp_hashchain_signature": otp_batch_signature, - "recipient": contact_id - } + hash_chain_seed = sha3_512(secrets.token_bytes(MESSAGE_HASH_CHAIN_LEN)) + + ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + our_strand_key, + b"\x00" + hash_chain_seed + otp_batch_signature + kyber_ciphertext_blob + mceliece_ciphertext_blob + ) + + try: - http_request(f"{server_url}/messages/send_pads", "POST", payload=payload, auth_token=auth_token) + http_request(f"{server_url}/messages/send", "POST", payload={ + "ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(), + "recipient": contact_id + }, auth_token=auth_token) except Exception: ui_queue.put({"type": "showerror", "title": "Error", "message": "Failed to send our one-time-pads key batch to the server"}) return False - pads = one_time_pad(kyber_shared_secrets, mceliece_shared_secrets) + pads, _ = one_time_pad(kyber_shared_secrets, mceliece_shared_secrets) + + + our_strand_key = sha3_512(pads[:32])[:32] # We update & save only at the end, so if request fails, we do not desync our state. with user_data_lock: - user_data["contacts"][contact_id]["our_pads"]["pads"] = pads[64:] - user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = pads[:64] + user_data["contacts"][contact_id]["our_strand_key"] = our_strand_key + user_data["contacts"][contact_id]["our_pads"]["pads"] = pads[32:] + + user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = hash_chain_seed save_account_data(user_data, user_data_lock) @@ -104,9 +124,10 @@ def send_message_processor(user_data, user_data_lock, contact_id: str, message: contact_kyber_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME] contact_mceliece_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] - our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] + our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] - + + if contact_kyber_public_key is None or contact_mceliece_public_key is None: logger.debug("This shouldn't happen, contact ephemeral keys are not initialized even once yet???") ui_queue.put({ @@ -134,53 +155,54 @@ def send_message_processor(user_data, user_data_lock, contact_id: str, message: our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"] - message_encoded = message.encode("utf-8") - next_hash_chain = sha3_512(our_hash_chain + message_encoded) - message_encoded = next_hash_chain + message_encoded - - message_otp_padding_length = max(0, OTP_PADDING_LIMIT - OTP_PADDING_LENGTH - len(message_encoded)) - - if (len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length) > len(our_pads): - logger.info("Your message size (%d) is larger than our pads size (%s), therefore we are generating new pads for you", len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length, len(our_pads)) - - if not generate_and_send_pads(user_data, user_data_lock, contact_id, ui_queue): - return False - - with user_data_lock: - our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] - our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"] - - # We remove old hashchain from message and calculate new next hash in the chain - message_encoded = message_encoded[64:] - next_hash_chain = sha3_512(our_hash_chain + message_encoded) - message_encoded = next_hash_chain + message_encoded - - message_otp_pad = our_pads[:len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length] + while True: + message_encoded = message.encode("utf-8") + try: + # We one-time-pad encrypt the message with padding + # + # NOTE: The padding only protects short-messages which are easy to infer what is said based purely on message length + # With messages larger than padding_limit, we assume the message entropy give enough security to make an adversary assumption + # of message context (almost) useless. + # + message_encrypted, new_pads = otp_encrypt_with_padding(message_encoded, our_pads) + logger.debug("Our old pad size is %d and new size after the message is %d", len(our_pads), len(new_pads)) + break + except ValueError as e: + logger.debug("Failed to encrypt message to contact (%s) with error: %s", contact_id, str(e)) + logger.info("Your message size (%d) when padded, is larger than our pads size (%s), therefore we are generating new pads for you", len(message), len(our_pads)) + + if not generate_and_send_pads(user_data, user_data_lock, contact_id, ui_queue): + return False - logger.debug("Our pad size is %d and new size after the message is %d", len(our_pads), len(our_pads) - len(message_otp_pad)) + with user_data_lock: + our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] + our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"] + - # We one-time-pad encrypt the message with padding - # - # NOTE: The padding only protects short-messages which are easy to infer what is said based purely on message length - # With messages larger than padding_limit, we assume the message entropy give enough security to make an adversary assumption - # of message context (almost) useless. - # - message_encrypted = otp_encrypt_with_padding(message_encoded, message_otp_pad, padding_limit = message_otp_padding_length) - message_encrypted = b64encode(message_encrypted).decode() # Unlike in other functions, we truncate pads here and compute the next hash chain regardless of request being successful or not # because a malicious server could make our requests fail to force us to re-use the same pad for our next message # which would break all of our security + + next_hash_chain = sha3_512(our_hash_chain + message_encrypted) + with user_data_lock: - user_data["contacts"][contact_id]["our_pads"]["pads"] = user_data["contacts"][contact_id]["our_pads"]["pads"][len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length:] + user_data["contacts"][contact_id]["our_pads"]["pads"] = user_data["contacts"][contact_id]["our_pads"]["pads"][len(message_encrypted):] user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = next_hash_chain + our_strand_key = user_data["contacts"][contact_id]["our_strand_key"] + save_account_data(user_data, user_data_lock) + + ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + our_strand_key, + b"\x01" + next_hash_chain + message_encrypted + ) try: - http_request(f"{server_url}/messages/send_message", "POST", payload = { - "message_encrypted": message_encrypted, + http_request(f"{server_url}/messages/send", "POST", payload = { + "ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(), "recipient": contact_id }, auth_token=auth_token @@ -204,33 +226,58 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic """ contact_id = message["sender"] - if (not (contact_id in user_data_copied["contacts"])): - logger.warning("Contact is missing, maybe we (or they) are not synced? Not sure, but we will ignore this Message request for now") - logger.debug("Our contacts: %s", json.dumps(user_data_copied["contacts"], indent=2)) + if contact_id not in user_data_copied["contacts"]: + logger.error("Contact (%s) is not saved! Skipping message", contact_id) + logger.debug("Our contacts: %s", str(user_data_copied["contacts"])) return if not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["verified"]: - logger.warning("Contact long-term signing key is not verified.. it is possible that this is a MiTM attack, we ignoring this message for now.") + logger.warning("Contact (%s) is not verified! Skipping message", contact_id) return contact_public_key = user_data_copied["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] + contact_strand_key = user_data_copied["contacts"][contact_id]["contact_strand_key"] if contact_public_key is None: - logger.warning("Contact per-contact Dilithium 5 public key is missing.. skipping message") + logger.error("Contact (%s) per-contact ML-DSA-87 public key is missing! Skipping message..", contact_id) return + if not contact_strand_key: + logger.error("Contact (%s) strand key key is missing! Skipping message...", contact_id) + return - logger.debug("Received a new message of type: %s", message["msg_type"]) - if message["msg_type"] == "new_otp_batch": - otp_hashchain_signature = b64decode(message["otp_hashchain_signature"], validate=True) - otp_hashchain_ciphertext = b64decode(message["otp_hashchain_ciphertext"], validate=True) + ciphertext_blob = b64decode(message["ciphertext_blob"], validate = True) + + # Everything from here is not validated by server + try: + msgs_plaintext = decrypt_xchacha20poly1305(contact_strand_key, ciphertext_blob[:XCHACHA20POLY1305_NONCE_LEN], ciphertext_blob[XCHACHA20POLY1305_NONCE_LEN:]) + except Exception as e: + logger.error("Failed to decrypt `ciphertext_blob` from contact (%s) with error: %s", contact_id, str(e)) + return + + + if msgs_plaintext[0] == 0: + logger.debug("Received a new OTP pads batch from contact (%s).", contact_id) - valid_signature = verify_signature(ML_DSA_87_NAME, otp_hashchain_ciphertext, otp_hashchain_signature, contact_public_key) - if not valid_signature: - logger.debug("Invalid OTP_hashchain_ciphertext signature.. possible MiTM ?") + if len(msgs_plaintext) != ( (ML_KEM_1024_CT_LEN + CLASSIC_MCELIECE_8_F_CT_LEN) * (OTP_PAD_SIZE // 32)) + ML_DSA_87_SIGN_LEN + MESSAGE_HASH_CHAIN_LEN + 1: + logger.error("Contact (%s) gave us a otp batch message request with malformed strand plaintext length (%d)", contact_id, len(msgs_plaintext)) + return + + otp_hashchain_signature = msgs_plaintext[1 + MESSAGE_HASH_CHAIN_LEN : MESSAGE_HASH_CHAIN_LEN + ML_DSA_87_SIGN_LEN + 1] + otp_hashchain_ciphertext = msgs_plaintext[ML_DSA_87_SIGN_LEN + MESSAGE_HASH_CHAIN_LEN + 1:] + + contact_hash_chain = msgs_plaintext[1 : MESSAGE_HASH_CHAIN_LEN + 1] + + try: + valid_signature = verify_signature(ML_DSA_87_NAME, otp_hashchain_ciphertext, otp_hashchain_signature, contact_public_key) + if not valid_signature: + logger.error("Invalid `otp_hashchain_ciphertext` signature from contact (%s)! This might be a MiTM attack.", contact_id) + return + except Exception as e: + logger.error("Contact (%s) gave us a messages request with malformed strand signature which generated this error: %s", contact_id, str(e)) return our_kyber_key = user_data_copied["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"] @@ -238,22 +285,27 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic # / 32 because shared secret is 32 bytes try: - contact_kyber_pads = decrypt_shared_secrets(otp_hashchain_ciphertext[:ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["CT_LEN"] * int(OTP_PAD_SIZE / 32)], our_kyber_key, ML_KEM_1024_NAME) - except: - logger.error("Failed to decrypt Kyber's shared_secrets, possible MiTM?") + contact_kyber_pads = decrypt_shared_secrets(otp_hashchain_ciphertext[:ML_KEM_1024_CT_LEN * (OTP_PAD_SIZE // 32)], our_kyber_key, ML_KEM_1024_NAME) + except Exception as e: + logger.error("Failed to decrypt ML-KEM-1024 ciphertext from contact (%s), received error: %s", contact_id, str(e)) return try: - contact_mceliece_pads = decrypt_shared_secrets(otp_hashchain_ciphertext[ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["CT_LEN"] * int(OTP_PAD_SIZE / 32):], our_mceliece_key, CLASSIC_MCELIECE_8_F_NAME) - except: - logger.error("Failed to decrypt McEliece's shared_secrets, possible MiTM?") + contact_mceliece_pads = decrypt_shared_secrets(otp_hashchain_ciphertext[ML_KEM_1024_CT_LEN * (OTP_PAD_SIZE // 32):], our_mceliece_key, CLASSIC_MCELIECE_8_F_NAME) + except Exception as e: + logger.error("Failed to decrypt Classic-McEliece8192128's ciphertext from contact (%s), received error: %s", contact_id, str(e)) return - contact_pads = one_time_pad(contact_kyber_pads, contact_mceliece_pads) + contact_pads, _ = one_time_pad(contact_kyber_pads, contact_mceliece_pads) + contact_strand_key = sha3_512(contact_pads[:32])[:32] + contact_pads = contact_pads[32:] + with user_data_lock: - user_data["contacts"][contact_id]["contact_pads"]["pads"] = contact_pads[64:] - user_data["contacts"][contact_id]["contact_pads"]["hash_chain"] = contact_pads[:64] + user_data["contacts"][contact_id]["contact_pads"]["pads"] = contact_pads + user_data["contacts"][contact_id]["contact_pads"]["hash_chain"] = contact_hash_chain + + user_data["contacts"][contact_id]["contact_strand_key"] = contact_strand_key user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"] += 1 @@ -261,10 +313,10 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic rotation_counter = user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"] - + logger.debug("Incremented McEliece's rotation_counter by 1 (now is %d) for contact (%s)", rotation_counter, contact_id) - logger.info("Saved contact (%s) new batch of One-Time-Pads and hash chain seed", contact_id) + logger.info("Saved contact (%s) new batch of One-Time-Pads, new strand key, and new hash chain seed", contact_id) save_account_data(user_data, user_data_lock) @@ -275,13 +327,30 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic - elif message["msg_type"] == "new_message": - message_encrypted = b64decode(message["message_encrypted"], validate=True) + elif msgs_plaintext[0] == 1: + logger.debug("Received a new message from contact (%s).", contact_id) + + if len(msgs_plaintext) < OTP_MAX_BUCKET + MESSAGE_HASH_CHAIN_LEN + 1: + logger.error("Contact (%s) gave us a message request with malformed strand plaintext length (%d)", contact_id, len(msgs_plaintext)) + return + + + hash_chain = msgs_plaintext[1:MESSAGE_HASH_CHAIN_LEN + 1] + message_encrypted = msgs_plaintext[MESSAGE_HASH_CHAIN_LEN + 1:] + with user_data_lock: contact_pads = user_data["contacts"][contact_id]["contact_pads"]["pads"] contact_hash_chain = user_data["contacts"][contact_id]["contact_pads"]["hash_chain"] + + next_hash_chain = sha3_512(contact_hash_chain + message_encrypted) + + if next_hash_chain != hash_chain: + logger.warning("Message hash chain did not match, this could be a possible replay attack, or a failed tampering attempt. Skipping this message...") + return + + if (not contact_pads) or (len(message_encrypted) > len(contact_pads)): # TODO: Maybe reset our local pads as well? # I feel like we should do something more when we hit this case, but I am not sure. @@ -292,15 +361,6 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic # immediately truncate the pads contact_pads = contact_pads[len(message_encrypted):] - hash_chain = message_decrypted[:64] - message_decrypted = message_decrypted[64:] - - next_hash_chain = sha3_512(contact_hash_chain + message_decrypted) - - if next_hash_chain != hash_chain: - logger.warning("Message hash chain did not match, this could be a possible replay attack, or a failed tampering attempt. Skipping this message...") - return - # and save the new pads and the hash chain with user_data_lock: @@ -322,3 +382,6 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic "contact_id": contact_id, "message": message_decoded }) + + else: + logger.error("Received unknown message type (%d)", msgs_plaintext[0]) diff --git a/logic/pfs.py b/logic/pfs.py index ff449dd..02fcb76 100644 --- a/logic/pfs.py +++ b/logic/pfs.py @@ -19,14 +19,21 @@ random_number_range ) from core.constants import ( - ALGOS_BUFFER_LIMITS, ML_KEM_1024_NAME, ML_DSA_87_NAME, + ML_KEM_1024_PK_LEN, + ML_DSA_87_SIGN_LEN, + XCHACHA20POLY1305_NONCE_LEN, CLASSIC_MCELIECE_8_F_NAME, + CLASSIC_MCELIECE_8_F_PK_LEN, CLASSIC_MCELIECE_8_F_ROTATE_AT, KEYS_HASH_CHAIN_LEN ) -from core.trad_crypto import sha3_512 +from core.trad_crypto import ( + sha3_512, + encrypt_xchacha20poly1305, + decrypt_xchacha20poly1305 +) from base64 import b64encode, b64decode import secrets import copy @@ -40,7 +47,7 @@ def send_new_ephemeral_keys(user_data: dict, user_data_lock: threading.Lock, contact_id: str, ui_queue: queue.Queue) -> None: """ - Generate and send fresh ephemeral keys to a contact. + Generate, encrypt, and send fresh ephemeral keys to a contact. - Maintains a per-contact hash chain for signing key material. - Generates new Kyber1024 keys every call. @@ -61,6 +68,8 @@ def send_new_ephemeral_keys(user_data: dict, user_data_lock: threading.Lock, con server_url = user_data_copied["server_url"] auth_token = user_data_copied["token"] + + our_strand_key = user_data_copied["contacts"][contact_id]["our_strand_key"] rotation_counter = user_data_copied["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"] rotate_at = user_data_copied["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotate_at"] @@ -83,25 +92,26 @@ def send_new_ephemeral_keys(user_data: dict, user_data_lock: threading.Lock, con kyber_private_key, kyber_public_key = generate_kem_keys(ML_KEM_1024_NAME) publickeys_hashchain = our_hash_chain + kyber_public_key - pfs_type = "partial" + rotate_mceliece = False if (rotate_at == rotation_counter) or (user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"] is None): + # Generate Classic McEliece 8192128f keys mceliece_private_key, mceliece_public_key = generate_kem_keys(CLASSIC_MCELIECE_8_F_NAME) publickeys_hashchain += mceliece_public_key - pfs_type = "full" + rotate_mceliece = True # Sign them with our per-contact long-term private key publickeys_hashchain_signature = create_signature(ML_DSA_87_NAME, publickeys_hashchain, lt_sign_private_key) - payload = { - "publickeys_hashchain": b64encode(publickeys_hashchain).decode(), - "hashchain_signature" : b64encode(publickeys_hashchain_signature).decode(), - "recipient" : contact_id, - "pfs_type" : pfs_type - } - + ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + our_strand_key, + publickeys_hashchain_signature + publickeys_hashchain + ) try: - http_request(f"{server_url}/pfs/send_keys", "POST", payload=payload, auth_token=auth_token) + http_request(f"{server_url}/pfs/send_keys", "POST", payload={ + "ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(), + "recipient" : contact_id, + }, auth_token=auth_token) except Exception: ui_queue.put({"type": "showerror", "title": "Error", "message": "Failed to send our ephemeral keys to the server"}) return @@ -115,7 +125,7 @@ def send_new_ephemeral_keys(user_data: dict, user_data_lock: threading.Lock, con "public_key": kyber_public_key } - if pfs_type == "full": + if rotate_mceliece: user_data["tmp"]["new_code_kem_keys"][contact_id] = { "private_key": mceliece_private_key, "public_key": mceliece_public_key @@ -191,34 +201,52 @@ def pfs_data_handler(user_data: dict, user_data_lock: threading.Lock, user_data_ contact_id = message["sender"] if contact_id not in user_data_copied["contacts"]: - logger.error("Contact is not saved., maybe we (or they) are not synced? Ignoring this PFS message.") - logger.debug("Our saved contacts: %s", str(user_data_copied["contacts"])) + logger.error("Contact (%s) is not saved! Skipping message", contact_id) + logger.debug("Our contacts: %s", str(user_data_copied["contacts"])) + return + + if not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["verified"]: + logger.error("Contact long-term signing key is not verified! We will ignore this PFS message.") return - # Contact's per-contact signing public-key contact_lt_public_key = user_data_copied["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] + contact_strand_key = user_data_copied["contacts"][contact_id]["contact_strand_key"] if not contact_lt_public_key: - logger.error("Contact long-term signing key is missing... 0 clue how we reached here, but we aint continuing..") + logger.error("Contact (%s) per-contact ML-DSA-87 public key is missing! Skipping message..", contact_id) return - if not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["verified"]: - logger.error("Contact long-term signing key is not verified! We will ignore this PFS message.") - return + if not contact_strand_key: + logger.error("Contact (%s) strand key key is missing! Skipping message...", contact_id) + return - contact_hashchain_signature = b64decode(message["hashchain_signature"], validate=True) - contact_publickeys_hashchain = b64decode(message["publickeys_hashchain"], validate=True) + ciphertext_blob = b64decode(message["ciphertext_blob"], validate = True) + + # Everything from here is not validated by server + try: + pfs_plaintext = decrypt_xchacha20poly1305(contact_strand_key, ciphertext_blob[:XCHACHA20POLY1305_NONCE_LEN], ciphertext_blob[XCHACHA20POLY1305_NONCE_LEN:]) + except Exception as e: + logger.error("Failed to decrypt `ciphertext_blob` from contact (%s) with error: %s", contact_id, str(e)) + return - valid_signature = verify_signature(ML_DSA_87_NAME, contact_publickeys_hashchain, contact_hashchain_signature, contact_lt_public_key) - if not valid_signature: - logger.error("Invalid ephemeral public-key + hashchain signature from contact (%s)", contact_id) + if (len(pfs_plaintext) < ML_KEM_1024_PK_LEN + ML_DSA_87_SIGN_LEN + KEYS_HASH_CHAIN_LEN) or len(pfs_plaintext) > ML_KEM_1024_PK_LEN + ML_DSA_87_SIGN_LEN + CLASSIC_MCELIECE_8_F_PK_LEN + KEYS_HASH_CHAIN_LEN: + logger.error("Contact (%s) gave us a PFS request with malformed strand plaintext length (%d)", contact_id, len(pfs_plaintext)) return - if message["pfs_type"] not in ["full", "partial"]: - logger.error("contact (%s) sent message of unknown pfs_type (%s)", contact_id, message["pfs_type"]) + contact_hashchain_signature = pfs_plaintext[:ML_DSA_87_SIGN_LEN] + contact_publickeys_hashchain = pfs_plaintext[ML_DSA_87_SIGN_LEN:] + + contact_hash_chain = contact_publickeys_hashchain[:KEYS_HASH_CHAIN_LEN] + + try: + valid_signature = verify_signature(ML_DSA_87_NAME, contact_publickeys_hashchain, contact_hashchain_signature, contact_lt_public_key) + if not valid_signature: + logger.error("Invalid ephemeral public-key + hashchain signature from contact (%s)", contact_id) + return + except Exception as e: + logger.error("Contact (%s) gave us a PFS request with malformed strand signature which generated this error: %s", contact_id, str(e)) return - contact_hash_chain = contact_publickeys_hashchain[:KEYS_HASH_CHAIN_LEN] # If we do not have a hashchain for the contact, we don't need to compute the chain, just save. if not user_data_copied["contacts"][contact_id]["lt_sign_keys"]["contact_hash_chain"]: @@ -233,15 +261,16 @@ def pfs_data_handler(user_data: dict, user_data_lock: threading.Lock, user_data_ logger.error("Contact keys hash chain does not match our computed hash chain! Skipping this PFS message...") return - contact_kyber_public_key = contact_publickeys_hashchain[KEYS_HASH_CHAIN_LEN: ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["PK_LEN"] + KEYS_HASH_CHAIN_LEN] - if message["pfs_type"] == "full": + contact_kyber_public_key = contact_publickeys_hashchain[KEYS_HASH_CHAIN_LEN: ML_KEM_1024_PK_LEN + KEYS_HASH_CHAIN_LEN] + + if len(contact_publickeys_hashchain) == ML_KEM_1024_PK_LEN + CLASSIC_MCELIECE_8_F_PK_LEN + KEYS_HASH_CHAIN_LEN: logger.info("contact (%s) has rotated their Kyber and McEliece keys", contact_id) - contact_mceliece_public_key = contact_publickeys_hashchain[ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["PK_LEN"] + KEYS_HASH_CHAIN_LEN:] + contact_mceliece_public_key = contact_publickeys_hashchain[ML_KEM_1024_PK_LEN + KEYS_HASH_CHAIN_LEN:] with user_data_lock: user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] = contact_mceliece_public_key - elif message["pfs_type"] == "partial": + elif len(contact_publickeys_hashchain) == ML_KEM_1024_PK_LEN + KEYS_HASH_CHAIN_LEN: logger.info("contact (%s) has rotated their Kyber keys", contact_id) with user_data_lock: diff --git a/logic/smp.py b/logic/smp.py index e345b2b..607eeaa 100644 --- a/logic/smp.py +++ b/logic/smp.py @@ -4,32 +4,11 @@ The socialist millionaire problem A variant of Yao's millionaire problem - Guranteed verification certainity IF the answer has enough entropy (for the duration of the process.) + Guaranteed verification certainity IF the answer has enough entropy (for the duration of the process.) This is not **strictly** a SMP implementation, but it is a simplified, human-language variant we made for verifying a contact's long-term public-key. - Our implementation is inspired by Off-The-Record Messaging's SMP implementation. - - - Query server for new SMP verification messages - Check which step we are on - Act accordingly - - Step 1 is initiated by the contact, whom sets a question and an answer, then sends the question to our user - We assume user starts at step 2, step 1 is done by the contact who initiated the verification process - Step 2, we ask our user to provide an answer to the contact's question - Then we compute a proof for our version of the contact's public-key fingerprint - Step 3, the contact receives our proof and tries to compuate the same proof - if it matches, he marks us as verified, otherwise, a failure notice is sent and both user and contact SMP state is deleted - After it matches, contact compuates a proof for his version of our public-key fingerprint - And sends it over - Step 4 user receive this proof and try to compute an identical one - If we succeed, the verification process is complete and we mark contact's as verified - - This provides a strong guarantee of authenticity and integrity for our long-term public keys - IF the answer has enough entropy to be uncrackable *just* for the duration of the process - """ from core.requests import http_request @@ -41,7 +20,7 @@ generate_kem_keys, encap_shared_secret, decap_shared_secret, - + one_time_pad ) from core.trad_crypto import ( derive_key_argon2id, @@ -131,14 +110,18 @@ def smp_step_2(user_data: dict, user_data_lock, contact_id: str, message: dict, signing_private_key, signing_public_key = generate_sign_keys() - our_nonce = secrets.token_bytes(SMP_NONCE_LENGTH) + our_nonce = sha3_512(secrets.token_bytes(SMP_NONCE_LENGTH))[:SMP_NONCE_LENGTH] key_ciphertext, chacha_key = encap_shared_secret(contact_kem_public_key, ML_KEM_1024_NAME) chacha_key = sha3_512(chacha_key)[:32] + our_next_strand_nonce = sha3_512(secrets.token_bytes(XCHACHA20POLY1305_NONCE_LEN))[:XCHACHA20POLY1305_NONCE_LEN] + contact_next_strand_nonce = sha3_512(secrets.token_bytes(XCHACHA20POLY1305_NONCE_LEN))[:XCHACHA20POLY1305_NONCE_LEN] + + ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( chacha_key, - signing_public_key + our_nonce, + signing_public_key + our_nonce + our_next_strand_nonce + contact_next_strand_nonce, counter = 2 ) @@ -167,6 +150,9 @@ def smp_step_2(user_data: dict, user_data_lock, contact_id: str, message: dict, user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] = signing_public_key user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 4 + + user_data["contacts"][contact_id]["our_next_strand_nonce"] = our_next_strand_nonce + user_data["contacts"][contact_id]["contact_next_strand_nonce"] = contact_next_strand_nonce def smp_step_3(user_data: dict, user_data_lock: threading.Lock, contact_id: str, message: dict, ui_queue: queue.Queue()) -> None: @@ -194,15 +180,18 @@ def smp_step_3(user_data: dict, user_data_lock: threading.Lock, contact_id: str, ) contact_signing_public_key = smp_plaintext[:ML_DSA_87_PK_LEN] - contact_nonce = smp_plaintext[ML_DSA_87_PK_LEN:] + contact_nonce = smp_plaintext[ML_DSA_87_PK_LEN: ML_DSA_87_PK_LEN + SMP_NONCE_LENGTH] + + contact_next_strand_nonce = smp_plaintext[ML_DSA_87_PK_LEN + SMP_NONCE_LENGTH: ML_DSA_87_PK_LEN + SMP_NONCE_LENGTH + XCHACHA20POLY1305_NONCE_LEN] + our_next_strand_nonce = smp_plaintext[ML_DSA_87_PK_LEN + SMP_NONCE_LENGTH + XCHACHA20POLY1305_NONCE_LEN:] - our_nonce = secrets.token_bytes(SMP_NONCE_LENGTH) + our_nonce = sha3_512(secrets.token_bytes(SMP_NONCE_LENGTH))[:SMP_NONCE_LENGTH] signing_private_key, signing_public_key = generate_sign_keys() contact_key_fingerprint = sha3_512(contact_signing_public_key) - # Derieve a high-entropy secret key from the low-entropy answer + # Derive a high-entropy secret key from the low-entropy answer argon2id_salt = sha3_512(contact_nonce + our_nonce)[:ARGON2_SALT_LEN] answer_secret, _ = derive_key_argon2id(answer.encode("utf-8"), salt = argon2id_salt, output_length = SMP_ANSWER_OUTPUT_LEN) @@ -211,17 +200,19 @@ def smp_step_3(user_data: dict, user_data_lock: threading.Lock, contact_id: str, our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).digest() logger.debug("Our proof of contact (%s) public-key fingerprint: %s", contact_id, our_proof) + - ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + our_new_strand_nonce = sha3_512(secrets.token_bytes(XCHACHA20POLY1305_NONCE_LEN))[:XCHACHA20POLY1305_NONCE_LEN] + _, ciphertext_blob = encrypt_xchacha20poly1305( chacha_key, - signing_public_key + our_nonce + our_proof + question.encode("utf-8"), - counter = 3 + our_new_strand_nonce + signing_public_key + our_nonce + our_proof + question.encode("utf-8"), + nonce = our_next_strand_nonce ) try: http_request(f"{server_url}/smp/step", "POST", payload = { - "ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(), + "ciphertext_blob": b64encode(ciphertext_blob).decode(), "recipient": contact_id }, auth_token=auth_token) @@ -235,13 +226,17 @@ def smp_step_3(user_data: dict, user_data_lock: threading.Lock, contact_id: str, with user_data_lock: user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] = contact_signing_public_key - user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"] = b64encode(contact_nonce).decode() - user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"] = b64encode(our_nonce).decode() - user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"] = b64encode(chacha_key).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"] = b64encode(contact_nonce).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"] = b64encode(our_nonce).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"] = b64encode(chacha_key).decode() user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] = signing_private_key user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] = signing_public_key + user_data["contacts"][contact_id]["our_next_strand_nonce"] = our_new_strand_nonce + user_data["contacts"][contact_id]["contact_next_strand_nonce"] = contact_next_strand_nonce + + user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 5 @@ -249,20 +244,29 @@ def smp_step_4_request_answer(user_data, user_data_lock, contact_id, message, ui with user_data_lock: tmp_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"]) + contact_next_strand_nonce = user_data["contacts"][contact_id]["contact_next_strand_nonce"] + + ciphertext_blob = b64decode(message["ciphertext_blob"], validate = True) - smp_plaintext = decrypt_xchacha20poly1305(tmp_key, ciphertext_blob[:XCHACHA20POLY1305_NONCE_LEN], ciphertext_blob[XCHACHA20POLY1305_NONCE_LEN:]) + + smp_plaintext = decrypt_xchacha20poly1305(tmp_key, contact_next_strand_nonce, ciphertext_blob) + + contact_new_strand_nonce = smp_plaintext[:XCHACHA20POLY1305_NONCE_LEN] - contact_signing_public_key = smp_plaintext[:ML_DSA_87_PK_LEN] - contact_nonce = b64encode(smp_plaintext[ML_DSA_87_PK_LEN : SMP_NONCE_LENGTH + ML_DSA_87_PK_LEN]).decode() - contact_proof = b64encode(smp_plaintext[SMP_NONCE_LENGTH + ML_DSA_87_PK_LEN : SMP_NONCE_LENGTH + SMP_PROOF_LENGTH + ML_DSA_87_PK_LEN]).decode() - question = smp_plaintext[SMP_NONCE_LENGTH + SMP_PROOF_LENGTH + ML_DSA_87_PK_LEN:].decode("utf-8") + contact_signing_public_key = smp_plaintext[XCHACHA20POLY1305_NONCE_LEN : ML_DSA_87_PK_LEN + XCHACHA20POLY1305_NONCE_LEN] + + contact_nonce = b64encode(smp_plaintext[XCHACHA20POLY1305_NONCE_LEN + ML_DSA_87_PK_LEN : SMP_NONCE_LENGTH + ML_DSA_87_PK_LEN + XCHACHA20POLY1305_NONCE_LEN]).decode() + + contact_proof = b64encode(smp_plaintext[XCHACHA20POLY1305_NONCE_LEN + SMP_NONCE_LENGTH + ML_DSA_87_PK_LEN : SMP_NONCE_LENGTH + SMP_PROOF_LENGTH + ML_DSA_87_PK_LEN + XCHACHA20POLY1305_NONCE_LEN]).decode() + + question = smp_plaintext[SMP_NONCE_LENGTH + XCHACHA20POLY1305_NONCE_LEN + SMP_PROOF_LENGTH + ML_DSA_87_PK_LEN:].decode("utf-8") with user_data_lock: user_data["contacts"][contact_id]["lt_sign_key_smp"]["question"] = question user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_proof"] = contact_proof - # user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 5 + user_data["contacts"][contact_id]["contact_next_strand_nonce"] = contact_new_strand_nonce user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"] = contact_nonce @@ -290,13 +294,15 @@ def smp_step_4_answer_provided(user_data, user_data_lock, contact_id, answer, ui our_signing_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] + our_next_strand_nonce = user_data["contacts"][contact_id]["our_next_strand_nonce"] + tmp_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"]) answer = normalize_answer(answer) our_key_fingerprint = sha3_512(our_signing_public_key) - # Derieve a high-entropy secret key from the low-entropy answer + # Derive a high-entropy secret key from the low-entropy answer argon2id_salt = sha3_512(our_nonce + contact_nonce)[:ARGON2_SALT_LEN] answer_secret, _ = derive_key_argon2id(answer.encode("utf-8"), salt = argon2id_salt, output_length = SMP_ANSWER_OUTPUT_LEN) @@ -307,30 +313,36 @@ def smp_step_4_answer_provided(user_data, user_data_lock, contact_id, answer, ui logger.debug("SMP Proof sent to us: %s", contact_proof) logger.debug("Our compute message: %s", our_proof) + # Verify Contact's version of our public-key fingerprint matches our actual public-key fingerprint # We compare using compare_digest to prevent timing analysis by avoiding content-based short circuiting behaviour if not hmac.compare_digest(our_proof, contact_proof): - logger.warning("SMP Verification failed") + logger.warning("SMP Verification failed at step 4") smp_failure_notify_contact(user_data, user_data_lock, contact_id, ui_queue) return - # We compute proof for contact's public key (signing public key, and the question public key) + # We compute proof for contact's public key (signing public key, and the kem public key) contact_key_fingerprint = sha3_512(contact_signing_public_key + contact_kem_public_key) our_proof = contact_nonce + our_nonce + contact_key_fingerprint our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).digest() - ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + + our_strand_key = sha3_512(secrets.token_bytes(32))[:32] + contact_strand_key = sha3_512(secrets.token_bytes(32))[:32] + + our_new_strand_nonce = sha3_512(secrets.token_bytes(XCHACHA20POLY1305_NONCE_LEN))[:XCHACHA20POLY1305_NONCE_LEN] + _, ciphertext_blob = encrypt_xchacha20poly1305( tmp_key, - our_proof, - counter = 4 + our_new_strand_nonce + our_proof + our_strand_key + contact_strand_key, + nonce = our_next_strand_nonce ) try: http_request(f"{server_url}/smp/step", "POST", payload = { - "ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(), + "ciphertext_blob": b64encode(ciphertext_blob).decode(), "recipient": contact_id }, auth_token=auth_token) except Exception: @@ -339,13 +351,20 @@ def smp_step_4_answer_provided(user_data, user_data_lock, contact_id, answer, ui return + our_strand_key, _ = one_time_pad(sha3_512(answer_secret)[:32], our_strand_key) + contact_strand_key, _ = one_time_pad(sha3_512(answer_secret)[:32], contact_strand_key) + + # We call smp_success at very end to ensure if the requests step fail, we don't alter our local state smp_success(user_data, user_data_lock, contact_id, ui_queue) with user_data_lock: user_data["contacts"][contact_id]["lt_sign_key_smp"]["answer"] = answer + user_data["contacts"][contact_id]["our_next_strand_nonce"] = our_new_strand_nonce + user_data["contacts"][contact_id]["our_strand_key"] = our_strand_key + user_data["contacts"][contact_id]["contact_strand_key"] = contact_strand_key @@ -363,11 +382,12 @@ def smp_step_5(user_data, user_data_lock, contact_id, message, ui_queue) -> None contact_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"], validate=True) tmp_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"]) - + contact_next_strand_nonce = user_data["contacts"][contact_id]["contact_next_strand_nonce"] + our_key_fingerprint = sha3_512(our_signing_public_key + our_kem_public_key) - # Derieve a high-entropy secret key from the low-entropy answer + # Derive a high-entropy secret key from the low-entropy answer argon2id_salt = sha3_512(contact_nonce + our_nonce)[:ARGON2_SALT_LEN] answer_secret, _ = derive_key_argon2id(answer.encode("utf-8"), salt = argon2id_salt, output_length = SMP_ANSWER_OUTPUT_LEN) @@ -376,21 +396,40 @@ def smp_step_5(user_data, user_data_lock, contact_id, message, ui_queue) -> None our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).digest() ciphertext_blob = b64decode(message["ciphertext_blob"], validate = True) - contact_proof = decrypt_xchacha20poly1305(tmp_key, ciphertext_blob[:XCHACHA20POLY1305_NONCE_LEN], ciphertext_blob[XCHACHA20POLY1305_NONCE_LEN:]) + + smp_plaintext = decrypt_xchacha20poly1305(tmp_key, contact_next_strand_nonce, ciphertext_blob) + contact_new_strand_nonce = smp_plaintext[:XCHACHA20POLY1305_NONCE_LEN] + + contact_proof = smp_plaintext[XCHACHA20POLY1305_NONCE_LEN : SMP_PROOF_LENGTH + XCHACHA20POLY1305_NONCE_LEN] + + contact_strand_key = smp_plaintext[XCHACHA20POLY1305_NONCE_LEN + SMP_PROOF_LENGTH : XCHACHA20POLY1305_NONCE_LEN + SMP_PROOF_LENGTH + 32] + our_strand_key = smp_plaintext[XCHACHA20POLY1305_NONCE_LEN + SMP_PROOF_LENGTH + 32:] + logger.debug("SMP Proof sent to us: %s", contact_proof) logger.debug("Our compute message: %s", our_proof) + # Verify Contact's version of our public-key fingerprint matches our actual public-key fingerprint - # We compare using compare_digest to prevent timing analysis by avoiding content-based short circuiting behaviour if not hmac.compare_digest(our_proof, contact_proof): - logger.warning("SMP Verification failed") + logger.warning("SMP Verification failed at step 5") smp_failure_notify_contact(user_data, user_data_lock, contact_id, ui_queue) return + our_strand_key, _ = one_time_pad(sha3_512(answer_secret)[:32], our_strand_key) + contact_strand_key, _ = one_time_pad(sha3_512(answer_secret)[:32], contact_strand_key) + + with user_data_lock: + user_data["contacts"][contact_id]["contact_next_strand_nonce"] = contact_new_strand_nonce + + user_data["contacts"][contact_id]["our_strand_key"] = our_strand_key + user_data["contacts"][contact_id]["contact_strand_key"] = contact_strand_key + + + # We call smp_success at very end to ensure if the requests step fail, we don't alter our local state smp_success(user_data, user_data_lock, contact_id, ui_queue) @@ -487,6 +526,12 @@ def smp_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, mess except Exception: smp_step = 2 + + if "failure" in message: + # Delete SMP state for contact + smp_failure(user_data, user_data_lock, contact_id, ui_queue) + return + # Check if we don't have this contact saved if contact_id not in user_data_copied["contacts"]: # We assume it has to be step 1 because the contact did not exist before @@ -515,10 +560,6 @@ def smp_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, mess elif smp_step == 2: smp_step_2(user_data, user_data_lock, contact_id, message, ui_queue) - elif "failure" in message: - # Delete SMP state for contact - smp_failure(user_data, user_data_lock, contact_id, ui_queue) - elif smp_step == 3: if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]): logger.error("Contact (%s) is not pending verification, yet they sent us a SMP request. Ignoring it.", contact_id) diff --git a/logic/storage.py b/logic/storage.py index e9d496f..0075951 100644 --- a/logic/storage.py +++ b/logic/storage.py @@ -84,6 +84,21 @@ def load_account_data(password = None) -> dict: except TypeError: pass + + try: + user_data["contacts"][contact_id]["our_strand_key"] = b64decode(user_data["contacts"][contact_id]["our_strand_key"], validate=True) + user_data["contacts"][contact_id]["contact_strand_key"] = b64decode(user_data["contacts"][contact_id]["contact_strand_key"], validate=True) + except TypeError: + pass + + try: + user_data["contacts"][contact_id]["our_next_strand_nonce"] = b64decode(user_data["contacts"][contact_id]["our_next_strand_nonce"], validate=True) + user_data["contacts"][contact_id]["contact_next_strand_nonce"] = b64decode(user_data["contacts"][contact_id]["contact_next_strand_nonce"], validate=True) + except TypeError: + pass + + + try: user_data["contacts"][contact_id]["our_pads"]["pads"] = b64decode(user_data["contacts"][contact_id]["our_pads"]["pads"], validate=True) user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = b64decode(user_data["contacts"][contact_id]["our_pads"]["hash_chain"], validate=True) @@ -164,6 +179,21 @@ def save_account_data(user_data: dict, user_data_lock, password = None) -> None: except TypeError: pass + + try: + user_data["contacts"][contact_id]["our_strand_key"] = b64encode(user_data["contacts"][contact_id]["our_strand_key"]).decode() + user_data["contacts"][contact_id]["contact_strand_key"] = b64encode(user_data["contacts"][contact_id]["contact_strand_key"]).decode() + except TypeError: + pass + + try: + user_data["contacts"][contact_id]["our_next_strand_nonce"] = b64encode(user_data["contacts"][contact_id]["our_next_strand_nonce"]).decode() + user_data["contacts"][contact_id]["contact_next_strand_nonce"] = b64encode(user_data["contacts"][contact_id]["contact_next_strand_nonce"]).decode() + except TypeError: + pass + + + try: user_data["contacts"][contact_id]["our_pads"]["pads"] = b64encode(user_data["contacts"][contact_id]["our_pads"]["pads"]).decode() user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = b64encode(user_data["contacts"][contact_id]["our_pads"]["hash_chain"]).decode() diff --git a/tests/test_crypto.py b/tests/test_crypto.py index d03928d..97d4163 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -20,8 +20,8 @@ random_number_range ) from core.constants import ( - OTP_PADDING_LIMIT, - OTP_PADDING_LENGTH, + OTP_SIZE_LENGTH, + OTP_MAX_BUCKET, ML_KEM_1024_NAME, ML_KEM_1024_SK_LEN, ML_KEM_1024_PK_LEN, @@ -38,14 +38,12 @@ ) from core.trad_crypto import sha3_512 -HASH_SIZE = 64 # SHA3-512 output size in bytes - def test_random_number_range(): - min_val, max_val = 100, 1000 + min_val, max_val = 10, 1000 # Check multiple values fall in range - for _ in range(1000): + for _ in range(10000): num = random_number_range(min_val, max_val) assert min_val <= num <= max_val, f"{num} out of range {min_val}-{max_val}" @@ -136,7 +134,7 @@ def test_signature_verifcation(): def test_kem_otp_encryption(): - """Full Kyber OTP exchange and tamper detection test.""" + """ML-KEM-1024 OTP pad derivation and encryption test.""" # Alice creates ephemeral ML-KEM-1024 keypair for PFS alice_private_key, alice_public_key = generate_kem_keys(ML_KEM_1024_NAME) @@ -148,34 +146,32 @@ def test_kem_otp_encryption(): assert ciphertext != bob_pads, "Ciphertext equals pads (should differ)" # First 64 bytes are hash chain seed - bob_hash_chain_seed = bob_pads[:HASH_SIZE] + # bob_hash_chain_seed = bob_pads[:HASH_SIZE] # Alice decrypts ciphertext to recover shared pads plaintext = decrypt_shared_secrets(ciphertext, alice_private_key, ML_KEM_1024_NAME) assert plaintext == bob_pads, "Pads mismatch after decryption" + assert plaintext != ciphertext, "Pads equals Bobs ciphertext" # Bob encrypts a message using OTP with hash chain - message = "Hello, World!" - message_encoded = message.encode("utf-8") - bob_next_hash_chain = sha3_512(bob_hash_chain_seed + message_encoded) - message_encoded = bob_next_hash_chain + message_encoded + message_encoded = "Hello, World!".encode("utf-8") + + encrypted_message, new_pads = otp_encrypt_with_padding(message_encoded, bob_pads) - pad_len = max(0, OTP_PADDING_LIMIT - OTP_PADDING_LENGTH - len(message_encoded)) - otp_pad = bob_pads[:len(message_encoded) + OTP_PADDING_LENGTH + pad_len] - encrypted = otp_encrypt_with_padding(message_encoded, otp_pad, padding_limit=pad_len) + assert encrypted_message != message_encoded, "Ciphertext equals message" + assert new_pads != bob_pads, "Pads did not get truncated after use!" + assert len(encrypted_message) == len(message_encoded) + (OTP_MAX_BUCKET - len(message_encoded)), "Encrypted message length does not match expected length" - assert encrypted != message_encoded, "Ciphertext equals plaintext" - assert len(encrypted) == len(otp_pad), "Ciphertext length mismatch" # Alice decrypts and validates hash chain - decrypted = otp_decrypt_with_padding(encrypted, plaintext[:len(encrypted)]) - recv_hash = decrypted[:HASH_SIZE] - recv_plaintext = decrypted[HASH_SIZE:] - assert recv_plaintext.decode() == message, "Decrypted message mismatch" + decrypted_message = otp_decrypt_with_padding(encrypted_message, plaintext[:len(encrypted_message)]) + assert decrypted_message == message_encoded, "Decrypted message mismatch" - calc_next_hash = sha3_512(bob_hash_chain_seed + recv_plaintext) - assert calc_next_hash == recv_hash, "Hash chain verification failed" + # calc_next_hash = sha3_512(bob_hash_chain_seed + recv_plaintext) + # assert calc_next_hash == recv_hash, "Hash chain verification failed" + # Temporarily disabled until I make new, improved tests. + """ # Tampering test: flip a byte tampered_message = bytearray(encrypted) tampered_message[HASH_SIZE + 1] ^= 0xFF @@ -186,3 +182,4 @@ def test_kem_otp_encryption(): calc_tampered_hash = sha3_512(bob_hash_chain_seed + tampered_plaintext) assert calc_tampered_hash != tampered_hash, "Tampering not detected" + """