Fixed #10790 -- Refactored sql.Query.setup_joins()

This is a rather large refactoring. The "lookup traversal" code was
splitted out from the setup_joins. There is now names_to_path() method
which does the lookup traveling, the actual work of setup_joins() is
calling names_to_path() and then adding the joins found into the query.

As a side effect it was possible to remove the "process_extra"
functionality used by genric relations. This never worked for left
joins. Now the extra restriction is appended directly to the join
condition instead of the where clause.

To generate the extra condition we need to have the join field
available in the compiler. This has the side-effect that we need more
ugly code in Query.__getstate__ and __setstate__ as Field objects
aren't pickleable.

The join trimming code got a big change - now we trim all direct joins
and never trim reverse joins. This also fixes the problem in #10790
which was join trimming in null filter cases.
This commit is contained in:
Anssi Kääriäinen 2012-08-25 16:33:07 +03:00
parent f811649710
commit 69597e5bcc
8 changed files with 562 additions and 266 deletions

View File

@ -205,17 +205,16 @@ class GenericRelation(RelatedField, Field):
# same db_type as well.
return None
def extra_filters(self, pieces, pos, negate):
def get_content_type(self):
"""
Return an extra filter to the queryset so that the results are filtered
on the appropriate content type.
Returns the content type associated with this field's model.
"""
if negate:
return []
content_type = ContentType.objects.get_for_model(self.model)
prefix = "__".join(pieces[:pos + 1])
return [("%s__%s" % (prefix, self.content_type_field_name),
content_type)]
return ContentType.objects.get_for_model(self.model)
def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias):
extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column
contenttype = self.get_content_type().pk
return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype]
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
"""
@ -246,9 +245,6 @@ class ReverseGenericRelatedObjectsDescriptor(object):
if instance is None:
return self
# This import is done here to avoid circular import importing this module
from django.contrib.contenttypes.models import ContentType
# Dynamically create a class that subclasses the related model's
# default manager.
rel_model = self.field.rel.to
@ -379,8 +375,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet):
def __init__(self, data=None, files=None, instance=None, save_as_new=None,
prefix=None, queryset=None):
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
opts = self.model._meta
self.instance = instance
self.rel_name = '-'.join((
@ -409,8 +403,6 @@ class BaseGenericInlineFormSet(BaseModelFormSet):
))
def save_new(self, form, commit=True):
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
kwargs = {
self.ct_field.get_attname(): ContentType.objects.get_for_model(self.instance).pk,
self.ct_fk_field.get_attname(): self.instance.pk,
@ -432,8 +424,6 @@ def generic_inlineformset_factory(model, form=ModelForm,
defaults ``content_type`` and ``object_id`` respectively.
"""
opts = model._meta
# Avoid a circular import.
from django.contrib.contenttypes.models import ContentType
# if there is no field called `ct_field` let the exception propagate
ct_field = opts.get_field(ct_field)
if not isinstance(ct_field, models.ForeignKey) or ct_field.rel.to != ContentType:

View File

@ -274,7 +274,8 @@ class SQLCompiler(object):
except KeyError:
link_field = opts.get_ancestor_link(model)
alias = self.query.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column))
link_field.column, model._meta.pk.column),
join_field=link_field)
seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
@ -448,8 +449,8 @@ class SQLCompiler(object):
"""
if not alias:
alias = self.query.get_initial_alias()
field, target, opts, joins, _, _ = self.query.setup_joins(pieces,
opts, alias, REUSE_ALL)
field, target, opts, joins, _ = self.query.setup_joins(
pieces, opts, alias, REUSE_ALL)
# We will later on need to promote those joins that were added to the
# query afresh above.
joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
@ -501,20 +502,27 @@ class SQLCompiler(object):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
first = True
from_params = []
for alias in self.query.tables:
if not self.query.alias_refcount[alias]:
continue
try:
name, alias, join_type, lhs, lhs_col, col, nullable = self.query.alias_map[alias]
name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias]
except KeyError:
# Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them.
continue
alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first:
result.append('%s %s%s ON (%s.%s = %s.%s)'
% (join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col)))
if join_field and hasattr(join_field, 'get_extra_join_sql'):
extra_cond, extra_params = join_field.get_extra_join_sql(
self.connection, qn, lhs, alias)
from_params.extend(extra_params)
else:
extra_cond = ""
result.append('%s %s%s ON (%s.%s = %s.%s%s)' %
(join_type, qn(name), alias_str, qn(lhs),
qn2(lhs_col), qn(alias), qn2(col), extra_cond))
else:
connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str))
@ -528,7 +536,7 @@ class SQLCompiler(object):
connector = not first and ', ' or ''
result.append('%s%s' % (connector, qn(alias)))
first = False
return result, []
return result, from_params
def get_grouping(self, ordering_group_by):
"""
@ -638,7 +646,7 @@ class SQLCompiler(object):
alias = self.query.join((alias, table, f.column,
f.rel.get_related_field().column),
promote=promote)
promote=promote, join_field=f)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend(
@ -685,7 +693,7 @@ class SQLCompiler(object):
alias_chain.append(alias)
alias = self.query.join(
(alias, table, f.rel.get_related_field().column, f.column),
promote=True
promote=True, join_field=f
)
from_parent = (opts.model if issubclass(model, opts.model)
else None)

View File

@ -18,12 +18,19 @@ QUERY_TERMS = set([
# Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100
# Constants to make looking up tuple values clearer.
# Namedtuples for sql.* internal use.
# Join lists (indexes into the tuples that are values in the alias_map
# dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias '
'lhs_join_col rhs_join_col nullable')
'lhs_join_col rhs_join_col nullable join_field')
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the join in Model terms (model Options and Fields for both
# sides of the join. The rel_field is the field we are joining along.
PathInfo = namedtuple('PathInfo',
'from_field to_field from_opts to_opts join_field')
# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')

View File

@ -50,10 +50,10 @@ class SQLEvaluator(object):
self.cols.append((node, query.aggregate_select[node.name]))
else:
try:
field, source, opts, join_list, last, _ = query.setup_joins(
field, source, opts, join_list, path = query.setup_joins(
field_list, query.get_meta(),
query.get_initial_alias(), self.reuse)
col, _, join_list = query.trim_joins(source, join_list, last, False)
col, _, join_list = query.trim_joins(source, join_list, path)
if self.reuse is not None and self.reuse != REUSE_ALL:
self.reuse.update(join_list)
self.cols.append((node, (join_list[-1], col)))

View File

@ -14,13 +14,13 @@ from django.utils.encoding import force_text
from django.utils.tree import Node
from django.utils import six
from django.db import connections, DEFAULT_DB_ALIAS
from django.db.models import signals
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist
from django.db.models.loading import get_model
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo)
ORDER_PATTERN, REUSE_ALL, JoinInfo, SelectInfo, PathInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@ -119,7 +119,7 @@ class Query(object):
self.filter_is_sticky = False
self.included_inherited_models = {}
# SQL-related attributes
# SQL-related attributes
# Select and related select clauses as SelectInfo instances.
# The select is used for cases where we want to set up the select
# clause to contain other than default fields (values(), annotate(),
@ -201,6 +201,16 @@ class Query(object):
(s.col, s.field is not None and s.field.name or None)
for s in obj_dict['select']
]
# alias_map can also contain references to fields.
new_alias_map = {}
for alias, join_info in obj_dict['alias_map'].items():
if join_info.join_field is None:
new_alias_map[alias] = join_info
else:
model = join_info.join_field.model._meta
field_id = (model.app_label, model.object_name, join_info.join_field.name)
new_alias_map[alias] = join_info._replace(join_field=field_id)
obj_dict['alias_map'] = new_alias_map
return obj_dict
def __setstate__(self, obj_dict):
@ -213,6 +223,15 @@ class Query(object):
SelectInfo(tpl[0], tpl[1] is not None and opts.get_field(tpl[1]) or None)
for tpl in obj_dict['select']
]
new_alias_map = {}
for alias, join_info in obj_dict['alias_map'].items():
if join_info.join_field is None:
new_alias_map[alias] = join_info
else:
field_id = join_info.join_field
new_alias_map[alias] = join_info._replace(
join_field=get_model(field_id[0], field_id[1])._meta.get_field(field_id[2]))
obj_dict['alias_map'] = new_alias_map
self.__dict__.update(obj_dict)
@ -479,21 +498,26 @@ class Query(object):
# Now, add the joins from rhs query into the new query (skipping base
# table).
for alias in rhs.tables[1:]:
if not rhs.alias_refcount[alias]:
continue
table, _, join_type, lhs, lhs_col, col, nullable = rhs.alias_map[alias]
table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias]
promote = (join_type == self.LOUTER)
# If the left side of the join was already relabeled, use the
# updated alias.
lhs = change_map.get(lhs, lhs)
new_alias = self.join(
(lhs, table, lhs_col, col), reuse=reuse, promote=promote,
outer_if_first=not conjunction, nullable=nullable)
outer_if_first=not conjunction, nullable=nullable,
join_field=join_field)
# We can't reuse the same join again in the query. If we have two
# distinct joins for the same connection in rhs query, then the
# combined query must have two joins, too.
reuse.discard(new_alias)
change_map[alias] = new_alias
if not rhs.alias_refcount[alias]:
# The alias was unused in the rhs query. Unref it so that it
# will be unused in the new query, too. We have to add and
# unref the alias so that join promotion has information of
# the join type for the unused alias.
self.unref_alias(new_alias)
# So that we don't exclude valid results in an "or" query combination,
# all joins exclusive to either the lhs or the rhs must be converted
@ -868,7 +892,7 @@ class Query(object):
return len([1 for count in self.alias_refcount.values() if count])
def join(self, connection, reuse=REUSE_ALL, promote=False,
outer_if_first=False, nullable=False):
outer_if_first=False, nullable=False, join_field=None):
"""
Returns an alias for the join in 'connection', either reusing an
existing alias for that join or creating a new one. 'connection' is a
@ -897,6 +921,8 @@ class Query(object):
If 'nullable' is True, the join can potentially involve NULL values and
is a candidate for promotion (to "left outer") when combining querysets.
The 'join_field' is the field we are joining along (if any).
"""
lhs, table, lhs_col, col = connection
existing = self.join_map.get(connection, ())
@ -906,8 +932,13 @@ class Query(object):
reuse = set()
else:
reuse = [a for a in existing if a in reuse]
if reuse:
alias = reuse[0]
for alias in reuse:
if join_field and self.alias_map[alias].join_field != join_field:
# The join_map doesn't contain join_field (mainly because
# fields in Query structs are problematic in pickling), so
# check that the existing join is created using the same
# join_field used for the under work join.
continue
self.ref_alias(alias)
if promote or (lhs and self.alias_map[lhs].join_type == self.LOUTER):
self.promote_joins([alias])
@ -926,7 +957,8 @@ class Query(object):
join_type = self.LOUTER
else:
join_type = self.INNER
join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable)
join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable,
join_field)
self.alias_map[alias] = join
if connection in self.join_map:
self.join_map[connection] += (alias,)
@ -1007,11 +1039,11 @@ class Query(object):
# - this is an annotation over a model field
# then we need to explore the joins that are required.
field, source, opts, join_list, last, _ = self.setup_joins(
field, source, opts, join_list, path = self.setup_joins(
field_list, opts, self.get_initial_alias(), REUSE_ALL)
# Process the join chain to see if it can be trimmed
col, _, join_list = self.trim_joins(source, join_list, last, False)
col, _, join_list = self.trim_joins(source, join_list, path)
# If the aggregate references a model or field that requires a join,
# those joins must be LEFT OUTER - empty join rows must be returned
@ -1030,7 +1062,7 @@ class Query(object):
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
def add_filter(self, filter_expr, connector=AND, negate=False,
can_reuse=None, process_extras=True, force_having=False):
can_reuse=None, force_having=False):
"""
Add a single filter to the query. The 'filter_expr' is a pair:
(filter_string, value). E.g. ('name__contains', 'fred')
@ -1047,10 +1079,6 @@ class Query(object):
will be a set of table aliases that can be reused in this filter, even
if we would otherwise force the creation of new aliases for a join
(needed for nested Q-filters). The set is updated by this method.
If 'process_extras' is set, any extra filters returned from the table
joining process will be processed. This parameter is set to False
during the processing of extra filters to avoid infinite recursion.
"""
arg, value = filter_expr
parts = arg.split(LOOKUP_SEP)
@ -1115,10 +1143,11 @@ class Query(object):
allow_many = not negate
try:
field, target, opts, join_list, last, extra_filters = self.setup_joins(
field, target, opts, join_list, path = self.setup_joins(
parts, opts, alias, can_reuse, allow_many,
allow_explicit_fk=True, negate=negate,
process_extras=process_extras)
allow_explicit_fk=True)
if can_reuse is not None:
can_reuse.update(join_list)
except MultiJoin as e:
self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
can_reuse)
@ -1136,10 +1165,10 @@ class Query(object):
join_promote = True
# Process the join list to see if we can remove any inner joins from
# the far end (fewer tables in a query is better).
nonnull_comparison = (lookup_type == 'isnull' and value is False)
col, alias, join_list = self.trim_joins(target, join_list, last,
nonnull_comparison)
# the far end (fewer tables in a query is better). Note that join
# promotion must happen before join trimming to have the join type
# information available when reusing joins.
col, alias, join_list = self.trim_joins(target, join_list, path)
if connector == OR:
# Some joins may need to be promoted when adding a new filter to a
@ -1212,12 +1241,6 @@ class Query(object):
# is added in upper layers of the code.
self.where.add((Constraint(alias, col, None), 'isnull', False), AND)
if can_reuse is not None:
can_reuse.update(join_list)
if process_extras:
for filter in extra_filters:
self.add_filter(filter, negate=negate, can_reuse=can_reuse,
process_extras=False)
def add_q(self, q_object, used_aliases=None, force_having=False):
"""
@ -1270,37 +1293,24 @@ class Query(object):
if self.filter_is_sticky:
self.used_aliases = used_aliases
def setup_joins(self, names, opts, alias, can_reuse, allow_many=True,
allow_explicit_fk=False, negate=False, process_extras=True):
def names_to_path(self, names, opts, allow_many=False,
allow_explicit_fk=True):
"""
Compute the necessary table joins for the passage through the fields
given in 'names'. 'opts' is the Options class for the current model
(which gives the table we are joining to), 'alias' is the alias for the
table we are joining to.
Walks the names path and turns them PathInfo tuples. Note that a
single name in 'names' can generate multiple PathInfos (m2m for
example).
The 'can_reuse' defines the reverse foreign key joins we can reuse. It
can be either sql.constants.REUSE_ALL in which case all joins are
reusable or a set of aliases that can be reused. Non-reverse foreign
keys are always reusable.
'names' is the path of names to travle, 'opts' is the model Options we
start the name resolving from, 'allow_many' and 'allow_explicit_fk'
are as for setup_joins().
The 'allow_explicit_fk' controls if field.attname is allowed in the
lookups.
Finally, 'negate' is used in the same sense as for add_filter()
-- it indicates an exclude() filter, or something similar. It is only
passed in here so that it can be passed to a field's extra_filter() for
customized behavior.
Returns the final field involved in the join, the target database
column (used for any 'where' constraint), the final 'opts' value and the
list of tables joined.
Returns a list of PathInfo tuples. In addition returns the final field
(the last used join field), and target (which is a field guaranteed to
contain the same value as the final field).
"""
joins = [alias]
last = [0]
extra_filters = []
int_alias = None
path = []
multijoin_pos = None
for pos, name in enumerate(names):
last.append(len(joins))
if name == 'pk':
name = opts.pk.name
try:
@ -1314,14 +1324,12 @@ class Query(object):
field, model, direct, m2m = opts.get_field_by_name(f.name)
break
else:
names = opts.get_all_field_names() + list(self.aggregate_select)
available = opts.get_all_field_names() + list(self.aggregate_select)
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
if not allow_many and (m2m or not direct):
for alias in joins:
self.unref_alias(alias)
raise MultiJoin(pos + 1)
"Choices are: %s" % (name, ", ".join(available)))
# Check if we need any joins for concrete inheritance cases (the
# field lives in parent, but we are currently in one of its
# children)
if model:
# The field lives on a base class of the current model.
# Skip the chain of proxy to the concrete proxied model
@ -1331,172 +1339,179 @@ class Query(object):
if int_model is proxied_model:
opts = int_model._meta
else:
lhs_col = opts.parents[int_model].column
final_field = opts.parents[int_model]
target = final_field.rel.get_related_field()
opts = int_model._meta
alias = self.join((alias, opts.db_table, lhs_col,
opts.pk.column))
joins.append(alias)
cached_data = opts._join_cache.get(name)
orig_opts = opts
if process_extras and hasattr(field, 'extra_filters'):
extra_filters.extend(field.extra_filters(names, pos, negate))
if direct:
if m2m:
# Many-to-many field defined on the current model.
if cached_data:
(table1, from_col1, to_col1, table2, from_col2,
to_col2, opts, target) = cached_data
else:
table1 = field.m2m_db_table()
from_col1 = opts.get_field_by_name(
field.m2m_target_field_name())[0].column
to_col1 = field.m2m_column_name()
opts = field.rel.to._meta
table2 = opts.db_table
from_col2 = field.m2m_reverse_name()
to_col2 = opts.get_field_by_name(
field.m2m_reverse_target_field_name())[0].column
target = opts.pk
orig_opts._join_cache[name] = (table1, from_col1,
to_col1, table2, from_col2, to_col2, opts,
target)
int_alias = self.join((alias, table1, from_col1, to_col1),
reuse=can_reuse, nullable=True)
if int_alias == table2 and from_col2 == to_col2:
joins.append(int_alias)
alias = int_alias
else:
alias = self.join(
(int_alias, table2, from_col2, to_col2),
reuse=can_reuse, nullable=True)
joins.extend([int_alias, alias])
elif field.rel:
# One-to-one or many-to-one field
if cached_data:
(table, from_col, to_col, opts, target) = cached_data
else:
opts = field.rel.to._meta
target = field.rel.get_related_field()
table = opts.db_table
from_col = field.column
to_col = target.column
orig_opts._join_cache[name] = (table, from_col, to_col,
opts, target)
alias = self.join((alias, table, from_col, to_col),
nullable=self.is_nullable(field))
joins.append(alias)
path.append(PathInfo(final_field, target, final_field.model._meta,
opts, final_field))
# We have five different cases to solve: foreign keys, reverse
# foreign keys, m2m fields (also reverse) and non-relational
# fields. We are mostly just using the related field API to
# fetch the from and to fields. The m2m fields are handled as
# two foreign keys, first one reverse, the second one direct.
if direct and not field.rel and not m2m:
# Local non-relational field.
final_field = target = field
break
elif direct and not m2m:
# Foreign Key
opts = field.rel.to._meta
target = field.rel.get_related_field()
final_field = field
from_opts = field.model._meta
path.append(PathInfo(field, target, from_opts, opts, field))
elif not direct and not m2m:
# Revere foreign key
final_field = to_field = field.field
opts = to_field.model._meta
from_field = to_field.rel.get_related_field()
from_opts = from_field.model._meta
path.append(
PathInfo(from_field, to_field, from_opts, opts, to_field))
if from_field.model is to_field.model:
# Recursive foreign key to self.
target = opts.get_field_by_name(
field.field.rel.field_name)[0]
else:
# Non-relation fields.
target = field
break
else:
orig_field = field
target = opts.pk
elif direct and m2m:
if not field.rel.through:
# Gotcha! This is just a fake m2m field - a generic relation
# field).
from_field = opts.pk
opts = field.rel.to._meta
target = opts.get_field_by_name(field.object_id_field_name)[0]
final_field = field
# Note that we are using different field for the join_field
# than from_field or to_field. This is a hack, but we need the
# GenericRelation to generate the extra SQL.
path.append(PathInfo(from_field, target, field.model._meta, opts,
field))
else:
# m2m field. We are travelling first to the m2m table along a
# reverse relation, then from m2m table to the target table.
from_field1 = opts.get_field_by_name(
field.m2m_target_field_name())[0]
opts = field.rel.through._meta
to_field1 = opts.get_field_by_name(field.m2m_field_name())[0]
path.append(
PathInfo(from_field1, to_field1, from_field1.model._meta,
opts, to_field1))
final_field = from_field2 = opts.get_field_by_name(
field.m2m_reverse_field_name())[0]
opts = field.rel.to._meta
target = to_field2 = opts.get_field_by_name(
field.m2m_reverse_target_field_name())[0]
path.append(
PathInfo(from_field2, to_field2, from_field2.model._meta,
opts, from_field2))
elif not direct and m2m:
# This one is just like above, except we are travelling the
# fields in opposite direction.
field = field.field
if m2m:
# Many-to-many field defined on the target model.
if cached_data:
(table1, from_col1, to_col1, table2, from_col2,
to_col2, opts, target) = cached_data
else:
table1 = field.m2m_db_table()
from_col1 = opts.get_field_by_name(
field.m2m_reverse_target_field_name())[0].column
to_col1 = field.m2m_reverse_name()
opts = orig_field.opts
table2 = opts.db_table
from_col2 = field.m2m_column_name()
to_col2 = opts.get_field_by_name(
field.m2m_target_field_name())[0].column
target = opts.pk
orig_opts._join_cache[name] = (table1, from_col1,
to_col1, table2, from_col2, to_col2, opts,
target)
from_field1 = opts.get_field_by_name(
field.m2m_reverse_target_field_name())[0]
int_opts = field.rel.through._meta
to_field1 = int_opts.get_field_by_name(
field.m2m_reverse_field_name())[0]
path.append(
PathInfo(from_field1, to_field1, from_field1.model._meta,
int_opts, to_field1))
final_field = from_field2 = int_opts.get_field_by_name(
field.m2m_field_name())[0]
opts = field.opts
target = to_field2 = opts.get_field_by_name(
field.m2m_target_field_name())[0]
path.append(PathInfo(from_field2, to_field2, from_field2.model._meta,
opts, from_field2))
int_alias = self.join((alias, table1, from_col1, to_col1),
reuse=can_reuse, nullable=True)
alias = self.join((int_alias, table2, from_col2, to_col2),
reuse=can_reuse, nullable=True)
joins.extend([int_alias, alias])
else:
# One-to-many field (ForeignKey defined on the target model)
if cached_data:
(table, from_col, to_col, opts, target) = cached_data
else:
local_field = opts.get_field_by_name(
field.rel.field_name)[0]
opts = orig_field.opts
table = opts.db_table
from_col = local_field.column
to_col = field.column
# In case of a recursive FK, use the to_field for
# reverse lookups as well
if orig_field.model is local_field.model:
target = opts.get_field_by_name(
field.rel.field_name)[0]
else:
target = opts.pk
orig_opts._join_cache[name] = (table, from_col, to_col,
opts, target)
alias = self.join((alias, table, from_col, to_col),
reuse=can_reuse, nullable=True)
joins.append(alias)
if m2m and multijoin_pos is None:
multijoin_pos = pos
if not direct and not path[-1].to_field.unique and multijoin_pos is None:
multijoin_pos = pos
if pos != len(names) - 1:
if pos == len(names) - 2:
raise FieldError("Join on field %r not permitted. Did you misspell %r for the lookup type?" % (name, names[pos + 1]))
raise FieldError(
"Join on field %r not permitted. Did you misspell %r for "
"the lookup type?" % (name, names[pos + 1]))
else:
raise FieldError("Join on field %r not permitted." % name)
if multijoin_pos is not None and len(path) >= multijoin_pos and not allow_many:
raise MultiJoin(multijoin_pos + 1)
return path, final_field, target
return field, target, opts, joins, last, extra_filters
def trim_joins(self, target, join_list, last, nonnull_check=False):
def setup_joins(self, names, opts, alias, can_reuse, allow_many=True,
allow_explicit_fk=False):
"""
Sometimes joins at the end of a multi-table sequence can be trimmed. If
the final join is against the same column as we are comparing against,
and is an inner join, we can go back one step in a join chain and
compare against the LHS of the join instead (and then repeat the
optimization). The result, potentially, involves fewer table joins.
Compute the necessary table joins for the passage through the fields
given in 'names'. 'opts' is the Options class for the current model
(which gives the table we are starting from), 'alias' is the alias for
the table to start the joining from.
The 'target' parameter is the final field being joined to, 'join_list'
is the full list of join aliases.
The 'can_reuse' defines the reverse foreign key joins we can reuse. It
can be sql.constants.REUSE_ALL in which case all joins are reusable
or a set of aliases that can be reused. Note that Non-reverse foreign
keys are always reusable.
The 'last' list contains offsets into 'join_list', corresponding to
each component of the filter. Many-to-many relations, for example, add
two tables to the join list and we want to deal with both tables the
same way, so 'last' has an entry for the first of the two tables and
then the table immediately after the second table, in that case.
If 'allow_many' is False, then any reverse foreign key seen will
generate a MultiJoin exception.
The 'nonnull_check' parameter is True when we are using inner joins
between tables explicitly to exclude NULL entries. In that case, the
tables shouldn't be trimmed, because the very action of joining to them
alters the result set.
The 'allow_explicit_fk' controls if field.attname is allowed in the
lookups.
Returns the final field involved in the joins, the target field (used
for any 'where' constraint), the final 'opts' value, the joins and the
field path travelled to generate the joins.
The target field is the field containing the concrete value. Final
field can be something different, for example foreign key pointing to
that value. Final field is needed for example in some value
conversions (convert 'obj' in fk__id=obj to pk val using the foreign
key field for example).
"""
joins = [alias]
# First, generate the path for the names
path, final_field, target = self.names_to_path(
names, opts, allow_many, allow_explicit_fk)
# Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type
# of the trimmed joins.
for pos, join in enumerate(path):
from_field, to_field, from_opts, opts, join_field = join
direct = join_field == from_field
if direct:
nullable = self.is_nullable(from_field)
else:
nullable = True
connection = alias, opts.db_table, from_field.column, to_field.column
alias = self.join(connection, reuse=can_reuse, nullable=nullable,
join_field=join_field)
joins.append(alias)
return final_field, target, opts, joins, path
def trim_joins(self, target, joins, path):
"""
The 'target' parameter is the final field being joined to, 'joins'
is the full list of join aliases. The 'path' contain the PathInfos
used to create the joins.
Returns the final active column and table alias and the new active
join_list.
joins.
We will always trim any direct join if we have the target column
available already in the previous table. Reverse joins can't be
trimmed as we don't know if there is anything on the other side of
the join.
"""
final = len(join_list)
penultimate = last.pop()
if penultimate == final:
penultimate = last.pop()
col = target.column
alias = join_list[-1]
while final > 1:
join = self.alias_map[alias]
if (col != join.rhs_join_col or join.join_type != self.INNER or
nonnull_check):
for info in reversed(path):
direct = info.join_field == info.from_field
if info.to_field == target and direct:
target = info.from_field
self.unref_alias(joins.pop())
else:
break
self.unref_alias(alias)
alias = join.lhs_alias
col = join.lhs_join_col
join_list.pop()
final -= 1
if final == penultimate:
penultimate = last.pop()
return col, alias, join_list
return target.column, joins[-1], joins
def split_exclude(self, filter_expr, prefix, can_reuse):
"""
@ -1627,9 +1642,9 @@ class Query(object):
try:
for name in field_names:
field, target, u2, joins, u3, u4 = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, REUSE_ALL,
allow_m2m, True)
field, target, u2, joins, u3 = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, REUSE_ALL, allow_m2m,
True)
final_alias = joins[-1]
col = target.column
if len(joins) > 1:
@ -1918,7 +1933,7 @@ class Query(object):
"""
opts = self.model._meta
alias = self.get_initial_alias()
field, col, opts, joins, last, extra = self.setup_joins(
field, col, opts, joins, extra = self.setup_joins(
start.split(LOOKUP_SEP), opts, alias, REUSE_ALL)
select_col = self.alias_map[joins[1]].lhs_join_col
select_alias = alias
@ -1975,18 +1990,6 @@ def get_order_dir(field, default='ASC'):
return field, dirn[0]
def setup_join_cache(sender, **kwargs):
"""
The information needed to join between model fields is something that is
invariant over the life of the model, so we cache it in the model's Options
class, rather than recomputing it all the time.
This method initialises the (empty) cache when the model is created.
"""
sender._meta._join_cache = {}
signals.class_prepared.connect(setup_join_cache)
def add_to_dict(data, key, value):
"""
A helper function to add "value" to the set of values for "key", whether or

View File

@ -978,3 +978,7 @@ class AggregationTests(TestCase):
('The Definitive Guide to Django: Web Development Done Right', 2)
]
)
def test_reverse_join_trimming(self):
qs = Author.objects.annotate(Count('book_contact_set__contact'))
self.assertIn(' JOIN ', str(qs.query))

View File

@ -283,6 +283,7 @@ class SingleObject(models.Model):
class RelatedObject(models.Model):
single = models.ForeignKey(SingleObject, null=True)
f = models.IntegerField(null=True)
class Meta:
ordering = ['single']
@ -311,7 +312,7 @@ class Food(models.Model):
@python_2_unicode_compatible
class Eaten(models.Model):
food = models.ForeignKey(Food, to_field="name")
food = models.ForeignKey(Food, to_field="name", null=True)
meal = models.CharField(max_length=20)
def __str__(self):
@ -400,3 +401,23 @@ class ModelA(models.Model):
name = models.TextField()
b = models.ForeignKey(ModelB, null=True)
d = models.ForeignKey(ModelD)
@python_2_unicode_compatible
class Job(models.Model):
name = models.CharField(max_length=20, unique=True)
def __str__(self):
return self.name
class JobResponsibilities(models.Model):
job = models.ForeignKey(Job, to_field='name')
responsibility = models.ForeignKey('Responsibility', to_field='description')
@python_2_unicode_compatible
class Responsibility(models.Model):
description = models.CharField(max_length=20, unique=True)
jobs = models.ManyToManyField(Job, through=JobResponsibilities,
related_name='responsibilities')
def __str__(self):
return self.description

View File

@ -23,7 +23,8 @@ from .models import (Annotation, Article, Author, Celebrity, Child, Cover,
Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten,
Node, ObjectA, ObjectB, ObjectC, CategoryItem, SimpleCategory,
SpecialCategory, OneToOneCategory, NullableName, ProxyCategory,
SingleObject, RelatedObject, ModelA, ModelD)
SingleObject, RelatedObject, ModelA, ModelD, Responsibility, Job,
JobResponsibilities)
class BaseQuerysetTest(TestCase):
@ -243,7 +244,10 @@ class Queries1Tests(BaseQuerysetTest):
q1 = Item.objects.order_by('name')
q2 = Item.objects.filter(id=self.i1.id)
list(q2)
self.assertEqual(len((q1 & q2).order_by('name').query.tables), 1)
combined_query = (q1 & q2).order_by('name').query
self.assertEqual(len([
t for t in combined_query.tables if combined_query.alias_refcount[t]
]), 1)
def test_order_by_join_unref(self):
"""
@ -883,6 +887,225 @@ class Queries1Tests(BaseQuerysetTest):
Item.objects.filter(Q(tags__name__in=['t4', 't3'])),
[repr(i) for i in Item.objects.filter(~~Q(tags__name__in=['t4', 't3']))])
def test_ticket_10790_1(self):
# Querying direct fields with isnull should trim the left outer join.
# It also should not create INNER JOIN.
q = Tag.objects.filter(parent__isnull=True)
self.assertQuerysetEqual(q, ['<Tag: t1>'])
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.filter(parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Tag: t2>', '<Tag: t3>', '<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.exclude(parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Tag: t2>', '<Tag: t3>', '<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.exclude(parent__isnull=False)
self.assertQuerysetEqual(q, ['<Tag: t1>'])
self.assertTrue('JOIN' not in str(q.query))
q = Tag.objects.exclude(parent__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue('INNER JOIN' not in str(q.query))
def test_ticket_10790_2(self):
# Querying across several tables should strip only the last outer join,
# while preserving the preceeding inner joins.
q = Tag.objects.filter(parent__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
# Querying without isnull should not convert anything to left outer join.
q = Tag.objects.filter(parent__parent=self.t1)
self.assertQuerysetEqual(
q,
['<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
def test_ticket_10790_3(self):
# Querying via indirect fields should populate the left outer join
q = NamedCategory.objects.filter(tag__isnull=True)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
# join to dumbcategory ptr_id
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
self.assertQuerysetEqual(q, [])
# Querying across several tables should strip only the last join, while
# preserving the preceding left outer joins.
q = NamedCategory.objects.filter(tag__parent__isnull=True)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
self.assertQuerysetEqual( q, ['<NamedCategory: NamedCategory object>'])
def test_ticket_10790_4(self):
# Querying across m2m field should not strip the m2m table from join.
q = Author.objects.filter(item__tags__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a2>', '<Author: a3>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 2)
self.assertTrue('INNER JOIN' not in str(q.query))
q = Author.objects.filter(item__tags__parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a3>'],
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 3)
self.assertTrue('INNER JOIN' not in str(q.query))
def test_ticket_10790_5(self):
# Querying with isnull=False across m2m field should not create outer joins
q = Author.objects.filter(item__tags__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 2)
q = Author.objects.filter(item__tags__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 3)
q = Author.objects.filter(item__tags__parent__parent__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 4)
def test_ticket_10790_6(self):
# Querying with isnull=True across m2m field should not create inner joins
# and strip last outer join
q = Author.objects.filter(item__tags__parent__parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a1>', '<Author: a2>', '<Author: a2>',
'<Author: a2>', '<Author: a3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 4)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
q = Author.objects.filter(item__tags__parent__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 3)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
def test_ticket_10790_7(self):
# Reverse querying with isnull should not strip the join
q = Author.objects.filter(item__isnull=True)
self.assertQuerysetEqual(
q,
['<Author: a3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
q = Author.objects.filter(item__isnull=False)
self.assertQuerysetEqual(
q,
['<Author: a1>', '<Author: a2>', '<Author: a2>', '<Author: a4>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 1)
def test_ticket_10790_8(self):
# Querying with combined q-objects should also strip the left outer join
q = Tag.objects.filter(Q(parent__isnull=True) | Q(parent=self.t1))
self.assertQuerysetEqual(
q,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q.query).count('INNER JOIN') == 0)
def test_ticket_10790_combine(self):
# Combining queries should not re-populate the left outer join
q1 = Tag.objects.filter(parent__isnull=True)
q2 = Tag.objects.filter(parent__isnull=False)
q3 = q1 | q2
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>', '<Tag: t4>', '<Tag: t5>'],
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q3 = q1 & q2
self.assertQuerysetEqual(q3, [])
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q2 = Tag.objects.filter(parent=self.t1)
q3 = q1 | q2
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q3 = q2 | q1
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 0)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q1 = Tag.objects.filter(parent__isnull=True)
q2 = Tag.objects.filter(parent__parent__isnull=True)
q3 = q1 | q2
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
q3 = q2 | q1
self.assertQuerysetEqual(
q3,
['<Tag: t1>', '<Tag: t2>', '<Tag: t3>']
)
self.assertTrue(str(q3.query).count('LEFT OUTER JOIN') == 1)
self.assertTrue(str(q3.query).count('INNER JOIN') == 0)
class Queries2Tests(TestCase):
def setUp(self):
Number.objects.create(num=4)
@ -1037,6 +1260,10 @@ class Queries4Tests(BaseQuerysetTest):
Item.objects.create(name='i2', created=datetime.datetime.now(), note=n1, creator=self.a3)
def test_ticket14876(self):
# Note: when combining the query we need to have information available
# about the join type of the trimmed "creator__isnull" join. If we
# don't have that information, then the join is created as INNER JOIN
# and results will be incorrect.
q1 = Report.objects.filter(Q(creator__isnull=True) | Q(creator__extra__info='e1'))
q2 = Report.objects.filter(Q(creator__isnull=True)) | Report.objects.filter(Q(creator__extra__info='e1'))
self.assertQuerysetEqual(q1, ["<Report: r1>", "<Report: r3>"], ordered=False)
@ -1405,17 +1632,19 @@ class NullableRelOrderingTests(TestCase):
# the join type of already existing joins.
Plaything.objects.create(name="p1")
s = SingleObject.objects.create(name='s')
r = RelatedObject.objects.create(single=s)
r = RelatedObject.objects.create(single=s, f=1)
Plaything.objects.create(name="p2", others=r)
qs = Plaything.objects.all().filter(others__isnull=False).order_by('pk')
self.assertTrue('JOIN' not in str(qs.query))
qs = Plaything.objects.all().filter(others__f__isnull=False).order_by('pk')
self.assertTrue('INNER' in str(qs.query))
qs = qs.order_by('others__single__name')
# The ordering by others__single__pk will add one new join (to single)
# and that join must be LEFT join. The already existing join to related
# objects must be kept INNER. So, we have both a INNER and a LEFT join
# in the query.
self.assertTrue('LEFT' in str(qs.query))
self.assertTrue('INNER' in str(qs.query))
self.assertEquals(str(qs.query).count('LEFT'), 1)
self.assertEquals(str(qs.query).count('INNER'), 1)
self.assertQuerysetEqual(
qs,
['<Plaything: p2>']
@ -1466,6 +1695,7 @@ class Queries6Tests(TestCase):
# This next test used to cause really weird PostgreSQL behavior, but it was
# only apparent much later when the full test suite ran.
# - Yeah, it leaves global ITER_CHUNK_SIZE to 2 instead of 100...
#@unittest.expectedFailure
def test_slicing_and_cache_interaction(self):
# We can do slicing beyond what is currently in the result cache,
@ -1993,6 +2223,29 @@ class DefaultValuesInsertTest(TestCase):
except TypeError:
self.fail("Creation of an instance of a model with only the PK field shouldn't error out after bulk insert refactoring (#17056)")
class ExcludeTest(TestCase):
def setUp(self):
f1 = Food.objects.create(name='apples')
Food.objects.create(name='oranges')
Eaten.objects.create(food=f1, meal='dinner')
j1 = Job.objects.create(name='Manager')
r1 = Responsibility.objects.create(description='Playing golf')
j2 = Job.objects.create(name='Programmer')
r2 = Responsibility.objects.create(description='Programming')
JobResponsibilities.objects.create(job=j1, responsibility=r1)
JobResponsibilities.objects.create(job=j2, responsibility=r2)
def test_to_field(self):
self.assertQuerysetEqual(
Food.objects.exclude(eaten__meal='dinner'),
['<Food: oranges>'])
self.assertQuerysetEqual(
Job.objects.exclude(responsibilities__description='Playing golf'),
['<Job: Programmer>'])
self.assertQuerysetEqual(
Responsibility.objects.exclude(jobs__name='Manager'),
['<Responsibility: Programming>'])
class NullInExcludeTest(TestCase):
def setUp(self):
NullableName.objects.create(name='i1')
@ -2155,3 +2408,13 @@ class NullJoinPromotionOrTest(TestCase):
# so we can use INNER JOIN for it. However, we can NOT use INNER JOIN
# for the b->c join, as a->b is nullable.
self.assertEqual(str(qset.query).count('INNER JOIN'), 1)
class ReverseJoinTrimmingTest(TestCase):
def test_reverse_trimming(self):
# Check that we don't accidentally trim reverse joins - we can't know
# if there is anything on the other side of the join, so trimming
# reverse joins can't be done, ever.
t = Tag.objects.create()
qs = Tag.objects.filter(annotation__tag=t.pk)
self.assertIn('INNER JOIN', str(qs.query))
self.assertEquals(list(qs), [])