From 65ad4ade74dc9208b9d686a451cd6045df0c9c3a Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Tue, 28 Mar 2023 00:13:00 -0400 Subject: [PATCH] Refs #28900 -- Made SELECT respect the order specified by values(*selected). Previously the order was always extra_fields + model_fields + annotations with respective local ordering inferred from the insertion order of *selected. This commits introduces a new `Query.selected` propery that keeps tracks of the global select order as specified by on values assignment. This is crucial feature to allow the combination of queries mixing annotations and table references. It also allows the removal of the re-ordering shenanigans perform by ValuesListIterable in order to re-map the tuples returned from the database backend to the order specified by values_list() as they'll be in the right order at query compilation time. Refs #28553 as the initially reported issue that was only partially fixed for annotations by d6b6e5d0fd4e6b6d0183b4cf6e4bd4f9afc7bf67. Thanks Mariusz Felisiak and Sarah Boyce for review. --- django/db/models/query.py | 37 ++++++---------------- django/db/models/sql/compiler.py | 45 +++++++++++++++++--------- django/db/models/sql/query.py | 47 +++++++++++++++++++--------- docs/ref/models/querysets.txt | 10 ++++++ docs/releases/5.2.txt | 8 ++++- docs/spelling_wordlist | 1 + tests/annotations/tests.py | 10 ++++++ tests/postgres_tests/test_array.py | 4 +-- tests/queries/test_qs_combinators.py | 23 +++++++++++++- tests/queries/tests.py | 2 +- 10 files changed, 125 insertions(+), 62 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index cb5c63c0d17..3f9d4768f76 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -200,12 +200,15 @@ class ValuesIterable(BaseIterable): query = queryset.query compiler = query.get_compiler(queryset.db) - # extra(select=...) cols are always at the start of the row. - names = [ - *query.extra_select, - *query.values_select, - *query.annotation_select, - ] + if query.selected: + names = list(query.selected) + else: + # extra(select=...) cols are always at the start of the row. + names = [ + *query.extra_select, + *query.values_select, + *query.annotation_select, + ] indexes = range(len(names)) for row in compiler.results_iter( chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size @@ -223,28 +226,6 @@ class ValuesListIterable(BaseIterable): queryset = self.queryset query = queryset.query compiler = query.get_compiler(queryset.db) - - if queryset._fields: - # extra(select=...) cols are always at the start of the row. - names = [ - *query.extra_select, - *query.values_select, - *query.annotation_select, - ] - fields = [ - *queryset._fields, - *(f for f in query.annotation_select if f not in queryset._fields), - ] - if fields != names: - # Reorder according to fields. - index_map = {name: idx for idx, name in enumerate(names)} - rowfactory = operator.itemgetter(*[index_map[f] for f in fields]) - return map( - rowfactory, - compiler.results_iter( - chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size - ), - ) return compiler.results_iter( tuple_expected=True, chunked_fetch=self.chunked_fetch, diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 7377e555c3a..d606505cdf5 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -247,11 +247,6 @@ class SQLCompiler: select = [] klass_info = None annotations = {} - select_idx = 0 - for alias, (sql, params) in self.query.extra_select.items(): - annotations[alias] = select_idx - select.append((RawSQL(sql, params), alias)) - select_idx += 1 assert not (self.query.select and self.query.default_cols) select_mask = self.query.get_select_mask() if self.query.default_cols: @@ -261,19 +256,39 @@ class SQLCompiler: # any model. cols = self.query.select if cols: - select_list = [] - for col in cols: - select_list.append(select_idx) - select.append((col, None)) - select_idx += 1 klass_info = { "model": self.query.model, - "select_fields": select_list, + "select_fields": list( + range( + len(self.query.extra_select), + len(self.query.extra_select) + len(cols), + ) + ), } - for alias, annotation in self.query.annotation_select.items(): - annotations[alias] = select_idx - select.append((annotation, alias)) - select_idx += 1 + selected = [] + if self.query.selected is None: + selected = [ + *( + (alias, RawSQL(*args)) + for alias, args in self.query.extra_select.items() + ), + *((None, col) for col in cols), + *self.query.annotation_select.items(), + ] + else: + for alias, expression in self.query.selected.items(): + # Reference to an annotation. + if isinstance(expression, str): + expression = self.query.annotations[expression] + # Reference to a column. + elif isinstance(expression, int): + expression = cols[expression] + selected.append((alias, expression)) + + for select_idx, (alias, expression) in enumerate(selected): + if alias: + annotations[alias] = select_idx + select.append((expression, alias)) if self.query.select_related: related_klass_infos = self.get_related_selections(select, select_mask) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 07c3fdbd344..ce97ebe1d1d 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -26,6 +26,7 @@ from django.db.models.expressions import ( Exists, F, OuterRef, + RawSQL, Ref, ResolvedOuterRef, Value, @@ -265,6 +266,7 @@ class Query(BaseExpression): # Holds the selects defined by a call to values() or values_list() # excluding annotation_select and extra_select. values_select = () + selected = None # SQL annotation-related attributes. annotation_select_mask = None @@ -584,6 +586,7 @@ class Query(BaseExpression): else: outer_query = self self.select = () + self.selected = None self.default_cols = False self.extra = {} if self.annotations: @@ -1194,13 +1197,10 @@ class Query(BaseExpression): if select: self.append_annotation_mask([alias]) else: - annotation_mask = ( - value - for value in dict.fromkeys(self.annotation_select) - if value != alias - ) - self.set_annotation_mask(annotation_mask) + self.set_annotation_mask(set(self.annotation_select).difference({alias})) self.annotations[alias] = annotation + if self.selected: + self.selected[alias] = alias def resolve_expression(self, query, *args, **kwargs): clone = self.clone() @@ -2153,6 +2153,7 @@ class Query(BaseExpression): self.select_related = False self.set_extra_mask(()) self.set_annotation_mask(()) + self.selected = None def clear_select_fields(self): """ @@ -2162,10 +2163,12 @@ class Query(BaseExpression): """ self.select = () self.values_select = () + self.selected = None def add_select_col(self, col, name): self.select += (col,) self.values_select += (name,) + self.selected[name] = len(self.select) - 1 def set_select(self, cols): self.default_cols = False @@ -2416,12 +2419,23 @@ class Query(BaseExpression): if names is None: self.annotation_select_mask = None else: - self.annotation_select_mask = list(dict.fromkeys(names)) + self.annotation_select_mask = set(names) + if self.selected: + # Prune the masked annotations. + self.selected = { + key: value + for key, value in self.selected.items() + if not isinstance(value, str) + or value in self.annotation_select_mask + } + # Append the unmasked annotations. + for name in names: + self.selected[name] = name self._annotation_select_cache = None def append_annotation_mask(self, names): if self.annotation_select_mask is not None: - self.set_annotation_mask((*self.annotation_select_mask, *names)) + self.set_annotation_mask(self.annotation_select_mask.union(names)) def set_extra_mask(self, names): """ @@ -2440,6 +2454,7 @@ class Query(BaseExpression): self.clear_select_fields() self.has_select_fields = True + selected = {} if fields: field_names = [] extra_names = [] @@ -2448,13 +2463,16 @@ class Query(BaseExpression): # Shortcut - if there are no extra or annotations, then # the values() clause must be just field names. field_names = list(fields) + selected = dict(zip(fields, range(len(fields)))) else: self.default_cols = False for f in fields: - if f in self.extra_select: + if extra := self.extra_select.get(f): extra_names.append(f) + selected[f] = RawSQL(*extra) elif f in self.annotation_select: annotation_names.append(f) + selected[f] = f elif f in self.annotations: raise FieldError( f"Cannot select the '{f}' alias. Use annotate() to " @@ -2466,13 +2484,13 @@ class Query(BaseExpression): # `f` is not resolvable. if self.annotation_select: self.names_to_path(f.split(LOOKUP_SEP), self.model._meta) + selected[f] = len(field_names) field_names.append(f) self.set_extra_mask(extra_names) self.set_annotation_mask(annotation_names) - selected = frozenset(field_names + extra_names + annotation_names) else: field_names = [f.attname for f in self.model._meta.concrete_fields] - selected = frozenset(field_names) + selected = dict.fromkeys(field_names, None) # Selected annotations must be known before setting the GROUP BY # clause. if self.group_by is True: @@ -2495,6 +2513,7 @@ class Query(BaseExpression): self.values_select = tuple(field_names) self.add_fields(field_names, True) + self.selected = selected if fields else None @property def annotation_select(self): @@ -2508,9 +2527,9 @@ class Query(BaseExpression): return {} elif self.annotation_select_mask is not None: self._annotation_select_cache = { - k: self.annotations[k] - for k in self.annotation_select_mask - if k in self.annotations + k: v + for k, v in self.annotations.items() + if k in self.annotation_select_mask } return self._annotation_select_cache else: diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 7a0d086bfe4..d708e05a79c 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -745,6 +745,11 @@ You can also refer to fields on related models with reverse relations through ``"true"``, ``"false"``, and ``"null"`` strings for :class:`~django.db.models.JSONField` key transforms. +.. versionchanged:: 5.2 + + The ``SELECT`` clause generated when using ``values()`` was updated to + respect the order of the specified ``*fields`` and ``**expressions``. + ``values_list()`` ~~~~~~~~~~~~~~~~~ @@ -835,6 +840,11 @@ not having any author: ``"true"``, ``"false"``, and ``"null"`` strings for :class:`~django.db.models.JSONField` key transforms. +.. versionchanged:: 5.2 + + The ``SELECT`` clause generated when using ``values_list()`` was updated to + respect the order of the specified ``*fields``. + ``dates()`` ~~~~~~~~~~~ diff --git a/docs/releases/5.2.txt b/docs/releases/5.2.txt index bdc53493684..5d5887fe34a 100644 --- a/docs/releases/5.2.txt +++ b/docs/releases/5.2.txt @@ -195,7 +195,13 @@ Migrations Models ~~~~~~ -* ... +* The ``SELECT`` clause generated when using + :meth:`QuerySet.values()` and + :meth:`~django.db.models.query.QuerySet.values_list` now matches the + specified order of the referenced expressions. Previously the order was based + of a set of counterintuitive rules which made query combination through + methods such as + :meth:`QuerySet.union()` unpredictable. Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/spelling_wordlist b/docs/spelling_wordlist index 1044cd80ebd..d715e62e054 100644 --- a/docs/spelling_wordlist +++ b/docs/spelling_wordlist @@ -96,6 +96,7 @@ contenttypes contrib coroutine coroutines +counterintuitive criticals cron crontab diff --git a/tests/annotations/tests.py b/tests/annotations/tests.py index f1260b41926..703847e1dd5 100644 --- a/tests/annotations/tests.py +++ b/tests/annotations/tests.py @@ -568,6 +568,16 @@ class NonAggregateAnnotationTestCase(TestCase): self.assertEqual(book["other_rating"], 4) self.assertEqual(book["other_isbn"], "155860191") + def test_values_fields_annotations_order(self): + qs = Book.objects.annotate(other_rating=F("rating") - 1).values( + "other_rating", "rating" + ) + book = qs.get(pk=self.b1.pk) + self.assertEqual( + list(book.items()), + [("other_rating", self.b1.rating - 1), ("rating", self.b1.rating)], + ) + def test_values_with_pk_annotation(self): # annotate references a field in values() with pk publishers = Publisher.objects.values("id", "book__rating").annotate( diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 386a0afa3a9..ff0c4aabb1c 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -466,8 +466,8 @@ class TestQuerying(PostgreSQLTestCase): ], ) sql = ctx[0]["sql"] - self.assertIn("GROUP BY 2", sql) - self.assertIn("ORDER BY 2", sql) + self.assertIn("GROUP BY 1", sql) + self.assertIn("ORDER BY 1", sql) def test_order_by_arrayagg_index(self): qs = ( diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py index 4c2dbc5b176..e1306130843 100644 --- a/tests/queries/test_qs_combinators.py +++ b/tests/queries/test_qs_combinators.py @@ -257,6 +257,23 @@ class QuerySetSetOperationTests(TestCase): ) self.assertCountEqual(qs1.union(qs2), [(1, 0), (1, 2)]) + def test_union_with_field_and_annotation_values(self): + qs1 = ( + Number.objects.filter(num=1) + .annotate( + zero=Value(0, IntegerField()), + ) + .values_list("num", "zero") + ) + qs2 = ( + Number.objects.filter(num=2) + .annotate( + zero=Value(0, IntegerField()), + ) + .values_list("zero", "num") + ) + self.assertCountEqual(qs1.union(qs2), [(1, 0), (0, 2)]) + def test_union_with_extra_and_values_list(self): qs1 = ( Number.objects.filter(num=1) @@ -265,7 +282,11 @@ class QuerySetSetOperationTests(TestCase): ) .values_list("num", "count") ) - qs2 = Number.objects.filter(num=2).extra(select={"count": 1}) + qs2 = ( + Number.objects.filter(num=2) + .extra(select={"count": 1}) + .values_list("num", "count") + ) self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)]) def test_union_with_values_list_on_annotated_and_unannotated(self): diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 7ac8a65d420..3621b6cb2cd 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -2200,7 +2200,7 @@ class Queries6Tests(TestCase): {"tag_per_parent__max": 2}, ) sql = captured_queries[0]["sql"] - self.assertIn("AS %s" % connection.ops.quote_name("col1"), sql) + self.assertIn("AS %s" % connection.ops.quote_name("parent"), sql) def test_xor_subquery(self): self.assertSequenceEqual(