diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 9f66a01e74..1a7625bfa9 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -113,6 +113,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): "related fields.": { "update.tests.AdvancedTests." "test_update_ordered_by_inline_m2m_annotation", + "update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation", }, } if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode: diff --git a/django/db/models/query.py b/django/db/models/query.py index 308073d4de..a169d0c235 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1169,6 +1169,20 @@ class QuerySet: self._for_write = True query = self.query.chain(sql.UpdateQuery) query.add_update_values(kwargs) + + # Inline annotations in order_by(), if possible. + new_order_by = [] + for col in query.order_by: + if annotation := query.annotations.get(col): + if getattr(annotation, "contains_aggregate", False): + raise exceptions.FieldError( + f"Cannot update when ordering by an aggregate: {annotation}" + ) + new_order_by.append(annotation) + else: + new_order_by.append(col) + query.order_by = tuple(new_order_by) + # Clear any annotations so that they won't be present in subqueries. query.annotations = {} with transaction.mark_for_rollback_on_error(using=self.db): diff --git a/tests/update/tests.py b/tests/update/tests.py index 0c3a399514..2162f5164d 100644 --- a/tests/update/tests.py +++ b/tests/update/tests.py @@ -225,6 +225,16 @@ class AdvancedTests(TestCase): new_name=annotation, ).update(name=F("new_name")) + def test_update_ordered_by_m2m_aggregation_annotation(self): + msg = ( + "Cannot update when ordering by an aggregate: " + "Count(Col(update_bar_m2m_foo, update.Bar_m2m_foo.foo))" + ) + with self.assertRaisesMessage(FieldError, msg): + Bar.objects.annotate(m2m_count=Count("m2m_foo")).order_by( + "m2m_count" + ).update(x=2) + def test_update_ordered_by_inline_m2m_annotation(self): foo = Foo.objects.create(target="test") Bar.objects.create(foo=foo) @@ -232,6 +242,13 @@ class AdvancedTests(TestCase): Bar.objects.order_by(Abs("m2m_foo")).update(x=2) self.assertEqual(Bar.objects.get().x, 2) + def test_update_ordered_by_m2m_annotation(self): + foo = Foo.objects.create(target="test") + Bar.objects.create(foo=foo) + + Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3) + self.assertEqual(Bar.objects.get().x, 3) + @unittest.skipUnless( connection.vendor == "mysql", @@ -259,14 +276,12 @@ class MySQLUpdateOrderByTest(TestCase): self.assertEqual(updated, 2) def test_order_by_update_on_unique_constraint_annotation(self): - # Ordering by annotations is omitted because they cannot be resolved in - # .update(). - with self.assertRaises(IntegrityError): - UniqueNumber.objects.annotate(number_inverse=F("number").desc(),).order_by( - "number_inverse" - ).update( - number=F("number") + 1, - ) + updated = ( + UniqueNumber.objects.annotate(number_inverse=F("number").desc()) + .order_by("number_inverse") + .update(number=F("number") + 1) + ) + self.assertEqual(updated, 2) def test_order_by_update_on_parent_unique_constraint(self): # Ordering by inherited fields is omitted because joined fields cannot