Refs #33992 -- Refactored subquery grouping logic.
This required moving the combined queries slicing logic to the compiler in order to allow Query.exists() to be called at expression resolving time. It allowed for Query.exists() to be called at Exists() initialization time and thus ensured that get_group_by_cols() was operating on the terminal representation of the query that only has a single column selected.
This commit is contained in:
parent
04518e310d
commit
3d734c09ff
|
@ -1470,11 +1470,10 @@ class Subquery(BaseExpression, Combinable):
|
||||||
def get_external_cols(self):
|
def get_external_cols(self):
|
||||||
return self.query.get_external_cols()
|
return self.query.get_external_cols()
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, template=None, query=None, **extra_context):
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
template_params = {**self.extra, **extra_context}
|
template_params = {**self.extra, **extra_context}
|
||||||
query = query or self.query
|
subquery_sql, sql_params = self.query.as_sql(compiler, connection)
|
||||||
subquery_sql, sql_params = query.as_sql(compiler, connection)
|
|
||||||
template_params["subquery"] = subquery_sql[1:-1]
|
template_params["subquery"] = subquery_sql[1:-1]
|
||||||
|
|
||||||
template = template or template_params.get("template", self.template)
|
template = template or template_params.get("template", self.template)
|
||||||
|
@ -1482,13 +1481,7 @@ class Subquery(BaseExpression, Combinable):
|
||||||
return sql, sql_params
|
return sql, sql_params
|
||||||
|
|
||||||
def get_group_by_cols(self, alias=None):
|
def get_group_by_cols(self, alias=None):
|
||||||
# If this expression is referenced by an alias for an explicit GROUP BY
|
return self.query.get_group_by_cols(alias=alias, wrapper=self)
|
||||||
# through values() a reference to this expression and not the
|
|
||||||
# underlying .query must be returned to ensure external column
|
|
||||||
# references are not grouped against as well.
|
|
||||||
if alias:
|
|
||||||
return [Ref(alias, self)]
|
|
||||||
return self.query.get_group_by_cols()
|
|
||||||
|
|
||||||
|
|
||||||
class Exists(Subquery):
|
class Exists(Subquery):
|
||||||
|
@ -1498,28 +1491,18 @@ class Exists(Subquery):
|
||||||
def __init__(self, queryset, negated=False, **kwargs):
|
def __init__(self, queryset, negated=False, **kwargs):
|
||||||
self.negated = negated
|
self.negated = negated
|
||||||
super().__init__(queryset, **kwargs)
|
super().__init__(queryset, **kwargs)
|
||||||
|
self.query = self.query.exists()
|
||||||
|
|
||||||
def __invert__(self):
|
def __invert__(self):
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
clone.negated = not self.negated
|
clone.negated = not self.negated
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def get_group_by_cols(self, alias=None):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
# self.query only gets limited to a single row in the .exists() call
|
|
||||||
# from self.as_sql() so deferring to Query.get_group_by_cols() is
|
|
||||||
# inappropriate.
|
|
||||||
if alias is None:
|
|
||||||
return [self]
|
|
||||||
return super().get_group_by_cols(alias)
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, template=None, **extra_context):
|
|
||||||
query = self.query.exists(using=connection.alias)
|
|
||||||
try:
|
try:
|
||||||
sql, params = super().as_sql(
|
sql, params = super().as_sql(
|
||||||
compiler,
|
compiler,
|
||||||
connection,
|
connection,
|
||||||
template=template,
|
|
||||||
query=query,
|
|
||||||
**extra_context,
|
**extra_context,
|
||||||
)
|
)
|
||||||
except EmptyResultSet:
|
except EmptyResultSet:
|
||||||
|
|
|
@ -535,8 +535,8 @@ class SQLCompiler:
|
||||||
if not query.is_empty()
|
if not query.is_empty()
|
||||||
]
|
]
|
||||||
if not features.supports_slicing_ordering_in_compound:
|
if not features.supports_slicing_ordering_in_compound:
|
||||||
for query, compiler in zip(self.query.combined_queries, compilers):
|
for compiler in compilers:
|
||||||
if query.low_mark or query.high_mark:
|
if compiler.query.is_sliced:
|
||||||
raise DatabaseError(
|
raise DatabaseError(
|
||||||
"LIMIT/OFFSET not allowed in subqueries of compound statements."
|
"LIMIT/OFFSET not allowed in subqueries of compound statements."
|
||||||
)
|
)
|
||||||
|
@ -544,6 +544,11 @@ class SQLCompiler:
|
||||||
raise DatabaseError(
|
raise DatabaseError(
|
||||||
"ORDER BY not allowed in subqueries of compound statements."
|
"ORDER BY not allowed in subqueries of compound statements."
|
||||||
)
|
)
|
||||||
|
elif self.query.is_sliced and combinator == "union":
|
||||||
|
limit = (self.query.low_mark, self.query.high_mark)
|
||||||
|
for compiler in compilers:
|
||||||
|
if not compiler.query.is_sliced:
|
||||||
|
compiler.query.set_limits(*limit)
|
||||||
parts = ()
|
parts = ()
|
||||||
for compiler in compilers:
|
for compiler in compilers:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -556,7 +556,7 @@ class Query(BaseExpression):
|
||||||
def has_filters(self):
|
def has_filters(self):
|
||||||
return self.where
|
return self.where
|
||||||
|
|
||||||
def exists(self, using, limit=True):
|
def exists(self, limit=True):
|
||||||
q = self.clone()
|
q = self.clone()
|
||||||
if not (q.distinct and q.is_sliced):
|
if not (q.distinct and q.is_sliced):
|
||||||
if q.group_by is True:
|
if q.group_by is True:
|
||||||
|
@ -568,11 +568,8 @@ class Query(BaseExpression):
|
||||||
q.set_group_by(allow_aliases=False)
|
q.set_group_by(allow_aliases=False)
|
||||||
q.clear_select_clause()
|
q.clear_select_clause()
|
||||||
if q.combined_queries and q.combinator == "union":
|
if q.combined_queries and q.combinator == "union":
|
||||||
limit_combined = connections[
|
|
||||||
using
|
|
||||||
].features.supports_slicing_ordering_in_compound
|
|
||||||
q.combined_queries = tuple(
|
q.combined_queries = tuple(
|
||||||
combined_query.exists(using, limit=limit_combined)
|
combined_query.exists(limit=False)
|
||||||
for combined_query in q.combined_queries
|
for combined_query in q.combined_queries
|
||||||
)
|
)
|
||||||
q.clear_ordering(force=True)
|
q.clear_ordering(force=True)
|
||||||
|
@ -1150,12 +1147,16 @@ class Query(BaseExpression):
|
||||||
if col.alias in self.external_aliases
|
if col.alias in self.external_aliases
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_group_by_cols(self, alias=None):
|
def get_group_by_cols(self, alias=None, wrapper=None):
|
||||||
|
# If wrapper is referenced by an alias for an explicit GROUP BY through
|
||||||
|
# values() a reference to this expression and not the self must be
|
||||||
|
# returned to ensure external column references are not grouped against
|
||||||
|
# as well.
|
||||||
if alias:
|
if alias:
|
||||||
return [Ref(alias, self)]
|
return [Ref(alias, wrapper or self)]
|
||||||
external_cols = self.get_external_cols()
|
external_cols = self.get_external_cols()
|
||||||
if any(col.possibly_multivalued for col in external_cols):
|
if any(col.possibly_multivalued for col in external_cols):
|
||||||
return [self]
|
return [wrapper or self]
|
||||||
return external_cols
|
return external_cols
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
|
|
@ -1440,9 +1440,7 @@ class AggregateTestCase(TestCase):
|
||||||
.annotate(cnt=Count("isbn"))
|
.annotate(cnt=Count("isbn"))
|
||||||
.filter(cnt__gt=1)
|
.filter(cnt__gt=1)
|
||||||
)
|
)
|
||||||
query = publishers_having_more_than_one_book_qs.query.exists(
|
query = publishers_having_more_than_one_book_qs.query.exists()
|
||||||
using=connection.alias
|
|
||||||
)
|
|
||||||
_, _, group_by = query.get_compiler(connection=connection).pre_sql_setup()
|
_, _, group_by = query.get_compiler(connection=connection).pre_sql_setup()
|
||||||
self.assertEqual(len(group_by), 1)
|
self.assertEqual(len(group_by), 1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue