[2.2.x] Fixed #31246 -- Fixed locking models in QuerySet.select_for_update(of=()) for related fields and parent link fields with multi-table inheritance.

Partly regression in 0107e3d105.

Backport of 1712a76b9d from master.
This commit is contained in:
Abhijeet Viswa 2020-02-08 10:22:09 +05:30 committed by Mariusz Felisiak
parent eeed073aa2
commit 32d89bf114
5 changed files with 75 additions and 23 deletions

View File

@ -9,6 +9,7 @@ answer newbie questions, and generally made Django that much better:
Aaron Swartz <http://www.aaronsw.com/> Aaron Swartz <http://www.aaronsw.com/>
Aaron T. Myers <atmyers@gmail.com> Aaron T. Myers <atmyers@gmail.com>
Abeer Upadhyay <ab.esquarer@gmail.com> Abeer Upadhyay <ab.esquarer@gmail.com>
Abhijeet Viswa <abhijeetviswa@gmail.com>
Abhinav Patil <https://github.com/ubadub/> Abhinav Patil <https://github.com/ubadub/>
Abhishek Gautam <abhishekg1128@yahoo.com> Abhishek Gautam <abhishekg1128@yahoo.com>
Adam Allred <adam.w.allred@gmail.com> Adam Allred <adam.w.allred@gmail.com>

View File

@ -948,19 +948,34 @@ class SQLCompiler:
the query. the query.
""" """
def _get_parent_klass_info(klass_info): def _get_parent_klass_info(klass_info):
return ( for parent_model, parent_link in klass_info['model']._meta.parents.items():
{ parent_list = parent_model._meta.get_parent_list()
yield {
'model': parent_model, 'model': parent_model,
'field': parent_link, 'field': parent_link,
'reverse': False, 'reverse': False,
'select_fields': [ 'select_fields': [
select_index select_index
for select_index in klass_info['select_fields'] for select_index in klass_info['select_fields']
if self.select[select_index][0].target.model == parent_model # Selected columns from a model or its parents.
if (
self.select[select_index][0].target.model == parent_model or
self.select[select_index][0].target.model in parent_list
)
], ],
} }
for parent_model, parent_link in klass_info['model']._meta.parents.items()
) def _get_first_selected_col_from_model(klass_info):
"""
Find the first selected column from a model. If it doesn't exist,
don't lock a model.
select_fields is filled recursively, so it also contains fields
from the parent models.
"""
for select_index in klass_info['select_fields']:
if self.select[select_index][0].target.model == klass_info['model']:
return self.select[select_index][0]
def _get_field_choices(): def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order.""" """Yield all allowed field paths in breadth-first search order."""
@ -989,14 +1004,7 @@ class SQLCompiler:
for name in self.query.select_for_update_of: for name in self.query.select_for_update_of:
klass_info = self.klass_info klass_info = self.klass_info
if name == 'self': if name == 'self':
# Find the first selected column from a base model. If it col = _get_first_selected_col_from_model(klass_info)
# doesn't exist, don't lock a base model.
for select_index in klass_info['select_fields']:
if self.select[select_index][0].target.model == klass_info['model']:
col = self.select[select_index][0]
break
else:
col = None
else: else:
for part in name.split(LOOKUP_SEP): for part in name.split(LOOKUP_SEP):
klass_infos = ( klass_infos = (
@ -1016,8 +1024,7 @@ class SQLCompiler:
if klass_info is None: if klass_info is None:
invalid_names.append(name) invalid_names.append(name)
continue continue
select_index = klass_info['select_fields'][0] col = _get_first_selected_col_from_model(klass_info)
col = self.select[select_index][0]
if col is not None: if col is not None:
if self.connection.features.select_for_update_of_column: if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0]) result.append(self.compile(col)[0])

View File

@ -9,4 +9,8 @@ Django 2.2.11 fixes a data loss bug in 2.2.10.
Bugfixes Bugfixes
======== ========
* ... * Fixed a data loss possibility in the
:meth:`~django.db.models.query.QuerySet.select_for_update`. When using
related fields or parent link fields with :ref:`multi-table-inheritance` in
the ``of`` argument, the corresponding models were not locked
(:ticket:`31246`).

View File

@ -1,7 +1,11 @@
from django.db import models from django.db import models
class Country(models.Model): class Entity(models.Model):
pass
class Country(Entity):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)

View File

@ -113,7 +113,10 @@ class SelectForUpdateTests(TransactionTestCase):
)) ))
features = connections['default'].features features = connections['default'].features
if features.select_for_update_of_column: if features.select_for_update_of_column:
expected = ['select_for_update_person"."id', 'select_for_update_country"."id'] expected = [
'select_for_update_person"."id',
'select_for_update_country"."entity_ptr_id',
]
else: else:
expected = ['select_for_update_person', 'select_for_update_country'] expected = ['select_for_update_person', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected] expected = [connection.ops.quote_name(value) for value in expected]
@ -137,13 +140,29 @@ class SelectForUpdateTests(TransactionTestCase):
if connection.features.select_for_update_of_column: if connection.features.select_for_update_of_column:
expected = [ expected = [
'select_for_update_eucountry"."country_ptr_id', 'select_for_update_eucountry"."country_ptr_id',
'select_for_update_country"."id', 'select_for_update_country"."entity_ptr_id',
] ]
else: else:
expected = ['select_for_update_eucountry', 'select_for_update_country'] expected = ['select_for_update_eucountry', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected] expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_related_model_inheritance_generated_of(self):
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
list(EUCity.objects.select_related('country').select_for_update(
of=('self', 'country'),
))
if connection.features.select_for_update_of_column:
expected = [
'select_for_update_eucity"."id',
'select_for_update_eucountry"."country_ptr_id',
]
else:
expected = ['select_for_update_eucity', 'select_for_update_eucountry']
expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of') @skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self): def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
with transaction.atomic(), CaptureQueriesContext(connection) as ctx: with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
@ -153,13 +172,29 @@ class SelectForUpdateTests(TransactionTestCase):
if connection.features.select_for_update_of_column: if connection.features.select_for_update_of_column:
expected = [ expected = [
'select_for_update_eucity"."id', 'select_for_update_eucity"."id',
'select_for_update_country"."id', 'select_for_update_country"."entity_ptr_id',
] ]
else: else:
expected = ['select_for_update_eucity', 'select_for_update_country'] expected = ['select_for_update_eucity', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected] expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_multilevel_model_inheritance_ptr_generated_of(self):
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
list(EUCountry.objects.select_for_update(
of=('country_ptr', 'country_ptr__entity_ptr'),
))
if connection.features.select_for_update_of_column:
expected = [
'select_for_update_country"."entity_ptr_id',
'select_for_update_entity"."id',
]
else:
expected = ['select_for_update_country', 'select_for_update_entity']
expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of') @skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_of_followed_by_values(self): def test_for_update_of_followed_by_values(self):
with transaction.atomic(): with transaction.atomic():
@ -264,7 +299,8 @@ class SelectForUpdateTests(TransactionTestCase):
msg = ( msg = (
'Invalid field name(s) given in select_for_update(of=(...)): %s. ' 'Invalid field name(s) given in select_for_update(of=(...)): %s. '
'Only relational fields followed in the query are allowed. ' 'Only relational fields followed in the query are allowed. '
'Choices are: self, born, born__country.' 'Choices are: self, born, born__country, '
'born__country__entity_ptr.'
) )
invalid_of = [ invalid_of = [
('nonexistent',), ('nonexistent',),
@ -307,13 +343,13 @@ class SelectForUpdateTests(TransactionTestCase):
) )
with self.assertRaisesMessage( with self.assertRaisesMessage(
FieldError, FieldError,
msg % 'country, country__country_ptr', msg % 'country, country__country_ptr, country__country_ptr__entity_ptr',
): ):
with transaction.atomic(): with transaction.atomic():
EUCity.objects.select_related( EUCity.objects.select_related(
'country', 'country',
).select_for_update(of=('name',)).get() ).select_for_update(of=('name',)).get()
with self.assertRaisesMessage(FieldError, msg % 'country_ptr'): with self.assertRaisesMessage(FieldError, msg % 'country_ptr, country_ptr__entity_ptr'):
with transaction.atomic(): with transaction.atomic():
EUCountry.objects.select_for_update(of=('name',)).get() EUCountry.objects.select_for_update(of=('name',)).get()