diff --git a/django/contrib/gis/db/models/sql/compiler.py b/django/contrib/gis/db/models/sql/compiler.py index cf6a8ad047..0dcf50d32a 100644 --- a/django/contrib/gis/db/models/sql/compiler.py +++ b/django/contrib/gis/db/models/sql/compiler.py @@ -39,7 +39,7 @@ class GeoSQLCompiler(compiler.SQLCompiler): if self.query.select: only_load = self.deferred_to_columns() # This loop customized for GeoQuery. - for col, field in zip(self.query.select, self.query.select_fields): + for col, field in self.query.select: if isinstance(col, (list, tuple)): alias, column = col table = self.query.alias_map[alias].table_name @@ -85,7 +85,7 @@ class GeoSQLCompiler(compiler.SQLCompiler): ]) # This loop customized for GeoQuery. - for (table, col), field in zip(self.query.related_select_cols, self.query.related_select_fields): + for (table, col), field in self.query.related_select_cols: r = self.get_field_select(field, table, col) if with_aliases and col in col_aliases: c_alias = 'Col%d' % len(col_aliases) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b9095e503a..7461f5f31d 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -6,7 +6,7 @@ from django.db.backends.util import truncate_name from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import select_related_descend from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, - GET_ITERATOR_CHUNK_SIZE) + GET_ITERATOR_CHUNK_SIZE, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.query import get_order_dir, Query @@ -188,7 +188,7 @@ class SQLCompiler(object): col_aliases = set() if self.query.select: only_load = self.deferred_to_columns() - for col in self.query.select: + for col, _ in self.query.select: if isinstance(col, (list, tuple)): alias, column = col table = self.query.alias_map[alias].table_name @@ -233,7 +233,7 @@ class SQLCompiler(object): for alias, aggregate in self.query.aggregate_select.items() ]) - for table, col in self.query.related_select_cols: + for (table, col), _ in self.query.related_select_cols: r = '%s.%s' % (qn(table), qn(col)) if with_aliases and col in col_aliases: c_alias = 'Col%d' % len(col_aliases) @@ -557,8 +557,9 @@ class SQLCompiler(object): for extra_select, extra_params in six.itervalues(self.query.extra_select): extra_selects.append(extra_select) params.extend(extra_params) - cols = (group_by + self.query.select + - self.query.related_select_cols + extra_selects) + select_cols = [s.col for s in self.query.select] + related_select_cols = [s.col for s in self.query.related_select_cols] + cols = (group_by + select_cols + related_select_cols + extra_selects) seen = set() for col in cols: if col in seen: @@ -589,7 +590,6 @@ class SQLCompiler(object): opts = self.query.get_meta() root_alias = self.query.get_initial_alias() self.query.related_select_cols = [] - self.query.related_select_fields = [] if not used: used = set() if dupe_set is None: @@ -664,8 +664,8 @@ class SQLCompiler(object): used.add(alias) columns, aliases = self.get_default_columns(start_alias=alias, opts=f.rel.to._meta, as_pairs=True) - self.query.related_select_cols.extend(columns) - self.query.related_select_fields.extend(f.rel.to._meta.fields) + self.query.related_select_cols.extend( + SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields)) if restricted: next = requested.get(f.name, {}) else: @@ -734,8 +734,8 @@ class SQLCompiler(object): used.add(alias) columns, aliases = self.get_default_columns(start_alias=alias, opts=model._meta, as_pairs=True, local_only=True) - self.query.related_select_cols.extend(columns) - self.query.related_select_fields.extend(model._meta.fields) + self.query.related_select_cols.extend( + SelectInfo(col, field) for col, field in zip(columns, model._meta.fields)) next = requested.get(f.related_query_name(), {}) # Use True here because we are looking at the _reverse_ side of @@ -772,7 +772,7 @@ class SQLCompiler(object): if resolve_columns: if fields is None: # We only set this up here because - # related_select_fields isn't populated until + # related_select_cols isn't populated until # execute_sql() has been called. # We also include types of fields of related models that @@ -782,11 +782,11 @@ class SQLCompiler(object): # This code duplicates the logic for the order of fields # found in get_columns(). It would be nice to clean this up. - if self.query.select_fields: - fields = self.query.select_fields + if self.query.select: + fields = [f.field for f in self.query.select] else: fields = self.query.model._meta.fields - fields = fields + self.query.related_select_fields + fields = fields + [f.field for f in self.query.related_select_cols] # If the field was deferred, exclude it from being passed # into `resolve_columns` because it wasn't selected. diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index f750310624..7e34047e1d 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -25,6 +25,9 @@ JoinInfo = namedtuple('JoinInfo', 'table_name rhs_alias join_type lhs_alias ' 'lhs_join_col rhs_join_col nullable') +# Pairs of column clauses to select, and (possibly None) field for the clause. +SelectInfo = namedtuple('SelectInfo', 'col field') + # How many results to expect from a cursor.execute call MULTI = 'multi' SINGLE = 'single' diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index cef01c48ab..de7e5904a3 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -20,7 +20,7 @@ from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, - ORDER_PATTERN, JoinInfo) + ORDER_PATTERN, JoinInfo, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, @@ -115,17 +115,20 @@ class Query(object): self.default_ordering = True self.standard_ordering = True self.ordering_aliases = [] - self.related_select_fields = [] self.dupe_avoidance = {} self.used_aliases = set() self.filter_is_sticky = False self.included_inherited_models = {} - # SQL-related attributes + # SQL-related attributes + # Select and related select clauses as SelectInfo instances. + # The select is used for cases where we want to set up the select + # clause to contain other than default fields (values(), annotate(), + # subqueries...) self.select = [] - # For each to-be-selected field in self.select there must be a - # corresponding entry in self.select - git seems to need this. - self.select_fields = [] + # The related_select_cols is used for columns needed for + # select_related - this is populated in compile stage. + self.related_select_cols = [] self.tables = [] # Aliases in the order they are created. self.where = where() self.where_class = where @@ -138,7 +141,6 @@ class Query(object): self.select_for_update = False self.select_for_update_nowait = False self.select_related = False - self.related_select_cols = [] # SQL aggregate-related attributes self.aggregates = SortedDict() # Maps alias -> SQL aggregate function @@ -191,15 +193,14 @@ class Query(object): Pickling support. """ obj_dict = self.__dict__.copy() - obj_dict['related_select_fields'] = [] obj_dict['related_select_cols'] = [] # Fields can't be pickled, so if a field list has been # specified, we pickle the list of field names instead. # None is also a possible value; that can pass as-is - obj_dict['select_fields'] = [ - f is not None and f.name or None - for f in obj_dict['select_fields'] + obj_dict['select'] = [ + (s.col, s.field is not None and s.field.name or None) + for s in obj_dict['select'] ] return obj_dict @@ -209,9 +210,9 @@ class Query(object): """ # Rebuild list of field instances opts = obj_dict['model']._meta - obj_dict['select_fields'] = [ - name is not None and opts.get_field(name) or None - for name in obj_dict['select_fields'] + obj_dict['select'] = [ + SelectInfo(tpl[0], tpl[1] is not None and opts.get_field(tpl[1]) or None) + for tpl in obj_dict['select'] ] self.__dict__.update(obj_dict) @@ -256,10 +257,9 @@ class Query(object): obj.standard_ordering = self.standard_ordering obj.included_inherited_models = self.included_inherited_models.copy() obj.ordering_aliases = [] - obj.select_fields = self.select_fields[:] - obj.related_select_fields = self.related_select_fields[:] obj.dupe_avoidance = self.dupe_avoidance.copy() obj.select = self.select[:] + obj.related_select_cols = [] obj.tables = self.tables[:] obj.where = copy.deepcopy(self.where, memo=memo) obj.where_class = self.where_class @@ -275,7 +275,6 @@ class Query(object): obj.select_for_update = self.select_for_update obj.select_for_update_nowait = self.select_for_update_nowait obj.select_related = self.select_related - obj.related_select_cols = [] obj.aggregates = copy.deepcopy(self.aggregates, memo=memo) if self.aggregate_select_mask is None: obj.aggregate_select_mask = None @@ -384,7 +383,6 @@ class Query(object): query.select_for_update = False query.select_related = False query.related_select_cols = [] - query.related_select_fields = [] result = query.get_compiler(using).execute_sql(SINGLE) if result is None: @@ -527,14 +525,14 @@ class Query(object): # Selection columns and extra extensions are those provided by 'rhs'. self.select = [] - for col in rhs.select: + for col, field in rhs.select: if isinstance(col, (list, tuple)): - self.select.append((change_map.get(col[0], col[0]), col[1])) + new_col = change_map.get(col[0], col[0]), col[1] + self.select.append(SelectInfo(new_col, field)) else: item = copy.deepcopy(col) item.relabel_aliases(change_map) - self.select.append(item) - self.select_fields = rhs.select_fields[:] + self.select.append(SelectInfo(item, field)) if connector == OR: # It would be nice to be able to handle this, but the queries don't @@ -750,24 +748,23 @@ class Query(object): """ assert set(change_map.keys()).intersection(set(change_map.values())) == set() + def relabel_column(col): + if isinstance(col, (list, tuple)): + old_alias = col[0] + return (change_map.get(old_alias, old_alias), col[1]) + else: + col.relabel_aliases(change_map) + return col # 1. Update references in "select" (normal columns plus aliases), # "group by", "where" and "having". self.where.relabel_aliases(change_map) self.having.relabel_aliases(change_map) - for columns in [self.select, self.group_by or []]: - for pos, col in enumerate(columns): - if isinstance(col, (list, tuple)): - old_alias = col[0] - columns[pos] = (change_map.get(old_alias, old_alias), col[1]) - else: - col.relabel_aliases(change_map) - for mapping in [self.aggregates]: - for key, col in mapping.items(): - if isinstance(col, (list, tuple)): - old_alias = col[0] - mapping[key] = (change_map.get(old_alias, old_alias), col[1]) - else: - col.relabel_aliases(change_map) + if self.group_by: + self.group_by = [relabel_column(col) for col in self.group_by] + self.select = [SelectInfo(relabel_column(s.col), s.field) + for s in self.select] + self.aggregates = SortedDict( + (key, relabel_column(col)) for key, col in self.aggregates.items()) # 2. Rename the alias in the internal table/alias datastructures. for k, aliases in self.join_map.items(): @@ -1570,7 +1567,7 @@ class Query(object): # since we are adding a IN clause. This prevents the # database from tripping over IN (...,NULL,...) selects and returning # nothing - alias, col = query.select[0] + alias, col = query.select[0].col query.where.add((Constraint(alias, col, None), 'isnull', False), AND) self.add_filter(('%s__in' % prefix, query), negate=True, trim=True, @@ -1629,7 +1626,6 @@ class Query(object): Removes all fields from SELECT clause. """ self.select = [] - self.select_fields = [] self.default_cols = False self.select_related = False self.set_extra_mask(()) @@ -1642,7 +1638,6 @@ class Query(object): columns. """ self.select = [] - self.select_fields = [] def add_distinct_fields(self, *field_names): """ @@ -1674,8 +1669,7 @@ class Query(object): col = join.lhs_join_col joins = joins[:-1] self.promote_joins(joins[1:]) - self.select.append((final_alias, col)) - self.select_fields.append(field) + self.select.append(SelectInfo((final_alias, col), field)) except MultiJoin: raise FieldError("Invalid field name: '%s'" % name) except FieldError: @@ -1731,8 +1725,8 @@ class Query(object): """ self.group_by = [] - for sel in self.select: - self.group_by.append(sel) + for col, _ in self.select: + self.group_by.append(col) def add_count_column(self): """ @@ -1745,7 +1739,7 @@ class Query(object): else: assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select': %r" % self.select - count = self.aggregates_module.Count(self.select[0]) + count = self.aggregates_module.Count(self.select[0].col) else: opts = self.model._meta if not self.select: @@ -1757,7 +1751,7 @@ class Query(object): assert len(self.select) == 1, \ "Cannot add count col with multiple cols in 'select'." - count = self.aggregates_module.Count(self.select[0], distinct=True) + count = self.aggregates_module.Count(self.select[0].col, distinct=True) # Distinct handling is done in Count(), so don't do it at this # level. self.distinct = False @@ -1781,7 +1775,6 @@ class Query(object): d = d.setdefault(part, {}) self.select_related = field_dict self.related_select_cols = [] - self.related_select_fields = [] def add_extra(self, select, select_params, where, params, tables, order_by): """ @@ -1975,7 +1968,7 @@ class Query(object): self.unref_alias(select_alias) select_alias = join_info.rhs_alias select_col = join_info.rhs_join_col - self.select = [(select_alias, select_col)] + self.select = [SelectInfo((select_alias, select_col), None)] self.remove_inherited_models() def is_nullable(self, field): diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 24ac957cbf..39d1ee0116 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -76,7 +76,7 @@ class DeleteQuery(Query): return else: innerq.clear_select_clause() - innerq.select, innerq.select_fields = [(self.get_initial_alias(), pk.column)], [None] + innerq.select = [SelectInfo((self.get_initial_alias(), pk.column), None)] values = innerq where = self.where_class() where.add((Constraint(None, pk.column, pk), 'in', values), AND) @@ -244,7 +244,7 @@ class DateQuery(Query): alias = result[3][-1] select = Date((alias, field.column), lookup_type) self.clear_select_clause() - self.select, self.select_fields = [select], [None] + self.select = [SelectInfo(select, None)] self.distinct = True self.order_by = order == 'ASC' and [1] or [-1]