Fixed #29497 -- Fixed loss of assigned parent when saving child with bulk_create() after parent.

This commit is contained in:
Hannes Ljungberg 2020-10-22 16:50:37 +02:00 committed by Mariusz Felisiak
parent 83a8da576d
commit 10f8b82d19
3 changed files with 65 additions and 34 deletions

View File

@ -679,38 +679,7 @@ class Model(metaclass=ModelBase):
that the "save" must be an SQL insert or update (or equivalent for that the "save" must be an SQL insert or update (or equivalent for
non-SQL backends), respectively. Normally, they should not be set. non-SQL backends), respectively. Normally, they should not be set.
""" """
# Ensure that a model instance without a PK hasn't been assigned to self._prepare_related_fields_for_save(operation_name='save')
# a ForeignKey or OneToOneField on this model. If the field is
# nullable, allowing the save() would result in silent data loss.
for field in self._meta.concrete_fields:
# If the related field isn't cached, then an instance hasn't
# been assigned and there's no need to worry about this check.
if field.is_relation and field.is_cached(self):
obj = getattr(self, field.name, None)
if not obj:
continue
# A pk may have been assigned manually to a model instance not
# saved to the database (or auto-generated in a case like
# UUIDField), but we allow the save to proceed and rely on the
# database to raise an IntegrityError if applicable. If
# constraints aren't supported by the database, there's the
# unavoidable risk of data corruption.
if obj.pk is None:
# Remove the object from a related instance cache.
if not field.remote_field.multiple:
field.remote_field.delete_cached_value(obj)
raise ValueError(
"save() prohibited to prevent data loss due to "
"unsaved related object '%s'." % field.name
)
elif getattr(self, field.attname) is None:
# Use pk from related object if it has been saved after
# an assignment.
setattr(self, field.attname, obj.pk)
# If the relationship's pk/to_field was changed, clear the
# cached relationship.
if getattr(obj, field.target_field.attname) != getattr(self, field.attname):
field.delete_cached_value(self)
using = using or router.db_for_write(self.__class__, instance=self) using = using or router.db_for_write(self.__class__, instance=self)
if force_insert and (force_update or update_fields): if force_insert and (force_update or update_fields):
@ -939,6 +908,40 @@ class Model(metaclass=ModelBase):
using=using, raw=raw, using=using, raw=raw,
) )
def _prepare_related_fields_for_save(self, operation_name):
# Ensure that a model instance without a PK hasn't been assigned to
# a ForeignKey or OneToOneField on this model. If the field is
# nullable, allowing the save would result in silent data loss.
for field in self._meta.concrete_fields:
# If the related field isn't cached, then an instance hasn't been
# assigned and there's no need to worry about this check.
if field.is_relation and field.is_cached(self):
obj = getattr(self, field.name, None)
if not obj:
continue
# A pk may have been assigned manually to a model instance not
# saved to the database (or auto-generated in a case like
# UUIDField), but we allow the save to proceed and rely on the
# database to raise an IntegrityError if applicable. If
# constraints aren't supported by the database, there's the
# unavoidable risk of data corruption.
if obj.pk is None:
# Remove the object from a related instance cache.
if not field.remote_field.multiple:
field.remote_field.delete_cached_value(obj)
raise ValueError(
"%s() prohibited to prevent data loss due to unsaved "
"related object '%s'." % (operation_name, field.name)
)
elif getattr(self, field.attname) is None:
# Use pk from related object if it has been saved after
# an assignment.
setattr(self, field.attname, obj.pk)
# If the relationship's pk/to_field was changed, clear the
# cached relationship.
if getattr(obj, field.target_field.attname) != getattr(self, field.attname):
field.delete_cached_value(self)
def delete(self, using=None, keep_parents=False): def delete(self, using=None, keep_parents=False):
using = using or router.db_for_write(self.__class__, instance=self) using = using or router.db_for_write(self.__class__, instance=self)
assert self.pk is not None, ( assert self.pk is not None, (

View File

@ -453,10 +453,12 @@ class QuerySet:
obj.save(force_insert=True, using=self.db) obj.save(force_insert=True, using=self.db)
return obj return obj
def _populate_pk_values(self, objs): def _prepare_for_bulk_create(self, objs):
for obj in objs: for obj in objs:
if obj.pk is None: if obj.pk is None:
# Populate new PK values.
obj.pk = obj._meta.pk.get_pk_value_on_save(obj) obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
obj._prepare_related_fields_for_save(operation_name='bulk_create')
def bulk_create(self, objs, batch_size=None, ignore_conflicts=False): def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
""" """
@ -493,7 +495,7 @@ class QuerySet:
opts = self.model._meta opts = self.model._meta
fields = opts.concrete_fields fields = opts.concrete_fields
objs = list(objs) objs = list(objs)
self._populate_pk_values(objs) self._prepare_for_bulk_create(objs)
with transaction.atomic(using=self.db, savepoint=False): with transaction.atomic(using=self.db, savepoint=False):
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
if objs_with_pk: if objs_with_pk:

View File

@ -321,3 +321,29 @@ class BulkCreateTests(TestCase):
# Without ignore_conflicts=True, there's a problem. # Without ignore_conflicts=True, there's a problem.
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
TwoFields.objects.bulk_create(conflicting_objects) TwoFields.objects.bulk_create(conflicting_objects)
def test_nullable_fk_after_parent(self):
parent = NoFields()
child = NullableFields(auto_field=parent, integer_field=88)
parent.save()
NullableFields.objects.bulk_create([child])
child = NullableFields.objects.get(integer_field=88)
self.assertEqual(child.auto_field, parent)
@skipUnlessDBFeature('can_return_rows_from_bulk_insert')
def test_nullable_fk_after_parent_bulk_create(self):
parent = NoFields()
child = NullableFields(auto_field=parent, integer_field=88)
NoFields.objects.bulk_create([parent])
NullableFields.objects.bulk_create([child])
child = NullableFields.objects.get(integer_field=88)
self.assertEqual(child.auto_field, parent)
def test_unsaved_parent(self):
parent = NoFields()
msg = (
"bulk_create() prohibited to prevent data loss due to unsaved "
"related object 'auto_field'."
)
with self.assertRaisesMessage(ValueError, msg):
NullableFields.objects.bulk_create([NullableFields(auto_field=parent)])