Fixed #32095 -- Made QuerySet.update_or_create() save only fields passed in defaults or with custom pre_save().
Thanks Florian Apolloner for the initial patch.
This commit is contained in:
parent
1d77b931f7
commit
6cc0f22a73
|
@ -20,7 +20,7 @@ from django.db import (
|
||||||
router,
|
router,
|
||||||
transaction,
|
transaction,
|
||||||
)
|
)
|
||||||
from django.db.models import AutoField, DateField, DateTimeField, sql
|
from django.db.models import AutoField, DateField, DateTimeField, Field, sql
|
||||||
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
from django.db.models.constants import LOOKUP_SEP, OnConflict
|
||||||
from django.db.models.deletion import Collector
|
from django.db.models.deletion import Collector
|
||||||
from django.db.models.expressions import Case, F, Ref, Value, When
|
from django.db.models.expressions import Case, F, Ref, Value, When
|
||||||
|
@ -963,6 +963,24 @@ class QuerySet:
|
||||||
return obj, created
|
return obj, created
|
||||||
for k, v in resolve_callables(defaults):
|
for k, v in resolve_callables(defaults):
|
||||||
setattr(obj, k, v)
|
setattr(obj, k, v)
|
||||||
|
|
||||||
|
update_fields = set(defaults)
|
||||||
|
concrete_field_names = self.model._meta._non_pk_concrete_field_names
|
||||||
|
# update_fields does not support non-concrete fields.
|
||||||
|
if concrete_field_names.issuperset(update_fields):
|
||||||
|
# Add fields which are set on pre_save(), e.g. auto_now fields.
|
||||||
|
# This is to maintain backward compatibility as these fields
|
||||||
|
# are not updated unless explicitly specified in the
|
||||||
|
# update_fields list.
|
||||||
|
for field in self.model._meta.local_concrete_fields:
|
||||||
|
if not (
|
||||||
|
field.primary_key or field.__class__.pre_save is Field.pre_save
|
||||||
|
):
|
||||||
|
update_fields.add(field.name)
|
||||||
|
if field.name != field.attname:
|
||||||
|
update_fields.add(field.attname)
|
||||||
|
obj.save(using=self.db, update_fields=update_fields)
|
||||||
|
else:
|
||||||
obj.save(using=self.db)
|
obj.save(using=self.db)
|
||||||
return obj, False
|
return obj, False
|
||||||
|
|
||||||
|
|
|
@ -63,3 +63,4 @@ class Book(models.Model):
|
||||||
related_name="books",
|
related_name="books",
|
||||||
db_column="publisher_id_column",
|
db_column="publisher_id_column",
|
||||||
)
|
)
|
||||||
|
updated = models.DateTimeField(auto_now=True)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from threading import Thread
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import DatabaseError, IntegrityError, connection
|
from django.db import DatabaseError, IntegrityError, connection
|
||||||
from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
|
from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
from django.utils.functional import lazy
|
from django.utils.functional import lazy
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
|
@ -513,6 +514,31 @@ class UpdateOrCreateTests(TestCase):
|
||||||
self.assertIs(created, False)
|
self.assertIs(created, False)
|
||||||
self.assertEqual(journalist.name, "John")
|
self.assertEqual(journalist.name, "John")
|
||||||
|
|
||||||
|
def test_update_only_defaults_and_pre_save_fields_when_local_fields(self):
|
||||||
|
publisher = Publisher.objects.create(name="Acme Publishing")
|
||||||
|
book = Book.objects.create(publisher=publisher, name="The Book of Ed & Fred")
|
||||||
|
|
||||||
|
for defaults in [{"publisher": publisher}, {"publisher_id": publisher}]:
|
||||||
|
with self.subTest(defaults=defaults):
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
book, created = Book.objects.update_or_create(
|
||||||
|
pk=book.pk,
|
||||||
|
defaults=defaults,
|
||||||
|
)
|
||||||
|
self.assertIs(created, False)
|
||||||
|
update_sqls = [
|
||||||
|
q["sql"] for q in captured_queries if q["sql"].startswith("UPDATE")
|
||||||
|
]
|
||||||
|
self.assertEqual(len(update_sqls), 1)
|
||||||
|
update_sql = update_sqls[0]
|
||||||
|
self.assertIsNotNone(update_sql)
|
||||||
|
self.assertIn(
|
||||||
|
connection.ops.quote_name("publisher_id_column"), update_sql
|
||||||
|
)
|
||||||
|
self.assertIn(connection.ops.quote_name("updated"), update_sql)
|
||||||
|
# Name should not be updated.
|
||||||
|
self.assertNotIn(connection.ops.quote_name("name"), update_sql)
|
||||||
|
|
||||||
|
|
||||||
class UpdateOrCreateTestsWithManualPKs(TestCase):
|
class UpdateOrCreateTestsWithManualPKs(TestCase):
|
||||||
def test_create_with_duplicate_primary_key(self):
|
def test_create_with_duplicate_primary_key(self):
|
||||||
|
|
Loading…
Reference in New Issue