diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fc08442193..e15e64cde4 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1613,10 +1613,26 @@ class Query(BaseExpression): self.unref_alias(joins.pop()) return targets, joins[-1], joins + @classmethod + def _gen_col_aliases(cls, exprs): + for expr in exprs: + if isinstance(expr, Col): + yield expr.alias + else: + yield from cls._gen_col_aliases(expr.get_source_expressions()) + def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simple_col=False): if not allow_joins and LOOKUP_SEP in name: raise FieldError("Joined field references are not permitted in this query") - if name in self.annotations: + annotation = self.annotations.get(name) + if annotation is not None: + if not allow_joins: + for alias in self._gen_col_aliases([annotation]): + if isinstance(self.alias_map[alias], Join): + raise FieldError( + 'Joined field references are not permitted in ' + 'this query' + ) if summarize: # Summarize currently means we are doing an aggregate() query # which is executed as a wrapped subquery if any of the @@ -1624,7 +1640,7 @@ class Query(BaseExpression): # that case we need to return a Ref to the subquery's annotation. return Ref(name, self.annotation_select[name]) else: - return self.annotations[name] + return annotation else: field_list = name.split(LOOKUP_SEP) join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) diff --git a/tests/update/tests.py b/tests/update/tests.py index 8cb4d11f75..abf4db11d9 100644 --- a/tests/update/tests.py +++ b/tests/update/tests.py @@ -1,5 +1,6 @@ from django.core.exceptions import FieldError from django.db.models import Count, F, Max +from django.db.models.functions import Concat, Lower from django.test import TestCase from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint @@ -182,16 +183,19 @@ class AdvancedTests(TestCase): # Update where annotation is used for filtering qs = DataPoint.objects.annotate(related_count=Count('relatedpoint')) self.assertEqual(qs.filter(related_count=1).update(value='Foo'), 1) - # Update where annotation is used in update parameters - # #26539 - This isn't forbidden but also doesn't generate proper SQL - # qs = RelatedPoint.objects.annotate(data_name=F('data__name')) - # updated = qs.update(name=F('data_name')) - # self.assertEqual(updated, 1) # Update where aggregation annotation is used in update parameters qs = RelatedPoint.objects.annotate(max=Max('data__value')) - msg = ( - 'Aggregate functions are not allowed in this query ' - '(name=Max(Col(update_datapoint, update.DataPoint.value))).' - ) + msg = 'Joined field references are not permitted in this query' with self.assertRaisesMessage(FieldError, msg): qs.update(name=F('max')) + + def test_update_with_joined_field_annotation(self): + msg = 'Joined field references are not permitted in this query' + for annotation in ( + F('data__name'), + Lower('data__name'), + Concat('data__name', 'data__value'), + ): + with self.subTest(annotation=annotation): + with self.assertRaisesMessage(FieldError, msg): + RelatedPoint.objects.annotate(new_name=annotation).update(name=F('new_name'))