Refs #33308 -- Enabled explicit GROUP BY and ORDER BY aliases.
This ensures explicit grouping from using values() before annotating an aggregate function groups by selected aliases if supported. The GROUP BY feature is disabled on Oracle because it doesn't support it.
This commit is contained in:
parent
c58a8acd41
commit
04518e310d
|
@ -10,6 +10,7 @@ class BaseDatabaseFeatures:
|
||||||
allows_group_by_lob = True
|
allows_group_by_lob = True
|
||||||
allows_group_by_pk = False
|
allows_group_by_pk = False
|
||||||
allows_group_by_selected_pks = False
|
allows_group_by_selected_pks = False
|
||||||
|
allows_group_by_refs = True
|
||||||
empty_fetchmany_value = []
|
empty_fetchmany_value = []
|
||||||
update_can_self_select = True
|
update_can_self_select = True
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
|
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
|
||||||
# BLOB" when grouping by LOBs (#24096).
|
# BLOB" when grouping by LOBs (#24096).
|
||||||
allows_group_by_lob = False
|
allows_group_by_lob = False
|
||||||
|
allows_group_by_refs = False
|
||||||
interprets_empty_strings_as_nulls = True
|
interprets_empty_strings_as_nulls = True
|
||||||
has_select_for_update = True
|
has_select_for_update = True
|
||||||
has_select_for_update_nowait = True
|
has_select_for_update_nowait = True
|
||||||
|
|
|
@ -419,6 +419,8 @@ class BaseExpression:
|
||||||
|
|
||||||
def get_group_by_cols(self, alias=None):
|
def get_group_by_cols(self, alias=None):
|
||||||
if not self.contains_aggregate:
|
if not self.contains_aggregate:
|
||||||
|
if alias:
|
||||||
|
return [Ref(alias, self)]
|
||||||
return [self]
|
return [self]
|
||||||
cols = []
|
cols = []
|
||||||
for source in self.get_source_expressions():
|
for source in self.get_source_expressions():
|
||||||
|
@ -1243,7 +1245,7 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
|
||||||
return expression.get_group_by_cols(alias=alias)
|
return expression.get_group_by_cols(alias=alias)
|
||||||
# For non-expressions e.g. an SQL WHERE clause, the entire
|
# For non-expressions e.g. an SQL WHERE clause, the entire
|
||||||
# `expression` must be included in the GROUP BY clause.
|
# `expression` must be included in the GROUP BY clause.
|
||||||
return super().get_group_by_cols()
|
return super().get_group_by_cols(alias=alias)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
return compiler.compile(self.expression)
|
return compiler.compile(self.expression)
|
||||||
|
|
|
@ -131,9 +131,12 @@ class SQLCompiler:
|
||||||
# Converts string references to expressions.
|
# Converts string references to expressions.
|
||||||
for expr in self.query.group_by:
|
for expr in self.query.group_by:
|
||||||
if not hasattr(expr, "as_sql"):
|
if not hasattr(expr, "as_sql"):
|
||||||
expressions.append(self.query.resolve_ref(expr))
|
expr = self.query.resolve_ref(expr)
|
||||||
else:
|
if not self.connection.features.allows_group_by_refs and isinstance(
|
||||||
expressions.append(expr)
|
expr, Ref
|
||||||
|
):
|
||||||
|
expr = expr.source
|
||||||
|
expressions.append(expr)
|
||||||
# Note that even if the group_by is set, it is only the minimal
|
# Note that even if the group_by is set, it is only the minimal
|
||||||
# set to group by. So, we need to add cols in select, order_by, and
|
# set to group by. So, we need to add cols in select, order_by, and
|
||||||
# having into the select in any case.
|
# having into the select in any case.
|
||||||
|
@ -344,7 +347,13 @@ class SQLCompiler:
|
||||||
if not self.query.standard_ordering:
|
if not self.query.standard_ordering:
|
||||||
field = field.copy()
|
field = field.copy()
|
||||||
field.reverse_ordering()
|
field.reverse_ordering()
|
||||||
yield field, False
|
if isinstance(field.expression, F) and (
|
||||||
|
annotation := self.query.annotation_select.get(
|
||||||
|
field.expression.name
|
||||||
|
)
|
||||||
|
):
|
||||||
|
field.expression = Ref(field.expression.name, annotation)
|
||||||
|
yield field, isinstance(field.expression, Ref)
|
||||||
continue
|
continue
|
||||||
if field == "?": # random
|
if field == "?": # random
|
||||||
yield OrderBy(Random()), False
|
yield OrderBy(Random()), False
|
||||||
|
@ -432,6 +441,10 @@ class SQLCompiler:
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
seen = set()
|
seen = set()
|
||||||
|
replacements = {
|
||||||
|
expr: Ref(alias, expr)
|
||||||
|
for alias, expr in self.query.annotation_select.items()
|
||||||
|
}
|
||||||
|
|
||||||
for expr, is_ref in self._order_by_pairs():
|
for expr, is_ref in self._order_by_pairs():
|
||||||
resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
|
resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
|
||||||
|
@ -461,7 +474,7 @@ class SQLCompiler:
|
||||||
q.add_annotation(expr_src, col_name)
|
q.add_annotation(expr_src, col_name)
|
||||||
self.query.add_select_col(resolved, col_name)
|
self.query.add_select_col(resolved, col_name)
|
||||||
resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())])
|
resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())])
|
||||||
sql, params = self.compile(resolved)
|
sql, params = self.compile(resolved.replace_expressions(replacements))
|
||||||
# Don't add the same column twice, but the order direction is
|
# Don't add the same column twice, but the order direction is
|
||||||
# not taken into account so we strip it. When this entire method
|
# not taken into account so we strip it. When this entire method
|
||||||
# is refactored into expressions, then we can check each part as we
|
# is refactored into expressions, then we can check each part as we
|
||||||
|
|
|
@ -2220,8 +2220,8 @@ class Query(BaseExpression):
|
||||||
primary key, and the query would be equivalent, the optimization
|
primary key, and the query would be equivalent, the optimization
|
||||||
will be made automatically.
|
will be made automatically.
|
||||||
"""
|
"""
|
||||||
# Column names from JOINs to check collisions with aliases.
|
|
||||||
if allow_aliases:
|
if allow_aliases:
|
||||||
|
# Column names from JOINs to check collisions with aliases.
|
||||||
column_names = set()
|
column_names = set()
|
||||||
seen_models = set()
|
seen_models = set()
|
||||||
for join in list(self.alias_map.values())[1:]: # Skip base table.
|
for join in list(self.alias_map.values())[1:]: # Skip base table.
|
||||||
|
@ -2231,7 +2231,20 @@ class Query(BaseExpression):
|
||||||
{field.column for field in model._meta.local_concrete_fields}
|
{field.column for field in model._meta.local_concrete_fields}
|
||||||
)
|
)
|
||||||
seen_models.add(model)
|
seen_models.add(model)
|
||||||
|
if self.values_select:
|
||||||
|
# If grouping by aliases is allowed assign selected values
|
||||||
|
# aliases by moving them to annotations.
|
||||||
|
group_by_annotations = {}
|
||||||
|
values_select = {}
|
||||||
|
for alias, expr in zip(self.values_select, self.select):
|
||||||
|
if isinstance(expr, Col):
|
||||||
|
values_select[alias] = expr
|
||||||
|
else:
|
||||||
|
group_by_annotations[alias] = expr
|
||||||
|
self.annotations = {**group_by_annotations, **self.annotations}
|
||||||
|
self.append_annotation_mask(group_by_annotations)
|
||||||
|
self.select = tuple(values_select.values())
|
||||||
|
self.values_select = tuple(values_select)
|
||||||
group_by = list(self.select)
|
group_by = list(self.select)
|
||||||
if self.annotation_select:
|
if self.annotation_select:
|
||||||
for alias, annotation in self.annotation_select.items():
|
for alias, annotation in self.annotation_select.items():
|
||||||
|
|
|
@ -178,10 +178,19 @@ class AggregationTests(TestCase):
|
||||||
)
|
)
|
||||||
.annotate(sum_discount=Sum("discount_price"))
|
.annotate(sum_discount=Sum("discount_price"))
|
||||||
)
|
)
|
||||||
self.assertSequenceEqual(
|
with self.assertNumQueries(1) as ctx:
|
||||||
values,
|
self.assertSequenceEqual(
|
||||||
[{"discount_price": Decimal("59.38"), "sum_discount": Decimal("59.38")}],
|
values,
|
||||||
)
|
[
|
||||||
|
{
|
||||||
|
"discount_price": Decimal("59.38"),
|
||||||
|
"sum_discount": Decimal("59.38"),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if connection.features.allows_group_by_refs:
|
||||||
|
alias = connection.ops.quote_name("discount_price")
|
||||||
|
self.assertIn(f"GROUP BY {alias}", ctx[0]["sql"])
|
||||||
|
|
||||||
def test_aggregates_in_where_clause(self):
|
def test_aggregates_in_where_clause(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,7 +12,12 @@ from django.core.management import call_command
|
||||||
from django.db import IntegrityError, connection, models
|
from django.db import IntegrityError, connection, models
|
||||||
from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
|
from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
|
||||||
from django.db.models.functions import Cast, JSONObject, Upper
|
from django.db.models.functions import Cast, JSONObject, Upper
|
||||||
from django.test import TransactionTestCase, modify_settings, override_settings
|
from django.test import (
|
||||||
|
TransactionTestCase,
|
||||||
|
modify_settings,
|
||||||
|
override_settings,
|
||||||
|
skipUnlessDBFeature,
|
||||||
|
)
|
||||||
from django.test.utils import isolate_apps
|
from django.test.utils import isolate_apps
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
|
@ -405,6 +410,27 @@ class TestQuerying(PostgreSQLTestCase):
|
||||||
expected,
|
expected,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("allows_group_by_refs")
|
||||||
|
def test_group_by_order_by_aliases(self):
|
||||||
|
with self.assertNumQueries(1) as ctx:
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
NullableIntegerArrayModel.objects.filter(
|
||||||
|
field__0__isnull=False,
|
||||||
|
)
|
||||||
|
.values("field__0")
|
||||||
|
.annotate(arrayagg=ArrayAgg("id"))
|
||||||
|
.order_by("field__0"),
|
||||||
|
[
|
||||||
|
{"field__0": 1, "arrayagg": [1]},
|
||||||
|
{"field__0": 2, "arrayagg": [2, 3]},
|
||||||
|
{"field__0": 20, "arrayagg": [4]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
alias = connection.ops.quote_name("field__0")
|
||||||
|
sql = ctx[0]["sql"]
|
||||||
|
self.assertIn(f"GROUP BY {alias}", sql)
|
||||||
|
self.assertIn(f"ORDER BY {alias}", sql)
|
||||||
|
|
||||||
def test_index(self):
|
def test_index(self):
|
||||||
self.assertSequenceEqual(
|
self.assertSequenceEqual(
|
||||||
NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
|
NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
|
||||||
|
|
Loading…
Reference in New Issue