Refs #19544 -- Extracted ManyRelatedManager.add() missing ids logic to a method.
This commit is contained in:
parent
0ac4e51b2c
commit
dd32f9a3a2
|
@ -1051,6 +1051,44 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||||
return obj, created
|
return obj, created
|
||||||
update_or_create.alters_data = True
|
update_or_create.alters_data = True
|
||||||
|
|
||||||
|
def _get_missing_target_ids(self, source_field_name, target_field_name, db, objs):
|
||||||
|
"""
|
||||||
|
Return the subset of ids of `objs` that aren't already assigned to
|
||||||
|
this relationship.
|
||||||
|
"""
|
||||||
|
from django.db.models import Model
|
||||||
|
target_ids = set()
|
||||||
|
target_field = self.through._meta.get_field(target_field_name)
|
||||||
|
for obj in objs:
|
||||||
|
if isinstance(obj, self.model):
|
||||||
|
if not router.allow_relation(obj, self.instance):
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot add "%r": instance is on database "%s", '
|
||||||
|
'value is on database "%s"' %
|
||||||
|
(obj, self.instance._state.db, obj._state.db)
|
||||||
|
)
|
||||||
|
target_id = target_field.get_foreign_related_value(obj)[0]
|
||||||
|
if target_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot add "%r": the value for field "%s" is None' %
|
||||||
|
(obj, target_field_name)
|
||||||
|
)
|
||||||
|
target_ids.add(target_id)
|
||||||
|
elif isinstance(obj, Model):
|
||||||
|
raise TypeError(
|
||||||
|
"'%s' instance expected, got %r" %
|
||||||
|
(self.model._meta.object_name, obj)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
target_ids.add(obj)
|
||||||
|
vals = self.through._default_manager.using(db).values_list(
|
||||||
|
target_field_name, flat=True
|
||||||
|
).filter(**{
|
||||||
|
source_field_name: self.related_val[0],
|
||||||
|
'%s__in' % target_field_name: target_ids,
|
||||||
|
})
|
||||||
|
return target_ids.difference(vals)
|
||||||
|
|
||||||
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
|
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
|
||||||
# source_field_name: the PK fieldname in join table for the source object
|
# source_field_name: the PK fieldname in join table for the source object
|
||||||
# target_field_name: the PK fieldname in join table for the target object
|
# target_field_name: the PK fieldname in join table for the target object
|
||||||
|
@ -1058,40 +1096,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||||
through_defaults = through_defaults or {}
|
through_defaults = through_defaults or {}
|
||||||
|
|
||||||
# If there aren't any objects, there is nothing to do.
|
# If there aren't any objects, there is nothing to do.
|
||||||
from django.db.models import Model
|
|
||||||
if objs:
|
if objs:
|
||||||
new_ids = set()
|
|
||||||
for obj in objs:
|
|
||||||
if isinstance(obj, self.model):
|
|
||||||
if not router.allow_relation(obj, self.instance):
|
|
||||||
raise ValueError(
|
|
||||||
'Cannot add "%r": instance is on database "%s", value is on database "%s"' %
|
|
||||||
(obj, self.instance._state.db, obj._state.db)
|
|
||||||
)
|
|
||||||
fk_val = self.through._meta.get_field(
|
|
||||||
target_field_name).get_foreign_related_value(obj)[0]
|
|
||||||
if fk_val is None:
|
|
||||||
raise ValueError(
|
|
||||||
'Cannot add "%r": the value for field "%s" is None' %
|
|
||||||
(obj, target_field_name)
|
|
||||||
)
|
|
||||||
new_ids.add(fk_val)
|
|
||||||
elif isinstance(obj, Model):
|
|
||||||
raise TypeError(
|
|
||||||
"'%s' instance expected, got %r" %
|
|
||||||
(self.model._meta.object_name, obj)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_ids.add(obj)
|
|
||||||
|
|
||||||
db = router.db_for_write(self.through, instance=self.instance)
|
db = router.db_for_write(self.through, instance=self.instance)
|
||||||
vals = (self.through._default_manager.using(db)
|
missing_target_ids = self._get_missing_target_ids(source_field_name, target_field_name, db, objs)
|
||||||
.values_list(target_field_name, flat=True)
|
|
||||||
.filter(**{
|
|
||||||
source_field_name: self.related_val[0],
|
|
||||||
'%s__in' % target_field_name: new_ids,
|
|
||||||
}))
|
|
||||||
new_ids.difference_update(vals)
|
|
||||||
|
|
||||||
with transaction.atomic(using=db, savepoint=False):
|
with transaction.atomic(using=db, savepoint=False):
|
||||||
if self.reverse or source_field_name == self.source_field_name:
|
if self.reverse or source_field_name == self.source_field_name:
|
||||||
|
@ -1100,16 +1107,16 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||||
signals.m2m_changed.send(
|
signals.m2m_changed.send(
|
||||||
sender=self.through, action='pre_add',
|
sender=self.through, action='pre_add',
|
||||||
instance=self.instance, reverse=self.reverse,
|
instance=self.instance, reverse=self.reverse,
|
||||||
model=self.model, pk_set=new_ids, using=db,
|
model=self.model, pk_set=missing_target_ids, using=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the ones that aren't there already
|
# Add the ones that aren't there already
|
||||||
self.through._default_manager.using(db).bulk_create([
|
self.through._default_manager.using(db).bulk_create([
|
||||||
self.through(**through_defaults, **{
|
self.through(**through_defaults, **{
|
||||||
'%s_id' % source_field_name: self.related_val[0],
|
'%s_id' % source_field_name: self.related_val[0],
|
||||||
'%s_id' % target_field_name: obj_id,
|
'%s_id' % target_field_name: target_id,
|
||||||
})
|
})
|
||||||
for obj_id in new_ids
|
for target_id in missing_target_ids
|
||||||
])
|
])
|
||||||
|
|
||||||
if self.reverse or source_field_name == self.source_field_name:
|
if self.reverse or source_field_name == self.source_field_name:
|
||||||
|
@ -1118,7 +1125,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||||
signals.m2m_changed.send(
|
signals.m2m_changed.send(
|
||||||
sender=self.through, action='post_add',
|
sender=self.through, action='post_add',
|
||||||
instance=self.instance, reverse=self.reverse,
|
instance=self.instance, reverse=self.reverse,
|
||||||
model=self.model, pk_set=new_ids, using=db,
|
model=self.model, pk_set=missing_target_ids, using=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _remove_items(self, source_field_name, target_field_name, *objs):
|
def _remove_items(self, source_field_name, target_field_name, *objs):
|
||||||
|
|
Loading…
Reference in New Issue