diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 52cb91d3a8..2b426c37a5 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -1051,10 +1051,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): return obj, created update_or_create.alters_data = True - def _get_missing_target_ids(self, source_field_name, target_field_name, db, objs): + def _get_target_ids(self, target_field_name, objs): """ - Return the subset of ids of `objs` that aren't already assigned to - this relationship. + Return the set of ids of `objs` that the target field references. """ from django.db.models import Model target_ids = set() @@ -1081,6 +1080,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): ) else: target_ids.add(obj) + return target_ids + + def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids): + """ + Return the subset of ids of `objs` that aren't already assigned to + this relationship. + """ vals = self.through._default_manager.using(db).values_list( target_field_name, flat=True ).filter(**{ @@ -1089,6 +1095,35 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): }) return target_ids.difference(vals) + def _get_add_plan(self, db, source_field_name): + """ + Return a boolean triple of the way the add should be performed. + + The first element is whether or not bulk_create(ignore_conflicts) + can be used, the second whether or not signals must be sent, and + the third element is whether or not the immediate bulk insertion + with conflicts ignored can be performed. + """ + # Conflicts can be ignored when the intermediary model is + # auto-created as the only possible collision is on the + # (source_id, target_id) tuple. The same assertion doesn't hold for + # user-defined intermediary models as they could have other fields + # causing conflicts which must be surfaced. + can_ignore_conflicts = ( + connections[db].features.supports_ignore_conflicts and + self.through._meta.auto_created is not False + ) + # Don't send the signal when inserting duplicate data row + # for symmetrical reverse entries. + must_send_signals = (self.reverse or source_field_name == self.source_field_name) and ( + signals.m2m_changed.has_listeners(self.through) + ) + # Fast addition through bulk insertion can only be performed + # if no m2m_changed listeners are connected for self.through + # as they require the added set of ids to be provided via + # pk_set. + return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals) + def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None): # source_field_name: the PK fieldname in join table for the source object # target_field_name: the PK fieldname in join table for the target object @@ -1097,37 +1132,40 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): # If there aren't any objects, there is nothing to do. if objs: + target_ids = self._get_target_ids(target_field_name, objs) db = router.db_for_write(self.through, instance=self.instance) - missing_target_ids = self._get_missing_target_ids(source_field_name, target_field_name, db, objs) + can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name) + if can_fast_add: + self.through._default_manager.using(db).bulk_create([ + self.through(**{ + '%s_id' % source_field_name: self.related_val[0], + '%s_id' % target_field_name: target_id, + }) + for target_id in target_ids + ], ignore_conflicts=True) + return + missing_target_ids = self._get_missing_target_ids( + source_field_name, target_field_name, db, target_ids + ) with transaction.atomic(using=db, savepoint=False): - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are inserting the - # duplicate data row for symmetrical reverse entries. + if must_send_signals: signals.m2m_changed.send( sender=self.through, action='pre_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=missing_target_ids, using=db, ) - # Add the ones that aren't there already. Conflicts can be - # ignored when the intermediary model is auto-created as - # the only possible collision is on the (sid_id, tid_id) - # tuple. The same assertion doesn't hold for user-defined - # intermediary models as they could have other fields - # causing conflicts which must be surfaced. - ignore_conflicts = self.through._meta.auto_created is not False + # Add the ones that aren't there already. self.through._default_manager.using(db).bulk_create([ self.through(**through_defaults, **{ '%s_id' % source_field_name: self.related_val[0], '%s_id' % target_field_name: target_id, }) for target_id in missing_target_ids - ], ignore_conflicts=ignore_conflicts) + ], ignore_conflicts=can_ignore_conflicts) - if self.reverse or source_field_name == self.source_field_name: - # Don't send the signal when we are inserting the - # duplicate data row for symmetrical reverse entries. + if must_send_signals: signals.m2m_changed.send( sender=self.through, action='post_add', instance=self.instance, reverse=self.reverse, diff --git a/tests/many_to_many/tests.py b/tests/many_to_many/tests.py index adde2ac563..098cd29e46 100644 --- a/tests/many_to_many/tests.py +++ b/tests/many_to_many/tests.py @@ -118,13 +118,26 @@ class ManyToManyTests(TestCase): ) @skipUnlessDBFeature('supports_ignore_conflicts') - def test_add_ignore_conflicts(self): + def test_fast_add_ignore_conflicts(self): + """ + A single query is necessary to add auto-created through instances if + the database backend supports bulk_create(ignore_conflicts) and no + m2m_changed signals receivers are connected. + """ + with self.assertNumQueries(1): + self.a1.publications.add(self.p1, self.p2) + + @skipUnlessDBFeature('supports_ignore_conflicts') + def test_slow_add_ignore_conflicts(self): manager_cls = self.a1.publications.__class__ # Simulate a race condition between the missing ids retrieval and # the bulk insertion attempt. missing_target_ids = {self.p1.id} + # Disable fast-add to test the case where the slow add path is taken. + add_plan = (True, False, False) with mock.patch.object(manager_cls, '_get_missing_target_ids', return_value=missing_target_ids) as mocked: - self.a1.publications.add(self.p1) + with mock.patch.object(manager_cls, '_get_add_plan', return_value=add_plan): + self.a1.publications.add(self.p1) mocked.assert_called_once() def test_related_sets(self):