mirror of https://github.com/django/django.git
Fixed #20946 -- model inheritance + m2m failure
Cleaned up the internal implementation of m2m fields by removing related.py _get_fk_val(). The _get_fk_val() was doing the wrong thing if asked for the foreign key value on foreign key to parent model's primary key when child model had different primary key field.
This commit is contained in:
parent
83e434a2c2
commit
b065aeb17f
|
@ -503,8 +503,6 @@ def create_many_related_manager(superclass, rel):
|
|||
self.through = through
|
||||
self.prefetch_cache_name = prefetch_cache_name
|
||||
self.related_val = source_field.get_foreign_related_value(instance)
|
||||
# Used for single column related auto created models
|
||||
self._fk_val = self.related_val[0]
|
||||
if None in self.related_val:
|
||||
raise ValueError('"%r" needs to have a value for field "%s" before '
|
||||
'this many-to-many relationship can be used.' %
|
||||
|
@ -517,18 +515,6 @@ def create_many_related_manager(superclass, rel):
|
|||
"a many-to-many relationship can be used." %
|
||||
instance.__class__.__name__)
|
||||
|
||||
def _get_fk_val(self, obj, field_name):
|
||||
"""
|
||||
Returns the correct value for this relationship's foreign key. This
|
||||
might be something else than pk value when to_field is used.
|
||||
"""
|
||||
fk = self.through._meta.get_field(field_name)
|
||||
if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
|
||||
attname = fk.rel.get_related_field().get_attname()
|
||||
return fk.get_prep_lookup('exact', getattr(obj, attname))
|
||||
else:
|
||||
return obj.pk
|
||||
|
||||
def get_queryset(self):
|
||||
try:
|
||||
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
|
||||
|
@ -626,11 +612,12 @@ def create_many_related_manager(superclass, rel):
|
|||
if not router.allow_relation(obj, self.instance):
|
||||
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
|
||||
(obj, self.instance._state.db, obj._state.db))
|
||||
fk_val = self._get_fk_val(obj, target_field_name)
|
||||
fk_val = self.through._meta.get_field(
|
||||
target_field_name).get_foreign_related_value(obj)[0]
|
||||
if fk_val is None:
|
||||
raise ValueError('Cannot add "%r": the value for field "%s" is None' %
|
||||
(obj, target_field_name))
|
||||
new_ids.add(self._get_fk_val(obj, target_field_name))
|
||||
new_ids.add(fk_val)
|
||||
elif isinstance(obj, Model):
|
||||
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
|
||||
else:
|
||||
|
@ -638,7 +625,7 @@ def create_many_related_manager(superclass, rel):
|
|||
db = router.db_for_write(self.through, instance=self.instance)
|
||||
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
|
||||
vals = vals.filter(**{
|
||||
source_field_name: self._fk_val,
|
||||
source_field_name: self.related_val[0],
|
||||
'%s__in' % target_field_name: new_ids,
|
||||
})
|
||||
new_ids = new_ids - set(vals)
|
||||
|
@ -652,7 +639,7 @@ def create_many_related_manager(superclass, rel):
|
|||
# Add the ones that aren't there already
|
||||
self.through._default_manager.using(db).bulk_create([
|
||||
self.through(**{
|
||||
'%s_id' % source_field_name: self._fk_val,
|
||||
'%s_id' % source_field_name: self.related_val[0],
|
||||
'%s_id' % target_field_name: obj_id,
|
||||
})
|
||||
for obj_id in new_ids
|
||||
|
@ -676,7 +663,9 @@ def create_many_related_manager(superclass, rel):
|
|||
old_ids = set()
|
||||
for obj in objs:
|
||||
if isinstance(obj, self.model):
|
||||
old_ids.add(self._get_fk_val(obj, target_field_name))
|
||||
fk_val = self.through._meta.get_field(
|
||||
target_field_name).get_foreign_related_value(obj)[0]
|
||||
old_ids.add(fk_val)
|
||||
else:
|
||||
old_ids.add(obj)
|
||||
# Work out what DB we're operating on
|
||||
|
@ -690,7 +679,7 @@ def create_many_related_manager(superclass, rel):
|
|||
model=self.model, pk_set=old_ids, using=db)
|
||||
# Remove the specified objects from the join table
|
||||
self.through._default_manager.using(db).filter(**{
|
||||
source_field_name: self._fk_val,
|
||||
source_field_name: self.related_val[0],
|
||||
'%s__in' % target_field_name: old_ids
|
||||
}).delete()
|
||||
if self.reverse or source_field_name == self.source_field_name:
|
||||
|
@ -994,10 +983,13 @@ class ForeignObject(RelatedField):
|
|||
# Gotcha: in some cases (like fixture loading) a model can have
|
||||
# different values in parent_ptr_id and parent's id. So, use
|
||||
# instance.pk (that is, parent_ptr_id) when asked for instance.id.
|
||||
opts = instance._meta
|
||||
if field.primary_key:
|
||||
ret.append(instance.pk)
|
||||
else:
|
||||
ret.append(getattr(instance, field.attname))
|
||||
possible_parent_link = opts.get_ancestor_link(field.model)
|
||||
if not possible_parent_link or possible_parent_link.primary_key:
|
||||
ret.append(instance.pk)
|
||||
continue
|
||||
ret.append(getattr(instance, field.attname))
|
||||
return tuple(ret)
|
||||
|
||||
def get_attname_column(self):
|
||||
|
|
|
@ -162,3 +162,9 @@ class Mixin(object):
|
|||
|
||||
class MixinModel(models.Model, Mixin):
|
||||
pass
|
||||
|
||||
class Base(models.Model):
|
||||
titles = models.ManyToManyField(Title)
|
||||
|
||||
class SubBase(Base):
|
||||
sub_id = models.IntegerField(primary_key=True)
|
||||
|
|
|
@ -10,7 +10,8 @@ from django.utils import six
|
|||
|
||||
from .models import (
|
||||
Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post,
|
||||
Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel)
|
||||
Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel,
|
||||
Title, Base, SubBase)
|
||||
|
||||
|
||||
class ModelInheritanceTests(TestCase):
|
||||
|
@ -357,3 +358,16 @@ class ModelInheritanceTests(TestCase):
|
|||
[Place.objects.get(pk=s.pk)],
|
||||
lambda x: x
|
||||
)
|
||||
|
||||
def test_custompk_m2m(self):
|
||||
b = Base.objects.create()
|
||||
b.titles.add(Title.objects.create(title="foof"))
|
||||
s = SubBase.objects.create(sub_id=b.id)
|
||||
b = Base.objects.get(pk=s.id)
|
||||
self.assertNotEqual(b.pk, s.pk)
|
||||
# Low-level test for related_val
|
||||
self.assertEqual(s.titles.related_val, (s.id,))
|
||||
# Higher level test for correct query values (title foof not
|
||||
# accidentally found).
|
||||
self.assertQuerysetEqual(
|
||||
s.titles.all(), [])
|
||||
|
|
Loading…
Reference in New Issue