Fixed #29725 -- Removed unnecessary join in QuerySet.count() and exists() on a many to many relation.

Co-Authored-By: Shiwei Chen <april.chen.0615@gmail.com>
This commit is contained in:
ontowhee 2023-05-16 19:12:53 -07:00 committed by Mariusz Felisiak
parent 0d8fbe2ade
commit 66e47ac69a
3 changed files with 151 additions and 10 deletions

View File

@ -75,7 +75,7 @@ from django.db import (
router, router,
transaction, transaction,
) )
from django.db.models import Q, Window, signals from django.db.models import Manager, Q, Window, signals
from django.db.models.functions import RowNumber from django.db.models.functions import RowNumber
from django.db.models.lookups import GreaterThan, LessThanOrEqual from django.db.models.lookups import GreaterThan, LessThanOrEqual
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
@ -1121,6 +1121,12 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
queryset._defer_next_filter = True queryset._defer_next_filter = True
return queryset._next_is_sticky().filter(**self.core_filters) return queryset._next_is_sticky().filter(**self.core_filters)
def get_prefetch_cache(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
return None
def _remove_prefetched_objects(self): def _remove_prefetched_objects(self):
try: try:
self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name) self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
@ -1128,9 +1134,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
pass # nothing to clear from cache pass # nothing to clear from cache
def get_queryset(self): def get_queryset(self):
try: if (cache := self.get_prefetch_cache()) is not None:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name] return cache
except (AttributeError, KeyError): else:
queryset = super().get_queryset() queryset = super().get_queryset()
return self._apply_rel_filters(queryset) return self._apply_rel_filters(queryset)
@ -1195,6 +1201,45 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
False, False,
) )
@property
def constrained_target(self):
# If the through relation's target field's foreign integrity is
# enforced, the query can be performed solely against the through
# table as the INNER JOIN'ing against target table is unnecessary.
if not self.target_field.db_constraint:
return None
db = router.db_for_read(self.through, instance=self.instance)
if not connections[db].features.supports_foreign_keys:
return None
hints = {"instance": self.instance}
manager = self.through._base_manager.db_manager(db, hints=hints)
filters = {self.source_field_name: self.instance.pk}
# Nullable target rows must be excluded as well as they would have
# been filtered out from an INNER JOIN.
if self.target_field.null:
filters["%s__isnull" % self.target_field_name] = False
return manager.filter(**filters)
def exists(self):
if (
superclass is Manager
and self.get_prefetch_cache() is None
and (constrained_target := self.constrained_target) is not None
):
return constrained_target.exists()
else:
return super().exists()
def count(self):
if (
superclass is Manager
and self.get_prefetch_cache() is None
and (constrained_target := self.constrained_target) is not None
):
return constrained_target.count()
else:
return super().count()
def add(self, *objs, through_defaults=None): def add(self, *objs, through_defaults=None):
self._remove_prefetched_objects() self._remove_prefetched_objects()
db = router.db_for_write(self.through, instance=self.instance) db = router.db_for_write(self.through, instance=self.instance)

View File

@ -78,3 +78,15 @@ class InheritedArticleA(AbstractArticle):
class InheritedArticleB(AbstractArticle): class InheritedArticleB(AbstractArticle):
pass pass
class NullableTargetArticle(models.Model):
headline = models.CharField(max_length=100)
publications = models.ManyToManyField(
Publication, through="NullablePublicationThrough"
)
class NullablePublicationThrough(models.Model):
article = models.ForeignKey(NullableTargetArticle, models.CASCADE)
publication = models.ForeignKey(Publication, models.CASCADE, null=True)

View File

@ -1,10 +1,18 @@
from unittest import mock from unittest import mock
from django.db import transaction from django.db import connection, transaction
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango60Warning from django.utils.deprecation import RemovedInDjango60Warning
from .models import Article, InheritedArticleA, InheritedArticleB, Publication, User from .models import (
Article,
InheritedArticleA,
InheritedArticleB,
NullablePublicationThrough,
NullableTargetArticle,
Publication,
User,
)
class ManyToManyTests(TestCase): class ManyToManyTests(TestCase):
@ -558,10 +566,16 @@ class ManyToManyTests(TestCase):
def test_custom_default_manager_exists_count(self): def test_custom_default_manager_exists_count(self):
a5 = Article.objects.create(headline="deleted") a5 = Article.objects.create(headline="deleted")
a5.publications.add(self.p2) a5.publications.add(self.p2)
self.assertEqual(self.p2.article_set.count(), self.p2.article_set.all().count()) with self.assertNumQueries(2) as ctx:
self.assertEqual(
self.p2.article_set.count(), self.p2.article_set.all().count()
)
self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
with self.assertNumQueries(2) as ctx:
self.assertEqual( self.assertEqual(
self.p3.article_set.exists(), self.p3.article_set.all().exists() self.p3.article_set.exists(), self.p3.article_set.all().exists()
) )
self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
def test_get_prefetch_queryset_warning(self): def test_get_prefetch_queryset_warning(self):
articles = Article.objects.all() articles = Article.objects.all()
@ -582,3 +596,73 @@ class ManyToManyTests(TestCase):
instances=articles, instances=articles,
querysets=[Publication.objects.all(), Publication.objects.all()], querysets=[Publication.objects.all(), Publication.objects.all()],
) )
class ManyToManyQueryTests(TestCase):
"""
SQL is optimized to reference the through table without joining against the
related table when using count() and exists() functions on a queryset for
many to many relations. The optimization applies to the case where there
are no filters.
"""
@classmethod
def setUpTestData(cls):
cls.article = Article.objects.create(
headline="Django lets you build Web apps easily"
)
cls.nullable_target_article = NullableTargetArticle.objects.create(
headline="The python is good"
)
NullablePublicationThrough.objects.create(
article=cls.nullable_target_article, publication=None
)
@skipUnlessDBFeature("supports_foreign_keys")
def test_count_join_optimization(self):
with self.assertNumQueries(1) as ctx:
self.article.publications.count()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
with self.assertNumQueries(1) as ctx:
self.article.publications.count()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
self.assertEqual(self.nullable_target_article.publications.count(), 0)
def test_count_join_optimization_disabled(self):
with (
mock.patch.object(connection.features, "supports_foreign_keys", False),
self.assertNumQueries(1) as ctx,
):
self.article.publications.count()
self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
@skipUnlessDBFeature("supports_foreign_keys")
def test_exists_join_optimization(self):
with self.assertNumQueries(1) as ctx:
self.article.publications.exists()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
self.article.publications.prefetch_related()
with self.assertNumQueries(1) as ctx:
self.article.publications.exists()
self.assertNotIn("JOIN", ctx.captured_queries[0]["sql"])
self.assertIs(self.nullable_target_article.publications.exists(), False)
def test_exists_join_optimization_disabled(self):
with (
mock.patch.object(connection.features, "supports_foreign_keys", False),
self.assertNumQueries(1) as ctx,
):
self.article.publications.exists()
self.assertIn("JOIN", ctx.captured_queries[0]["sql"])
def test_prefetch_related_no_queries_optimization_disabled(self):
qs = Article.objects.prefetch_related("publications")
article = qs.get()
with self.assertNumQueries(0):
article.publications.count()
with self.assertNumQueries(0):
article.publications.exists()