Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions netbox_custom_objects/migrations/0002_ensure_fk_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from django.db import migrations


def ensure_existing_fk_constraints(apps, schema_editor):
"""
Go through all existing CustomObjectType models and ensure FK constraints
are properly set for any OBJECT type fields.
"""
# Import the actual model class (not the historical version) to access methods
from netbox_custom_objects.models import CustomObjectType

for custom_object_type in CustomObjectType.objects.all():
try:
model = custom_object_type.get_model()
custom_object_type._ensure_all_fk_constraints(model)
except Exception as e:
print(f"Warning: Could not ensure FK constraints for {custom_object_type}: {e}")


class Migration(migrations.Migration):

dependencies = [
('netbox_custom_objects', '0001_initial'),
]

operations = [
migrations.RunPython(
ensure_existing_fk_constraints,
reverse_code=migrations.RunPython.noop
),
]
138 changes: 131 additions & 7 deletions netbox_custom_objects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,12 @@ def get_model(
"""

# Double-check pattern: check cache again after acquiring lock
if self.is_model_cached(self.id) and not no_cache:
model = self.get_cached_model(self.id)
return model
with self._global_lock:
if self.is_model_cached(self.id) and not no_cache:
model = self.get_cached_model(self.id)
return model

# Generate the model inside the lock to prevent race conditions
# Generate the model outside the lock to avoid holding it during expensive operations
model_name = self.get_table_model_name(self.pk)

# TODO: Add other fields with "index" specified
Expand Down Expand Up @@ -523,8 +524,9 @@ def wrapped_post_through_setup(self, cls):

self._after_model_generation(attrs, model)

# Cache the generated model
self._model_cache[self.id] = model
# Cache the generated model (protected by lock for thread safety)
with self._global_lock:
self._model_cache[self.id] = model

# Do the clear cache now that we have it in the cache so there
# is no recursion.
Expand All @@ -538,11 +540,76 @@ def wrapped_post_through_setup(self, cls):

def get_model_with_serializer(self):
from netbox_custom_objects.api.serializers import get_serializer_class
model = self.get_model(no_cache=True)
model = self.get_model()
get_serializer_class(model)
self.register_custom_object_search_index(model)
return model

def _ensure_field_fk_constraint(self, model, field_name):
"""
Ensure that a foreign key constraint is properly created at the database level
for a specific OBJECT type field with ON DELETE CASCADE. This is necessary because
models are created with managed=False, which may not properly create FK constraints
with CASCADE behavior.

:param model: The model containing the field
:param field_name: The name of the field to ensure FK constraint for
"""
table_name = self.get_database_table_name()

# Get the model field
try:
model_field = model._meta.get_field(field_name)
except Exception:
return

if not (hasattr(model_field, 'remote_field') and model_field.remote_field):
return

# Get the referenced table
related_model = model_field.remote_field.model
related_table = related_model._meta.db_table
column_name = model_field.column

with connection.cursor() as cursor:
# Drop existing FK constraint if it exists
# Query for existing constraints
cursor.execute("""
SELECT constraint_name
FROM information_schema.table_constraints
WHERE table_name = %s
AND constraint_type = 'FOREIGN KEY'
AND constraint_name LIKE %s
""", [table_name, f"%{column_name}%"])

for row in cursor.fetchall():
constraint_name = row[0]
cursor.execute(f'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{constraint_name}"')

# Create new FK constraint with ON DELETE CASCADE
constraint_name = f"{table_name}_{column_name}_fk_cascade"
cursor.execute(f"""
ALTER TABLE "{table_name}"
ADD CONSTRAINT "{constraint_name}"
FOREIGN KEY ("{column_name}")
REFERENCES "{related_table}" ("id")
ON DELETE CASCADE
DEFERRABLE INITIALLY DEFERRED
""")

def _ensure_all_fk_constraints(self, model):
"""
Ensure that foreign key constraints are properly created at the database level
for ALL OBJECT type fields with ON DELETE CASCADE.

:param model: The model to ensure FK constraints for
"""
# Query all OBJECT type fields for this CustomObjectType
object_fields = self.fields.filter(type=CustomFieldTypeChoices.TYPE_OBJECT)

for field in object_fields:
self._ensure_field_fk_constraint(model, field.name)

def create_model(self):
from netbox_custom_objects.api.serializers import get_serializer_class
# Get the model and ensure it's registered
Expand Down Expand Up @@ -796,6 +863,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._name = self.__dict__.get("name")
self._original_name = self.name
self._original_type = self.type
self._original_related_object_type_id = self.related_object_type_id

def __str__(self):
return self.label or self.name.replace("_", " ").capitalize()
Expand Down Expand Up @@ -1482,11 +1551,35 @@ def save(self, *args, **kwargs):
# Normal field alteration
schema_editor.alter_field(model, old_field, model_field)

# Ensure FK constraints are properly created for OBJECT fields with CASCADE behavior
should_ensure_fk = False
if self.type == CustomFieldTypeChoices.TYPE_OBJECT:
if self._state.adding:
should_ensure_fk = True
else:
# Existing field - check if type changed to OBJECT or related_object_type changed
type_changed_to_object = (
self._original_type != CustomFieldTypeChoices.TYPE_OBJECT
and self.type == CustomFieldTypeChoices.TYPE_OBJECT
)
related_object_changed = (
self._original_type == CustomFieldTypeChoices.TYPE_OBJECT
and self.related_object_type_id != self._original_related_object_type_id
)
should_ensure_fk = type_changed_to_object or related_object_changed

# Clear and refresh the model cache for this CustomObjectType when a field is modified
self.custom_object_type.clear_model_cache(self.custom_object_type.id)

super().save(*args, **kwargs)

# Ensure FK constraints AFTER the transaction commits to avoid "pending trigger events" errors
if should_ensure_fk:
def ensure_constraint():
self.custom_object_type._ensure_field_fk_constraint(model, self.name)

transaction.on_commit(ensure_constraint)

# Reregister SearchIndex with new set of searchable fields
self.custom_object_type.register_custom_object_search_index(model)

Expand Down Expand Up @@ -1540,3 +1633,34 @@ class CustomObjectObjectType(ObjectType):

class Meta:
proxy = True


# Signal handlers to clear model cache when definitions change


@receiver(post_save, sender=CustomObjectType)
def clear_cache_on_custom_object_type_save(sender, instance, **kwargs):
"""
Clear the model cache when a CustomObjectType is saved.
"""
CustomObjectType.clear_model_cache(instance.id)


@receiver(post_save, sender=CustomObjectTypeField)
def clear_cache_on_field_save(sender, instance, **kwargs):
"""
Clear the model cache when a CustomObjectTypeField is saved.
This ensures the parent CustomObjectType's model is regenerated.
"""
if instance.custom_object_type_id:
CustomObjectType.clear_model_cache(instance.custom_object_type_id)


@receiver(pre_delete, sender=CustomObjectTypeField)
def clear_cache_on_field_delete(sender, instance, **kwargs):
"""
Clear the model cache when a CustomObjectTypeField is deleted.
This is in addition to the manual clear in the delete() method.
"""
if instance.custom_object_type_id:
CustomObjectType.clear_model_cache(instance.custom_object_type_id)
7 changes: 7 additions & 0 deletions netbox_custom_objects/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def setUp(self):
self.client = Client()
self.client.force_login(self.user)

def tearDown(self):
"""Clean up after each test."""
# Clear the model cache to ensure test isolation
# This prevents cached models with deleted fields from affecting other tests
CustomObjectType.clear_model_cache()
super().tearDown()

@classmethod
def create_custom_object_type(cls, **kwargs):
"""Helper method to create a custom object type."""
Expand Down