Fixed #8106 -- Untangled some problems with complex select_related() queries

and models that have multiple paths to them from other models.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@8559 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2008-08-25 22:43:25 +00:00
parent 6abe0460c8
commit 3deff41a32
2 changed files with 57 additions and 9 deletions

View File

@ -913,7 +913,7 @@ class Query(object):
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used=None, requested=None, restricted=None, nullable=None, used=None, requested=None, restricted=None, nullable=None,
dupe_set=None): dupe_set=None, avoid_set=None):
""" """
Fill in the information needed for a select_related query. The current Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model depth is measured as the number of connections away from the root model
@ -933,8 +933,9 @@ class Query(object):
used = set() used = set()
if dupe_set is None: if dupe_set is None:
dupe_set = set() dupe_set = set()
if avoid_set is None:
avoid_set = set()
orig_dupe_set = dupe_set orig_dupe_set = dupe_set
orig_used = used
# Setup for the case when only particular related fields should be # Setup for the case when only particular related fields should be
# included in the related selection. # included in the related selection.
@ -948,8 +949,12 @@ class Query(object):
for f, model in opts.get_fields_with_model(): for f, model in opts.get_fields_with_model():
if not select_related_descend(f, restricted, requested): if not select_related_descend(f, restricted, requested):
continue continue
# The "avoid" set is aliases we want to avoid just for this
# particular branch of the recursion. They aren't permanently
# forbidden from reuse in the related selection tables (which is
# what "used" specifies).
avoid = avoid_set.copy()
dupe_set = orig_dupe_set.copy() dupe_set = orig_dupe_set.copy()
used = orig_used.copy()
table = f.rel.to._meta.db_table table = f.rel.to._meta.db_table
if nullable or f.null: if nullable or f.null:
promote = True promote = True
@ -962,7 +967,7 @@ class Query(object):
lhs_col = int_opts.parents[int_model].column lhs_col = int_opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets dedupe = lhs_col in opts.duplicate_targets
if dedupe: if dedupe:
used.update(self.dupe_avoidance.get(id(opts), lhs_col), avoid.update(self.dupe_avoidance.get(id(opts), lhs_col),
()) ())
dupe_set.add((opts, lhs_col)) dupe_set.add((opts, lhs_col))
int_opts = int_model._meta int_opts = int_model._meta
@ -976,13 +981,13 @@ class Query(object):
dedupe = f.column in opts.duplicate_targets dedupe = f.column in opts.duplicate_targets
if dupe_set or dedupe: if dupe_set or dedupe:
used.update(self.dupe_avoidance.get((id(opts), f.column), ())) avoid.update(self.dupe_avoidance.get((id(opts), f.column), ()))
if dedupe: if dedupe:
dupe_set.add((opts, f.column)) dupe_set.add((opts, f.column))
alias = self.join((alias, table, f.column, alias = self.join((alias, table, f.column,
f.rel.get_related_field().column), exclusions=used, f.rel.get_related_field().column),
promote=promote) exclusions=used.union(avoid), promote=promote)
used.add(alias) used.add(alias)
self.related_select_cols.extend(self.get_default_columns( self.related_select_cols.extend(self.get_default_columns(
start_alias=alias, opts=f.rel.to._meta, as_pairs=True)[0]) start_alias=alias, opts=f.rel.to._meta, as_pairs=True)[0])
@ -998,7 +1003,7 @@ class Query(object):
for dupe_opts, dupe_col in dupe_set: for dupe_opts, dupe_col in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias) self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable, dupe_set) used, next, restricted, new_nullable, dupe_set, avoid)
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
can_reuse=None): can_reuse=None):

View File

@ -28,11 +28,35 @@ class Connection(models.Model):
def __unicode__(self): def __unicode__(self):
return u"%s to %s" % (self.start, self.end) return u"%s to %s" % (self.start, self.end)
# Another non-tree hierarchy that exercises code paths similar to the above
# example, but in a slightly different configuration.
class TUser(models.Model):
name = models.CharField(max_length=200)
class Person(models.Model):
user = models.ForeignKey(TUser, unique=True)
class Organizer(models.Model):
person = models.ForeignKey(Person)
class Student(models.Model):
person = models.ForeignKey(Person)
class Class(models.Model):
org = models.ForeignKey(Organizer)
class Enrollment(models.Model):
std = models.ForeignKey(Student)
cls = models.ForeignKey(Class)
__test__ = {'API_TESTS': """ __test__ = {'API_TESTS': """
Regression test for bug #7110. When using select_related(), we must query the Regression test for bug #7110. When using select_related(), we must query the
Device and Building tables using two different aliases (each) in order to Device and Building tables using two different aliases (each) in order to
differentiate the start and end Connection fields. The net result is that both differentiate the start and end Connection fields. The net result is that both
the "connections = ..." queries here should give the same results. the "connections = ..." queries here should give the same results without
pulling in more than the absolute minimum number of tables (history has
shown that it's easy to make a mistake in the implementation and include some
unnecessary bonus joins).
>>> b=Building.objects.create(name='101') >>> b=Building.objects.create(name='101')
>>> dev1=Device.objects.create(name="router", building=b) >>> dev1=Device.objects.create(name="router", building=b)
@ -57,4 +81,23 @@ the "connections = ..." queries here should give the same results.
>>> connections.query.count_active_tables() >>> connections.query.count_active_tables()
7 7
Regression test for bug #8106. Same sort of problem as the previous test, but
this time there are more extra tables to pull in as part of the
select_related() and some of them could potentially clash (so need to be kept
separate).
>>> us = TUser.objects.create(name="std")
>>> usp = Person.objects.create(user=us)
>>> uo = TUser.objects.create(name="org")
>>> uop = Person.objects.create(user=uo)
>>> s = Student.objects.create(person = usp)
>>> o = Organizer.objects.create(person = uop)
>>> c = Class.objects.create(org=o)
>>> e = Enrollment.objects.create(std=s, cls=c)
>>> e_related = Enrollment.objects.all().select_related()[0]
>>> e_related.std.person.user.name
u"std"
>>> e_related.cls.org.person.user.name
u"org"
"""} """}