mirror of https://github.com/django/django.git
Fixed #22669 -- Fixed QuerySet.bulk_create() with empty model fields on Oracle.
This commit is contained in:
parent
965f678a39
commit
c4e2fc5d98
|
@ -9,7 +9,7 @@ from django.utils import timezone
|
||||||
from django.utils.encoding import force_bytes, force_text
|
from django.utils.encoding import force_bytes, force_text
|
||||||
|
|
||||||
from .base import Database
|
from .base import Database
|
||||||
from .utils import InsertIdVar, Oracle_datetime
|
from .utils import BulkInsertMapper, InsertIdVar, Oracle_datetime
|
||||||
|
|
||||||
|
|
||||||
class DatabaseOperations(BaseDatabaseOperations):
|
class DatabaseOperations(BaseDatabaseOperations):
|
||||||
|
@ -523,10 +523,18 @@ WHEN (new.%(col_name)s IS NULL)
|
||||||
return truncate_name(trigger_name, name_length).upper()
|
return truncate_name(trigger_name, name_length).upper()
|
||||||
|
|
||||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||||
return " UNION ALL ".join(
|
query = []
|
||||||
"SELECT %s FROM DUAL" % ", ".join(row)
|
for row in placeholder_rows:
|
||||||
for row in placeholder_rows
|
select = []
|
||||||
)
|
for i, placeholder in enumerate(row):
|
||||||
|
# A model without any fields has fields=[None].
|
||||||
|
if not fields[i]:
|
||||||
|
select.append(placeholder)
|
||||||
|
else:
|
||||||
|
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
|
||||||
|
select.append(BulkInsertMapper.types.get(internal_type, '%s') % placeholder)
|
||||||
|
query.append('SELECT %s FROM DUAL' % ', '.join(select))
|
||||||
|
return ' UNION ALL '.join(query)
|
||||||
|
|
||||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||||
if internal_type == 'DateField':
|
if internal_type == 'DateField':
|
||||||
|
|
|
@ -29,3 +29,27 @@ class Oracle_datetime(datetime.datetime):
|
||||||
dt.year, dt.month, dt.day,
|
dt.year, dt.month, dt.day,
|
||||||
dt.hour, dt.minute, dt.second, dt.microsecond,
|
dt.hour, dt.minute, dt.second, dt.microsecond,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BulkInsertMapper:
|
||||||
|
BLOB = 'TO_BLOB(%s)'
|
||||||
|
DATE = 'TO_DATE(%s)'
|
||||||
|
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
|
||||||
|
NUMBER = 'TO_NUMBER(%s)'
|
||||||
|
TIMESTAMP = 'TO_TIMESTAMP(%s)'
|
||||||
|
|
||||||
|
types = {
|
||||||
|
'BigIntegerField': NUMBER,
|
||||||
|
'BinaryField': BLOB,
|
||||||
|
'DateField': DATE,
|
||||||
|
'DateTimeField': TIMESTAMP,
|
||||||
|
'DecimalField': NUMBER,
|
||||||
|
'DurationField': INTERVAL,
|
||||||
|
'FloatField': NUMBER,
|
||||||
|
'IntegerField': NUMBER,
|
||||||
|
'NullBooleanField': NUMBER,
|
||||||
|
'PositiveIntegerField': NUMBER,
|
||||||
|
'PositiveSmallIntegerField': NUMBER,
|
||||||
|
'SmallIntegerField': NUMBER,
|
||||||
|
'TimeField': TIMESTAMP,
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
|
||||||
class Country(models.Model):
|
class Country(models.Model):
|
||||||
|
@ -51,3 +56,32 @@ class TwoFields(models.Model):
|
||||||
|
|
||||||
class NoFields(models.Model):
|
class NoFields(models.Model):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NullableFields(models.Model):
|
||||||
|
# Fields in db.backends.oracle.BulkInsertMapper
|
||||||
|
big_int_filed = models.BigIntegerField(null=True, default=1)
|
||||||
|
binary_field = models.BinaryField(null=True, default=b'data')
|
||||||
|
date_field = models.DateField(null=True, default=timezone.now)
|
||||||
|
datetime_field = models.DateTimeField(null=True, default=timezone.now)
|
||||||
|
decimal_field = models.DecimalField(null=True, max_digits=2, decimal_places=1, default=Decimal('1.1'))
|
||||||
|
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
|
||||||
|
float_field = models.FloatField(null=True, default=3.2)
|
||||||
|
integer_field = models.IntegerField(null=True, default=2)
|
||||||
|
null_boolean_field = models.NullBooleanField(null=True, default=False)
|
||||||
|
positive_integer_field = models.PositiveIntegerField(null=True, default=3)
|
||||||
|
positive_small_integer_field = models.PositiveSmallIntegerField(null=True, default=4)
|
||||||
|
small_integer_field = models.SmallIntegerField(null=True, default=5)
|
||||||
|
time_field = models.TimeField(null=True, default=timezone.now)
|
||||||
|
# Fields not required in BulkInsertMapper
|
||||||
|
char_field = models.CharField(null=True, max_length=4, default='char')
|
||||||
|
email_field = models.EmailField(null=True, default='user@example.com')
|
||||||
|
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
|
||||||
|
file_field = models.FileField(null=True, default='file.txt')
|
||||||
|
file_path_field = models.FilePathField(path='/tmp', null=True, default='file.txt')
|
||||||
|
generic_ip_address_field = models.GenericIPAddressField(null=True, default='127.0.0.1')
|
||||||
|
image_field = models.ImageField(null=True, default='image.jpg')
|
||||||
|
slug_field = models.SlugField(null=True, default='slug')
|
||||||
|
text_field = models.TextField(null=True, default='text')
|
||||||
|
url_field = models.URLField(null=True, default='/')
|
||||||
|
uuid_field = models.UUIDField(null=True, default=uuid.uuid4)
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
|
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
from django.db.models import Value
|
from django.db.models import FileField, Value
|
||||||
from django.db.models.functions import Lower
|
from django.db.models.functions import Lower
|
||||||
from django.test import (
|
from django.test import (
|
||||||
TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
|
TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
Country, NoFields, Pizzeria, ProxyCountry, ProxyMultiCountry,
|
Country, NoFields, NullableFields, Pizzeria, ProxyCountry,
|
||||||
ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant, State, TwoFields,
|
ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant,
|
||||||
|
State, TwoFields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -204,6 +205,19 @@ class BulkCreateTests(TestCase):
|
||||||
bbb = Restaurant.objects.filter(name="betty's beetroot bar")
|
bbb = Restaurant.objects.filter(name="betty's beetroot bar")
|
||||||
self.assertEqual(bbb.count(), 1)
|
self.assertEqual(bbb.count(), 1)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature('has_bulk_insert')
|
||||||
|
def test_bulk_insert_nullable_fields(self):
|
||||||
|
# NULL can be mixed with other values in nullable fields
|
||||||
|
nullable_fields = [field for field in NullableFields._meta.get_fields() if field.name != 'id']
|
||||||
|
NullableFields.objects.bulk_create([
|
||||||
|
NullableFields(**{field.name: None}) for field in nullable_fields
|
||||||
|
])
|
||||||
|
self.assertEqual(NullableFields.objects.count(), len(nullable_fields))
|
||||||
|
for field in nullable_fields:
|
||||||
|
with self.subTest(field=field):
|
||||||
|
field_value = '' if isinstance(field, FileField) else None
|
||||||
|
self.assertEqual(NullableFields.objects.filter(**{field.name: field_value}).count(), 1)
|
||||||
|
|
||||||
@skipUnlessDBFeature('can_return_ids_from_bulk_insert')
|
@skipUnlessDBFeature('can_return_ids_from_bulk_insert')
|
||||||
def test_set_pk_and_insert_single_item(self):
|
def test_set_pk_and_insert_single_item(self):
|
||||||
with self.assertNumQueries(1):
|
with self.assertNumQueries(1):
|
||||||
|
|
Loading…
Reference in New Issue