Fixed #28454 -- Simplifed use of Query.setup_joins() by returning a named tuple.

This commit is contained in:
Matthew Wilkes 2017-08-01 15:37:02 +01:00 committed by Tim Graham
parent 8df7681d0e
commit 32d1bf2bdb
2 changed files with 31 additions and 26 deletions

View File

@ -818,8 +818,8 @@ class SQLCompiler:
related_field_name = f.related_query_name() related_field_name = f.related_query_name()
fields_found.add(related_field_name) fields_found.add(related_field_name)
_, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias) join_info = self.query.setup_joins([related_field_name], opts, root_alias)
alias = joins[-1] alias = join_info.joins[-1]
from_parent = issubclass(model, opts.model) and model is not opts.model from_parent = issubclass(model, opts.model) and model is not opts.model
klass_info = { klass_info = {
'model': model, 'model': model,

View File

@ -6,7 +6,7 @@ themselves do not have to (and could be backed by things other than SQL
databases). The abstraction barrier only works one way: this module has to know databases). The abstraction barrier only works one way: this module has to know
all about the internals of models in order to get the information it needs. all about the internals of models in order to get the information it needs.
""" """
from collections import Counter, Iterator, Mapping, OrderedDict from collections import Counter, Iterator, Mapping, OrderedDict, namedtuple
from contextlib import suppress from contextlib import suppress
from itertools import chain, count, product from itertools import chain, count, product
from string import ascii_uppercase from string import ascii_uppercase
@ -44,6 +44,12 @@ def get_field_names_from_opts(opts):
)) ))
JoinInfo = namedtuple(
'JoinInfo',
('final_field', 'targets', 'opts', 'joins', 'path')
)
class RawQuery: class RawQuery:
"""A single raw SQL query.""" """A single raw SQL query."""
@ -935,10 +941,9 @@ class Query:
curr_opts = int_model._meta curr_opts = int_model._meta
continue continue
link_field = curr_opts.get_ancestor_link(int_model) link_field = curr_opts.get_ancestor_link(int_model)
_, _, _, joins, _ = self.setup_joins( join_info = self.setup_joins([link_field.name], curr_opts, alias)
[link_field.name], curr_opts, alias)
curr_opts = int_model._meta curr_opts = int_model._meta
alias = seen[int_model] = joins[-1] alias = seen[int_model] = join_info.joins[-1]
return alias or seen[None] return alias or seen[None]
def add_annotation(self, annotation, alias, is_summary=False): def add_annotation(self, annotation, alias, is_summary=False):
@ -1146,39 +1151,38 @@ class Query:
allow_many = not branch_negated or not split_subq allow_many = not branch_negated or not split_subq
try: try:
field, sources, opts, join_list, path = self.setup_joins( join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many)
parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many)
# Prevent iterator from being consumed by check_related_objects() # Prevent iterator from being consumed by check_related_objects()
if isinstance(value, Iterator): if isinstance(value, Iterator):
value = list(value) value = list(value)
self.check_related_objects(field, value, opts) self.check_related_objects(join_info.final_field, value, join_info.opts)
# split_exclude() needs to know which joins were generated for the # split_exclude() needs to know which joins were generated for the
# lookup parts # lookup parts
self._lookup_joins = join_list self._lookup_joins = join_info.joins
except MultiJoin as e: except MultiJoin as e:
return self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), return self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]),
can_reuse, e.names_with_path) can_reuse, e.names_with_path)
# Update used_joins before trimming since they are reused to determine # Update used_joins before trimming since they are reused to determine
# which joins could be later promoted to INNER. # which joins could be later promoted to INNER.
used_joins.update(join_list) used_joins.update(join_info.joins)
targets, alias, join_list = self.trim_joins(sources, join_list, path) targets, alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
if can_reuse is not None: if can_reuse is not None:
can_reuse.update(join_list) can_reuse.update(join_list)
if field.is_relation: if join_info.final_field.is_relation:
# No support for transforms for relational fields # No support for transforms for relational fields
num_lookups = len(lookups) num_lookups = len(lookups)
if num_lookups > 1: if num_lookups > 1:
raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0]))
if len(targets) == 1: if len(targets) == 1:
col = targets[0].get_col(alias, field) col = targets[0].get_col(alias, join_info.final_field)
else: else:
col = MultiColSource(alias, targets, sources, field) col = MultiColSource(alias, targets, join_info.targets, join_info.final_field)
else: else:
col = targets[0].get_col(alias, field) col = targets[0].get_col(alias, join_info.final_field)
condition = self.build_lookup(lookups, col, value) condition = self.build_lookup(lookups, col, value)
lookup_type = condition.lookup_name lookup_type = condition.lookup_name
@ -1200,7 +1204,7 @@ class Query:
# <=> # <=>
# NOT (col IS NOT NULL AND col = someval). # NOT (col IS NOT NULL AND col = someval).
lookup_class = targets[0].get_lookup('isnull') lookup_class = targets[0].get_lookup('isnull')
clause.add(lookup_class(targets[0].get_col(alias, sources[0]), False), AND) clause.add(lookup_class(targets[0].get_col(alias, join_info.targets[0]), False), AND)
return clause, used_joins if not require_outer else () return clause, used_joins if not require_outer else ()
def add_filter(self, filter_clause): def add_filter(self, filter_clause):
@ -1383,7 +1387,7 @@ class Query:
reuse = can_reuse if join.m2m else None reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse) alias = self.join(connection, reuse=reuse)
joins.append(alias) joins.append(alias)
return final_field, targets, opts, joins, path return JoinInfo(final_field, targets, opts, joins, path)
def trim_joins(self, targets, joins, path): def trim_joins(self, targets, joins, path):
""" """
@ -1425,16 +1429,14 @@ class Query:
return self.annotation_select[name] return self.annotation_select[name]
else: else:
field_list = name.split(LOOKUP_SEP) field_list = name.split(LOOKUP_SEP)
field, sources, opts, join_list, path = self.setup_joins( join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse)
field_list, self.get_meta(), targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path)
self.get_initial_alias(), reuse)
targets, _, join_list = self.trim_joins(sources, join_list, path)
if len(targets) > 1: if len(targets) > 1:
raise FieldError("Referencing multicolumn fields with F() objects " raise FieldError("Referencing multicolumn fields with F() objects "
"isn't supported") "isn't supported")
if reuse is not None: if reuse is not None:
reuse.update(join_list) reuse.update(join_list)
col = targets[0].get_col(join_list[-1], sources[0]) col = targets[0].get_col(join_list[-1], join_info.targets[0])
return col return col
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
@ -1586,9 +1588,12 @@ class Query:
for name in field_names: for name in field_names:
# Join promotion note - we must not remove any rows here, so # Join promotion note - we must not remove any rows here, so
# if there is no existing joins, use outer join. # if there is no existing joins, use outer join.
_, targets, _, joins, path = self.setup_joins( join_info = self.setup_joins(name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m) targets, final_alias, joins = self.trim_joins(
targets, final_alias, joins = self.trim_joins(targets, joins, path) join_info.targets,
join_info.joins,
join_info.path,
)
for target in targets: for target in targets:
cols.append(target.get_col(final_alias)) cols.append(target.get_col(final_alias))
if cols: if cols: