diff --git a/django/contrib/contenttypes/models.py b/django/contrib/contenttypes/models.py index 4823e2251a7..26a3db17c90 100644 --- a/django/contrib/contenttypes/models.py +++ b/django/contrib/contenttypes/models.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals +from collections import defaultdict + from django.apps import apps from django.db import models from django.utils.encoding import force_text, python_2_unicode_compatible @@ -65,12 +67,12 @@ class ContentTypeManager(models.Manager): Given *models, returns a dictionary mapping {model: content_type}. """ for_concrete_models = kwargs.pop('for_concrete_models', True) - # Final results results = {} - # models that aren't already in the cache + # Models that aren't already in the cache. needed_app_labels = set() needed_models = set() - needed_opts = set() + # Mapping of opts to the list of models requiring it. + needed_opts = defaultdict(list) for model in models: opts = self._get_opts(model, for_concrete_models) try: @@ -78,28 +80,30 @@ class ContentTypeManager(models.Manager): except KeyError: needed_app_labels.add(opts.app_label) needed_models.add(opts.model_name) - needed_opts.add(opts) + needed_opts[opts].append(model) else: results[model] = ct if needed_opts: + # Lookup required content types from the DB. cts = self.filter( app_label__in=needed_app_labels, model__in=needed_models ) for ct in cts: model = ct.model_class() - if model._meta in needed_opts: + opts_models = needed_opts.pop(ct.model_class()._meta, []) + for model in opts_models: results[model] = ct - needed_opts.remove(model._meta) self._add_to_cache(self.db, ct) - for opts in needed_opts: - # These weren't in the cache, or the DB, create them. + # Create content types that weren't in the cache or DB. + for opts, opts_models in needed_opts.items(): ct = self.create( app_label=opts.app_label, model=opts.model_name, ) self._add_to_cache(self.db, ct) - results[ct.model_class()] = ct + for model in opts_models: + results[model] = ct return results def get_for_id(self, id): diff --git a/tests/contenttypes_tests/test_models.py b/tests/contenttypes_tests/test_models.py index cf0c188fa81..a51a343c839 100644 --- a/tests/contenttypes_tests/test_models.py +++ b/tests/contenttypes_tests/test_models.py @@ -54,13 +54,26 @@ class ContentTypesTests(TestCase): with self.assertNumQueries(0): ContentType.objects.get_by_natural_key('contenttypes', 'contenttype') - def test_get_for_models_empty_cache(self): - # Empty cache. - with self.assertNumQueries(1): - cts = ContentType.objects.get_for_models(ContentType, FooWithUrl) + def test_get_for_models_creation(self): + ContentType.objects.all().delete() + with self.assertNumQueries(4): + cts = ContentType.objects.get_for_models(ContentType, FooWithUrl, ProxyModel, ConcreteModel) self.assertEqual(cts, { ContentType: ContentType.objects.get_for_model(ContentType), FooWithUrl: ContentType.objects.get_for_model(FooWithUrl), + ProxyModel: ContentType.objects.get_for_model(ProxyModel), + ConcreteModel: ContentType.objects.get_for_model(ConcreteModel), + }) + + def test_get_for_models_empty_cache(self): + # Empty cache. + with self.assertNumQueries(1): + cts = ContentType.objects.get_for_models(ContentType, FooWithUrl, ProxyModel, ConcreteModel) + self.assertEqual(cts, { + ContentType: ContentType.objects.get_for_model(ContentType), + FooWithUrl: ContentType.objects.get_for_model(FooWithUrl), + ProxyModel: ContentType.objects.get_for_model(ProxyModel), + ConcreteModel: ContentType.objects.get_for_model(ConcreteModel), }) def test_get_for_models_partial_cache(self):