diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 5c99736f223..6b935083836 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -10,6 +10,7 @@ class BaseDatabaseFeatures: allows_group_by_lob = True allows_group_by_pk = False allows_group_by_selected_pks = False + allows_group_by_refs = True empty_fetchmany_value = [] update_can_self_select = True diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index 9a98616dc21..a85cf45560f 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -8,6 +8,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): # Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got # BLOB" when grouping by LOBs (#24096). allows_group_by_lob = False + allows_group_by_refs = False interprets_empty_strings_as_nulls = True has_select_for_update = True has_select_for_update_nowait = True diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index cbf4fd8296a..c292504d0ea 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -419,6 +419,8 @@ class BaseExpression: def get_group_by_cols(self, alias=None): if not self.contains_aggregate: + if alias: + return [Ref(alias, self)] return [self] cols = [] for source in self.get_source_expressions(): @@ -1243,7 +1245,7 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression): return expression.get_group_by_cols(alias=alias) # For non-expressions e.g. an SQL WHERE clause, the entire # `expression` must be included in the GROUP BY clause. - return super().get_group_by_cols() + return super().get_group_by_cols(alias=alias) def as_sql(self, compiler, connection): return compiler.compile(self.expression) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index f566546307d..e4605918d9c 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -131,9 +131,12 @@ class SQLCompiler: # Converts string references to expressions. for expr in self.query.group_by: if not hasattr(expr, "as_sql"): - expressions.append(self.query.resolve_ref(expr)) - else: - expressions.append(expr) + expr = self.query.resolve_ref(expr) + if not self.connection.features.allows_group_by_refs and isinstance( + expr, Ref + ): + expr = expr.source + expressions.append(expr) # Note that even if the group_by is set, it is only the minimal # set to group by. So, we need to add cols in select, order_by, and # having into the select in any case. @@ -344,7 +347,13 @@ class SQLCompiler: if not self.query.standard_ordering: field = field.copy() field.reverse_ordering() - yield field, False + if isinstance(field.expression, F) and ( + annotation := self.query.annotation_select.get( + field.expression.name + ) + ): + field.expression = Ref(field.expression.name, annotation) + yield field, isinstance(field.expression, Ref) continue if field == "?": # random yield OrderBy(Random()), False @@ -432,6 +441,10 @@ class SQLCompiler: """ result = [] seen = set() + replacements = { + expr: Ref(alias, expr) + for alias, expr in self.query.annotation_select.items() + } for expr, is_ref in self._order_by_pairs(): resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None) @@ -461,7 +474,7 @@ class SQLCompiler: q.add_annotation(expr_src, col_name) self.query.add_select_col(resolved, col_name) resolved.set_source_expressions([RawSQL(f"{order_by_idx}", ())]) - sql, params = self.compile(resolved) + sql, params = self.compile(resolved.replace_expressions(replacements)) # Don't add the same column twice, but the order direction is # not taken into account so we strip it. When this entire method # is refactored into expressions, then we can check each part as we diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 61e39b5153e..e454a6e8689 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -2220,8 +2220,8 @@ class Query(BaseExpression): primary key, and the query would be equivalent, the optimization will be made automatically. """ - # Column names from JOINs to check collisions with aliases. if allow_aliases: + # Column names from JOINs to check collisions with aliases. column_names = set() seen_models = set() for join in list(self.alias_map.values())[1:]: # Skip base table. @@ -2231,7 +2231,20 @@ class Query(BaseExpression): {field.column for field in model._meta.local_concrete_fields} ) seen_models.add(model) - + if self.values_select: + # If grouping by aliases is allowed assign selected values + # aliases by moving them to annotations. + group_by_annotations = {} + values_select = {} + for alias, expr in zip(self.values_select, self.select): + if isinstance(expr, Col): + values_select[alias] = expr + else: + group_by_annotations[alias] = expr + self.annotations = {**group_by_annotations, **self.annotations} + self.append_annotation_mask(group_by_annotations) + self.select = tuple(values_select.values()) + self.values_select = tuple(values_select) group_by = list(self.select) if self.annotation_select: for alias, annotation in self.annotation_select.items(): diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index 6d6902b61af..1fe86aad3ee 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -178,10 +178,19 @@ class AggregationTests(TestCase): ) .annotate(sum_discount=Sum("discount_price")) ) - self.assertSequenceEqual( - values, - [{"discount_price": Decimal("59.38"), "sum_discount": Decimal("59.38")}], - ) + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + values, + [ + { + "discount_price": Decimal("59.38"), + "sum_discount": Decimal("59.38"), + } + ], + ) + if connection.features.allows_group_by_refs: + alias = connection.ops.quote_name("discount_price") + self.assertIn(f"GROUP BY {alias}", ctx[0]["sql"]) def test_aggregates_in_where_clause(self): """ diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 436838dd090..dff9a651c3c 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -12,7 +12,12 @@ from django.core.management import call_command from django.db import IntegrityError, connection, models from django.db.models.expressions import Exists, OuterRef, RawSQL, Value from django.db.models.functions import Cast, JSONObject, Upper -from django.test import TransactionTestCase, modify_settings, override_settings +from django.test import ( + TransactionTestCase, + modify_settings, + override_settings, + skipUnlessDBFeature, +) from django.test.utils import isolate_apps from django.utils import timezone @@ -405,6 +410,27 @@ class TestQuerying(PostgreSQLTestCase): expected, ) + @skipUnlessDBFeature("allows_group_by_refs") + def test_group_by_order_by_aliases(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field__0__isnull=False, + ) + .values("field__0") + .annotate(arrayagg=ArrayAgg("id")) + .order_by("field__0"), + [ + {"field__0": 1, "arrayagg": [1]}, + {"field__0": 2, "arrayagg": [2, 3]}, + {"field__0": 20, "arrayagg": [4]}, + ], + ) + alias = connection.ops.quote_name("field__0") + sql = ctx[0]["sql"] + self.assertIn(f"GROUP BY {alias}", sql) + self.assertIn(f"ORDER BY {alias}", sql) + def test_index(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]