mirror of https://github.com/django/django.git
Fixed #32442 -- Used converters on returning fields from INSERT statements.
This commit is contained in:
parent
619f26d289
commit
d9de74141e
|
@ -1405,6 +1405,7 @@ class SQLInsertCompiler(SQLCompiler):
|
||||||
returning_fields and len(self.query.objs) != 1 and
|
returning_fields and len(self.query.objs) != 1 and
|
||||||
not self.connection.features.can_return_rows_from_bulk_insert
|
not self.connection.features.can_return_rows_from_bulk_insert
|
||||||
)
|
)
|
||||||
|
opts = self.query.get_meta()
|
||||||
self.returning_fields = returning_fields
|
self.returning_fields = returning_fields
|
||||||
with self.connection.cursor() as cursor:
|
with self.connection.cursor() as cursor:
|
||||||
for sql, params in self.as_sql():
|
for sql, params in self.as_sql():
|
||||||
|
@ -1412,13 +1413,21 @@ class SQLInsertCompiler(SQLCompiler):
|
||||||
if not self.returning_fields:
|
if not self.returning_fields:
|
||||||
return []
|
return []
|
||||||
if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
|
if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
|
||||||
return self.connection.ops.fetch_returned_insert_rows(cursor)
|
rows = self.connection.ops.fetch_returned_insert_rows(cursor)
|
||||||
if self.connection.features.can_return_columns_from_insert:
|
elif self.connection.features.can_return_columns_from_insert:
|
||||||
assert len(self.query.objs) == 1
|
assert len(self.query.objs) == 1
|
||||||
return [self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)]
|
rows = [self.connection.ops.fetch_returned_insert_columns(
|
||||||
return [(self.connection.ops.last_insert_id(
|
cursor, self.returning_params,
|
||||||
cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column
|
)]
|
||||||
|
else:
|
||||||
|
rows = [(self.connection.ops.last_insert_id(
|
||||||
|
cursor, opts.db_table, opts.pk.column,
|
||||||
),)]
|
),)]
|
||||||
|
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
|
||||||
|
converters = self.get_converters(cols)
|
||||||
|
if converters:
|
||||||
|
rows = list(self.apply_converters(rows, converters))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
class SQLDeleteCompiler(SQLCompiler):
|
class SQLDeleteCompiler(SQLCompiler):
|
||||||
|
|
|
@ -20,7 +20,7 @@ class MyWrapper:
|
||||||
return self.value == other
|
return self.value == other
|
||||||
|
|
||||||
|
|
||||||
class MyAutoField(models.CharField):
|
class MyWrapperField(models.CharField):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
kwargs['max_length'] = 10
|
kwargs['max_length'] = 10
|
||||||
|
@ -58,3 +58,15 @@ class MyAutoField(models.CharField):
|
||||||
if isinstance(value, MyWrapper):
|
if isinstance(value, MyWrapper):
|
||||||
return str(value)
|
return str(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class MyAutoField(models.BigAutoField):
|
||||||
|
def from_db_value(self, value, expression, connection):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return MyWrapper(value)
|
||||||
|
|
||||||
|
def get_prep_value(self, value):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
|
|
@ -7,7 +7,7 @@ this behavior by explicitly adding ``primary_key=True`` to a field.
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
from .fields import MyAutoField
|
from .fields import MyAutoField, MyWrapperField
|
||||||
|
|
||||||
|
|
||||||
class Employee(models.Model):
|
class Employee(models.Model):
|
||||||
|
@ -31,8 +31,12 @@ class Business(models.Model):
|
||||||
|
|
||||||
|
|
||||||
class Bar(models.Model):
|
class Bar(models.Model):
|
||||||
id = MyAutoField(primary_key=True, db_index=True)
|
id = MyWrapperField(primary_key=True, db_index=True)
|
||||||
|
|
||||||
|
|
||||||
class Foo(models.Model):
|
class Foo(models.Model):
|
||||||
bar = models.ForeignKey(Bar, models.CASCADE)
|
bar = models.ForeignKey(Bar, models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAutoFieldModel(models.Model):
|
||||||
|
id = MyAutoField(primary_key=True)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from django.db import IntegrityError, transaction
|
from django.db import IntegrityError, transaction
|
||||||
from django.test import TestCase, skipIfDBFeature
|
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
|
|
||||||
from .models import Bar, Business, Employee, Foo
|
from .fields import MyWrapper
|
||||||
|
from .models import Bar, Business, CustomAutoFieldModel, Employee, Foo
|
||||||
|
|
||||||
|
|
||||||
class BasicCustomPKTests(TestCase):
|
class BasicCustomPKTests(TestCase):
|
||||||
|
@ -230,3 +231,13 @@ class CustomPKTests(TestCase):
|
||||||
with self.assertRaises(IntegrityError):
|
with self.assertRaises(IntegrityError):
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
Employee.objects.create(first_name="Tom", last_name="Smith")
|
Employee.objects.create(first_name="Tom", last_name="Smith")
|
||||||
|
|
||||||
|
def test_auto_field_subclass_create(self):
|
||||||
|
obj = CustomAutoFieldModel.objects.create()
|
||||||
|
self.assertIsInstance(obj.id, MyWrapper)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('can_return_rows_from_bulk_insert')
|
||||||
|
def test_auto_field_subclass_bulk_create(self):
|
||||||
|
obj = CustomAutoFieldModel()
|
||||||
|
CustomAutoFieldModel.objects.bulk_create([obj])
|
||||||
|
self.assertIsInstance(obj.id, MyWrapper)
|
||||||
|
|
Loading…
Reference in New Issue