diff --git a/AUTHORS b/AUTHORS
index 2fd68834d85..de8731b6aac 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -9,6 +9,7 @@ answer newbie questions, and generally made Django that much better:
Aaron Swartz
Aaron T. Myers
Abeer Upadhyay
+ Abhijeet Viswa
Abhinav Patil
Abhishek Gautam
Adam Allred
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index aee5cb81cd7..92213a4e670 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -961,19 +961,34 @@ class SQLCompiler:
the query.
"""
def _get_parent_klass_info(klass_info):
- return (
- {
+ for parent_model, parent_link in klass_info['model']._meta.parents.items():
+ parent_list = parent_model._meta.get_parent_list()
+ yield {
'model': parent_model,
'field': parent_link,
'reverse': False,
'select_fields': [
select_index
for select_index in klass_info['select_fields']
- if self.select[select_index][0].target.model == parent_model
+ # Selected columns from a model or its parents.
+ if (
+ self.select[select_index][0].target.model == parent_model or
+ self.select[select_index][0].target.model in parent_list
+ )
],
}
- for parent_model, parent_link in klass_info['model']._meta.parents.items()
- )
+
+ def _get_first_selected_col_from_model(klass_info):
+ """
+ Find the first selected column from a model. If it doesn't exist,
+ don't lock a model.
+
+ select_fields is filled recursively, so it also contains fields
+ from the parent models.
+ """
+ for select_index in klass_info['select_fields']:
+ if self.select[select_index][0].target.model == klass_info['model']:
+ return self.select[select_index][0]
def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order."""
@@ -1002,14 +1017,7 @@ class SQLCompiler:
for name in self.query.select_for_update_of:
klass_info = self.klass_info
if name == 'self':
- # Find the first selected column from a base model. If it
- # doesn't exist, don't lock a base model.
- for select_index in klass_info['select_fields']:
- if self.select[select_index][0].target.model == klass_info['model']:
- col = self.select[select_index][0]
- break
- else:
- col = None
+ col = _get_first_selected_col_from_model(klass_info)
else:
for part in name.split(LOOKUP_SEP):
klass_infos = (
@@ -1029,8 +1037,7 @@ class SQLCompiler:
if klass_info is None:
invalid_names.append(name)
continue
- select_index = klass_info['select_fields'][0]
- col = self.select[select_index][0]
+ col = _get_first_selected_col_from_model(klass_info)
if col is not None:
if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0])
diff --git a/docs/releases/2.2.11.txt b/docs/releases/2.2.11.txt
index 5aaa5deab0c..b14d961ac36 100644
--- a/docs/releases/2.2.11.txt
+++ b/docs/releases/2.2.11.txt
@@ -9,4 +9,8 @@ Django 2.2.11 fixes a data loss bug in 2.2.10.
Bugfixes
========
-* ...
+* Fixed a data loss possibility in the
+ :meth:`~django.db.models.query.QuerySet.select_for_update`. When using
+ related fields or parent link fields with :ref:`multi-table-inheritance` in
+ the ``of`` argument, the corresponding models were not locked
+ (:ticket:`31246`).
diff --git a/docs/releases/3.0.4.txt b/docs/releases/3.0.4.txt
index c24d8f7a6a8..7be1ed15cee 100644
--- a/docs/releases/3.0.4.txt
+++ b/docs/releases/3.0.4.txt
@@ -14,3 +14,9 @@ Bugfixes
* Fixed a regression in Django 3.0 that caused a file response using a
temporary file to be closed incorrectly (:ticket:`31240`).
+
+* Fixed a data loss possibility in the
+ :meth:`~django.db.models.query.QuerySet.select_for_update`. When using
+ related fields or parent link fields with :ref:`multi-table-inheritance` in
+ the ``of`` argument, the corresponding models were not locked
+ (:ticket:`31246`).
diff --git a/tests/select_for_update/models.py b/tests/select_for_update/models.py
index c84f9ad6b29..305e8cac490 100644
--- a/tests/select_for_update/models.py
+++ b/tests/select_for_update/models.py
@@ -1,7 +1,11 @@
from django.db import models
-class Country(models.Model):
+class Entity(models.Model):
+ pass
+
+
+class Country(Entity):
name = models.CharField(max_length=30)
diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py
index 0bb21972d10..3622a95c11a 100644
--- a/tests/select_for_update/tests.py
+++ b/tests/select_for_update/tests.py
@@ -113,7 +113,10 @@ class SelectForUpdateTests(TransactionTestCase):
))
features = connections['default'].features
if features.select_for_update_of_column:
- expected = ['select_for_update_person"."id', 'select_for_update_country"."id']
+ expected = [
+ 'select_for_update_person"."id',
+ 'select_for_update_country"."entity_ptr_id',
+ ]
else:
expected = ['select_for_update_person', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected]
@@ -137,13 +140,29 @@ class SelectForUpdateTests(TransactionTestCase):
if connection.features.select_for_update_of_column:
expected = [
'select_for_update_eucountry"."country_ptr_id',
- 'select_for_update_country"."id',
+ 'select_for_update_country"."entity_ptr_id',
]
else:
expected = ['select_for_update_eucountry', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+ @skipUnlessDBFeature('has_select_for_update_of')
+ def test_for_update_sql_related_model_inheritance_generated_of(self):
+ with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+ list(EUCity.objects.select_related('country').select_for_update(
+ of=('self', 'country'),
+ ))
+ if connection.features.select_for_update_of_column:
+ expected = [
+ 'select_for_update_eucity"."id',
+ 'select_for_update_eucountry"."country_ptr_id',
+ ]
+ else:
+ expected = ['select_for_update_eucity', 'select_for_update_eucountry']
+ expected = [connection.ops.quote_name(value) for value in expected]
+ self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self):
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
@@ -153,13 +172,29 @@ class SelectForUpdateTests(TransactionTestCase):
if connection.features.select_for_update_of_column:
expected = [
'select_for_update_eucity"."id',
- 'select_for_update_country"."id',
+ 'select_for_update_country"."entity_ptr_id',
]
else:
expected = ['select_for_update_eucity', 'select_for_update_country']
expected = [connection.ops.quote_name(value) for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+ @skipUnlessDBFeature('has_select_for_update_of')
+ def test_for_update_sql_multilevel_model_inheritance_ptr_generated_of(self):
+ with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
+ list(EUCountry.objects.select_for_update(
+ of=('country_ptr', 'country_ptr__entity_ptr'),
+ ))
+ if connection.features.select_for_update_of_column:
+ expected = [
+ 'select_for_update_country"."entity_ptr_id',
+ 'select_for_update_entity"."id',
+ ]
+ else:
+ expected = ['select_for_update_country', 'select_for_update_entity']
+ expected = [connection.ops.quote_name(value) for value in expected]
+ self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
+
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_of_followed_by_values(self):
with transaction.atomic():
@@ -264,7 +299,8 @@ class SelectForUpdateTests(TransactionTestCase):
msg = (
'Invalid field name(s) given in select_for_update(of=(...)): %s. '
'Only relational fields followed in the query are allowed. '
- 'Choices are: self, born, born__country.'
+ 'Choices are: self, born, born__country, '
+ 'born__country__entity_ptr.'
)
invalid_of = [
('nonexistent',),
@@ -307,13 +343,13 @@ class SelectForUpdateTests(TransactionTestCase):
)
with self.assertRaisesMessage(
FieldError,
- msg % 'country, country__country_ptr',
+ msg % 'country, country__country_ptr, country__country_ptr__entity_ptr',
):
with transaction.atomic():
EUCity.objects.select_related(
'country',
).select_for_update(of=('name',)).get()
- with self.assertRaisesMessage(FieldError, msg % 'country_ptr'):
+ with self.assertRaisesMessage(FieldError, msg % 'country_ptr, country_ptr__entity_ptr'):
with transaction.atomic():
EUCountry.objects.select_for_update(of=('name',)).get()