Fixed #32095 -- Made QuerySet.update_or_create() save only fields passed in defaults or with custom pre_save().

Thanks Florian Apolloner for the initial patch.
This commit is contained in:
sarahboyce 2022-09-27 15:26:02 +02:00 committed by Mariusz Felisiak
parent 1d77b931f7
commit 6cc0f22a73
3 changed files with 47 additions and 2 deletions

View File

@ -20,7 +20,7 @@ from django.db import (
router, router,
transaction, transaction,
) )
from django.db.models import AutoField, DateField, DateTimeField, sql from django.db.models import AutoField, DateField, DateTimeField, Field, sql
from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.expressions import Case, F, Ref, Value, When from django.db.models.expressions import Case, F, Ref, Value, When
@ -963,7 +963,25 @@ class QuerySet:
return obj, created return obj, created
for k, v in resolve_callables(defaults): for k, v in resolve_callables(defaults):
setattr(obj, k, v) setattr(obj, k, v)
obj.save(using=self.db)
update_fields = set(defaults)
concrete_field_names = self.model._meta._non_pk_concrete_field_names
# update_fields does not support non-concrete fields.
if concrete_field_names.issuperset(update_fields):
# Add fields which are set on pre_save(), e.g. auto_now fields.
# This is to maintain backward compatibility as these fields
# are not updated unless explicitly specified in the
# update_fields list.
for field in self.model._meta.local_concrete_fields:
if not (
field.primary_key or field.__class__.pre_save is Field.pre_save
):
update_fields.add(field.name)
if field.name != field.attname:
update_fields.add(field.attname)
obj.save(using=self.db, update_fields=update_fields)
else:
obj.save(using=self.db)
return obj, False return obj, False
async def aupdate_or_create(self, defaults=None, **kwargs): async def aupdate_or_create(self, defaults=None, **kwargs):

View File

@ -63,3 +63,4 @@ class Book(models.Model):
related_name="books", related_name="books",
db_column="publisher_id_column", db_column="publisher_id_column",
) )
updated = models.DateTimeField(auto_now=True)

View File

@ -6,6 +6,7 @@ from threading import Thread
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import DatabaseError, IntegrityError, connection from django.db import DatabaseError, IntegrityError, connection
from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
from django.utils.functional import lazy from django.utils.functional import lazy
from .models import ( from .models import (
@ -513,6 +514,31 @@ class UpdateOrCreateTests(TestCase):
self.assertIs(created, False) self.assertIs(created, False)
self.assertEqual(journalist.name, "John") self.assertEqual(journalist.name, "John")
def test_update_only_defaults_and_pre_save_fields_when_local_fields(self):
publisher = Publisher.objects.create(name="Acme Publishing")
book = Book.objects.create(publisher=publisher, name="The Book of Ed & Fred")
for defaults in [{"publisher": publisher}, {"publisher_id": publisher}]:
with self.subTest(defaults=defaults):
with CaptureQueriesContext(connection) as captured_queries:
book, created = Book.objects.update_or_create(
pk=book.pk,
defaults=defaults,
)
self.assertIs(created, False)
update_sqls = [
q["sql"] for q in captured_queries if q["sql"].startswith("UPDATE")
]
self.assertEqual(len(update_sqls), 1)
update_sql = update_sqls[0]
self.assertIsNotNone(update_sql)
self.assertIn(
connection.ops.quote_name("publisher_id_column"), update_sql
)
self.assertIn(connection.ops.quote_name("updated"), update_sql)
# Name should not be updated.
self.assertNotIn(connection.ops.quote_name("name"), update_sql)
class UpdateOrCreateTestsWithManualPKs(TestCase): class UpdateOrCreateTestsWithManualPKs(TestCase):
def test_create_with_duplicate_primary_key(self): def test_create_with_duplicate_primary_key(self):