Fixed handling of multiple fields in a model pointing to the same related model.

Thanks to ElliotM, mk and oyvind for some excellent test cases for this. Fixed #7110, #7125.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@7778 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Malcolm Tredinnick 2008-06-29 02:36:18 +00:00
parent d800c0b031
commit bb2182453b
6 changed files with 217 additions and 12 deletions

View File

@ -692,6 +692,11 @@ class ForeignKey(RelatedField, Field):
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
super(ForeignKey, self).contribute_to_class(cls, name) super(ForeignKey, self).contribute_to_class(cls, name)
setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self)) setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
if isinstance(self.rel.to, basestring):
target = self.rel.to
else:
target = self.rel.to._meta.db_table
cls._meta.duplicate_targets[self.column] = (target, "o2m")
def contribute_to_related_class(self, cls, related): def contribute_to_related_class(self, cls, related):
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
@ -826,6 +831,12 @@ class ManyToManyField(RelatedField, Field):
# Set up the accessor for the m2m table name for the relation # Set up the accessor for the m2m table name for the relation
self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
if isinstance(self.rel.to, basestring):
target = self.rel.to
else:
target = self.rel.to._meta.db_table
cls._meta.duplicate_targets[self.column] = (target, "m2m")
def contribute_to_related_class(self, cls, related): def contribute_to_related_class(self, cls, related):
# m2m relations to self do not have a ManyRelatedObjectsDescriptor, # m2m relations to self do not have a ManyRelatedObjectsDescriptor,
# as it would be redundant - unless the field is non-symmetrical. # as it would be redundant - unless the field is non-symmetrical.

View File

@ -44,6 +44,7 @@ class Options(object):
self.one_to_one_field = None self.one_to_one_field = None
self.abstract = False self.abstract = False
self.parents = SortedDict() self.parents = SortedDict()
self.duplicate_targets = {}
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
from django.db import connection from django.db import connection
@ -115,6 +116,24 @@ class Options(object):
auto_created=True) auto_created=True)
model.add_to_class('id', auto) model.add_to_class('id', auto)
# Determine any sets of fields that are pointing to the same targets
# (e.g. two ForeignKeys to the same remote model). The query
# construction code needs to know this. At the end of this,
# self.duplicate_targets will map each duplicate field column to the
# columns it duplicates.
collections = {}
for column, target in self.duplicate_targets.iteritems():
try:
collections[target].add(column)
except KeyError:
collections[target] = set([column])
self.duplicate_targets = {}
for elt in collections.itervalues():
if len(elt) == 1:
continue
for column in elt:
self.duplicate_targets[column] = elt.difference(set([column]))
def add_field(self, field): def add_field(self, field):
# Insert the given field in the order in which it was created, using # Insert the given field in the order in which it was created, using
# the "creation_counter" attribute of the field. # the "creation_counter" attribute of the field.

View File

@ -57,6 +57,7 @@ class Query(object):
self.start_meta = None self.start_meta = None
self.select_fields = [] self.select_fields = []
self.related_select_fields = [] self.related_select_fields = []
self.dupe_avoidance = {}
# SQL-related attributes # SQL-related attributes
self.select = [] self.select = []
@ -165,6 +166,7 @@ class Query(object):
obj.start_meta = self.start_meta obj.start_meta = self.start_meta
obj.select_fields = self.select_fields[:] obj.select_fields = self.select_fields[:]
obj.related_select_fields = self.related_select_fields[:] obj.related_select_fields = self.related_select_fields[:]
obj.dupe_avoidance = self.dupe_avoidance.copy()
obj.select = self.select[:] obj.select = self.select[:]
obj.tables = self.tables[:] obj.tables = self.tables[:]
obj.where = deepcopy(self.where) obj.where = deepcopy(self.where)
@ -830,8 +832,8 @@ class Query(object):
if reuse and always_create and table in self.table_map: if reuse and always_create and table in self.table_map:
# Convert the 'reuse' to case to be "exclude everything but the # Convert the 'reuse' to case to be "exclude everything but the
# reusable set for this table". # reusable set, minus exclusions, for this table".
exclusions = set(self.table_map[table]).difference(reuse) exclusions = set(self.table_map[table]).difference(reuse).union(set(exclusions))
always_create = False always_create = False
t_ident = (lhs_table, table, lhs_col, col) t_ident = (lhs_table, table, lhs_col, col)
if not always_create: if not always_create:
@ -866,7 +868,8 @@ class Query(object):
return alias return alias
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used=None, requested=None, restricted=None, nullable=None): used=None, requested=None, restricted=None, nullable=None,
dupe_set=None):
""" """
Fill in the information needed for a select_related query. The current Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model depth is measured as the number of connections away from the root model
@ -876,6 +879,7 @@ class Query(object):
if not restricted and self.max_depth and cur_depth > self.max_depth: if not restricted and self.max_depth and cur_depth > self.max_depth:
# We've recursed far enough; bail out. # We've recursed far enough; bail out.
return return
if not opts: if not opts:
opts = self.get_meta() opts = self.get_meta()
root_alias = self.get_initial_alias() root_alias = self.get_initial_alias()
@ -883,6 +887,10 @@ class Query(object):
self.related_select_fields = [] self.related_select_fields = []
if not used: if not used:
used = set() used = set()
if dupe_set is None:
dupe_set = set()
orig_dupe_set = dupe_set
orig_used = used
# Setup for the case when only particular related fields should be # Setup for the case when only particular related fields should be
# included in the related selection. # included in the related selection.
@ -897,6 +905,8 @@ class Query(object):
if (not f.rel or (restricted and f.name not in requested) or if (not f.rel or (restricted and f.name not in requested) or
(not restricted and f.null) or f.rel.parent_link): (not restricted and f.null) or f.rel.parent_link):
continue continue
dupe_set = orig_dupe_set.copy()
used = orig_used.copy()
table = f.rel.to._meta.db_table table = f.rel.to._meta.db_table
if nullable or f.null: if nullable or f.null:
promote = True promote = True
@ -907,12 +917,26 @@ class Query(object):
alias = root_alias alias = root_alias
for int_model in opts.get_base_chain(model): for int_model in opts.get_base_chain(model):
lhs_col = int_opts.parents[int_model].column lhs_col = int_opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets
if dedupe:
used.update(self.dupe_avoidance.get(id(opts), lhs_col),
())
dupe_set.add((opts, lhs_col))
int_opts = int_model._meta int_opts = int_model._meta
alias = self.join((alias, int_opts.db_table, lhs_col, alias = self.join((alias, int_opts.db_table, lhs_col,
int_opts.pk.column), exclusions=used, int_opts.pk.column), exclusions=used,
promote=promote) promote=promote)
for (dupe_opts, dupe_col) in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
else: else:
alias = root_alias alias = root_alias
dedupe = f.column in opts.duplicate_targets
if dupe_set or dedupe:
used.update(self.dupe_avoidance.get((id(opts), f.column), ()))
if dedupe:
dupe_set.add((opts, f.column))
alias = self.join((alias, table, f.column, alias = self.join((alias, table, f.column,
f.rel.get_related_field().column), exclusions=used, f.rel.get_related_field().column), exclusions=used,
promote=promote) promote=promote)
@ -928,8 +952,10 @@ class Query(object):
new_nullable = f.null new_nullable = f.null
else: else:
new_nullable = None new_nullable = None
for dupe_opts, dupe_col in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable) used, next, restricted, new_nullable, dupe_set)
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
can_reuse=None): can_reuse=None):
@ -1128,7 +1154,9 @@ class Query(object):
(which gives the table we are joining to), 'alias' is the alias for the (which gives the table we are joining to), 'alias' is the alias for the
table we are joining to. If dupe_multis is True, any many-to-many or table we are joining to. If dupe_multis is True, any many-to-many or
many-to-one joins will always create a new alias (necessary for many-to-one joins will always create a new alias (necessary for
disjunctive filters). disjunctive filters). If can_reuse is not None, it's a list of aliases
that can be reused in these joins (nothing else can be reused in this
case).
Returns the final field involved in the join, the target database Returns the final field involved in the join, the target database
column (used for any 'where' constraint), the final 'opts' value and the column (used for any 'where' constraint), the final 'opts' value and the
@ -1136,7 +1164,14 @@ class Query(object):
""" """
joins = [alias] joins = [alias]
last = [0] last = [0]
dupe_set = set()
exclusions = set()
for pos, name in enumerate(names): for pos, name in enumerate(names):
try:
exclusions.add(int_alias)
except NameError:
pass
exclusions.add(alias)
last.append(len(joins)) last.append(len(joins))
if name == 'pk': if name == 'pk':
name = opts.pk.name name = opts.pk.name
@ -1155,6 +1190,7 @@ class Query(object):
names = opts.get_all_field_names() names = opts.get_all_field_names()
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names))) "Choices are: %s" % (name, ", ".join(names)))
if not allow_many and (m2m or not direct): if not allow_many and (m2m or not direct):
for alias in joins: for alias in joins:
self.unref_alias(alias) self.unref_alias(alias)
@ -1164,12 +1200,27 @@ class Query(object):
alias_list = [] alias_list = []
for int_model in opts.get_base_chain(model): for int_model in opts.get_base_chain(model):
lhs_col = opts.parents[int_model].column lhs_col = opts.parents[int_model].column
dedupe = lhs_col in opts.duplicate_targets
if dedupe:
exclusions.update(self.dupe_avoidance.get(
(id(opts), lhs_col), ()))
dupe_set.add((opts, lhs_col))
opts = int_model._meta opts = int_model._meta
alias = self.join((alias, opts.db_table, lhs_col, alias = self.join((alias, opts.db_table, lhs_col,
opts.pk.column), exclusions=joins) opts.pk.column), exclusions=exclusions)
joins.append(alias) joins.append(alias)
exclusions.add(alias)
for (dupe_opts, dupe_col) in dupe_set:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
cached_data = opts._join_cache.get(name) cached_data = opts._join_cache.get(name)
orig_opts = opts orig_opts = opts
dupe_col = direct and field.column or field.field.column
dedupe = dupe_col in opts.duplicate_targets
if dupe_set or dedupe:
if dedupe:
dupe_set.add((opts, dupe_col))
exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col),
()))
if direct: if direct:
if m2m: if m2m:
@ -1191,9 +1242,11 @@ class Query(object):
target) target)
int_alias = self.join((alias, table1, from_col1, to_col1), int_alias = self.join((alias, table1, from_col1, to_col1),
dupe_multis, joins, nullable=True, reuse=can_reuse) dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
alias = self.join((int_alias, table2, from_col2, to_col2), alias = self.join((int_alias, table2, from_col2, to_col2),
dupe_multis, joins, nullable=True, reuse=can_reuse) dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
joins.extend([int_alias, alias]) joins.extend([int_alias, alias])
elif field.rel: elif field.rel:
# One-to-one or many-to-one field # One-to-one or many-to-one field
@ -1209,7 +1262,7 @@ class Query(object):
opts, target) opts, target)
alias = self.join((alias, table, from_col, to_col), alias = self.join((alias, table, from_col, to_col),
exclusions=joins, nullable=field.null) exclusions=exclusions, nullable=field.null)
joins.append(alias) joins.append(alias)
else: else:
# Non-relation fields. # Non-relation fields.
@ -1237,9 +1290,11 @@ class Query(object):
target) target)
int_alias = self.join((alias, table1, from_col1, to_col1), int_alias = self.join((alias, table1, from_col1, to_col1),
dupe_multis, joins, nullable=True, reuse=can_reuse) dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
alias = self.join((int_alias, table2, from_col2, to_col2), alias = self.join((int_alias, table2, from_col2, to_col2),
dupe_multis, joins, nullable=True, reuse=can_reuse) dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
joins.extend([int_alias, alias]) joins.extend([int_alias, alias])
else: else:
# One-to-many field (ForeignKey defined on the target model) # One-to-many field (ForeignKey defined on the target model)
@ -1257,14 +1312,34 @@ class Query(object):
opts, target) opts, target)
alias = self.join((alias, table, from_col, to_col), alias = self.join((alias, table, from_col, to_col),
dupe_multis, joins, nullable=True, reuse=can_reuse) dupe_multis, exclusions, nullable=True,
reuse=can_reuse)
joins.append(alias) joins.append(alias)
for (dupe_opts, dupe_col) in dupe_set:
try:
self.update_dupe_avoidance(dupe_opts, dupe_col, int_alias)
except NameError:
self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
if pos != len(names) - 1: if pos != len(names) - 1:
raise FieldError("Join on field %r not permitted." % name) raise FieldError("Join on field %r not permitted." % name)
return field, target, opts, joins, last return field, target, opts, joins, last
def update_dupe_avoidance(self, opts, col, alias):
"""
For a column that is one of multiple pointing to the same table, update
the internal data structures to note that this alias shouldn't be used
for those other columns.
"""
ident = id(opts)
for name in opts.duplicate_targets[col]:
try:
self.dupe_avoidance[ident, name].add(alias)
except KeyError:
self.dupe_avoidance[ident, name] = set([alias])
def split_exclude(self, filter_expr, prefix): def split_exclude(self, filter_expr, prefix):
""" """
When doing an exclude against any kind of N-to-many relation, we need When doing an exclude against any kind of N-to-many relation, we need

View File

@ -28,6 +28,24 @@ class Child(models.Model):
parent = models.ForeignKey(Parent) parent = models.ForeignKey(Parent)
# Multiple paths to the same model (#7110, #7125)
class Category(models.Model):
name = models.CharField(max_length=20)
def __unicode__(self):
return self.name
class Record(models.Model):
category = models.ForeignKey(Category)
class Relation(models.Model):
left = models.ForeignKey(Record, related_name='left_set')
right = models.ForeignKey(Record, related_name='right_set')
def __unicode__(self):
return u"%s - %s" % (self.left.category.name, self.right.category.name)
__test__ = {'API_TESTS':""" __test__ = {'API_TESTS':"""
>>> Third.objects.create(id='3', name='An example') >>> Third.objects.create(id='3', name='An example')
<Third: Third object> <Third: Third object>
@ -73,4 +91,26 @@ Traceback (most recent call last):
... ...
ValueError: Cannot assign "<First: First object>": "Child.parent" must be a "Parent" instance. ValueError: Cannot assign "<First: First object>": "Child.parent" must be a "Parent" instance.
# Test of multiple ForeignKeys to the same model (bug #7125)
>>> c1 = Category.objects.create(name='First')
>>> c2 = Category.objects.create(name='Second')
>>> c3 = Category.objects.create(name='Third')
>>> r1 = Record.objects.create(category=c1)
>>> r2 = Record.objects.create(category=c1)
>>> r3 = Record.objects.create(category=c2)
>>> r4 = Record.objects.create(category=c2)
>>> r5 = Record.objects.create(category=c3)
>>> r = Relation.objects.create(left=r1, right=r2)
>>> r = Relation.objects.create(left=r3, right=r4)
>>> r = Relation.objects.create(left=r1, right=r3)
>>> r = Relation.objects.create(left=r5, right=r2)
>>> r = Relation.objects.create(left=r3, right=r2)
>>> Relation.objects.filter(left__category__name__in=['First'], right__category__name__in=['Second'])
[<Relation: First - Second>]
>>> Category.objects.filter(record__left_set__right__category__name='Second').order_by('name')
[<Category: First>, <Category: Second>]
"""} """}

View File

@ -0,0 +1,60 @@
from django.db import models
class Building(models.Model):
name = models.CharField(max_length=10)
def __unicode__(self):
return u"Building: %s" % self.name
class Device(models.Model):
building = models.ForeignKey('Building')
name = models.CharField(max_length=10)
def __unicode__(self):
return u"device '%s' in building %s" % (self.name, self.building)
class Port(models.Model):
device = models.ForeignKey('Device')
number = models.CharField(max_length=10)
def __unicode__(self):
return u"%s/%s" % (self.device.name, self.number)
class Connection(models.Model):
start = models.ForeignKey(Port, related_name='connection_start',
unique=True)
end = models.ForeignKey(Port, related_name='connection_end', unique=True)
def __unicode__(self):
return u"%s to %s" % (self.start, self.end)
__test__ = {'API_TESTS': """
Regression test for bug #7110. When using select_related(), we must query the
Device and Building tables using two different aliases (each) in order to
differentiate the start and end Connection fields. The net result is that both
the "connections = ..." queries here should give the same results.
>>> b=Building.objects.create(name='101')
>>> dev1=Device.objects.create(name="router", building=b)
>>> dev2=Device.objects.create(name="switch", building=b)
>>> dev3=Device.objects.create(name="server", building=b)
>>> port1=Port.objects.create(number='4',device=dev1)
>>> port2=Port.objects.create(number='7',device=dev2)
>>> port3=Port.objects.create(number='1',device=dev3)
>>> c1=Connection.objects.create(start=port1, end=port2)
>>> c2=Connection.objects.create(start=port2, end=port3)
>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).order_by('id')
>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections]
[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')]
>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).select_related().order_by('id')
>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections]
[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')]
# This final query should only join seven tables (port, device and building
# twice each, plus connection once).
>>> connections.query.count_active_tables()
7
"""}