From 3afe409d0ea9d2e3e1dc70abd4bc1cc1a916daf9 Mon Sep 17 00:00:00 2001 From: Russell Keith-Magee Date: Mon, 22 Aug 2011 07:40:12 +0000 Subject: [PATCH] Fixed #14876 -- Ensure that join promotion works correctly when there are nullable related fields. Thanks to simonpercivall for the report, oinopion and Aleksandra Sendecka for the original patch, and to Malcolm for helping me wrestle the edge cases to the ground. git-svn-id: http://code.djangoproject.com/svn/django/trunk@16648 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/sql/query.py | 50 ++++++++++++++++++-------- tests/regressiontests/queries/tests.py | 34 +++++++++++++++--- 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 110e3179d5..d906cb1132 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -445,8 +445,6 @@ class Query(object): "Cannot combine a unique query with a non-unique query." self.remove_inherited_models() - l_tables = set([a for a in self.tables if self.alias_refcount[a]]) - r_tables = set([a for a in rhs.tables if rhs.alias_refcount[a]]) # Work out how to relabel the rhs aliases, if necessary. change_map = {} used = set() @@ -471,16 +469,27 @@ class Query(object): # all joins exclusive to either the lhs or the rhs must be converted # to an outer join. if not conjunction: + l_tables = set(self.tables) + r_tables = set(rhs.tables) # Update r_tables aliases. for alias in change_map: if alias in r_tables: - r_tables.remove(alias) - r_tables.add(change_map[alias]) + # r_tables may contain entries that have a refcount of 0 + # if the query has references to a table that can be + # trimmed because only the foreign key is used. + # We only need to fix the aliases for the tables that + # actually have aliases. + if rhs.alias_refcount[alias]: + r_tables.remove(alias) + r_tables.add(change_map[alias]) # Find aliases that are exclusive to rhs or lhs. # These are promoted to outer joins. - outer_aliases = (l_tables | r_tables) - (l_tables & r_tables) - for alias in outer_aliases: - self.promote_alias(alias, True) + outer_tables = (l_tables | r_tables) - (l_tables & r_tables) + for alias in outer_tables: + # Again, some of the tables won't have aliases due to + # the trimming of unnecessary tables. + if self.alias_refcount.get(alias) or rhs.alias_refcount.get(alias): + self.promote_alias(alias, True) # Now relabel a copy of the rhs where-clause and add it to the current # one. @@ -668,7 +677,7 @@ class Query(object): False, the join is only promoted if it is nullable, otherwise it is always promoted. - Returns True if the join was promoted. + Returns True if the join was promoted by this call. """ if ((unconditional or self.alias_map[alias][NULLABLE]) and self.alias_map[alias][JOIN_TYPE] != self.LOUTER): @@ -1076,17 +1085,20 @@ class Query(object): can_reuse) return + table_promote = False + join_promote = False + if (lookup_type == 'isnull' and value is True and not negate and len(join_list) > 1): # If the comparison is against NULL, we may need to use some left # outer joins when creating the join chain. This is only done when # needed, as it's less efficient at the database level. self.promote_alias_chain(join_list) + join_promote = True # Process the join list to see if we can remove any inner joins from # the far end (fewer tables in a query is better). col, alias, join_list = self.trim_joins(target, join_list, last, trim) - if connector == OR: # Some joins may need to be promoted when adding a new filter to a # disjunction. We walk the list of new joins and where it diverges @@ -1096,19 +1108,29 @@ class Query(object): join_it = iter(join_list) table_it = iter(self.tables) join_it.next(), table_it.next() - table_promote = False - join_promote = False + unconditional = False for join in join_it: table = table_it.next() + # Once we hit an outer join, all subsequent joins must + # also be promoted, regardless of whether they have been + # promoted as a result of this pass through the tables. + unconditional = (unconditional or + self.alias_map[join][JOIN_TYPE] == self.LOUTER) if join == table and self.alias_refcount[join] > 1: + # We have more than one reference to this join table. + # This means that we are dealing with two different query + # subtrees, so we don't need to do any join promotion. continue - join_promote = self.promote_alias(join) + join_promote = join_promote or self.promote_alias(join, unconditional) if table != join: table_promote = self.promote_alias(table) + # We only get here if we have found a table that exists + # in the join list, but isn't on the original tables list. + # This means we've reached the point where we only have + # new tables, so we can break out of this promotion loop. break self.promote_alias_chain(join_it, join_promote) - self.promote_alias_chain(table_it, table_promote) - + self.promote_alias_chain(table_it, table_promote or join_promote) if having_clause or force_having: if (alias, col) not in self.group_by: diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index 231a336649..4181ea5170 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -959,12 +959,36 @@ class Queries4Tests(BaseQuerysetTest): e1 = ExtraInfo.objects.create(info='e1', note=n1) e2 = ExtraInfo.objects.create(info='e2', note=n2) - a1 = Author.objects.create(name='a1', num=1001, extra=e1) - a3 = Author.objects.create(name='a3', num=3003, extra=e2) + self.a1 = Author.objects.create(name='a1', num=1001, extra=e1) + self.a3 = Author.objects.create(name='a3', num=3003, extra=e2) - Report.objects.create(name='r1', creator=a1) - Report.objects.create(name='r2', creator=a3) - Report.objects.create(name='r3') + self.r1 = Report.objects.create(name='r1', creator=self.a1) + self.r2 = Report.objects.create(name='r2', creator=self.a3) + self.r3 = Report.objects.create(name='r3') + + Item.objects.create(name='i1', created=datetime.datetime.now(), note=n1, creator=self.a1) + Item.objects.create(name='i2', created=datetime.datetime.now(), note=n1, creator=self.a3) + + def test_ticket14876(self): + q1 = Report.objects.filter(Q(creator__isnull=True) | Q(creator__extra__info='e1')) + q2 = Report.objects.filter(Q(creator__isnull=True)) | Report.objects.filter(Q(creator__extra__info='e1')) + self.assertQuerysetEqual(q1, ["", ""]) + self.assertEqual(str(q1.query), str(q2.query)) + + q1 = Report.objects.filter(Q(creator__extra__info='e1') | Q(creator__isnull=True)) + q2 = Report.objects.filter(Q(creator__extra__info='e1')) | Report.objects.filter(Q(creator__isnull=True)) + self.assertQuerysetEqual(q1, ["", ""]) + self.assertEqual(str(q1.query), str(q2.query)) + + q1 = Item.objects.filter(Q(creator=self.a1) | Q(creator__report__name='r1')).order_by() + q2 = Item.objects.filter(Q(creator=self.a1)).order_by() | Item.objects.filter(Q(creator__report__name='r1')).order_by() + self.assertQuerysetEqual(q1, [""]) + self.assertEqual(str(q1.query), str(q2.query)) + + q1 = Item.objects.filter(Q(creator__report__name='e1') | Q(creator=self.a1)).order_by() + q2 = Item.objects.filter(Q(creator__report__name='e1')).order_by() | Item.objects.filter(Q(creator=self.a1)).order_by() + self.assertQuerysetEqual(q1, [""]) + self.assertEqual(str(q1.query), str(q2.query)) def test_ticket7095(self): # Updates that are filtered on the model being updated are somewhat