diff --git a/django/contrib/auth/management/__init__.py b/django/contrib/auth/management/__init__.py index f82060e81c..532654f92d 100644 --- a/django/contrib/auth/management/__init__.py +++ b/django/contrib/auth/management/__init__.py @@ -46,17 +46,15 @@ def create_permissions(app, created_models, verbosity, **kwargs): "content_type", "codename" )) - for ctype, (codename, name) in searched_perms: - # If the permissions exists, move on. - if (ctype.pk, codename) in all_perms: - continue - p = auth_app.Permission.objects.create( - codename=codename, - name=name, - content_type=ctype - ) - if verbosity >= 2: - print "Adding permission '%s'" % p + objs = [ + auth_app.Permission(codename=codename, name=name, content_type=ctype) + for ctype, (codename, name) in searched_perms + if (ctype.pk, codename) not in all_perms + ] + auth_app.Permission.objects.bulk_create(objs) + if verbosity >= 2: + for obj in objs: + print "Adding permission '%s'" % obj def create_superuser(app, created_models, verbosity, **kwargs): diff --git a/django/contrib/contenttypes/management.py b/django/contrib/contenttypes/management.py index 27d12751f2..5c7e12a847 100644 --- a/django/contrib/contenttypes/management.py +++ b/django/contrib/contenttypes/management.py @@ -8,25 +8,41 @@ def update_contenttypes(app, created_models, verbosity=2, **kwargs): entries that no longer have a matching model class. """ ContentType.objects.clear_cache() - content_types = list(ContentType.objects.filter(app_label=app.__name__.split('.')[-2])) app_models = get_models(app) if not app_models: return - for klass in app_models: - opts = klass._meta - try: - ct = ContentType.objects.get(app_label=opts.app_label, - model=opts.object_name.lower()) - content_types.remove(ct) - except ContentType.DoesNotExist: - ct = ContentType(name=smart_unicode(opts.verbose_name_raw), - app_label=opts.app_label, model=opts.object_name.lower()) - ct.save() - if verbosity >= 2: - print "Adding content type '%s | %s'" % (ct.app_label, ct.model) - # The presence of any remaining content types means the supplied app has an - # undefined model. Confirm that the content type is stale before deletion. - if content_types: + # They all have the same app_label, get the first one. + app_label = app_models[0]._meta.app_label + app_models = dict( + (model._meta.object_name.lower(), model) + for model in app_models + ) + # Get all the content types + content_types = dict( + (ct.model, ct) + for ct in ContentType.objects.filter(app_label=app_label) + ) + to_remove = [ + ct + for (model_name, ct) in content_types.iteritems() + if model_name not in app_models + ] + + cts = ContentType.objects.bulk_create([ + ContentType( + name=smart_unicode(model._meta.verbose_name_raw), + app_label=app_label, + model=model_name, + ) + for (model_name, model) in app_models.iteritems() + if model_name not in content_types + ]) + if verbosity >= 2: + for ct in cts: + print "Adding content type '%s | %s'" % (ct.app_label, ct.model) + + # Confirm that the content type is stale before deletion. + if to_remove: if kwargs.get('interactive', False): content_type_display = '\n'.join([' %s | %s' % (ct.app_label, ct.model) for ct in content_types]) ok_to_delete = raw_input("""The following content types are stale and need to be deleted: @@ -42,7 +58,7 @@ If you're unsure, answer 'no'. ok_to_delete = False if ok_to_delete == 'yes': - for ct in content_types: + for ct in to_remove: if verbosity >= 2: print "Deleting stale content type '%s | %s'" % (ct.app_label, ct.model) ct.delete() diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 23ddedb4c6..06ce9deb72 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -301,8 +301,10 @@ class BaseDatabaseFeatures(object): can_use_chunked_reads = True can_return_id_from_insert = False + has_bulk_insert = False uses_autocommit = False uses_savepoints = False + can_combine_inserts_with_and_without_auto_increment_pk = False # If True, don't use integer foreign keys referring to, e.g., positive # integer primary keys. diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 2dd9e84ebe..a22951a3b7 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -124,6 +124,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): allows_group_by_pk = True related_fields_match_type = True allow_sliced_subqueries = False + has_bulk_insert = True has_select_for_update = True has_select_for_update_nowait = False supports_forward_references = False @@ -263,6 +264,10 @@ class DatabaseOperations(BaseDatabaseOperations): def max_name_length(self): return 64 + def bulk_insert_sql(self, fields, num_values): + items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) + return "VALUES " + ", ".join([items_sql] * num_values) + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'mysql' operators = { diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 7efa6e67b2..f0a89e50a8 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -74,6 +74,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_defer_constraint_checks = True has_select_for_update = True has_select_for_update_nowait = True + has_bulk_insert = True class DatabaseWrapper(BaseDatabaseWrapper): diff --git a/django/db/backends/postgresql_psycopg2/operations.py b/django/db/backends/postgresql_psycopg2/operations.py index d535ee316f..c3a23c2073 100644 --- a/django/db/backends/postgresql_psycopg2/operations.py +++ b/django/db/backends/postgresql_psycopg2/operations.py @@ -180,3 +180,7 @@ class DatabaseOperations(BaseDatabaseOperations): def return_insert_id(self): return "RETURNING %s", () + + def bulk_insert_sql(self, fields, num_values): + items_sql = "(%s)" % ", ".join(["%s"] * len(fields)) + return "VALUES " + ", ".join([items_sql] * num_values) diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index e3c12f7583..b45c0fb935 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -58,6 +58,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_unspecified_pk = True supports_1000_query_parameters = False supports_mixed_date_datetime_comparisons = False + has_bulk_insert = True + can_combine_inserts_with_and_without_auto_increment_pk = True def _supports_stddev(self): """Confirm support for STDDEV and related stats functions @@ -106,7 +108,7 @@ class DatabaseOperations(BaseDatabaseOperations): return "" def pk_default_value(self): - return 'NULL' + return "NULL" def quote_name(self, name): if name.startswith('"') and name.endswith('"'): @@ -154,6 +156,14 @@ class DatabaseOperations(BaseDatabaseOperations): # No field, or the field isn't known to be a decimal or integer return value + def bulk_insert_sql(self, fields, num_values): + res = [] + res.append("SELECT %s" % ", ".join( + "%%s AS %s" % self.quote_name(f.column) for f in fields + )) + res.extend(["UNION SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1)) + return " ".join(res) + class DatabaseWrapper(BaseDatabaseWrapper): vendor = 'sqlite' # SQLite requires LIKE statements to include an ESCAPE clause if the value diff --git a/django/db/models/base.py b/django/db/models/base.py index 71fd1f78bb..4b3220b514 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -540,24 +540,16 @@ class Model(object): order_value = manager.using(using).filter(**{field.name: getattr(self, field.attname)}).count() self._order = order_value + fields = meta.local_fields if not pk_set: if force_update: raise ValueError("Cannot force an update in save() with no primary key.") - values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) - for f in meta.local_fields if not isinstance(f, AutoField)] - else: - values = [(f, f.get_db_prep_save(raw and getattr(self, f.attname) or f.pre_save(self, True), connection=connection)) - for f in meta.local_fields] + fields = [f for f in fields if not isinstance(f, AutoField)] record_exists = False update_pk = bool(meta.has_auto_field and not pk_set) - if values: - # Create a new record. - result = manager._insert(values, return_id=update_pk, using=using) - else: - # Create a new record with defaults for everything. - result = manager._insert([(meta.pk, connection.ops.pk_default_value())], return_id=update_pk, raw_values=True, using=using) + result = manager._insert([self], fields=fields, return_id=update_pk, using=using, raw=raw) if update_pk: setattr(self, meta.pk.attname, result) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index d904f56cd6..fa7b482f24 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -430,7 +430,7 @@ class ForeignRelatedObjectsDescriptor(object): add.alters_data = True def create(self, **kwargs): - kwargs.update({rel_field.name: instance}) + kwargs[rel_field.name] = instance db = router.db_for_write(rel_model, instance=instance) return super(RelatedManager, self.db_manager(db)).create(**kwargs) create.alters_data = True @@ -438,7 +438,7 @@ class ForeignRelatedObjectsDescriptor(object): def get_or_create(self, **kwargs): # Update kwargs with the related object that this # ForeignRelatedObjectsDescriptor knows about. - kwargs.update({rel_field.name: instance}) + kwargs[rel_field.name] = instance db = router.db_for_write(rel_model, instance=instance) return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs) get_or_create.alters_data = True @@ -578,11 +578,13 @@ def create_many_related_manager(superclass, rel=False): instance=self.instance, reverse=self.reverse, model=self.model, pk_set=new_ids, using=db) # Add the ones that aren't there already - for obj_id in new_ids: - self.through._default_manager.using(db).create(**{ + self.through._default_manager.using(db).bulk_create([ + self.through(**{ '%s_id' % source_field_name: self._pk_val, '%s_id' % target_field_name: obj_id, }) + for obj_id in new_ids + ]) if self.reverse or source_field_name == self.source_field_name: # Don't send the signal when we are inserting the # duplicate data row for symmetrical reverse entries. @@ -701,12 +703,12 @@ class ReverseManyRelatedObjectsDescriptor(object): def __init__(self, m2m_field): self.field = m2m_field - def _through(self): + @property + def through(self): # through is provided so that you have easy access to the through # model (Book.authors.through) for inlines, etc. This is done as # a property to ensure that the fully resolved value is returned. return self.field.rel.through - through = property(_through) def __get__(self, instance, instance_type=None): if instance is None: diff --git a/django/db/models/manager.py b/django/db/models/manager.py index bdd86bbd45..baf701f6dd 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -136,6 +136,9 @@ class Manager(object): def create(self, **kwargs): return self.get_query_set().create(**kwargs) + def bulk_create(self, *args, **kwargs): + return self.get_query_set().bulk_create(*args, **kwargs) + def filter(self, *args, **kwargs): return self.get_query_set().filter(*args, **kwargs) @@ -193,8 +196,8 @@ class Manager(object): def exists(self, *args, **kwargs): return self.get_query_set().exists(*args, **kwargs) - def _insert(self, values, **kwargs): - return insert_query(self.model, values, **kwargs) + def _insert(self, objs, fields, **kwargs): + return insert_query(self.model, objs, fields, **kwargs) def _update(self, values, **kwargs): return self.get_query_set()._update(values, **kwargs) diff --git a/django/db/models/query.py b/django/db/models/query.py index ff5289c89c..4b6645569f 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -5,10 +5,12 @@ The main QuerySet implementation. This provides the public API for the ORM. import copy from django.db import connections, router, transaction, IntegrityError +from django.db.models.fields import AutoField from django.db.models.query_utils import (Q, select_related_descend, deferred_class_factory, InvalidQuery) from django.db.models.deletion import Collector from django.db.models import signals, sql +from django.utils.functional import partition # Used to control how many objects are worked with at once in some cases (e.g. # when deleting objects). @@ -352,6 +354,41 @@ class QuerySet(object): obj.save(force_insert=True, using=self.db) return obj + def bulk_create(self, objs): + """ + Inserts each of the instances into the database. This does *not* call + save() on each of the instances, does not send any pre/post save + signals, and does not set the primary key attribute if it is an + autoincrement field. + """ + # So this case is fun. When you bulk insert you don't get the primary + # keys back (if it's an autoincrement), so you can't insert into the + # child tables which references this. There are two workarounds, 1) + # this could be implemented if you didn't have an autoincrement pk, + # and 2) you could do it by doing O(n) normal inserts into the parent + # tables to get the primary keys back, and then doing a single bulk + # insert into the childmost table. We're punting on these for now + # because they are relatively rare cases. + if self.model._meta.parents: + raise ValueError("Can't bulk create an inherited model") + if not objs: + return + self._for_write = True + connection = connections[self.db] + fields = self.model._meta.local_fields + if (connection.features.can_combine_inserts_with_and_without_auto_increment_pk + and self.model._meta.has_auto_field): + self.model._base_manager._insert(objs, fields=fields, using=self.db) + else: + objs_with_pk, objs_without_pk = partition( + lambda o: o.pk is None, + objs + ) + if objs_with_pk: + self.model._base_manager._insert(objs_with_pk, fields=fields, using=self.db) + if objs_without_pk: + self.model._base_manager._insert(objs_without_pk, fields=[f for f in fields if not isinstance(f, AutoField)], using=self.db) + def get_or_create(self, **kwargs): """ Looks up an object with the given kwargs, creating one if necessary. @@ -1437,12 +1474,12 @@ class RawQuerySet(object): self._model_fields[converter(column)] = field return self._model_fields -def insert_query(model, values, return_id=False, raw_values=False, using=None): +def insert_query(model, objs, fields, return_id=False, raw=False, using=None): """ Inserts a new record for the given model. This provides an interface to the InsertQuery class and is how Model.save() is implemented. It is not part of the public API. """ query = sql.InsertQuery(model) - query.insert_values(values, raw_values) + query.insert_values(fields, objs, raw=raw) return query.get_compiler(using=using).execute_sql(return_id) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 05c19f33d5..b8bba4b013 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1,3 +1,5 @@ +from itertools import izip + from django.core.exceptions import FieldError from django.db import connections from django.db import transaction @@ -9,6 +11,7 @@ from django.db.models.sql.query import (get_proxied_model, get_order_dir, select_related_descend, Query) from django.db.utils import DatabaseError + class SQLCompiler(object): def __init__(self, query, connection, using): self.query = query @@ -794,20 +797,55 @@ class SQLInsertCompiler(SQLCompiler): qn = self.connection.ops.quote_name opts = self.query.model._meta result = ['INSERT INTO %s' % qn(opts.db_table)] - result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns])) - values = [self.placeholder(*v) for v in self.query.values] - result.append('VALUES (%s)' % ', '.join(values)) - params = self.query.params + + has_fields = bool(self.query.fields) + fields = self.query.fields if has_fields else [opts.pk] + result.append('(%s)' % ', '.join([qn(f.column) for f in fields])) + + if has_fields: + params = values = [ + [ + f.get_db_prep_save(getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True), connection=self.connection) + for f in fields + ] + for obj in self.query.objs + ] + else: + values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs] + params = [[]] + fields = [None] + can_bulk = not any(hasattr(field, "get_placeholder") for field in fields) and not self.return_id + + if can_bulk: + placeholders = [["%s"] * len(fields)] + else: + placeholders = [ + [self.placeholder(field, v) for field, v in izip(fields, val)] + for val in values + ] if self.return_id and self.connection.features.can_return_id_from_insert: + params = values[0] col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) + result.append("VALUES (%s)" % ", ".join(placeholders[0])) r_fmt, r_params = self.connection.ops.return_insert_id() result.append(r_fmt % col) - params = params + r_params - return ' '.join(result), params + params += r_params + return [(" ".join(result), tuple(params))] + if can_bulk and self.connection.features.has_bulk_insert: + result.append(self.connection.ops.bulk_insert_sql(fields, len(values))) + return [(" ".join(result), tuple([v for val in values for v in val]))] + else: + return [ + (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) + for p, vals in izip(placeholders, params) + ] def execute_sql(self, return_id=False): + assert not (return_id and len(self.query.objs) != 1) self.return_id = return_id - cursor = super(SQLInsertCompiler, self).execute_sql(None) + cursor = self.connection.cursor() + for sql, params in self.as_sql(): + cursor.execute(sql, params) if not (return_id and cursor): return if self.connection.features.can_return_id_from_insert: diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b55de3d766..101d4ac7a5 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -8,9 +8,10 @@ all about the internals of models in order to get the information it needs. """ import copy -from django.utils.tree import Node + from django.utils.datastructures import SortedDict from django.utils.encoding import force_unicode +from django.utils.tree import Node from django.db import connections, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.fields import FieldDoesNotExist diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index ecde857a9b..1b03647595 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -136,20 +136,19 @@ class InsertQuery(Query): def __init__(self, *args, **kwargs): super(InsertQuery, self).__init__(*args, **kwargs) - self.columns = [] - self.values = [] - self.params = () + self.fields = [] + self.objs = [] def clone(self, klass=None, **kwargs): extras = { - 'columns': self.columns[:], - 'values': self.values[:], - 'params': self.params + 'fields': self.fields[:], + 'objs': self.objs[:], + 'raw': self.raw, } extras.update(kwargs) return super(InsertQuery, self).clone(klass, **extras) - def insert_values(self, insert_values, raw_values=False): + def insert_values(self, fields, objs, raw=False): """ Set up the insert query from the 'insert_values' dictionary. The dictionary gives the model field names and their target values. @@ -159,16 +158,9 @@ class InsertQuery(Query): parameters. This provides a way to insert NULL and DEFAULT keywords into the query, for example. """ - placeholders, values = [], [] - for field, val in insert_values: - placeholders.append((field, val)) - self.columns.append(field.column) - values.append(val) - if raw_values: - self.values.extend([(None, v) for v in values]) - else: - self.params += tuple(values) - self.values.extend(placeholders) + self.fields = fields + self.objs = objs + self.raw = raw class DateQuery(Query): """ diff --git a/django/utils/functional.py b/django/utils/functional.py index c8f8ee33c7..1345d3b005 100644 --- a/django/utils/functional.py +++ b/django/utils/functional.py @@ -275,4 +275,17 @@ class lazy_property(property): @wraps(fdel) def fdel(instance, name=fdel.__name__): return getattr(instance, name)() - return property(fget, fset, fdel, doc) \ No newline at end of file + return property(fget, fset, fdel, doc) + +def partition(predicate, values): + """ + Splits the values into two sets, based on the return value of the function + (True/False). e.g.: + + >>> partition(lambda: x > 3, range(5)) + [1, 2, 3], [4] + """ + results = ([], []) + for item in values: + results[predicate(item)].append(item) + return results \ No newline at end of file diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index b2c521e645..898306d67d 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1158,6 +1158,29 @@ has a side effect on your data. For more, see `Safe methods`_ in the HTTP spec. .. _Safe methods: http://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html#sec9.1.1 +bulk_create +~~~~~~~~~~~ + +.. method:: bulk_create(objs) + +This method inserts the provided list of objects into the database in an +efficient manner (generally only 1 query, no matter how many objects there +are):: + + >>> Entry.objects.bulk_create([ + ... Entry(headline="Django 1.0 Released"), + ... Entry(headline="Django 1.1 Announced"), + ... Entry(headline="Breaking: Django is awesome") + ... ]) + +This has a number of caveats though: + + * The model's ``save()`` method will not be called, and the ``pre_save`` and + ``post_save`` signals will not be sent. + * It does not work with child models in a multi-table inheritance scenario. + * If the model's primary key is an :class:`~django.db.models.AutoField` it + does not retrieve and set the primary key attribute, as ``save()`` does. + count ~~~~~ diff --git a/docs/releases/1.4.txt b/docs/releases/1.4.txt index 81e1b72d69..ba05f2805f 100644 --- a/docs/releases/1.4.txt +++ b/docs/releases/1.4.txt @@ -252,6 +252,17 @@ filename. For example, the file ``css/styles.css`` would also be saved as See the :class:`~django.contrib.staticfiles.storage.CachedStaticFilesStorage` docs for more information. +``Model.objects.bulk_create`` in the ORM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This method allows for more efficient creation of multiple objects in the ORM. +It can provide significant performance increases if you have many objects, +Django makes use of this internally, meaning some operations (such as database +setup for test suites) has seen a performance benefit as a result. + +See the :meth:`~django.db.models.query.QuerySet.bulk_create` docs for more +information. + Minor features ~~~~~~~~~~~~~~ diff --git a/docs/topics/db/optimization.txt b/docs/topics/db/optimization.txt index 265ef55fae..5093917b61 100644 --- a/docs/topics/db/optimization.txt +++ b/docs/topics/db/optimization.txt @@ -268,3 +268,33 @@ instead of:: entry.blog.id +Insert in bulk +============== + +When creating objects, where possible, use the +:meth:`~django.db.models.query.QuerySet.bulk_create()` method to reduce the +number of SQL queries. For example:: + + Entry.objects.bulk_create([ + Entry(headline="Python 3.0 Released"), + Entry(headline="Python 3.1 Planned") + ]) + +Is preferable to:: + + Entry.objects.create(headline="Python 3.0 Released") + Entry.objects.create(headline="Python 3.1 Planned") + +Note that there are a number of :meth:`caveats to this method +`, make sure it is appropriate for +your use case. This also applies to :class:`ManyToManyFields +`, doing:: + + my_band.members.add(me, my_friend) + +Is preferable to:: + + my_band.members.add(me) + my_band.members.add(my_friend) + +Where ``Bands`` and ``Artists`` have a many-to-many relationship. diff --git a/tests/regressiontests/bulk_create/__init__.py b/tests/regressiontests/bulk_create/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/regressiontests/bulk_create/models.py b/tests/regressiontests/bulk_create/models.py new file mode 100644 index 0000000000..a4c611d537 --- /dev/null +++ b/tests/regressiontests/bulk_create/models.py @@ -0,0 +1,21 @@ +from django.db import models + + +class Country(models.Model): + name = models.CharField(max_length=255) + iso_two_letter = models.CharField(max_length=2) + +class Place(models.Model): + name = models.CharField(max_length=100) + + class Meta: + abstract = True + +class Restaurant(Place): + pass + +class Pizzeria(Restaurant): + pass + +class State(models.Model): + two_letter_code = models.CharField(max_length=2, primary_key=True) \ No newline at end of file diff --git a/tests/regressiontests/bulk_create/tests.py b/tests/regressiontests/bulk_create/tests.py new file mode 100644 index 0000000000..8d3faa2379 --- /dev/null +++ b/tests/regressiontests/bulk_create/tests.py @@ -0,0 +1,54 @@ +from __future__ import with_statement + +from operator import attrgetter + +from django.test import TestCase, skipUnlessDBFeature + +from models import Country, Restaurant, Pizzeria, State + + +class BulkCreateTests(TestCase): + def setUp(self): + self.data = [ + Country(name="United States of America", iso_two_letter="US"), + Country(name="The Netherlands", iso_two_letter="NL"), + Country(name="Germany", iso_two_letter="DE"), + Country(name="Czech Republic", iso_two_letter="CZ") + ] + + def test_simple(self): + Country.objects.bulk_create(self.data) + self.assertQuerysetEqual(Country.objects.order_by("-name"), [ + "United States of America", "The Netherlands", "Germany", "Czech Republic" + ], attrgetter("name")) + + @skipUnlessDBFeature("has_bulk_insert") + def test_efficiency(self): + with self.assertNumQueries(1): + Country.objects.bulk_create(self.data) + + def test_inheritance(self): + Restaurant.objects.bulk_create([ + Restaurant(name="Nicholas's") + ]) + self.assertQuerysetEqual(Restaurant.objects.all(), [ + "Nicholas's", + ], attrgetter("name")) + with self.assertRaises(ValueError): + Pizzeria.objects.bulk_create([ + Pizzeria(name="The Art of Pizza") + ]) + self.assertQuerysetEqual(Pizzeria.objects.all(), []) + self.assertQuerysetEqual(Restaurant.objects.all(), [ + "Nicholas's", + ], attrgetter("name")) + + def test_non_auto_increment_pk(self): + with self.assertNumQueries(1): + State.objects.bulk_create([ + State(two_letter_code=s) + for s in ["IL", "NY", "CA", "ME"] + ]) + self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [ + "CA", "IL", "ME", "NY", + ], attrgetter("two_letter_code")) \ No newline at end of file diff --git a/tests/regressiontests/db_typecasts/tests.py b/tests/regressiontests/db_typecasts/tests.py index 8c71c8f809..1d3bbfa101 100644 --- a/tests/regressiontests/db_typecasts/tests.py +++ b/tests/regressiontests/db_typecasts/tests.py @@ -53,10 +53,10 @@ TEST_CASES = { class DBTypeCasts(unittest.TestCase): def test_typeCasts(self): - for k, v in TEST_CASES.items(): + for k, v in TEST_CASES.iteritems(): for inpt, expected in v: got = getattr(typecasts, k)(inpt) - assert got == expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got) + self.assertEqual(got, expected, "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got)) if __name__ == '__main__': unittest.main()