From edb3b901a0efe04027e0b61ceb33dbd9a9219ddf Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Tue, 13 Jan 2026 21:39:11 +0000 Subject: [PATCH] fix(sql): apply converters to returning --- django/db/models/sql/compiler.py | 22 +++++++++++++++------- tests/queries/models.py | 29 +++++++++++++++++++++++++++++ tests/queries/test_db_returning.py | 30 ++++++++++++++++++++++++++++-- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index f02199d97ca1..407616f6c837 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -6,7 +6,8 @@ from django.core.exceptions import EmptyResultSet, FieldError from django.db import DatabaseError, NotSupportedError from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value +from django.db.models.expressions import Col, F, OrderBy, RawSQL, Ref, Value +from django.db.models.fields import Field from django.db.models.functions import Cast, Random from django.db.models.query_utils import Q, select_related_descend from django.db.models.sql.constants import ( @@ -1101,6 +1102,8 @@ def get_converters(self, expressions): converters = {} for i, expression in enumerate(expressions): if expression: + if isinstance(expression, Field): + expression = Col(None, expression) backend_converters = self.connection.ops.get_db_converters(expression) field_converters = expression.get_db_converters(self.connection) if backend_converters or field_converters: @@ -1412,13 +1415,18 @@ def execute_sql(self, returning_fields=None): if not self.returning_fields: return [] if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1: - return self.connection.ops.fetch_returned_insert_rows(cursor) - if self.connection.features.can_return_columns_from_insert: + rows = list(self.connection.ops.fetch_returned_insert_rows(cursor)) + elif self.connection.features.can_return_columns_from_insert: assert len(self.query.objs) == 1 - return [self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)] - return [(self.connection.ops.last_insert_id( - cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column - ),)] + rows = [self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)] + else: + rows = [(self.connection.ops.last_insert_id( + cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column + ),)] + converters = self.get_converters(self.returning_fields) + if converters: + rows = [tuple(row) for row in self.apply_converters(rows, converters)] + return rows class SQLDeleteCompiler(SQLCompiler): diff --git a/tests/queries/models.py b/tests/queries/models.py index 383f633be9a7..017e417ac1f4 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -735,6 +735,31 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +class WrappedId(int): + pass + + +class WrappedBigAutoField(models.BigAutoField): + + def from_db_value(self, value, expression, connection): + if value is None: + return value + return WrappedId(value) + + def to_python(self, value): + if value is None or isinstance(value, WrappedId): + return value + value = super().to_python(value) + if value is None: + return value + return WrappedId(value) + + def get_prep_value(self, value): + if isinstance(value, WrappedId): + value = int(value) + return super().get_prep_value(value) + + class ReturningModel(models.Model): created = CreatedField(editable=False) @@ -743,6 +768,10 @@ class NonIntegerPKReturningModel(models.Model): created = CreatedField(editable=False, primary_key=True) +class ConvertedPKReturningModel(models.Model): + id = WrappedBigAutoField(primary_key=True) + + class JSONFieldNullable(models.Model): json_field = models.JSONField(blank=True, null=True) diff --git a/tests/queries/test_db_returning.py b/tests/queries/test_db_returning.py index 9ba352a7ab7f..4208f19d792f 100644 --- a/tests/queries/test_db_returning.py +++ b/tests/queries/test_db_returning.py @@ -1,10 +1,16 @@ import datetime from django.db import connection -from django.test import TestCase, skipUnlessDBFeature +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext -from .models import DumbCategory, NonIntegerPKReturningModel, ReturningModel +from .models import ( + ConvertedPKReturningModel, + DumbCategory, + NonIntegerPKReturningModel, + ReturningModel, + WrappedId, +) @skipUnlessDBFeature('can_return_columns_from_insert') @@ -41,6 +47,10 @@ def test_insert_returning_multiple(self): self.assertTrue(obj.pk) self.assertIsInstance(obj.created, datetime.datetime) + def test_insert_returning_runs_field_converters(self): + obj = ConvertedPKReturningModel.objects.create() + self.assertIsInstance(obj.id, WrappedId) + @skipUnlessDBFeature('can_return_rows_from_bulk_insert') def test_bulk_insert(self): objs = [ReturningModel(), ReturningModel(pk=2 ** 11), ReturningModel()] @@ -49,3 +59,19 @@ def test_bulk_insert(self): with self.subTest(obj=obj): self.assertTrue(obj.pk) self.assertIsInstance(obj.created, datetime.datetime) + + @skipUnlessDBFeature('can_return_rows_from_bulk_insert') + def test_bulk_insert_runs_field_converters(self): + objs = [ConvertedPKReturningModel(), ConvertedPKReturningModel()] + ConvertedPKReturningModel.objects.bulk_create(objs) + for obj in objs: + with self.subTest(obj=obj): + self.assertIsInstance(obj.id, WrappedId) + + +@skipIfDBFeature('can_return_columns_from_insert') +class ReturningLastInsertIdTests(TestCase): + + def test_insert_uses_field_converters(self): + obj = ConvertedPKReturningModel.objects.create() + self.assertIsInstance(obj.id, WrappedId)