diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index bfce9063e5e..b0e3e48c156 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -37,7 +37,7 @@ class SQLCompiler(object): # cleaned. We are not using a clone() of the query here. """ if not self.query.tables: - self.query.join((None, self.query.get_meta().db_table, None)) + self.query.get_initial_alias() if (not self.query.select and self.query.default_cols and not self.query.included_inherited_models): self.query.setup_inherited_models() @@ -171,7 +171,6 @@ class SQLCompiler(object): # Finally do cleanup - get rid of the joins we created above. self.query.reset_refcounts(refcounts_before) - return ' '.join(result), tuple(params) def as_nested_sql(self): @@ -511,51 +510,27 @@ class SQLCompiler(object): ordering and distinct must be done first. """ result = [] - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - first = True - from_params = [] + params = [] for alias in self.query.tables: if not self.query.alias_refcount[alias]: continue try: - name, alias, join_type, lhs, join_cols, _, join_field = self.query.alias_map[alias] + from_clause = self.query.alias_map[alias] except KeyError: # Extra tables can end up in self.tables, but not in the # alias_map if they aren't in a join. That's OK. We skip them. continue - alias_str = '' if alias == name else (' %s' % alias) - if join_type and not first: - extra_cond = join_field.get_extra_restriction( - self.query.where_class, alias, lhs) - if extra_cond: - extra_sql, extra_params = self.compile(extra_cond) - extra_sql = 'AND (%s)' % extra_sql - from_params.extend(extra_params) - else: - extra_sql = "" - result.append('%s %s%s ON (' - % (join_type, qn(name), alias_str)) - for index, (lhs_col, rhs_col) in enumerate(join_cols): - if index != 0: - result.append(' AND ') - result.append('%s.%s = %s.%s' % - (qn(lhs), qn2(lhs_col), qn(alias), qn2(rhs_col))) - result.append('%s)' % extra_sql) - else: - connector = '' if first else ', ' - result.append('%s%s%s' % (connector, qn(name), alias_str)) - first = False + clause_sql, clause_params = self.compile(from_clause) + result.append(clause_sql) + params.extend(clause_params) for t in self.query.extra_tables: alias, _ = self.query.table_alias(t) # Only add the alias if it's not already present (the table_alias() - # calls increments the refcount, so an alias refcount of one means - # this is the only reference. + # call increments the refcount, so an alias refcount of one means + # this is the only reference). if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1: - connector = '' if first else ', ' - result.append('%s%s' % (connector, qn(alias))) - first = False - return result, from_params + result.append(', %s' % self.quote_name_unless_alias(alias)) + return result, params def get_grouping(self, having_group_by, ordering_group_by): """ diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 49fdf114b31..e0e3f101008 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -21,12 +21,6 @@ GET_ITERATOR_CHUNK_SIZE = 100 # Namedtuples for sql.* internal use. -# Join lists (indexes into the tuples that are values in the alias_map -# dictionary in the Query class). -JoinInfo = namedtuple('JoinInfo', - 'table_name rhs_alias join_type lhs_alias ' - 'join_cols nullable join_field') - # Pairs of column clauses to select, and (possibly None) field for the clause. SelectInfo = namedtuple('SelectInfo', 'col field') @@ -41,3 +35,7 @@ ORDER_DIR = { 'ASC': ('ASC', 'DESC'), 'DESC': ('DESC', 'ASC'), } + +# SQL join types. +INNER = 'INNER JOIN' +LOUTER = 'LEFT OUTER JOIN' diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 321451ac42b..fc5ffc17901 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -2,6 +2,7 @@ Useful auxiliary data structures for query construction. Not useful outside the SQL domain. """ +from django.db.models.sql.constants import INNER, LOUTER class EmptyResultSet(Exception): @@ -22,3 +23,119 @@ class MultiJoin(Exception): class Empty(object): pass + + +class Join(object): + """ + Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the + FROM entry. For example, the SQL generated could be + LEFT OUTER JOIN "sometable" T1 ON ("othertable"."sometable_id" = "sometable"."id") + + This class is primarily used in Query.alias_map. All entries in alias_map + must be Join compatible by providing the following attributes and methods: + - table_name (string) + - table_alias (possible alias for the table, can be None) + - join_type (can be None for those entries that aren't joined from + anything) + - parent_alias (which table is this join's parent, can be None similarly + to join_type) + - as_sql() + - relabeled_clone() + + """ + def __init__(self, table_name, parent_alias, table_alias, join_type, + join_field, nullable): + # Join table + self.table_name = table_name + self.parent_alias = parent_alias + # Note: table_alias is not necessarily known at instantiation time. + self.table_alias = table_alias + # LOUTER or INNER + self.join_type = join_type + # A list of 2-tuples to use in the ON clause of the JOIN. + # Each 2-tuple will create one join condition in the ON clause. + self.join_cols = join_field.get_joining_columns() + # Along which field (or RelatedObject in the reverse join case) + self.join_field = join_field + # Is this join nullabled? + self.nullable = nullable + + def as_sql(self, compiler, connection): + """ + Generates the full + LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params + clause for this join. + """ + params = [] + sql = [] + alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias) + qn = compiler.quote_name_unless_alias + qn2 = connection.ops.quote_name + sql.append('%s %s%s ON (' % (self.join_type, qn(self.table_name), alias_str)) + for index, (lhs_col, rhs_col) in enumerate(self.join_cols): + if index != 0: + sql.append(' AND ') + sql.append('%s.%s = %s.%s' % ( + qn(self.parent_alias), + qn2(lhs_col), + qn(self.table_alias), + qn2(rhs_col), + )) + extra_cond = self.join_field.get_extra_restriction( + compiler.query.where_class, self.table_alias, self.parent_alias) + if extra_cond: + extra_sql, extra_params = compiler.compile(extra_cond) + extra_sql = 'AND (%s)' % extra_sql + params.extend(extra_params) + sql.append('%s' % extra_sql) + sql.append(')') + return ' '.join(sql), params + + def relabeled_clone(self, change_map): + new_parent_alias = change_map.get(self.parent_alias, self.parent_alias) + new_table_alias = change_map.get(self.table_alias, self.table_alias) + return self.__class__( + self.table_name, new_parent_alias, new_table_alias, self.join_type, + self.join_field, self.nullable) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return ( + self.table_name == other.table_name and + self.parent_alias == other.parent_alias and + self.join_field == other.join_field + ) + return False + + def demote(self): + new = self.relabeled_clone({}) + new.join_type = INNER + return new + + def promote(self): + new = self.relabeled_clone({}) + new.join_type = LOUTER + return new + + +class BaseTable(object): + """ + The BaseTable class is used for base table references in FROM clause. For + example, the SQL "foo" in + SELECT * FROM "foo" WHERE somecond + could be generated by this class. + """ + join_type = None + parent_alias = None + + def __init__(self, table_name, alias): + self.table_name = table_name + self.table_alias = alias + + def as_sql(self, compiler, connection): + alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias) + base_sql = compiler.quote_name_unless_alias(self.table_name) + return base_sql + alias_str, [] + + def relabeled_clone(self, change_map): + return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9f5ca0dc506..dadca181294 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -20,8 +20,9 @@ from django.db.models.query_utils import Q, refs_aggregate from django.db.models.related import PathInfo from django.db.models.aggregates import Count from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, - ORDER_PATTERN, JoinInfo, SelectInfo) -from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin + ORDER_PATTERN, SelectInfo, INNER, LOUTER) +from django.db.models.sql.datastructures import ( + EmptyResultSet, Empty, MultiJoin, Join, BaseTable) from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, ExtraWhere, AND, OR, EmptyWhere) from django.utils import six @@ -87,10 +88,6 @@ class Query(object): """ A single SQL query. """ - # SQL join types. These are part of the class because their string forms - # vary from database to database and can be customised by a subclass. - INNER = 'INNER JOIN' - LOUTER = 'LEFT OUTER JOIN' alias_prefix = 'T' subq_aliases = frozenset([alias_prefix]) @@ -103,15 +100,15 @@ class Query(object): self.alias_refcount = {} # alias_map is the most important data structure regarding joins. # It's used for recording which joins exist in the query and what - # type they are. The key is the alias of the joined table (possibly - # the table name) and the value is JoinInfo from constants.py. + # types they are. The key is the alias of the joined table (possibly + # the table name) and the value is a Join-like object (see + # sql.datastructures.Join for more information). self.alias_map = {} # Sometimes the query contains references to aliases in outer queries (as # a result of split_exclude). Correct alias quoting needs to know these # aliases too. self.external_aliases = set() self.table_map = {} # Maps table names to list of aliases. - self.join_map = {} self.default_cols = True self.default_ordering = True self.standard_ordering = True @@ -246,7 +243,6 @@ class Query(object): obj.alias_map = self.alias_map.copy() obj.external_aliases = self.external_aliases.copy() obj.table_map = self.table_map.copy() - obj.join_map = self.join_map.copy() obj.default_cols = self.default_cols obj.default_ordering = self.default_ordering obj.standard_ordering = self.standard_ordering @@ -495,19 +491,17 @@ class Query(object): self.get_initial_alias() joinpromoter = JoinPromoter(connector, 2, False) joinpromoter.add_votes( - j for j in self.alias_map if self.alias_map[j].join_type == self.INNER) + j for j in self.alias_map if self.alias_map[j].join_type == INNER) rhs_votes = set() # Now, add the joins from rhs query into the new query (skipping base # table). for alias in rhs.tables[1:]: - table, _, join_type, lhs, join_cols, nullable, join_field = rhs.alias_map[alias] + join = rhs.alias_map[alias] # If the left side of the join was already relabeled, use the # updated alias. - lhs = change_map.get(lhs, lhs) - new_alias = self.join( - (lhs, table, join_cols), reuse=reuse, - nullable=nullable, join_field=join_field) - if join_type == self.INNER: + join = join.relabeled_clone(change_map) + new_alias = self.join(join, reuse=reuse) + if join.join_type == INNER: rhs_votes.add(new_alias) # We can't reuse the same join again in the query. If we have two # distinct joins for the same connection in rhs query, then the @@ -714,27 +708,26 @@ class Query(object): aliases = list(aliases) while aliases: alias = aliases.pop(0) - if self.alias_map[alias].join_cols[0][1] is None: + if self.alias_map[alias].join_type is None: # This is the base table (first FROM entry) - this table # isn't really joined at all in the query, so we should not # alter its join type. continue # Only the first alias (skipped above) should have None join_type assert self.alias_map[alias].join_type is not None - parent_alias = self.alias_map[alias].lhs_alias + parent_alias = self.alias_map[alias].parent_alias parent_louter = ( parent_alias - and self.alias_map[parent_alias].join_type == self.LOUTER) - already_louter = self.alias_map[alias].join_type == self.LOUTER + and self.alias_map[parent_alias].join_type == LOUTER) + already_louter = self.alias_map[alias].join_type == LOUTER if ((self.alias_map[alias].nullable or parent_louter) and not already_louter): - data = self.alias_map[alias]._replace(join_type=self.LOUTER) - self.alias_map[alias] = data + self.alias_map[alias] = self.alias_map[alias].promote() # Join type of 'alias' changed, so re-examine all aliases that # refer to this one. aliases.extend( join for join in self.alias_map.keys() - if (self.alias_map[join].lhs_alias == alias + if (self.alias_map[join].parent_alias == alias and join not in aliases)) def demote_joins(self, aliases): @@ -750,10 +743,10 @@ class Query(object): aliases = list(aliases) while aliases: alias = aliases.pop(0) - if self.alias_map[alias].join_type == self.LOUTER: - self.alias_map[alias] = self.alias_map[alias]._replace(join_type=self.INNER) - parent_alias = self.alias_map[alias].lhs_alias - if self.alias_map[parent_alias].join_type == self.INNER: + if self.alias_map[alias].join_type == LOUTER: + self.alias_map[alias] = self.alias_map[alias].demote() + parent_alias = self.alias_map[alias].parent_alias + if self.alias_map[parent_alias].join_type == INNER: aliases.append(parent_alias) def reset_refcounts(self, to_counts): @@ -792,19 +785,13 @@ class Query(object): (key, relabel_column(col)) for key, col in self._annotations.items()) # 2. Rename the alias in the internal table/alias datastructures. - for ident, aliases in self.join_map.items(): - del self.join_map[ident] - aliases = tuple(change_map.get(a, a) for a in aliases) - ident = (change_map.get(ident[0], ident[0]),) + ident[1:] - self.join_map[ident] = aliases for old_alias, new_alias in six.iteritems(change_map): - alias_data = self.alias_map.get(old_alias) - if alias_data is None: + if old_alias not in self.alias_map: continue - alias_data = alias_data._replace(rhs_alias=new_alias) + alias_data = self.alias_map[old_alias].relabeled_clone(change_map) + self.alias_map[new_alias] = alias_data self.alias_refcount[new_alias] = self.alias_refcount[old_alias] del self.alias_refcount[old_alias] - self.alias_map[new_alias] = alias_data del self.alias_map[old_alias] table_aliases = self.table_map[alias_data.table_name] @@ -819,14 +806,6 @@ class Query(object): for key, alias in self.included_inherited_models.items(): if alias in change_map: self.included_inherited_models[key] = change_map[alias] - - # 3. Update any joins that refer to the old alias. - for alias, data in six.iteritems(self.alias_map): - lhs = data.lhs_alias - if lhs in change_map: - data = data._replace(lhs_alias=change_map[lhs]) - self.alias_map[alias] = data - self.external_aliases = {change_map.get(alias, alias) for alias in self.external_aliases} @@ -862,7 +841,7 @@ class Query(object): alias = self.tables[0] self.ref_alias(alias) else: - alias = self.join((None, self.get_meta().db_table, None)) + alias = self.join(BaseTable(self.get_meta().db_table, None)) return alias def count_active_tables(self): @@ -873,7 +852,7 @@ class Query(object): """ return len([1 for count in self.alias_refcount.values() if count]) - def join(self, connection, reuse=None, nullable=False, join_field=None): + def join(self, join, reuse=None): """ Returns an alias for the join in 'connection', either reusing an existing alias for that join or creating a new one. 'connection' is a @@ -897,40 +876,22 @@ class Query(object): The 'join_field' is the field we are joining along (if any). """ - lhs, table, join_cols = connection - assert lhs is None or join_field is not None - existing = self.join_map.get(connection, ()) - if reuse is None: - reuse = existing - else: - reuse = [a for a in existing if a in reuse] - for alias in reuse: - if join_field and self.alias_map[alias].join_field != join_field: - # The join_map doesn't contain join_field (mainly because - # fields in Query structs are problematic in pickling), so - # check that the existing join is created using the same - # join_field used for the under work join. - continue - self.ref_alias(alias) - return alias + reuse = [a for a, j in self.alias_map.items() + if (reuse is None or a in reuse) and j == join] + if reuse: + self.ref_alias(reuse[0]) + return reuse[0] # No reuse is possible, so we need a new alias. - alias, _ = self.table_alias(table, create=True) - if not lhs: - # Not all tables need to be joined to anything. No join type - # means the later columns are ignored. - join_type = None - elif self.alias_map[lhs].join_type == self.LOUTER or nullable: - join_type = self.LOUTER - else: - join_type = self.INNER - join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable, - join_field) + alias, _ = self.table_alias(join.table_name, create=True) + if join.join_type: + if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable: + join_type = LOUTER + else: + join_type = INNER + join.join_type = join_type + join.table_alias = alias self.alias_map[alias] = join - if connection in self.join_map: - self.join_map[connection] += (alias,) - else: - self.join_map[connection] = (alias,) return alias def setup_inherited_models(self): @@ -1249,7 +1210,7 @@ class Query(object): require_outer = True if (lookup_type != 'isnull' and ( self.is_nullable(targets[0]) or - self.alias_map[join_list[-1]].join_type == self.LOUTER)): + self.alias_map[join_list[-1]].join_type == LOUTER)): # The condition added here will be SQL like this: # NOT (col IS NOT NULL), where the first NOT is added in # upper layers of code. The reason for addition is that if col @@ -1326,7 +1287,7 @@ class Query(object): # rel_a doesn't produce any rows, then the whole condition must fail. # So, demotion is OK. existing_inner = set( - (a for a in self.alias_map if self.alias_map[a].join_type == self.INNER)) + (a for a in self.alias_map if self.alias_map[a].join_type == INNER)) clause, require_inner = self._add_q(where_part, self.used_aliases) self.where.add(clause, AND) for hp in having_parts: @@ -1490,10 +1451,9 @@ class Query(object): nullable = self.is_nullable(join.join_field) else: nullable = True - connection = alias, opts.db_table, join.join_field.get_joining_columns() + connection = Join(opts.db_table, alias, None, INNER, join.join_field, nullable) reuse = can_reuse if join.m2m else None - alias = self.join( - connection, reuse=reuse, nullable=nullable, join_field=join.join_field) + alias = self.join(connection, reuse=reuse) joins.append(alias) if hasattr(final_field, 'field'): final_field = final_field.field @@ -1991,9 +1951,10 @@ class Query(object): for trimmed_paths, path in enumerate(all_paths): if path.m2m: break - if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == self.LOUTER: + if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == LOUTER: contains_louter = True - self.unref_alias(lookup_tables[trimmed_paths]) + alias = lookup_tables[trimmed_paths] + self.unref_alias(alias) # The path.join_field is a Rel, lets get the other side's field join_field = path.join_field.field # Build the filter prefix. @@ -2010,7 +1971,7 @@ class Query(object): # Lets still see if we can trim the first join from the inner query # (that is, self). We can't do this for LEFT JOINs because we would # miss those rows that have nothing on the outer side. - if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type != self.LOUTER: + if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type != LOUTER: select_fields = [r[0] for r in join_field.related_fields] select_alias = lookup_tables[trimmed_paths + 1] self.unref_alias(lookup_tables[trimmed_paths]) @@ -2024,6 +1985,12 @@ class Query(object): # values in select_fields. Lets punt this one for now. select_fields = [r[1] for r in join_field.related_fields] select_alias = lookup_tables[trimmed_paths] + # The found starting point is likely a Join instead of a BaseTable reference. + # But the first entry in the query's FROM clause must not be a JOIN. + for table in self.tables: + if self.alias_refcount[table] > 0: + self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table) + break self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields] return trimmed_prefix, contains_louter diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 6aa3770897e..1c2906e7a39 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -11,6 +11,7 @@ from django.core.exceptions import FieldError from django.db import connection, DEFAULT_DB_ALIAS from django.db.models import Count, F, Q from django.db.models.sql.where import WhereNode, EverythingNode, NothingNode +from django.db.models.sql.constants import LOUTER from django.db.models.sql.datastructures import EmptyResultSet from django.test import TestCase, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext @@ -128,7 +129,7 @@ class Queries1Tests(BaseQuerysetTest): def test_ticket2306(self): # Checking that no join types are "left outer" joins. query = Item.objects.filter(tags=self.t2).query - self.assertNotIn(query.LOUTER, [x[2] for x in query.alias_map.values()]) + self.assertNotIn(LOUTER, [x.join_type for x in query.alias_map.values()]) self.assertQuerysetEqual( Item.objects.filter(Q(tags=self.t1)).order_by('name'), @@ -336,7 +337,7 @@ class Queries1Tests(BaseQuerysetTest): # Excluding from a relation that cannot be NULL should not use outer joins. query = Item.objects.exclude(creator__in=[self.a1, self.a2]).query - self.assertNotIn(query.LOUTER, [x[2] for x in query.alias_map.values()]) + self.assertNotIn(LOUTER, [x.join_type for x in query.alias_map.values()]) # Similarly, when one of the joins cannot possibly, ever, involve NULL # values (Author -> ExtraInfo, in the following), it should never be @@ -344,7 +345,7 @@ class Queries1Tests(BaseQuerysetTest): # involve one "left outer" join (Author -> Item is 0-to-many). qs = Author.objects.filter(id=self.a1.id).filter(Q(extra__note=self.n1) | Q(item__note=self.n3)) self.assertEqual( - len([x[2] for x in qs.query.alias_map.values() if x[2] == query.LOUTER and qs.query.alias_refcount[x[1]]]), + len([x for x in qs.query.alias_map.values() if x.join_type == LOUTER and qs.query.alias_refcount[x.table_alias]]), 1 ) @@ -855,7 +856,7 @@ class Queries1Tests(BaseQuerysetTest): ) q = Note.objects.filter(Q(extrainfo__author=self.a1) | Q(extrainfo=xx)).query self.assertEqual( - len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]), + len([x for x in q.alias_map.values() if x.join_type == LOUTER and q.alias_refcount[x.table_alias]]), 1 )