Refs #19544 -- Added a fast path for through additions if supported.

The single query insertion path is taken if the backend supports inserts
that ignore conflicts and m2m_changed signals don't have to be sent.
This commit is contained in:
Simon Charette 2019-02-15 22:02:33 -05:00 committed by Tim Graham
parent 28712d8acf
commit de7f6b51b2
2 changed files with 71 additions and 20 deletions

View File

@ -1051,10 +1051,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
return obj, created return obj, created
update_or_create.alters_data = True update_or_create.alters_data = True
def _get_missing_target_ids(self, source_field_name, target_field_name, db, objs): def _get_target_ids(self, target_field_name, objs):
""" """
Return the subset of ids of `objs` that aren't already assigned to Return the set of ids of `objs` that the target field references.
this relationship.
""" """
from django.db.models import Model from django.db.models import Model
target_ids = set() target_ids = set()
@ -1081,6 +1080,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
) )
else: else:
target_ids.add(obj) target_ids.add(obj)
return target_ids
def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids):
"""
Return the subset of ids of `objs` that aren't already assigned to
this relationship.
"""
vals = self.through._default_manager.using(db).values_list( vals = self.through._default_manager.using(db).values_list(
target_field_name, flat=True target_field_name, flat=True
).filter(**{ ).filter(**{
@ -1089,6 +1095,35 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
}) })
return target_ids.difference(vals) return target_ids.difference(vals)
def _get_add_plan(self, db, source_field_name):
"""
Return a boolean triple of the way the add should be performed.
The first element is whether or not bulk_create(ignore_conflicts)
can be used, the second whether or not signals must be sent, and
the third element is whether or not the immediate bulk insertion
with conflicts ignored can be performed.
"""
# Conflicts can be ignored when the intermediary model is
# auto-created as the only possible collision is on the
# (source_id, target_id) tuple. The same assertion doesn't hold for
# user-defined intermediary models as they could have other fields
# causing conflicts which must be surfaced.
can_ignore_conflicts = (
connections[db].features.supports_ignore_conflicts and
self.through._meta.auto_created is not False
)
# Don't send the signal when inserting duplicate data row
# for symmetrical reverse entries.
must_send_signals = (self.reverse or source_field_name == self.source_field_name) and (
signals.m2m_changed.has_listeners(self.through)
)
# Fast addition through bulk insertion can only be performed
# if no m2m_changed listeners are connected for self.through
# as they require the added set of ids to be provided via
# pk_set.
return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals)
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None): def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
# source_field_name: the PK fieldname in join table for the source object # source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object # target_field_name: the PK fieldname in join table for the target object
@ -1097,37 +1132,40 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# If there aren't any objects, there is nothing to do. # If there aren't any objects, there is nothing to do.
if objs: if objs:
target_ids = self._get_target_ids(target_field_name, objs)
db = router.db_for_write(self.through, instance=self.instance) db = router.db_for_write(self.through, instance=self.instance)
missing_target_ids = self._get_missing_target_ids(source_field_name, target_field_name, db, objs) can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name)
if can_fast_add:
self.through._default_manager.using(db).bulk_create([
self.through(**{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in target_ids
], ignore_conflicts=True)
return
missing_target_ids = self._get_missing_target_ids(
source_field_name, target_field_name, db, target_ids
)
with transaction.atomic(using=db, savepoint=False): with transaction.atomic(using=db, savepoint=False):
if self.reverse or source_field_name == self.source_field_name: if must_send_signals:
# Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries.
signals.m2m_changed.send( signals.m2m_changed.send(
sender=self.through, action='pre_add', sender=self.through, action='pre_add',
instance=self.instance, reverse=self.reverse, instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db, model=self.model, pk_set=missing_target_ids, using=db,
) )
# Add the ones that aren't there already. Conflicts can be # Add the ones that aren't there already.
# ignored when the intermediary model is auto-created as
# the only possible collision is on the (sid_id, tid_id)
# tuple. The same assertion doesn't hold for user-defined
# intermediary models as they could have other fields
# causing conflicts which must be surfaced.
ignore_conflicts = self.through._meta.auto_created is not False
self.through._default_manager.using(db).bulk_create([ self.through._default_manager.using(db).bulk_create([
self.through(**through_defaults, **{ self.through(**through_defaults, **{
'%s_id' % source_field_name: self.related_val[0], '%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id, '%s_id' % target_field_name: target_id,
}) })
for target_id in missing_target_ids for target_id in missing_target_ids
], ignore_conflicts=ignore_conflicts) ], ignore_conflicts=can_ignore_conflicts)
if self.reverse or source_field_name == self.source_field_name: if must_send_signals:
# Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries.
signals.m2m_changed.send( signals.m2m_changed.send(
sender=self.through, action='post_add', sender=self.through, action='post_add',
instance=self.instance, reverse=self.reverse, instance=self.instance, reverse=self.reverse,

View File

@ -118,13 +118,26 @@ class ManyToManyTests(TestCase):
) )
@skipUnlessDBFeature('supports_ignore_conflicts') @skipUnlessDBFeature('supports_ignore_conflicts')
def test_add_ignore_conflicts(self): def test_fast_add_ignore_conflicts(self):
"""
A single query is necessary to add auto-created through instances if
the database backend supports bulk_create(ignore_conflicts) and no
m2m_changed signals receivers are connected.
"""
with self.assertNumQueries(1):
self.a1.publications.add(self.p1, self.p2)
@skipUnlessDBFeature('supports_ignore_conflicts')
def test_slow_add_ignore_conflicts(self):
manager_cls = self.a1.publications.__class__ manager_cls = self.a1.publications.__class__
# Simulate a race condition between the missing ids retrieval and # Simulate a race condition between the missing ids retrieval and
# the bulk insertion attempt. # the bulk insertion attempt.
missing_target_ids = {self.p1.id} missing_target_ids = {self.p1.id}
# Disable fast-add to test the case where the slow add path is taken.
add_plan = (True, False, False)
with mock.patch.object(manager_cls, '_get_missing_target_ids', return_value=missing_target_ids) as mocked: with mock.patch.object(manager_cls, '_get_missing_target_ids', return_value=missing_target_ids) as mocked:
self.a1.publications.add(self.p1) with mock.patch.object(manager_cls, '_get_add_plan', return_value=add_plan):
self.a1.publications.add(self.p1)
mocked.assert_called_once() mocked.assert_called_once()
def test_related_sets(self): def test_related_sets(self):