Fixed #32442 -- Used converters on returning fields from INSERT statements.

This commit is contained in:
Adam Johnson 2021-02-13 08:58:24 +00:00 committed by Mariusz Felisiak
parent 619f26d289
commit d9de74141e
4 changed files with 47 additions and 11 deletions

View File

@ -1405,6 +1405,7 @@ class SQLInsertCompiler(SQLCompiler):
returning_fields and len(self.query.objs) != 1 and returning_fields and len(self.query.objs) != 1 and
not self.connection.features.can_return_rows_from_bulk_insert not self.connection.features.can_return_rows_from_bulk_insert
) )
opts = self.query.get_meta()
self.returning_fields = returning_fields self.returning_fields = returning_fields
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
for sql, params in self.as_sql(): for sql, params in self.as_sql():
@ -1412,13 +1413,21 @@ class SQLInsertCompiler(SQLCompiler):
if not self.returning_fields: if not self.returning_fields:
return [] return []
if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1: if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
return self.connection.ops.fetch_returned_insert_rows(cursor) rows = self.connection.ops.fetch_returned_insert_rows(cursor)
if self.connection.features.can_return_columns_from_insert: elif self.connection.features.can_return_columns_from_insert:
assert len(self.query.objs) == 1 assert len(self.query.objs) == 1
return [self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)] rows = [self.connection.ops.fetch_returned_insert_columns(
return [(self.connection.ops.last_insert_id( cursor, self.returning_params,
cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column )]
else:
rows = [(self.connection.ops.last_insert_id(
cursor, opts.db_table, opts.pk.column,
),)] ),)]
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
converters = self.get_converters(cols)
if converters:
rows = list(self.apply_converters(rows, converters))
return rows
class SQLDeleteCompiler(SQLCompiler): class SQLDeleteCompiler(SQLCompiler):

View File

@ -20,7 +20,7 @@ class MyWrapper:
return self.value == other return self.value == other
class MyAutoField(models.CharField): class MyWrapperField(models.CharField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['max_length'] = 10 kwargs['max_length'] = 10
@ -58,3 +58,15 @@ class MyAutoField(models.CharField):
if isinstance(value, MyWrapper): if isinstance(value, MyWrapper):
return str(value) return str(value)
return value return value
class MyAutoField(models.BigAutoField):
def from_db_value(self, value, expression, connection):
if value is None:
return None
return MyWrapper(value)
def get_prep_value(self, value):
if value is None:
return None
return int(value)

View File

@ -7,7 +7,7 @@ this behavior by explicitly adding ``primary_key=True`` to a field.
from django.db import models from django.db import models
from .fields import MyAutoField from .fields import MyAutoField, MyWrapperField
class Employee(models.Model): class Employee(models.Model):
@ -31,8 +31,12 @@ class Business(models.Model):
class Bar(models.Model): class Bar(models.Model):
id = MyAutoField(primary_key=True, db_index=True) id = MyWrapperField(primary_key=True, db_index=True)
class Foo(models.Model): class Foo(models.Model):
bar = models.ForeignKey(Bar, models.CASCADE) bar = models.ForeignKey(Bar, models.CASCADE)
class CustomAutoFieldModel(models.Model):
id = MyAutoField(primary_key=True)

View File

@ -1,7 +1,8 @@
from django.db import IntegrityError, transaction from django.db import IntegrityError, transaction
from django.test import TestCase, skipIfDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from .models import Bar, Business, Employee, Foo from .fields import MyWrapper
from .models import Bar, Business, CustomAutoFieldModel, Employee, Foo
class BasicCustomPKTests(TestCase): class BasicCustomPKTests(TestCase):
@ -230,3 +231,13 @@ class CustomPKTests(TestCase):
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
with transaction.atomic(): with transaction.atomic():
Employee.objects.create(first_name="Tom", last_name="Smith") Employee.objects.create(first_name="Tom", last_name="Smith")
def test_auto_field_subclass_create(self):
obj = CustomAutoFieldModel.objects.create()
self.assertIsInstance(obj.id, MyWrapper)
@skipUnlessDBFeature('can_return_rows_from_bulk_insert')
def test_auto_field_subclass_bulk_create(self):
obj = CustomAutoFieldModel()
CustomAutoFieldModel.objects.bulk_create([obj])
self.assertIsInstance(obj.id, MyWrapper)