diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index dd9fef34d52..80c62f85c47 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -573,9 +573,31 @@ def create_many_related_manager(superclass, rel): self.reverse = reverse self.through = through self.prefetch_cache_name = prefetch_cache_name - self._pk_val = self.instance.pk - if self._pk_val is None: - raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__) + self._fk_val = self._get_fk_val(instance, source_field_name) + if self._fk_val is None: + raise ValueError('"%r" needs to have a value for field "%s" before ' + 'this many-to-many relationship can be used.' % + (instance, source_field_name)) + # Even if this relation is not to pk, we require still pk value. + # The wish is that the instance has been already saved to DB, + # although having a pk value isn't a guarantee of that. + if instance.pk is None: + raise ValueError("%r instance needs to have a primary key value before " + "a many-to-many relationship can be used." % + instance.__class__.__name__) + + + def _get_fk_val(self, obj, field_name): + """ + Returns the correct value for this relationship's foreign key. This + might be something else than pk value when to_field is used. + """ + fk = self.through._meta.get_field(field_name) + if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname: + attname = fk.rel.get_related_field().get_attname() + return fk.get_prep_lookup('exact', getattr(obj, attname)) + else: + return obj.pk def get_query_set(self): try: @@ -677,7 +699,11 @@ def create_many_related_manager(superclass, rel): if not router.allow_relation(obj, self.instance): raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % (obj, self.instance._state.db, obj._state.db)) - new_ids.add(obj.pk) + fk_val = self._get_fk_val(obj, target_field_name) + if fk_val is None: + raise ValueError('Cannot add "%r": the value for field "%s" is None' % + (obj, target_field_name)) + new_ids.add(self._get_fk_val(obj, target_field_name)) elif isinstance(obj, Model): raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) else: @@ -685,7 +711,7 @@ def create_many_related_manager(superclass, rel): db = router.db_for_write(self.through, instance=self.instance) vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) vals = vals.filter(**{ - source_field_name: self._pk_val, + source_field_name: self._fk_val, '%s__in' % target_field_name: new_ids, }) new_ids = new_ids - set(vals) @@ -699,11 +725,12 @@ def create_many_related_manager(superclass, rel): # Add the ones that aren't there already self.through._default_manager.using(db).bulk_create([ self.through(**{ - '%s_id' % source_field_name: self._pk_val, + '%s_id' % source_field_name: self._fk_val, '%s_id' % target_field_name: obj_id, }) for obj_id in new_ids ]) + if self.reverse or source_field_name == self.source_field_name: # Don't send the signal when we are inserting the # duplicate data row for symmetrical reverse entries. @@ -722,7 +749,7 @@ def create_many_related_manager(superclass, rel): old_ids = set() for obj in objs: if isinstance(obj, self.model): - old_ids.add(obj.pk) + old_ids.add(self._get_fk_val(obj, target_field_name)) else: old_ids.add(obj) # Work out what DB we're operating on @@ -736,7 +763,7 @@ def create_many_related_manager(superclass, rel): model=self.model, pk_set=old_ids, using=db) # Remove the specified objects from the join table self.through._default_manager.using(db).filter(**{ - source_field_name: self._pk_val, + source_field_name: self._fk_val, '%s__in' % target_field_name: old_ids }).delete() if self.reverse or source_field_name == self.source_field_name: @@ -756,7 +783,7 @@ def create_many_related_manager(superclass, rel): instance=self.instance, reverse=self.reverse, model=self.model, pk_set=None, using=db) self.through._default_manager.using(db).filter(**{ - source_field_name: self._pk_val + source_field_name: self._fk_val }).delete() if self.reverse or source_field_name == self.source_field_name: # Don't send the signal when we are clearing the diff --git a/tests/regressiontests/m2m_through_regress/models.py b/tests/regressiontests/m2m_through_regress/models.py index 47c24ed5b2c..23d3366f228 100644 --- a/tests/regressiontests/m2m_through_regress/models.py +++ b/tests/regressiontests/m2m_through_regress/models.py @@ -62,18 +62,18 @@ class B(models.Model): # Using to_field on the through model @python_2_unicode_compatible class Car(models.Model): - make = models.CharField(max_length=20, unique=True) + make = models.CharField(max_length=20, unique=True, null=True) drivers = models.ManyToManyField('Driver', through='CarDriver') def __str__(self): - return self.make + return "%s" % self.make @python_2_unicode_compatible class Driver(models.Model): - name = models.CharField(max_length=20, unique=True) + name = models.CharField(max_length=20, unique=True, null=True) def __str__(self): - return self.name + return "%s" % self.name @python_2_unicode_compatible class CarDriver(models.Model): diff --git a/tests/regressiontests/m2m_through_regress/tests.py b/tests/regressiontests/m2m_through_regress/tests.py index 458c194f891..828ec3618c8 100644 --- a/tests/regressiontests/m2m_through_regress/tests.py +++ b/tests/regressiontests/m2m_through_regress/tests.py @@ -123,18 +123,104 @@ class ToFieldThroughTests(TestCase): self.car = Car.objects.create(make="Toyota") self.driver = Driver.objects.create(name="Ryan Briscoe") CarDriver.objects.create(car=self.car, driver=self.driver) + # We are testing if wrong objects get deleted due to using wrong + # field value in m2m queries. So, it is essential that the pk + # numberings do not match. + # Create one intentionally unused driver to mix up the autonumbering + self.unused_driver = Driver.objects.create(name="Barney Gumble") + # And two intentionally unused cars. + self.unused_car1 = Car.objects.create(make="Trabant") + self.unused_car2 = Car.objects.create(make="Wartburg") def test_to_field(self): self.assertQuerysetEqual( self.car.drivers.all(), [""] - ) + ) def test_to_field_reverse(self): self.assertQuerysetEqual( self.driver.car_set.all(), [""] - ) + ) + + def test_to_field_clear_reverse(self): + self.driver.car_set.clear() + self.assertQuerysetEqual( + self.driver.car_set.all(),[]) + + def test_to_field_clear(self): + self.car.drivers.clear() + self.assertQuerysetEqual( + self.car.drivers.all(),[]) + + # Low level tests for _add_items and _remove_items. We test these methods + # because .add/.remove aren't available for m2m fields with through, but + # through is the only way to set to_field currently. We do want to make + # sure these methods are ready if the ability to use .add or .remove with + # to_field relations is added some day. + def test_add(self): + self.assertQuerysetEqual( + self.car.drivers.all(), + [""] + ) + # Yikes - barney is going to drive... + self.car.drivers._add_items('car', 'driver', self.unused_driver) + self.assertQuerysetEqual( + self.car.drivers.all(), + ["", ""] + ) + + def test_add_null(self): + nullcar = Car.objects.create(make=None) + with self.assertRaises(ValueError): + nullcar.drivers._add_items('car', 'driver', self.unused_driver) + + def test_add_related_null(self): + nulldriver = Driver.objects.create(name=None) + with self.assertRaises(ValueError): + self.car.drivers._add_items('car', 'driver', nulldriver) + + def test_add_reverse(self): + car2 = Car.objects.create(make="Honda") + self.assertQuerysetEqual( + self.driver.car_set.all(), + [""] + ) + self.driver.car_set._add_items('driver', 'car', car2) + self.assertQuerysetEqual( + self.driver.car_set.all(), + ["", ""] + ) + + def test_add_null_reverse(self): + nullcar = Car.objects.create(make=None) + with self.assertRaises(ValueError): + self.driver.car_set._add_items('driver', 'car', nullcar) + + def test_add_null_reverse_related(self): + nulldriver = Driver.objects.create(name=None) + with self.assertRaises(ValueError): + nulldriver.car_set._add_items('driver', 'car', self.car) + + def test_remove(self): + self.assertQuerysetEqual( + self.car.drivers.all(), + [""] + ) + self.car.drivers._remove_items('car', 'driver', self.driver) + self.assertQuerysetEqual( + self.car.drivers.all(),[]) + + def test_remove_reverse(self): + self.assertQuerysetEqual( + self.driver.car_set.all(), + [""] + ) + self.driver.car_set._remove_items('driver', 'car', self.car) + self.assertQuerysetEqual( + self.driver.car_set.all(),[]) + class ThroughLoadDataTestCase(TestCase): fixtures = ["m2m_through"]