Simplified RelatedManager._add_items() a bit.

Added early return in RelatedManager._add_items() to decrease an
indentation level.
This commit is contained in:
Baptiste Mispelon 2019-11-29 18:19:29 +01:00 committed by Mariusz Felisiak
parent 6c0341f127
commit c50839fccf
1 changed files with 39 additions and 40 deletions

View File

@ -1113,49 +1113,48 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# source_field_name: the PK fieldname in join table for the source object # source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object # target_field_name: the PK fieldname in join table for the target object
# *objs - objects to add. Either object instances, or primary keys of object instances. # *objs - objects to add. Either object instances, or primary keys of object instances.
if not objs:
return
through_defaults = through_defaults or {} through_defaults = through_defaults or {}
target_ids = self._get_target_ids(target_field_name, objs)
db = router.db_for_write(self.through, instance=self.instance)
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name)
if can_fast_add:
self.through._default_manager.using(db).bulk_create([
self.through(**{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in target_ids
], ignore_conflicts=True)
return
# If there aren't any objects, there is nothing to do. missing_target_ids = self._get_missing_target_ids(
if objs: source_field_name, target_field_name, db, target_ids
target_ids = self._get_target_ids(target_field_name, objs) )
db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False):
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name) if must_send_signals:
if can_fast_add: signals.m2m_changed.send(
self.through._default_manager.using(db).bulk_create([ sender=self.through, action='pre_add',
self.through(**{ instance=self.instance, reverse=self.reverse,
'%s_id' % source_field_name: self.related_val[0], model=self.model, pk_set=missing_target_ids, using=db,
'%s_id' % target_field_name: target_id, )
}) # Add the ones that aren't there already.
for target_id in target_ids self.through._default_manager.using(db).bulk_create([
], ignore_conflicts=True) self.through(**through_defaults, **{
return '%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in missing_target_ids
], ignore_conflicts=can_ignore_conflicts)
missing_target_ids = self._get_missing_target_ids( if must_send_signals:
source_field_name, target_field_name, db, target_ids signals.m2m_changed.send(
) sender=self.through, action='post_add',
with transaction.atomic(using=db, savepoint=False): instance=self.instance, reverse=self.reverse,
if must_send_signals: model=self.model, pk_set=missing_target_ids, using=db,
signals.m2m_changed.send( )
sender=self.through, action='pre_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db,
)
# Add the ones that aren't there already.
self.through._default_manager.using(db).bulk_create([
self.through(**through_defaults, **{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in missing_target_ids
], ignore_conflicts=can_ignore_conflicts)
if must_send_signals:
signals.m2m_changed.send(
sender=self.through, action='post_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db,
)
def _remove_items(self, source_field_name, target_field_name, *objs): def _remove_items(self, source_field_name, target_field_name, *objs):
# source_field_name: the PK colname in join table for the source object # source_field_name: the PK colname in join table for the source object