Fixed #22669 -- Fixed QuerySet.bulk_create() with empty model fields on Oracle.

This commit is contained in:
Mikhail Nacharov 2017-02-07 09:49:31 +05:00 committed by Tim Graham
parent 965f678a39
commit c4e2fc5d98
4 changed files with 88 additions and 8 deletions

View File

@ -9,7 +9,7 @@ from django.utils import timezone
from django.utils.encoding import force_bytes, force_text from django.utils.encoding import force_bytes, force_text
from .base import Database from .base import Database
from .utils import InsertIdVar, Oracle_datetime from .utils import BulkInsertMapper, InsertIdVar, Oracle_datetime
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
@ -523,10 +523,18 @@ WHEN (new.%(col_name)s IS NULL)
return truncate_name(trigger_name, name_length).upper() return truncate_name(trigger_name, name_length).upper()
def bulk_insert_sql(self, fields, placeholder_rows): def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join( query = []
"SELECT %s FROM DUAL" % ", ".join(row) for row in placeholder_rows:
for row in placeholder_rows select = []
) for i, placeholder in enumerate(row):
# A model without any fields has fields=[None].
if not fields[i]:
select.append(placeholder)
else:
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
select.append(BulkInsertMapper.types.get(internal_type, '%s') % placeholder)
query.append('SELECT %s FROM DUAL' % ', '.join(select))
return ' UNION ALL '.join(query)
def subtract_temporals(self, internal_type, lhs, rhs): def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == 'DateField': if internal_type == 'DateField':

View File

@ -29,3 +29,27 @@ class Oracle_datetime(datetime.datetime):
dt.year, dt.month, dt.day, dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second, dt.microsecond, dt.hour, dt.minute, dt.second, dt.microsecond,
) )
class BulkInsertMapper:
BLOB = 'TO_BLOB(%s)'
DATE = 'TO_DATE(%s)'
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
NUMBER = 'TO_NUMBER(%s)'
TIMESTAMP = 'TO_TIMESTAMP(%s)'
types = {
'BigIntegerField': NUMBER,
'BinaryField': BLOB,
'DateField': DATE,
'DateTimeField': TIMESTAMP,
'DecimalField': NUMBER,
'DurationField': INTERVAL,
'FloatField': NUMBER,
'IntegerField': NUMBER,
'NullBooleanField': NUMBER,
'PositiveIntegerField': NUMBER,
'PositiveSmallIntegerField': NUMBER,
'SmallIntegerField': NUMBER,
'TimeField': TIMESTAMP,
}

View File

@ -1,4 +1,9 @@
import datetime
import uuid
from decimal import Decimal
from django.db import models from django.db import models
from django.utils import timezone
class Country(models.Model): class Country(models.Model):
@ -51,3 +56,32 @@ class TwoFields(models.Model):
class NoFields(models.Model): class NoFields(models.Model):
pass pass
class NullableFields(models.Model):
# Fields in db.backends.oracle.BulkInsertMapper
big_int_filed = models.BigIntegerField(null=True, default=1)
binary_field = models.BinaryField(null=True, default=b'data')
date_field = models.DateField(null=True, default=timezone.now)
datetime_field = models.DateTimeField(null=True, default=timezone.now)
decimal_field = models.DecimalField(null=True, max_digits=2, decimal_places=1, default=Decimal('1.1'))
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
float_field = models.FloatField(null=True, default=3.2)
integer_field = models.IntegerField(null=True, default=2)
null_boolean_field = models.NullBooleanField(null=True, default=False)
positive_integer_field = models.PositiveIntegerField(null=True, default=3)
positive_small_integer_field = models.PositiveSmallIntegerField(null=True, default=4)
small_integer_field = models.SmallIntegerField(null=True, default=5)
time_field = models.TimeField(null=True, default=timezone.now)
# Fields not required in BulkInsertMapper
char_field = models.CharField(null=True, max_length=4, default='char')
email_field = models.EmailField(null=True, default='user@example.com')
duration_field = models.DurationField(null=True, default=datetime.timedelta(1))
file_field = models.FileField(null=True, default='file.txt')
file_path_field = models.FilePathField(path='/tmp', null=True, default='file.txt')
generic_ip_address_field = models.GenericIPAddressField(null=True, default='127.0.0.1')
image_field = models.ImageField(null=True, default='image.jpg')
slug_field = models.SlugField(null=True, default='slug')
text_field = models.TextField(null=True, default='text')
url_field = models.URLField(null=True, default='/')
uuid_field = models.UUIDField(null=True, default=uuid.uuid4)

View File

@ -1,15 +1,16 @@
from operator import attrgetter from operator import attrgetter
from django.db import connection from django.db import connection
from django.db.models import Value from django.db.models import FileField, Value
from django.db.models.functions import Lower from django.db.models.functions import Lower
from django.test import ( from django.test import (
TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature, TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
) )
from .models import ( from .models import (
Country, NoFields, Pizzeria, ProxyCountry, ProxyMultiCountry, Country, NoFields, NullableFields, Pizzeria, ProxyCountry,
ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant, State, TwoFields, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant,
State, TwoFields,
) )
@ -204,6 +205,19 @@ class BulkCreateTests(TestCase):
bbb = Restaurant.objects.filter(name="betty's beetroot bar") bbb = Restaurant.objects.filter(name="betty's beetroot bar")
self.assertEqual(bbb.count(), 1) self.assertEqual(bbb.count(), 1)
@skipUnlessDBFeature('has_bulk_insert')
def test_bulk_insert_nullable_fields(self):
# NULL can be mixed with other values in nullable fields
nullable_fields = [field for field in NullableFields._meta.get_fields() if field.name != 'id']
NullableFields.objects.bulk_create([
NullableFields(**{field.name: None}) for field in nullable_fields
])
self.assertEqual(NullableFields.objects.count(), len(nullable_fields))
for field in nullable_fields:
with self.subTest(field=field):
field_value = '' if isinstance(field, FileField) else None
self.assertEqual(NullableFields.objects.filter(**{field.name: field_value}).count(), 1)
@skipUnlessDBFeature('can_return_ids_from_bulk_insert') @skipUnlessDBFeature('can_return_ids_from_bulk_insert')
def test_set_pk_and_insert_single_item(self): def test_set_pk_and_insert_single_item(self):
with self.assertNumQueries(1): with self.assertNumQueries(1):