Fixed #18823 -- Ensured m2m.clear() works when using through+to_field
There was a potential data-loss issue involved -- when clearing instance's m2m assignments it was possible some other instance's m2m data was deleted instead. This commit also improved None handling for to_field cases.
This commit is contained in:
parent
98032f67c7
commit
611c4d6f1c
|
@ -573,9 +573,31 @@ def create_many_related_manager(superclass, rel):
|
||||||
self.reverse = reverse
|
self.reverse = reverse
|
||||||
self.through = through
|
self.through = through
|
||||||
self.prefetch_cache_name = prefetch_cache_name
|
self.prefetch_cache_name = prefetch_cache_name
|
||||||
self._pk_val = self.instance.pk
|
self._fk_val = self._get_fk_val(instance, source_field_name)
|
||||||
if self._pk_val is None:
|
if self._fk_val is None:
|
||||||
raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)
|
raise ValueError('"%r" needs to have a value for field "%s" before '
|
||||||
|
'this many-to-many relationship can be used.' %
|
||||||
|
(instance, source_field_name))
|
||||||
|
# Even if this relation is not to pk, we require still pk value.
|
||||||
|
# The wish is that the instance has been already saved to DB,
|
||||||
|
# although having a pk value isn't a guarantee of that.
|
||||||
|
if instance.pk is None:
|
||||||
|
raise ValueError("%r instance needs to have a primary key value before "
|
||||||
|
"a many-to-many relationship can be used." %
|
||||||
|
instance.__class__.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fk_val(self, obj, field_name):
|
||||||
|
"""
|
||||||
|
Returns the correct value for this relationship's foreign key. This
|
||||||
|
might be something else than pk value when to_field is used.
|
||||||
|
"""
|
||||||
|
fk = self.through._meta.get_field(field_name)
|
||||||
|
if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
|
||||||
|
attname = fk.rel.get_related_field().get_attname()
|
||||||
|
return fk.get_prep_lookup('exact', getattr(obj, attname))
|
||||||
|
else:
|
||||||
|
return obj.pk
|
||||||
|
|
||||||
def get_query_set(self):
|
def get_query_set(self):
|
||||||
try:
|
try:
|
||||||
|
@ -677,7 +699,11 @@ def create_many_related_manager(superclass, rel):
|
||||||
if not router.allow_relation(obj, self.instance):
|
if not router.allow_relation(obj, self.instance):
|
||||||
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
|
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
|
||||||
(obj, self.instance._state.db, obj._state.db))
|
(obj, self.instance._state.db, obj._state.db))
|
||||||
new_ids.add(obj.pk)
|
fk_val = self._get_fk_val(obj, target_field_name)
|
||||||
|
if fk_val is None:
|
||||||
|
raise ValueError('Cannot add "%r": the value for field "%s" is None' %
|
||||||
|
(obj, target_field_name))
|
||||||
|
new_ids.add(self._get_fk_val(obj, target_field_name))
|
||||||
elif isinstance(obj, Model):
|
elif isinstance(obj, Model):
|
||||||
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
|
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
|
||||||
else:
|
else:
|
||||||
|
@ -685,7 +711,7 @@ def create_many_related_manager(superclass, rel):
|
||||||
db = router.db_for_write(self.through, instance=self.instance)
|
db = router.db_for_write(self.through, instance=self.instance)
|
||||||
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
|
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
|
||||||
vals = vals.filter(**{
|
vals = vals.filter(**{
|
||||||
source_field_name: self._pk_val,
|
source_field_name: self._fk_val,
|
||||||
'%s__in' % target_field_name: new_ids,
|
'%s__in' % target_field_name: new_ids,
|
||||||
})
|
})
|
||||||
new_ids = new_ids - set(vals)
|
new_ids = new_ids - set(vals)
|
||||||
|
@ -699,11 +725,12 @@ def create_many_related_manager(superclass, rel):
|
||||||
# Add the ones that aren't there already
|
# Add the ones that aren't there already
|
||||||
self.through._default_manager.using(db).bulk_create([
|
self.through._default_manager.using(db).bulk_create([
|
||||||
self.through(**{
|
self.through(**{
|
||||||
'%s_id' % source_field_name: self._pk_val,
|
'%s_id' % source_field_name: self._fk_val,
|
||||||
'%s_id' % target_field_name: obj_id,
|
'%s_id' % target_field_name: obj_id,
|
||||||
})
|
})
|
||||||
for obj_id in new_ids
|
for obj_id in new_ids
|
||||||
])
|
])
|
||||||
|
|
||||||
if self.reverse or source_field_name == self.source_field_name:
|
if self.reverse or source_field_name == self.source_field_name:
|
||||||
# Don't send the signal when we are inserting the
|
# Don't send the signal when we are inserting the
|
||||||
# duplicate data row for symmetrical reverse entries.
|
# duplicate data row for symmetrical reverse entries.
|
||||||
|
@ -722,7 +749,7 @@ def create_many_related_manager(superclass, rel):
|
||||||
old_ids = set()
|
old_ids = set()
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if isinstance(obj, self.model):
|
if isinstance(obj, self.model):
|
||||||
old_ids.add(obj.pk)
|
old_ids.add(self._get_fk_val(obj, target_field_name))
|
||||||
else:
|
else:
|
||||||
old_ids.add(obj)
|
old_ids.add(obj)
|
||||||
# Work out what DB we're operating on
|
# Work out what DB we're operating on
|
||||||
|
@ -736,7 +763,7 @@ def create_many_related_manager(superclass, rel):
|
||||||
model=self.model, pk_set=old_ids, using=db)
|
model=self.model, pk_set=old_ids, using=db)
|
||||||
# Remove the specified objects from the join table
|
# Remove the specified objects from the join table
|
||||||
self.through._default_manager.using(db).filter(**{
|
self.through._default_manager.using(db).filter(**{
|
||||||
source_field_name: self._pk_val,
|
source_field_name: self._fk_val,
|
||||||
'%s__in' % target_field_name: old_ids
|
'%s__in' % target_field_name: old_ids
|
||||||
}).delete()
|
}).delete()
|
||||||
if self.reverse or source_field_name == self.source_field_name:
|
if self.reverse or source_field_name == self.source_field_name:
|
||||||
|
@ -756,7 +783,7 @@ def create_many_related_manager(superclass, rel):
|
||||||
instance=self.instance, reverse=self.reverse,
|
instance=self.instance, reverse=self.reverse,
|
||||||
model=self.model, pk_set=None, using=db)
|
model=self.model, pk_set=None, using=db)
|
||||||
self.through._default_manager.using(db).filter(**{
|
self.through._default_manager.using(db).filter(**{
|
||||||
source_field_name: self._pk_val
|
source_field_name: self._fk_val
|
||||||
}).delete()
|
}).delete()
|
||||||
if self.reverse or source_field_name == self.source_field_name:
|
if self.reverse or source_field_name == self.source_field_name:
|
||||||
# Don't send the signal when we are clearing the
|
# Don't send the signal when we are clearing the
|
||||||
|
|
|
@ -62,18 +62,18 @@ class B(models.Model):
|
||||||
# Using to_field on the through model
|
# Using to_field on the through model
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class Car(models.Model):
|
class Car(models.Model):
|
||||||
make = models.CharField(max_length=20, unique=True)
|
make = models.CharField(max_length=20, unique=True, null=True)
|
||||||
drivers = models.ManyToManyField('Driver', through='CarDriver')
|
drivers = models.ManyToManyField('Driver', through='CarDriver')
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.make
|
return "%s" % self.make
|
||||||
|
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class Driver(models.Model):
|
class Driver(models.Model):
|
||||||
name = models.CharField(max_length=20, unique=True)
|
name = models.CharField(max_length=20, unique=True, null=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return "%s" % self.name
|
||||||
|
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class CarDriver(models.Model):
|
class CarDriver(models.Model):
|
||||||
|
|
|
@ -123,18 +123,104 @@ class ToFieldThroughTests(TestCase):
|
||||||
self.car = Car.objects.create(make="Toyota")
|
self.car = Car.objects.create(make="Toyota")
|
||||||
self.driver = Driver.objects.create(name="Ryan Briscoe")
|
self.driver = Driver.objects.create(name="Ryan Briscoe")
|
||||||
CarDriver.objects.create(car=self.car, driver=self.driver)
|
CarDriver.objects.create(car=self.car, driver=self.driver)
|
||||||
|
# We are testing if wrong objects get deleted due to using wrong
|
||||||
|
# field value in m2m queries. So, it is essential that the pk
|
||||||
|
# numberings do not match.
|
||||||
|
# Create one intentionally unused driver to mix up the autonumbering
|
||||||
|
self.unused_driver = Driver.objects.create(name="Barney Gumble")
|
||||||
|
# And two intentionally unused cars.
|
||||||
|
self.unused_car1 = Car.objects.create(make="Trabant")
|
||||||
|
self.unused_car2 = Car.objects.create(make="Wartburg")
|
||||||
|
|
||||||
def test_to_field(self):
|
def test_to_field(self):
|
||||||
self.assertQuerysetEqual(
|
self.assertQuerysetEqual(
|
||||||
self.car.drivers.all(),
|
self.car.drivers.all(),
|
||||||
["<Driver: Ryan Briscoe>"]
|
["<Driver: Ryan Briscoe>"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_to_field_reverse(self):
|
def test_to_field_reverse(self):
|
||||||
self.assertQuerysetEqual(
|
self.assertQuerysetEqual(
|
||||||
self.driver.car_set.all(),
|
self.driver.car_set.all(),
|
||||||
["<Car: Toyota>"]
|
["<Car: Toyota>"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_to_field_clear_reverse(self):
|
||||||
|
self.driver.car_set.clear()
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.driver.car_set.all(),[])
|
||||||
|
|
||||||
|
def test_to_field_clear(self):
|
||||||
|
self.car.drivers.clear()
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.car.drivers.all(),[])
|
||||||
|
|
||||||
|
# Low level tests for _add_items and _remove_items. We test these methods
|
||||||
|
# because .add/.remove aren't available for m2m fields with through, but
|
||||||
|
# through is the only way to set to_field currently. We do want to make
|
||||||
|
# sure these methods are ready if the ability to use .add or .remove with
|
||||||
|
# to_field relations is added some day.
|
||||||
|
def test_add(self):
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.car.drivers.all(),
|
||||||
|
["<Driver: Ryan Briscoe>"]
|
||||||
|
)
|
||||||
|
# Yikes - barney is going to drive...
|
||||||
|
self.car.drivers._add_items('car', 'driver', self.unused_driver)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.car.drivers.all(),
|
||||||
|
["<Driver: Ryan Briscoe>", "<Driver: Barney Gumble>"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_add_null(self):
|
||||||
|
nullcar = Car.objects.create(make=None)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
nullcar.drivers._add_items('car', 'driver', self.unused_driver)
|
||||||
|
|
||||||
|
def test_add_related_null(self):
|
||||||
|
nulldriver = Driver.objects.create(name=None)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.car.drivers._add_items('car', 'driver', nulldriver)
|
||||||
|
|
||||||
|
def test_add_reverse(self):
|
||||||
|
car2 = Car.objects.create(make="Honda")
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.driver.car_set.all(),
|
||||||
|
["<Car: Toyota>"]
|
||||||
|
)
|
||||||
|
self.driver.car_set._add_items('driver', 'car', car2)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.driver.car_set.all(),
|
||||||
|
["<Car: Toyota>", "<Car: Honda>"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_add_null_reverse(self):
|
||||||
|
nullcar = Car.objects.create(make=None)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.driver.car_set._add_items('driver', 'car', nullcar)
|
||||||
|
|
||||||
|
def test_add_null_reverse_related(self):
|
||||||
|
nulldriver = Driver.objects.create(name=None)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
nulldriver.car_set._add_items('driver', 'car', self.car)
|
||||||
|
|
||||||
|
def test_remove(self):
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.car.drivers.all(),
|
||||||
|
["<Driver: Ryan Briscoe>"]
|
||||||
|
)
|
||||||
|
self.car.drivers._remove_items('car', 'driver', self.driver)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.car.drivers.all(),[])
|
||||||
|
|
||||||
|
def test_remove_reverse(self):
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.driver.car_set.all(),
|
||||||
|
["<Car: Toyota>"]
|
||||||
|
)
|
||||||
|
self.driver.car_set._remove_items('driver', 'car', self.car)
|
||||||
|
self.assertQuerysetEqual(
|
||||||
|
self.driver.car_set.all(),[])
|
||||||
|
|
||||||
|
|
||||||
class ThroughLoadDataTestCase(TestCase):
|
class ThroughLoadDataTestCase(TestCase):
|
||||||
fixtures = ["m2m_through"]
|
fixtures = ["m2m_through"]
|
||||||
|
|
Loading…
Reference in New Issue