diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 7cc33af5373..bf9834d1f88 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -9,7 +9,7 @@ from django.utils import timezone from django.utils.encoding import force_bytes, force_text from .base import Database -from .utils import InsertIdVar, Oracle_datetime +from .utils import BulkInsertMapper, InsertIdVar, Oracle_datetime class DatabaseOperations(BaseDatabaseOperations): @@ -523,10 +523,18 @@ WHEN (new.%(col_name)s IS NULL) return truncate_name(trigger_name, name_length).upper() def bulk_insert_sql(self, fields, placeholder_rows): - return " UNION ALL ".join( - "SELECT %s FROM DUAL" % ", ".join(row) - for row in placeholder_rows - ) + query = [] + for row in placeholder_rows: + select = [] + for i, placeholder in enumerate(row): + # A model without any fields has fields=[None]. + if not fields[i]: + select.append(placeholder) + else: + internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type() + select.append(BulkInsertMapper.types.get(internal_type, '%s') % placeholder) + query.append('SELECT %s FROM DUAL' % ', '.join(select)) + return ' UNION ALL '.join(query) def subtract_temporals(self, internal_type, lhs, rhs): if internal_type == 'DateField': diff --git a/django/db/backends/oracle/utils.py b/django/db/backends/oracle/utils.py index f958655f940..6c81d5cd7ba 100644 --- a/django/db/backends/oracle/utils.py +++ b/django/db/backends/oracle/utils.py @@ -29,3 +29,27 @@ class Oracle_datetime(datetime.datetime): dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, dt.microsecond, ) + + +class BulkInsertMapper: + BLOB = 'TO_BLOB(%s)' + DATE = 'TO_DATE(%s)' + INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))' + NUMBER = 'TO_NUMBER(%s)' + TIMESTAMP = 'TO_TIMESTAMP(%s)' + + types = { + 'BigIntegerField': NUMBER, + 'BinaryField': BLOB, + 'DateField': DATE, + 'DateTimeField': TIMESTAMP, + 'DecimalField': NUMBER, + 'DurationField': INTERVAL, + 'FloatField': NUMBER, + 'IntegerField': NUMBER, + 'NullBooleanField': NUMBER, + 'PositiveIntegerField': NUMBER, + 'PositiveSmallIntegerField': NUMBER, + 'SmallIntegerField': NUMBER, + 'TimeField': TIMESTAMP, + } diff --git a/tests/bulk_create/models.py b/tests/bulk_create/models.py index c302a70b1b2..75d4a3cbdc7 100644 --- a/tests/bulk_create/models.py +++ b/tests/bulk_create/models.py @@ -1,4 +1,9 @@ +import datetime +import uuid +from decimal import Decimal + from django.db import models +from django.utils import timezone class Country(models.Model): @@ -51,3 +56,32 @@ class TwoFields(models.Model): class NoFields(models.Model): pass + + +class NullableFields(models.Model): + # Fields in db.backends.oracle.BulkInsertMapper + big_int_filed = models.BigIntegerField(null=True, default=1) + binary_field = models.BinaryField(null=True, default=b'data') + date_field = models.DateField(null=True, default=timezone.now) + datetime_field = models.DateTimeField(null=True, default=timezone.now) + decimal_field = models.DecimalField(null=True, max_digits=2, decimal_places=1, default=Decimal('1.1')) + duration_field = models.DurationField(null=True, default=datetime.timedelta(1)) + float_field = models.FloatField(null=True, default=3.2) + integer_field = models.IntegerField(null=True, default=2) + null_boolean_field = models.NullBooleanField(null=True, default=False) + positive_integer_field = models.PositiveIntegerField(null=True, default=3) + positive_small_integer_field = models.PositiveSmallIntegerField(null=True, default=4) + small_integer_field = models.SmallIntegerField(null=True, default=5) + time_field = models.TimeField(null=True, default=timezone.now) + # Fields not required in BulkInsertMapper + char_field = models.CharField(null=True, max_length=4, default='char') + email_field = models.EmailField(null=True, default='user@example.com') + duration_field = models.DurationField(null=True, default=datetime.timedelta(1)) + file_field = models.FileField(null=True, default='file.txt') + file_path_field = models.FilePathField(path='/tmp', null=True, default='file.txt') + generic_ip_address_field = models.GenericIPAddressField(null=True, default='127.0.0.1') + image_field = models.ImageField(null=True, default='image.jpg') + slug_field = models.SlugField(null=True, default='slug') + text_field = models.TextField(null=True, default='text') + url_field = models.URLField(null=True, default='/') + uuid_field = models.UUIDField(null=True, default=uuid.uuid4) diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index 97a8f821171..f401d4e32e3 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -1,15 +1,16 @@ from operator import attrgetter from django.db import connection -from django.db.models import Value +from django.db.models import FileField, Value from django.db.models.functions import Lower from django.test import ( TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature, ) from .models import ( - Country, NoFields, Pizzeria, ProxyCountry, ProxyMultiCountry, - ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant, State, TwoFields, + Country, NoFields, NullableFields, Pizzeria, ProxyCountry, + ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant, + State, TwoFields, ) @@ -204,6 +205,19 @@ class BulkCreateTests(TestCase): bbb = Restaurant.objects.filter(name="betty's beetroot bar") self.assertEqual(bbb.count(), 1) + @skipUnlessDBFeature('has_bulk_insert') + def test_bulk_insert_nullable_fields(self): + # NULL can be mixed with other values in nullable fields + nullable_fields = [field for field in NullableFields._meta.get_fields() if field.name != 'id'] + NullableFields.objects.bulk_create([ + NullableFields(**{field.name: None}) for field in nullable_fields + ]) + self.assertEqual(NullableFields.objects.count(), len(nullable_fields)) + for field in nullable_fields: + with self.subTest(field=field): + field_value = '' if isinstance(field, FileField) else None + self.assertEqual(NullableFields.objects.filter(**{field.name: field_value}).count(), 1) + @skipUnlessDBFeature('can_return_ids_from_bulk_insert') def test_set_pk_and_insert_single_item(self): with self.assertNumQueries(1):