Fixed #22669 -- Fixed QuerySet.bulk_create() with empty model fields on Oracle.

This commit is contained in:
Mikhail Nacharov 2017-02-07 09:49:31 +05:00 committed by Tim Graham
parent 965f678a39
commit c4e2fc5d98
4 changed files with 88 additions and 8 deletions

View File

@ -9,7 +9,7 @@ from django.utils import timezone
from django.utils.encoding import force_bytes, force_text
from .base import Database
from .utils import InsertIdVar, Oracle_datetime
from .utils import BulkInsertMapper, InsertIdVar, Oracle_datetime
class DatabaseOperations(BaseDatabaseOperations):
@ -523,10 +523,18 @@ WHEN (new.%(col_name)s IS NULL)
return truncate_name(trigger_name, name_length).upper()
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s FROM DUAL" % ", ".join(row)
for row in placeholder_rows
)
query = []
for row in placeholder_rows:
select = []
for i, placeholder in enumerate(row):
# A model without any fields has fields=[None].
if not fields[i]:
select.append(placeholder)
else:
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
select.append(BulkInsertMapper.types.get(internal_type, '%s') % placeholder)
query.append('SELECT %s FROM DUAL' % ', '.join(select))
return ' UNION ALL '.join(query)
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == 'DateField':

View File

@ -29,3 +29,27 @@ class Oracle_datetime(datetime.datetime):
dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second, dt.microsecond,
)
class BulkInsertMapper:
BLOB = 'TO_BLOB(%s)'
DATE = 'TO_DATE(%s)'
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
NUMBER = 'TO_NUMBER(%s)'
TIMESTAMP = 'TO_TIMESTAMP(%s)'
types = {
'BigIntegerField': NUMBER,
'BinaryField': BLOB,
'DateField': DATE,
'DateTimeField': TIMESTAMP,
'DecimalField': NUMBER,
'DurationField': INTERVAL,
'FloatField': NUMBER,
'IntegerField': NUMBER,
'NullBooleanField': NUMBER,
'PositiveIntegerField': NUMBER,
'PositiveSmallIntegerField': NUMBER,
'SmallIntegerField': NUMBER,
'TimeField': TIMESTAMP,
}

View File

@ -1,4 +1,9 @@
import datetime
import uuid
from decimal import Decimal
from django.db import models
from django.utils import timezone
class Country(models.Model):
@ -51,3 +56,32 @@ class TwoFields(models.Model):
class NoFields(models.Model):
pass
class NullableFields(models.Model):
# Fields in db.backends.oracle.BulkInsertMapper
big_int_filed = models.BigIntegerField(null=True, default=1)
binary_field = models.BinaryField(null=True, default=b'data')
date_field = models.DateField(null=True, default=timezone.now)
datetime_field = models.DateTimeField(null=True, default=timezone.now)
decimal_field = models.DecimalField(null=True, max_digits=2, decimal_places=1, default=Decimal('1.1'))
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
float_field = models.FloatField(null=True, default=3.2)
integer_field = models.IntegerField(null=True, default=2)
null_boolean_field = models.NullBooleanField(null=True, default=False)
positive_integer_field = models.PositiveIntegerField(null=True, default=3)
positive_small_integer_field = models.PositiveSmallIntegerField(null=True, default=4)
small_integer_field = models.SmallIntegerField(null=True, default=5)
time_field = models.TimeField(null=True, default=timezone.now)
# Fields not required in BulkInsertMapper
char_field = models.CharField(null=True, max_length=4, default='char')
email_field = models.EmailField(null=True, default='user@example.com')
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
file_field = models.FileField(null=True, default='file.txt')
file_path_field = models.FilePathField(path='/tmp', null=True, default='file.txt')
generic_ip_address_field = models.GenericIPAddressField(null=True, default='127.0.0.1')
image_field = models.ImageField(null=True, default='image.jpg')
slug_field = models.SlugField(null=True, default='slug')
text_field = models.TextField(null=True, default='text')
url_field = models.URLField(null=True, default='/')
uuid_field = models.UUIDField(null=True, default=uuid.uuid4)

View File

@ -1,15 +1,16 @@
from operator import attrgetter
from django.db import connection
from django.db.models import Value
from django.db.models import FileField, Value
from django.db.models.functions import Lower
from django.test import (
TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
)
from .models import (
Country, NoFields, Pizzeria, ProxyCountry, ProxyMultiCountry,
ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant, State, TwoFields,
Country, NoFields, NullableFields, Pizzeria, ProxyCountry,
ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant,
State, TwoFields,
)
@ -204,6 +205,19 @@ class BulkCreateTests(TestCase):
bbb = Restaurant.objects.filter(name="betty's beetroot bar")
self.assertEqual(bbb.count(), 1)
@skipUnlessDBFeature('has_bulk_insert')
def test_bulk_insert_nullable_fields(self):
# NULL can be mixed with other values in nullable fields
nullable_fields = [field for field in NullableFields._meta.get_fields() if field.name != 'id']
NullableFields.objects.bulk_create([
NullableFields(**{field.name: None}) for field in nullable_fields
])
self.assertEqual(NullableFields.objects.count(), len(nullable_fields))
for field in nullable_fields:
with self.subTest(field=field):
field_value = '' if isinstance(field, FileField) else None
self.assertEqual(NullableFields.objects.filter(**{field.name: field_value}).count(), 1)
@skipUnlessDBFeature('can_return_ids_from_bulk_insert')
def test_set_pk_and_insert_single_item(self):
with self.assertNumQueries(1):