From 32d1bf2bdbf9006bad8bdd94ff74333da84cae9c Mon Sep 17 00:00:00 2001 From: Matthew Wilkes Date: Tue, 1 Aug 2017 15:37:02 +0100 Subject: [PATCH] Fixed #28454 -- Simplifed use of Query.setup_joins() by returning a named tuple. --- django/db/models/sql/compiler.py | 4 +-- django/db/models/sql/query.py | 53 +++++++++++++++++--------------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 2abfd2c289c..5ae5213b147 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -818,8 +818,8 @@ class SQLCompiler: related_field_name = f.related_query_name() fields_found.add(related_field_name) - _, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias) - alias = joins[-1] + join_info = self.query.setup_joins([related_field_name], opts, root_alias) + alias = join_info.joins[-1] from_parent = issubclass(model, opts.model) and model is not opts.model klass_info = { 'model': model, diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ccbbe0cd5fd..bf21c66f202 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -6,7 +6,7 @@ themselves do not have to (and could be backed by things other than SQL databases). The abstraction barrier only works one way: this module has to know all about the internals of models in order to get the information it needs. """ -from collections import Counter, Iterator, Mapping, OrderedDict +from collections import Counter, Iterator, Mapping, OrderedDict, namedtuple from contextlib import suppress from itertools import chain, count, product from string import ascii_uppercase @@ -44,6 +44,12 @@ def get_field_names_from_opts(opts): )) +JoinInfo = namedtuple( + 'JoinInfo', + ('final_field', 'targets', 'opts', 'joins', 'path') +) + + class RawQuery: """A single raw SQL query.""" @@ -935,10 +941,9 @@ class Query: curr_opts = int_model._meta continue link_field = curr_opts.get_ancestor_link(int_model) - _, _, _, joins, _ = self.setup_joins( - [link_field.name], curr_opts, alias) + join_info = self.setup_joins([link_field.name], curr_opts, alias) curr_opts = int_model._meta - alias = seen[int_model] = joins[-1] + alias = seen[int_model] = join_info.joins[-1] return alias or seen[None] def add_annotation(self, annotation, alias, is_summary=False): @@ -1146,39 +1151,38 @@ class Query: allow_many = not branch_negated or not split_subq try: - field, sources, opts, join_list, path = self.setup_joins( - parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) + join_info = self.setup_joins(parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many) # Prevent iterator from being consumed by check_related_objects() if isinstance(value, Iterator): value = list(value) - self.check_related_objects(field, value, opts) + self.check_related_objects(join_info.final_field, value, join_info.opts) # split_exclude() needs to know which joins were generated for the # lookup parts - self._lookup_joins = join_list + self._lookup_joins = join_info.joins except MultiJoin as e: return self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level]), can_reuse, e.names_with_path) # Update used_joins before trimming since they are reused to determine # which joins could be later promoted to INNER. - used_joins.update(join_list) - targets, alias, join_list = self.trim_joins(sources, join_list, path) + used_joins.update(join_info.joins) + targets, alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) if can_reuse is not None: can_reuse.update(join_list) - if field.is_relation: + if join_info.final_field.is_relation: # No support for transforms for relational fields num_lookups = len(lookups) if num_lookups > 1: raise FieldError('Related Field got invalid lookup: {}'.format(lookups[0])) if len(targets) == 1: - col = targets[0].get_col(alias, field) + col = targets[0].get_col(alias, join_info.final_field) else: - col = MultiColSource(alias, targets, sources, field) + col = MultiColSource(alias, targets, join_info.targets, join_info.final_field) else: - col = targets[0].get_col(alias, field) + col = targets[0].get_col(alias, join_info.final_field) condition = self.build_lookup(lookups, col, value) lookup_type = condition.lookup_name @@ -1200,7 +1204,7 @@ class Query: # <=> # NOT (col IS NOT NULL AND col = someval). lookup_class = targets[0].get_lookup('isnull') - clause.add(lookup_class(targets[0].get_col(alias, sources[0]), False), AND) + clause.add(lookup_class(targets[0].get_col(alias, join_info.targets[0]), False), AND) return clause, used_joins if not require_outer else () def add_filter(self, filter_clause): @@ -1383,7 +1387,7 @@ class Query: reuse = can_reuse if join.m2m else None alias = self.join(connection, reuse=reuse) joins.append(alias) - return final_field, targets, opts, joins, path + return JoinInfo(final_field, targets, opts, joins, path) def trim_joins(self, targets, joins, path): """ @@ -1425,16 +1429,14 @@ class Query: return self.annotation_select[name] else: field_list = name.split(LOOKUP_SEP) - field, sources, opts, join_list, path = self.setup_joins( - field_list, self.get_meta(), - self.get_initial_alias(), reuse) - targets, _, join_list = self.trim_joins(sources, join_list, path) + join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), reuse) + targets, _, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) if len(targets) > 1: raise FieldError("Referencing multicolumn fields with F() objects " "isn't supported") if reuse is not None: reuse.update(join_list) - col = targets[0].get_col(join_list[-1], sources[0]) + col = targets[0].get_col(join_list[-1], join_info.targets[0]) return col def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): @@ -1586,9 +1588,12 @@ class Query: for name in field_names: # Join promotion note - we must not remove any rows here, so # if there is no existing joins, use outer join. - _, targets, _, joins, path = self.setup_joins( - name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m) - targets, final_alias, joins = self.trim_joins(targets, joins, path) + join_info = self.setup_joins(name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m) + targets, final_alias, joins = self.trim_joins( + join_info.targets, + join_info.joins, + join_info.path, + ) for target in targets: cols.append(target.get_col(final_alias)) if cols: