From dd32f9a3a21272e784d434a6f9ca9f07aeedb50a Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Fri, 15 Feb 2019 01:00:06 -0500 Subject: [PATCH] Refs #19544 -- Extracted ManyRelatedManager.add() missing ids logic to a method. --- .../db/models/fields/related_descriptors.py | 79 ++++++++++--------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 02b5bcb62ee..96b8a375833 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -1051,6 +1051,44 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): return obj, created update_or_create.alters_data = True + def _get_missing_target_ids(self, source_field_name, target_field_name, db, objs): + """ + Return the subset of ids of `objs` that aren't already assigned to + this relationship. + """ + from django.db.models import Model + target_ids = set() + target_field = self.through._meta.get_field(target_field_name) + for obj in objs: + if isinstance(obj, self.model): + if not router.allow_relation(obj, self.instance): + raise ValueError( + 'Cannot add "%r": instance is on database "%s", ' + 'value is on database "%s"' % + (obj, self.instance._state.db, obj._state.db) + ) + target_id = target_field.get_foreign_related_value(obj)[0] + if target_id is None: + raise ValueError( + 'Cannot add "%r": the value for field "%s" is None' % + (obj, target_field_name) + ) + target_ids.add(target_id) + elif isinstance(obj, Model): + raise TypeError( + "'%s' instance expected, got %r" % + (self.model._meta.object_name, obj) + ) + else: + target_ids.add(obj) + vals = self.through._default_manager.using(db).values_list( + target_field_name, flat=True + ).filter(**{ + source_field_name: self.related_val[0], + '%s__in' % target_field_name: target_ids, + }) + return target_ids.difference(vals) + def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None): # source_field_name: the PK fieldname in join table for the source object # target_field_name: the PK fieldname in join table for the target object @@ -1058,40 +1096,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): through_defaults = through_defaults or {} # If there aren't any objects, there is nothing to do. - from django.db.models import Model if objs: - new_ids = set() - for obj in objs: - if isinstance(obj, self.model): - if not router.allow_relation(obj, self.instance): - raise ValueError( - 'Cannot add "%r": instance is on database "%s", value is on database "%s"' % - (obj, self.instance._state.db, obj._state.db) - ) - fk_val = self.through._meta.get_field( - target_field_name).get_foreign_related_value(obj)[0] - if fk_val is None: - raise ValueError( - 'Cannot add "%r": the value for field "%s" is None' % - (obj, target_field_name) - ) - new_ids.add(fk_val) - elif isinstance(obj, Model): - raise TypeError( - "'%s' instance expected, got %r" % - (self.model._meta.object_name, obj) - ) - else: - new_ids.add(obj) - db = router.db_for_write(self.through, instance=self.instance) - vals = (self.through._default_manager.using(db) - .values_list(target_field_name, flat=True) - .filter(**{ - source_field_name: self.related_val[0], - '%s__in' % target_field_name: new_ids, - })) - new_ids.difference_update(vals) + missing_target_ids = self._get_missing_target_ids(source_field_name, target_field_name, db, objs) with transaction.atomic(using=db, savepoint=False): if self.reverse or source_field_name == self.source_field_name: @@ -1100,16 +1107,16 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): signals.m2m_changed.send( sender=self.through, action='pre_add', instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=new_ids, using=db, + model=self.model, pk_set=missing_target_ids, using=db, ) # Add the ones that aren't there already self.through._default_manager.using(db).bulk_create([ self.through(**through_defaults, **{ '%s_id' % source_field_name: self.related_val[0], - '%s_id' % target_field_name: obj_id, + '%s_id' % target_field_name: target_id, }) - for obj_id in new_ids + for target_id in missing_target_ids ]) if self.reverse or source_field_name == self.source_field_name: @@ -1118,7 +1125,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): signals.m2m_changed.send( sender=self.through, action='post_add', instance=self.instance, reverse=self.reverse, - model=self.model, pk_set=new_ids, using=db, + model=self.model, pk_set=missing_target_ids, using=db, ) def _remove_items(self, source_field_name, target_field_name, *objs):