This commit is contained in:
parent
06acb3445f
commit
a84344bc53
|
@ -632,6 +632,8 @@ class QuerySet(object):
|
||||||
self._for_write = True
|
self._for_write = True
|
||||||
query = self.query.clone(sql.UpdateQuery)
|
query = self.query.clone(sql.UpdateQuery)
|
||||||
query.add_update_values(kwargs)
|
query.add_update_values(kwargs)
|
||||||
|
# Clear any annotations so that they won't be present in subqueries.
|
||||||
|
query._annotations = None
|
||||||
with transaction.atomic(using=self.db, savepoint=False):
|
with transaction.atomic(using=self.db, savepoint=False):
|
||||||
rows = query.get_compiler(self.db).execute_sql(CURSOR)
|
rows = query.get_compiler(self.db).execute_sql(CURSOR)
|
||||||
self._result_cache = None
|
self._result_cache = None
|
||||||
|
|
|
@ -142,7 +142,11 @@ class UpdateQuery(Query):
|
||||||
that will be used to generate the UPDATE query. Might be more usefully
|
that will be used to generate the UPDATE query. Might be more usefully
|
||||||
called add_update_targets() to hint at the extra information here.
|
called add_update_targets() to hint at the extra information here.
|
||||||
"""
|
"""
|
||||||
self.values.extend(values_seq)
|
for field, model, val in values_seq:
|
||||||
|
if hasattr(val, 'resolve_expression'):
|
||||||
|
# Resolve expressions here so that annotations are no longer needed
|
||||||
|
val = val.resolve_expression(self, allow_joins=False, for_save=True)
|
||||||
|
self.values.append((field, model, val))
|
||||||
|
|
||||||
def add_related_update(self, model, field, value):
|
def add_related_update(self, model, field, value):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db.models import F, Max
|
from django.db.models import Count, F, Max
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint
|
from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint
|
||||||
|
@ -158,3 +158,23 @@ class AdvancedTests(TestCase):
|
||||||
qs = DataPoint.objects.annotate(max=Max('value'))
|
qs = DataPoint.objects.annotate(max=Max('value'))
|
||||||
with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
|
with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
|
||||||
qs.update(another_value=F('max'))
|
qs.update(another_value=F('max'))
|
||||||
|
|
||||||
|
def test_update_annotated_multi_table_queryset(self):
|
||||||
|
"""
|
||||||
|
Update of a queryset that's been annotated and involves multiple tables.
|
||||||
|
"""
|
||||||
|
# Trivial annotated update
|
||||||
|
qs = DataPoint.objects.annotate(related_count=Count('relatedpoint'))
|
||||||
|
self.assertEqual(qs.update(value='Foo'), 3)
|
||||||
|
# Update where annotation is used for filtering
|
||||||
|
qs = DataPoint.objects.annotate(related_count=Count('relatedpoint'))
|
||||||
|
self.assertEqual(qs.filter(related_count=1).update(value='Foo'), 1)
|
||||||
|
# Update where annotation is used in update parameters
|
||||||
|
# #26539 - This isn't forbidden but also doesn't generate proper SQL
|
||||||
|
# qs = RelatedPoint.objects.annotate(data_name=F('data__name'))
|
||||||
|
# updated = qs.update(name=F('data_name'))
|
||||||
|
# self.assertEqual(updated, 1)
|
||||||
|
# Update where aggregation annotation is used in update parameters
|
||||||
|
qs = RelatedPoint.objects.annotate(max=Max('data__value'))
|
||||||
|
with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
|
||||||
|
qs.update(name=F('max'))
|
||||||
|
|
Loading…
Reference in New Issue