Fixed #7596. Added Model.objects.bulk_create, and make use of it in several places. This provides a performance benefit when inserting multiple objects. THanks to Russ for the review, and Simon Meers for the MySQl implementation.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16739 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2011-09-09 19:22:28 +00:00
parent e55bbf4c3c
commit 7deb25b8dd
22 changed files with 331 additions and 78 deletions

View File

@ -46,17 +46,15 @@ def create_permissions(app, created_models, verbosity, **kwargs):
"content_type", "codename"
))
for ctype, (codename, name) in searched_perms:
# If the permissions exists, move on.
if (ctype.pk, codename) in all_perms:
continue
p = auth_app.Permission.objects.create(
codename=codename,
name=name,
content_type=ctype
)
objs = [
auth_app.Permission(codename=codename, name=name, content_type=ctype)
for ctype, (codename, name) in searched_perms
if (ctype.pk, codename) not in all_perms
]
auth_app.Permission.objects.bulk_create(objs)
if verbosity >= 2:
print "Adding permission '%s'" % p
for obj in objs:
print "Adding permission '%s'" % obj
def create_superuser(app, created_models, verbosity, **kwargs):

View File

@ -8,25 +8,41 @@ def update_contenttypes(app, created_models, verbosity=2, **kwargs):
entries that no longer have a matching model class.
"""
ContentType.objects.clear_cache()
content_types = list(ContentType.objects.filter(app_label=app.__name__.split('.')[-2]))
app_models = get_models(app)
if not app_models:
return
for klass in app_models:
opts = klass._meta
try:
ct = ContentType.objects.get(app_label=opts.app_label,
model=opts.object_name.lower())
content_types.remove(ct)
except ContentType.DoesNotExist:
ct = ContentType(name=smart_unicode(opts.verbose_name_raw),
app_label=opts.app_label, model=opts.object_name.lower())
ct.save()
# They all have the same app_label, get the first one.
app_label = app_models[0]._meta.app_label
app_models = dict(
(model._meta.object_name.lower(), model)
for model in app_models
)
# Get all the content types
content_types = dict(
(ct.model, ct)
for ct in ContentType.objects.filter(app_label=app_label)
)
to_remove = [
ct
for (model_name, ct) in content_types.iteritems()
if model_name not in app_models
]
cts = ContentType.objects.bulk_create([
ContentType(
name=smart_unicode(model._meta.verbose_name_raw),
app_label=app_label,
model=model_name,
)
for (model_name, model) in app_models.iteritems()
if model_name not in content_types
])
if verbosity >= 2:
for ct in cts:
print "Adding content type '%s | %s'" % (ct.app_label, ct.model)
# The presence of any remaining content types means the supplied app has an
# undefined model. Confirm that the content type is stale before deletion.
if content_types:
# Confirm that the content type is stale before deletion.
if to_remove:
if kwargs.get('interactive', False):
content_type_display = '\n'.join([' %s | %s' % (ct.app_label, ct.model) for ct in content_types])
ok_to_delete = raw_input("""The following content types are stale and need to be deleted:
@ -42,7 +58,7 @@ If you're unsure, answer 'no'.
ok_to_delete = False
if ok_to_delete == 'yes':
for ct in content_types:
for ct in to_remove:
if verbosity >= 2:
print "Deleting stale content type '%s | %s'" % (ct.app_label, ct.model)
ct.delete()

View File

@ -301,8 +301,10 @@ class BaseDatabaseFeatures(object):
can_use_chunked_reads = True
can_return_id_from_insert = False
has_bulk_insert = False
uses_autocommit = False
uses_savepoints = False
can_combine_inserts_with_and_without_auto_increment_pk = False
# If True, don't use integer foreign keys referring to, e.g., positive
# integer primary keys.

View File

@ -124,6 +124,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_pk = True
related_fields_match_type = True
allow_sliced_subqueries = False
has_bulk_insert = True
has_select_for_update = True
has_select_for_update_nowait = False
supports_forward_references = False
@ -263,6 +264,10 @@ class DatabaseOperations(BaseDatabaseOperations):
def max_name_length(self):
return 64
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql'
operators = {

View File

@ -74,6 +74,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_defer_constraint_checks = True
has_select_for_update = True
has_select_for_update_nowait = True
has_bulk_insert = True
class DatabaseWrapper(BaseDatabaseWrapper):

View File

@ -180,3 +180,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_id(self):
return "RETURNING %s", ()
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)

View File

@ -58,6 +58,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_unspecified_pk = True
supports_1000_query_parameters = False
supports_mixed_date_datetime_comparisons = False
has_bulk_insert = True
can_combine_inserts_with_and_without_auto_increment_pk = True
def _supports_stddev(self):
"""Confirm support for STDDEV and related stats functions
@ -106,7 +108,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return ""
def pk_default_value(self):
return 'NULL'
return "NULL"
def quote_name(self, name):
if name.startswith('"') and name.endswith('"'):
@ -154,6 +156,14 @@ class DatabaseOperations(BaseDatabaseOperations):
# No field, or the field isn't known to be a decimal or integer
return value
def bulk_insert_sql(self, fields, num_values):
res = []
res.append("SELECT %s" % ", ".join(
"%%s AS %s" % self.quote_name(f.column) for f in fields
))
res.extend(["UNION SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
return " ".join(res)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'sqlite'
# SQLite requires LIKE statements to include an ESCAPE clause if the value

View File

@ -540,24 +540,16 @@ class Model(object):
order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()
self._order = order_value
fields = meta.local_fields
if not pk_set:
if force_update:
raise ValueError("Cannot force an update in save() with no primary key.")
values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
for f in meta.local_fields if not isinstance(f, AutoField)]
else:
values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
for f in meta.local_fields]
fields = [f for f in fields if not isinstance(f, AutoField)]
record_exists = False
update_pk = bool(meta.has_auto_field and not pk_set)
if values:
# Create a new record.
result = manager._insert(values, return_id=update_pk, using=using)
else:
# Create a new record with defaults for everything.
result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using)
result = manager._insert([self], fields=fields, return_id=update_pk, using=using, raw=raw)
if update_pk:
setattr(self, meta.pk.attname, result)

View File

@ -430,7 +430,7 @@ class ForeignRelatedObjectsDescriptor(object):
add.alters_data = True
def create(self, **kwargs):
kwargs.update({rel_field.name: instance})
kwargs[rel_field.name] = instance
db = router.db_for_write(rel_model, instance=instance)
return super(RelatedManager, self.db_manager(db)).create(**kwargs)
create.alters_data = True
@ -438,7 +438,7 @@ class ForeignRelatedObjectsDescriptor(object):
def get_or_create(self, **kwargs):
# Update kwargs with the related object that this
# ForeignRelatedObjectsDescriptor knows about.
kwargs.update({rel_field.name: instance})
kwargs[rel_field.name] = instance
db = router.db_for_write(rel_model, instance=instance)
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
get_or_create.alters_data = True
@ -578,11 +578,13 @@ def create_many_related_manager(superclass, rel=False):
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=new_ids, using=db)
# Add the ones that aren't there already
for obj_id in new_ids:
self.through._default_manager.using(db).create(**{
self.through._default_manager.using(db).bulk_create([
self.through(**{
'%s_id' % source_field_name: self._pk_val,
'%s_id' % target_field_name: obj_id,
})
for obj_id in new_ids
])
if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries.
@ -701,12 +703,12 @@ class ReverseManyRelatedObjectsDescriptor(object):
def __init__(self, m2m_field):
self.field = m2m_field
def _through(self):
@property
def through(self):
# through is provided so that you have easy access to the through
# model (Book.authors.through) for inlines, etc. This is done as
# a property to ensure that the fully resolved value is returned.
return self.field.rel.through
through = property(_through)
def __get__(self, instance, instance_type=None):
if instance is None:

View File

@ -136,6 +136,9 @@ class Manager(object):
def create(self, **kwargs):
return self.get_query_set().create(**kwargs)
def bulk_create(self, *args, **kwargs):
return self.get_query_set().bulk_create(*args, **kwargs)
def filter(self, *args, **kwargs):
return self.get_query_set().filter(*args, **kwargs)
@ -193,8 +196,8 @@ class Manager(object):
def exists(self, *args, **kwargs):
return self.get_query_set().exists(*args, **kwargs)
def _insert(self, values, **kwargs):
return insert_query(self.model, values, **kwargs)
def _insert(self, objs, fields, **kwargs):
return insert_query(self.model, objs, fields, **kwargs)
def _update(self, values, **kwargs):
return self.get_query_set()._update(values, **kwargs)

View File

@ -5,10 +5,12 @@ The main QuerySet implementation. This provides the public API for the ORM.
import copy
from django.db import connections, router, transaction, IntegrityError
from django.db.models.fields import AutoField
from django.db.models.query_utils import (Q, select_related_descend,
deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector
from django.db.models import signals, sql
from django.utils.functional import partition
# Used to control how many objects are worked with at once in some cases (e.g.
# when deleting objects).
@ -352,6 +354,41 @@ class QuerySet(object):
obj.save(force_insert=True, using=self.db)
return obj
def bulk_create(self, objs):
"""
Inserts each of the instances into the database. This does *not* call
save() on each of the instances, does not send any pre/post save
signals, and does not set the primary key attribute if it is an
autoincrement field.
"""
# So this case is fun. When you bulk insert you don't get the primary
# keys back (if it's an autoincrement), so you can't insert into the
# child tables which references this. There are two workarounds, 1)
# this could be implemented if you didn't have an autoincrement pk,
# and 2) you could do it by doing O(n) normal inserts into the parent
# tables to get the primary keys back, and then doing a single bulk
# insert into the childmost table. We're punting on these for now
# because they are relatively rare cases.
if self.model._meta.parents:
raise ValueError("Can't bulk create an inherited model")
if not objs:
return
self._for_write = True
connection = connections[self.db]
fields = self.model._meta.local_fields
if (connection.features.can_combine_inserts_with_and_without_auto_increment_pk
and self.model._meta.has_auto_field):
self.model._base_manager._insert(objs, fields=fields, using=self.db)
else:
objs_with_pk, objs_without_pk = partition(
lambda o: o.pk is None,
objs
)
if objs_with_pk:
self.model._base_manager._insert(objs_with_pk, fields=fields, using=self.db)
if objs_without_pk:
self.model._base_manager._insert(objs_without_pk, fields=[f for f in fields if not isinstance(f, AutoField)], using=self.db)
def get_or_create(self, **kwargs):
"""
Looks up an object with the given kwargs, creating one if necessary.
@ -1437,12 +1474,12 @@ class RawQuerySet(object):
self._model_fields[converter(column)] = field
return self._model_fields
def insert_query(model, values, return_id=False, raw_values=False, using=None):
def insert_query(model, objs, fields, return_id=False, raw=False, using=None):
"""
Inserts a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API.
"""
query = sql.InsertQuery(model)
query.insert_values(values, raw_values)
query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(return_id)

View File

@ -1,3 +1,5 @@
from itertools import izip
from django.core.exceptions import FieldError
from django.db import connections
from django.db import transaction
@ -9,6 +11,7 @@ from django.db.models.sql.query import (get_proxied_model, get_order_dir,
select_related_descend, Query)
from django.db.utils import DatabaseError
class SQLCompiler(object):
def __init__(self, query, connection, using):
self.query = query
@ -794,20 +797,55 @@ class SQLInsertCompiler(SQLCompiler):
qn = self.connection.ops.quote_name
opts = self.query.model._meta
result = ['INSERT INTO %s' % qn(opts.db_table)]
result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
values = [self.placeholder(*v) for v in self.query.values]
result.append('VALUES (%s)' % ', '.join(values))
params = self.query.params
has_fields = bool(self.query.fields)
fields = self.query.fields if has_fields else [opts.pk]
result.append('(%s)' % ', '.join([qn(f.column) for f in fields]))
if has_fields:
params = values = [
[
f.get_db_prep_save(getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), connection=self.connection)
for f in fields
]
for obj in self.query.objs
]
else:
values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
params = [[]]
fields = [None]
can_bulk = not any(hasattr(field, "get_placeholder") for field in fields) and not self.return_id
if can_bulk:
placeholders = [["%s"] * len(fields)]
else:
placeholders = [
[self.placeholder(field, v) for field, v in izip(fields, val)]
for val in values
]
if self.return_id and self.connection.features.can_return_id_from_insert:
params = values[0]
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append("VALUES (%s)" % ", ".join(placeholders[0]))
r_fmt, r_params = self.connection.ops.return_insert_id()
result.append(r_fmt % col)
params = params + r_params
return ' '.join(result), params
params += r_params
return [(" ".join(result), tuple(params))]
if can_bulk and self.connection.features.has_bulk_insert:
result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
return [(" ".join(result), tuple([v for val in values for v in val]))]
else:
return [
(" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
for p, vals in izip(placeholders, params)
]
def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id
cursor = super(SQLInsertCompiler, self).execute_sql(None)
cursor = self.connection.cursor()
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not (return_id and cursor):
return
if self.connection.features.can_return_id_from_insert:

View File

@ -8,9 +8,10 @@ all about the internals of models in order to get the information it needs.
"""
import copy
from django.utils.tree import Node
from django.utils.datastructures import SortedDict
from django.utils.encoding import force_unicode
from django.utils.tree import Node
from django.db import connections, DEFAULT_DB_ALIAS
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist

View File

@ -136,20 +136,19 @@ class InsertQuery(Query):
def __init__(self, *args, **kwargs):
super(InsertQuery, self).__init__(*args, **kwargs)
self.columns = []
self.values = []
self.params = ()
self.fields = []
self.objs = []
def clone(self, klass=None, **kwargs):
extras = {
'columns': self.columns[:],
'values': self.values[:],
'params': self.params
'fields': self.fields[:],
'objs': self.objs[:],
'raw': self.raw,
}
extras.update(kwargs)
return super(InsertQuery, self).clone(klass, **extras)
def insert_values(self, insert_values, raw_values=False):
def insert_values(self, fields, objs, raw=False):
"""
Set up the insert query from the 'insert_values' dictionary. The
dictionary gives the model field names and their target values.
@ -159,16 +158,9 @@ class InsertQuery(Query):
parameters. This provides a way to insert NULL and DEFAULT keywords
into the query, for example.
"""
placeholders, values = [], []
for field, val in insert_values:
placeholders.append((field, val))
self.columns.append(field.column)
values.append(val)
if raw_values:
self.values.extend([(None, v) for v in values])
else:
self.params += tuple(values)
self.values.extend(placeholders)
self.fields = fields
self.objs = objs
self.raw = raw
class DateQuery(Query):
"""

View File

@ -276,3 +276,16 @@ class lazy_property(property):
def fdel(instance, name=fdel.__name__):
return getattr(instance, name)()
return property(fget, fset, fdel, doc)
def partition(predicate, values):
"""
Splits the values into two sets, based on the return value of the function
(True/False). e.g.:
>>> partition(lambda: x > 3, range(5))
[1, 2, 3], [4]
"""
results = ([], [])
for item in values:
results[predicate(item)].append(item)
return results

View File

@ -1158,6 +1158,29 @@ has a side effect on your data. For more, see `Safe methods`_ in the HTTP spec.
.. _Safe methods: http://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html#sec9.1.1
bulk_create
~~~~~~~~~~~
.. method:: bulk_create(objs)
This method inserts the provided list of objects into the database in an
efficient manner (generally only 1 query, no matter how many objects there
are)::
>>> Entry.objects.bulk_create([
... Entry(headline="Django 1.0 Released"),
... Entry(headline="Django 1.1 Announced"),
... Entry(headline="Breaking: Django is awesome")
... ])
This has a number of caveats though:
* The model's ``save()`` method will not be called, and the ``pre_save`` and
``post_save`` signals will not be sent.
* It does not work with child models in a multi-table inheritance scenario.
* If the model's primary key is an :class:`~django.db.models.AutoField` it
does not retrieve and set the primary key attribute, as ``save()`` does.
count
~~~~~

View File

@ -252,6 +252,17 @@ filename. For example, the file ``css/styles.css`` would also be saved as
See the :class:`~django.contrib.staticfiles.storage.CachedStaticFilesStorage`
docs for more information.
``Model.objects.bulk_create`` in the ORM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This method allows for more efficient creation of multiple objects in the ORM.
It can provide significant performance increases if you have many objects,
Django makes use of this internally, meaning some operations (such as database
setup for test suites) has seen a performance benefit as a result.
See the :meth:`~django.db.models.query.QuerySet.bulk_create` docs for more
information.
Minor features
~~~~~~~~~~~~~~

View File

@ -268,3 +268,33 @@ instead of::
entry.blog.id
Insert in bulk
==============
When creating objects, where possible, use the
:meth:`~django.db.models.query.QuerySet.bulk_create()` method to reduce the
number of SQL queries. For example::
Entry.objects.bulk_create([
Entry(headline="Python 3.0 Released"),
Entry(headline="Python 3.1 Planned")
])
Is preferable to::
Entry.objects.create(headline="Python 3.0 Released")
Entry.objects.create(headline="Python 3.1 Planned")
Note that there are a number of :meth:`caveats to this method
<django.db.models.query.QuerySet.bulk_create>`, make sure it is appropriate for
your use case. This also applies to :class:`ManyToManyFields
<django.db.models.ManyToManyField>`, doing::
my_band.members.add(me, my_friend)
Is preferable to::
my_band.members.add(me)
my_band.members.add(my_friend)
Where ``Bands`` and ``Artists`` have a many-to-many relationship.

View File

@ -0,0 +1,21 @@
from django.db import models
class Country(models.Model):
name = models.CharField(max_length=255)
iso_two_letter = models.CharField(max_length=2)
class Place(models.Model):
name = models.CharField(max_length=100)
class Meta:
abstract = True
class Restaurant(Place):
pass
class Pizzeria(Restaurant):
pass
class State(models.Model):
two_letter_code = models.CharField(max_length=2, primary_key=True)

View File

@ -0,0 +1,54 @@
from __future__ import with_statement
from operator import attrgetter
from django.test import TestCase, skipUnlessDBFeature
from models import Country, Restaurant, Pizzeria, State
class BulkCreateTests(TestCase):
def setUp(self):
self.data = [
Country(name="United States of America", iso_two_letter="US"),
Country(name="The Netherlands", iso_two_letter="NL"),
Country(name="Germany", iso_two_letter="DE"),
Country(name="Czech Republic", iso_two_letter="CZ")
]
def test_simple(self):
Country.objects.bulk_create(self.data)
self.assertQuerysetEqual(Country.objects.order_by("-name"), [
"United States of America", "The Netherlands", "Germany", "Czech Republic"
], attrgetter("name"))
@skipUnlessDBFeature("has_bulk_insert")
def test_efficiency(self):
with self.assertNumQueries(1):
Country.objects.bulk_create(self.data)
def test_inheritance(self):
Restaurant.objects.bulk_create([
Restaurant(name="Nicholas's")
])
self.assertQuerysetEqual(Restaurant.objects.all(), [
"Nicholas's",
], attrgetter("name"))
with self.assertRaises(ValueError):
Pizzeria.objects.bulk_create([
Pizzeria(name="The Art of Pizza")
])
self.assertQuerysetEqual(Pizzeria.objects.all(), [])
self.assertQuerysetEqual(Restaurant.objects.all(), [
"Nicholas's",
], attrgetter("name"))
def test_non_auto_increment_pk(self):
with self.assertNumQueries(1):
State.objects.bulk_create([
State(two_letter_code=s)
for s in ["IL", "NY", "CA", "ME"]
])
self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [
"CA", "IL", "ME", "NY",
], attrgetter("two_letter_code"))

View File

@ -53,10 +53,10 @@ TEST_CASES = {
class DBTypeCasts(unittest.TestCase):
def test_typeCasts(self):
for k, v in TEST_CASES.items():
for k, v in TEST_CASES.iteritems():
for inpt, expected in v:
got = getattr(typecasts, k)(inpt)
assert got == expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got)
self.assertEqual(got, expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got))
if __name__ == '__main__':
unittest.main()