Refs #27149, #29542 -- Simplified subquery parentheses wrapping logic.

This commit is contained in:
Simon Charette 2019-03-06 01:24:41 -05:00 committed by Tim Graham
parent 3543129822
commit 3a505c70e7
4 changed files with 8 additions and 21 deletions

View File

@ -1023,23 +1023,13 @@ class Subquery(Expression):
def as_sql(self, compiler, connection, template=None, **extra_context): def as_sql(self, compiler, connection, template=None, **extra_context):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
template_params = {**self.extra, **extra_context} template_params = {**self.extra, **extra_context}
template_params['subquery'], sql_params = self.query.as_sql(compiler, connection) subquery_sql, sql_params = self.query.as_sql(compiler, connection)
template_params['subquery'] = subquery_sql[1:-1]
template = template or template_params.get('template', self.template) template = template or template_params.get('template', self.template)
sql = template % template_params sql = template % template_params
return sql, sql_params return sql, sql_params
def _prepare(self, output_field):
# This method will only be called if this instance is the "rhs" in an
# expression: the wrapping () must be removed (as the expression that
# contains this will provide them). SQLite evaluates ((subquery))
# differently than the other databases.
if self.template == '(%(subquery)s)':
clone = self.copy()
clone.template = '%(subquery)s'
return clone
return self
def get_group_by_cols(self, alias=None): def get_group_by_cols(self, alias=None):
if alias: if alias:
return [Ref(alias, self)] return [Ref(alias, self)]

View File

@ -89,8 +89,7 @@ class Lookup:
value = self.apply_bilateral_transforms(value) value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query) value = value.resolve_expression(compiler.query)
if hasattr(value, 'as_sql'): if hasattr(value, 'as_sql'):
sql, params = compiler.compile(value) return compiler.compile(value)
return '(' + sql + ')', params
else: else:
return self.get_db_prep_lookup(value, connection) return self.get_db_prep_lookup(value, connection)

View File

@ -5,7 +5,7 @@ from itertools import chain
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import OrderBy, Random, RawSQL, Ref, Subquery from django.db.models.expressions import OrderBy, Random, RawSQL, Ref
from django.db.models.query_utils import QueryWrapper, select_related_descend from django.db.models.query_utils import QueryWrapper, select_related_descend
from django.db.models.sql.constants import ( from django.db.models.sql.constants import (
CURSOR, GET_ITERATOR_CHUNK_SIZE, MULTI, NO_RESULTS, ORDER_DIR, SINGLE, CURSOR, GET_ITERATOR_CHUNK_SIZE, MULTI, NO_RESULTS, ORDER_DIR, SINGLE,
@ -126,11 +126,6 @@ class SQLCompiler:
for expr in expressions: for expr in expressions:
sql, params = self.compile(expr) sql, params = self.compile(expr)
if isinstance(expr, Subquery) and not sql.startswith('('):
# Subquery expression from HAVING clause may not contain
# wrapping () because they could be removed when a subquery is
# the "rhs" in an expression (see Subquery._prepare()).
sql = '(%s)' % sql
if (sql, tuple(params)) not in seen: if (sql, tuple(params)) not in seen:
result.append((sql, params)) result.append((sql, params))
seen.add((sql, tuple(params))) seen.add((sql, tuple(params)))

View File

@ -1022,7 +1022,10 @@ class Query(BaseExpression):
return clone return clone
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return self.get_compiler(connection=connection).as_sql() sql, params = self.get_compiler(connection=connection).as_sql()
if self.subquery:
sql = '(%s)' % sql
return sql, params
def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col): def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col):
if hasattr(value, 'resolve_expression'): if hasattr(value, 'resolve_expression'):