diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 6c11df4cbd..4ff93e701f 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -501,8 +501,6 @@ def create_many_related_manager(superclass, rel): self.through = through self.prefetch_cache_name = prefetch_cache_name self.related_val = source_field.get_foreign_related_value(instance) - # Used for single column related auto created models - self._fk_val = self.related_val[0] if None in self.related_val: raise ValueError('"%r" needs to have a value for field "%s" before ' 'this many-to-many relationship can be used.' % @@ -515,18 +513,6 @@ def create_many_related_manager(superclass, rel): "a many-to-many relationship can be used." % instance.__class__.__name__) - def _get_fk_val(self, obj, field_name): - """ - Returns the correct value for this relationship's foreign key. This - might be something else than pk value when to_field is used. - """ - fk = self.through._meta.get_field(field_name) - if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname: - attname = fk.rel.get_related_field().get_attname() - return fk.get_prep_lookup('exact', getattr(obj, attname)) - else: - return obj.pk - def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] @@ -624,11 +610,12 @@ def create_many_related_manager(superclass, rel): if not router.allow_relation(obj, self.instance): raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % (obj, self.instance._state.db, obj._state.db)) - fk_val = self._get_fk_val(obj, target_field_name) + fk_val = self.through._meta.get_field( + target_field_name).get_foreign_related_value(obj)[0] if fk_val is None: raise ValueError('Cannot add "%r": the value for field "%s" is None' % (obj, target_field_name)) - new_ids.add(self._get_fk_val(obj, target_field_name)) + new_ids.add(fk_val) elif isinstance(obj, Model): raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) else: @@ -636,7 +623,7 @@ def create_many_related_manager(superclass, rel): db = router.db_for_write(self.through, instance=self.instance) vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) vals = vals.filter(**{ - source_field_name: self._fk_val, + source_field_name: self.related_val[0], '%s__in' % target_field_name: new_ids, }) new_ids = new_ids - set(vals) @@ -650,7 +637,7 @@ def create_many_related_manager(superclass, rel): # Add the ones that aren't there already self.through._default_manager.using(db).bulk_create([ self.through(**{ - '%s_id' % source_field_name: self._fk_val, + '%s_id' % source_field_name: self.related_val[0], '%s_id' % target_field_name: obj_id, }) for obj_id in new_ids @@ -674,7 +661,9 @@ def create_many_related_manager(superclass, rel): old_ids = set() for obj in objs: if isinstance(obj, self.model): - old_ids.add(self._get_fk_val(obj, target_field_name)) + fk_val = self.through._meta.get_field( + target_field_name).get_foreign_related_value(obj)[0] + old_ids.add(fk_val) else: old_ids.add(obj) # Work out what DB we're operating on @@ -688,7 +677,7 @@ def create_many_related_manager(superclass, rel): model=self.model, pk_set=old_ids, using=db) # Remove the specified objects from the join table self.through._default_manager.using(db).filter(**{ - source_field_name: self._fk_val, + source_field_name: self.related_val[0], '%s__in' % target_field_name: old_ids }).delete() if self.reverse or source_field_name == self.source_field_name: @@ -994,10 +983,13 @@ class ForeignObject(RelatedField): # Gotcha: in some cases (like fixture loading) a model can have # different values in parent_ptr_id and parent's id. So, use # instance.pk (that is, parent_ptr_id) when asked for instance.id. + opts = instance._meta if field.primary_key: - ret.append(instance.pk) - else: - ret.append(getattr(instance, field.attname)) + possible_parent_link = opts.get_ancestor_link(field.model) + if not possible_parent_link or possible_parent_link.primary_key: + ret.append(instance.pk) + continue + ret.append(getattr(instance, field.attname)) return tuple(ret) def get_attname_column(self): diff --git a/tests/model_inheritance/models.py b/tests/model_inheritance/models.py index 106645d23c..020bb35bc7 100644 --- a/tests/model_inheritance/models.py +++ b/tests/model_inheritance/models.py @@ -162,3 +162,9 @@ class Mixin(object): class MixinModel(models.Model, Mixin): pass + +class Base(models.Model): + titles = models.ManyToManyField(Title) + +class SubBase(Base): + sub_id = models.IntegerField(primary_key=True) diff --git a/tests/model_inheritance/tests.py b/tests/model_inheritance/tests.py index b8ab0c8581..dab3088a41 100644 --- a/tests/model_inheritance/tests.py +++ b/tests/model_inheritance/tests.py @@ -10,7 +10,8 @@ from django.utils import six from .models import ( Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post, - Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel) + Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel, + Title, Base, SubBase) class ModelInheritanceTests(TestCase): @@ -357,3 +358,16 @@ class ModelInheritanceTests(TestCase): [Place.objects.get(pk=s.pk)], lambda x: x ) + + def test_custompk_m2m(self): + b = Base.objects.create() + b.titles.add(Title.objects.create(title="foof")) + s = SubBase.objects.create(sub_id=b.id) + b = Base.objects.get(pk=s.id) + self.assertNotEqual(b.pk, s.pk) + # Low-level test for related_val + self.assertEqual(s.titles.related_val, (s.id,)) + # Higher level test for correct query values (title foof not + # accidentally found). + self.assertQuerysetEqual( + s.titles.all(), [])