From 10f8b82d195caa3745ba37d9424893763f89653e Mon Sep 17 00:00:00 2001 From: Hannes Ljungberg Date: Thu, 22 Oct 2020 16:50:37 +0200 Subject: [PATCH] Fixed #29497 -- Fixed loss of assigned parent when saving child with bulk_create() after parent. --- django/db/models/base.py | 67 ++++++++++++++++++++------------------ django/db/models/query.py | 6 ++-- tests/bulk_create/tests.py | 26 +++++++++++++++ 3 files changed, 65 insertions(+), 34 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index 97d1eecbc8f..de044886c3e 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -679,38 +679,7 @@ class Model(metaclass=ModelBase): that the "save" must be an SQL insert or update (or equivalent for non-SQL backends), respectively. Normally, they should not be set. """ - # Ensure that a model instance without a PK hasn't been assigned to - # a ForeignKey or OneToOneField on this model. If the field is - # nullable, allowing the save() would result in silent data loss. - for field in self._meta.concrete_fields: - # If the related field isn't cached, then an instance hasn't - # been assigned and there's no need to worry about this check. - if field.is_relation and field.is_cached(self): - obj = getattr(self, field.name, None) - if not obj: - continue - # A pk may have been assigned manually to a model instance not - # saved to the database (or auto-generated in a case like - # UUIDField), but we allow the save to proceed and rely on the - # database to raise an IntegrityError if applicable. If - # constraints aren't supported by the database, there's the - # unavoidable risk of data corruption. - if obj.pk is None: - # Remove the object from a related instance cache. - if not field.remote_field.multiple: - field.remote_field.delete_cached_value(obj) - raise ValueError( - "save() prohibited to prevent data loss due to " - "unsaved related object '%s'." % field.name - ) - elif getattr(self, field.attname) is None: - # Use pk from related object if it has been saved after - # an assignment. - setattr(self, field.attname, obj.pk) - # If the relationship's pk/to_field was changed, clear the - # cached relationship. - if getattr(obj, field.target_field.attname) != getattr(self, field.attname): - field.delete_cached_value(self) + self._prepare_related_fields_for_save(operation_name='save') using = using or router.db_for_write(self.__class__, instance=self) if force_insert and (force_update or update_fields): @@ -939,6 +908,40 @@ class Model(metaclass=ModelBase): using=using, raw=raw, ) + def _prepare_related_fields_for_save(self, operation_name): + # Ensure that a model instance without a PK hasn't been assigned to + # a ForeignKey or OneToOneField on this model. If the field is + # nullable, allowing the save would result in silent data loss. + for field in self._meta.concrete_fields: + # If the related field isn't cached, then an instance hasn't been + # assigned and there's no need to worry about this check. + if field.is_relation and field.is_cached(self): + obj = getattr(self, field.name, None) + if not obj: + continue + # A pk may have been assigned manually to a model instance not + # saved to the database (or auto-generated in a case like + # UUIDField), but we allow the save to proceed and rely on the + # database to raise an IntegrityError if applicable. If + # constraints aren't supported by the database, there's the + # unavoidable risk of data corruption. + if obj.pk is None: + # Remove the object from a related instance cache. + if not field.remote_field.multiple: + field.remote_field.delete_cached_value(obj) + raise ValueError( + "%s() prohibited to prevent data loss due to unsaved " + "related object '%s'." % (operation_name, field.name) + ) + elif getattr(self, field.attname) is None: + # Use pk from related object if it has been saved after + # an assignment. + setattr(self, field.attname, obj.pk) + # If the relationship's pk/to_field was changed, clear the + # cached relationship. + if getattr(obj, field.target_field.attname) != getattr(self, field.attname): + field.delete_cached_value(self) + def delete(self, using=None, keep_parents=False): using = using or router.db_for_write(self.__class__, instance=self) assert self.pk is not None, ( diff --git a/django/db/models/query.py b/django/db/models/query.py index 53238ed60be..2c2b5d08837 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -453,10 +453,12 @@ class QuerySet: obj.save(force_insert=True, using=self.db) return obj - def _populate_pk_values(self, objs): + def _prepare_for_bulk_create(self, objs): for obj in objs: if obj.pk is None: + # Populate new PK values. obj.pk = obj._meta.pk.get_pk_value_on_save(obj) + obj._prepare_related_fields_for_save(operation_name='bulk_create') def bulk_create(self, objs, batch_size=None, ignore_conflicts=False): """ @@ -493,7 +495,7 @@ class QuerySet: opts = self.model._meta fields = opts.concrete_fields objs = list(objs) - self._populate_pk_values(objs) + self._prepare_for_bulk_create(objs) with transaction.atomic(using=self.db, savepoint=False): objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) if objs_with_pk: diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index 2b1d901e312..df764f945fc 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -321,3 +321,29 @@ class BulkCreateTests(TestCase): # Without ignore_conflicts=True, there's a problem. with self.assertRaises(IntegrityError): TwoFields.objects.bulk_create(conflicting_objects) + + def test_nullable_fk_after_parent(self): + parent = NoFields() + child = NullableFields(auto_field=parent, integer_field=88) + parent.save() + NullableFields.objects.bulk_create([child]) + child = NullableFields.objects.get(integer_field=88) + self.assertEqual(child.auto_field, parent) + + @skipUnlessDBFeature('can_return_rows_from_bulk_insert') + def test_nullable_fk_after_parent_bulk_create(self): + parent = NoFields() + child = NullableFields(auto_field=parent, integer_field=88) + NoFields.objects.bulk_create([parent]) + NullableFields.objects.bulk_create([child]) + child = NullableFields.objects.get(integer_field=88) + self.assertEqual(child.auto_field, parent) + + def test_unsaved_parent(self): + parent = NoFields() + msg = ( + "bulk_create() prohibited to prevent data loss due to unsaved " + "related object 'auto_field'." + ) + with self.assertRaisesMessage(ValueError, msg): + NullableFields.objects.bulk_create([NullableFields(auto_field=parent)])