mirror of https://github.com/django/django.git
Fixed a regression with get_or_create and virtual fields.
refs #20429 Thanks Simon Charette for the report and review.
This commit is contained in:
parent
a5cf5da50d
commit
f7290581fe
|
@ -411,7 +411,7 @@ class QuerySet(object):
|
||||||
Returns a tuple of (object, created), where created is a boolean
|
Returns a tuple of (object, created), where created is a boolean
|
||||||
specifying whether an object was created.
|
specifying whether an object was created.
|
||||||
"""
|
"""
|
||||||
lookup, params, _ = self._extract_model_params(defaults, **kwargs)
|
lookup, params = self._extract_model_params(defaults, **kwargs)
|
||||||
self._for_write = True
|
self._for_write = True
|
||||||
try:
|
try:
|
||||||
return self.get(**lookup), False
|
return self.get(**lookup), False
|
||||||
|
@ -425,7 +425,8 @@ class QuerySet(object):
|
||||||
Returns a tuple (object, created), where created is a boolean
|
Returns a tuple (object, created), where created is a boolean
|
||||||
specifying whether an object was created.
|
specifying whether an object was created.
|
||||||
"""
|
"""
|
||||||
lookup, params, filtered_defaults = self._extract_model_params(defaults, **kwargs)
|
defaults = defaults or {}
|
||||||
|
lookup, params = self._extract_model_params(defaults, **kwargs)
|
||||||
self._for_write = True
|
self._for_write = True
|
||||||
try:
|
try:
|
||||||
obj = self.get(**lookup)
|
obj = self.get(**lookup)
|
||||||
|
@ -433,12 +434,12 @@ class QuerySet(object):
|
||||||
obj, created = self._create_object_from_params(lookup, params)
|
obj, created = self._create_object_from_params(lookup, params)
|
||||||
if created:
|
if created:
|
||||||
return obj, created
|
return obj, created
|
||||||
for k, v in six.iteritems(filtered_defaults):
|
for k, v in six.iteritems(defaults):
|
||||||
setattr(obj, k, v)
|
setattr(obj, k, v)
|
||||||
|
|
||||||
sid = transaction.savepoint(using=self.db)
|
sid = transaction.savepoint(using=self.db)
|
||||||
try:
|
try:
|
||||||
obj.save(update_fields=filtered_defaults.keys(), using=self.db)
|
obj.save(using=self.db)
|
||||||
transaction.savepoint_commit(sid, using=self.db)
|
transaction.savepoint_commit(sid, using=self.db)
|
||||||
return obj, False
|
return obj, False
|
||||||
except DatabaseError:
|
except DatabaseError:
|
||||||
|
@ -469,22 +470,17 @@ class QuerySet(object):
|
||||||
def _extract_model_params(self, defaults, **kwargs):
|
def _extract_model_params(self, defaults, **kwargs):
|
||||||
"""
|
"""
|
||||||
Prepares `lookup` (kwargs that are valid model attributes), `params`
|
Prepares `lookup` (kwargs that are valid model attributes), `params`
|
||||||
(for creating a model instance) and `filtered_defaults` (defaults
|
(for creating a model instance) based on given kwargs; for use by
|
||||||
that are valid model attributes) based on given kwargs; for use by
|
|
||||||
get_or_create and update_or_create.
|
get_or_create and update_or_create.
|
||||||
"""
|
"""
|
||||||
defaults = defaults or {}
|
defaults = defaults or {}
|
||||||
filtered_defaults = {}
|
|
||||||
lookup = kwargs.copy()
|
lookup = kwargs.copy()
|
||||||
for f in self.model._meta.fields:
|
for f in self.model._meta.fields:
|
||||||
# Filter out fields that don't belongs to the model.
|
|
||||||
if f.attname in lookup:
|
if f.attname in lookup:
|
||||||
lookup[f.name] = lookup.pop(f.attname)
|
lookup[f.name] = lookup.pop(f.attname)
|
||||||
if f.attname in defaults:
|
|
||||||
filtered_defaults[f.name] = defaults.pop(f.attname)
|
|
||||||
params = dict((k, v) for k, v in kwargs.items() if LOOKUP_SEP not in k)
|
params = dict((k, v) for k, v in kwargs.items() if LOOKUP_SEP not in k)
|
||||||
params.update(filtered_defaults)
|
params.update(defaults)
|
||||||
return lookup, params, filtered_defaults
|
return lookup, params
|
||||||
|
|
||||||
def _earliest_or_latest(self, field_name=None, direction="-"):
|
def _earliest_or_latest(self, field_name=None, direction="-"):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -263,6 +263,29 @@ class GenericRelationsTests(TestCase):
|
||||||
formset = GenericFormSet(initial=initial_data)
|
formset = GenericFormSet(initial=initial_data)
|
||||||
self.assertEqual(formset.forms[0].initial, initial_data[0])
|
self.assertEqual(formset.forms[0].initial, initial_data[0])
|
||||||
|
|
||||||
|
def test_get_or_create(self):
|
||||||
|
# get_or_create should work with virtual fields (content_object)
|
||||||
|
quartz = Mineral.objects.create(name="Quartz", hardness=7)
|
||||||
|
tag, created = TaggedItem.objects.get_or_create(tag="shiny",
|
||||||
|
defaults={'content_object': quartz})
|
||||||
|
self.assertTrue(created)
|
||||||
|
self.assertEqual(tag.tag, "shiny")
|
||||||
|
self.assertEqual(tag.content_object.id, quartz.id)
|
||||||
|
|
||||||
|
def test_update_or_create_defaults(self):
|
||||||
|
# update_or_create should work with virtual fields (content_object)
|
||||||
|
quartz = Mineral.objects.create(name="Quartz", hardness=7)
|
||||||
|
diamond = Mineral.objects.create(name="Diamond", hardness=7)
|
||||||
|
tag, created = TaggedItem.objects.update_or_create(tag="shiny",
|
||||||
|
defaults={'content_object': quartz})
|
||||||
|
self.assertTrue(created)
|
||||||
|
self.assertEqual(tag.content_object.id, quartz.id)
|
||||||
|
|
||||||
|
tag, created = TaggedItem.objects.update_or_create(tag="shiny",
|
||||||
|
defaults={'content_object': diamond})
|
||||||
|
self.assertFalse(created)
|
||||||
|
self.assertEqual(tag.content_object.id, diamond.id)
|
||||||
|
|
||||||
|
|
||||||
class CustomWidget(forms.TextInput):
|
class CustomWidget(forms.TextInput):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue