diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 6f1f4b9618..a82d6f56a1 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -470,6 +470,14 @@ class BaseDatabaseOperations(object): """ return None + def bulk_batch_size(self, fields, objs): + """ + Returns the maximum allowed batch size for the backend. The fields + are the fields going to be inserted in the batch, the objs contains + all the objects to be inserted. + """ + return len(objs) + def date_extract_sql(self, lookup_type, field_name): """ Given a lookup_type of 'year', 'month' or 'day', returns the SQL that @@ -507,6 +515,17 @@ class BaseDatabaseOperations(object): """ return '' + def distinct_sql(self, fields): + """ + Returns an SQL DISTINCT clause which removes duplicate rows from the + result set. If any fields are given, only the given fields are being + checked for duplicates. + """ + if fields: + raise NotImplementedError('DISTINCT ON fields is not supported by this database backend') + else: + return 'DISTINCT' + def drop_foreignkey_sql(self): """ Returns the SQL command that drops a foreign key. @@ -562,17 +581,6 @@ class BaseDatabaseOperations(object): """ raise NotImplementedError('Full-text search is not implemented for this database backend') - def distinct_sql(self, fields): - """ - Returns an SQL DISTINCT clause which removes duplicate rows from the - result set. If any fields are given, only the given fields are being - checked for duplicates. - """ - if fields: - raise NotImplementedError('DISTINCT ON fields is not supported by this database backend') - else: - return 'DISTINCT' - def last_executed_query(self, cursor, sql, params): """ Returns a string of the query last executed by the given cursor, with diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 0b19442e78..2146a7fa8a 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -83,7 +83,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_1000_query_parameters = False supports_mixed_date_datetime_comparisons = False has_bulk_insert = True - can_combine_inserts_with_and_without_auto_increment_pk = True + can_combine_inserts_with_and_without_auto_increment_pk = False def _supports_stddev(self): """Confirm support for STDDEV and related stats functions @@ -104,6 +104,13 @@ class DatabaseFeatures(BaseDatabaseFeatures): return has_support class DatabaseOperations(BaseDatabaseOperations): + def bulk_batch_size(self, fields, objs): + """ + SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of + 999 variables per query. + """ + return (999 // len(fields)) if len(fields) > 0 else len(objs) + def date_extract_sql(self, lookup_type, field_name): # sqlite doesn't support extract, so we fake it with the user-defined # function django_extract that's registered in connect(). Note that diff --git a/django/db/models/query.py b/django/db/models/query.py index 44acadf037..82378fb5a2 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -377,7 +377,7 @@ class QuerySet(object): obj.save(force_insert=True, using=self.db) return obj - def bulk_create(self, objs): + def bulk_create(self, objs, batch_size=None): """ Inserts each of the instances into the database. This does *not* call save() on each of the instances, does not send any pre/post save @@ -390,8 +390,10 @@ class QuerySet(object): # this could be implemented if you didn't have an autoincrement pk, # and 2) you could do it by doing O(n) normal inserts into the parent # tables to get the primary keys back, and then doing a single bulk - # insert into the childmost table. We're punting on these for now - # because they are relatively rare cases. + # insert into the childmost table. Some databases might allow doing + # this by using RETURNING clause for the insert query. We're punting + # on these for now because they are relatively rare cases. + assert batch_size is None or batch_size > 0 if self.model._meta.parents: raise ValueError("Can't bulk create an inherited model") if not objs: @@ -407,13 +409,14 @@ class QuerySet(object): try: if (connection.features.can_combine_inserts_with_and_without_auto_increment_pk and self.model._meta.has_auto_field): - self.model._base_manager._insert(objs, fields=fields, using=self.db) + self._batched_insert(objs, fields, batch_size) else: objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) if objs_with_pk: - self.model._base_manager._insert(objs_with_pk, fields=fields, using=self.db) + self._batched_insert(objs_with_pk, fields, batch_size) if objs_without_pk: - self.model._base_manager._insert(objs_without_pk, fields=[f for f in fields if not isinstance(f, AutoField)], using=self.db) + fields= [f for f in fields if not isinstance(f, AutoField)] + self._batched_insert(objs_without_pk, fields, batch_size) if forced_managed: transaction.commit(using=self.db) else: @@ -849,6 +852,20 @@ class QuerySet(object): ################### # PRIVATE METHODS # ################### + def _batched_insert(self, objs, fields, batch_size): + """ + A little helper method for bulk_insert to insert the bulk one batch + at a time. Inserts recursively a batch from the front of the bulk and + then _batched_insert() the remaining objects again. + """ + if not objs: + return + ops = connections[self.db].ops + batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1)) + for batch in [objs[i:i+batch_size] + for i in range(0, len(objs), batch_size)]: + self.model._base_manager._insert(batch, fields=fields, + using=self.db) def _clone(self, klass=None, setup=False, **kwargs): if klass is None: diff --git a/tests/regressiontests/bulk_create/models.py b/tests/regressiontests/bulk_create/models.py index a4c611d537..bc685bbbe4 100644 --- a/tests/regressiontests/bulk_create/models.py +++ b/tests/regressiontests/bulk_create/models.py @@ -18,4 +18,8 @@ class Pizzeria(Restaurant): pass class State(models.Model): - two_letter_code = models.CharField(max_length=2, primary_key=True) \ No newline at end of file + two_letter_code = models.CharField(max_length=2, primary_key=True) + +class TwoFields(models.Model): + f1 = models.IntegerField(unique=True) + f2 = models.IntegerField(unique=True) diff --git a/tests/regressiontests/bulk_create/tests.py b/tests/regressiontests/bulk_create/tests.py index 0fa142b795..b4c3e7f17f 100644 --- a/tests/regressiontests/bulk_create/tests.py +++ b/tests/regressiontests/bulk_create/tests.py @@ -2,9 +2,11 @@ from __future__ import with_statement, absolute_import from operator import attrgetter -from django.test import TestCase, skipUnlessDBFeature +from django.db import connection +from django.test import TestCase, skipIfDBFeature +from django.test.utils import override_settings -from .models import Country, Restaurant, Pizzeria, State +from .models import Country, Restaurant, Pizzeria, State, TwoFields class BulkCreateTests(TestCase): @@ -27,7 +29,6 @@ class BulkCreateTests(TestCase): self.assertEqual(created, []) self.assertEqual(Country.objects.count(), 4) - @skipUnlessDBFeature("has_bulk_insert") def test_efficiency(self): with self.assertNumQueries(1): Country.objects.bulk_create(self.data) @@ -56,4 +57,43 @@ class BulkCreateTests(TestCase): ]) self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [ "CA", "IL", "ME", "NY", - ], attrgetter("two_letter_code")) \ No newline at end of file + ], attrgetter("two_letter_code")) + + def test_large_batch(self): + with override_settings(DEBUG=True): + connection.queries = [] + TwoFields.objects.bulk_create([ + TwoFields(f1=i, f2=i+1) for i in range(0, 1001) + ]) + self.assertTrue(len(connection.queries) < 10) + self.assertEqual(TwoFields.objects.count(), 1001) + self.assertEqual( + TwoFields.objects.filter(f1__gte=450, f1__lte=550).count(), + 101) + self.assertEqual(TwoFields.objects.filter(f2__gte=901).count(), 101) + + def test_large_batch_mixed(self): + """ + Test inserting a large batch with objects having primary key set + mixed together with objects without PK set. + """ + with override_settings(DEBUG=True): + connection.queries = [] + TwoFields.objects.bulk_create([ + TwoFields(id=i if i % 2 == 0 else None, f1=i, f2=i+1) + for i in range(100000, 101000)]) + self.assertTrue(len(connection.queries) < 10) + self.assertEqual(TwoFields.objects.count(), 1000) + # We can't assume much about the ID's created, except that the above + # created IDs must exist. + id_range = range(100000, 101000, 2) + self.assertEqual(TwoFields.objects.filter(id__in=id_range).count(), 500) + self.assertEqual(TwoFields.objects.exclude(id__in=id_range).count(), 500) + + def test_explicit_batch_size(self): + objs = [TwoFields(f1=i, f2=i) for i in range(0, 100)] + with self.assertNumQueries(2): + TwoFields.objects.bulk_create(objs, 50) + TwoFields.objects.all().delete() + with self.assertNumQueries(1): + TwoFields.objects.bulk_create(objs, len(objs)) diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index ded3e8ffa7..ed71be8392 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -1807,8 +1807,7 @@ class ConditionalTests(BaseQuerysetTest): # Test that the "in" lookup works with lists of 1000 items or more. Number.objects.all().delete() numbers = range(2500) - for num in numbers: - _ = Number.objects.create(num=num) + Number.objects.bulk_create(Number(num=num) for num in numbers) self.assertEqual( Number.objects.filter(num__in=numbers[:1000]).count(), 1000