Refs #33308 -- Enabled explicit GROUP BY and ORDER BY aliases.

This ensures explicit grouping from using values() before annotating an
aggregate function groups by selected aliases if supported.

The GROUP BY feature is disabled on Oracle because it doesn't support it.
This commit is contained in:
Simon Charette 2022-09-22 23:57:06 -04:00 committed by Mariusz Felisiak
parent c58a8acd41
commit 04518e310d
7 changed files with 78 additions and 13 deletions

View File

@ -10,6 +10,7 @@ class BaseDatabaseFeatures:
allows_group_by_lob = True
allows_group_by_pk = False
allows_group_by_selected_pks = False
allows_group_by_refs = True
empty_fetchmany_value = []
update_can_self_select = True

View File

@ -8,6 +8,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
# BLOB" when grouping by LOBs (#24096).
allows_group_by_lob = False
allows_group_by_refs = False
interprets_empty_strings_as_nulls = True
has_select_for_update = True
has_select_for_update_nowait = True

View File

@ -419,6 +419,8 @@ class BaseExpression:
def get_group_by_cols(self, alias=None):
if not self.contains_aggregate:
if alias:
return [Ref(alias, self)]
return [self]
cols = []
for source in self.get_source_expressions():
@ -1243,7 +1245,7 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
return expression.get_group_by_cols(alias=alias)
# For non-expressions e.g. an SQL WHERE clause, the entire
# `expression` must be included in the GROUP BY clause.
return super().get_group_by_cols()
return super().get_group_by_cols(alias=alias)
def as_sql(self, compiler, connection):
return compiler.compile(self.expression)

View File

@ -131,8 +131,11 @@ class SQLCompiler:
# Converts string references to expressions.
for expr in self.query.group_by:
if not hasattr(expr, "as_sql"):
expressions.append(self.query.resolve_ref(expr))
else:
expr = self.query.resolve_ref(expr)
if not self.connection.features.allows_group_by_refs and isinstance(
expr, Ref
):
expr = expr.source
expressions.append(expr)
# Note that even if the group_by is set, it is only the minimal
# set to group by. So, we need to add cols in select, order_by, and
@ -344,7 +347,13 @@ class SQLCompiler:
if not self.query.standard_ordering:
field = field.copy()
field.reverse_ordering()
yield field, False
if isinstance(field.expression, F) and (
annotation := self.query.annotation_select.get(
field.expression.name
)
):
field.expression = Ref(field.expression.name, annotation)
yield field, isinstance(field.expression, Ref)
continue
if field == "?": # random
yield OrderBy(Random()), False
@ -432,6 +441,10 @@ class SQLCompiler:
"""
result = []
seen = set()
replacements = {
expr: Ref(alias, expr)
for alias, expr in self.query.annotation_select.items()
}
for expr, is_ref in self._order_by_pairs():
resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
@ -461,7 +474,7 @@ class SQLCompiler:
q.add_annotation(expr_src, col_name)
self.query.add_select_col(resolved, col_name)
resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())])
sql, params = self.compile(resolved)
sql, params = self.compile(resolved.replace_expressions(replacements))
# Don't add the same column twice, but the order direction is
# not taken into account so we strip it. When this entire method
# is refactored into expressions, then we can check each part as we

View File

@ -2220,8 +2220,8 @@ class Query(BaseExpression):
primary key, and the query would be equivalent, the optimization
will be made automatically.
"""
# Column names from JOINs to check collisions with aliases.
if allow_aliases:
# Column names from JOINs to check collisions with aliases.
column_names = set()
seen_models = set()
for join in list(self.alias_map.values())[1:]: # Skip base table.
@ -2231,7 +2231,20 @@ class Query(BaseExpression):
{field.column for field in model._meta.local_concrete_fields}
)
seen_models.add(model)
if self.values_select:
# If grouping by aliases is allowed assign selected values
# aliases by moving them to annotations.
group_by_annotations = {}
values_select = {}
for alias, expr in zip(self.values_select, self.select):
if isinstance(expr, Col):
values_select[alias] = expr
else:
group_by_annotations[alias] = expr
self.annotations = {**group_by_annotations, **self.annotations}
self.append_annotation_mask(group_by_annotations)
self.select = tuple(values_select.values())
self.values_select = tuple(values_select)
group_by = list(self.select)
if self.annotation_select:
for alias, annotation in self.annotation_select.items():

View File

@ -178,10 +178,19 @@ class AggregationTests(TestCase):
)
.annotate(sum_discount=Sum("discount_price"))
)
with self.assertNumQueries(1) as ctx:
self.assertSequenceEqual(
values,
[{"discount_price": Decimal("59.38"), "sum_discount": Decimal("59.38")}],
[
{
"discount_price": Decimal("59.38"),
"sum_discount": Decimal("59.38"),
}
],
)
if connection.features.allows_group_by_refs:
alias = connection.ops.quote_name("discount_price")
self.assertIn(f"GROUP BY {alias}", ctx[0]["sql"])
def test_aggregates_in_where_clause(self):
"""

View File

@ -12,7 +12,12 @@ from django.core.management import call_command
from django.db import IntegrityError, connection, models
from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
from django.db.models.functions import Cast, JSONObject, Upper
from django.test import TransactionTestCase, modify_settings, override_settings
from django.test import (
TransactionTestCase,
modify_settings,
override_settings,
skipUnlessDBFeature,
)
from django.test.utils import isolate_apps
from django.utils import timezone
@ -405,6 +410,27 @@ class TestQuerying(PostgreSQLTestCase):
expected,
)
@skipUnlessDBFeature("allows_group_by_refs")
def test_group_by_order_by_aliases(self):
with self.assertNumQueries(1) as ctx:
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(
field__0__isnull=False,
)
.values("field__0")
.annotate(arrayagg=ArrayAgg("id"))
.order_by("field__0"),
[
{"field__0": 1, "arrayagg": [1]},
{"field__0": 2, "arrayagg": [2, 3]},
{"field__0": 20, "arrayagg": [4]},
],
)
alias = connection.ops.quote_name("field__0")
sql = ctx[0]["sql"]
self.assertIn(f"GROUP BY {alias}", sql)
self.assertIn(f"ORDER BY {alias}", sql)
def test_index(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]