diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index a030112e75..c71bc634aa 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -772,17 +772,37 @@ class Query(object): unref_amount = cur_refcount - to_counts.get(alias, 0) self.unref_alias(alias, unref_amount) - def promote_unused_aliases(self, initial_refcounts, used_aliases): + def promote_disjunction(self, aliases_before, alias_usage_counts, + num_childs): """ - Given a "before" copy of the alias_refcounts dictionary (as - 'initial_refcounts') and a collection of aliases that may have been - changed or created, works out which aliases have been created since - then and which ones haven't been used and promotes all of those - aliases, plus any children of theirs in the alias tree, to outer joins. + This method is to be used for promoting joins in ORed filters. + + The principle for promotion is: any alias which is used (it is in + alias_usage_counts), is not used by every child of the ORed filter, + and isn't pre-existing needs to be promoted to LOUTER join. + + Some examples (assume all joins used are nullable): + - existing filter: a__f1=foo + - add filter: b__f1=foo|b__f2=foo + In this case we should not promote either of the joins (using INNER + doesn't remove results). We correctly avoid join promotion, because + a is not used in this branch, and b is used two times. + + - add filter a__f1=foo|b__f2=foo + In this case we should promote both a and b, otherwise they will + remove results. We will also correctly do that as both aliases are + used, and in addition both are used only once while there are two + filters. + + - existing: a__f1=bar + - add filter: a__f2=foo|b__f2=foo + We will not promote a as it is previously used. If the join results + in null, the existing filter can't succeed. + + The above (and some more) are tested in queries.DisjunctionPromotionTests """ - for alias in self.tables: - if alias in used_aliases and (alias not in initial_refcounts or - self.alias_refcount[alias] == initial_refcounts[alias]): + for alias, use_count in alias_usage_counts.items(): + if use_count < num_childs and alias not in aliases_before: self.promote_joins([alias]) def change_aliases(self, change_map): @@ -1150,16 +1170,12 @@ class Query(object): can_reuse) return - table_promote = False - join_promote = False - if (lookup_type == 'isnull' and value is True and not negate and len(join_list) > 1): # If the comparison is against NULL, we may need to use some left # outer joins when creating the join chain. This is only done when # needed, as it's less efficient at the database level. self.promote_joins(join_list) - join_promote = True # Process the join list to see if we can remove any inner joins from # the far end (fewer tables in a query is better). Note that join @@ -1167,39 +1183,6 @@ class Query(object): # information available when reusing joins. col, alias, join_list = self.trim_joins(target, join_list, path) - if connector == OR: - # Some joins may need to be promoted when adding a new filter to a - # disjunction. We walk the list of new joins and where it diverges - # from any previous joins (ref count is 1 in the table list), we - # make the new additions (and any existing ones not used in the new - # join list) an outer join. - join_it = iter(join_list) - table_it = iter(self.tables) - next(join_it), next(table_it) - unconditional = False - for join in join_it: - table = next(table_it) - # Once we hit an outer join, all subsequent joins must - # also be promoted, regardless of whether they have been - # promoted as a result of this pass through the tables. - unconditional = (unconditional or - self.alias_map[join].join_type == self.LOUTER) - if join == table and self.alias_refcount[join] > 1: - # We have more than one reference to this join table. - # This means that we are dealing with two different query - # subtrees, so we don't need to do any join promotion. - continue - join_promote = join_promote or self.promote_joins([join], unconditional) - if table != join: - table_promote = self.promote_joins([table]) - # We only get here if we have found a table that exists - # in the join list, but isn't on the original tables list. - # This means we've reached the point where we only have - # new tables, so we can break out of this promotion loop. - break - self.promote_joins(join_it, join_promote) - self.promote_joins(table_it, table_promote or join_promote) - if having_clause or force_having: if (alias, col) not in self.group_by: self.group_by.append((alias, col)) @@ -1256,33 +1239,36 @@ class Query(object): subtree = True else: subtree = False - connector = AND + connector = q_object.connector + if connector == OR: + alias_usage_counts = dict() + aliases_before = set(self.tables) if q_object.connector == OR and not force_having: force_having = self.need_force_having(q_object) for child in q_object.children: - if connector == OR: - refcounts_before = self.alias_refcount.copy() if force_having: self.having.start_subtree(connector) else: self.where.start_subtree(connector) + if connector == OR: + refcounts_before = self.alias_refcount.copy() if isinstance(child, Node): self.add_q(child, used_aliases, force_having=force_having) else: self.add_filter(child, connector, q_object.negated, can_reuse=used_aliases, force_having=force_having) + if connector == OR: + used = alias_diff(refcounts_before, self.alias_refcount) + for alias in used: + alias_usage_counts[alias] = alias_usage_counts.get(alias, 0) + 1 if force_having: self.having.end_subtree() else: self.where.end_subtree() - if connector == OR: - # Aliases that were newly added or not used at all need to - # be promoted to outer joins if they are nullable relations. - # (they shouldn't turn the whole conditional into the empty - # set just because they don't match anything). - self.promote_unused_aliases(refcounts_before, used_aliases) - connector = q_object.connector + if connector == OR: + self.promote_disjunction(aliases_before, alias_usage_counts, + len(q_object.children)) if q_object.negated: self.where.negate() if subtree: @@ -2005,3 +1991,11 @@ def is_reverse_o2o(field): expected to be some sort of relation field or related object. """ return not hasattr(field, 'rel') and field.field.unique + +def alias_diff(refcounts_before, refcounts_after): + """ + Given the before and after copies of refcounts works out which aliases + have been added to the after copy. + """ + return set(t for t in refcounts_after + if refcounts_after[t] > refcounts_before.get(t, 0)) diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index 73b9762150..16583e891c 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -421,3 +421,21 @@ class Responsibility(models.Model): def __str__(self): return self.description + +# Models for disjunction join promotion low level testing. +class FK1(models.Model): + f1 = models.TextField() + f2 = models.TextField() + +class FK2(models.Model): + f1 = models.TextField() + f2 = models.TextField() + +class FK3(models.Model): + f1 = models.TextField() + f2 = models.TextField() + +class BaseA(models.Model): + a = models.ForeignKey(FK1, null=True) + b = models.ForeignKey(FK2, null=True) + c = models.ForeignKey(FK3, null=True) diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index 9270955877..e3e515025c 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -8,8 +8,8 @@ import sys from django.conf import settings from django.core.exceptions import FieldError from django.db import DatabaseError, connection, connections, DEFAULT_DB_ALIAS -from django.db.models import Count -from django.db.models.query import Q, ITER_CHUNK_SIZE, EmptyQuerySet +from django.db.models import Count, F, Q +from django.db.models.query import ITER_CHUNK_SIZE, EmptyQuerySet from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode from django.db.models.sql.datastructures import EmptyResultSet from django.test import TestCase, skipUnlessDBFeature @@ -24,7 +24,7 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover, Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, - JobResponsibilities) + JobResponsibilities, BaseA) class BaseQuerysetTest(TestCase): @@ -2451,3 +2451,127 @@ class JoinReuseTest(TestCase): def test_revfk_noreuse(self): qs = Author.objects.filter(report__name='r4').filter(report__name='r1') self.assertEqual(str(qs.query).count('JOIN'), 2) + +class DisjunctionPromotionTests(TestCase): + def test_disjunction_promotion1(self): + # Pre-existing join, add two ORed filters to the same join, + # all joins can be INNER JOINS. + qs = BaseA.objects.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = qs.filter(Q(b__f1='foo') | Q(b__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + # Reverse the order of AND and OR filters. + qs = BaseA.objects.filter(Q(b__f1='foo') | Q(b__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = qs.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + + def test_disjunction_promotion2(self): + qs = BaseA.objects.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + # Now we have two different joins in an ORed condition, these + # must be OUTER joins. The pre-existing join should remain INNER. + qs = qs.filter(Q(b__f1='foo') | Q(c__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + # Reverse case. + qs = BaseA.objects.filter(Q(b__f1='foo') | Q(c__f2='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + qs = qs.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + + def test_disjunction_promotion3(self): + qs = BaseA.objects.filter(a__f2='bar') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + # The ANDed a__f2 filter allows us to use keep using INNER JOIN + # even inside the ORed case. If the join to a__ returns nothing, + # the ANDed filter for a__f2 can't be true. + qs = qs.filter(Q(a__f1='foo') | Q(b__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + + @unittest.expectedFailure + def test_disjunction_promotion3_failing(self): + # Now the ORed filter creates LOUTER join, but we do not have + # logic to unpromote it for the AND filter after it. The query + # results will be correct, but we have one LOUTER JOIN too much + # currently. + qs = BaseA.objects.filter( + Q(a__f1='foo') | Q(b__f2='foo')).filter(a__f2='bar') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + + def test_disjunction_promotion4(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('JOIN'), 0) + qs = qs.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = BaseA.objects.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = qs.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + + def test_disjunction_promotion5(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + # Note that the above filters on a force the join to an + # inner join even if it is trimmed. + self.assertEqual(str(qs.query).count('JOIN'), 0) + qs = qs.filter(Q(a__f1='foo') | Q(b__f1='foo')) + # So, now the a__f1 join doesn't need promotion. + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + + @unittest.expectedFailure + def test_disjunction_promotion5_failing(self): + qs = BaseA.objects.filter(Q(a__f1='foo') | Q(b__f1='foo')) + # Now the join to a is created as LOUTER + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 0) + # The below filter should force the a to be inner joined. But, + # this is failing as we do not have join unpromotion logic. + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + + def test_disjunction_promotion6(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('JOIN'), 0) + qs = BaseA.objects.filter(Q(a__f1='foo') & Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 0) + + qs = BaseA.objects.filter(Q(a__f1='foo') & Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + qs = qs.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 0) + + def test_disjunction_promotion7(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('JOIN'), 0) + qs = BaseA.objects.filter(Q(a__f1='foo') | (Q(b__f1='foo') & Q(a__f1='bar'))) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + qs = BaseA.objects.filter( + (Q(a__f1='foo') | Q(b__f1='foo')) & (Q(a__f1='bar') | Q(c__f1='foo')) + ) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 3) + self.assertEqual(str(qs.query).count('INNER JOIN'), 0) + qs = BaseA.objects.filter( + (Q(a__f1='foo') | (Q(a__f1='bar')) & (Q(b__f1='bar') | Q(c__f1='foo'))) + ) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + + def test_disjunction_promotion_fexpression(self): + qs = BaseA.objects.filter(Q(a__f1=F('b__f1')) | Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 3) + qs = BaseA.objects.filter(Q(a__f1=F('b__f1')) | Q(a__f2=F('b__f2')) | Q(c__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 3) + qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | (Q(pk=1) & Q(pk=2))) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + self.assertEqual(str(qs.query).count('INNER JOIN'), 0)