From f92d73fbd48aa22c4f0d6b155e795b65ecc6355c Mon Sep 17 00:00:00 2001 From: Russell Keith-Magee Date: Sat, 3 Apr 2010 11:45:31 +0000 Subject: [PATCH] Fixed #12247 -- Corrected the way update queries are processed when the update only refers to attributes on a base class. Thanks to jsmullyan for the report, and matiasb for the fix. git-svn-id: http://code.djangoproject.com/svn/django/trunk@12910 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/models/sql/subqueries.py | 2 +- tests/modeltests/update/models.py | 13 ++++++++ tests/modeltests/update/tests.py | 49 ++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/modeltests/update/tests.py diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 6b52729f68..a066dfeca8 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -131,7 +131,7 @@ class UpdateQuery(Query): for model, values in self.related_updates.iteritems(): query = UpdateQuery(model) query.values = values - if self.related_ids: + if self.related_ids is not None: query.add_filter(('pk__in', self.related_ids)) result.append(query) return result diff --git a/tests/modeltests/update/models.py b/tests/modeltests/update/models.py index 0ffd029437..9ce672f2b3 100644 --- a/tests/modeltests/update/models.py +++ b/tests/modeltests/update/models.py @@ -21,6 +21,19 @@ class RelatedPoint(models.Model): return unicode(self.name) +class A(models.Model): + x = models.IntegerField(default=10) + +class B(models.Model): + a = models.ForeignKey(A) + y = models.IntegerField(default=10) + +class C(models.Model): + y = models.IntegerField(default=10) + +class D(C): + a = models.ForeignKey(A) + __test__ = {'API_TESTS': """ >>> DataPoint(name="d0", value="apple").save() >>> DataPoint(name="d2", value="banana").save() diff --git a/tests/modeltests/update/tests.py b/tests/modeltests/update/tests.py new file mode 100644 index 0000000000..05397f8306 --- /dev/null +++ b/tests/modeltests/update/tests.py @@ -0,0 +1,49 @@ +from django.test import TestCase + +from models import A, B, D + +class SimpleTest(TestCase): + def setUp(self): + self.a1 = A.objects.create() + self.a2 = A.objects.create() + for x in range(20): + B.objects.create(a=self.a1) + D.objects.create(a=self.a1) + + def test_nonempty_update(self): + """ + Test that update changes the right number of rows for a nonempty queryset + """ + num_updated = self.a1.b_set.update(y=100) + self.failUnlessEqual(num_updated, 20) + cnt = B.objects.filter(y=100).count() + self.failUnlessEqual(cnt, 20) + + def test_empty_update(self): + """ + Test that update changes the right number of rows for an empty queryset + """ + num_updated = self.a2.b_set.update(y=100) + self.failUnlessEqual(num_updated, 0) + cnt = B.objects.filter(y=100).count() + self.failUnlessEqual(cnt, 0) + + def test_nonempty_update_with_inheritance(self): + """ + Test that update changes the right number of rows for an empty queryset + when the update affects only a base table + """ + num_updated = self.a1.d_set.update(y=100) + self.failUnlessEqual(num_updated, 20) + cnt = D.objects.filter(y=100).count() + self.failUnlessEqual(cnt, 20) + + def test_empty_update_with_inheritance(self): + """ + Test that update changes the right number of rows for an empty queryset + when the update affects only a base table + """ + num_updated = self.a2.d_set.update(y=100) + self.failUnlessEqual(num_updated, 0) + cnt = D.objects.filter(y=100).count() + self.failUnlessEqual(cnt, 0)