Fixed #7596. Added Model.objects.bulk_create, and make use of it in several places. This provides a performance benefit when inserting multiple objects. THanks to Russ for the review, and Simon Meers for the MySQl implementation.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16739 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2011-09-09 19:22:28 +00:00
parent e55bbf4c3c
commit 7deb25b8dd
22 changed files with 331 additions and 78 deletions

View File

@ -46,17 +46,15 @@ def create_permissions(app, created_models, verbosity, **kwargs):
"content_type", "codename" "content_type", "codename"
)) ))
for ctype, (codename, name) in searched_perms: objs = [
# If the permissions exists, move on. auth_app.Permission(codename=codename, name=name, content_type=ctype)
if (ctype.pk, codename) in all_perms: for ctype, (codename, name) in searched_perms
continue if (ctype.pk, codename) not in all_perms
p = auth_app.Permission.objects.create( ]
codename=codename, auth_app.Permission.objects.bulk_create(objs)
name=name, if verbosity >= 2:
content_type=ctype for obj in objs:
) print "Adding permission '%s'" % obj
if verbosity >= 2:
print "Adding permission '%s'" % p
def create_superuser(app, created_models, verbosity, **kwargs): def create_superuser(app, created_models, verbosity, **kwargs):

View File

@ -8,25 +8,41 @@ def update_contenttypes(app, created_models, verbosity=2, **kwargs):
entries that no longer have a matching model class. entries that no longer have a matching model class.
""" """
ContentType.objects.clear_cache() ContentType.objects.clear_cache()
content_types = list(ContentType.objects.filter(app_label=app.__name__.split('.')[-2]))
app_models = get_models(app) app_models = get_models(app)
if not app_models: if not app_models:
return return
for klass in app_models: # They all have the same app_label, get the first one.
opts = klass._meta app_label = app_models[0]._meta.app_label
try: app_models = dict(
ct = ContentType.objects.get(app_label=opts.app_label, (model._meta.object_name.lower(), model)
model=opts.object_name.lower()) for model in app_models
content_types.remove(ct) )
except ContentType.DoesNotExist: # Get all the content types
ct = ContentType(name=smart_unicode(opts.verbose_name_raw), content_types = dict(
app_label=opts.app_label, model=opts.object_name.lower()) (ct.model, ct)
ct.save() for ct in ContentType.objects.filter(app_label=app_label)
if verbosity >= 2: )
print "Adding content type '%s | %s'" % (ct.app_label, ct.model) to_remove = [
# The presence of any remaining content types means the supplied app has an ct
# undefined model. Confirm that the content type is stale before deletion. for (model_name, ct) in content_types.iteritems()
if content_types: if model_name not in app_models
]
cts = ContentType.objects.bulk_create([
ContentType(
name=smart_unicode(model._meta.verbose_name_raw),
app_label=app_label,
model=model_name,
)
for (model_name, model) in app_models.iteritems()
if model_name not in content_types
])
if verbosity >= 2:
for ct in cts:
print "Adding content type '%s | %s'" % (ct.app_label, ct.model)
# Confirm that the content type is stale before deletion.
if to_remove:
if kwargs.get('interactive', False): if kwargs.get('interactive', False):
content_type_display = '\n'.join([' %s | %s' % (ct.app_label, ct.model) for ct in content_types]) content_type_display = '\n'.join([' %s | %s' % (ct.app_label, ct.model) for ct in content_types])
ok_to_delete = raw_input("""The following content types are stale and need to be deleted: ok_to_delete = raw_input("""The following content types are stale and need to be deleted:
@ -42,7 +58,7 @@ If you're unsure, answer 'no'.
ok_to_delete = False ok_to_delete = False
if ok_to_delete == 'yes': if ok_to_delete == 'yes':
for ct in content_types: for ct in to_remove:
if verbosity >= 2: if verbosity >= 2:
print "Deleting stale content type '%s | %s'" % (ct.app_label, ct.model) print "Deleting stale content type '%s | %s'" % (ct.app_label, ct.model)
ct.delete() ct.delete()

View File

@ -301,8 +301,10 @@ class BaseDatabaseFeatures(object):
can_use_chunked_reads = True can_use_chunked_reads = True
can_return_id_from_insert = False can_return_id_from_insert = False
has_bulk_insert = False
uses_autocommit = False uses_autocommit = False
uses_savepoints = False uses_savepoints = False
can_combine_inserts_with_and_without_auto_increment_pk = False
# If True, don't use integer foreign keys referring to, e.g., positive # If True, don't use integer foreign keys referring to, e.g., positive
# integer primary keys. # integer primary keys.

View File

@ -124,6 +124,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_pk = True allows_group_by_pk = True
related_fields_match_type = True related_fields_match_type = True
allow_sliced_subqueries = False allow_sliced_subqueries = False
has_bulk_insert = True
has_select_for_update = True has_select_for_update = True
has_select_for_update_nowait = False has_select_for_update_nowait = False
supports_forward_references = False supports_forward_references = False
@ -263,6 +264,10 @@ class DatabaseOperations(BaseDatabaseOperations):
def max_name_length(self): def max_name_length(self):
return 64 return 64
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql' vendor = 'mysql'
operators = { operators = {

View File

@ -74,6 +74,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_defer_constraint_checks = True can_defer_constraint_checks = True
has_select_for_update = True has_select_for_update = True
has_select_for_update_nowait = True has_select_for_update_nowait = True
has_bulk_insert = True
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):

View File

@ -180,3 +180,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_id(self): def return_insert_id(self):
return "RETURNING %s", () return "RETURNING %s", ()
def bulk_insert_sql(self, fields, num_values):
items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
return "VALUES " + ", ".join([items_sql] * num_values)

View File

@ -58,6 +58,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_unspecified_pk = True supports_unspecified_pk = True
supports_1000_query_parameters = False supports_1000_query_parameters = False
supports_mixed_date_datetime_comparisons = False supports_mixed_date_datetime_comparisons = False
has_bulk_insert = True
can_combine_inserts_with_and_without_auto_increment_pk = True
def _supports_stddev(self): def _supports_stddev(self):
"""Confirm support for STDDEV and related stats functions """Confirm support for STDDEV and related stats functions
@ -106,7 +108,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return "" return ""
def pk_default_value(self): def pk_default_value(self):
return 'NULL' return "NULL"
def quote_name(self, name): def quote_name(self, name):
if name.startswith('"') and name.endswith('"'): if name.startswith('"') and name.endswith('"'):
@ -154,6 +156,14 @@ class DatabaseOperations(BaseDatabaseOperations):
# No field, or the field isn't known to be a decimal or integer # No field, or the field isn't known to be a decimal or integer
return value return value
def bulk_insert_sql(self, fields, num_values):
res = []
res.append("SELECT %s" % ", ".join(
"%%s AS %s" % self.quote_name(f.column) for f in fields
))
res.extend(["UNION SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
return " ".join(res)
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'sqlite' vendor = 'sqlite'
# SQLite requires LIKE statements to include an ESCAPE clause if the value # SQLite requires LIKE statements to include an ESCAPE clause if the value

View File

@ -540,24 +540,16 @@ class Model(object):
order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count() order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count()
self._order = order_value self._order = order_value
fields = meta.local_fields
if not pk_set: if not pk_set:
if force_update: if force_update:
raise ValueError("Cannot force an update in save() with no primary key.") raise ValueError("Cannot force an update in save() with no primary key.")
values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) fields = [f for f in fields if not isinstance(f, AutoField)]
for f in meta.local_fields if not isinstance(f, AutoField)]
else:
values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection))
for f in meta.local_fields]
record_exists = False record_exists = False
update_pk = bool(meta.has_auto_field and not pk_set) update_pk = bool(meta.has_auto_field and not pk_set)
if values: result = manager._insert([self], fields=fields, return_id=update_pk, using=using, raw=raw)
# Create a new record.
result = manager._insert(values, return_id=update_pk, using=using)
else:
# Create a new record with defaults for everything.
result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using)
if update_pk: if update_pk:
setattr(self, meta.pk.attname, result) setattr(self, meta.pk.attname, result)

View File

@ -430,7 +430,7 @@ class ForeignRelatedObjectsDescriptor(object):
add.alters_data = True add.alters_data = True
def create(self, **kwargs): def create(self, **kwargs):
kwargs.update({rel_field.name: instance}) kwargs[rel_field.name] = instance
db = router.db_for_write(rel_model, instance=instance) db = router.db_for_write(rel_model, instance=instance)
return super(RelatedManager, self.db_manager(db)).create(**kwargs) return super(RelatedManager, self.db_manager(db)).create(**kwargs)
create.alters_data = True create.alters_data = True
@ -438,7 +438,7 @@ class ForeignRelatedObjectsDescriptor(object):
def get_or_create(self, **kwargs): def get_or_create(self, **kwargs):
# Update kwargs with the related object that this # Update kwargs with the related object that this
# ForeignRelatedObjectsDescriptor knows about. # ForeignRelatedObjectsDescriptor knows about.
kwargs.update({rel_field.name: instance}) kwargs[rel_field.name] = instance
db = router.db_for_write(rel_model, instance=instance) db = router.db_for_write(rel_model, instance=instance)
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs) return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
get_or_create.alters_data = True get_or_create.alters_data = True
@ -578,11 +578,13 @@ def create_many_related_manager(superclass, rel=False):
instance=self.instance, reverse=self.reverse, instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=new_ids, using=db) model=self.model, pk_set=new_ids, using=db)
# Add the ones that aren't there already # Add the ones that aren't there already
for obj_id in new_ids: self.through._default_manager.using(db).bulk_create([
self.through._default_manager.using(db).create(**{ self.through(**{
'%s_id' % source_field_name: self._pk_val, '%s_id' % source_field_name: self._pk_val,
'%s_id' % target_field_name: obj_id, '%s_id' % target_field_name: obj_id,
}) })
for obj_id in new_ids
])
if self.reverse or source_field_name == self.source_field_name: if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are inserting the # Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries. # duplicate data row for symmetrical reverse entries.
@ -701,12 +703,12 @@ class ReverseManyRelatedObjectsDescriptor(object):
def __init__(self, m2m_field): def __init__(self, m2m_field):
self.field = m2m_field self.field = m2m_field
def _through(self): @property
def through(self):
# through is provided so that you have easy access to the through # through is provided so that you have easy access to the through
# model (Book.authors.through) for inlines, etc. This is done as # model (Book.authors.through) for inlines, etc. This is done as
# a property to ensure that the fully resolved value is returned. # a property to ensure that the fully resolved value is returned.
return self.field.rel.through return self.field.rel.through
through = property(_through)
def __get__(self, instance, instance_type=None): def __get__(self, instance, instance_type=None):
if instance is None: if instance is None:

View File

@ -136,6 +136,9 @@ class Manager(object):
def create(self, **kwargs): def create(self, **kwargs):
return self.get_query_set().create(**kwargs) return self.get_query_set().create(**kwargs)
def bulk_create(self, *args, **kwargs):
return self.get_query_set().bulk_create(*args, **kwargs)
def filter(self, *args, **kwargs): def filter(self, *args, **kwargs):
return self.get_query_set().filter(*args, **kwargs) return self.get_query_set().filter(*args, **kwargs)
@ -193,8 +196,8 @@ class Manager(object):
def exists(self, *args, **kwargs): def exists(self, *args, **kwargs):
return self.get_query_set().exists(*args, **kwargs) return self.get_query_set().exists(*args, **kwargs)
def _insert(self, values, **kwargs): def _insert(self, objs, fields, **kwargs):
return insert_query(self.model, values, **kwargs) return insert_query(self.model, objs, fields, **kwargs)
def _update(self, values, **kwargs): def _update(self, values, **kwargs):
return self.get_query_set()._update(values, **kwargs) return self.get_query_set()._update(values, **kwargs)

View File

@ -5,10 +5,12 @@ The main QuerySet implementation. This provides the public API for the ORM.
import copy import copy
from django.db import connections, router, transaction, IntegrityError from django.db import connections, router, transaction, IntegrityError
from django.db.models.fields import AutoField
from django.db.models.query_utils import (Q, select_related_descend, from django.db.models.query_utils import (Q, select_related_descend,
deferred_class_factory, InvalidQuery) deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models import signals, sql from django.db.models import signals, sql
from django.utils.functional import partition
# Used to control how many objects are worked with at once in some cases (e.g. # Used to control how many objects are worked with at once in some cases (e.g.
# when deleting objects). # when deleting objects).
@ -352,6 +354,41 @@ class QuerySet(object):
obj.save(force_insert=True, using=self.db) obj.save(force_insert=True, using=self.db)
return obj return obj
def bulk_create(self, objs):
"""
Inserts each of the instances into the database. This does *not* call
save() on each of the instances, does not send any pre/post save
signals, and does not set the primary key attribute if it is an
autoincrement field.
"""
# So this case is fun. When you bulk insert you don't get the primary
# keys back (if it's an autoincrement), so you can't insert into the
# child tables which references this. There are two workarounds, 1)
# this could be implemented if you didn't have an autoincrement pk,
# and 2) you could do it by doing O(n) normal inserts into the parent
# tables to get the primary keys back, and then doing a single bulk
# insert into the childmost table. We're punting on these for now
# because they are relatively rare cases.
if self.model._meta.parents:
raise ValueError("Can't bulk create an inherited model")
if not objs:
return
self._for_write = True
connection = connections[self.db]
fields = self.model._meta.local_fields
if (connection.features.can_combine_inserts_with_and_without_auto_increment_pk
and self.model._meta.has_auto_field):
self.model._base_manager._insert(objs, fields=fields, using=self.db)
else:
objs_with_pk, objs_without_pk = partition(
lambda o: o.pk is None,
objs
)
if objs_with_pk:
self.model._base_manager._insert(objs_with_pk, fields=fields, using=self.db)
if objs_without_pk:
self.model._base_manager._insert(objs_without_pk, fields=[f for f in fields if not isinstance(f, AutoField)], using=self.db)
def get_or_create(self, **kwargs): def get_or_create(self, **kwargs):
""" """
Looks up an object with the given kwargs, creating one if necessary. Looks up an object with the given kwargs, creating one if necessary.
@ -1437,12 +1474,12 @@ class RawQuerySet(object):
self._model_fields[converter(column)] = field self._model_fields[converter(column)] = field
return self._model_fields return self._model_fields
def insert_query(model, values, return_id=False, raw_values=False, using=None): def insert_query(model, objs, fields, return_id=False, raw=False, using=None):
""" """
Inserts a new record for the given model. This provides an interface to Inserts a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. It is not the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API. part of the public API.
""" """
query = sql.InsertQuery(model) query = sql.InsertQuery(model)
query.insert_values(values, raw_values) query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(return_id) return query.get_compiler(using=using).execute_sql(return_id)

View File

@ -1,3 +1,5 @@
from itertools import izip
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connections from django.db import connections
from django.db import transaction from django.db import transaction
@ -9,6 +11,7 @@ from django.db.models.sql.query import (get_proxied_model, get_order_dir,
select_related_descend, Query) select_related_descend, Query)
from django.db.utils import DatabaseError from django.db.utils import DatabaseError
class SQLCompiler(object): class SQLCompiler(object):
def __init__(self, query, connection, using): def __init__(self, query, connection, using):
self.query = query self.query = query
@ -794,20 +797,55 @@ class SQLInsertCompiler(SQLCompiler):
qn = self.connection.ops.quote_name qn = self.connection.ops.quote_name
opts = self.query.model._meta opts = self.query.model._meta
result = ['INSERT INTO %s' % qn(opts.db_table)] result = ['INSERT INTO %s' % qn(opts.db_table)]
result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
values = [self.placeholder(*v) for v in self.query.values] has_fields = bool(self.query.fields)
result.append('VALUES (%s)' % ', '.join(values)) fields = self.query.fields if has_fields else [opts.pk]
params = self.query.params result.append('(%s)' % ', '.join([qn(f.column) for f in fields]))
if has_fields:
params = values = [
[
f.get_db_prep_save(getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), connection=self.connection)
for f in fields
]
for obj in self.query.objs
]
else:
values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
params = [[]]
fields = [None]
can_bulk = not any(hasattr(field, "get_placeholder") for field in fields) and not self.return_id
if can_bulk:
placeholders = [["%s"] * len(fields)]
else:
placeholders = [
[self.placeholder(field, v) for field, v in izip(fields, val)]
for val in values
]
if self.return_id and self.connection.features.can_return_id_from_insert: if self.return_id and self.connection.features.can_return_id_from_insert:
params = values[0]
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append("VALUES (%s)" % ", ".join(placeholders[0]))
r_fmt, r_params = self.connection.ops.return_insert_id() r_fmt, r_params = self.connection.ops.return_insert_id()
result.append(r_fmt % col) result.append(r_fmt % col)
params = params + r_params params += r_params
return ' '.join(result), params return [(" ".join(result), tuple(params))]
if can_bulk and self.connection.features.has_bulk_insert:
result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
return [(" ".join(result), tuple([v for val in values for v in val]))]
else:
return [
(" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
for p, vals in izip(placeholders, params)
]
def execute_sql(self, return_id=False): def execute_sql(self, return_id=False):
assert not (return_id and len(self.query.objs) != 1)
self.return_id = return_id self.return_id = return_id
cursor = super(SQLInsertCompiler, self).execute_sql(None) cursor = self.connection.cursor()
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not (return_id and cursor): if not (return_id and cursor):
return return
if self.connection.features.can_return_id_from_insert: if self.connection.features.can_return_id_from_insert:

View File

@ -8,9 +8,10 @@ all about the internals of models in order to get the information it needs.
""" """
import copy import copy
from django.utils.tree import Node
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.encoding import force_unicode from django.utils.encoding import force_unicode
from django.utils.tree import Node
from django.db import connections, DEFAULT_DB_ALIAS from django.db import connections, DEFAULT_DB_ALIAS
from django.db.models import signals from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist

View File

@ -136,20 +136,19 @@ class InsertQuery(Query):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(InsertQuery, self).__init__(*args, **kwargs) super(InsertQuery, self).__init__(*args, **kwargs)
self.columns = [] self.fields = []
self.values = [] self.objs = []
self.params = ()
def clone(self, klass=None, **kwargs): def clone(self, klass=None, **kwargs):
extras = { extras = {
'columns': self.columns[:], 'fields': self.fields[:],
'values': self.values[:], 'objs': self.objs[:],
'params': self.params 'raw': self.raw,
} }
extras.update(kwargs) extras.update(kwargs)
return super(InsertQuery, self).clone(klass, **extras) return super(InsertQuery, self).clone(klass, **extras)
def insert_values(self, insert_values, raw_values=False): def insert_values(self, fields, objs, raw=False):
""" """
Set up the insert query from the 'insert_values' dictionary. The Set up the insert query from the 'insert_values' dictionary. The
dictionary gives the model field names and their target values. dictionary gives the model field names and their target values.
@ -159,16 +158,9 @@ class InsertQuery(Query):
parameters. This provides a way to insert NULL and DEFAULT keywords parameters. This provides a way to insert NULL and DEFAULT keywords
into the query, for example. into the query, for example.
""" """
placeholders, values = [], [] self.fields = fields
for field, val in insert_values: self.objs = objs
placeholders.append((field, val)) self.raw = raw
self.columns.append(field.column)
values.append(val)
if raw_values:
self.values.extend([(None, v) for v in values])
else:
self.params += tuple(values)
self.values.extend(placeholders)
class DateQuery(Query): class DateQuery(Query):
""" """

View File

@ -275,4 +275,17 @@ class lazy_property(property):
@wraps(fdel) @wraps(fdel)
def fdel(instance, name=fdel.__name__): def fdel(instance, name=fdel.__name__):
return getattr(instance, name)() return getattr(instance, name)()
return property(fget, fset, fdel, doc) return property(fget, fset, fdel, doc)
def partition(predicate, values):
"""
Splits the values into two sets, based on the return value of the function
(True/False). e.g.:
>>> partition(lambda: x > 3, range(5))
[1, 2, 3], [4]
"""
results = ([], [])
for item in values:
results[predicate(item)].append(item)
return results

View File

@ -1158,6 +1158,29 @@ has a side effect on your data. For more, see `Safe methods`_ in the HTTP spec.
.. _Safe methods: http://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html#sec9.1.1 .. _Safe methods: http://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html#sec9.1.1
bulk_create
~~~~~~~~~~~
.. method:: bulk_create(objs)
This method inserts the provided list of objects into the database in an
efficient manner (generally only 1 query, no matter how many objects there
are)::
>>> Entry.objects.bulk_create([
... Entry(headline="Django 1.0 Released"),
... Entry(headline="Django 1.1 Announced"),
... Entry(headline="Breaking: Django is awesome")
... ])
This has a number of caveats though:
* The model's ``save()`` method will not be called, and the ``pre_save`` and
``post_save`` signals will not be sent.
* It does not work with child models in a multi-table inheritance scenario.
* If the model's primary key is an :class:`~django.db.models.AutoField` it
does not retrieve and set the primary key attribute, as ``save()`` does.
count count
~~~~~ ~~~~~

View File

@ -252,6 +252,17 @@ filename. For example, the file ``css/styles.css`` would also be saved as
See the :class:`~django.contrib.staticfiles.storage.CachedStaticFilesStorage` See the :class:`~django.contrib.staticfiles.storage.CachedStaticFilesStorage`
docs for more information. docs for more information.
``Model.objects.bulk_create`` in the ORM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This method allows for more efficient creation of multiple objects in the ORM.
It can provide significant performance increases if you have many objects,
Django makes use of this internally, meaning some operations (such as database
setup for test suites) has seen a performance benefit as a result.
See the :meth:`~django.db.models.query.QuerySet.bulk_create` docs for more
information.
Minor features Minor features
~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~

View File

@ -268,3 +268,33 @@ instead of::
entry.blog.id entry.blog.id
Insert in bulk
==============
When creating objects, where possible, use the
:meth:`~django.db.models.query.QuerySet.bulk_create()` method to reduce the
number of SQL queries. For example::
Entry.objects.bulk_create([
Entry(headline="Python 3.0 Released"),
Entry(headline="Python 3.1 Planned")
])
Is preferable to::
Entry.objects.create(headline="Python 3.0 Released")
Entry.objects.create(headline="Python 3.1 Planned")
Note that there are a number of :meth:`caveats to this method
<django.db.models.query.QuerySet.bulk_create>`, make sure it is appropriate for
your use case. This also applies to :class:`ManyToManyFields
<django.db.models.ManyToManyField>`, doing::
my_band.members.add(me, my_friend)
Is preferable to::
my_band.members.add(me)
my_band.members.add(my_friend)
Where ``Bands`` and ``Artists`` have a many-to-many relationship.

View File

@ -0,0 +1,21 @@
from django.db import models
class Country(models.Model):
name = models.CharField(max_length=255)
iso_two_letter = models.CharField(max_length=2)
class Place(models.Model):
name = models.CharField(max_length=100)
class Meta:
abstract = True
class Restaurant(Place):
pass
class Pizzeria(Restaurant):
pass
class State(models.Model):
two_letter_code = models.CharField(max_length=2, primary_key=True)

View File

@ -0,0 +1,54 @@
from __future__ import with_statement
from operator import attrgetter
from django.test import TestCase, skipUnlessDBFeature
from models import Country, Restaurant, Pizzeria, State
class BulkCreateTests(TestCase):
def setUp(self):
self.data = [
Country(name="United States of America", iso_two_letter="US"),
Country(name="The Netherlands", iso_two_letter="NL"),
Country(name="Germany", iso_two_letter="DE"),
Country(name="Czech Republic", iso_two_letter="CZ")
]
def test_simple(self):
Country.objects.bulk_create(self.data)
self.assertQuerysetEqual(Country.objects.order_by("-name"), [
"United States of America", "The Netherlands", "Germany", "Czech Republic"
], attrgetter("name"))
@skipUnlessDBFeature("has_bulk_insert")
def test_efficiency(self):
with self.assertNumQueries(1):
Country.objects.bulk_create(self.data)
def test_inheritance(self):
Restaurant.objects.bulk_create([
Restaurant(name="Nicholas's")
])
self.assertQuerysetEqual(Restaurant.objects.all(), [
"Nicholas's",
], attrgetter("name"))
with self.assertRaises(ValueError):
Pizzeria.objects.bulk_create([
Pizzeria(name="The Art of Pizza")
])
self.assertQuerysetEqual(Pizzeria.objects.all(), [])
self.assertQuerysetEqual(Restaurant.objects.all(), [
"Nicholas's",
], attrgetter("name"))
def test_non_auto_increment_pk(self):
with self.assertNumQueries(1):
State.objects.bulk_create([
State(two_letter_code=s)
for s in ["IL", "NY", "CA", "ME"]
])
self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [
"CA", "IL", "ME", "NY",
], attrgetter("two_letter_code"))

View File

@ -53,10 +53,10 @@ TEST_CASES = {
class DBTypeCasts(unittest.TestCase): class DBTypeCasts(unittest.TestCase):
def test_typeCasts(self): def test_typeCasts(self):
for k, v in TEST_CASES.items(): for k, v in TEST_CASES.iteritems():
for inpt, expected in v: for inpt, expected in v:
got = getattr(typecasts, k)(inpt) got = getattr(typecasts, k)(inpt)
assert got == expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got) self.assertEqual(got, expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()