Fixed #19837 -- Refactored split_exclude() join generation

The refactoring mainly concentrates on making sure the inner and outer
query agree about the split position. The split position is where the
multijoin happens, and thus the split position also determines the
columns used in the "WHERE col1 IN (SELECT col2 from ...)" condition.

This commit fixes a regression caused by #10790 and commit
69597e5bcc. The regression was caused
by wrong cols in the split position.
This commit is contained in:
Anssi Kääriäinen 2013-02-18 01:56:24 +02:00
parent ffcfb19f47
commit b4492a8ca4
5 changed files with 103 additions and 74 deletions

View File

@ -12,8 +12,10 @@ class MultiJoin(Exception):
multi-valued join was attempted (if the caller wants to treat that multi-valued join was attempted (if the caller wants to treat that
exceptionally). exceptionally).
""" """
def __init__(self, level): def __init__(self, names_pos, path_with_names):
self.level = level self.level = names_pos
# The path travelled, this includes the path to the multijoin.
self.names_with_path = path_with_names
class Empty(object): class Empty(object):
pass pass

View File

@ -1200,7 +1200,7 @@ class Query(object):
can_reuse.update(join_list) can_reuse.update(join_list)
except MultiJoin as e: except MultiJoin as e:
self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
can_reuse) can_reuse, e.names_with_path)
return return
if (lookup_type == 'isnull' and value is True and not negate and if (lookup_type == 'isnull' and value is True and not negate and
@ -1324,7 +1324,7 @@ class Query(object):
(the last used join field), and target (which is a field guaranteed to (the last used join field), and target (which is a field guaranteed to
contain the same value as the final field). contain the same value as the final field).
""" """
path = [] path, names_with_path = [], []
for pos, name in enumerate(names): for pos, name in enumerate(names):
if name == 'pk': if name == 'pk':
name = opts.pk.name name = opts.pk.name
@ -1361,16 +1361,17 @@ class Query(object):
opts, final_field, False, True)) opts, final_field, False, True))
if hasattr(field, 'get_path_info'): if hasattr(field, 'get_path_info'):
pathinfos, opts, target, final_field = field.get_path_info() pathinfos, opts, target, final_field = field.get_path_info()
if not allow_many:
for inner_pos, p in enumerate(pathinfos):
if p.m2m:
names_with_path.append((name, pathinfos[0:inner_pos + 1]))
raise MultiJoin(pos + 1, names_with_path)
path.extend(pathinfos) path.extend(pathinfos)
names_with_path.append((name, pathinfos))
else: else:
# Local non-relational field. # Local non-relational field.
final_field = target = field final_field = target = field
break break
multijoin_pos = None
for m2mpos, pathinfo in enumerate(path):
if pathinfo.m2m:
multijoin_pos = m2mpos
break
if pos != len(names) - 1: if pos != len(names) - 1:
if pos == len(names) - 2: if pos == len(names) - 2:
@ -1379,8 +1380,6 @@ class Query(object):
"the lookup type?" % (name, names[pos + 1])) "the lookup type?" % (name, names[pos + 1]))
else: else:
raise FieldError("Join on field %r not permitted." % name) raise FieldError("Join on field %r not permitted." % name)
if multijoin_pos is not None and len(path) >= multijoin_pos and not allow_many:
raise MultiJoin(multijoin_pos + 1)
return path, final_field, target return path, final_field, target
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
@ -1454,7 +1453,7 @@ class Query(object):
break break
return target.column, joins[-1], joins return target.column, joins[-1], joins
def split_exclude(self, filter_expr, prefix, can_reuse): def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
""" """
When doing an exclude against any kind of N-to-many relation, we need When doing an exclude against any kind of N-to-many relation, we need
to use a subquery. This method constructs the nested query, given the to use a subquery. This method constructs the nested query, given the
@ -1462,11 +1461,10 @@ class Query(object):
N-to-many relation field. N-to-many relation field.
As an example we could have original filter ~Q(child__name='foo'). As an example we could have original filter ~Q(child__name='foo').
We would get here with filter_expr = child_name, prefix = child and We would get here with filter_expr = child__name, prefix = child and
can_reuse is a set of joins we can reuse for filtering in the original can_reuse is a set of joins usable for filters in the original query.
query.
We will turn this into We will turn this into equivalent of:
WHERE pk NOT IN (SELECT parent_id FROM thetable WHERE pk NOT IN (SELECT parent_id FROM thetable
WHERE name = 'foo' AND parent_id IS NOT NULL) WHERE name = 'foo' AND parent_id IS NOT NULL)
@ -1474,44 +1472,48 @@ class Query(object):
saner null handling, and is easier for the backend's optimizer to saner null handling, and is easier for the backend's optimizer to
handle. handle.
""" """
# Generate the inner query.
query = Query(self.model) query = Query(self.model)
query.add_filter(filter_expr) query.add_filter(filter_expr)
query.bump_prefix() query.bump_prefix()
query.clear_ordering(True) query.clear_ordering(True)
query.set_start(prefix) # Try to have as simple as possible subquery -> trim leading joins from
# Adding extra check to make sure the selected field will not be null # the subquery.
trimmed_joins = query.trim_start(names_with_path)
# Add extra check to make sure the selected field will not be null
# since we are adding a IN <subquery> clause. This prevents the # since we are adding a IN <subquery> clause. This prevents the
# database from tripping over IN (...,NULL,...) selects and returning # database from tripping over IN (...,NULL,...) selects and returning
# nothing # nothing
alias, col = query.select[0].col alias, col = query.select[0].col
query.where.add((Constraint(alias, col, None), 'isnull', False), AND) query.where.add((Constraint(alias, col, None), 'isnull', False), AND)
# We need to trim the last part from the prefix.
trimmed_prefix = LOOKUP_SEP.join(prefix.split(LOOKUP_SEP)[0:-1]) # Still make sure that the trimmed parts in the inner query and
if not trimmed_prefix: # trimmed prefix are in sync. So, use the trimmed_joins to make sure
rel, _, direct, m2m = self.model._meta.get_field_by_name(prefix) # as many path elements are in the prefix as there were trimmed joins.
if not m2m: # In addition, convert the path elements back to names so that
trimmed_prefix = rel.field.rel.field_name # add_filter() can handle them.
trimmed_prefix = []
paths_in_prefix = trimmed_joins
for name, path in names_with_path:
if paths_in_prefix - len(path) > 0:
trimmed_prefix.append(name)
paths_in_prefix -= len(path)
else: else:
if direct: trimmed_prefix.append(
trimmed_prefix = rel.m2m_target_field_name() path[paths_in_prefix - len(path)].from_field.name)
else: break
trimmed_prefix = rel.field.m2m_reverse_target_field_name() trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
self.add_filter(('%s__in' % trimmed_prefix, query), negate=True, self.add_filter(('%s__in' % trimmed_prefix, query), negate=True,
can_reuse=can_reuse) can_reuse=can_reuse)
# If there's more than one join in the inner query (before any initial # If there's more than one join in the inner query, we need to also
# bits were trimmed -- which means the last active table is more than # handle the possibility that the earlier joins don't match anything
# two places into the alias list), we need to also handle the # by adding a comparison to NULL (e.g. in
# possibility that the earlier joins don't match anything by adding a # Tag.objects.exclude(parent__parent__name='t1')
# comparison to NULL (e.g. in # a tag with no parent would otherwise be overlooked).
# Tag.objects.exclude(parent__parent__name='t1'), a tag with no parent if trimmed_joins > 1:
# would otherwise be overlooked).
active_positions = len([count for count
in query.alias_refcount.items() if count])
if active_positions > 1:
self.add_filter(('%s__isnull' % trimmed_prefix, False), negate=True, self.add_filter(('%s__isnull' % trimmed_prefix, False), negate=True,
can_reuse=can_reuse) can_reuse=can_reuse)
def set_empty(self): def set_empty(self):
self.where = EmptyWhere() self.where = EmptyWhere()
@ -1869,42 +1871,33 @@ class Query(object):
return self.extra return self.extra
extra_select = property(_extra_select) extra_select = property(_extra_select)
def set_start(self, start): def trim_start(self, names_with_path):
""" """
Sets the table from which to start joining. The start position is Trims joins from the start of the join path. The candidates for trim
specified by the related attribute from the base model. This will are the PathInfos in names_with_path structure. Outer joins are not
automatically set to the select column to be the column linked from the eligible for removal. Also sets the select column so the start
previous table. matches the join.
This method is primarily for internal use and the error checking isn't This method is mostly useful for generating the subquery joins & col
as friendly as add_filter(). Mostly useful for querying directly in "WHERE somecol IN (subquery)". This construct is needed by
against the join table of many-to-many relation in a subquery. split_exclude().
""" _"""
opts = self.model._meta join_pos = 0
alias = self.get_initial_alias() for _, paths in names_with_path:
field, col, opts, joins, extra = self.setup_joins( for path in paths:
start.split(LOOKUP_SEP), opts, alias) peek = self.tables[join_pos + 1]
select_col = self.alias_map[joins[1]].lhs_join_col if self.alias_map[peek].join_type == self.LOUTER:
select_alias = alias # Back up one level and break
select_alias = self.tables[join_pos]
# The call to setup_joins added an extra reference to everything in select_col = path.from_field.column
# joins. Reverse that. break
for alias in joins: select_alias = self.tables[join_pos + 1]
self.unref_alias(alias) select_col = path.to_field.column
self.unref_alias(self.tables[join_pos])
# We might be able to trim some joins from the front of this query, join_pos += 1
# providing that we only traverse "always equal" connections (i.e. rhs
# is *always* the same value as lhs).
for alias in joins[1:]:
join_info = self.alias_map[alias]
if (join_info.lhs_join_col != select_col
or join_info.join_type != self.INNER):
break
self.unref_alias(select_alias)
select_alias = join_info.rhs_alias
select_col = join_info.rhs_join_col
self.select = [SelectInfo((select_alias, select_col), None)] self.select = [SelectInfo((select_alias, select_col), None)]
self.remove_inherited_models() self.remove_inherited_models()
return join_pos
def is_nullable(self, field): def is_nullable(self, field):
""" """

View File

@ -439,3 +439,17 @@ class BaseA(models.Model):
a = models.ForeignKey(FK1, null=True) a = models.ForeignKey(FK1, null=True)
b = models.ForeignKey(FK2, null=True) b = models.ForeignKey(FK2, null=True)
c = models.ForeignKey(FK3, null=True) c = models.ForeignKey(FK3, null=True)
@python_2_unicode_compatible
class Identifier(models.Model):
name = models.CharField(max_length=100)
def __str__(self):
return self.name
class Program(models.Model):
identifier = models.OneToOneField(Identifier)
class Channel(models.Model):
programs = models.ManyToManyField(Program)
identifier = models.OneToOneField(Identifier)

View File

@ -24,7 +24,7 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover,
Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory, Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory,
SpecialCategory, OneToOneCategory, NullableName, ProxyCategory, SpecialCategory, OneToOneCategory, NullableName, ProxyCategory,
SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job, SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job,
JobResponsibilities, BaseA) JobResponsibilities, BaseA, Identifier, Program, Channel)
class BaseQuerysetTest(TestCase): class BaseQuerysetTest(TestCase):
@ -2612,3 +2612,22 @@ class DisjunctionPromotionTests(TestCase):
qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | (Q(pk=1) & Q(pk=2))) qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | (Q(pk=1) & Q(pk=2)))
self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2)
self.assertEqual(str(qs.query).count('INNER JOIN'), 0) self.assertEqual(str(qs.query).count('INNER JOIN'), 0)
class ManyToManyExcludeTest(TestCase):
def test_exclude_many_to_many(self):
Identifier.objects.create(name='extra')
program = Program.objects.create(identifier=Identifier.objects.create(name='program'))
channel = Channel.objects.create(identifier=Identifier.objects.create(name='channel'))
channel.programs.add(program)
# channel contains 'program1', so all Identifiers except that one
# should be returned
self.assertQuerysetEqual(
Identifier.objects.exclude(program__channel=channel).order_by('name'),
['<Identifier: channel>', '<Identifier: extra>']
)
self.assertQuerysetEqual(
Identifier.objects.exclude(program__channel=None).order_by('name'),
['<Identifier: program>']
)

1
tests/tmp.txt Normal file
View File

@ -0,0 +1 @@
SELECT "queries_tag"."id", "queries_tag"."name", "queries_tag"."parent_id", "queries_tag"."category_id" FROM "queries_tag" WHERE NOT (("queries_tag"."id" IN (SELECT U0."id" FROM "queries_tag" U0 LEFT OUTER JOIN "queries_tag" U1 ON (U0."id" = U1."parent_id") WHERE (U1."id" IS NULL AND U0."id" IS NOT NULL)) AND "queries_tag"."id" IS NOT NULL)) ORDER BY "queries_tag"."name" ASC