Refs #19544 -- Extracted ManyRelatedManager.add() missing ids logic to a method.

This commit is contained in:
Simon Charette 2019-02-15 01:00:06 -05:00 committed by Tim Graham
parent 0ac4e51b2c
commit dd32f9a3a2
1 changed files with 43 additions and 36 deletions

View File

@ -1051,6 +1051,44 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
return obj, created
update_or_create.alters_data = True
def _get_missing_target_ids(self, source_field_name, target_field_name, db, objs):
"""
Return the subset of ids of `objs` that aren't already assigned to
this relationship.
"""
from django.db.models import Model
target_ids = set()
target_field = self.through._meta.get_field(target_field_name)
for obj in objs:
if isinstance(obj, self.model):
if not router.allow_relation(obj, self.instance):
raise ValueError(
'Cannot add "%r": instance is on database "%s", '
'value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db)
)
target_id = target_field.get_foreign_related_value(obj)[0]
if target_id is None:
raise ValueError(
'Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name)
)
target_ids.add(target_id)
elif isinstance(obj, Model):
raise TypeError(
"'%s' instance expected, got %r" %
(self.model._meta.object_name, obj)
)
else:
target_ids.add(obj)
vals = self.through._default_manager.using(db).values_list(
target_field_name, flat=True
).filter(**{
source_field_name: self.related_val[0],
'%s__in' % target_field_name: target_ids,
})
return target_ids.difference(vals)
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
# source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object
@ -1058,40 +1096,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
through_defaults = through_defaults or {}
# If there aren't any objects, there is nothing to do.
from django.db.models import Model
if objs:
new_ids = set()
for obj in objs:
if isinstance(obj, self.model):
if not router.allow_relation(obj, self.instance):
raise ValueError(
'Cannot add "%r": instance is on database "%s", value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db)
)
fk_val = self.through._meta.get_field(
target_field_name).get_foreign_related_value(obj)[0]
if fk_val is None:
raise ValueError(
'Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name)
)
new_ids.add(fk_val)
elif isinstance(obj, Model):
raise TypeError(
"'%s' instance expected, got %r" %
(self.model._meta.object_name, obj)
)
else:
new_ids.add(obj)
db = router.db_for_write(self.through, instance=self.instance)
vals = (self.through._default_manager.using(db)
.values_list(target_field_name, flat=True)
.filter(**{
source_field_name: self.related_val[0],
'%s__in' % target_field_name: new_ids,
}))
new_ids.difference_update(vals)
missing_target_ids = self._get_missing_target_ids(source_field_name, target_field_name, db, objs)
with transaction.atomic(using=db, savepoint=False):
if self.reverse or source_field_name == self.source_field_name:
@ -1100,16 +1107,16 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
signals.m2m_changed.send(
sender=self.through, action='pre_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=new_ids, using=db,
model=self.model, pk_set=missing_target_ids, using=db,
)
# Add the ones that aren't there already
self.through._default_manager.using(db).bulk_create([
self.through(**through_defaults, **{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: obj_id,
'%s_id' % target_field_name: target_id,
})
for obj_id in new_ids
for target_id in missing_target_ids
])
if self.reverse or source_field_name == self.source_field_name:
@ -1118,7 +1125,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
signals.m2m_changed.send(
sender=self.through, action='post_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=new_ids, using=db,
model=self.model, pk_set=missing_target_ids, using=db,
)
def _remove_items(self, source_field_name, target_field_name, *objs):