Fixed #30382 -- Allowed specifying parent classes in force_insert of Model.save().

This commit is contained in:
Akash Kumar Sen 2023-06-22 18:23:11 +05:30 committed by Mariusz Felisiak
parent 601ffb0da3
commit a40b0103bc
6 changed files with 138 additions and 5 deletions

View File

@ -832,6 +832,26 @@ class Model(AltersData, metaclass=ModelBase):
asave.alters_data = True
@classmethod
def _validate_force_insert(cls, force_insert):
if force_insert is False:
return ()
if force_insert is True:
return (cls,)
if not isinstance(force_insert, tuple):
raise TypeError("force_insert must be a bool or tuple.")
for member in force_insert:
if not isinstance(member, ModelBase):
raise TypeError(
f"Invalid force_insert member. {member!r} must be a model subclass."
)
if not issubclass(cls, member):
raise TypeError(
f"Invalid force_insert member. {member.__qualname__} must be a "
f"base of {cls.__qualname__}."
)
return force_insert
def save_base(
self,
raw=False,
@ -873,7 +893,11 @@ class Model(AltersData, metaclass=ModelBase):
with context_manager:
parent_inserted = False
if not raw:
parent_inserted = self._save_parents(cls, using, update_fields)
# Validate force insert only when parents are inserted.
force_insert = self._validate_force_insert(force_insert)
parent_inserted = self._save_parents(
cls, using, update_fields, force_insert
)
updated = self._save_table(
raw,
cls,
@ -900,7 +924,9 @@ class Model(AltersData, metaclass=ModelBase):
save_base.alters_data = True
def _save_parents(self, cls, using, update_fields, updated_parents=None):
def _save_parents(
self, cls, using, update_fields, force_insert, updated_parents=None
):
"""Save all the parents of cls using values from self."""
meta = cls._meta
inserted = False
@ -919,13 +945,14 @@ class Model(AltersData, metaclass=ModelBase):
cls=parent,
using=using,
update_fields=update_fields,
force_insert=force_insert,
updated_parents=updated_parents,
)
updated = self._save_table(
cls=parent,
using=using,
update_fields=update_fields,
force_insert=parent_inserted,
force_insert=parent_inserted or issubclass(parent, force_insert),
)
if not updated:
inserted = True

View File

@ -589,6 +589,18 @@ row. In these cases you can pass the ``force_insert=True`` or
Passing both parameters is an error: you cannot both insert *and* update at the
same time!
When using :ref:`multi-table inheritance <multi-table-inheritance>`, it's also
possible to provide a tuple of parent classes to ``force_insert`` in order to
force ``INSERT`` statements for each base. For example::
Restaurant(pk=1, name="Bob's Cafe").save(force_insert=(Place,))
Restaurant(pk=1, name="Bob's Cafe", rating=4).save(force_insert=(Place, Rating))
You can pass ``force_insert=(models.Model,)`` to force an ``INSERT`` statement
for all parents. By default, ``force_insert=True`` only forces the insertion of
a new row for the current model.
It should be very rare that you'll need to use these parameters. Django will
almost always do the right thing and trying to override that will lead to
errors that are difficult to track down. This feature is for advanced use
@ -596,6 +608,11 @@ only.
Using ``update_fields`` will force an update similarly to ``force_update``.
.. versionchanged:: 5.0
Support for passing a tuple of parent classes to ``force_insert`` was
added.
.. _ref-models-field-updates-using-f-expressions:
Updating attributes based on existing fields

View File

@ -335,6 +335,10 @@ Models
:ref:`Choices classes <field-choices-enum-types>` directly instead of
requiring expansion with the ``choices`` attribute.
* The :ref:`force_insert <ref-models-force-insert>` argument of
:meth:`.Model.save` now allows specifying a tuple of parent classes that must
be forced to be inserted.
Pagination
~~~~~~~~~~

View File

@ -10,7 +10,7 @@ class RevisionableModel(models.Model):
title = models.CharField(blank=True, max_length=255)
when = models.DateTimeField(default=datetime.datetime.now)
def save(self, *args, force_insert=None, force_update=None, **kwargs):
def save(self, *args, force_insert=False, force_update=False, **kwargs):
super().save(
*args, force_insert=force_insert, force_update=force_update, **kwargs
)

View File

@ -30,3 +30,13 @@ class SubSubCounter(SubCounter):
class WithCustomPK(models.Model):
name = models.IntegerField(primary_key=True)
value = models.IntegerField()
class OtherSubCounter(Counter):
other_counter_ptr = models.OneToOneField(
Counter, primary_key=True, parent_link=True, on_delete=models.CASCADE
)
class DiamondSubSubCounter(SubCounter, OtherSubCounter):
pass

View File

@ -1,9 +1,11 @@
from django.db import DatabaseError, IntegrityError, transaction
from django.db import DatabaseError, IntegrityError, models, transaction
from django.test import TestCase
from .models import (
Counter,
DiamondSubSubCounter,
InheritedCounter,
OtherSubCounter,
ProxyCounter,
SubCounter,
SubSubCounter,
@ -76,6 +78,29 @@ class InheritanceTests(TestCase):
class ForceInsertInheritanceTests(TestCase):
def test_force_insert_not_bool_or_tuple(self):
msg = "force_insert must be a bool or tuple."
with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
Counter().save(force_insert=1)
with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
Counter().save(force_insert="test")
with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
Counter().save(force_insert=[])
def test_force_insert_not_model(self):
msg = f"Invalid force_insert member. {object!r} must be a model subclass."
with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
Counter().save(force_insert=(object,))
instance = Counter()
msg = f"Invalid force_insert member. {instance!r} must be a model subclass."
with self.assertRaisesMessage(TypeError, msg), transaction.atomic():
Counter().save(force_insert=(instance,))
def test_force_insert_not_base(self):
msg = "Invalid force_insert member. SubCounter must be a base of Counter."
with self.assertRaisesMessage(TypeError, msg):
Counter().save(force_insert=(SubCounter,))
def test_force_insert_false(self):
with self.assertNumQueries(3):
obj = SubCounter.objects.create(pk=1, value=0)
@ -87,6 +112,10 @@ class ForceInsertInheritanceTests(TestCase):
SubCounter(pk=obj.pk, value=2).save(force_insert=False)
obj.refresh_from_db()
self.assertEqual(obj.value, 2)
with self.assertNumQueries(2):
SubCounter(pk=obj.pk, value=3).save(force_insert=())
obj.refresh_from_db()
self.assertEqual(obj.value, 3)
def test_force_insert_false_with_existing_parent(self):
parent = Counter.objects.create(pk=1, value=1)
@ -96,13 +125,59 @@ class ForceInsertInheritanceTests(TestCase):
def test_force_insert_parent(self):
with self.assertNumQueries(3):
SubCounter(pk=1, value=1).save(force_insert=True)
# Force insert a new parent and don't UPDATE first.
with self.assertNumQueries(2):
SubCounter(pk=2, value=1).save(force_insert=(Counter,))
with self.assertNumQueries(2):
SubCounter(pk=3, value=1).save(force_insert=(models.Model,))
def test_force_insert_with_grandparent(self):
with self.assertNumQueries(4):
SubSubCounter(pk=1, value=1).save(force_insert=True)
# Force insert parents on all levels and don't UPDATE first.
with self.assertNumQueries(3):
SubSubCounter(pk=2, value=1).save(force_insert=(models.Model,))
with self.assertNumQueries(3):
SubSubCounter(pk=3, value=1).save(force_insert=(Counter,))
# Force insert only the last parent.
with self.assertNumQueries(4):
SubSubCounter(pk=4, value=1).save(force_insert=(SubCounter,))
def test_force_insert_with_existing_grandparent(self):
# Force insert only the last child.
grandparent = Counter.objects.create(pk=1, value=1)
with self.assertNumQueries(4):
SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=True)
# Force insert a parent, and don't force insert a grandparent.
grandparent = Counter.objects.create(pk=2, value=1)
with self.assertNumQueries(3):
SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=(SubCounter,))
# Force insert parents on all levels, grandparent conflicts.
grandparent = Counter.objects.create(pk=3, value=1)
with self.assertRaises(IntegrityError), transaction.atomic():
SubSubCounter(pk=grandparent.pk, value=1).save(force_insert=(Counter,))
def test_force_insert_diamond_mti(self):
# Force insert all parents.
with self.assertNumQueries(4):
DiamondSubSubCounter(pk=1, value=1).save(
force_insert=(Counter, SubCounter, OtherSubCounter)
)
with self.assertNumQueries(4):
DiamondSubSubCounter(pk=2, value=1).save(force_insert=(models.Model,))
# Force insert parents, and don't force insert a common grandparent.
with self.assertNumQueries(5):
DiamondSubSubCounter(pk=3, value=1).save(
force_insert=(SubCounter, OtherSubCounter)
)
grandparent = Counter.objects.create(pk=4, value=1)
with self.assertNumQueries(4):
DiamondSubSubCounter(pk=grandparent.pk, value=1).save(
force_insert=(SubCounter, OtherSubCounter),
)
# Force insert all parents, grandparent conflicts.
grandparent = Counter.objects.create(pk=5, value=1)
with self.assertRaises(IntegrityError), transaction.atomic():
DiamondSubSubCounter(pk=grandparent.pk, value=1).save(
force_insert=(models.Model,)
)