Fixed #19513, #18580 -- Fixed crash on QuerySet.update() after annotate().

This commit is contained in:
David Sanders 2016-04-17 10:03:08 -07:00 committed by Tim Graham
parent 06acb3445f
commit a84344bc53
3 changed files with 28 additions and 2 deletions

View File

@ -632,6 +632,8 @@ class QuerySet(object):
self._for_write = True self._for_write = True
query = self.query.clone(sql.UpdateQuery) query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs) query.add_update_values(kwargs)
# Clear any annotations so that they won't be present in subqueries.
query._annotations = None
with transaction.atomic(using=self.db, savepoint=False): with transaction.atomic(using=self.db, savepoint=False):
rows = query.get_compiler(self.db).execute_sql(CURSOR) rows = query.get_compiler(self.db).execute_sql(CURSOR)
self._result_cache = None self._result_cache = None

View File

@ -142,7 +142,11 @@ class UpdateQuery(Query):
that will be used to generate the UPDATE query. Might be more usefully that will be used to generate the UPDATE query. Might be more usefully
called add_update_targets() to hint at the extra information here. called add_update_targets() to hint at the extra information here.
""" """
self.values.extend(values_seq) for field, model, val in values_seq:
if hasattr(val, 'resolve_expression'):
# Resolve expressions here so that annotations are no longer needed
val = val.resolve_expression(self, allow_joins=False, for_save=True)
self.values.append((field, model, val))
def add_related_update(self, model, field, value): def add_related_update(self, model, field, value):
""" """

View File

@ -1,7 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models import F, Max from django.db.models import Count, F, Max
from django.test import TestCase from django.test import TestCase
from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint
@ -158,3 +158,23 @@ class AdvancedTests(TestCase):
qs = DataPoint.objects.annotate(max=Max('value')) qs = DataPoint.objects.annotate(max=Max('value'))
with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'): with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
qs.update(another_value=F('max')) qs.update(another_value=F('max'))
def test_update_annotated_multi_table_queryset(self):
"""
Update of a queryset that's been annotated and involves multiple tables.
"""
# Trivial annotated update
qs = DataPoint.objects.annotate(related_count=Count('relatedpoint'))
self.assertEqual(qs.update(value='Foo'), 3)
# Update where annotation is used for filtering
qs = DataPoint.objects.annotate(related_count=Count('relatedpoint'))
self.assertEqual(qs.filter(related_count=1).update(value='Foo'), 1)
# Update where annotation is used in update parameters
# #26539 - This isn't forbidden but also doesn't generate proper SQL
# qs = RelatedPoint.objects.annotate(data_name=F('data__name'))
# updated = qs.update(name=F('data_name'))
# self.assertEqual(updated, 1)
# Update where aggregation annotation is used in update parameters
qs = RelatedPoint.objects.annotate(max=Max('data__value'))
with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
qs.update(name=F('max'))