Fixed handling of multiple fields in a model pointing to the same related model.

Thanks to ElliotM, mk and oyvind for some excellent test cases for this. Fixed #7110, #7125.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@7778 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2008-06-29 02:36:18 +00:00
parent d800c0b031
commit bb2182453b
6 changed files with 217 additions and 12 deletions

View File

@ -692,6 +692,11 @@ class ForeignKey(RelatedField, Field):
def contribute_to_class(self, cls, name):
super(ForeignKey, self).contribute_to_class(cls, name)
setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
if isinstance(self.rel.to, basestring):
target = self.rel.to
else:
target = self.rel.to._meta.db_table
cls._meta.duplicate_targets[self.column] = (target, "o2m")
def contribute_to_related_class(self, cls, related):
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
@ -826,6 +831,12 @@ class ManyToManyField(RelatedField, Field):
# Set up the accessor for the m2m table name for the relation
self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
if isinstance(self.rel.to, basestring):
target = self.rel.to
else:
target = self.rel.to._meta.db_table
cls._meta.duplicate_targets[self.column] = (target, "m2m")
def contribute_to_related_class(self, cls, related):
# m2m relations to self do not have a ManyRelatedObjectsDescriptor,
# as it would be redundant - unless the field is non-symmetrical.

View File

@ -44,6 +44,7 @@ class Options(object):
self.one_to_one_field = None
self.abstract = False
self.parents = SortedDict()
self.duplicate_targets = {}
def contribute_to_class(self, cls, name):
from django.db import connection
@ -115,6 +116,24 @@ class Options(object):
auto_created=True)
model.add_to_class('id', auto)
# Determine any sets of fields that are pointing to the same targets
# (e.g. two ForeignKeys to the same remote model). The query
# construction code needs to know this. At the end of this,
# self.duplicate_targets will map each duplicate field column to the
# columns it duplicates.
collections = {}
for column, target in self.duplicate_targets.iteritems():
try:
collections[target].add(column)
except KeyError:
collections[target] = set([column])
self.duplicate_targets = {}
for elt in collections.itervalues():
if len(elt) == 1:
continue
for column in elt:
self.duplicate_targets[column] = elt.difference(set([column]))
def add_field(self, field):
# Insert the given field in the order in which it was created, using
# the "creation_counter" attribute of the field.

View File

@ -57,6 +57,7 @@ class Query(object):
self.start_meta = None
self.select_fields = []
self.related_select_fields = []
self.dupe_avoidance = {}
# SQL-related attributes
self.select = []
@ -165,6 +166,7 @@ class Query(object):
obj.start_meta = self.start_meta
obj.select_fields = self.select_fields[:]
obj.related_select_fields = self.related_select_fields[:]
obj.dupe_avoidance = self.dupe_avoidance.copy()
obj.select = self.select[:]
obj.tables = self.tables[:]
obj.where = deepcopy(self.where)
@ -830,8 +832,8 @@ class Query(object):
if reuse and always_create and table in self.table_map:
# Convert the 'reuse' to case to be "exclude everything but the
# reusable set for this table".
exclusions = set(self.table_map[table]).difference(reuse)
# reusable set, minus exclusions, for this table".
exclusions = set(self.table_map[table]).difference(reuse).union(set(exclusions))
always_create = False
t_ident = (lhs_table, table, lhs_col, col)
if not always_create:
@ -866,7 +868,8 @@ class Query(object):
return alias
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used=None, requested=None, restricted=None, nullable=None):
used=None, requested=None, restricted=None, nullable=None,
dupe_set=None):
"""
Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model
@ -876,6 +879,7 @@ class Query(object):
if not restricted and self.max_depth and cur_depth > self.max_depth:
# We've recursed far enough; bail out.
return
if not opts:
opts = self.get_meta()
root_alias = self.get_initial_alias()
@ -883,6 +887,10 @@ class Query(object):
self.related_select_fields = []
if not used:
used = set()
if dupe_set is None:
dupe_set = set()
orig_dupe_set = dupe_set
orig_used = used
# Setup for the case when only particular related fields should be
# included in the related selection.
@ -897,6 +905,8 @@ class Query(object):
if (not f.rel or (restricted and f.name not in requested) or
(not restricted and f.null) or f.rel.parent_link):
continue
dupe_set = orig_dupe_set.copy()
used = orig_used.copy()
table = f.rel.to._meta.db_table
if nullable or f.null:
promote = True
@ -907,12 +917,26 @@ class Query(object):
alias = root_alias
for int_model in opts.get_base_chain(model):
lhs_col = int_opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets
if dedupe:
used.update(self.dupe_avoidance.get(id(opts), lhs_col),
())
dupe_set.add((opts, lhs_col))
int_opts = int_model._meta
alias = self.join((alias, int_opts.db_table, lhs_col,
int_opts.pk.column), exclusions=used,
promote=promote)
for (dupe_opts, dupe_col) in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
else:
alias = root_alias
dedupe = f.column in opts.duplicate_targets
if dupe_set or dedupe:
used.update(self.dupe_avoidance.get((id(opts), f.column), ()))
if dedupe:
dupe_set.add((opts, f.column))
alias = self.join((alias, table, f.column,
f.rel.get_related_field().column), exclusions=used,
promote=promote)
@ -928,8 +952,10 @@ class Query(object):
new_nullable = f.null
else:
new_nullable = None
for dupe_opts, dupe_col in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable)
used, next, restricted, new_nullable, dupe_set)
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
can_reuse=None):
@ -1128,7 +1154,9 @@ class Query(object):
(which gives the table we are joining to), 'alias' is the alias for the
table we are joining to. If dupe_multis is True, any many-to-many or
many-to-one joins will always create a new alias (necessary for
disjunctive filters).
disjunctive filters). If can_reuse is not None, it's a list of aliases
that can be reused in these joins (nothing else can be reused in this
case).
Returns the final field involved in the join, the target database
column (used for any 'where' constraint), the final 'opts' value and the
@ -1136,7 +1164,14 @@ class Query(object):
"""
joins = [alias]
last = [0]
dupe_set = set()
exclusions = set()
for pos, name in enumerate(names):
try:
exclusions.add(int_alias)
except NameError:
pass
exclusions.add(alias)
last.append(len(joins))
if name == 'pk':
name = opts.pk.name
@ -1155,6 +1190,7 @@ class Query(object):
names = opts.get_all_field_names()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
if not allow_many and (m2m or not direct):
for alias in joins:
self.unref_alias(alias)
@ -1164,12 +1200,27 @@ class Query(object):
alias_list = []
for int_model in opts.get_base_chain(model):
lhs_col = opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets
if dedupe:
exclusions.update(self.dupe_avoidance.get(
(id(opts), lhs_col), ()))
dupe_set.add((opts, lhs_col))
opts = int_model._meta
alias = self.join((alias, opts.db_table, lhs_col,
opts.pk.column), exclusions=joins)
opts.pk.column), exclusions=exclusions)
joins.append(alias)
exclusions.add(alias)
for (dupe_opts, dupe_col) in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
cached_data = opts._join_cache.get(name)
orig_opts = opts
dupe_col = direct and field.column or field.field.column
dedupe = dupe_col in opts.duplicate_targets
if dupe_set or dedupe:
if dedupe:
dupe_set.add((opts, dupe_col))
exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col),
()))
if direct:
if m2m:
@ -1191,9 +1242,11 @@ class Query(object):
target)
int_alias = self.join((alias, table1, from_col1, to_col1),
dupe_multis, joins, nullable=True, reuse=can_reuse)
dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
alias = self.join((int_alias, table2, from_col2, to_col2),
dupe_multis, joins, nullable=True, reuse=can_reuse)
dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
joins.extend([int_alias, alias])
elif field.rel:
# One-to-one or many-to-one field
@ -1209,7 +1262,7 @@ class Query(object):
opts, target)
alias = self.join((alias, table, from_col, to_col),
exclusions=joins, nullable=field.null)
exclusions=exclusions, nullable=field.null)
joins.append(alias)
else:
# Non-relation fields.
@ -1237,9 +1290,11 @@ class Query(object):
target)
int_alias = self.join((alias, table1, from_col1, to_col1),
dupe_multis, joins, nullable=True, reuse=can_reuse)
dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
alias = self.join((int_alias, table2, from_col2, to_col2),
dupe_multis, joins, nullable=True, reuse=can_reuse)
dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
joins.extend([int_alias, alias])
else:
# One-to-many field (ForeignKey defined on the target model)
@ -1257,14 +1312,34 @@ class Query(object):
opts, target)
alias = self.join((alias, table, from_col, to_col),
dupe_multis, joins, nullable=True, reuse=can_reuse)
dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
joins.append(alias)
for (dupe_opts, dupe_col) in dupe_set:
try:
self.update_dupe_avoidance(dupe_opts, dupe_col, int_alias)
except NameError:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
if pos != len(names) - 1:
raise FieldError("Join on field %r not permitted." % name)
return field, target, opts, joins, last
def update_dupe_avoidance(self, opts, col, alias):
"""
For a column that is one of multiple pointing to the same table, update
the internal data structures to note that this alias shouldn't be used
for those other columns.
"""
ident = id(opts)
for name in opts.duplicate_targets[col]:
try:
self.dupe_avoidance[ident, name].add(alias)
except KeyError:
self.dupe_avoidance[ident, name] = set([alias])
def split_exclude(self, filter_expr, prefix):
"""
When doing an exclude against any kind of N-to-many relation, we need

View File

@ -28,6 +28,24 @@ class Child(models.Model):
parent = models.ForeignKey(Parent)
# Multiple paths to the same model (#7110, #7125)
class Category(models.Model):
name = models.CharField(max_length=20)
def __unicode__(self):
return self.name
class Record(models.Model):
category = models.ForeignKey(Category)
class Relation(models.Model):
left = models.ForeignKey(Record, related_name='left_set')
right = models.ForeignKey(Record, related_name='right_set')
def __unicode__(self):
return u"%s - %s" % (self.left.category.name, self.right.category.name)
__test__ = {'API_TESTS':"""
>>> Third.objects.create(id='3', name='An example')
<Third: Third object>
@ -73,4 +91,26 @@ Traceback (most recent call last):
...
ValueError: Cannot assign "<First: First object>": "Child.parent" must be a "Parent" instance.
# Test of multiple ForeignKeys to the same model (bug #7125)
>>> c1 = Category.objects.create(name='First')
>>> c2 = Category.objects.create(name='Second')
>>> c3 = Category.objects.create(name='Third')
>>> r1 = Record.objects.create(category=c1)
>>> r2 = Record.objects.create(category=c1)
>>> r3 = Record.objects.create(category=c2)
>>> r4 = Record.objects.create(category=c2)
>>> r5 = Record.objects.create(category=c3)
>>> r = Relation.objects.create(left=r1, right=r2)
>>> r = Relation.objects.create(left=r3, right=r4)
>>> r = Relation.objects.create(left=r1, right=r3)
>>> r = Relation.objects.create(left=r5, right=r2)
>>> r = Relation.objects.create(left=r3, right=r2)
>>> Relation.objects.filter(left__category__name__in=['First'], right__category__name__in=['Second'])
[<Relation: First - Second>]
>>> Category.objects.filter(record__left_set__right__category__name='Second').order_by('name')
[<Category: First>, <Category: Second>]
"""}

View File

@ -0,0 +1,60 @@
from django.db import models
class Building(models.Model):
name = models.CharField(max_length=10)
def __unicode__(self):
return u"Building: %s" % self.name
class Device(models.Model):
building = models.ForeignKey('Building')
name = models.CharField(max_length=10)
def __unicode__(self):
return u"device '%s' in building %s" % (self.name, self.building)
class Port(models.Model):
device = models.ForeignKey('Device')
number = models.CharField(max_length=10)
def __unicode__(self):
return u"%s/%s" % (self.device.name, self.number)
class Connection(models.Model):
start = models.ForeignKey(Port, related_name='connection_start',
unique=True)
end = models.ForeignKey(Port, related_name='connection_end', unique=True)
def __unicode__(self):
return u"%s to %s" % (self.start, self.end)
__test__ = {'API_TESTS': """
Regression test for bug #7110. When using select_related(), we must query the
Device and Building tables using two different aliases (each) in order to
differentiate the start and end Connection fields. The net result is that both
the "connections = ..." queries here should give the same results.
>>> b=Building.objects.create(name='101')
>>> dev1=Device.objects.create(name="router", building=b)
>>> dev2=Device.objects.create(name="switch", building=b)
>>> dev3=Device.objects.create(name="server", building=b)
>>> port1=Port.objects.create(number='4',device=dev1)
>>> port2=Port.objects.create(number='7',device=dev2)
>>> port3=Port.objects.create(number='1',device=dev3)
>>> c1=Connection.objects.create(start=port1, end=port2)
>>> c2=Connection.objects.create(start=port2, end=port3)
>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).order_by('id')
>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections]
[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')]
>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).select_related().order_by('id')
>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections]
[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')]
# This final query should only join seven tables (port, device and building
# twice each, plus connection once).
>>> connections.query.count_active_tables()
7
"""}