From d9de74141e8a920940f1b91ed0a3ccb835b55729 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Sat, 13 Feb 2021 08:58:24 +0000 Subject: [PATCH] Fixed #32442 -- Used converters on returning fields from INSERT statements. --- django/db/models/sql/compiler.py | 21 +++++++++++++++------ tests/custom_pk/fields.py | 14 +++++++++++++- tests/custom_pk/models.py | 8 ++++++-- tests/custom_pk/tests.py | 15 +++++++++++++-- 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index f02199d97c..11ad4fde90 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1405,6 +1405,7 @@ class SQLInsertCompiler(SQLCompiler): returning_fields and len(self.query.objs) != 1 and not self.connection.features.can_return_rows_from_bulk_insert ) + opts = self.query.get_meta() self.returning_fields = returning_fields with self.connection.cursor() as cursor: for sql, params in self.as_sql(): @@ -1412,13 +1413,21 @@ class SQLInsertCompiler(SQLCompiler): if not self.returning_fields: return [] if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1: - return self.connection.ops.fetch_returned_insert_rows(cursor) - if self.connection.features.can_return_columns_from_insert: + rows = self.connection.ops.fetch_returned_insert_rows(cursor) + elif self.connection.features.can_return_columns_from_insert: assert len(self.query.objs) == 1 - return [self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)] - return [(self.connection.ops.last_insert_id( - cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column - ),)] + rows = [self.connection.ops.fetch_returned_insert_columns( + cursor, self.returning_params, + )] + else: + rows = [(self.connection.ops.last_insert_id( + cursor, opts.db_table, opts.pk.column, + ),)] + cols = [field.get_col(opts.db_table) for field in self.returning_fields] + converters = self.get_converters(cols) + if converters: + rows = list(self.apply_converters(rows, converters)) + return rows class SQLDeleteCompiler(SQLCompiler): diff --git a/tests/custom_pk/fields.py b/tests/custom_pk/fields.py index 5bd249df3c..bc7259300b 100644 --- a/tests/custom_pk/fields.py +++ b/tests/custom_pk/fields.py @@ -20,7 +20,7 @@ class MyWrapper: return self.value == other -class MyAutoField(models.CharField): +class MyWrapperField(models.CharField): def __init__(self, *args, **kwargs): kwargs['max_length'] = 10 @@ -58,3 +58,15 @@ class MyAutoField(models.CharField): if isinstance(value, MyWrapper): return str(value) return value + + +class MyAutoField(models.BigAutoField): + def from_db_value(self, value, expression, connection): + if value is None: + return None + return MyWrapper(value) + + def get_prep_value(self, value): + if value is None: + return None + return int(value) diff --git a/tests/custom_pk/models.py b/tests/custom_pk/models.py index edfc6712f3..d9a73885f2 100644 --- a/tests/custom_pk/models.py +++ b/tests/custom_pk/models.py @@ -7,7 +7,7 @@ this behavior by explicitly adding ``primary_key=True`` to a field. from django.db import models -from .fields import MyAutoField +from .fields import MyAutoField, MyWrapperField class Employee(models.Model): @@ -31,8 +31,12 @@ class Business(models.Model): class Bar(models.Model): - id = MyAutoField(primary_key=True, db_index=True) + id = MyWrapperField(primary_key=True, db_index=True) class Foo(models.Model): bar = models.ForeignKey(Bar, models.CASCADE) + + +class CustomAutoFieldModel(models.Model): + id = MyAutoField(primary_key=True) diff --git a/tests/custom_pk/tests.py b/tests/custom_pk/tests.py index abb4ccd90b..cbf1fd2cb6 100644 --- a/tests/custom_pk/tests.py +++ b/tests/custom_pk/tests.py @@ -1,7 +1,8 @@ from django.db import IntegrityError, transaction -from django.test import TestCase, skipIfDBFeature +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature -from .models import Bar, Business, Employee, Foo +from .fields import MyWrapper +from .models import Bar, Business, CustomAutoFieldModel, Employee, Foo class BasicCustomPKTests(TestCase): @@ -230,3 +231,13 @@ class CustomPKTests(TestCase): with self.assertRaises(IntegrityError): with transaction.atomic(): Employee.objects.create(first_name="Tom", last_name="Smith") + + def test_auto_field_subclass_create(self): + obj = CustomAutoFieldModel.objects.create() + self.assertIsInstance(obj.id, MyWrapper) + + @skipUnlessDBFeature('can_return_rows_from_bulk_insert') + def test_auto_field_subclass_bulk_create(self): + obj = CustomAutoFieldModel() + CustomAutoFieldModel.objects.bulk_create([obj]) + self.assertIsInstance(obj.id, MyWrapper)