diff --git a/core/constants.py b/core/constants.py index 1f263dd..002a4e6 100644 --- a/core/constants.py +++ b/core/constants.py @@ -12,7 +12,7 @@ # crypto parameters (bytes) CHALLENGE_LEN = 11264 -AES_GCM_NONCE_LEN = 12 +XCHACHA20POLY1305_NONCE_LEN = 24 OTP_PAD_SIZE = 11264 OTP_PADDING_LENGTH = 2 @@ -71,5 +71,5 @@ ARGON2_MEMORY = 256 * 1024 # MB ARGON2_ITERS = 3 ARGON2_OUTPUT_LEN = 32 # bytes -ARGON2_SALT_LEN = 32 # bytes +ARGON2_SALT_LEN = 16 # bytes (Must be always 16 for interoperability with libsodium.) ARGON2_LANES = 4 diff --git a/core/crypto.py b/core/crypto.py index b9458dc..6aa8048 100644 --- a/core/crypto.py +++ b/core/crypto.py @@ -148,6 +148,14 @@ def generate_kem_keys(algorithm: str): private_key = kem.export_secret_key() return private_key, public_key +def encap_shared_secret(public_key: bytes, algorithm: str): + with oqs.KeyEncapsulation(algorithm) as kem: + return kem.encap_secret(public_key[:ALGOS_BUFFER_LIMITS[algorithm]["PK_LEN"]]) + +def decap_shared_secret(ciphertext: bytes, private_key: bytes, algorithm: str): + with oqs.KeyEncapsulation(algorithm, secret_key = private_key[:ALGOS_BUFFER_LIMITS[algorithm]["SK_LEN"]]) as kem: + return kem.decap_secret(ciphertext[:ALGOS_BUFFER_LIMITS[algorithm]["CT_LEN"]]) + def decrypt_shared_secrets(ciphertext_blob: bytes, private_key: bytes, algorithm: str = None, otp_pad_size: int = OTP_PAD_SIZE): """ Decrypts concatenated KEM ciphertexts to derive shared one-time pad. diff --git a/core/trad_crypto.py b/core/trad_crypto.py index d2e94f1..7c0dcdc 100644 --- a/core/trad_crypto.py +++ b/core/trad_crypto.py @@ -8,11 +8,10 @@ These functions rely on the cryptography library and are intended for use within Coldwire's higher-level protocol logic. """ -from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from cryptography.hazmat.primitives.kdf.argon2 import Argon2id +from nacl import pwhash, bindings from core.constants import ( OTP_PAD_SIZE, - AES_GCM_NONCE_LEN, + XCHACHA20POLY1305_NONCE_LEN, ARGON2_ITERS, ARGON2_MEMORY, ARGON2_LANES, @@ -39,7 +38,7 @@ def sha3_512(data: bytes) -> bytes: return h.digest() -def derive_key_argon2id(password: bytes, salt: bytes = None, salt_length: int = ARGON2_SALT_LEN, output_length: int = ARGON2_OUTPUT_LEN) -> tuple[bytes, bytes]: +def derive_key_argon2id(password: bytes, salt: bytes = None, output_length: int = ARGON2_OUTPUT_LEN) -> tuple[bytes, bytes]: """ Derive a symmetric key from a password using Argon2id. @@ -57,55 +56,60 @@ def derive_key_argon2id(password: bytes, salt: bytes = None, salt_length: int = - salt: The salt used for derivation. """ if salt is None: - salt = secrets.token_bytes(salt_length) + salt = secrets.token_bytes(ARGON2_SALT_LEN) - kdf = Argon2id( - salt=salt, - iterations=ARGON2_ITERS, - memory_cost=ARGON2_MEMORY, - length=output_length, - lanes=ARGON2_LANES - ) - derived_key = kdf.derive(password) - return derived_key, salt + return pwhash.argon2id.kdf( + output_length, + password, + salt, + opslimit = ARGON2_ITERS, + memlimit = ARGON2_MEMORY + ), salt -def encrypt_aes_gcm(key: bytes, plaintext: bytes) -> tuple[bytes, bytes]: +def encrypt_xchacha20poly1305(key: bytes, plaintext: bytes, counter: int = None, counter_safety: int = 2 ** 32) -> tuple[bytes, bytes]: """ - Encrypt plaintext using AES-256 in GCM mode. + Encrypt plaintext using ChaCha20Poly1305. A random nonce is generated for each encryption. Args: - key: A 32-byte AES key. + key: A 32-byte ChaCha20Poly1305 key. plaintext: Data to encrypt. + counter: an (optional) number to add to nonce Returns: A tuple (nonce, ciphertext) where: - nonce: The randomly generated AES-GCM nonce. - ciphertext: The encrypted data including the authentication tag. """ - nonce = secrets.token_bytes(AES_GCM_NONCE_LEN) - aes_gcm = AESGCM(key) - ciphertext = aes_gcm.encrypt(nonce, plaintext, None) + nonce = secrets.token_bytes(XCHACHA20POLY1305_NONCE_LEN) + if counter is not None: + if counter > counter_safety: + raise ValueError("ChaCha counter has overflowen") + + nonce = nonce[:XCHACHA20POLY1305_NONCE_LEN - 4] + counter.to_bytes(4, "big") + + ciphertext = bindings.crypto_aead_xchacha20poly1305_ietf_encrypt(plaintext, None, nonce, key) + return nonce, ciphertext -def decrypt_aes_gcm(key: bytes, nonce: bytes, ciphertext: bytes) -> bytes: +def decrypt_xchacha20poly1305(key: bytes, nonce: bytes, ciphertext: bytes) -> bytes: """ - Decrypt ciphertext using AES-256 in GCM mode. + Decrypt ciphertext using ChaCha20Poly1305. Raises an exception if authentication fails. Args: - key: The 32-byte AES key used for encryption. + key: The 32-byte ChaCha20Poly1305 key used for encryption. nonce: The nonce used during encryption. ciphertext: The encrypted data including the authentication tag. Returns: The decrypted plaintext bytes. """ - aes_gcm = AESGCM(key) - return aes_gcm.decrypt(nonce, ciphertext, None) + + return bindings.crypto_aead_xchacha20poly1305_ietf_decrypt(ciphertext, None, nonce, key) diff --git a/logic/background_worker.py b/logic/background_worker.py index 6fdd078..8821567 100644 --- a/logic/background_worker.py +++ b/logic/background_worker.py @@ -32,10 +32,7 @@ def background_worker(user_data, user_data_lock, ui_queue, stop_flag): # logger.debug("Data received: %s", json.dumps(response, indent = 2)[:2000]) for message in response["messages"]: - try: - logger.debug("Received data message: %s", json.dumps(message, indent = 2)[:5000]) - except: - print("################# ", message) + logger.debug("Received data message: %s", json.dumps(message, indent = 2)[:5000]) # Sanity check universal message fields if (not "sender" in message) or (not message["sender"].isdigit()) or (len(message["sender"]) != 16): diff --git a/logic/contacts.py b/logic/contacts.py index ce93f5f..993f3c8 100644 --- a/logic/contacts.py +++ b/logic/contacts.py @@ -60,6 +60,7 @@ def save_contact(user_data: dict, user_data_lock, contact_id: str) -> None: "contact_nonce": None, "smp_step": None, "tmp_proof": None, + "tmp_key": None, "contact_kem_public_key": None, "our_kem_keys": { "private_key": None, diff --git a/logic/smp.py b/logic/smp.py index c1172f4..e345b2b 100644 --- a/logic/smp.py +++ b/logic/smp.py @@ -39,45 +39,58 @@ from core.crypto import ( generate_sign_keys, generate_kem_keys, - generate_shared_secrets, - decrypt_shared_secrets, - one_time_pad + encap_shared_secret, + decap_shared_secret, + +) +from core.trad_crypto import ( + derive_key_argon2id, + sha3_512, + encrypt_xchacha20poly1305, + decrypt_xchacha20poly1305 ) -from core.trad_crypto import derive_key_argon2id, sha3_512 from base64 import b64encode, b64decode from core.constants import ( SMP_NONCE_LENGTH, + SMP_PROOF_LENGTH, SMP_QUESTION_MAX_LEN, SMP_ANSWER_OUTPUT_LEN, - ML_KEM_1024_NAME + ARGON2_SALT_LEN, + ML_KEM_1024_NAME, + ML_KEM_1024_CT_LEN, + ML_DSA_87_PK_LEN, + XCHACHA20POLY1305_NONCE_LEN ) import hashlib import secrets import hmac import logging +import threading +import queue logger = logging.getLogger(__name__) def normalize_answer(s: str) -> str: - return s.strip().lower() + s = s.strip() + + # lowercase the 1st character + s = s[0].lower() + s[1:] if s else s + return s # This is step 1. -def initiate_smp(user_data: dict, user_data_lock, contact_id: str, question: str, answer: str) -> None: +def initiate_smp(user_data: dict, user_data_lock: threading.Lock, contact_id: str, question: str, answer: str) -> None: with user_data_lock: server_url = user_data["server_url"] auth_token = user_data["token"] - our_nonce = b64encode(secrets.token_bytes(SMP_NONCE_LENGTH)).decode() - - signing_private_key, signing_public_key = generate_sign_keys() + kem_private_key, kem_public_key = generate_kem_keys(ML_KEM_1024_NAME) try: response = http_request(f"{server_url}/smp/initiate", "POST", payload = { - "nonce": our_nonce, - "signing_public_key": b64encode(signing_public_key).decode(), + "kem_public_key": b64encode(kem_public_key).decode(), "recipient": contact_id }, auth_token=auth_token) @@ -88,20 +101,18 @@ def initiate_smp(user_data: dict, user_data_lock, contact_id: str, question: str if "error" in response: raise ValueError(response["error"][:512]) raise ValueError("Server sent malformed response") - - + answer = normalize_answer(answer) with user_data_lock: user_data["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"] = True user_data["contacts"][contact_id]["lt_sign_key_smp"]["question"] = question user_data["contacts"][contact_id]["lt_sign_key_smp"]["answer"] = answer - user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"] = our_nonce - user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 1 + user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 3 - user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] = signing_private_key - user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] = signing_public_key + user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["private_key"] = b64encode(kem_private_key).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["public_key"] = b64encode(kem_public_key).decode() @@ -110,22 +121,30 @@ def initiate_smp(user_data: dict, user_data_lock, contact_id: str, question: str -def smp_step_2(user_data, user_data_lock, contact_id, message, ui_queue) -> None: +def smp_step_2(user_data: dict, user_data_lock, contact_id: str, message: dict, ui_queue: queue.Queue) -> None: with user_data_lock: server_url = user_data["server_url"] auth_token = user_data["token"] + our_id = user_data["user_id"] + + contact_kem_public_key = b64decode(message["kem_public_key"], validate = True) signing_private_key, signing_public_key = generate_sign_keys() - - question_private_key, question_public_key = generate_kem_keys(ML_KEM_1024_NAME) - our_nonce = b64encode(secrets.token_bytes(SMP_NONCE_LENGTH)).decode() + our_nonce = secrets.token_bytes(SMP_NONCE_LENGTH) + + key_ciphertext, chacha_key = encap_shared_secret(contact_kem_public_key, ML_KEM_1024_NAME) + chacha_key = sha3_512(chacha_key)[:32] + + ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + chacha_key, + signing_public_key + our_nonce, + counter = 2 + ) try: - http_request(f"{server_url}/smp/step_2", "POST", payload = { - "nonce": our_nonce, - "signing_public_key": b64encode(signing_public_key).decode(), - "question_public_key": b64encode(question_public_key).decode(), + http_request(f"{server_url}/smp/step", "POST", payload = { + "ciphertext_blob": b64encode(key_ciphertext + ciphertext_nonce + ciphertext_blob).decode(), "recipient": contact_id }, auth_token=auth_token) @@ -139,35 +158,70 @@ def smp_step_2(user_data, user_data_lock, contact_id, message, ui_queue) -> None with user_data_lock: user_data["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"] = True - user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"] = our_nonce - user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"] = message["nonce"] - - user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["private_key"] = b64encode(question_private_key).decode() - user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["public_key"] = b64encode(question_public_key).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"] = b64encode(our_nonce).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"] = b64encode(chacha_key).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_kem_public_key"] = message["kem_public_key"] + user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] = signing_private_key user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] = signing_public_key - user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] = b64decode(message["signing_public_key"], validate = True) - + user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 4 -def smp_step_3(user_data, user_data_lock, contact_id, message, ui_queue) -> None: +def smp_step_3(user_data: dict, user_data_lock: threading.Lock, contact_id: str, message: dict, ui_queue: queue.Queue()) -> None: with user_data_lock: server_url = user_data["server_url"] auth_token = user_data["token"] - + our_id = user_data["user_id"] + question = user_data["contacts"][contact_id]["lt_sign_key_smp"]["question"] + answer = user_data["contacts"][contact_id]["lt_sign_key_smp"]["answer"] + + our_kem_private_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["private_key"]) + + ciphertext_blob = b64decode(message["ciphertext_blob"], validate = True) + key_ciphertext = ciphertext_blob[:ML_KEM_1024_CT_LEN] + + chacha_key = decap_shared_secret(key_ciphertext, our_kem_private_key, ML_KEM_1024_NAME) + + chacha_key = sha3_512(chacha_key)[:32] + + smp_plaintext = decrypt_xchacha20poly1305( + chacha_key, + ciphertext_blob[ML_KEM_1024_CT_LEN : ML_KEM_1024_CT_LEN + XCHACHA20POLY1305_NONCE_LEN], + ciphertext_blob[ML_KEM_1024_CT_LEN + XCHACHA20POLY1305_NONCE_LEN:] + ) + + contact_signing_public_key = smp_plaintext[:ML_DSA_87_PK_LEN] + contact_nonce = smp_plaintext[ML_DSA_87_PK_LEN:] + + our_nonce = secrets.token_bytes(SMP_NONCE_LENGTH) + + signing_private_key, signing_public_key = generate_sign_keys() + + contact_key_fingerprint = sha3_512(contact_signing_public_key) + + # Derieve a high-entropy secret key from the low-entropy answer + argon2id_salt = sha3_512(contact_nonce + our_nonce)[:ARGON2_SALT_LEN] + answer_secret, _ = derive_key_argon2id(answer.encode("utf-8"), salt = argon2id_salt, output_length = SMP_ANSWER_OUTPUT_LEN) - contact_question_public_key = b64decode(message["question_public_key"], validate = True) + # Compute our proof + our_proof = contact_nonce + our_nonce + contact_key_fingerprint + our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).digest() + + logger.debug("Our proof of contact (%s) public-key fingerprint: %s", contact_id, our_proof) + + ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + chacha_key, + signing_public_key + our_nonce + our_proof + question.encode("utf-8"), + counter = 3 + ) - question_pads_ciphertext, question_pads = generate_shared_secrets(contact_question_public_key, ML_KEM_1024_NAME, otp_pad_size = SMP_QUESTION_MAX_LEN) - question_ciphertext = b64encode(one_time_pad(question.encode("utf-8"), question_pads)).decode() try: - http_request(f"{server_url}/smp/step_3", "POST", payload = { - "question_ciphertext": question_ciphertext, - "question_pads_ciphertext": b64encode(question_pads_ciphertext).decode(), + http_request(f"{server_url}/smp/step", "POST", payload = { + "ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(), "recipient": contact_id }, auth_token=auth_token) @@ -179,28 +233,46 @@ def smp_step_3(user_data, user_data_lock, contact_id, message, ui_queue) -> None # We only update after the request is sent successfully with user_data_lock: - user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] = b64decode(message["signing_public_key"], validate=True) - user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_kem_public_key"] = message["question_public_key"] - user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"] = message["nonce"] - + user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] = contact_signing_public_key + + user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"] = b64encode(contact_nonce).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"] = b64encode(our_nonce).decode() + user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"] = b64encode(chacha_key).decode() + + user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["private_key"] = signing_private_key + user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] = signing_public_key + + user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 5 def smp_step_4_request_answer(user_data, user_data_lock, contact_id, message, ui_queue) -> None: with user_data_lock: - our_question_private_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["private_key"]) + tmp_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"]) + + ciphertext_blob = b64decode(message["ciphertext_blob"], validate = True) + smp_plaintext = decrypt_xchacha20poly1305(tmp_key, ciphertext_blob[:XCHACHA20POLY1305_NONCE_LEN], ciphertext_blob[XCHACHA20POLY1305_NONCE_LEN:]) + + contact_signing_public_key = smp_plaintext[:ML_DSA_87_PK_LEN] + contact_nonce = b64encode(smp_plaintext[ML_DSA_87_PK_LEN : SMP_NONCE_LENGTH + ML_DSA_87_PK_LEN]).decode() + contact_proof = b64encode(smp_plaintext[SMP_NONCE_LENGTH + ML_DSA_87_PK_LEN : SMP_NONCE_LENGTH + SMP_PROOF_LENGTH + ML_DSA_87_PK_LEN]).decode() + question = smp_plaintext[SMP_NONCE_LENGTH + SMP_PROOF_LENGTH + ML_DSA_87_PK_LEN:].decode("utf-8") - pads = decrypt_shared_secrets(b64decode(message["question_pads_ciphertext"], validate = True), our_question_private_key, ML_KEM_1024_NAME, otp_pad_size = SMP_QUESTION_MAX_LEN) - question = one_time_pad(b64decode(message["question_ciphertext"], validate = True), pads) with user_data_lock: - user_data["contacts"][contact_id]["lt_sign_key_smp"]["question"] = question.decode("utf-8") - user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 4 + user_data["contacts"][contact_id]["lt_sign_key_smp"]["question"] = question + user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_proof"] = contact_proof + # user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] = 5 + + + user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"] = contact_nonce + + user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] = contact_signing_public_key ui_queue.put({ "type": "smp_question", "contact_id": contact_id, - "question": question.decode("utf-8") + "question": question }) @@ -210,88 +282,55 @@ def smp_step_4_answer_provided(user_data, user_data_lock, contact_id, answer, ui auth_token = user_data["token"] contact_signing_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] - contact_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"], validate=True) + contact_kem_public_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_kem_public_key"], validate = True) + contact_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"], validate=True) + contact_proof = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_proof"], validate=True) our_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"], validate=True) - user_data["contacts"][contact_id]["lt_sign_key_smp"]["answer"] = answer + our_signing_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] + + tmp_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"]) answer = normalize_answer(answer) - contact_key_fingerprint = sha3_512(contact_signing_public_key) + our_key_fingerprint = sha3_512(our_signing_public_key) # Derieve a high-entropy secret key from the low-entropy answer - argon2id_salt = sha3_512(our_nonce + contact_nonce) + argon2id_salt = sha3_512(our_nonce + contact_nonce)[:ARGON2_SALT_LEN] answer_secret, _ = derive_key_argon2id(answer.encode("utf-8"), salt = argon2id_salt, output_length = SMP_ANSWER_OUTPUT_LEN) # Compute our proof - our_proof = contact_nonce + our_nonce + contact_key_fingerprint - our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).hexdigest() - - logger.debug("Our proof of contact (%s) public-key fingerprint: %s", contact_id, our_proof) - - - try: - http_request(f"{server_url}/smp/step_4", "POST", payload = { - "proof": our_proof, - "recipient": contact_id - }, auth_token=auth_token) - except Exception: - logger.error("Failed to send proof request to server, either you are offline or the server is down") - smp_failure_notify_contact(user_data, user_data_lock, contact_id, ui_queue) - return - - - -def smp_step_5(user_data, user_data_lock, contact_id, message, ui_queue) -> None: - with user_data_lock: - server_url = user_data["server_url"] - auth_token = user_data["token"] - - answer = user_data["contacts"][contact_id]["lt_sign_key_smp"]["answer"] - - contact_signing_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] - contact_question_public_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_kem_public_key"]) - - our_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] - our_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"], validate=True) - contact_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"], validate=True) - - our_key_fingerprint = sha3_512(our_public_key) - - # Derieve a high-entropy secret key from the low-entropy answer - argon2id_salt = sha3_512(contact_nonce + our_nonce) - answer_secret, _ = derive_key_argon2id(answer.encode("utf-8"), salt = argon2id_salt, output_length = SMP_ANSWER_OUTPUT_LEN) - - # Compute the proof our_proof = our_nonce + contact_nonce + our_key_fingerprint our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).digest() - contact_proof_raw = bytes.fromhex(message["proof"]) - - logger.debug("SMP Proof sent to us: %s", contact_proof_raw) + logger.debug("SMP Proof sent to us: %s", contact_proof) logger.debug("Our compute message: %s", our_proof) - # Verify Contact's version of our public-key fingerprint matches our actual public-key fingerprint # We compare using compare_digest to prevent timing analysis by avoiding content-based short circuiting behaviour - if not hmac.compare_digest(our_proof, contact_proof_raw): + if not hmac.compare_digest(our_proof, contact_proof): logger.warning("SMP Verification failed") smp_failure_notify_contact(user_data, user_data_lock, contact_id, ui_queue) return # We compute proof for contact's public key (signing public key, and the question public key) - contact_key_fingerprint = sha3_512(contact_signing_public_key + contact_question_public_key) + contact_key_fingerprint = sha3_512(contact_signing_public_key + contact_kem_public_key) our_proof = contact_nonce + our_nonce + contact_key_fingerprint - our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).hexdigest() + our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).digest() + + ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305( + tmp_key, + our_proof, + counter = 4 + ) - logger.debug("Our proof to contact: %s", our_proof) try: - http_request(f"{server_url}/smp/step_5", "POST", payload = { - "proof": our_proof, + http_request(f"{server_url}/smp/step", "POST", payload = { + "ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(), "recipient": contact_id }, auth_token=auth_token) except Exception: @@ -299,45 +338,61 @@ def smp_step_5(user_data, user_data_lock, contact_id, message, ui_queue) -> None smp_failure_notify_contact(user_data, user_data_lock, contact_id, ui_queue) return + # We call smp_success at very end to ensure if the requests step fail, we don't alter our local state smp_success(user_data, user_data_lock, contact_id, ui_queue) + with user_data_lock: + user_data["contacts"][contact_id]["lt_sign_key_smp"]["answer"] = answer + + + -def smp_step_6(user_data, user_data_lock, contact_id, message, ui_queue) -> None: + +def smp_step_5(user_data, user_data_lock, contact_id, message, ui_queue) -> None: with user_data_lock: + server_url = user_data["server_url"] + auth_token = user_data["token"] + answer = user_data["contacts"][contact_id]["lt_sign_key_smp"]["answer"] - contact_signing_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["contact_public_key"] - - our_question_public_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["public_key"]) - our_signing_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] + our_signing_public_key = user_data["contacts"][contact_id]["lt_sign_keys"]["our_keys"]["public_key"] + our_kem_public_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_kem_keys"]["public_key"]) our_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["our_nonce"], validate=True) contact_nonce = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["contact_nonce"], validate=True) - our_key_fingerprint = sha3_512(our_signing_public_key + our_question_public_key) + tmp_key = b64decode(user_data["contacts"][contact_id]["lt_sign_key_smp"]["tmp_key"]) + + + our_key_fingerprint = sha3_512(our_signing_public_key + our_kem_public_key) # Derieve a high-entropy secret key from the low-entropy answer - argon2id_salt = sha3_512(our_nonce + contact_nonce) + argon2id_salt = sha3_512(contact_nonce + our_nonce)[:ARGON2_SALT_LEN] answer_secret, _ = derive_key_argon2id(answer.encode("utf-8"), salt = argon2id_salt, output_length = SMP_ANSWER_OUTPUT_LEN) # Compute the proof our_proof = our_nonce + contact_nonce + our_key_fingerprint our_proof = hmac.new(answer_secret, our_proof, hashlib.sha3_512).digest() - contact_proof_raw = bytes.fromhex(message["proof"]) + ciphertext_blob = b64decode(message["ciphertext_blob"], validate = True) + contact_proof = decrypt_xchacha20poly1305(tmp_key, ciphertext_blob[:XCHACHA20POLY1305_NONCE_LEN], ciphertext_blob[XCHACHA20POLY1305_NONCE_LEN:]) + - logger.debug("SMP Proof sent to us: %s", contact_proof_raw) + logger.debug("SMP Proof sent to us: %s", contact_proof) logger.debug("Our compute message: %s", our_proof) # Verify Contact's version of our public-key fingerprint matches our actual public-key fingerprint # We compare using compare_digest to prevent timing analysis by avoiding content-based short circuiting behaviour - if not hmac.compare_digest(our_proof, contact_proof_raw): + if not hmac.compare_digest(our_proof, contact_proof): logger.warning("SMP Verification failed") smp_failure_notify_contact(user_data, user_data_lock, contact_id, ui_queue) return + + + # We call smp_success at very end to ensure if the requests step fail, we don't alter our local state smp_success(user_data, user_data_lock, contact_id, ui_queue) @@ -352,19 +407,20 @@ def smp_step_6(user_data, user_data_lock, contact_id, message, ui_queue) -> None def smp_success(user_data, user_data_lock, contact_id, ui_queue) -> None: with user_data_lock: user_data["contacts"][contact_id]["lt_sign_key_smp"] = { - "verified": True, - "pending_verification": False, - "question": None, - "answer": None, - "our_nonce": None, - "contact_nonce": None, - "smp_step": None, - "contact_kem_public_key": None, - "our_kem_keys": { - "private_key": None, - "public_key": None - } - + "verified": True, + "pending_verification": False, + "question": None, + "answer": None, + "our_nonce": None, + "contact_nonce": None, + "smp_step": None, + "tmp_proof": None, + "tmp_key": None, + "contact_kem_public_key": None, + "our_kem_keys": { + "private_key": None, + "public_key": None + } } @@ -375,20 +431,20 @@ def smp_success(user_data, user_data_lock, contact_id, ui_queue) -> None: def smp_failure(user_data, user_data_lock, contact_id, ui_queue) -> None: with user_data_lock: user_data["contacts"][contact_id]["lt_sign_key_smp"] = { - "verified": False, - "pending_verification": False, - "question": None, - "answer": None, - "our_nonce": None, - "contact_nonce": None, - "smp_step": None, - "contact_kem_public_key": None, - "our_kem_keys": { - "private_key": None, - "public_key": None - } - - + "verified": False, + "pending_verification": False, + "question": None, + "answer": None, + "our_nonce": None, + "contact_nonce": None, + "smp_step": None, + "tmp_proof": None, + "tmp_key": None, + "contact_kem_public_key": None, + "our_kem_keys": { + "private_key": None, + "public_key": None + } } ui_queue.put({"type": "showerror", "title": "Error", "message": "Verification has failed! Please re-try."}) @@ -424,23 +480,21 @@ def smp_unanswered_questions(user_data, user_data_lock, ui_queue): def smp_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, message): contact_id = message["sender"] - if (not "step" in message): - logger.error("Message has no 'step'. Maybe malicious server ? anyhow, we will ignore this SMP request. Message: %s", repr(message)) - return - - if not (message["step"] in [1, 2, 3, 4, 5, 6, -1]): - logger.error("SMP 'step' is not in range of values we accept. We will ignore this SMP request. Step: %d", message["step"]) - return + try: + smp_step = user_data["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] + if smp_step is None: + raise Exception() + except Exception: + smp_step = 2 # Check if we don't have this contact saved - if (not (contact_id in user_data_copied["contacts"])): + if contact_id not in user_data_copied["contacts"]: # We assume it has to be step 1 because the contact did not exist before - if message["step"] != 1: - logger.error("something wrong, we or they are not synced? Not sure, but we will ignore this SMP request because the step should've been 1, instead we got (%d)", message["step"]) + if smp_step != 2: + logger.error("Unknown contact sent SMP request of step (%d)", smp_step) return - logger.info("We received a new SMP request for a contact we did not have saved") - + logger.info("We received a new SMP request for a contact (%s) we did not have saved", contact_id) # Save them in-memory save_contact(user_data, user_data_lock, contact_id) @@ -458,45 +512,35 @@ def smp_data_handler(user_data, user_data_lock, user_data_copied, ui_queue, mess # Same thing as above code, except that we don't fetch nor save the contact here # as they're already fetched and saved - elif message["step"] == 1: + elif smp_step == 2: smp_step_2(user_data, user_data_lock, contact_id, message, ui_queue) - elif message["step"] == 2: - if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]) or (user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] != 1): - logger.error("something wrong, we or they are not synced? Not sure, but we will ignore this SMP request for now") + elif "failure" in message: + # Delete SMP state for contact + smp_failure(user_data, user_data_lock, contact_id, ui_queue) + + elif smp_step == 3: + if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]): + logger.error("Contact (%s) is not pending verification, yet they sent us a SMP request. Ignoring it.", contact_id) return smp_step_3(user_data, user_data_lock, contact_id, message, ui_queue) - elif message["step"] == 3: - if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]): # or (user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] != 2): - logger.error("something wrong, we or they are not synced? Not sure, but we will ignore this SMP request for now") + elif smp_step == 4: + if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]): + logger.error("Contact (%s) is not pending verification, yet they sent us a SMP request. Ignoring it.", contact_id) return smp_step_4_request_answer(user_data, user_data_lock, contact_id, message, ui_queue) - elif message["step"] == 4: - if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]): # or (user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] != 3): - logger.error("something wrong, we or they are not synced? Not sure, but we will ignore this SMP request for now") + elif smp_step == 5: + if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]): + logger.error("Contact (%s) is not pending verification, yet they sent us a SMP request. Ignoring it.", contact_id) return smp_step_5(user_data, user_data_lock, contact_id, message, ui_queue) - elif message["step"] == 5: - if (not user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["pending_verification"]): # or (user_data_copied["contacts"][contact_id]["lt_sign_key_smp"]["smp_step"] != 4): - logger.error("something wrong, we or they are not synced? Not sure, but we will ignore this SMP request for now") - return - - smp_step_6(user_data, user_data_lock, contact_id, message, ui_queue) - - - - # SMP failure on contact side - elif (message["step"] == -1): - # Delete SMP state for contact - smp_failure(user_data, user_data_lock, contact_id, ui_queue) - else: - logger.error("This is an impossible condition, either you have discovered a bug in Coldwire, or the server is malicious. Skipping weird SMP step (%d)...", message["step"]) + logger.error("This is an impossible condition, You may have discovered a bug in Coldwire. Skipping weird SMP step (%d)...", smp_step) return save_account_data(user_data, user_data_lock) diff --git a/logic/storage.py b/logic/storage.py index eec4d0d..e9d496f 100644 --- a/logic/storage.py +++ b/logic/storage.py @@ -3,7 +3,8 @@ from core.constants import ( ML_KEM_1024_NAME, CLASSIC_MCELIECE_8_F_NAME, - ACCOUNT_FILE_PATH + ACCOUNT_FILE_PATH, + ARGON2_SALT_LEN ) import core.trad_crypto as crypto @@ -29,11 +30,11 @@ def load_account_data(password = None) -> dict: # first 12 bytes is nonce, and last 32 bytes is the password salt, # and the ciphertext is inbetween. - password_kdf, _ = crypto.derive_key_argon2id(password.encode(), salt=blob[-32:]) + password_kdf, _ = crypto.derive_key_argon2id(password.encode(), salt=blob[-ARGON2_SALT_LEN:]) - blob = blob[:-32] + blob = blob[:-ARGON2_SALT_LEN] - user_data = json.loads(crypto.decrypt_aes_gcm(password_kdf, blob[:12], blob[12:])) + user_data = json.loads(crypto.decrypt_xchacha20poly1305(password_kdf, blob[:12], blob[12:])) @@ -195,7 +196,7 @@ def save_account_data(user_data: dict, user_data_lock, password = None) -> None: password_kdf, password_salt = crypto.derive_key_argon2id(password.encode()) - nonce, ciphertext = crypto.encrypt_aes_gcm(password_kdf, json.dumps(user_data).encode("utf-8")) + nonce, ciphertext = crypto.encrypt_xchacha20poly1305(password_kdf, json.dumps(user_data).encode("utf-8")) with open(ACCOUNT_FILE_PATH, "wb") as f: f.write(nonce + ciphertext + password_salt) diff --git a/requirements.txt b/requirements.txt index 0d38bc5..2b203db 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -cryptography +pynacl diff --git a/tests/test_trad_crypto.py b/tests/test_trad_crypto.py index 9e258ef..fdd8f05 100644 --- a/tests/test_trad_crypto.py +++ b/tests/test_trad_crypto.py @@ -1,13 +1,13 @@ # tests/test_trad_crypto.py """ - Tests for AES-256 GCM encryption/decryption and Argon2id key derivation. - Focus: Correctness of encryption/decryption flow and tamper detection. + Tests for XChaCha20Poly1305 encryption & decryption and Argon2id key derivation. + Focus: Correctness of encryption & decryption flow and tamper detection. """ import pytest from core.trad_crypto import ( - encrypt_aes_gcm, - decrypt_aes_gcm, + encrypt_xchacha20poly1305, + decrypt_xchacha20poly1305, derive_key_argon2id ) @@ -23,12 +23,12 @@ def test_aes_encrypt_decrypt(): assert key != password, "Derived key should not match plaintext password" # Encrypt plaintext using AES-GCM - nonce, ciphertext = encrypt_aes_gcm(key, data) + nonce, ciphertext = encrypt_xchacha20poly1305(key, data) assert nonce != ciphertext, "Nonce and ciphertext should not be equal" assert ciphertext != data, "Ciphertext should differ from plaintext" # Decrypt ciphertext and verify correctness - plaintext = decrypt_aes_gcm(key, nonce, ciphertext) + plaintext = decrypt_xchacha20poly1305(key, nonce, ciphertext) assert plaintext == data, "Decrypted plaintext does not match original" # Tampering test: Modify ciphertext and expect decryption failure diff --git a/ui/smp_setup_window.py b/ui/smp_setup_window.py index 7ab1d3c..1bad6ec 100644 --- a/ui/smp_setup_window.py +++ b/ui/smp_setup_window.py @@ -2,6 +2,9 @@ from tkinter import messagebox from ui.utils import * from logic.smp import initiate_smp +from core.constants import ( + SMP_QUESTION_MAX_LEN +) class SMPSetupWindow(tk.Toplevel): def __init__(self, master, contact_id): @@ -52,23 +55,24 @@ def __init__(self, master, contact_id): def submit(self): question = self.question_entry.get().strip() - answer = self.answer_entry.get().strip().lower() + answer = self.answer_entry.get().strip() + if not question or not answer: messagebox.showerror("Error", "Both fields are required.") return - if question.lower() == answer: + if question.lower() == answer.lower(): messagebox.showerror("Error", "The question and answer must be different!") return - if answer in question.lower(): + if answer.lower() in question.lower(): messagebox.showerror("Error", "Question must not contain the answer!") return - if len(question) > 512: - messagebox.showerror("Error", "Question must be under 512 characters long.") + if len(question) > SMP_QUESTION_MAX_LEN: + messagebox.showerror("Error", f"Question must be under {SMP_QUESTION_MAX_LEN} characters long.") # This is just unacceptable, 4 characaters is the bare minimum.