From 0107e3d1058f653f66032f7fd3a0bd61e96bf782 Mon Sep 17 00:00:00 2001 From: Mariusz Felisiak Date: Mon, 2 Dec 2019 07:57:19 +0100 Subject: [PATCH] Fixed #30953 -- Made select_for_update() lock queryset's model when using "self" with multi-table inheritance. Thanks Abhijeet Viswa for the report and initial patch. --- django/db/models/sql/compiler.py | 69 +++++++++++++++++++++++-------- docs/ref/models/querysets.txt | 8 ++++ docs/releases/2.2.8.txt | 6 +++ tests/select_for_update/models.py | 9 ++++ tests/select_for_update/tests.py | 62 ++++++++++++++++++++++++++- 5 files changed, 135 insertions(+), 19 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b3d0e0ac2d..3d434d5909 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -953,6 +953,21 @@ class SQLCompiler: Return a quoted list of arguments for the SELECT FOR UPDATE OF part of the query. """ + def _get_parent_klass_info(klass_info): + return ( + { + 'model': parent_model, + 'field': parent_link, + 'reverse': False, + 'select_fields': [ + select_index + for select_index in klass_info['select_fields'] + if self.select[select_index][0].target.model == parent_model + ], + } + for parent_model, parent_link in klass_info['model']._meta.parents.items() + ) + def _get_field_choices(): """Yield all allowed field paths in breadth-first search order.""" queue = collections.deque([(None, self.klass_info)]) @@ -967,6 +982,10 @@ class SQLCompiler: field = field.remote_field path = parent_path + [field.name] yield LOOKUP_SEP.join(path) + queue.extend( + (path, klass_info) + for klass_info in _get_parent_klass_info(klass_info) + ) queue.extend( (path, klass_info) for klass_info in klass_info.get('related_klass_infos', []) @@ -974,28 +993,42 @@ class SQLCompiler: result = [] invalid_names = [] for name in self.query.select_for_update_of: - parts = [] if name == 'self' else name.split(LOOKUP_SEP) klass_info = self.klass_info - for part in parts: - for related_klass_info in klass_info.get('related_klass_infos', []): - field = related_klass_info['field'] - if related_klass_info['reverse']: - field = field.remote_field - if field.name == part: - klass_info = related_klass_info + if name == 'self': + # Find the first selected column from a base model. If it + # doesn't exist, don't lock a base model. + for select_index in klass_info['select_fields']: + if self.select[select_index][0].target.model == klass_info['model']: + col = self.select[select_index][0] break else: - klass_info = None - break - if klass_info is None: - invalid_names.append(name) - continue - select_index = klass_info['select_fields'][0] - col = self.select[select_index][0] - if self.connection.features.select_for_update_of_column: - result.append(self.compile(col)[0]) + col = None else: - result.append(self.quote_name_unless_alias(col.alias)) + for part in name.split(LOOKUP_SEP): + klass_infos = ( + *klass_info.get('related_klass_infos', []), + *_get_parent_klass_info(klass_info), + ) + for related_klass_info in klass_infos: + field = related_klass_info['field'] + if related_klass_info['reverse']: + field = field.remote_field + if field.name == part: + klass_info = related_klass_info + break + else: + klass_info = None + break + if klass_info is None: + invalid_names.append(name) + continue + select_index = klass_info['select_fields'][0] + col = self.select[select_index][0] + if col is not None: + if self.connection.features.select_for_update_of_column: + result.append(self.compile(col)[0]) + else: + result.append(self.quote_name_unless_alias(col.alias)) if invalid_names: raise FieldError( 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index b8dcb90fec..ba20c7516a 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1692,6 +1692,14 @@ specify the related objects you want to lock in ``select_for_update(of=(...))`` using the same fields syntax as :meth:`select_related`. Use the value ``'self'`` to refer to the queryset's model. +.. admonition:: Lock parents models in ``select_for_update(of=(...))`` + + If you want to lock parents models when using :ref:`multi-table inheritance + `, you must specify parent link fields (by default + ``_ptr``) in the ``of`` argument. For example:: + + Restaurant.objects.select_for_update(of=('self', 'place_ptr')) + You can't use ``select_for_update()`` on nullable relations:: >>> Person.objects.select_related('hometown').select_for_update() diff --git a/docs/releases/2.2.8.txt b/docs/releases/2.2.8.txt index 4d8f9869c5..3c5eb5c754 100644 --- a/docs/releases/2.2.8.txt +++ b/docs/releases/2.2.8.txt @@ -17,3 +17,9 @@ Bugfixes * Fixed a regression in Django 2.2.1 that caused a crash when migrating permissions for proxy models with a multiple database setup if the ``default`` entry was empty (:ticket:`31021`). + +* Fixed a data loss possibility in the + :meth:`~django.db.models.query.QuerySet.select_for_update()`. When using + ``'self'`` in the ``of`` argument with :ref:`multi-table inheritance + `, a parent model was locked instead of the + queryset's model (:ticket:`30953`). diff --git a/tests/select_for_update/models.py b/tests/select_for_update/models.py index b8154af3df..c84f9ad6b2 100644 --- a/tests/select_for_update/models.py +++ b/tests/select_for_update/models.py @@ -5,11 +5,20 @@ class Country(models.Model): name = models.CharField(max_length=30) +class EUCountry(Country): + join_date = models.DateField() + + class City(models.Model): name = models.CharField(max_length=30) country = models.ForeignKey(Country, models.CASCADE) +class EUCity(models.Model): + name = models.CharField(max_length=30) + country = models.ForeignKey(EUCountry, models.CASCADE) + + class Person(models.Model): name = models.CharField(max_length=30) born = models.ForeignKey(City, models.CASCADE, related_name='+') diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py index 7859775cff..0bb21972d1 100644 --- a/tests/select_for_update/tests.py +++ b/tests/select_for_update/tests.py @@ -15,7 +15,7 @@ from django.test import ( ) from django.test.utils import CaptureQueriesContext -from .models import City, Country, Person, PersonProfile +from .models import City, Country, EUCity, EUCountry, Person, PersonProfile class SelectForUpdateTests(TransactionTestCase): @@ -119,6 +119,47 @@ class SelectForUpdateTests(TransactionTestCase): expected = [connection.ops.quote_name(value) for value in expected] self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_sql_model_inheritance_generated_of(self): + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(EUCountry.objects.select_for_update(of=('self',))) + if connection.features.select_for_update_of_column: + expected = ['select_for_update_eucountry"."country_ptr_id'] + else: + expected = ['select_for_update_eucountry'] + expected = [connection.ops.quote_name(value) for value in expected] + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_sql_model_inheritance_ptr_generated_of(self): + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(EUCountry.objects.select_for_update(of=('self', 'country_ptr',))) + if connection.features.select_for_update_of_column: + expected = [ + 'select_for_update_eucountry"."country_ptr_id', + 'select_for_update_country"."id', + ] + else: + expected = ['select_for_update_eucountry', 'select_for_update_country'] + expected = [connection.ops.quote_name(value) for value in expected] + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self): + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(EUCity.objects.select_related('country').select_for_update( + of=('self', 'country__country_ptr',), + )) + if connection.features.select_for_update_of_column: + expected = [ + 'select_for_update_eucity"."id', + 'select_for_update_country"."id', + ] + else: + expected = ['select_for_update_eucity', 'select_for_update_country'] + expected = [connection.ops.quote_name(value) for value in expected] + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + @skipUnlessDBFeature('has_select_for_update_of') def test_for_update_of_followed_by_values(self): with transaction.atomic(): @@ -257,6 +298,25 @@ class SelectForUpdateTests(TransactionTestCase): 'born', 'profile', ).exclude(profile=None).select_for_update(of=(name,)).get() + @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of') + def test_model_inheritance_of_argument_raises_error_ptr_in_choices(self): + msg = ( + 'Invalid field name(s) given in select_for_update(of=(...)): ' + 'name. Only relational fields followed in the query are allowed. ' + 'Choices are: self, %s.' + ) + with self.assertRaisesMessage( + FieldError, + msg % 'country, country__country_ptr', + ): + with transaction.atomic(): + EUCity.objects.select_related( + 'country', + ).select_for_update(of=('name',)).get() + with self.assertRaisesMessage(FieldError, msg % 'country_ptr'): + with transaction.atomic(): + EUCountry.objects.select_for_update(of=('name',)).get() + @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of') def test_reverse_one_to_one_of_arguments(self): """