diff --git a/.gitignore b/.gitignore index d2c49d7..94e78bb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ _version.py .etc/ .env/ django_crypto_fields/tests/etc/django_crypto_fields +django_crypto_fields/tests/crypto_keys/django_crypto_fields .pypirc .settings .project diff --git a/CHANGES b/CHANGES index f0bf637..b223bd8 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,19 @@ CHANGES +unreleased +---------- +- add support for postgreSQL +- add field classes for additional datatypes: + EncryptedIntegerField, EncryptedDecimalField, EncryptedDateField + EncryptedDateTimeField +- empty strings are now encrypted. Only None values are ignored. +- refactor signatures and typing between encrypt and decrypt + +0.4.1 +----- +- CACHE_CRYPTO_KEY_PREFIX, settings attribute to customize the + cache prefix. + 0.4.0 ----- - merge functionality of key_creator and key_files into keys module, diff --git a/README.rst b/README.rst index c9ff1aa..5f03948 100644 --- a/README.rst +++ b/README.rst @@ -12,9 +12,11 @@ version >= 0.3.8 < 0.4.0 version 0.4.0+ Python 3.11+ Django 4.2+ using mysql, cache framework +version 0.4.2+ + Python 3.11+ Django 4.2+ mysql or postgres, cache framework + * Uses ``pycryptodomex`` -* This module has known problems with `postgres`. (I hope to address this soon) Add encrypted field classes to your Django models where ``unique=True`` and ``unique_together`` attributes work as expected. @@ -68,6 +70,7 @@ Add KEY_PREFIX (optional, the default is "user"): # optional filename prefix for encryption keys files: KEY_PREFIX = 'bhp066' + Run ``migrate`` to create the ``django_crypto_fields.crypt`` table: .. code-block:: python diff --git a/django_crypto_fields/admin/__init__.py b/django_crypto_fields/admin/__init__.py new file mode 100644 index 0000000..f5faa03 --- /dev/null +++ b/django_crypto_fields/admin/__init__.py @@ -0,0 +1,2 @@ +from .crypt_model_admin import CryptModelAdmin +from .formfield_overrides import formfield_overrides diff --git a/django_crypto_fields/admin.py b/django_crypto_fields/admin/crypt_model_admin.py similarity index 86% rename from django_crypto_fields/admin.py rename to django_crypto_fields/admin/crypt_model_admin.py index 734df2b..75cb318 100644 --- a/django_crypto_fields/admin.py +++ b/django_crypto_fields/admin/crypt_model_admin.py @@ -1,7 +1,7 @@ from django.contrib import admin -from .admin_site import encryption_admin -from .utils import get_crypt_model_cls +from ..admin_site import encryption_admin +from ..utils import get_crypt_model_cls @admin.register(get_crypt_model_cls(), site=encryption_admin) diff --git a/django_crypto_fields/admin/formfield_overrides.py b/django_crypto_fields/admin/formfield_overrides.py new file mode 100644 index 0000000..f39a824 --- /dev/null +++ b/django_crypto_fields/admin/formfield_overrides.py @@ -0,0 +1,21 @@ +from django.contrib.admin.options import FORMFIELD_FOR_DBFIELD_DEFAULTS +from django.db import models + +from django_crypto_fields.fields import ( + EncryptedCharField, + EncryptedDateField, + EncryptedDateTimeField, + EncryptedIntegerField, +) + +FORMFIELD_FOR_DBFIELD_DEFAULTS.update( + { + EncryptedCharField: FORMFIELD_FOR_DBFIELD_DEFAULTS[models.CharField], + EncryptedDateField: FORMFIELD_FOR_DBFIELD_DEFAULTS[models.DateField], + EncryptedDateTimeField: FORMFIELD_FOR_DBFIELD_DEFAULTS[models.DateTimeField], + EncryptedIntegerField: FORMFIELD_FOR_DBFIELD_DEFAULTS[models.IntegerField], + } +) + + +formfield_overrides = FORMFIELD_FOR_DBFIELD_DEFAULTS diff --git a/django_crypto_fields/cipher/cipher.py b/django_crypto_fields/cipher/cipher.py index bc9478b..cdfe04d 100644 --- a/django_crypto_fields/cipher/cipher.py +++ b/django_crypto_fields/cipher/cipher.py @@ -3,7 +3,7 @@ from typing import Callable from ..constants import CIPHER_PREFIX, HASH_PREFIX -from ..utils import make_hash, safe_encode_utf8 +from ..utils import make_hash __all__ = ["Cipher"] @@ -25,17 +25,17 @@ def __init__( salt_key: bytes, encrypt: Callable[[bytes], bytes] | None = None, ): - encoded_value = safe_encode_utf8(value) + # encoded_value = safe_encode(value) self.hash_prefix = b"" self.hashed_value = b"" self.cipher_prefix = b"" self.secret = b"" if salt_key: - self.hash_prefix: bytes = safe_encode_utf8(HASH_PREFIX) - self.hashed_value: bytes = make_hash(encoded_value, salt_key) + self.hash_prefix: bytes = HASH_PREFIX.encode() + self.hashed_value: bytes = make_hash(value, salt_key) if encrypt: - self.secret = encrypt(encoded_value) - self.cipher_prefix: bytes = safe_encode_utf8(CIPHER_PREFIX) + self.secret = encrypt(value) + self.cipher_prefix: bytes = CIPHER_PREFIX.encode() @property def cipher(self) -> bytes: diff --git a/django_crypto_fields/cipher/cipher_parser.py b/django_crypto_fields/cipher/cipher_parser.py index ee04c61..48071f1 100644 --- a/django_crypto_fields/cipher/cipher_parser.py +++ b/django_crypto_fields/cipher/cipher_parser.py @@ -2,7 +2,7 @@ from ..constants import CIPHER_PREFIX, HASH_PREFIX from ..exceptions import MalformedCiphertextError -from ..utils import make_hash, safe_encode_utf8 +from ..utils import make_hash __all__ = ["CipherParser"] @@ -13,7 +13,7 @@ def __init__(self, cipher: bytes, salt_key: bytes | None = None): self._hash_prefix = None self._hashed_value = None self._secret = None - self.cipher = safe_encode_utf8(cipher) + self.cipher = cipher self.salt_key = salt_key self.validate_hashed_value() self.validate_secret() @@ -21,14 +21,14 @@ def __init__(self, cipher: bytes, salt_key: bytes | None = None): @property def hash_prefix(self) -> bytes | None: if self.cipher: - hash_prefix = safe_encode_utf8(HASH_PREFIX) + hash_prefix = HASH_PREFIX.encode() self._hash_prefix = hash_prefix if self.cipher.startswith(hash_prefix) else None return self._hash_prefix @property def cipher_prefix(self) -> bytes | None: if self.cipher: - cipher_prefix = safe_encode_utf8(CIPHER_PREFIX) + cipher_prefix = CIPHER_PREFIX.encode() self._cipher_prefix = cipher_prefix if cipher_prefix in self.cipher else None return self._cipher_prefix @@ -42,7 +42,7 @@ def hashed_value(self) -> bytes | None: @property def secret(self) -> bytes | None: - if self.cipher and safe_encode_utf8(CIPHER_PREFIX) in self.cipher: + if self.cipher and CIPHER_PREFIX.encode() in self.cipher: self._secret = self.cipher.split(self.cipher_prefix)[1] return self._secret diff --git a/django_crypto_fields/constants.py b/django_crypto_fields/constants.py index 4319b5e..92a0aeb 100644 --- a/django_crypto_fields/constants.py +++ b/django_crypto_fields/constants.py @@ -1,7 +1,6 @@ AES = "aes" CIPHER_BUFFER_SIZE = 10 CIPHER_PREFIX = "enc2:::" -ENCODING = "utf-8" HASH_ALGORITHM = "sha256" HASH_PREFIX = "enc1:::" HASH_ROUNDS = 100000 diff --git a/django_crypto_fields/cryptor.py b/django_crypto_fields/cryptor.py index d7ffeb9..f507167 100644 --- a/django_crypto_fields/cryptor.py +++ b/django_crypto_fields/cryptor.py @@ -5,15 +5,11 @@ from Cryptodome import Random from Cryptodome.Cipher import AES as AES_CIPHER -from .constants import AES, ENCODING, PRIVATE, PUBLIC, RSA +from .constants import AES, PRIVATE, PUBLIC, RSA +from .encoding import safe_encode from .exceptions import EncryptionError from .keys import encryption_keys -from .utils import ( - append_padding, - get_keypath_from_settings, - remove_padding, - safe_encode_utf8, -) +from .utils import append_padding, get_keypath_from_settings, remove_padding if TYPE_CHECKING: from Cryptodome.Cipher._mode_cbc import CbcMode @@ -43,7 +39,7 @@ def __init__(self, algorithm, access_mode) -> None: self.decrypt = getattr(self, f"_{self.algorithm.lower()}_decrypt") def _aes_encrypt(self, value: str | bytes) -> bytes: - encoded_value = safe_encode_utf8(value) + encoded_value = safe_encode(value) iv: bytes = Random.new().read(AES_CIPHER.block_size) cipher: CbcMode = AES_CIPHER.new(self.aes_key, self.aes_encryption_mode, iv) encoded_value = append_padding(encoded_value, cipher.block_size) @@ -55,13 +51,11 @@ def _aes_decrypt(self, secret: bytes) -> str: cipher: CbcMode = AES_CIPHER.new(self.aes_key, self.aes_encryption_mode, iv) encoded_value = cipher.decrypt(secret)[AES_CIPHER.block_size :] encoded_value = remove_padding(encoded_value) - value = encoded_value.decode() - return value + return encoded_value.decode() if encoded_value is not None else None def _rsa_encrypt(self, value: str | bytes) -> bytes: - encoded_value = safe_encode_utf8(value) try: - secret = self.rsa_public_key.encrypt(encoded_value) + secret = self.rsa_public_key.encrypt(safe_encode(value)) except (ValueError, TypeError) as e: raise EncryptionError(f"RSA encryption failed for value. Got '{e}'") return secret @@ -73,4 +67,4 @@ def _rsa_decrypt(self, secret: bytes) -> str: raise EncryptionError( f"{e} Using RSA from key_path=`{get_keypath_from_settings()}`." ) - return encoded_value.decode(ENCODING) + return encoded_value.decode() if encoded_value is not None else None diff --git a/django_crypto_fields/encoding.py b/django_crypto_fields/encoding.py new file mode 100644 index 0000000..3f5c5d6 --- /dev/null +++ b/django_crypto_fields/encoding.py @@ -0,0 +1,68 @@ +from datetime import date, datetime +from decimal import Decimal +from typing import Any + +from django_crypto_fields.exceptions import ( + DjangoCryptoFieldsDecodingError, + DjangoCryptoFieldsEncodingError, +) + +ENCODING = "utf-8" +DATETIME_STRING = "%Y-%m-%d %H:%M:%S %z" +DATE_STRING = "%Y-%m-%d" + + +def safe_encode(value: str | int | Decimal | float | date | datetime | bytes) -> bytes | None: + if value is None: + return None + if type(value) in [str, int, Decimal, float]: + value = str(value).encode() + elif type(value) in [date, datetime]: + value = safe_encode_date(value) + else: + raise DjangoCryptoFieldsEncodingError( + f"Value must be of type str, date or number. Got {value} is {type(value)}" + ) + return value + + +def decode_to_type(value: bytes, to_type: type) -> Any: + if to_type in [date, datetime]: + value = safe_decode_date(value) + elif to_type in [Decimal]: + value = Decimal(value.decode()) + elif to_type in [int, float]: + value = to_type(value.decode()) + elif to_type in [str]: + value = value.decode() + else: + raise DjangoCryptoFieldsDecodingError(f"Unhandled type. Got {to_type}.") + return value + + +def safe_decode_date(value: bytes) -> [date, datetime]: + """Convert bytes to string and confirm date/datetime format""" + value = value.decode() + try: + value = datetime.strptime(value, "%Y-%m-%d %H:%M:%S %z") + except ValueError: + try: + value = datetime.strptime(value, "%Y-%m-%d") + except ValueError: + raise DjangoCryptoFieldsDecodingError( + f"Decoded string value must be in ISO date or datetime format. Got {value}" + ) + return value + + +def safe_encode_date(value: [date, datetime]) -> bytes: + """Convert date to string and encode.""" + if type(value) is datetime: + value = datetime.strftime(value, DATETIME_STRING) + elif type(value) is date: + value = datetime.strftime(value, DATE_STRING) + else: + raise DjangoCryptoFieldsEncodingError( + f"Value must be either a date or datetime. Got {value}." + ) + return value.encode() diff --git a/django_crypto_fields/exceptions.py b/django_crypto_fields/exceptions.py index 6c78f04..2fe778e 100644 --- a/django_crypto_fields/exceptions.py +++ b/django_crypto_fields/exceptions.py @@ -34,6 +34,14 @@ class DjangoCryptoFieldsKeyPathDoesNotExist(Exception): pass +class DjangoCryptoFieldsEncodingError(Exception): + pass + + +class DjangoCryptoFieldsDecodingError(Exception): + pass + + class EncryptionError(Exception): pass diff --git a/django_crypto_fields/field_cryptor.py b/django_crypto_fields/field_cryptor.py index cf8c8f8..ccf94e0 100644 --- a/django_crypto_fields/field_cryptor.py +++ b/django_crypto_fields/field_cryptor.py @@ -10,7 +10,6 @@ from .constants import ( AES, CIPHER_PREFIX, - ENCODING, HASH_PREFIX, LOCAL_MODE, PRIVATE, @@ -19,9 +18,14 @@ SALT, ) from .cryptor import Cryptor -from .exceptions import EncryptionError, EncryptionKeyError, InvalidEncryptionAlgorithm +from .exceptions import ( + DjangoCryptoFieldsError, + EncryptionError, + EncryptionKeyError, + InvalidEncryptionAlgorithm, +) from .keys import encryption_keys -from .utils import get_crypt_model_cls, make_hash, safe_decode, safe_encode_utf8 +from .utils import get_crypt_model_cls, make_hash __all__ = ["FieldCryptor"] @@ -41,21 +45,33 @@ class FieldCryptor: cryptor_cls = Cryptor cipher_cls = Cipher - def __init__(self, algorithm: str, access_mode: str): + def __init__( + self, + algorithm: str, + access_mode: str, + ): self._using = None self._algorithm = None self._access_mode = None + self._cryptor = None self.algorithm = algorithm self.access_mode = access_mode self.cipher_buffer_key = b"{self.algorithm}_{self.access_mode}" self.cipher_buffer = {self.cipher_buffer_key: {}} self.keys = encryption_keys - self.cryptor = self.cryptor_cls(algorithm=algorithm, access_mode=access_mode) self.hash_size: int = len(self.hash("Foo")) def __repr__(self) -> str: return f"FieldCryptor(algorithm='{self.algorithm}', mode='{self.access_mode}')" + @property + def cryptor(self) -> Cryptor: + if not self._cryptor: + self._cryptor = self.cryptor_cls( + algorithm=self.algorithm, access_mode=self.access_mode + ) + return self._cryptor + @property def algorithm(self) -> str: return self._algorithm @@ -81,25 +97,22 @@ def access_mode(self, value: str): f"'{LOCAL_MODE}' or '{PRIVATE}' or {RESTRICTED_MODE}. Got {value}." ) - def hash(self, value) -> bytes: + def hash(self, value: str) -> bytes: return make_hash(value, self.salt_key) @property - def salt_key(self): - attr = "_".join([SALT, self.access_mode, PRIVATE]) + def salt_key(self) -> bytes: + attr: str = "_".join([SALT, self.access_mode, PRIVATE]) try: - salt = getattr(self.keys, attr) + salt: bytes = getattr(self.keys, attr) except AttributeError as e: raise EncryptionKeyError(f"Invalid key. Got {attr}. {e}") return salt - def encrypt(self, value: bytes | None, update: bool | None = None) -> bytes: + def encrypt(self, value: str | None, update: bool | None = None) -> bytes: """Returns either an RSA or AES cipher of the format hash_prefix + hashed_value + cipher_prefix + secret. - - * 'value' may or may not be encoded * 'update' if True updates the value in the Crypt model - * `cipher.cipher` instance formats the cipher. For example: enc1:::234234ed234a24enc2::\x0e\xb9\xae\x13s\x8d\xe7O\xbb\r\x99. * 'value' is not re-encrypted if already encrypted and properly @@ -107,14 +120,11 @@ def encrypt(self, value: bytes | None, update: bool | None = None) -> bytes: """ cipher = None update = True if update is None else update - encoded_value = safe_encode_utf8(value) - if encoded_value and not self.is_encrypted(encoded_value): - cipher = self.cipher_cls( - encoded_value, self.salt_key, encrypt=self.cryptor.encrypt - ) + if value is not None and not self.is_encrypted(value): + cipher = self.cipher_cls(value, self.salt_key, encrypt=self.cryptor.encrypt) if update: self.update_crypt(cipher) - return getattr(cipher, "cipher", encoded_value) + return getattr(cipher, "cipher", value) def decrypt(self, hash_with_prefix: bytes) -> str | None: """Returns decrypted secret or None. @@ -141,15 +151,9 @@ def using(self): @property def cache_key_prefix(self) -> bytes: - algorithm = safe_encode_utf8(self.algorithm) - access_mode = safe_encode_utf8(self.access_mode) - prefix = safe_encode_utf8( - getattr( - settings, - "CACHE_CRYPTO_KEY_PREFIX", - "crypto", - ) - ) + algorithm = self.algorithm.encode() + access_mode = self.access_mode.encode() + prefix = getattr(settings, "CACHE_CRYPTO_KEY_PREFIX", "crypto").encode() return prefix + algorithm + b"-" + access_mode + b"-" def update_crypt(self, cipher: Cipher) -> None: @@ -170,24 +174,17 @@ def update_crypt(self, cipher: Cipher) -> None: ) cache.set(self.cache_key_prefix + cipher.hashed_value, cipher.secret) - def get_prep_value(self, value: str | bytes | None) -> str | bytes | None: + def get_prep_value(self, value: str | None) -> str | None: """Returns the prefix + hash_value, an empty string, or None - as stored in the DB table column of your model's "encrypted" + prepared for saving into the column of your model's "encrypted" field. Used by field_cls.get_prep_value() """ - hash_with_prefix = None - encoded_value = safe_encode_utf8(value) - if encoded_value == b"": - encoded_value = "" - elif encoded_value is None: - pass - else: - cipher = self.encrypt(encoded_value) - hash_with_prefix = cipher.split(CIPHER_PREFIX.encode(ENCODING))[0] - hash_with_prefix = safe_decode(hash_with_prefix) - return hash_with_prefix or encoded_value + if value is not None: + cipher = self.encrypt(value) + return cipher.split(CIPHER_PREFIX.encode())[0].decode() + return value def fetch_secret(self, hash_with_prefix: bytes) -> bytes | None: """Fetch the secret from the DB or the buffer using @@ -198,7 +195,9 @@ def fetch_secret(self, hash_with_prefix: bytes) -> bytes | None: A secret is the segment to follow the `enc2:::`. """ secret = None - hash_with_prefix = safe_encode_utf8(hash_with_prefix) + # hash_with_prefix = self.safe_encode(hash_with_prefix.encode() + if type(hash_with_prefix) is not bytes: + raise DjangoCryptoFieldsError("hash_with_prefix must be bytes") if hashed_value := hash_with_prefix[len(HASH_PREFIX) :][: self.hash_size] or None: secret = cache.get(self.cache_key_prefix + hashed_value, None) if not secret: @@ -224,13 +223,14 @@ def fetch_secret(self, hash_with_prefix: bytes) -> bytes | None: return secret @staticmethod - def is_encrypted(value: bytes | None) -> bool: + def is_encrypted(value: str | bytes | None) -> bool: """Returns True if value is encrypted. An encrypted value starts with the hash_prefix. """ - encoded_value = safe_encode_utf8(value) - if encoded_value and encoded_value.startswith(safe_encode_utf8(HASH_PREFIX)): + if type(value) is not bytes: + value = value.encode() if value is not None else value + if value and value.startswith(HASH_PREFIX.encode()): return True return False diff --git a/django_crypto_fields/fields/__init__.py b/django_crypto_fields/fields/__init__.py index 6a328c8..2e9d28e 100644 --- a/django_crypto_fields/fields/__init__.py +++ b/django_crypto_fields/fields/__init__.py @@ -2,6 +2,8 @@ from .base_field import BaseField from .base_rsa_field import BaseRsaField from .encrypted_char_field import EncryptedCharField +from .encrypted_date_field import EncryptedDateField +from .encrypted_datetime_field import EncryptedDateTimeField from .encrypted_decimal_field import EncryptedDecimalField from .encrypted_integer_field import EncryptedIntegerField from .encrypted_text_field import EncryptedTextField @@ -10,14 +12,16 @@ from .lastname_field import LastnameField __all__ = [ - "BaseField", "BaseAesField", + "BaseField", "BaseRsaField", "EncryptedCharField", + "EncryptedDateField", + "EncryptedDateTimeField", "EncryptedDecimalField", "EncryptedIntegerField", "EncryptedTextField", "FirstnameField", - "LastnameField", "IdentityField", + "LastnameField", ] diff --git a/django_crypto_fields/fields/base_field.py b/django_crypto_fields/fields/base_field.py index 9e00c7f..d48c968 100644 --- a/django_crypto_fields/fields/base_field.py +++ b/django_crypto_fields/fields/base_field.py @@ -6,7 +6,7 @@ from django.db import models from django.forms import widgets -from ..constants import ENCODING, HASH_PREFIX, LOCAL_MODE, RSA +from ..constants import HASH_PREFIX, LOCAL_MODE, RSA from ..exceptions import ( DjangoCryptoFieldsKeysNotLoaded, EncryptionError, @@ -14,7 +14,6 @@ ) from ..field_cryptor import FieldCryptor from ..keys import encryption_keys -from ..utils import safe_encode_utf8 if TYPE_CHECKING: from ..keys import Keys @@ -26,20 +25,18 @@ class BaseField(models.Field): description = "Field class that stores values as encrypted" def __init__(self, algorithm: str, access_mode: str, *args, **kwargs): + self._field_cryptor = None + self._keys = None self.readonly = False - self.keys: Keys = encryption_keys - if not encryption_keys.loaded: - raise DjangoCryptoFieldsKeysNotLoaded( - "Encryption keys not loaded. You need to run initialize()" - ) self.algorithm = algorithm or RSA self.mode = access_mode or LOCAL_MODE + self.help_text: str = kwargs.get("help_text", "") if not self.help_text.startswith(" (Encryption:"): self.help_text = "{} (Encryption: {} {})".format( self.help_text.split(" (Encryption:")[0], algorithm.upper(), self.mode ) - self.field_cryptor = FieldCryptor(self.algorithm, self.mode) + min_length: int = len(HASH_PREFIX) + self.field_cryptor.hash_size max_length: int = kwargs.get("max_length", min_length) self.max_length: int = min_length if max_length < min_length else max_length @@ -52,11 +49,29 @@ def __init__(self, algorithm: str, access_mode: str, *args, **kwargs): self.__class__.__name__, max_message_length, self.max_length ) ) + kwargs["max_length"] = self.max_length kwargs["help_text"] = self.help_text kwargs.setdefault("blank", True) + super().__init__(*args, **kwargs) + @property + def keys(self) -> Keys: + if not self._keys: + if not encryption_keys.loaded: + raise DjangoCryptoFieldsKeysNotLoaded( + "Encryption keys not loaded. You need to run initialize()" + ) + self._keys = encryption_keys + return self._keys + + @property + def field_cryptor(self) -> FieldCryptor: + if not self._field_cryptor: + self._field_cryptor = FieldCryptor(self.algorithm, self.mode) + return self._field_cryptor + def get_internal_type(self): """This is a `CharField` as we only ever store the hash_prefix + hash, which is a fixed length char. @@ -70,30 +85,26 @@ def deconstruct(self): return name, path, args, kwargs def formfield(self, **kwargs): - defaults = kwargs - try: - show_encrypted_values = settings.SHOW_CRYPTO_FORM_DATA - except AttributeError: - show_encrypted_values = True - if not show_encrypted_values: - defaults = {"disabled": True, "widget": widgets.PasswordInput} - defaults.update(kwargs) - return super(BaseField, self).formfield(**defaults) - - def from_db_value(self, value: bytes | None, *args) -> bytes | str | None: + if not getattr(settings, "SHOW_CRYPTO_FORM_DATA", True): + kwargs.update({"disabled": True, "widget": widgets.PasswordInput}) + return super().formfield(**kwargs) + + def from_db_value(self, value: str | None, *args) -> str | None: """Returns the decrypted value, an empty string, or None.""" - value = safe_encode_utf8(value) + value = value.encode() if value is not None else value if value == b"": return "" - return self.field_cryptor.decrypt(value) if value else None + return self.field_cryptor.decrypt(value) if value is not None else None - def get_prep_value(self, value): + def get_prep_value(self, value: str | None) -> str | None: """Returns prefix + hash_value, an empty string, or None - for use as a parameter in a query. + for use as a parameter in a query or for saving into + the database. Note: partial matches do not work. See get_prep_lookup(). """ - return self.field_cryptor.get_prep_value(value) + value = self.field_cryptor.get_prep_value(value) + return super().get_prep_value(value) def get_prep_lookup(self, lookup_type, value): """Convert the value to a hash with prefix and pass to super. @@ -111,7 +122,7 @@ def get_prep_lookup(self, lookup_type, value): elif lookup_type == "in": value = self.get_in_as_lookup(value) else: - value = HASH_PREFIX.encode(ENCODING) + self.field_cryptor.hash(value) + value = HASH_PREFIX.encode() + self.field_cryptor.hash(value) return super().get_prep_lookup(lookup_type, value) @staticmethod @@ -129,7 +140,7 @@ def get_isnull_as_lookup(self, value): def get_in_as_lookup(self, values): hashed_values = [] for value in values: - hashed_values.append(HASH_PREFIX.encode(ENCODING) + self.field_cryptor.hash(value)) + hashed_values.append(HASH_PREFIX.encode() + self.field_cryptor.hash(value)) return hashed_values def mask(self, value, mask=None): diff --git a/django_crypto_fields/fields/encrypted_date_field.py b/django_crypto_fields/fields/encrypted_date_field.py new file mode 100644 index 0000000..60350c0 --- /dev/null +++ b/django_crypto_fields/fields/encrypted_date_field.py @@ -0,0 +1,98 @@ +from datetime import date, datetime + +from django import forms +from django.core.exceptions import ValidationError +from django.db.models.fields import DateTimeCheckMixin, _to_naive +from django.utils.dateparse import parse_date +from django.utils.translation import gettext as _ + +from ..encoding import DATE_STRING +from .base_rsa_field import BaseRsaField + +__all__ = ["EncryptedDateField"] + + +class EncryptedDateField(DateTimeCheckMixin, BaseRsaField): + description = "local-rsa encrypted field for 'DateField'" + default_error_messages = { + "invalid": _( + "“%(value)s” value has an invalid date format. It must be " "in YYYY-MM-DD format." + ), + "invalid_date": _( + "“%(value)s” value has the correct format (YYYY-MM-DD) " + "but it is an invalid date." + ), + } + + def __init__(self, auto_now=False, auto_now_add=False, **kwargs): + self.auto_now, self.auto_now_add = auto_now, auto_now_add + if auto_now or auto_now_add: + kwargs["editable"] = False + kwargs["blank"] = True + super().__init__(**kwargs) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if self.auto_now: + kwargs["auto_now"] = True + if self.auto_now_add: + kwargs["auto_now_add"] = True + if self.auto_now or self.auto_now_add: + del kwargs["editable"] + del kwargs["blank"] + return name, path, args, kwargs + + def _check_fix_default_value(self): + """ + Warn that using an actual date or datetime value is probably wrong; + it's only evaluated on server startup. + """ + if not self.has_default(): + return [] + + value = self.default + if isinstance(value, datetime): + value = _to_naive(value).date() + elif isinstance(value, date): + pass + else: + return [] + return self._check_if_value_fixed(value) + + def pre_save(self, model_instance, add): + if self.auto_now or (self.auto_now_add and add): + value = date.today() + setattr(model_instance, self.attname, value) + return value + else: + return super().pre_save(model_instance, add) + + def get_prep_value(self, value: date | None) -> str | None: + if value: + value = datetime.strftime(value, DATE_STRING) + return super().get_prep_value(value) + + def to_python(self, value: str | date | None) -> date | None: + if value is None: + return value + if type(value) is date: + return value + try: + parsed = parse_date(value) + if parsed is not None: + return parsed + except ValueError: + raise ValidationError( + self.error_messages["invalid_date"], + code="invalid_date", + params={"value": value}, + ) + raise ValidationError( + self.error_messages["invalid"], + code="invalid", + params={"value": value}, + ) + + def formfield(self, **kwargs): + kwargs.update(form_class=forms.DateField) + return super().formfield(**kwargs) diff --git a/django_crypto_fields/fields/encrypted_datetime_field.py b/django_crypto_fields/fields/encrypted_datetime_field.py new file mode 100644 index 0000000..c4042f7 --- /dev/null +++ b/django_crypto_fields/fields/encrypted_datetime_field.py @@ -0,0 +1,99 @@ +from datetime import date, datetime + +from django import forms +from django.core.exceptions import ValidationError +from django.db.models.fields import DateTimeCheckMixin +from django.utils import timezone +from django.utils.dateparse import parse_date +from django.utils.translation import gettext as _ + +from ..encoding import DATETIME_STRING +from .base_rsa_field import BaseRsaField + +__all__ = ["EncryptedDateTimeField"] + + +class EncryptedDateTimeField(DateTimeCheckMixin, BaseRsaField): + description = "local-rsa encrypted field for 'DateTimeField'" + default_error_messages = { + "invalid": _( + "“%(value)s” value has an invalid format. It must be in " + "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format." + ), + "invalid_date": _( + "“%(value)s” value has the correct format " + "(YYYY-MM-DD) but it is an invalid date." + ), + "invalid_datetime": _( + "“%(value)s” value has the correct format " + "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " + "but it is an invalid date/time." + ), + } + + def __init__(self, auto_now=False, auto_now_add=False, **kwargs): + self.auto_now, self.auto_now_add = auto_now, auto_now_add + if auto_now or auto_now_add: + kwargs["editable"] = False + kwargs["blank"] = True + super().__init__(**kwargs) + + def _check_fix_default_value(self): + """ + Warn that using an actual date or datetime value is probably wrong; + it's only evaluated on server startup. + """ + if not self.has_default(): + return [] + + value = self.default + if isinstance(value, (datetime, date)): + return self._check_if_value_fixed(value) + return [] + + def from_db_value(self, value: str | None, *args) -> datetime | None: + """Returns the decrypted value, an empty string, or None.""" + if value is None: + return None + date_string = self.field_cryptor.decrypt(value.encode()) + if not date_string: + return None + return datetime.strptime(date_string, DATETIME_STRING) + + def get_prep_value(self, value: date | None) -> str | None: + if value: + value = datetime.strftime(value, DATETIME_STRING) + return super().get_prep_value(value) + + def to_python(self, value: str | datetime | None) -> date | None: + if value is None: + return value + if type(value) is datetime: + return value + try: + parsed = parse_date(value) + if parsed is not None: + return parsed + except ValueError: + raise ValidationError( + self.error_messages["invalid_date"], + code="invalid_date", + params={"value": value}, + ) + raise ValidationError( + self.error_messages["invalid"], + code="invalid", + params={"value": value}, + ) + + def pre_save(self, model_instance, add): + if self.auto_now or (self.auto_now_add and add): + value = timezone.now() + setattr(model_instance, self.attname, value) + return value + else: + return super().pre_save(model_instance, add) + + def formfield(self, **kwargs): + kwargs.update(form_class=forms.SplitDateTimeField) + return super().formfield(**kwargs) diff --git a/django_crypto_fields/fields/encrypted_decimal_field.py b/django_crypto_fields/fields/encrypted_decimal_field.py index b7f1c54..1b30d96 100644 --- a/django_crypto_fields/fields/encrypted_decimal_field.py +++ b/django_crypto_fields/fields/encrypted_decimal_field.py @@ -1,4 +1,6 @@ -from decimal import Decimal +from decimal import Decimal, InvalidOperation + +from django.core.exceptions import ValidationError from .base_rsa_field import BaseRsaField @@ -8,56 +10,33 @@ class EncryptedDecimalField(BaseRsaField): description = "local-rsa encrypted field for 'IntegerField'" - def __init__(self, *args, **kwargs): - self.validate_max_digits(kwargs) - self.validate_decimal_places(kwargs) - decimal_decimal_places = int(kwargs.get("decimal_places")) - decimal_max_digits = int(kwargs.get("max_digits")) - del kwargs["decimal_places"] - del kwargs["max_digits"] + def __init__(self, *args, max_digits=None, decimal_places=None, **kwargs): + self.decimal_places = int(decimal_places or 2) + self.max_digits = int(max_digits or 8) super().__init__(*args, **kwargs) - self.decimal_decimal_places = decimal_decimal_places - self.decimal_max_digits = decimal_max_digits - - def to_string(self, value): - if isinstance(value, (str,)): - raise TypeError("Expected basestring. Got {0}".format(value)) - return str(value) - - def to_python(self, value): - """Returns as integer""" - retval = super(EncryptedDecimalField, self).to_python(value) - if retval: - if not self.field_cryptor.is_encrypted(retval): - retval = Decimal(retval).to_eng_string() - return retval - - @staticmethod - def validate_max_digits(kwargs): - if "max_digits" not in kwargs: - raise AttributeError( - "EncryptedDecimalField requires attribute 'max_digits. " "Got none" - ) - elif "max_digits" in kwargs: - try: - int(kwargs.get("max_digits")) - except (TypeError, ValueError): - raise ValueError( - f"EncryptedDecimalField attribute 'max_digits must be an " - f'integer. Got {kwargs.get("max_digits")}' - ) - - @staticmethod - def validate_decimal_places(kwargs): - if "decimal_places" not in kwargs: - raise AttributeError( - "EncryptedDecimalField requires attribute 'decimal_places. " "Got none" + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + kwargs["decimal_places"] = self.decimal_places + kwargs["max_digits"] = self.max_digits + return name, path, args, kwargs + + def get_prep_value(self, value: Decimal | None) -> str | None: + if value is not None: + value = str(value) + return super().get_prep_value(value) + + def to_python(self, value: str | Decimal | None) -> Decimal | None: + if value is None: + return value + if isinstance(value, Decimal): + return value + try: + value = Decimal(value) + except InvalidOperation: + raise ValidationError( + "Invalid value. Expected a decimal", + code="invalid", + params={"value": value}, ) - elif "decimal_places" in kwargs: - try: - int(kwargs.get("decimal_places")) - except (TypeError, ValueError): - raise ValueError( - f"EncryptedDecimalField attribute 'decimal_places must be an " - f'integer. Got {kwargs.get("decimal_places")}' - ) + return value diff --git a/django_crypto_fields/fields/encrypted_integer_field.py b/django_crypto_fields/fields/encrypted_integer_field.py index 3a5f8bd..8d66498 100644 --- a/django_crypto_fields/fields/encrypted_integer_field.py +++ b/django_crypto_fields/fields/encrypted_integer_field.py @@ -1,3 +1,5 @@ +from django.core.exceptions import ValidationError + from .base_rsa_field import BaseRsaField __all__ = ["EncryptedIntegerField"] @@ -6,8 +8,22 @@ class EncryptedIntegerField(BaseRsaField): description = "local-rsa encrypted field for 'IntegerField'" - def to_python(self, value): - """Returns as integer""" - retval = super().to_python(value) - retval = int(retval) - return retval + def get_prep_value(self, value: int | None) -> str | None: + if value is not None: + value = str(value) + return super().get_prep_value(value) + + def to_python(self, value: str | int | None) -> int | None: + if value is None: + return value + if isinstance(value, int): + return value + try: + value = int(value) + except ValueError: + raise ValidationError( + "Invalid value. Expected a whole number (integer)", + code="invalid", + params={"value": value}, + ) + return value diff --git a/django_crypto_fields/tests/crypto_keys/django_crypto_fields b/django_crypto_fields/tests/crypto_keys/django_crypto_fields deleted file mode 100644 index d33aac3..0000000 --- a/django_crypto_fields/tests/crypto_keys/django_crypto_fields +++ /dev/null @@ -1,2 +0,0 @@ -path,date -/Users/erikvw/source/edc_source/django-crypto-fields/django_crypto_fields/tests/crypto_keys,2024-03-25 06:53:16.697430+00:00 diff --git a/django_crypto_fields/tests/tests/test_cryptor.py b/django_crypto_fields/tests/tests/test_cryptor.py index 80cfce1..f06695e 100644 --- a/django_crypto_fields/tests/tests/test_cryptor.py +++ b/django_crypto_fields/tests/tests/test_cryptor.py @@ -1,10 +1,13 @@ -from datetime import datetime +from datetime import date, datetime -from django.test import TestCase +from django.test import TestCase, tag from django_crypto_fields.constants import AES, LOCAL_MODE, RESTRICTED_MODE, RSA from django_crypto_fields.cryptor import Cryptor -from django_crypto_fields.exceptions import EncryptionError +from django_crypto_fields.exceptions import ( + DjangoCryptoFieldsEncodingError, + EncryptionError, +) from django_crypto_fields.keys import encryption_keys @@ -46,30 +49,31 @@ def test_encrypt_rsa_length(self): cryptor.encrypt(plaintext) self.assertRaises(EncryptionError, cryptor.encrypt, plaintext + "a") + @tag("1") def test_rsa_encoding(self): """Assert successful RSA roundtrip of byte return str.""" cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) - plaintext = "erik is a pleeb!!∂ƒ˜∫˙ç".encode("utf-8") + plaintext = "erik is a pleeb!!∂ƒ˜∫˙ç" ciphertext = cryptor.encrypt(plaintext) t2 = type(cryptor.decrypt(ciphertext)) self.assertTrue(type(t2), "str") def test_rsa_type(self): - """Assert fails for anything but str and byte.""" cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) - plaintext = 1 - self.assertRaises(EncryptionError, cryptor.encrypt, plaintext) - plaintext = 1.0 - self.assertRaises(EncryptionError, cryptor.encrypt, plaintext) - plaintext = datetime.today() - self.assertRaises(EncryptionError, cryptor.encrypt, plaintext) + for value in ["", 1, 1.0, date.today(), datetime.today()]: + with self.subTest(value=value): + try: + cryptor.encrypt(value) + except EncryptionError as e: + self.fail(e) + @tag("1") def test_no_re_encrypt(self): """Assert raise error if attempting to encrypt a cipher.""" cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) plaintext = "erik is a pleeb!!" ciphertext1 = cryptor.encrypt(plaintext) - self.assertRaises(EncryptionError, cryptor.encrypt, ciphertext1) + self.assertRaises(DjangoCryptoFieldsEncodingError, cryptor.encrypt, ciphertext1) def test_rsa_roundtrip(self): plaintext = ( diff --git a/django_crypto_fields/tests/tests/test_field_cryptor.py b/django_crypto_fields/tests/tests/test_field_cryptor.py index 8814967..c4a5977 100644 --- a/django_crypto_fields/tests/tests/test_field_cryptor.py +++ b/django_crypto_fields/tests/tests/test_field_cryptor.py @@ -1,19 +1,21 @@ +from datetime import date + from django.db import transaction from django.db.utils import IntegrityError from django.test import TestCase, tag from django_crypto_fields.cipher import CipherParser -from django_crypto_fields.constants import AES, ENCODING, HASH_PREFIX, LOCAL_MODE, RSA +from django_crypto_fields.constants import AES, HASH_PREFIX, LOCAL_MODE, RSA from django_crypto_fields.cryptor import Cryptor -from django_crypto_fields.exceptions import MalformedCiphertextError +from django_crypto_fields.exceptions import ( + DjangoCryptoFieldsError, + MalformedCiphertextError, +) from django_crypto_fields.field_cryptor import FieldCryptor from django_crypto_fields.keys import encryption_keys -from django_crypto_fields.utils import ( - get_crypt_model_cls, - has_valid_hash_or_raise, - safe_encode_utf8, -) +from django_crypto_fields.utils import get_crypt_model_cls +from ...encoding import safe_encode from ..models import TestModel @@ -26,48 +28,6 @@ def setUp(self): def tearDown(self): encryption_keys.reset_and_delete_keys(verbose=False) - def test_can_verify_hash_as_none(self): - field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = None - self.assertRaises(TypeError, has_valid_hash_or_raise, value, field_cryptor.hash_size) - value = "" - self.assertRaises( - MalformedCiphertextError, has_valid_hash_or_raise, value, field_cryptor.hash_size - ) - value = b"" - self.assertRaises( - MalformedCiphertextError, has_valid_hash_or_raise, value, field_cryptor.hash_size - ) - - def test_can_verify_hash_not_raises(self): - """Assert does NOT raise on valid hash.""" - field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = HASH_PREFIX.encode(ENCODING) + field_cryptor.hash( - "Mohammed Ali floats like a butterfly" - ) - try: - has_valid_hash_or_raise(value, field_cryptor.hash_size) - except MalformedCiphertextError: - self.fail("MalformedCiphertextError unexpectedly raised") - else: - pass - - def test_can_verify_hash_raises(self): - """Assert does raise on invalid hash.""" - field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = "erik" # missing prefix - self.assertRaises( - MalformedCiphertextError, has_valid_hash_or_raise, value, field_cryptor.hash_size - ) - value = HASH_PREFIX + "blah" # incorrect prefix - self.assertRaises( - MalformedCiphertextError, has_valid_hash_or_raise, value, field_cryptor.hash_size - ) - value = HASH_PREFIX # no hash following prefix - self.assertRaises( - MalformedCiphertextError, has_valid_hash_or_raise, value, field_cryptor.hash_size - ) - def test_verify_hashed_value(self): field_cryptor = FieldCryptor(RSA, LOCAL_MODE) value = field_cryptor.encrypt("Mohammed Ali floats like a butterfly") @@ -79,7 +39,7 @@ def test_verify_hashed_value(self): def test_verify_is_encrypted(self): field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = HASH_PREFIX.encode(ENCODING) + field_cryptor.hash( + value = HASH_PREFIX.encode() + field_cryptor.hash( "Mohammed Ali floats like a butterfly" ) self.assertTrue(field_cryptor.is_encrypted(value)) @@ -108,52 +68,47 @@ def test_rsa_field_encryption(self): def test_rsa_field_encryption_update_secret(self): """Assert successful RSA field roundtrip for same value.""" value = "erik is a pleeb!!∂ƒ˜∫˙ç" - encoded_value = safe_encode_utf8(value) for mode in encryption_keys.get(RSA): field_cryptor = FieldCryptor(RSA, mode) - cipher1 = field_cryptor.encrypt(encoded_value) - self.assertEqual(encoded_value.decode(), field_cryptor.decrypt(cipher1)) - cipher2 = field_cryptor.encrypt(encoded_value) - self.assertEqual(encoded_value.decode(), field_cryptor.decrypt(cipher2)) + cipher1 = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(cipher1)) + cipher2 = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(cipher2)) self.assertFalse(cipher1 == cipher2) def test_aes_field_encryption(self): """Assert successful RSA field roundtrip.""" value = "erik is a pleeb!!" - encoded_value = safe_encode_utf8(value) for mode in encryption_keys.get(AES): field_cryptor = FieldCryptor(AES, mode) - ciphertext = field_cryptor.encrypt(encoded_value) - self.assertEqual(encoded_value.decode(), field_cryptor.decrypt(ciphertext)) + ciphertext = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(ciphertext)) def test_rsa_field_encryption_encoded(self): """Assert successful RSA field roundtrip.""" value = "erik is a pleeb!!∂ƒ˜∫˙ç" - encoded_value = safe_encode_utf8(value) for mode in encryption_keys.get(RSA): field_cryptor = FieldCryptor(RSA, mode) - ciphertext = field_cryptor.encrypt(encoded_value) - self.assertEqual(encoded_value.decode(), field_cryptor.decrypt(ciphertext)) + ciphertext = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(ciphertext)) def test_aes_field_encryption_encoded(self): """Assert successful AES field roundtrip.""" value = "erik is a pleeb!!∂ƒ˜∫˙ç" - encoded_value = safe_encode_utf8(value) for mode in encryption_keys.get(AES): field_cryptor = FieldCryptor(AES, mode) - ciphertext = field_cryptor.encrypt(encoded_value) - self.assertEqual(encoded_value.decode(), field_cryptor.decrypt(ciphertext)) + ciphertext = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(ciphertext)) def test_aes_field_encryption_update_secret(self): """Assert successful AES field roundtrip for same value.""" value = "erik is a pleeb!!∂ƒ˜∫˙ç" - encoded_value = safe_encode_utf8(value) for mode in encryption_keys.get(AES): field_cryptor = FieldCryptor(AES, mode) - ciphertext1 = field_cryptor.encrypt(encoded_value) - self.assertEqual(encoded_value.decode(), field_cryptor.decrypt(ciphertext1)) - ciphertext2 = field_cryptor.encrypt(encoded_value) - self.assertEqual(encoded_value.decode(), field_cryptor.decrypt(ciphertext2)) + ciphertext1 = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(ciphertext1)) + ciphertext2 = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(ciphertext2)) self.assertFalse(ciphertext1 == ciphertext2) def test_rsa_update_crypt_model(self): @@ -161,13 +116,12 @@ def test_rsa_update_crypt_model(self): retrieved by hash, and decrypted. """ value = "erik is a pleeb!!∂ƒ˜∫˙ç" - encoded_value = safe_encode_utf8(value) cryptor = Cryptor(algorithm=RSA, access_mode=LOCAL_MODE) field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - hashed_value = field_cryptor.hash(encoded_value) - field_cryptor.encrypt(encoded_value, update=True) + hashed_value = field_cryptor.hash(value) + field_cryptor.encrypt(value, update=True) secret = get_crypt_model_cls().objects.get(hash=hashed_value.decode()).secret - field_cryptor.fetch_secret(HASH_PREFIX.encode(ENCODING) + hashed_value) + field_cryptor.fetch_secret(HASH_PREFIX.encode() + hashed_value) self.assertEqual(value, cryptor.decrypt(secret)) def test_aes_update_crypt_model(self): @@ -175,13 +129,12 @@ def test_aes_update_crypt_model(self): retrieved by hash, and decrypted. """ value = "erik is a pleeb!!∂ƒ˜∫˙ç" - encoded_value = safe_encode_utf8(value) field_cryptor = FieldCryptor(AES, LOCAL_MODE) - field_cryptor.encrypt(encoded_value, update=True) - hashed_value = field_cryptor.hash(encoded_value) + field_cryptor.encrypt(value, update=True) + hashed_value = field_cryptor.hash(value) secret = get_crypt_model_cls().objects.get(hash=hashed_value.decode()).secret - field_cryptor.fetch_secret(HASH_PREFIX.encode(ENCODING) + hashed_value) - self.assertEqual(encoded_value.decode(), field_cryptor.cryptor.decrypt(secret)) + field_cryptor.fetch_secret(HASH_PREFIX.encode() + hashed_value) + self.assertEqual(value, field_cryptor.cryptor.decrypt(secret)) def test_none_value_is_not_added_to_crypt_model(self): self.assertEqual(get_crypt_model_cls().objects.all().count(), 0) @@ -193,27 +146,26 @@ def test_none_value_is_not_added_to_crypt_model(self): self.assertIsNone(p.hash_prefix) self.assertEqual(get_crypt_model_cls().objects.all().count(), 0) - def test_empty_value_is_not_added_to_crypt_model(self): + def test_empty_value_is_added_once_to_crypt_model(self): self.assertEqual(get_crypt_model_cls().objects.all().count(), 0) field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = b"" + value = "" cipher = field_cryptor.encrypt(value, update=True) p = CipherParser(cipher) - self.assertIsNone(p.secret) - self.assertIsNone(p.hash_prefix) - self.assertEqual(get_crypt_model_cls().objects.all().count(), 0) + self.assertIsNotNone(p.secret) + self.assertIsNotNone(p.hash_prefix) + self.assertEqual(get_crypt_model_cls().objects.all().count(), 1) + field_cryptor.encrypt(value, update=True) + self.assertEqual(get_crypt_model_cls().objects.all().count(), 1) def test_get_secret(self): self.assertEqual(get_crypt_model_cls().objects.all().count(), 0) field_cryptor = FieldCryptor(RSA, LOCAL_MODE) value = "erik is a pleeb!!∂ƒ˜∫˙ç" - encoded_value = safe_encode_utf8(value) - cipher = field_cryptor.encrypt(encoded_value, update=True) + cipher = field_cryptor.encrypt(value, update=True) p = CipherParser(cipher) self.assertIsNotNone(p.secret) - self.assertEqual( - encoded_value.decode(), field_cryptor.decrypt(p.hash_prefix + p.hashed_value) - ) + self.assertEqual(value, field_cryptor.decrypt(p.hash_prefix + p.hashed_value)) self.assertEqual(get_crypt_model_cls().objects.all().count(), 1) def test_rsa_field_as_none_raises(self): @@ -221,28 +173,50 @@ def test_rsa_field_as_none_raises(self): field_cryptor = FieldCryptor(RSA, LOCAL_MODE) value = None cipher = field_cryptor.encrypt(value) - self.assertRaises(TypeError, field_cryptor.decrypt, cipher) + self.assertRaises(DjangoCryptoFieldsError, field_cryptor.decrypt, cipher) def test_aes_field_as_none_raises(self): """Asserts AES cannot roundtrip on None.""" field_cryptor = FieldCryptor(AES, LOCAL_MODE) value = None cipher = field_cryptor.encrypt(value) - self.assertRaises(TypeError, field_cryptor.decrypt, cipher) + self.assertRaises(DjangoCryptoFieldsError, field_cryptor.decrypt, cipher) - def test_rsa_field_as_empty(self): - """Asserts RSA cannot roundtrip on None.""" + def test_rsa_field_with_empty_string(self): field_cryptor = FieldCryptor(RSA, LOCAL_MODE) - value = b"" + value = "" cipher = field_cryptor.encrypt(value) - self.assertIsNone(field_cryptor.decrypt(cipher)) + self.assertEqual(value, field_cryptor.decrypt(cipher)) - def test_aes_field_as_empty(self): - """Asserts AES cannot roundtrip on None.""" + def test_aes_field_with_empty_string(self): + field_cryptor = FieldCryptor(AES, LOCAL_MODE) + value = "" + cipher = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(cipher)) + + def test_rsa_field_with_zero(self): + field_cryptor = FieldCryptor(RSA, LOCAL_MODE) + value = safe_encode(0).decode() + cipher = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(cipher)) + + def test_aes_field_with_zero(self): + field_cryptor = FieldCryptor(AES, LOCAL_MODE) + value = safe_encode(0).decode() + cipher = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(cipher)) + + def test_rsa_field_with_date(self): + field_cryptor = FieldCryptor(RSA, LOCAL_MODE) + value = safe_encode(date.today()).decode() + cipher = field_cryptor.encrypt(value) + self.assertEqual(value, field_cryptor.decrypt(cipher)) + + def test_aes_field_with_date(self): field_cryptor = FieldCryptor(AES, LOCAL_MODE) - value = b"" + value = safe_encode(date.today()).decode() cipher = field_cryptor.encrypt(value) - self.assertIsNone(field_cryptor.decrypt(cipher)) + self.assertEqual(value, field_cryptor.decrypt(cipher)) @tag("6") def test_model_with_encrypted_fields(self): @@ -260,11 +234,11 @@ def test_model_with_encrypted_fields(self): self.assertEqual(getattr(test_model, attr1), value1) self.assertEqual(get_crypt_model_cls().objects.all().count(), 3) - @tag("6") def test_model_with_encrypted_fields_empty_string(self): """Asserts roundtrip via a model with encrypted fields. - Note: comment is an empty string + Note: firstname is None and comment is an empty string + Expect identity and comment to be added to Crypt """ data = dict(firstname=None, identity="123456789", comment="") TestModel.objects.create(**data) @@ -273,7 +247,7 @@ def test_model_with_encrypted_fields_empty_string(self): test_model = TestModel.objects.get(**{attr: value}) for attr1, value1 in data.items(): self.assertEqual(getattr(test_model, attr1), value1) - self.assertEqual(get_crypt_model_cls().objects.all().count(), 1) + self.assertEqual(get_crypt_model_cls().objects.all().count(), 2) def test_model_with_encrypted_fields_as_none(self): """Asserts roundtrip via a model with encrypted fields. diff --git a/django_crypto_fields/utils.py b/django_crypto_fields/utils.py index 9ebf8dd..58da0ec 100644 --- a/django_crypto_fields/utils.py +++ b/django_crypto_fields/utils.py @@ -3,13 +3,16 @@ import binascii import hashlib import sys +from datetime import date, datetime +from decimal import Decimal from typing import TYPE_CHECKING, Type from django.apps import apps as django_apps from django.conf import settings -from .constants import CIPHER_PREFIX, ENCODING, HASH_ALGORITHM, HASH_PREFIX, HASH_ROUNDS -from .exceptions import EncryptionError, MalformedCiphertextError +from .constants import HASH_ALGORITHM, HASH_ROUNDS +from .encoding import safe_encode_date +from .exceptions import DjangoCryptoFieldsError, EncryptionError if TYPE_CHECKING: from django.db import models @@ -71,44 +74,21 @@ def get_key_prefix_from_settings() -> str: return getattr(settings, "DJANGO_CRYPTO_FIELDS_KEY_PREFIX", "user") -def safe_encode_utf8(value) -> bytes: - try: - value = value.encode(ENCODING) - except AttributeError: - pass - return value - - -def safe_decode(value) -> bytes: - try: - value.decode() - except AttributeError: - pass - return value - - -def has_valid_hash_or_raise(ciphertext: bytes, hash_size: int) -> bool: - """Verifies hash segment of ciphertext (bytes) and - raises an exception if not OK. - """ - ciphertext = safe_encode_utf8(ciphertext) - hash_value = ciphertext[len(safe_encode_utf8(HASH_PREFIX)) :].split( - safe_encode_utf8(CIPHER_PREFIX) - )[0] - if len(hash_value) != hash_size: - raise MalformedCiphertextError( - "Expected hash prefix to be followed by a hash. Got something else or nothing" - ) - return True - - -def make_hash(value, salt_key) -> bytes: +def make_hash( + value: str | date | datetime | int | float | Decimal, salt_key: bytes +) -> bytes | None: """Returns a hexified hash of a plaintext value (as bytes). The hashed value is used as a signature of the "secret". """ - encoded_value = safe_encode_utf8(value) - dk = hashlib.pbkdf2_hmac(HASH_ALGORITHM, encoded_value, salt_key, HASH_ROUNDS) + if value is None: + raise DjangoCryptoFieldsError("Cannot hash None value") + else: + if type(value) in [date, datetime]: + encoded_value = safe_encode_date(value) + else: + encoded_value = value.encode() + dk: bytes = hashlib.pbkdf2_hmac(HASH_ALGORITHM, encoded_value, salt_key, HASH_ROUNDS) return binascii.hexlify(dk)