Fixed #20946 -- model inheritance + m2m failure

Cleaned up the internal implementation of m2m fields by removing
related.py _get_fk_val(). The _get_fk_val() was doing the wrong thing
if asked for the foreign key value on foreign key to parent model's
primary key when child model had different primary key field.
This commit is contained in:
Anssi Kääriäinen 2013-08-20 17:13:41 +03:00 committed by Andrew Godwin
parent 7775ced938
commit 244e2b71f5
3 changed files with 36 additions and 24 deletions

View File

@ -501,8 +501,6 @@ def create_many_related_manager(superclass, rel):
self.through = through self.through = through
self.prefetch_cache_name = prefetch_cache_name self.prefetch_cache_name = prefetch_cache_name
self.related_val = source_field.get_foreign_related_value(instance) self.related_val = source_field.get_foreign_related_value(instance)
# Used for single column related auto created models
self._fk_val = self.related_val[0]
if None in self.related_val: if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before ' raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' % 'this many-to-many relationship can be used.' %
@ -515,18 +513,6 @@ def create_many_related_manager(superclass, rel):
"a many-to-many relationship can be used." % "a many-to-many relationship can be used." %
instance.__class__.__name__) instance.__class__.__name__)
def _get_fk_val(self, obj, field_name):
"""
Returns the correct value for this relationship's foreign key. This
might be something else than pk value when to_field is used.
"""
fk = self.through._meta.get_field(field_name)
if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
attname = fk.rel.get_related_field().get_attname()
return fk.get_prep_lookup('exact', getattr(obj, attname))
else:
return obj.pk
def get_queryset(self): def get_queryset(self):
try: try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name] return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
@ -624,11 +610,12 @@ def create_many_related_manager(superclass, rel):
if not router.allow_relation(obj, self.instance): if not router.allow_relation(obj, self.instance):
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db)) (obj, self.instance._state.db, obj._state.db))
fk_val = self._get_fk_val(obj, target_field_name) fk_val = self.through._meta.get_field(
target_field_name).get_foreign_related_value(obj)[0]
if fk_val is None: if fk_val is None:
raise ValueError('Cannot add "%r": the value for field "%s" is None' % raise ValueError('Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name)) (obj, target_field_name))
new_ids.add(self._get_fk_val(obj, target_field_name)) new_ids.add(fk_val)
elif isinstance(obj, Model): elif isinstance(obj, Model):
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
else: else:
@ -636,7 +623,7 @@ def create_many_related_manager(superclass, rel):
db = router.db_for_write(self.through, instance=self.instance) db = router.db_for_write(self.through, instance=self.instance)
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
vals = vals.filter(**{ vals = vals.filter(**{
source_field_name: self._fk_val, source_field_name: self.related_val[0],
'%s__in' % target_field_name: new_ids, '%s__in' % target_field_name: new_ids,
}) })
new_ids = new_ids - set(vals) new_ids = new_ids - set(vals)
@ -650,7 +637,7 @@ def create_many_related_manager(superclass, rel):
# Add the ones that aren't there already # Add the ones that aren't there already
self.through._default_manager.using(db).bulk_create([ self.through._default_manager.using(db).bulk_create([
self.through(**{ self.through(**{
'%s_id' % source_field_name: self._fk_val, '%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: obj_id, '%s_id' % target_field_name: obj_id,
}) })
for obj_id in new_ids for obj_id in new_ids
@ -674,7 +661,9 @@ def create_many_related_manager(superclass, rel):
old_ids = set() old_ids = set()
for obj in objs: for obj in objs:
if isinstance(obj, self.model): if isinstance(obj, self.model):
old_ids.add(self._get_fk_val(obj, target_field_name)) fk_val = self.through._meta.get_field(
target_field_name).get_foreign_related_value(obj)[0]
old_ids.add(fk_val)
else: else:
old_ids.add(obj) old_ids.add(obj)
# Work out what DB we're operating on # Work out what DB we're operating on
@ -688,7 +677,7 @@ def create_many_related_manager(superclass, rel):
model=self.model, pk_set=old_ids, using=db) model=self.model, pk_set=old_ids, using=db)
# Remove the specified objects from the join table # Remove the specified objects from the join table
self.through._default_manager.using(db).filter(**{ self.through._default_manager.using(db).filter(**{
source_field_name: self._fk_val, source_field_name: self.related_val[0],
'%s__in' % target_field_name: old_ids '%s__in' % target_field_name: old_ids
}).delete() }).delete()
if self.reverse or source_field_name == self.source_field_name: if self.reverse or source_field_name == self.source_field_name:
@ -994,9 +983,12 @@ class ForeignObject(RelatedField):
# Gotcha: in some cases (like fixture loading) a model can have # Gotcha: in some cases (like fixture loading) a model can have
# different values in parent_ptr_id and parent's id. So, use # different values in parent_ptr_id and parent's id. So, use
# instance.pk (that is, parent_ptr_id) when asked for instance.id. # instance.pk (that is, parent_ptr_id) when asked for instance.id.
opts = instance._meta
if field.primary_key: if field.primary_key:
possible_parent_link = opts.get_ancestor_link(field.model)
if not possible_parent_link or possible_parent_link.primary_key:
ret.append(instance.pk) ret.append(instance.pk)
else: continue
ret.append(getattr(instance, field.attname)) ret.append(getattr(instance, field.attname))
return tuple(ret) return tuple(ret)

View File

@ -162,3 +162,9 @@ class Mixin(object):
class MixinModel(models.Model, Mixin): class MixinModel(models.Model, Mixin):
pass pass
class Base(models.Model):
titles = models.ManyToManyField(Title)
class SubBase(Base):
sub_id = models.IntegerField(primary_key=True)

View File

@ -10,7 +10,8 @@ from django.utils import six
from .models import ( from .models import (
Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post, Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post,
Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel) Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel,
Title, Base, SubBase)
class ModelInheritanceTests(TestCase): class ModelInheritanceTests(TestCase):
@ -357,3 +358,16 @@ class ModelInheritanceTests(TestCase):
[Place.objects.get(pk=s.pk)], [Place.objects.get(pk=s.pk)],
lambda x: x lambda x: x
) )
def test_custompk_m2m(self):
b = Base.objects.create()
b.titles.add(Title.objects.create(title="foof"))
s = SubBase.objects.create(sub_id=b.id)
b = Base.objects.get(pk=s.id)
self.assertNotEqual(b.pk, s.pk)
# Low-level test for related_val
self.assertEqual(s.titles.related_val, (s.id,))
# Higher level test for correct query values (title foof not
# accidentally found).
self.assertQuerysetEqual(
s.titles.all(), [])