From 2c0bd0fac6713bd9edafd60b4a12fbac13f3cb28 Mon Sep 17 00:00:00 2001 From: evanandrews-xrd <74883242+evanandrews-xrd@users.noreply.github.com> Date: Tue, 11 Jun 2024 21:00:28 +1000 Subject: [PATCH] Use `attname` to get pk value (#650) * Use `attname` to get pk value * Add tests for model as primary key * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add clarifying comments to _get_pk_value --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- auditlog/models.py | 7 ++++--- auditlog_tests/models.py | 21 +++++++++++++++++++++ auditlog_tests/tests.py | 40 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/auditlog/models.py b/auditlog/models.py index 43d2baa9..1ae8dbcc 100644 --- a/auditlog/models.py +++ b/auditlog/models.py @@ -213,12 +213,13 @@ def _get_pk_value(self, instance): :type instance: Model :return: The primary key value of the given model instance. """ - pk_field = instance._meta.pk.name + # Should be equivalent to `instance.pk`. + pk_field = instance._meta.pk.attname pk = getattr(instance, pk_field, None) # Check to make sure that we got a pk not a model object. - if isinstance(pk, models.Model): - pk = self._get_pk_value(pk) + # Should be guaranteed as we used `attname` above, not `name`. + assert not isinstance(pk, models.Model) return pk def _get_serialized_data_or_none(self, instance): diff --git a/auditlog_tests/models.py b/auditlog_tests/models.py index 2492c6bc..ffacc838 100644 --- a/auditlog_tests/models.py +++ b/auditlog_tests/models.py @@ -57,6 +57,26 @@ class UUIDPrimaryKeyModel(models.Model): history = AuditlogHistoryField(delete_related=True, pk_indexable=False) +class ModelPrimaryKeyModel(models.Model): + """ + A model with another model as primary key. + """ + + key = models.OneToOneField( + "SimpleModel", + primary_key=True, + on_delete=models.CASCADE, + related_name="reverse_primary_key", + ) + + text = models.TextField(blank=True) + boolean = models.BooleanField(default=False) + integer = models.IntegerField(blank=True, null=True) + datetime = models.DateTimeField(auto_now=True) + + history = AuditlogHistoryField(delete_related=True, pk_indexable=False) + + class ProxyModel(SimpleModel): """ A model that is a proxy for another model. @@ -338,6 +358,7 @@ class AutoManyRelatedModel(models.Model): auditlog.register(AltPrimaryKeyModel) auditlog.register(UUIDPrimaryKeyModel) +auditlog.register(ModelPrimaryKeyModel) auditlog.register(ProxyModel) auditlog.register(RelatedModel) auditlog.register(ManyRelatedModel) diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py index 2e853d42..0ff960b4 100644 --- a/auditlog_tests/tests.py +++ b/auditlog_tests/tests.py @@ -20,7 +20,7 @@ from django.db.models import JSONField, Value from django.db.models.functions import Now from django.db.models.signals import pre_save -from django.test import RequestFactory, TestCase, override_settings +from django.test import RequestFactory, TestCase, TransactionTestCase, override_settings from django.urls import resolve, reverse from django.utils import dateformat, formats from django.utils import timezone as django_timezone @@ -46,6 +46,7 @@ JSONModel, ManyRelatedModel, ManyRelatedOtherModel, + ModelPrimaryKeyModel, NoDeleteHistoryModel, NullableJSONModel, PostgresArrayFieldModel, @@ -332,6 +333,41 @@ class UUIDPrimaryKeyModelModelWithActorTest( pass +class ModelPrimaryKeyModelBase(SimpleModelTest): + def make_object(self): + self.key = super().make_object() + return ModelPrimaryKeyModel.objects.create(key=self.key, text="I am strange.") + + +class ModelPrimaryKeyModelTest(NoActorMixin, ModelPrimaryKeyModelBase): + pass + + +class ModelPrimaryKeyModelWithActorTest(WithActorMixin, ModelPrimaryKeyModelBase): + pass + + +# Must inherit from TransactionTestCase to use self.assertNumQueries. +class ModelPrimaryKeyTest(TransactionTestCase): + def test_get_pk_value(self): + """ + Test that the primary key can be retrieved without additional database queries. + """ + key = SimpleModel.objects.create(text="I am not difficult.") + obj = ModelPrimaryKeyModel.objects.create(key=key, text="I am strange.") + # Refresh the object so the primary key object is not cached. + obj.refresh_from_db() + with self.assertNumQueries(0): + pk = LogEntry.objects._get_pk_value(obj) + self.assertEqual(pk, obj.pk) + self.assertEqual(pk, key.pk) + # Sanity check: verify accessing obj.key causes database access. + with self.assertNumQueries(1): + pk = obj.key.pk + self.assertEqual(pk, obj.pk) + self.assertEqual(pk, key.pk) + + class ProxyModelBase(SimpleModelTest): def make_object(self): return ProxyModel.objects.create(text="I am not what you think.") @@ -1206,7 +1242,7 @@ def test_register_models_register_app(self): self.assertTrue(self.test_auditlog.contains(SimpleExcludeModel)) self.assertTrue(self.test_auditlog.contains(ChoicesFieldModel)) - self.assertEqual(len(self.test_auditlog.get_models()), 26) + self.assertEqual(len(self.test_auditlog.get_models()), 27) def test_register_models_register_model_with_attrs(self): self.test_auditlog._register_models(