[2.1.x] Fixed #30953 -- Made select_for_update() lock queryset's model when using "self" with multi-table inheritance.

Thanks Abhijeet Viswa for the report and initial patch.

Backport of 0107e3d105 from master.
This commit is contained in:
Mariusz Felisiak 2019-12-02 07:57:19 +01:00
parent ed50f6c424
commit 015fab76ad
4 changed files with 129 additions and 19 deletions

View File

@ -912,6 +912,21 @@ class SQLCompiler:
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
def _get_parent_klass_info(klass_info):
return (
{
'model': parent_model,
'field': parent_link,
'reverse': False,
'select_fields': [
select_index
for select_index in klass_info['select_fields']
if self.select[select_index][0].target.model == parent_model
],
}
for parent_model, parent_link in klass_info['model']._meta.parents.items()
)
def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order."""
queue = collections.deque([(None, self.klass_info)])
@ -926,6 +941,10 @@ class SQLCompiler:
field = field.remote_field
path = parent_path + [field.name]
yield LOOKUP_SEP.join(path)
queue.extend(
(path, klass_info)
for klass_info in _get_parent_klass_info(klass_info)
)
queue.extend(
(path, klass_info)
for klass_info in klass_info.get('related_klass_infos', [])
@ -933,28 +952,42 @@ class SQLCompiler:
result = []
invalid_names = []
for name in self.query.select_for_update_of:
parts = [] if name == 'self' else name.split(LOOKUP_SEP)
klass_info = self.klass_info
for part in parts:
for related_klass_info in klass_info.get('related_klass_infos', []):
field = related_klass_info['field']
if related_klass_info['reverse']:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info
if name == 'self':
# Find the first selected column from a base model. If it
# doesn't exist, don't lock a base model.
for select_index in klass_info['select_fields']:
if self.select[select_index][0].target.model == klass_info['model']:
col = self.select[select_index][0]
break
else:
klass_info = None
break
if klass_info is None:
invalid_names.append(name)
continue
select_index = klass_info['select_fields'][0]
col = self.select[select_index][0]
if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0])
col = None
else:
result.append(self.quote_name_unless_alias(col.alias))
for part in name.split(LOOKUP_SEP):
klass_infos = (
*klass_info.get('related_klass_infos', []),
*_get_parent_klass_info(klass_info),
)
for related_klass_info in klass_infos:
field = related_klass_info['field']
if related_klass_info['reverse']:
field = field.remote_field
if field.name == part:
klass_info = related_klass_info
break
else:
klass_info = None
break
if klass_info is None:
invalid_names.append(name)
continue
select_index = klass_info['select_fields'][0]
col = self.select[select_index][0]
if col is not None:
if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0])
else:
result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
'Invalid field name(s) given in select_for_update(of=(...)): %s. '

View File

@ -1694,6 +1694,14 @@ specify the related objects you want to lock in ``select_for_update(of=(...))``
using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
to refer to the queryset's model.
.. admonition:: Lock parents models in ``select_for_update(of=(...))``
If you want to lock parents models when using :ref:`multi-table inheritance
<multi-table-inheritance>`, you must specify parent link fields (by default
``<parent_model_name>_ptr``) in the ``of`` argument. For example::
Restaurant.objects.select_for_update(of=('self', 'place_ptr'))
You can't use ``select_for_update()`` on nullable relations::
>>> Person.objects.select_related('hometown').select_for_update()

View File

@ -5,11 +5,20 @@ class Country(models.Model):
name = models.CharField(max_length=30)
class EUCountry(Country):
join_date = models.DateField()
class City(models.Model):
name = models.CharField(max_length=30)
country = models.ForeignKey(Country, models.CASCADE)
class EUCity(models.Model):
name = models.CharField(max_length=30)
country = models.ForeignKey(EUCountry, models.CASCADE)
class Person(models.Model):
name = models.CharField(max_length=30)
born = models.ForeignKey(City, models.CASCADE, related_name='+')

View File

@ -15,7 +15,7 @@ from django.test import (
)
from django.test.utils import CaptureQueriesContext
from .models import City, Country, Person, PersonProfile
from .models import City, Country, EUCity, EUCountry, Person, PersonProfile
class SelectForUpdateTests(TransactionTestCase):
@ -120,6 +120,47 @@ class SelectForUpdateTests(TransactionTestCase):
expected = [value.upper() for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_model_inheritance_generated_of(self):
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
list(EUCountry.objects.select_for_update(of=('self',)))
if connection.features.select_for_update_of_column:
expected = ['select_for_update_eucountry"."country_ptr_id']
else:
expected = ['select_for_update_eucountry']
expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_model_inheritance_ptr_generated_of(self):
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
list(EUCountry.objects.select_for_update(of=('self', 'country_ptr',)))
if connection.features.select_for_update_of_column:
expected = [
'select_for_update_eucountry"."country_ptr_id',
'select_for_update_country"."id',
]
else:
expected = ['select_for_update_eucountry', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
list(EUCity.objects.select_related('country').select_for_update(
of=('self', 'country__country_ptr',),
))
if connection.features.select_for_update_of_column:
expected = [
'select_for_update_eucity"."id',
'select_for_update_country"."id',
]
else:
expected = ['select_for_update_eucity', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_of_followed_by_values(self):
with transaction.atomic():
@ -258,6 +299,25 @@ class SelectForUpdateTests(TransactionTestCase):
'born', 'profile',
).exclude(profile=None).select_for_update(of=(name,)).get()
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
def test_model_inheritance_of_argument_raises_error_ptr_in_choices(self):
msg = (
'Invalid field name(s) given in select_for_update(of=(...)): '
'name. Only relational fields followed in the query are allowed. '
'Choices are: self, %s.'
)
with self.assertRaisesMessage(
FieldError,
msg % 'country, country__country_ptr',
):
with transaction.atomic():
EUCity.objects.select_related(
'country',
).select_for_update(of=('name',)).get()
with self.assertRaisesMessage(FieldError, msg % 'country_ptr'):
with transaction.atomic():
EUCountry.objects.select_for_update(of=('name',)).get()
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
def test_reverse_one_to_one_of_arguments(self):
"""