From 2e56066a5b93b528fbce37285bac591b44bc6ed7 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Tue, 23 Aug 2011 03:38:42 +0000 Subject: [PATCH] Fixed an isnull=False filtering edge-case. Fixes #15316. The bulk of this patch is due to some fine analysis from Aleksandra Sendecka. git-svn-id: http://code.djangoproject.com/svn/django/trunk@16656 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- AUTHORS | 1 + django/db/models/sql/query.py | 19 +++- tests/regressiontests/queries/models.py | 26 +++++ tests/regressiontests/queries/tests.py | 130 +++++++++++++++++++++++- 4 files changed, 168 insertions(+), 8 deletions(-) diff --git a/AUTHORS b/AUTHORS index 5635f2141d..38c7601579 100644 --- a/AUTHORS +++ b/AUTHORS @@ -448,6 +448,7 @@ answer newbie questions, and generally made Django that much better: schwank@gmail.com scott@staplefish.com Ilya Semenov + Aleksandra Sendecka serbaut@gmail.com John Shaffer Pete Shinners diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fd2e7a6d9a..e5e11f472a 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1105,7 +1105,10 @@ class Query(object): # Process the join list to see if we can remove any inner joins from # the far end (fewer tables in a query is better). - col, alias, join_list = self.trim_joins(target, join_list, last, trim) + nonnull_comparison = (lookup_type == 'isnull' and value is False) + col, alias, join_list = self.trim_joins(target, join_list, last, trim, + nonnull_comparison) + if connector == OR: # Some joins may need to be promoted when adding a new filter to a # disjunction. We walk the list of new joins and where it diverges @@ -1442,7 +1445,7 @@ class Query(object): return field, target, opts, joins, last, extra_filters - def trim_joins(self, target, join_list, last, trim): + def trim_joins(self, target, join_list, last, trim, nonnull_check=False): """ Sometimes joins at the end of a multi-table sequence can be trimmed. If the final join is against the same column as we are comparing against, @@ -1463,6 +1466,11 @@ class Query(object): trimmed before anything. See the documentation of add_filter() for details about this. + The 'nonnull_check' parameter is True when we are using inner joins + between tables explicitly to exclude NULL entries. In that case, the + tables shouldn't be trimmed, because the very action of joining to them + alters the result set. + Returns the final active column and table alias and the new active join_list. """ @@ -1470,7 +1478,7 @@ class Query(object): penultimate = last.pop() if penultimate == final: penultimate = last.pop() - if trim and len(join_list) > 1: + if trim and final > 1: extra = join_list[penultimate:] join_list = join_list[:penultimate] final = penultimate @@ -1483,12 +1491,13 @@ class Query(object): alias = join_list[-1] while final > 1: join = self.alias_map[alias] - if col != join[RHS_JOIN_COL] or join[JOIN_TYPE] != self.INNER: + if (col != join[RHS_JOIN_COL] or join[JOIN_TYPE] != self.INNER or + nonnull_check): break self.unref_alias(alias) alias = join[LHS_ALIAS] col = join[LHS_JOIN_COL] - join_list = join_list[:-1] + join_list.pop() final -= 1 if final == penultimate: penultimate = last.pop() diff --git a/tests/regressiontests/queries/models.py b/tests/regressiontests/queries/models.py index d1e5e6ea39..bb8099f60d 100644 --- a/tests/regressiontests/queries/models.py +++ b/tests/regressiontests/queries/models.py @@ -317,3 +317,29 @@ class ObjectC(models.Model): def __unicode__(self): return self.name + +class SimpleCategory(models.Model): + name = models.CharField(max_length=10) + + def __unicode__(self): + return self.name + +class SpecialCategory(SimpleCategory): + special_name = models.CharField(max_length=10) + + def __unicode__(self): + return self.name + " " + self.special_name + +class CategoryItem(models.Model): + category = models.ForeignKey(SimpleCategory) + + def __unicode__(self): + return "category item: " + str(self.category) + +class OneToOneCategory(models.Model): + new_name = models.CharField(max_length=10) + category = models.OneToOneField(SimpleCategory) + + def __unicode__(self): + return "one2one " + self.new_name + \ No newline at end of file diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py index 4181ea5170..a66752ab82 100644 --- a/tests/regressiontests/queries/tests.py +++ b/tests/regressiontests/queries/tests.py @@ -15,7 +15,7 @@ from models import (Annotation, Article, Author, Celebrity, Child, Cover, Detail DumbCategory, ExtraInfo, Fan, Item, LeafA, LoopX, LoopZ, ManagedModel, Member, NamedCategory, Note, Number, Plaything, PointerA, Ranking, Related, Report, ReservedName, Tag, TvChef, Valid, X, Food, Eaten, Node, ObjectA, ObjectB, - ObjectC) + ObjectC, CategoryItem, SimpleCategory, SpecialCategory, OneToOneCategory) class BaseQuerysetTest(TestCase): @@ -1043,11 +1043,135 @@ class Queries4Tests(BaseQuerysetTest): [] ) + def test_ticket15316_filter_false(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", + special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", + special_name="special2") + + ci1 = CategoryItem.objects.create(category=c1) + ci2 = CategoryItem.objects.create(category=c2) + ci3 = CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.filter(category__specialcategory__isnull=False) + self.assertEqual(qs.count(), 2) + self.assertQuerysetEqual(qs, [ci2.pk, ci3.pk], lambda x: x.pk, False) + + def test_ticket15316_exclude_false(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", + special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", + special_name="special2") + + ci1 = CategoryItem.objects.create(category=c1) + ci2 = CategoryItem.objects.create(category=c2) + ci3 = CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.exclude(category__specialcategory__isnull=False) + self.assertEqual(qs.count(), 1) + self.assertQuerysetEqual(qs, [ci1.pk], lambda x: x.pk) + + def test_ticket15316_filter_true(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", + special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", + special_name="special2") + + ci1 = CategoryItem.objects.create(category=c1) + ci2 = CategoryItem.objects.create(category=c2) + ci3 = CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.filter(category__specialcategory__isnull=True) + self.assertEqual(qs.count(), 1) + self.assertQuerysetEqual(qs, [ci1.pk], lambda x: x.pk) + + def test_ticket15316_exclude_true(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", + special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", + special_name="special2") + + ci1 = CategoryItem.objects.create(category=c1) + ci2 = CategoryItem.objects.create(category=c2) + ci3 = CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.exclude(category__specialcategory__isnull=True) + self.assertEqual(qs.count(), 2) + self.assertQuerysetEqual(qs, [ci2.pk, ci3.pk], lambda x: x.pk, False) + + def test_ticket15316_one2one_filter_false(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + c2 = OneToOneCategory.objects.create(category = c1, new_name="new1") + c3 = OneToOneCategory.objects.create(category = c0, new_name="new2") + + ci1 = CategoryItem.objects.create(category=c) + ci2 = CategoryItem.objects.create(category=c0) + ci3 = CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.filter(category__onetoonecategory__isnull=False) + self.assertEqual(qs.count(), 2) + self.assertQuerysetEqual(qs, [ci2.pk, ci3.pk], lambda x: x.pk, False) + + def test_ticket15316_one2one_exclude_false(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + c2 = OneToOneCategory.objects.create(category = c1, new_name="new1") + c3 = OneToOneCategory.objects.create(category = c0, new_name="new2") + + ci1 = CategoryItem.objects.create(category=c) + ci2 = CategoryItem.objects.create(category=c0) + ci3 = CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.exclude(category__onetoonecategory__isnull=False) + self.assertEqual(qs.count(), 1) + self.assertQuerysetEqual(qs, [ci1.pk], lambda x: x.pk) + + def test_ticket15316_one2one_filter_true(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + c2 = OneToOneCategory.objects.create(category = c1, new_name="new1") + c3 = OneToOneCategory.objects.create(category = c0, new_name="new2") + + ci1 = CategoryItem.objects.create(category=c) + ci2 = CategoryItem.objects.create(category=c0) + ci3 = CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.filter(category__onetoonecategory__isnull=True) + self.assertEqual(qs.count(), 1) + self.assertQuerysetEqual(qs, [ci1.pk], lambda x: x.pk) + + def test_ticket15316_one2one_exclude_true(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + c2 = OneToOneCategory.objects.create(category = c1, new_name="new1") + c3 = OneToOneCategory.objects.create(category = c0, new_name="new2") + + ci1 = CategoryItem.objects.create(category=c) + ci2 = CategoryItem.objects.create(category=c0) + ci3 = CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.exclude(category__onetoonecategory__isnull=True) + self.assertEqual(qs.count(), 2) + self.assertQuerysetEqual(qs, [ci2.pk, ci3.pk], lambda x: x.pk, False) + class Queries5Tests(TestCase): def setUp(self): - # Ordering by 'rank' gives us rank2, rank1, rank3. Ordering by the Meta.ordering - # will be rank3, rank2, rank1. + # Ordering by 'rank' gives us rank2, rank1, rank3. Ordering by the + # Meta.ordering will be rank3, rank2, rank1. n1 = Note.objects.create(note='n1', misc='foo', id=1) n2 = Note.objects.create(note='n2', misc='bar', id=2) e1 = ExtraInfo.objects.create(info='e1', note=n1)