From 81739a45b5ae8f534910aaabc7e9b457eaa34163 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20Abac=C4=B1?= Date: Tue, 30 Nov 2021 16:50:13 +0300 Subject: [PATCH] Fixed #33319 -- Fixed crash when combining with the | operator querysets with aliases that conflict. --- AUTHORS | 1 + django/db/models/sql/query.py | 25 +++++++++++++++++-------- tests/queries/models.py | 3 ++- tests/queries/tests.py | 22 +++++++++++++++++++++- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/AUTHORS b/AUTHORS index 738c0a04c22..e13d13d4aea 100644 --- a/AUTHORS +++ b/AUTHORS @@ -731,6 +731,7 @@ answer newbie questions, and generally made Django that much better: Oscar Ramirez Ossama M. Khayat Owen Griffiths + Ömer Faruk Abacı Pablo Martín Panos Laganakos Paolo Melchiorre diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fe8ac873b1f..b13c7b68932 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -572,6 +572,15 @@ class Query(BaseExpression): if self.distinct_fields != rhs.distinct_fields: raise TypeError('Cannot combine queries with different distinct fields.') + # If lhs and rhs shares the same alias prefix, it is possible to have + # conflicting alias changes like T4 -> T5, T5 -> T6, which might end up + # as T4 -> T6 while combining two querysets. To prevent this, change an + # alias prefix of the rhs and update current aliases accordingly, + # except if the alias is the base table since it must be present in the + # query on both sides. + initial_alias = self.get_initial_alias() + rhs.bump_prefix(self, exclude={initial_alias}) + # Work out how to relabel the rhs aliases, if necessary. change_map = {} conjunction = (connector == AND) @@ -589,9 +598,6 @@ class Query(BaseExpression): # the AND case. The results will be correct but this creates too many # joins. This is something that could be fixed later on. reuse = set() if conjunction else set(self.alias_map) - # Base table must be present in the query - this is the same - # table on both sides. - self.get_initial_alias() joinpromoter = JoinPromoter(connector, 2, False) joinpromoter.add_votes( j for j in self.alias_map if self.alias_map[j].join_type == INNER) @@ -882,12 +888,12 @@ class Query(BaseExpression): for alias, aliased in self.external_aliases.items() } - def bump_prefix(self, outer_query): + def bump_prefix(self, other_query, exclude=None): """ Change the alias prefix to the next letter in the alphabet in a way - that the outer query's aliases and this query's aliases will not + that the other query's aliases and this query's aliases will not conflict. Even tables that previously had no alias will get an alias - after this call. + after this call. To prevent changing aliases use the exclude parameter. """ def prefix_gen(): """ @@ -907,7 +913,7 @@ class Query(BaseExpression): yield ''.join(s) prefix = None - if self.alias_prefix != outer_query.alias_prefix: + if self.alias_prefix != other_query.alias_prefix: # No clashes between self and outer query should be possible. return @@ -925,10 +931,13 @@ class Query(BaseExpression): 'Maximum recursion depth exceeded: too many subqueries.' ) self.subq_aliases = self.subq_aliases.union([self.alias_prefix]) - outer_query.subq_aliases = outer_query.subq_aliases.union(self.subq_aliases) + other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases) + if exclude is None: + exclude = {} self.change_aliases({ alias: '%s%d' % (self.alias_prefix, pos) for pos, alias in enumerate(self.alias_map) + if alias not in exclude }) def get_initial_alias(self): diff --git a/tests/queries/models.py b/tests/queries/models.py index c3322c224e7..8e7ee1625ce 100644 --- a/tests/queries/models.py +++ b/tests/queries/models.py @@ -613,13 +613,14 @@ class OrderItem(models.Model): class BaseUser(models.Model): - pass + annotation = models.ForeignKey(Annotation, models.CASCADE, null=True, blank=True) class Task(models.Model): title = models.CharField(max_length=10) owner = models.ForeignKey(BaseUser, models.CASCADE, related_name='owner') creator = models.ForeignKey(BaseUser, models.CASCADE, related_name='creator') + note = models.ForeignKey(Note, on_delete=models.CASCADE, null=True, blank=True) def __str__(self): return self.title diff --git a/tests/queries/tests.py b/tests/queries/tests.py index ca982b9c6a3..f146bc89fd6 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -15,7 +15,7 @@ from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test.utils import CaptureQueriesContext from .models import ( - FK1, Annotation, Article, Author, BaseA, Book, CategoryItem, + FK1, Annotation, Article, Author, BaseA, BaseUser, Book, CategoryItem, CategoryRelationship, Celebrity, Channel, Chapter, Child, ChildObjectA, Classroom, CommonMixedCaseForeignKeys, Company, Cover, CustomPk, CustomPkTag, DateTimePK, Detail, DumbCategory, Eaten, Employment, @@ -2094,6 +2094,15 @@ class QuerySetBitwiseOperationTests(TestCase): cls.room_2 = Classroom.objects.create(school=cls.school, has_blackboard=True, name='Room 2') cls.room_3 = Classroom.objects.create(school=cls.school, has_blackboard=True, name='Room 3') cls.room_4 = Classroom.objects.create(school=cls.school, has_blackboard=False, name='Room 4') + tag = Tag.objects.create() + cls.annotation_1 = Annotation.objects.create(tag=tag) + annotation_2 = Annotation.objects.create(tag=tag) + note = cls.annotation_1.notes.create(tag=tag) + cls.base_user_1 = BaseUser.objects.create(annotation=cls.annotation_1) + cls.base_user_2 = BaseUser.objects.create(annotation=annotation_2) + cls.task = Task.objects.create( + owner=cls.base_user_2, creator=cls.base_user_2, note=note, + ) @skipUnlessDBFeature('allow_sliced_subqueries_with_in') def test_or_with_rhs_slice(self): @@ -2130,6 +2139,17 @@ class QuerySetBitwiseOperationTests(TestCase): nested_combined = School.objects.filter(pk__in=combined.values('pk')) self.assertSequenceEqual(nested_combined, [self.school]) + def test_conflicting_aliases_during_combine(self): + qs1 = self.annotation_1.baseuser_set.all() + qs2 = BaseUser.objects.filter( + Q(owner__note__in=self.annotation_1.notes.all()) | + Q(creator__note__in=self.annotation_1.notes.all()) + ) + self.assertSequenceEqual(qs1, [self.base_user_1]) + self.assertSequenceEqual(qs2, [self.base_user_2]) + self.assertCountEqual(qs2 | qs1, qs1 | qs2) + self.assertCountEqual(qs2 | qs1, [self.base_user_1, self.base_user_2]) + class CloneTests(TestCase):