diff --git a/django_crypto_fields/cipher/__init__.py b/django_crypto_fields/cipher/__init__.py new file mode 100644 index 0000000..d9ae13d --- /dev/null +++ b/django_crypto_fields/cipher/__init__.py @@ -0,0 +1,4 @@ +from .cipher import Cipher +from .cipher_parser import CipherParser + +__all__ = ["Cipher", "CipherParser"] diff --git a/django_crypto_fields/cipher/cipher.py b/django_crypto_fields/cipher/cipher.py new file mode 100644 index 0000000..cd3b445 --- /dev/null +++ b/django_crypto_fields/cipher/cipher.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import Callable + +from ..constants import CIPHER_PREFIX, HASH_PREFIX +from ..utils import make_hash, safe_encode_utf8 + +__all__ = ["Cipher"] + + +class Cipher: + """A class that given a value builds a cipher of the format + hash_prefix + hashed_value + cipher_prefix + secret. + + The secret is encrypted using the passed `encrypt` callable. + """ + + def __init__( + self, + value: str | bytes, + salt_key: bytes, + encrypt: Callable[[bytes], bytes] | None = None, + ): + encoded_value = safe_encode_utf8(value) + self.hash_prefix = b"" + self.hashed_value = b"" + self.cipher_prefix = b"" + self.secret = b"" + if salt_key: + self.hash_prefix: bytes = safe_encode_utf8(HASH_PREFIX) + self.hashed_value: bytes = make_hash(encoded_value, salt_key) + if encrypt: + self.secret = encrypt(encoded_value) + self.cipher_prefix: bytes = safe_encode_utf8(CIPHER_PREFIX) + + @property + def cipher(self) -> bytes: + return self.hash_prefix + self.hashed_value + self.cipher_prefix + self.secret + + def hash_with_prefix(self) -> bytes: + return self.hash_prefix + self.hashed_value + + def secret_with_prefix(self) -> bytes: + return self.cipher_prefix + self.secret diff --git a/django_crypto_fields/cipher/cipher_parser.py b/django_crypto_fields/cipher/cipher_parser.py new file mode 100644 index 0000000..ee04c61 --- /dev/null +++ b/django_crypto_fields/cipher/cipher_parser.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from ..constants import CIPHER_PREFIX, HASH_PREFIX +from ..exceptions import MalformedCiphertextError +from ..utils import make_hash, safe_encode_utf8 + +__all__ = ["CipherParser"] + + +class CipherParser: + def __init__(self, cipher: bytes, salt_key: bytes | None = None): + self._cipher_prefix = None + self._hash_prefix = None + self._hashed_value = None + self._secret = None + self.cipher = safe_encode_utf8(cipher) + self.salt_key = salt_key + self.validate_hashed_value() + self.validate_secret() + + @property + def hash_prefix(self) -> bytes | None: + if self.cipher: + hash_prefix = safe_encode_utf8(HASH_PREFIX) + self._hash_prefix = hash_prefix if self.cipher.startswith(hash_prefix) else None + return self._hash_prefix + + @property + def cipher_prefix(self) -> bytes | None: + if self.cipher: + cipher_prefix = safe_encode_utf8(CIPHER_PREFIX) + self._cipher_prefix = cipher_prefix if cipher_prefix in self.cipher else None + return self._cipher_prefix + + @property + def hashed_value(self) -> bytes | None: + if self.cipher and self.cipher.startswith(self.hash_prefix): + self._hashed_value = self.cipher.split(self.hash_prefix)[1].split( + self.cipher_prefix + )[0] + return self._hashed_value + + @property + def secret(self) -> bytes | None: + if self.cipher and safe_encode_utf8(CIPHER_PREFIX) in self.cipher: + self._secret = self.cipher.split(self.cipher_prefix)[1] + return self._secret + + def validate_hashed_value(self) -> None: + if self.hash_prefix and not self.hashed_value: + raise MalformedCiphertextError("Invalid hashed_value. Got None.") + elif self.salt_key and len(self.hashed_value) != len(make_hash("Foo", self.salt_key)): + raise MalformedCiphertextError("Invalid hashed_value. Incorrect size.") + + def validate_secret(self) -> None: + if self.cipher_prefix and not self.secret: + raise MalformedCiphertextError("Invalid secret. Got None.") diff --git a/django_crypto_fields/cryptor.py b/django_crypto_fields/cryptor.py index 2851365..bc701fb 100644 --- a/django_crypto_fields/cryptor.py +++ b/django_crypto_fields/cryptor.py @@ -6,7 +6,7 @@ from Cryptodome import Random from Cryptodome.Cipher import AES as AES_CIPHER -from .constants import AES, ENCODING, PRIVATE, PUBLIC, RSA +from .constants import AES, ENCODING, LOCAL_MODE, PRIVATE, PUBLIC, RSA from .exceptions import EncryptionError from .keys import encryption_keys from .utils import get_keypath_from_settings @@ -15,8 +15,6 @@ from Cryptodome.Cipher._mode_cbc import CbcMode from Cryptodome.Cipher.PKCS1_OAEP import PKCS1OAEP_Cipher - from .keys import Keys - class Cryptor: """Base class for all classes providing RSA and AES encryption @@ -26,9 +24,17 @@ class Cryptor: of this except the filenames are replaced with the actual keys. """ - def __init__(self): + def __init__(self, algorithm: AES | RSA, access_mode: PRIVATE | LOCAL_MODE = None) -> None: + self.algorithm = algorithm self.aes_encryption_mode: int = AES_CIPHER.MODE_CBC - self.keys: Keys = encryption_keys + aes_key_attr: str = "_".join([AES, access_mode, PRIVATE, "key"]) + self.aes_key: bytes = getattr(encryption_keys, aes_key_attr) + rsa_key_attr = "_".join([RSA, access_mode, PUBLIC, "key"]) + self.rsa_public_key: PKCS1OAEP_Cipher = getattr(encryption_keys, rsa_key_attr) + rsa_key_attr = "_".join([RSA, access_mode, PRIVATE, "key"]) + self.rsa_private_key: PKCS1OAEP_Cipher = getattr(encryption_keys, rsa_key_attr) + self.encrypt = getattr(self, f"_{self.algorithm.lower()}_encrypt") + self.decrypt = getattr(self, f"_{self.algorithm.lower()}_decrypt") def get_with_padding(self, plaintext: str | bytes, block_size: int) -> bytes: """Return string padded so length is a multiple of the block size. @@ -73,42 +79,34 @@ def get_without_padding(self, plaintext: str | bytes) -> bytes: return plaintext[:-1] return plaintext[:-padding_length] - def aes_encrypt(self, plaintext: str | bytes, mode: str) -> bytes: - aes_key_attr: str = "_".join([AES, mode, PRIVATE, "key"]) - aes_key: bytes = getattr(self.keys, aes_key_attr) + def _aes_encrypt(self, plaintext: str | bytes) -> bytes: iv: bytes = Random.new().read(AES_CIPHER.block_size) - cipher: CbcMode = AES_CIPHER.new(aes_key, self.aes_encryption_mode, iv) + cipher: CbcMode = AES_CIPHER.new(self.aes_key, self.aes_encryption_mode, iv) padded_plaintext = self.get_with_padding(plaintext, cipher.block_size) return iv + cipher.encrypt(padded_plaintext) - def aes_decrypt(self, ciphertext: bytes, mode: str) -> str: - aes_key_attr: str = "_".join([AES, mode, PRIVATE, "key"]) - aes_key: bytes = getattr(self.keys, aes_key_attr) + def _aes_decrypt(self, ciphertext: bytes) -> str: iv = ciphertext[: AES_CIPHER.block_size] - cipher: CbcMode = AES_CIPHER.new(aes_key, self.aes_encryption_mode, iv) + cipher: CbcMode = AES_CIPHER.new(self.aes_key, self.aes_encryption_mode, iv) plaintext = cipher.decrypt(ciphertext)[AES_CIPHER.block_size :] return self.get_without_padding(plaintext).decode() - def rsa_encrypt(self, plaintext: str | bytes, mode: int) -> bytes: - rsa_key_attr = "_".join([RSA, mode, PUBLIC, "key"]) - rsa_key: PKCS1OAEP_Cipher = getattr(self.keys, rsa_key_attr) + def _rsa_encrypt(self, plaintext: str | bytes) -> bytes: try: plaintext = plaintext.encode(ENCODING) except AttributeError: pass try: - ciphertext = rsa_key.encrypt(plaintext) + ciphertext = self.rsa_public_key.encrypt(plaintext) except (ValueError, TypeError) as e: raise EncryptionError(f"RSA encryption failed for value. Got '{e}'") return ciphertext - def rsa_decrypt(self, ciphertext: bytes, mode: str) -> str: - rsa_key_attr = "_".join([RSA, mode, PRIVATE, "key"]) - rsa_key: PKCS1OAEP_Cipher = getattr(self.keys, rsa_key_attr) + def _rsa_decrypt(self, ciphertext: bytes) -> str: try: - plaintext = rsa_key.decrypt(ciphertext) + plaintext = self.rsa_private_key.decrypt(ciphertext) except ValueError as e: raise EncryptionError( - f"{e} Using {rsa_key_attr} from key_path=`{get_keypath_from_settings()}`." + f"{e} Using RSA from key_path=`{get_keypath_from_settings()}`." ) return plaintext.decode(ENCODING) diff --git a/django_crypto_fields/field_cryptor.py b/django_crypto_fields/field_cryptor.py index e07b6e4..044d5c7 100644 --- a/django_crypto_fields/field_cryptor.py +++ b/django_crypto_fields/field_cryptor.py @@ -1,39 +1,17 @@ from __future__ import annotations -import binascii -import hashlib from typing import TYPE_CHECKING, Type from Cryptodome.Cipher import AES as AES_CIPHER from django.apps import apps as django_apps from django.core.exceptions import ObjectDoesNotExist -from .constants import ( - AES, - CIPHER_PREFIX, - ENCODING, - HASH_ALGORITHM, - HASH_PREFIX, - HASH_ROUNDS, - PRIVATE, - RSA, - SALT, -) +from .cipher import Cipher, CipherParser +from .constants import AES, CIPHER_PREFIX, ENCODING, HASH_PREFIX, PRIVATE, RSA, SALT from .cryptor import Cryptor -from .exceptions import ( - CipherError, - EncryptionError, - EncryptionKeyError, - InvalidEncryptionAlgorithm, -) +from .exceptions import EncryptionError, EncryptionKeyError, InvalidEncryptionAlgorithm from .keys import encryption_keys -from .utils import ( - get_crypt_model_cls, - has_valid_value_or_raise, - is_valid_ciphertext_or_raise, - safe_decode, - safe_encode_utf8, -) +from .utils import get_crypt_model_cls, make_hash, safe_decode, safe_encode_utf8 if TYPE_CHECKING: from .models import Crypt @@ -52,6 +30,8 @@ class FieldCryptor: """ crypt_model = "django_crypto_fields.crypt" + cryptor_cls = Cryptor + cipher_cls = Cipher def __init__(self, algorithm: str, access_mode: str): self._using = None @@ -62,7 +42,7 @@ def __init__(self, algorithm: str, access_mode: str): self.cipher_buffer_key = f"{self.algorithm}_{self.access_mode}" self.cipher_buffer = {self.cipher_buffer_key: {}} self.keys = encryption_keys - self.cryptor = Cryptor() + self.cryptor = self.cryptor_cls(algorithm=algorithm, access_mode=access_mode) self.hash_size: int = len(self.hash("Foo")) def __repr__(self) -> str: @@ -80,6 +60,9 @@ def algorithm(self, value): f"Invalid encryption algorithm. Expected 'aes' or 'rsa'. Got {value}" ) + def hash(self, value): + return make_hash(value, self.salt_key) + @property def salt_key(self): attr = "_".join([SALT, self.access_mode, PRIVATE]) @@ -96,21 +79,11 @@ def crypt_model_cls(self) -> Type[Crypt]: """ return get_crypt_model_cls() - def hash(self, plaintext): - """Returns a hexified hash of a plaintext value (as bytes). - - The hashed value is used as a signature of the "secret". - """ - plaintext = safe_encode_utf8(plaintext) - dk = hashlib.pbkdf2_hmac(HASH_ALGORITHM, plaintext, self.salt_key, HASH_ROUNDS) - return binascii.hexlify(dk) - def encrypt(self, value: str | bytes | None, update: bool | None = None): - """Returns ciphertext as byte data using either an - RSA or AES cipher. + """Returns either an RSA or AES cipher. * 'value' is either plaintext or ciphertext - * 'ciphertext' is a byte value of hash_prefix + * 'cipher' is a byte value of hash_prefix + hashed_value + cipher_prefix + secret. For example: enc1:::234234ed234a24enc2::\x0e\xb9\xae\x13s\x8d @@ -118,31 +91,28 @@ def encrypt(self, value: str | bytes | None, update: bool | None = None): * 'value' is not re-encrypted if already encrypted and properly formatted 'ciphertext'. """ - ciphertext = None + cipher = None update = True if update is None else update - value = safe_encode_utf8(value) - if value is not None and value != b"" and not self.is_encrypted(value): - ciphertext = self.get_ciphertext(value) + encoded_value = safe_encode_utf8(value) + if encoded_value and not self.is_encrypted(encoded_value): + cipher = self.cipher_cls(value, self.salt_key, encrypt=self.cryptor.encrypt) if update: - self.update_crypt(ciphertext) - return ciphertext + self.update_crypt(cipher) + return getattr(cipher, "cipher", encoded_value) - def decrypt(self, hash_with_prefix: str): + def decrypt(self, hash_with_prefix: str | bytes): """Returns decrypted secret or None. - Secret is retrieved from `Crypt` using the hash. + Secret is retrieved from `Crypt` using the hash_with_prefix + coming from the field of the user model. - hash_with_prefix = hash_prefix+hash. + hash_with_prefix = hash_prefix+hash_value. """ - plaintext = None hash_with_prefix = safe_encode_utf8(hash_with_prefix) - if self.is_encrypted(hash_with_prefix): + if hash_with_prefix and self.is_encrypted(hash_with_prefix): if secret := self.fetch_secret(hash_with_prefix): - if self.algorithm == AES: - plaintext = self.cryptor.aes_decrypt(secret, self.access_mode) - elif self.algorithm == RSA: - plaintext = self.cryptor.rsa_decrypt(secret, self.access_mode) - return plaintext + return self.cryptor.decrypt(secret) + return None @property def using(self): @@ -151,102 +121,83 @@ def using(self): self._using = app_config.crypt_model_using return self._using - def update_crypt(self, ciphertext): - """Updates cipher model (Crypt) and temporary buffer.""" - if is_valid_ciphertext_or_raise(ciphertext, self.hash_size): - hashed_value = self.get_hash(ciphertext) - secret = self.get_secret(ciphertext) - self.cipher_buffer[self.cipher_buffer_key].update({hashed_value: secret}) - try: - crypt = self.crypt_model_cls.objects.using(self.using).get( - hash=hashed_value, algorithm=self.algorithm, mode=self.access_mode - ) - crypt.secret = secret - crypt.save() - except ObjectDoesNotExist: - self.crypt_model_cls.objects.using(self.using).create( - hash=hashed_value, - secret=secret, - algorithm=self.algorithm, - cipher_mode=self.aes_encryption_mode, - mode=self.access_mode, - ) + def update_crypt(self, cipher: Cipher): + """Updates Crypt model and cipher_buffer.""" + self.cipher_buffer[self.cipher_buffer_key].update({cipher.hashed_value: cipher.secret}) + try: + crypt = self.crypt_model_cls.objects.using(self.using).get( + hash=cipher.hashed_value, algorithm=self.algorithm, mode=self.access_mode + ) + crypt.secret = cipher.secret + crypt.save() + except ObjectDoesNotExist: + self.crypt_model_cls.objects.using(self.using).create( + hash=cipher.hashed_value, + secret=cipher.secret, + algorithm=self.algorithm, + cipher_mode=self.aes_encryption_mode, + mode=self.access_mode, + ) def get_prep_value(self, value: str | bytes | None) -> str | bytes | None: - """Returns the prefix + hash as stored in the DB table column of + """Returns the prefix + hash_value as stored in the DB table column of your model's "encrypted" field. Used by get_prep_value() """ + hash_with_prefix = None if value is None or value in ["", b""]: pass # return None or empty string/byte else: - ciphertext = self.encrypt(value) - value = ciphertext.split(CIPHER_PREFIX.encode(ENCODING))[0] - value = safe_decode(value) - return value - - def get_ciphertext(self, value): - cipher = None - if self.algorithm == AES: - cipher = self.cryptor.aes_encrypt - elif self.algorithm == RSA: - cipher = self.cryptor.rsa_encrypt - ciphertext = ( - HASH_PREFIX.encode(ENCODING) - + self.hash(value) - + CIPHER_PREFIX.encode(ENCODING) - + cipher(value, self.access_mode) - ) - return is_valid_ciphertext_or_raise(ciphertext, self.hash_size) - - def get_hash(self, ciphertext: bytes) -> bytes | None: - """Returns the hashed_value given a ciphertext or None.""" - ciphertext = safe_encode_utf8(ciphertext) - return ciphertext[len(HASH_PREFIX) :][: self.hash_size] or None - - def get_secret(self, ciphertext: bytes) -> bytes | None: - """Returns the secret given a ciphertext.""" - if ciphertext is None: - secret = None - elif self.is_encrypted(ciphertext): - secret = ciphertext.split(CIPHER_PREFIX.encode(ENCODING))[1] - else: - raise CipherError("Expected a ciphertext or None") - return secret + cipher = self.encrypt(value) + hash_with_prefix = cipher.split(CIPHER_PREFIX.encode(ENCODING))[0] + hash_with_prefix = safe_decode(hash_with_prefix) + return hash_with_prefix or value def fetch_secret(self, hash_with_prefix: bytes): - hashed_value = self.get_hash(hash_with_prefix) + """Fetch the secret from the DB or the buffer using + the hashed_value as the lookup. + + If not found in buffer, lookup in DB and update the buffer. + + A secret is the segment to follow the `enc2:::`. + """ + hash_with_prefix = safe_encode_utf8(hash_with_prefix) + hashed_value = hash_with_prefix[len(HASH_PREFIX) :][: self.hash_size] or None secret = self.cipher_buffer[self.cipher_buffer_key].get(hashed_value) if not secret: try: - cipher = ( + data = ( self.crypt_model_cls.objects.using(self.using) .values("secret") .get(hash=hashed_value, algorithm=self.algorithm, mode=self.access_mode) ) - secret = cipher.get("secret") + secret = data.get("secret") self.cipher_buffer[self.cipher_buffer_key].update({hashed_value: secret}) except ObjectDoesNotExist: raise EncryptionError( - f"Failed to get secret for given {self.algorithm} " - f"{self.access_mode} hash. Got '{hash_with_prefix}'" + f"EncryptionError. Failed to get secret for given {self.algorithm} " + f"{self.access_mode} hash. Got '{str(hash_with_prefix)}'" ) return secret def is_encrypted(self, value: str | bytes | None) -> bool: """Returns True if value is encrypted. - Value can be: - * a string value + + An encrypted value starts with the hash_prefix. + + Inspects a value that is: + * a string value -> False * a well-formed hash - * a well-formed hash+secret. + * a well-formed hash_prefix + hash -> True + * a well-formed hash + secret. """ is_encrypted = False if value is not None: value = safe_encode_utf8(value) - if value[: len(HASH_PREFIX)] == HASH_PREFIX.encode(ENCODING): - has_secret = value[: len(CIPHER_PREFIX)] == CIPHER_PREFIX.encode(ENCODING) - has_valid_value_or_raise(value, self.hash_size, has_secret=has_secret) + if value.startswith(safe_encode_utf8(HASH_PREFIX)): + p = CipherParser(value, self.salt_key) + p.validate_hashed_value() is_encrypted = True return is_encrypted diff --git a/django_crypto_fields/fields/base_field.py b/django_crypto_fields/fields/base_field.py index 6baa9d6..d5392f7 100644 --- a/django_crypto_fields/fields/base_field.py +++ b/django_crypto_fields/fields/base_field.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from typing import TYPE_CHECKING from django.conf import settings @@ -10,11 +9,9 @@ from ..constants import ENCODING, HASH_PREFIX, LOCAL_MODE, RSA from ..exceptions import ( - CipherError, DjangoCryptoFieldsKeysNotLoaded, EncryptionError, EncryptionLookupError, - MalformedCiphertextError, ) from ..field_cryptor import FieldCryptor from ..keys import encryption_keys @@ -78,24 +75,10 @@ def formfield(self, **kwargs): return super(BaseField, self).formfield(**defaults) def decrypt(self, value): - decrypted_value = None if value is None or value in ["", b""]: - return value - try: + decrypted_value = value + else: decrypted_value = self.field_cryptor.decrypt(value) - if not decrypted_value: - self.readonly = True # did not decrypt - decrypted_value = value - except CipherError as e: - sys.stdout.write(style.ERROR(f"CipherError. Got {e}\n")) - sys.stdout.flush() - except EncryptionError as e: - sys.stdout.write(style.ERROR(f"EncryptionError. Got {e}\n")) - sys.stdout.flush() - raise - except MalformedCiphertextError as e: - sys.stdout.write(style.ERROR(f"MalformedCiphertextError. Got {e}\n")) - sys.stdout.flush() return decrypted_value def from_db_value(self, value, *args): @@ -119,24 +102,28 @@ def get_prep_lookup(self, lookup_type, value): Since the available value is the hash, only exact match lookup types are supported. """ - supported_lookups = ["iexact", "exact", "in", "isnull"] - if value is None or value in ["", b""] or lookup_type not in supported_lookups: + # TODO: why value in ["", b""] and not just value == b"" + if value is None or value in ["", b""]: pass else: - supported_lookups = ["iexact", "exact", "in", "isnull"] - if lookup_type not in supported_lookups: - raise EncryptionLookupError( - f"Field type only supports supports '{supported_lookups}' " - f"lookups. Got '{lookup_type}'" - ) + self.raise_if_unsupported_lookup(lookup_type) if lookup_type == "isnull": value = self.get_isnull_as_lookup(value) elif lookup_type == "in": - self.get_in_as_lookup(value) + value = self.get_in_as_lookup(value) else: value = HASH_PREFIX.encode(ENCODING) + self.field_cryptor.hash(value) return super().get_prep_lookup(lookup_type, value) + @staticmethod + def raise_if_unsupported_lookup(lookup_type): + supported_lookups = ["iexact", "exact", "in", "isnull"] + if lookup_type not in supported_lookups: + raise EncryptionLookupError( + f"Field type only supports supports '{supported_lookups}' " + f"lookups. Got '{lookup_type}'" + ) + def get_isnull_as_lookup(self, value): return value diff --git a/django_crypto_fields/tests/crypto_keys/django_crypto_fields b/django_crypto_fields/tests/crypto_keys/django_crypto_fields index c9db89d..7eb33c3 100644 --- a/django_crypto_fields/tests/crypto_keys/django_crypto_fields +++ b/django_crypto_fields/tests/crypto_keys/django_crypto_fields @@ -1,2 +1,2 @@ path,date -/Users/erikvw/source/edc_source/django-crypto-fields/django_crypto_fields/tests/crypto_keys,2024-03-19 04:17:34.681606+00:00 +/Users/erikvw/source/edc_source/django-crypto-fields/django_crypto_fields/tests/crypto_keys,2024-03-20 06:26:42.256762+00:00 diff --git a/django_crypto_fields/tests/tests/test_cryptor.py b/django_crypto_fields/tests/tests/test_cryptor.py index d911797..80cfce1 100644 --- a/django_crypto_fields/tests/tests/test_cryptor.py +++ b/django_crypto_fields/tests/tests/test_cryptor.py @@ -23,73 +23,73 @@ def test_mode_support(self): def test_encrypt_rsa(self): """Assert successful RSA roundtrip.""" - cryptor = Cryptor() - plaintext = "erik is a pleeb!!" for mode in encryption_keys.rsa_modes_supported: - ciphertext = cryptor.rsa_encrypt(plaintext, mode) - self.assertEqual(plaintext, cryptor.rsa_decrypt(ciphertext, mode)) + cryptor = Cryptor(algorithm=RSA, access_mode=mode) + plaintext = "erik is a pleeb!!" + ciphertext = cryptor.encrypt(plaintext) + self.assertEqual(plaintext, cryptor.decrypt(ciphertext)) def test_encrypt_aes(self): """Assert successful AES roundtrip.""" - cryptor = Cryptor() - plaintext = "erik is a pleeb!!" for mode in encryption_keys.aes_modes_supported: - ciphertext = cryptor.aes_encrypt(plaintext, mode) - self.assertEqual(plaintext, cryptor.aes_decrypt(ciphertext, mode)) + cryptor = Cryptor(algorithm=AES, access_mode=mode) + plaintext = "erik is a pleeb!!" + ciphertext = cryptor.encrypt(plaintext) + self.assertEqual(plaintext, cryptor.decrypt(ciphertext)) def test_encrypt_rsa_length(self): """Assert RSA raises EncryptionError if plaintext is too long.""" - cryptor = Cryptor() for mode in encryption_keys.rsa_modes_supported: + cryptor = Cryptor(algorithm=RSA, access_mode=mode) max_length = encryption_keys.rsa_key_info[mode]["max_message_length"] plaintext = "".join(["a" for _ in range(0, max_length)]) - cryptor.rsa_encrypt(plaintext, mode) - self.assertRaises(EncryptionError, cryptor.rsa_encrypt, plaintext + "a", mode) + cryptor.encrypt(plaintext) + self.assertRaises(EncryptionError, cryptor.encrypt, plaintext + "a") def test_rsa_encoding(self): """Assert successful RSA roundtrip of byte return str.""" - cryptor = Cryptor() + cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) plaintext = "erik is a pleeb!!∂ƒ˜∫˙ç".encode("utf-8") - ciphertext = cryptor.rsa_encrypt(plaintext, LOCAL_MODE) - t2 = type(cryptor.rsa_decrypt(ciphertext, LOCAL_MODE)) + ciphertext = cryptor.encrypt(plaintext) + t2 = type(cryptor.decrypt(ciphertext)) self.assertTrue(type(t2), "str") def test_rsa_type(self): """Assert fails for anything but str and byte.""" - cryptor = Cryptor() + cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) plaintext = 1 - self.assertRaises(EncryptionError, cryptor.rsa_encrypt, plaintext, LOCAL_MODE) + self.assertRaises(EncryptionError, cryptor.encrypt, plaintext) plaintext = 1.0 - self.assertRaises(EncryptionError, cryptor.rsa_encrypt, plaintext, LOCAL_MODE) + self.assertRaises(EncryptionError, cryptor.encrypt, plaintext) plaintext = datetime.today() - self.assertRaises(EncryptionError, cryptor.rsa_encrypt, plaintext, LOCAL_MODE) + self.assertRaises(EncryptionError, cryptor.encrypt, plaintext) def test_no_re_encrypt(self): """Assert raise error if attempting to encrypt a cipher.""" - cryptor = Cryptor() + cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) plaintext = "erik is a pleeb!!" - ciphertext1 = cryptor.rsa_encrypt(plaintext, LOCAL_MODE) - self.assertRaises(EncryptionError, cryptor.rsa_encrypt, ciphertext1, LOCAL_MODE) + ciphertext1 = cryptor.encrypt(plaintext) + self.assertRaises(EncryptionError, cryptor.encrypt, ciphertext1) def test_rsa_roundtrip(self): - cryptor = Cryptor() plaintext = ( "erik is a pleeb! ERIK IS A PLEEB 0123456789!@#$%^&*()" "_-+={[}]|\"':;>.<,?/~`±§" ) - for mode in cryptor.keys.get(RSA): + for mode in encryption_keys.rsa_modes_supported: + cryptor = Cryptor(algorithm=RSA, access_mode=mode) try: - ciphertext = cryptor.rsa_encrypt(plaintext, mode) + ciphertext = cryptor.encrypt(plaintext) except (AttributeError, TypeError) as e: self.fail(f"Failed encrypt: {mode} public ({e})\n") - self.assertTrue(plaintext == cryptor.rsa_decrypt(ciphertext, mode)) + self.assertTrue(plaintext == cryptor.decrypt(ciphertext)) def test_aes_roundtrip(self): - cryptor = Cryptor() plaintext = ( "erik is a pleeb!\nERIK IS A PLEEB\n0123456789!@#$%^&*()_" "-+={[}]|\"':;>.<,?/~`±§\n" ) - for mode in cryptor.keys.get(AES): - ciphertext = cryptor.aes_encrypt(plaintext, mode) + for mode in encryption_keys.aes_modes_supported: + cryptor = Cryptor(algorithm=AES, access_mode=mode) + ciphertext = cryptor.encrypt(plaintext) self.assertTrue(plaintext != ciphertext) - self.assertTrue(plaintext == cryptor.aes_decrypt(ciphertext, mode)) + self.assertTrue(plaintext == cryptor.decrypt(ciphertext)) diff --git a/django_crypto_fields/tests/tests/test_field_cryptor.py b/django_crypto_fields/tests/tests/test_field_cryptor.py index 4f44424..a45d801 100644 --- a/django_crypto_fields/tests/tests/test_field_cryptor.py +++ b/django_crypto_fields/tests/tests/test_field_cryptor.py @@ -1,18 +1,15 @@ from django.db import transaction from django.db.utils import IntegrityError -from django.test import TestCase +from django.test import TestCase, tag +from django_crypto_fields.cipher import CipherParser from django_crypto_fields.constants import AES, ENCODING, HASH_PREFIX, LOCAL_MODE, RSA from django_crypto_fields.cryptor import Cryptor from django_crypto_fields.exceptions import MalformedCiphertextError from django_crypto_fields.field_cryptor import FieldCryptor from django_crypto_fields.keys import encryption_keys +from django_crypto_fields.utils import has_valid_hash_or_raise -from ...utils import ( - has_valid_hash_or_raise, - has_valid_value_or_raise, - is_valid_ciphertext_or_raise, -) from ..models import TestModel @@ -67,23 +64,16 @@ def test_can_verify_hash_raises(self): MalformedCiphertextError, has_valid_hash_or_raise, value, field_cryptor.hash_size ) - def test_verify_with_secret(self): + def test_verify_hashed_value(self): field_cryptor = FieldCryptor(RSA, LOCAL_MODE) value = field_cryptor.encrypt("Mohammed Ali floats like a butterfly") - self.assertTrue(is_valid_ciphertext_or_raise(value, field_cryptor.hash_size)) - - def test_raises_on_verify_without_secret(self): - field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = HASH_PREFIX.encode(ENCODING) + field_cryptor.hash( - "Mohammed Ali floats like a butterfly" - ) - self.assertRaises( - MalformedCiphertextError, - is_valid_ciphertext_or_raise, - value, - field_cryptor.hash_size, - ) + p = CipherParser(value, field_cryptor.salt_key) + try: + p.validate_hashed_value() + except MalformedCiphertextError: + self.fail("MalformedCiphertextError unexpectedly raised") + @tag("6") def test_verify_is_encrypted(self): field_cryptor = FieldCryptor(RSA, LOCAL_MODE) value = HASH_PREFIX.encode(ENCODING) + field_cryptor.hash( @@ -100,12 +90,9 @@ def test_verify_is_not_encrypted(self): def test_verify_value(self): field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = "Mohammed Ali floats like a butterfly" - self.assertRaises( - MalformedCiphertextError, has_valid_value_or_raise, value, field_cryptor.hash_size - ) - value = field_cryptor.encrypt("Mohammed Ali floats like a butterfly") - self.assertEqual(value, has_valid_value_or_raise(value, field_cryptor.hash_size)) + cipher = field_cryptor.encrypt("Mohammed Ali floats like a butterfly") + p = CipherParser(cipher) + self.assertIsNotNone(p.secret) def test_rsa_field_encryption(self): """Assert successful RSA field roundtrip.""" @@ -166,41 +153,44 @@ def test_rsa_update_crypt_model(self): retrieved by hash, and decrypted. """ plaintext = "erik is a pleeb!!∂ƒ˜∫˙ç" - cryptor = Cryptor() + cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) field_cryptor = FieldCryptor(RSA, LOCAL_MODE) hashed_value = field_cryptor.hash(plaintext) - ciphertext1 = field_cryptor.encrypt(plaintext, update=False) - field_cryptor.update_crypt(ciphertext1) + field_cryptor.encrypt(plaintext, update=True) secret = field_cryptor.crypt_model_cls.objects.get(hash=hashed_value).secret field_cryptor.fetch_secret(HASH_PREFIX.encode(ENCODING) + hashed_value) - self.assertEqual(plaintext, cryptor.rsa_decrypt(secret, LOCAL_MODE)) + self.assertEqual(plaintext, cryptor.decrypt(secret)) def test_aes_update_crypt_model(self): """Asserts plaintext can be encrypted, saved to model, retrieved by hash, and decrypted. """ plaintext = "erik is a pleeb!!∂ƒ˜∫˙ç" - cryptor = Cryptor() field_cryptor = FieldCryptor(AES, LOCAL_MODE) + field_cryptor.encrypt(plaintext, update=True) hashed_value = field_cryptor.hash(plaintext) - ciphertext1 = field_cryptor.encrypt(plaintext, update=False) - field_cryptor.update_crypt(ciphertext1) secret = field_cryptor.crypt_model_cls.objects.get(hash=hashed_value).secret field_cryptor.fetch_secret(HASH_PREFIX.encode(ENCODING) + hashed_value) - self.assertEqual(plaintext, cryptor.aes_decrypt(secret, LOCAL_MODE)) + self.assertEqual(plaintext, field_cryptor.cryptor.decrypt(secret)) + @tag("3") def test_get_secret(self): """Asserts secret is returned either as None or the secret.""" - cryptor = Cryptor() field_cryptor = FieldCryptor(RSA, LOCAL_MODE) + plaintext = None - ciphertext = field_cryptor.encrypt(plaintext) - secret = field_cryptor.get_secret(ciphertext) + cipher = field_cryptor.encrypt(plaintext, update=True) + secret = CipherParser(cipher).secret self.assertIsNone(secret) + self.assertEqual(plaintext, field_cryptor.decrypt(secret)) + plaintext = "erik is a pleeb!!∂ƒ˜∫˙ç" - ciphertext = field_cryptor.encrypt(plaintext) - secret = field_cryptor.get_secret(ciphertext) - self.assertEqual(plaintext, cryptor.rsa_decrypt(secret, LOCAL_MODE)) + cipher = field_cryptor.encrypt(plaintext, update=True) + cipher = CipherParser(cipher) + self.assertIsNotNone(cipher.secret) + self.assertEqual( + plaintext, field_cryptor.decrypt(cipher.hash_prefix + cipher.hashed_value) + ) def test_rsa_field_as_none(self): """Asserts RSA roundtrip on None.""" diff --git a/django_crypto_fields/utils.py b/django_crypto_fields/utils.py index 7b3b66b..919384d 100644 --- a/django_crypto_fields/utils.py +++ b/django_crypto_fields/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +import binascii +import hashlib import sys from typing import TYPE_CHECKING, Type @@ -7,7 +9,7 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from .constants import CIPHER_PREFIX, ENCODING, HASH_PREFIX +from .constants import CIPHER_PREFIX, ENCODING, HASH_ALGORITHM, HASH_PREFIX, HASH_ROUNDS from .exceptions import MalformedCiphertextError if TYPE_CHECKING: @@ -103,15 +105,9 @@ def has_valid_hash_or_raise(ciphertext: bytes, hash_size: int) -> bool: raises an exception if not OK. """ ciphertext = safe_encode_utf8(ciphertext) - hash_prefix = HASH_PREFIX.encode(ENCODING) - if ciphertext == HASH_PREFIX.encode(ENCODING): - raise MalformedCiphertextError(f"Ciphertext has not hash. Got {ciphertext}") - if not ciphertext[: len(hash_prefix)] == hash_prefix: - raise MalformedCiphertextError( - f"Ciphertext must start with {hash_prefix}. " - f"Got {ciphertext[:len(hash_prefix)]}" - ) - hash_value = ciphertext[len(hash_prefix) :].split(CIPHER_PREFIX.encode(ENCODING))[0] + hash_value = ciphertext[len(safe_encode_utf8(HASH_PREFIX)) :].split( + safe_encode_utf8(CIPHER_PREFIX) + )[0] if len(hash_value) != hash_size: raise MalformedCiphertextError( "Expected hash prefix to be followed by a hash. Got something else or nothing" @@ -119,57 +115,11 @@ def has_valid_hash_or_raise(ciphertext: bytes, hash_size: int) -> bool: return True -def has_valid_value_or_raise( - value: str | bytes, hash_size: int, has_secret=None -) -> str | bytes: - """Encodes the value, validates its format, and returns it - or raises an exception. - - A value is either a value that can be encrypted or one that - already is encrypted. +def make_hash(value, salt_key) -> bytes: + """Returns a hexified hash of a plaintext value (as bytes). - * A value cannot just be equal to HASH_PREFIX or CIPHER_PREFIX; - * A value prefixed with HASH_PREFIX must be followed by a - valid hash (by length); - * A value prefixed with HASH_PREFIX + hashed_value + - CIPHER_PREFIX must be followed by some text; - * A value prefix by CIPHER_PREFIX must be followed by - some text; + The hashed value is used as a signature of the "secret". """ - has_secret = True if has_secret is None else has_secret encoded_value = safe_encode_utf8(value) - if encoded_value is not None and encoded_value != b"": - if encoded_value in [ - HASH_PREFIX.encode(ENCODING), - CIPHER_PREFIX.encode(ENCODING), - ]: - raise MalformedCiphertextError("Expected a value, got just the encryption prefix.") - has_valid_hash_or_raise(encoded_value, hash_size) - if has_secret: - is_valid_ciphertext_or_raise(encoded_value) - return value # note, is original passed value - - -def is_valid_ciphertext_or_raise(ciphertext: bytes, hash_size: int | None = None): - """Returns an unchanged ciphertext after verifying format hash_prefix + - hash + cipher_prefix + secret. - """ - try: - ciphertext.split(HASH_PREFIX.encode(ENCODING))[1] - except IndexError: - raise MalformedCiphertextError( - f"Malformed ciphertext. Expected prefixes {HASH_PREFIX}" - ) - try: - ciphertext.split(CIPHER_PREFIX.encode(ENCODING))[1] - except IndexError: - raise MalformedCiphertextError( - f"Malformed ciphertext. Expected prefixes {CIPHER_PREFIX}" - ) - if ciphertext[: len(HASH_PREFIX)] != HASH_PREFIX.encode(ENCODING): - raise MalformedCiphertextError( - f"Malformed ciphertext. Expected hash prefix {HASH_PREFIX}" - ) - if hash_size is not None: - has_valid_hash_or_raise(ciphertext, hash_size) - return ciphertext + dk = hashlib.pbkdf2_hmac(HASH_ALGORITHM, encoded_value, salt_key, HASH_ROUNDS) + return binascii.hexlify(dk)