diff --git a/core/constants.py b/core/constants.py index 51271bf..a8af7e9 100644 --- a/core/constants.py +++ b/core/constants.py @@ -2,6 +2,9 @@ APP_NAME = "Coldwire" APP_VERSION = "0.1" +# hard-coded filepaths +ACCOUNT_FILE_PATH = "account.coldwire" + # network defaults (seconds) LONGPOLL_MIN = 5 LONGPOLL_MAX = 30 @@ -17,6 +20,7 @@ ML_KEM_1024_NAME = "Kyber1024" ML_KEM_1024_SK_LEN = 3168 ML_KEM_1024_PK_LEN = 1568 +ML_KEM_1024_CT_LEN = 1568 ML_DSA_87_NAME = "Dilithium5" @@ -24,20 +28,37 @@ ML_DSA_87_PK_LEN = 2592 ML_DSA_87_SIGN_LEN = 4595 -ML_BUFFER_LIMITS = { + +CLASSIC_MCELIECE_8_F_NAME = "Classic-McEliece-8192128f" +CLASSIC_MCELIECE_8_F_SK_LEN = 14120 +CLASSIC_MCELIECE_8_F_PK_LEN = 1357824 +CLASSIC_MCELIECE_8_F_CT_LEN = 208 + + +CLASSIC_MCELIECE_8_F_ROTATE_AT = 3 # Default OTP batches needed to be sent for a key rotation to occur + + + +ALGOS_BUFFER_LIMITS = { ML_KEM_1024_NAME: { "SK_LEN": ML_KEM_1024_SK_LEN, - "PK_LEN": ML_KEM_1024_PK_LEN + "PK_LEN": ML_KEM_1024_PK_LEN, + "CT_LEN": ML_KEM_1024_CT_LEN }, ML_DSA_87_NAME: { "SK_LEN" : ML_DSA_87_SK_LEN, "PK_LEN" : ML_DSA_87_PK_LEN, "SIGN_LEN": ML_DSA_87_SIGN_LEN - } + }, + CLASSIC_MCELIECE_8_F_NAME: { + "SK_LEN": CLASSIC_MCELIECE_8_F_SK_LEN, + "PK_LEN": CLASSIC_MCELIECE_8_F_PK_LEN, + "CT_LEN": CLASSIC_MCELIECE_8_F_CT_LEN + }, } # hash parameters -ARGON2_MEMORY = 256 * 1024 # KB +ARGON2_MEMORY = 256 * 1024 # MB ARGON2_ITERS = 3 ARGON2_OUTPUT_LEN = 32 # bytes ARGON2_SALT_LEN = 32 # bytes diff --git a/core/crypto.py b/core/crypto.py index 15e55bb..ec2fcf8 100644 --- a/core/crypto.py +++ b/core/crypto.py @@ -28,7 +28,7 @@ ML_DSA_87_SK_LEN, ML_DSA_87_PK_LEN, ML_DSA_87_SIGN_LEN, - ML_BUFFER_LIMITS + ALGOS_BUFFER_LIMITS ) @@ -44,7 +44,7 @@ def create_signature(algorithm: str, message: bytes, private_key: bytes) -> byte Returns: Signature bytes of fixed size defined by the algorithm. """ - with oqs.Signature(algorithm, secret_key = private_key[:ML_BUFFER_LIMITS[algorithm]["SK_LEN"]]) as signer: + with oqs.Signature(algorithm, secret_key = private_key[:ALGOS_BUFFER_LIMITS[algorithm]["SK_LEN"]]) as signer: return signer.sign(message) def verify_signature(algorithm: str, message: bytes, signature: bytes, public_key: bytes) -> bool: @@ -61,7 +61,7 @@ def verify_signature(algorithm: str, message: bytes, signature: bytes, public_ke True if valid, False if invalid. """ with oqs.Signature(algorithm) as verifier: - return verifier.verify(message, signature[:ML_BUFFER_LIMITS[algorithm]["SIGN_LEN"]], public_key[:ML_BUFFER_LIMITS[algorithm]["PK_LEN"]]) + return verifier.verify(message, signature[:ALGOS_BUFFER_LIMITS[algorithm]["SIGN_LEN"]], public_key[:ALGOS_BUFFER_LIMITS[algorithm]["PK_LEN"]]) def generate_sign_keys(algorithm: str = ML_DSA_87_NAME): """ @@ -137,9 +137,9 @@ def one_time_pad(plaintext: bytes, key: bytes) -> bytes: otpd_plaintext += bytes([plain_byte ^ key_byte]) return otpd_plaintext -def generate_kem_keys(algorithm: str = ML_KEM_1024_NAME): +def generate_kem_keys(algorithm: str): """ - Generates ML-KEM-1024 keypair (Kyber). + Generates a KEM keypair. Args: algorithm: PQ KEM algorithm (default Kyber1024). @@ -152,39 +152,42 @@ def generate_kem_keys(algorithm: str = ML_KEM_1024_NAME): private_key = kem.export_secret_key() return private_key, public_key -def decrypt_kyber_shared_secrets(ciphertext_blob: bytes, private_key: bytes, otp_pad_size: int = OTP_PAD_SIZE): +def decrypt_shared_secrets(ciphertext_blob: bytes, private_key: bytes, algorithm: str = None, otp_pad_size: int = OTP_PAD_SIZE): """ - Decrypts concatenated Kyber ciphertexts to derive shared one-time pad. + Decrypts concatenated KEM ciphertexts to derive shared one-time pad. Args: ciphertext_blob: Concatenated Kyber ciphertexts. - private_key: ML-KEM-1024 private key. + private_key: KEM private key. + algorithm: KEM algorithm NIST name. otp_pad_size: Desired OTP pad size in bytes. Returns: Shared secret OTP pad bytes. """ - cipher_size = 1568 # Kyber1024 ciphertext size + cipher_size = ALGOS_BUFFER_LIMITS[algorithm]["CT_LEN"] # KEM ciphertext size shared_secrets = b'' cursor = 0 - with oqs.KeyEncapsulation(ML_KEM_1024_NAME, secret_key=private_key[:ML_BUFFER_LIMITS[ML_KEM_1024_NAME]["SK_LEN"]]) as kem: + with oqs.KeyEncapsulation(algorithm, secret_key=private_key[:ALGOS_BUFFER_LIMITS[algorithm]["SK_LEN"]]) as kem: while len(shared_secrets) < otp_pad_size: ciphertext = ciphertext_blob[cursor:cursor + cipher_size] if len(ciphertext) != cipher_size: - raise ValueError("Ciphertext blob is malformed or incomplete") + raise ValueError(f"Ciphertext of {algorithm} blob is malformed or incomplete ({len(ciphertext)})") + shared_secret = kem.decap_secret(ciphertext) shared_secrets += shared_secret cursor += cipher_size return shared_secrets[:otp_pad_size] -def generate_kyber_shared_secrets(public_key: bytes, otp_pad_size: int = OTP_PAD_SIZE): +def generate_shared_secrets(public_key: bytes, algorithm: str = None, otp_pad_size: int = OTP_PAD_SIZE): """ - Generates a one-time pad via Kyber encapsulation. + Generates a one-time pad via `algorithm` encapsulation. Args: - public_key: Recipient's ML-KEM-1024 public key. + public_key: Recipient's public key. + algorithm: KEM algorithm NIST name. otp_pad_size: Desired OTP pad size in bytes. Returns: @@ -193,9 +196,9 @@ def generate_kyber_shared_secrets(public_key: bytes, otp_pad_size: int = OTP_PAD shared_secrets = b'' ciphertexts_blob = b'' - with oqs.KeyEncapsulation(ML_KEM_1024_NAME) as kem: + with oqs.KeyEncapsulation(algorithm) as kem: while len(shared_secrets) < otp_pad_size: - ciphertext, shared_secret = kem.encap_secret(public_key[:ML_BUFFER_LIMITS[ML_KEM_1024_NAME]["PK_LEN"]]) + ciphertext, shared_secret = kem.encap_secret(public_key[:ALGOS_BUFFER_LIMITS[algorithm]["PK_LEN"]]) ciphertexts_blob += ciphertext shared_secrets += shared_secret diff --git a/core/trad_crypto.py b/core/trad_crypto.py index c310bb8..d2e94f1 100644 --- a/core/trad_crypto.py +++ b/core/trad_crypto.py @@ -11,6 +11,7 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.kdf.argon2 import Argon2id from core.constants import ( + OTP_PAD_SIZE, AES_GCM_NONCE_LEN, ARGON2_ITERS, ARGON2_MEMORY, @@ -22,6 +23,7 @@ import secrets + def sha3_512(data: bytes) -> bytes: """ Compute a SHA3-512 hash of the given data. @@ -37,12 +39,7 @@ def sha3_512(data: bytes) -> bytes: return h.digest() -def derive_key_argon2id( - password: bytes, - salt: bytes = None, - salt_length: int = ARGON2_SALT_LEN, - output_length: int = ARGON2_OUTPUT_LEN -) -> tuple[bytes, bytes]: +def derive_key_argon2id(password: bytes, salt: bytes = None, salt_length: int = ARGON2_SALT_LEN, output_length: int = ARGON2_OUTPUT_LEN) -> tuple[bytes, bytes]: """ Derive a symmetric key from a password using Argon2id. diff --git a/logic/background_worker.py b/logic/background_worker.py index 26e68a9..ead5598 100644 --- a/logic/background_worker.py +++ b/logic/background_worker.py @@ -1,6 +1,6 @@ from core.requests import http_request from logic.smp import smp_unanswered_questions, smp_data_handler -from logic.pfs import pfs_data_handler +from logic.pfs import pfs_data_handler, update_ephemeral_keys from logic.message import messages_data_handler from core.constants import ( LONGPOLL_MIN, @@ -29,9 +29,11 @@ def background_worker(user_data, user_data_lock, ui_queue, stop_flag): logger.debug("Data longpoll request has timed out, retrying...") continue - logger.debug("SMP messages: %s", json.dumps(response, indent = 2)) + # logger.debug("Data received: %s", json.dumps(response, indent = 2)[:2000]) for message in response["messages"]: + logger.debug("Received data message: %s", json.dumps(message, indent = 2)[:5000]) + # Sanity check universal message fields if (not "sender" in message) or (not message["sender"].isdigit()) or (len(message["sender"]) != 16): logger.error("Impossible condition, either you have discovered a bug in Coldwire, or the server is attempting to denial-of-service you. Skipping data message with no (or malformed) sender...") @@ -52,9 +54,16 @@ def background_worker(user_data, user_data_lock, ui_queue, stop_flag): elif message["data_type"] == "message": messages_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, message) - else: logger.error( "Impossible condition, either you have discovered a bug in Coldwire, or the server is attempting to denial-of-service you. Skipping data message with unknown data type (%s)...", message["data_type"] ) + + # *Sigh* I had to put this here because if we rotate before finishing reading all of the messages + # we would literally overwrite our own key. + # TODO: We need to keep the last used key and use it when decapsulation with new key gives invalid output + # because it might actually take some time for our keys to be uploaded to server + other servers and to the contact. + # + update_ephemeral_keys(user_data, user_data_lock) + diff --git a/logic/contacts.py b/logic/contacts.py index 6ca21d0..a30a30a 100644 --- a/logic/contacts.py +++ b/logic/contacts.py @@ -3,6 +3,11 @@ import json import math +from core.constants import ( + ML_KEM_1024_NAME, + CLASSIC_MCELIECE_8_F_NAME, + CLASSIC_MCELIECE_8_F_ROTATE_AT +) def generate_nickname_id(length: int = 4) -> str: # Calculate nickname ID: digits get >= letters @@ -56,14 +61,23 @@ def save_contact(user_data: dict, user_data_lock, contact_id: str) -> None: "smp_step": None, }, "ephemeral_keys": { - "contact_public_key": None, + "contact_public_keys": { + CLASSIC_MCELIECE_8_F_NAME: None, + ML_KEM_1024_NAME: None + }, "our_keys": { - "public_key": None, - "private_key": None, + CLASSIC_MCELIECE_8_F_NAME: { + "public_key": None, + "private_key": None, + "rotation_counter": 0, + "rotate_at": CLASSIC_MCELIECE_8_F_ROTATE_AT, + }, + ML_KEM_1024_NAME: { + "public_key": None, + "private_key": None, }, - "rotation_counter": None, - "rotate_at": None, + } }, "our_pads": { "hash_chain": None, diff --git a/logic/message.py b/logic/message.py index fc47c4a..a8fc348 100644 --- a/logic/message.py +++ b/logic/message.py @@ -3,7 +3,7 @@ ----------- Message sending, receiving, and one-time-pad key exchange logic. Handles: -- Generation and transmission of Kyber-encrypted OTP batches +- Generation and transmission of hybrid ciphertext OTP batches - Ephemeral key rotation enforcement for PFS - Message encryption/decryption with hash chain integrity checks - Incoming message processing and replay/tampering protection @@ -14,16 +14,21 @@ from logic.pfs import send_new_ephemeral_keys from core.trad_crypto import sha3_512 from core.crypto import ( - generate_kyber_shared_secrets, - decrypt_kyber_shared_secrets, + generate_shared_secrets, + decrypt_shared_secrets, create_signature, verify_signature, + one_time_pad, otp_encrypt_with_padding, otp_decrypt_with_padding ) from core.constants import ( + ALGOS_BUFFER_LIMITS, + OTP_PAD_SIZE, OTP_PADDING_LIMIT, OTP_PADDING_LENGTH, + ML_KEM_1024_NAME, + CLASSIC_MCELIECE_8_F_NAME, ML_DSA_87_NAME, ) from base64 import b64decode, b64encode @@ -35,7 +40,7 @@ def generate_and_send_pads(user_data, user_data_lock, contact_id: str, ui_queue) -> bool: """ - Generates a new Kyber OTP batch, signs it with Dilithium, and sends it to the server. + Generates a new hash-chained OTP batch, signs it with Dilithium, and sends it to the server. Updates local pad and hash chain state upon success. Returns: bool: True if successful, False otherwise. @@ -44,17 +49,20 @@ def generate_and_send_pads(user_data, user_data_lock, contact_id: str, ui_queue) server_url = user_data["server_url"] auth_token = user_data["token"] - contact_kyber_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"] - our_lt_private_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] + contact_kyber_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME] + contact_mceliece_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] + our_lt_private_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] - ciphertext_blob, pads = generate_kyber_shared_secrets(contact_kyber_public_key) - otp_batch_signature = create_signature(ML_DSA_87_NAME, ciphertext_blob, our_lt_private_key) + kyber_ciphertext_blob , kyber_shared_secrets = generate_shared_secrets(contact_kyber_public_key, ML_KEM_1024_NAME) + mceliece_ciphertext_blob, mceliece_shared_secrets = generate_shared_secrets(contact_mceliece_public_key, CLASSIC_MCELIECE_8_F_NAME) + + otp_batch_signature = create_signature(ML_DSA_87_NAME, kyber_ciphertext_blob + mceliece_ciphertext_blob, our_lt_private_key) otp_batch_signature = b64encode(otp_batch_signature).decode() payload = { - "otp_hashchain_ciphertext": b64encode(ciphertext_blob).decode(), + "otp_hashchain_ciphertext": b64encode(kyber_ciphertext_blob + mceliece_ciphertext_blob).decode(), "otp_hashchain_signature": otp_batch_signature, "recipient": contact_id } @@ -63,7 +71,9 @@ def generate_and_send_pads(user_data, user_data_lock, contact_id: str, ui_queue) except Exception: ui_queue.put({"type": "showerror", "title": "Error", "message": "Failed to send our one-time-pads key batch to the server"}) return False - + + pads = one_time_pad(kyber_shared_secrets, mceliece_shared_secrets) + # We update & save only at the end, so if request fails, we do not desync our state. with user_data_lock: user_data["contacts"][contact_id]["our_pads"]["pads"] = pads[64:] @@ -87,60 +97,30 @@ def send_message_processor(user_data, user_data_lock, contact_id: str, message: """ with user_data_lock: - if contact_id in user_data["tmp"]["ephemeral_key_send_lock"]: - ui_queue.put({ - "type": "showwarning", - "title": "Warning", - "message": f"We are waiting for ({contact_id[:32]}) to come online to exchange keys" - }) - return False - server_url = user_data["server_url"] auth_token = user_data["token"] + contact_kyber_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME] + contact_mceliece_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] - contact_kyber_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"] - + our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] + - if (not contact_kyber_public_key): - logger.debug("This shouldn't usually happen, contact kyber keys are not initialized even once yet???") + if contact_kyber_public_key is None or contact_mceliece_public_key is None: + logger.debug("This shouldn't happen, contact ephemeral keys are not initialized even once yet???") ui_queue.put({ "type": "showwarning", "title": f"Warning for {contact_id[:32]}", - "message": "Ephemeral keys have not yet initialized, maybe contact is offline. We will notify you when keys are initialized" + "message": "Ephemeral keys have not yet initialized, and we are not sure why." }) - send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) return False - - with user_data_lock: - our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] - - rotation_counter = user_data["contacts"][contact_id]["ephemeral_keys"]["rotation_counter"] - rotate_at = user_data["contacts"][contact_id]["ephemeral_keys"]["rotate_at"] - - - # We rotate keys before generating and sending new batch of pads because - # ephemeral key exchanges always get processed before messages do. - # Which means if we generate and send pads with contact's, we would be using his old key, which would get overriden by the request, even if we send pads first - # This is because of our server archiecture which prioritizes PFS requests before messages. - # - # Another note, that means after batch ends, and rotation time comes, you won't be able to send messages until other contact is online. - # This will (hopefully) change in a future update - if rotation_counter == rotate_at: - logger.info("We are rotating our ephemeral keys for contact (%s)", contact_id) - ui_queue.put({"type": "showinfo", "title": "Perfect Forward Secrecy", "message": f"We are rotating our ephemeral keys for contact ({contact_id[:32]})"}) - send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) - - save_account_data(user_data, user_data_lock) - return False - - - # If we have keys, but no one-time-pads, we send new pads to the contact + + # If we don't have any one-time-pads, we send new pads to the contact if not our_pads: - logger.debug("We have no pads to send message") + logger.debug("We have no OTP pads to use.") if not generate_and_send_pads(user_data, user_data_lock, contact_id, ui_queue): return False @@ -148,11 +128,7 @@ def send_message_processor(user_data, user_data_lock, contact_id: str, message: with user_data_lock: our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] - user_data["contacts"][contact_id]["ephemeral_keys"]["rotation_counter"] += 1 - - logger.debug("Incremented rotation_counter by 1. (%d)", rotation_counter) - - + with user_data_lock: our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"] @@ -172,10 +148,6 @@ def send_message_processor(user_data, user_data_lock, contact_id: str, message: with user_data_lock: our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"] our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"] - - user_data["contacts"][contact_id]["ephemeral_keys"]["rotation_counter"] += 1 - - logger.debug("Incremented rotation_counter by 1. (%d)", rotation_counter) # We remove old hashchain from message and calculate new next hash in the chain message_encoded = message_encoded[64:] @@ -244,7 +216,7 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic contact_public_key = user_data_copied["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] - if not contact_public_key: + if contact_public_key is None: logger.warning("Contact per-contact Dilithium 5 public key is missing.. skipping message") return @@ -260,23 +232,48 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic logger.debug("Invalid OTP_hashchain_ciphertext signature.. possible MiTM ?") return - our_kyber_key = user_data_copied["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["private_key"] + our_kyber_key = user_data_copied["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"] + our_mceliece_key = user_data_copied["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"] + # / 32 because shared secret is 32 bytes try: - contact_pads = decrypt_kyber_shared_secrets(otp_hashchain_ciphertext, our_kyber_key) + contact_kyber_pads = decrypt_shared_secrets(otp_hashchain_ciphertext[:ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["CT_LEN"] * int(OTP_PAD_SIZE / 32)], our_kyber_key, ML_KEM_1024_NAME) except: - logger.debug("Failed to decrypt shared_secrets, possible MiTM?") + logger.error("Failed to decrypt Kyber's shared_secrets, possible MiTM?") return + try: + contact_mceliece_pads = decrypt_shared_secrets(otp_hashchain_ciphertext[ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["CT_LEN"] * int(OTP_PAD_SIZE / 32):], our_mceliece_key, CLASSIC_MCELIECE_8_F_NAME) + except: + logger.error("Failed to decrypt McEliece's shared_secrets, possible MiTM?") + return + + contact_pads = one_time_pad(contact_kyber_pads, contact_mceliece_pads) with user_data_lock: user_data["contacts"][contact_id]["contact_pads"]["pads"] = contact_pads[64:] user_data["contacts"][contact_id]["contact_pads"]["hash_chain"] = contact_pads[:64] - logger.info("Saved contact (%s) new batch of One-Time-Pads and hash chain seed", contact_id) + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"] += 1 + + new_ml_kem_keys = user_data["tmp"]["new_ml_kem_keys"] + + rotation_counter = user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"] + + logger.debug("Incremented McEliece's rotation_counter by 1 (now is %d) for contact (%s)", rotation_counter, contact_id) + + logger.info("Saved contact (%s) new batch of One-Time-Pads and hash chain seed", contact_id) save_account_data(user_data, user_data_lock) + + if contact_id not in new_ml_kem_keys: + logger.info("Rotating our ephemeral keys") + send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) + save_account_data(user_data, user_data_lock) + + + elif message["msg_type"] == "new_message": message_encrypted = b64decode(message["message_encrypted"], validate=True) @@ -286,7 +283,8 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic if (not contact_pads) or (len(message_encrypted) > len(contact_pads)): # TODO: Maybe reset our local pads as well? - logger.warning("Message payload is larger than our local pads for the contact, we are skipping this message..") + # I feel like we should do something more when we hit this case, but I am not sure. + logger.error("Message payload is larger than our local pads for the contact (%s), we are skipping this message.. This is most likely a bug, please open an issue on Github (https://github.com/Freedom-Club-Sec/Coldwire)", contact_id) return message_decrypted = otp_decrypt_with_padding(message_encrypted, contact_pads[:len(message_encrypted)]) @@ -303,7 +301,7 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic return - # and immediately save the new pads and the hash chain + # and save the new pads and the hash chain with user_data_lock: user_data["contacts"][contact_id]["contact_pads"]["pads"] = contact_pads user_data["contacts"][contact_id]["contact_pads"]["hash_chain"] = next_hash_chain diff --git a/logic/pfs.py b/logic/pfs.py index 409890b..4567bd8 100644 --- a/logic/pfs.py +++ b/logic/pfs.py @@ -6,6 +6,12 @@ create_signature, random_number_range ) +from core.constants import ( + ALGOS_BUFFER_LIMITS, + ML_KEM_1024_NAME, + CLASSIC_MCELIECE_8_F_NAME, + CLASSIC_MCELIECE_8_F_ROTATE_AT +) from core.trad_crypto import sha3_512 from base64 import b64encode, b64decode import secrets @@ -19,10 +25,12 @@ def send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) -> None: with user_data_lock: user_data_copied = copy.deepcopy(user_data) + + rotation_counter = user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"] + rotate_at = user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotate_at"] - - server_url = user_data_copied["server_url"] - auth_token = user_data_copied["token"] + server_url = user_data["server_url"] + auth_token = user_data["token"] lt_sign_private_key = user_data_copied["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] @@ -39,15 +47,23 @@ def send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) -> our_hash_chain = sha3_512(our_hash_chain) # Generate new Kyber1024 keys for us - kyber_private_key, kyber_public_key = generate_kem_keys() + kyber_private_key, kyber_public_key = generate_kem_keys(ML_KEM_1024_NAME) + publickeys_hashchain = our_hash_chain + kyber_public_key + + pfs_type = "partial" + if (rotate_at == rotation_counter) or (user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"] is None): + mceliece_private_key, mceliece_public_key = generate_kem_keys(CLASSIC_MCELIECE_8_F_NAME) + publickeys_hashchain += mceliece_public_key + pfs_type = "full" # Sign them with our per-contact long-term private key - kyber_key_hashchain_signature = create_signature("Dilithium5", our_hash_chain + kyber_public_key, lt_sign_private_key) + publickeys_hashchain_signature = create_signature("Dilithium5", publickeys_hashchain, lt_sign_private_key) payload = { - "kyber_publickey_hashchain": b64encode(our_hash_chain + kyber_public_key).decode(), - "kyber_hashchain_signature": b64encode(kyber_key_hashchain_signature).decode(), - "recipient" : contact_id, + "publickeys_hashchain": b64encode(publickeys_hashchain).decode(), + "hashchain_signature" : b64encode(publickeys_hashchain_signature).decode(), + "recipient" : contact_id, + "pfs_type" : pfs_type } @@ -60,24 +76,50 @@ def send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) -> # We update at the very end to ensure if any of previous steps fail, we do not desync our state with user_data_lock: - user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["private_key"] = kyber_private_key - user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["public_key"] = kyber_public_key + + user_data["tmp"]["new_ml_kem_keys"][contact_id] = { + "private_key": kyber_private_key, + "public_key": kyber_public_key + } + + if pfs_type == "full": + user_data["tmp"]["new_code_kem_keys"][contact_id] = { + "private_key": mceliece_private_key, + "public_key": mceliece_public_key + } + + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"] = 0 + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotate_at"] = CLASSIC_MCELIECE_8_F_ROTATE_AT - # This one should prevent any pad generation and sending, until contact sends us his new ephemeral keys too - # user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"] = None user_data["contacts"][contact_id]["lt_sign_keys"]["our_hash_chain"] = our_hash_chain - # Set rotation counters to rotate every 2 pad batches sent - # TODO: Maybe rotate on every batch instead? and rework the counters, like we don't even need counters if we rotate on every batch sent. - user_data["contacts"][contact_id]["ephemeral_keys"]["rotation_counter"] = 0 - user_data["contacts"][contact_id]["ephemeral_keys"]["rotate_at"] = 2 - # = True, to make it easy for us to delete it later when we receive keys from contact - user_data["tmp"]["ephemeral_key_send_lock"][contact_id] = True +def update_ephemeral_keys(user_data, user_data_lock) -> None: + with user_data_lock: + new_ml_kem_keys = user_data["tmp"]["new_ml_kem_keys"] + new_code_kem_keys = user_data["tmp"]["new_code_kem_keys"] + + for contact_id, v in new_ml_kem_keys.items(): + with user_data_lock: + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"] = v["private_key"] + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["public_key"] = v["public_key"] + + for contact_id, v in new_code_kem_keys.items(): + with user_data_lock: + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"] = v["private_key"] + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["public_key"] = v["public_key"] + + + with user_data_lock: + user_data["tmp"]["new_ml_kem_keys"] = {} + user_data["tmp"]["new_code_kem_keys"] = {} + + + save_account_data(user_data, user_data_lock) @@ -85,7 +127,7 @@ def send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) -> def pfs_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, message) -> None: contact_id = message["sender"] - if (not (contact_id in user_data_copied["contacts"])): + if contact_id not in user_data_copied["contacts"]: logger.error("Contact is missing, maybe we (or they) are not synced? Not sure, but we will ignore this PFS request for now") logger.debug("Our saved contacts: %s", json.dumps(user_data_copied["contacts"], indent=2)) return @@ -99,20 +141,22 @@ def pfs_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, mess return if not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["verified"]: - logger.error("Contact long-term signing key is not verified! it is possible that this is a MiTM attack by the server, we ignoring this PFS for now.") + logger.error("Contact long-term signing key is not verified! it is possible that this is a MiTM attack, we ignoring this PFS for now.") return - contact_kyber_hashchain_signature = b64decode(message["kyber_hashchain_signature"], validate=True) - contact_kyber_publickey_hashchain = b64decode(message["kyber_publickey_hashchain"], validate=True) + contact_hashchain_signature = b64decode(message["hashchain_signature"], validate=True) + contact_publickeys_hashchain = b64decode(message["publickeys_hashchain"], validate=True) - valid_signature = verify_signature("Dilithium5", contact_kyber_publickey_hashchain, contact_kyber_hashchain_signature, contact_lt_public_key) + valid_signature = verify_signature("Dilithium5", contact_publickeys_hashchain, contact_hashchain_signature, contact_lt_public_key) if not valid_signature: - logger.error("Invalid ephemeral kyber public-key + hashchain signature! possible MiTM ?") + logger.error("Invalid ephemeral public-key + hashchain signature from contact (%s)", contact_id) return + if message["pfs_type"] not in ["full", "partial"]: + logger.error("contact (%s) sent message of unknown pfs_type (%s)", contact_id, message["pfs_type"]) + return - contact_kyber_public_key = contact_kyber_publickey_hashchain[64:] - contact_hash_chain = contact_kyber_publickey_hashchain[:64] + contact_hash_chain = contact_publickeys_hashchain[:64] # If we do not have a hashchain for the contact, we don't need to compute the chain, just save. if not user_data_copied["contacts"][contact_id]["lt_sign_keys"]["contact_hash_chain"]: @@ -127,40 +171,30 @@ def pfs_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, mess logger.error("Contact hash chain does not match our computed hash chain, we are skipping this PFS message...") return + contact_kyber_public_key = contact_publickeys_hashchain[64: ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["PK_LEN"] + 64] + if message["pfs_type"] == "full": + logger.info("contact (%s) has rotated their Kyber and McEliece keys", contact_id) + + contact_mceliece_public_key = contact_publickeys_hashchain[ALGOS_BUFFER_LIMITS[ML_KEM_1024_NAME]["PK_LEN"] + 64:] + with user_data_lock: + user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] = contact_mceliece_public_key + + elif message["pfs_type"] == "partial": + logger.info("contact (%s) has rotated their Kyber keys", contact_id) + with user_data_lock: user_data["contacts"][contact_id]["lt_sign_keys"]["contact_hash_chain"] = contact_hash_chain - user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"] = contact_kyber_public_key - - - # TODO: Investigate possible infinite loopback - # Details: What if contact_id wasn't in tmp (i.e. user closed the app and re-opened later ) - # then we re-send the keys ?! And not only that, if the contact also offline, and he receive it - # he will also resend keys - # and if we are offline, we also resend keys. - # so on and so fourth - # Maybe ephemeral_key_send_lock need to be in the contact info, not in tmp ? - - if contact_id in user_data_copied["tmp"]["ephemeral_key_send_lock"]: - logger.debug("We don't have to re-send keys, as we already have sent them, time to inform user of success :)") - - # Incase this was auto-fired by SMP's step 4, we don't want to give the user another popup - if not (contact_id in user_data_copied["tmp"]["pfs_do_not_inform"]): - logger.info("Successfully initialized ephemeral keys with contacts (%s)", contact_id) - ui_queue.put({"type": "showinfo", "title": "Success", "message": f"Successfully initialized ephemeral keys with contact ({contact_id[:32]})"}) - else: - logger.info("Not informing the user of successful ephemeral keys initialization because step was likely automatically fired") - - # We delete incase user end up rotating per-contact long-term keys within the same session - with user_data_lock: - del user_data["tmp"]["pfs_do_not_inform"][contact_id] - else: - # Send our ephemeral keys back to the contact - send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) + user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME] = contact_kyber_public_key - with user_data_lock: - if contact_id in user_data["tmp"]["ephemeral_key_send_lock"]: - del user_data["tmp"]["ephemeral_key_send_lock"][contact_id] + our_kyber_private_key = user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"] + our_mceliece_private_key = user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"] + + new_ml_kem_keys = user_data["tmp"]["new_ml_kem_keys"] + new_code_kem_keys = user_data["tmp"]["new_code_kem_keys"] + if (our_kyber_private_key is None or our_mceliece_private_key is None) and ((contact_id not in new_ml_kem_keys) and (contact_id not in new_code_kem_keys)): + send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) + logger.info("We are sending the contact (%s) our ephemeral keys because we didnt do it before.", contact_id) save_account_data(user_data, user_data_lock) diff --git a/logic/smp.py b/logic/smp.py index 8cee97e..81bc6ec 100644 --- a/logic/smp.py +++ b/logic/smp.py @@ -263,15 +263,10 @@ def smp_step_4(user_data, user_data_lock, contact_id, message, ui_queue) -> None smp_success(user_data, user_data_lock, contact_id, ui_queue) - with user_data_lock: - # = True to make it easy to remove contact_id later - user_data["tmp"]["pfs_do_not_inform"][contact_id] = True # Attempt to automatically exchanger per-contact and ephemeral keys - # We only attempt here and not inside of smp_success because we don't want both contact's attempting to exchange keys at the same time - # - # NOTE: Maybe we need a delay here to ensure if a failure occured in contact's step 3, we catch it and mark - # the contact as unverified despite verifiying him ? + # We only attempt here and not inside of smp_success because we don't want both contact's attempting to exchange keys at the same time + # cuz contact likely still hasnt verified us yet.. (ik its confuysing but just pretend u understand) # send_new_ephemeral_keys(user_data, user_data_lock, contact_id, ui_queue) diff --git a/logic/storage.py b/logic/storage.py index a32fd45..eec4d0d 100644 --- a/logic/storage.py +++ b/logic/storage.py @@ -1,13 +1,17 @@ from pathlib import Path from base64 import b64encode, b64decode +from core.constants import ( + ML_KEM_1024_NAME, + CLASSIC_MCELIECE_8_F_NAME, + ACCOUNT_FILE_PATH + +) import core.trad_crypto as crypto import json import copy import logging -ACCOUNT_FILE_PATH = "account.coldwire" - logger = logging.getLogger(__name__) @@ -34,9 +38,9 @@ def load_account_data(password = None) -> dict: user_data["tmp"] = { - "ephemeral_key_send_lock": {}, - "pfs_do_not_inform": {}, - "password": password + "password": password, + "new_ml_kem_keys": {}, + "new_code_kem_keys": {} } @@ -47,13 +51,24 @@ def load_account_data(password = None) -> dict: for contact_id in user_data["contacts"]: # They probably haven't exchanged yet, so it's fine to skip decoding them try: - user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"], validate=True) + user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME], validate=True) + except TypeError: + pass + + try: + user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME], validate=True) except TypeError: pass try: - user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["private_key"] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["private_key"], validate=True) - user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["public_key"] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["public_key"], validate=True) + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"], validate=True) + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["public_key"] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["public_key"], validate=True) + except TypeError: + pass + + try: + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"], validate=True) + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["public_key"] = b64decode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["public_key"], validate=True) except TypeError: pass @@ -112,16 +127,30 @@ def save_account_data(user_data: dict, user_data_lock, password = None) -> None: for contact_id in user_data["contacts"]: # They probably haven't exchanged yet, so it's fine to skip decoding them try: - user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_key"]).decode() + user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME]).decode() except TypeError: pass + try: - user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["private_key"] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["private_key"]).decode() - user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["public_key"] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"]["public_key"]).decode() + user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME]).decode() except TypeError: pass + + try: + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["private_key"]).decode() + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["public_key"] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][ML_KEM_1024_NAME]["public_key"]).decode() + except TypeError: + pass + + try: + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["private_key"]).decode() + user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["public_key"] = b64encode(user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["public_key"]).decode() + except TypeError: + pass + + try: user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] = b64encode(user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"]).decode() @@ -157,6 +186,7 @@ def save_account_data(user_data: dict, user_data_lock, password = None) -> None: pass + # logger.debug("User_data before saving: %s", str(user_data)) if password is None: with open(ACCOUNT_FILE_PATH, "w", encoding="utf-8") as f: diff --git a/main.py b/main.py index a25a8eb..8e212e4 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ class LevelBasedFormatter(logging.Formatter): FORMATS = { logging.DEBUG: "%(asctime)s [%(levelname)s] %(name)s:%(funcName)s:%(lineno)d - %(message)s", logging.INFO: "%(asctime)s [%(levelname)s] - %(message)s", - logging.WARNING: "%(asctime)s %(levelname)s] %(name)s - %(message)s", + logging.WARNING: "%(asctime)s [%(levelname)s] %(name)s - %(message)s", logging.ERROR: "%(asctime)s [%(levelname)s] %(name)s:%(funcName)s:%(lineno)d - %(message)s", logging.CRITICAL: "%(asctime)s [%(levelname)s] %(name)s:%(funcName)s:%(lineno)d - %(message)s" } diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 252fd9a..175d62f 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -14,8 +14,8 @@ generate_sign_keys, create_signature, verify_signature, - generate_kyber_shared_secrets, - decrypt_kyber_shared_secrets, + generate_shared_secrets, + decrypt_shared_secrets, otp_encrypt_with_padding, otp_decrypt_with_padding, random_number_range @@ -51,7 +51,7 @@ def test_mlkem_keygen_basic(): seen_public_keys = set() for _ in range(10): - private_key, public_key = generate_kem_keys(algorithm = ML_KEM_1024_NAME) + private_key, public_key = generate_kem_keys(ML_KEM_1024_NAME) assert private_key not in seen_private_keys, "Duplicate private key detected" assert public_key not in seen_public_keys, "Duplicate public key detected" @@ -113,20 +113,20 @@ def test_signature_verifcation(): def test_kem_otp_encryption(): """Full Kyber OTP exchange and tamper detection test.""" # Alice creates ephemeral ML-KEM-1024 keypair for PFS - alice_private_key, alice_public_key = generate_kem_keys() + alice_private_key, alice_public_key = generate_kem_keys(ML_KEM_1024_NAME) # Bob creates his own ephemeral keypair - bob_private_key, bob_public_key = generate_kem_keys() + bob_private_key, bob_public_key = generate_kem_keys(ML_KEM_1024_NAME) # Bob derives shared pads from Alice's public key - ciphertext, bob_pads = generate_kyber_shared_secrets(alice_public_key) + ciphertext, bob_pads = generate_shared_secrets(alice_public_key, ML_KEM_1024_NAME) assert ciphertext != bob_pads, "Ciphertext equals pads (should differ)" # First 64 bytes are hash chain seed bob_hash_chain_seed = bob_pads[:HASH_SIZE] # Alice decrypts ciphertext to recover shared pads - plaintext = decrypt_kyber_shared_secrets(ciphertext, alice_private_key) + plaintext = decrypt_shared_secrets(ciphertext, alice_private_key, ML_KEM_1024_NAME) assert plaintext == bob_pads, "Pads mismatch after decryption" # Bob encrypts a message using OTP with hash chain