From 8faaaf4e719531f2fd7f390a8af33ef2458f5427 Mon Sep 17 00:00:00 2001 From: Abhijeet Viswa Date: Sat, 8 Feb 2020 10:22:09 +0530 Subject: [PATCH] [3.0.x] Fixed #31246 -- Fixed locking models in QuerySet.select_for_update(of=()) for related fields and parent link fields with multi-table inheritance. Partly regression in 0107e3d1058f653f66032f7fd3a0bd61e96bf782. Backport of 1712a76b9dfda1ef220395e62ea87079da8c9f6c from master --- AUTHORS | 1 + django/db/models/sql/compiler.py | 37 ++++++++++++++---------- docs/releases/2.2.11.txt | 6 +++- docs/releases/3.0.4.txt | 6 ++++ tests/select_for_update/models.py | 6 +++- tests/select_for_update/tests.py | 48 +++++++++++++++++++++++++++---- 6 files changed, 81 insertions(+), 23 deletions(-) diff --git a/AUTHORS b/AUTHORS index 85ae591fc9c..2431fe61e14 100644 --- a/AUTHORS +++ b/AUTHORS @@ -9,6 +9,7 @@ answer newbie questions, and generally made Django that much better: Aaron Swartz Aaron T. Myers Abeer Upadhyay + Abhijeet Viswa Abhinav Patil Abhishek Gautam Adam Allred diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index be2d590d846..18365f1d752 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -972,19 +972,34 @@ class SQLCompiler: the query. """ def _get_parent_klass_info(klass_info): - return ( - { + for parent_model, parent_link in klass_info['model']._meta.parents.items(): + parent_list = parent_model._meta.get_parent_list() + yield { 'model': parent_model, 'field': parent_link, 'reverse': False, 'select_fields': [ select_index for select_index in klass_info['select_fields'] - if self.select[select_index][0].target.model == parent_model + # Selected columns from a model or its parents. + if ( + self.select[select_index][0].target.model == parent_model or + self.select[select_index][0].target.model in parent_list + ) ], } - for parent_model, parent_link in klass_info['model']._meta.parents.items() - ) + + def _get_first_selected_col_from_model(klass_info): + """ + Find the first selected column from a model. If it doesn't exist, + don't lock a model. + + select_fields is filled recursively, so it also contains fields + from the parent models. + """ + for select_index in klass_info['select_fields']: + if self.select[select_index][0].target.model == klass_info['model']: + return self.select[select_index][0] def _get_field_choices(): """Yield all allowed field paths in breadth-first search order.""" @@ -1013,14 +1028,7 @@ class SQLCompiler: for name in self.query.select_for_update_of: klass_info = self.klass_info if name == 'self': - # Find the first selected column from a base model. If it - # doesn't exist, don't lock a base model. - for select_index in klass_info['select_fields']: - if self.select[select_index][0].target.model == klass_info['model']: - col = self.select[select_index][0] - break - else: - col = None + col = _get_first_selected_col_from_model(klass_info) else: for part in name.split(LOOKUP_SEP): klass_infos = ( @@ -1040,8 +1048,7 @@ class SQLCompiler: if klass_info is None: invalid_names.append(name) continue - select_index = klass_info['select_fields'][0] - col = self.select[select_index][0] + col = _get_first_selected_col_from_model(klass_info) if col is not None: if self.connection.features.select_for_update_of_column: result.append(self.compile(col)[0]) diff --git a/docs/releases/2.2.11.txt b/docs/releases/2.2.11.txt index 5aaa5deab0c..b14d961ac36 100644 --- a/docs/releases/2.2.11.txt +++ b/docs/releases/2.2.11.txt @@ -9,4 +9,8 @@ Django 2.2.11 fixes a data loss bug in 2.2.10. Bugfixes ======== -* ... +* Fixed a data loss possibility in the + :meth:`~django.db.models.query.QuerySet.select_for_update`. When using + related fields or parent link fields with :ref:`multi-table-inheritance` in + the ``of`` argument, the corresponding models were not locked + (:ticket:`31246`). diff --git a/docs/releases/3.0.4.txt b/docs/releases/3.0.4.txt index c24d8f7a6a8..7be1ed15cee 100644 --- a/docs/releases/3.0.4.txt +++ b/docs/releases/3.0.4.txt @@ -14,3 +14,9 @@ Bugfixes * Fixed a regression in Django 3.0 that caused a file response using a temporary file to be closed incorrectly (:ticket:`31240`). + +* Fixed a data loss possibility in the + :meth:`~django.db.models.query.QuerySet.select_for_update`. When using + related fields or parent link fields with :ref:`multi-table-inheritance` in + the ``of`` argument, the corresponding models were not locked + (:ticket:`31246`). diff --git a/tests/select_for_update/models.py b/tests/select_for_update/models.py index c84f9ad6b29..305e8cac490 100644 --- a/tests/select_for_update/models.py +++ b/tests/select_for_update/models.py @@ -1,7 +1,11 @@ from django.db import models -class Country(models.Model): +class Entity(models.Model): + pass + + +class Country(Entity): name = models.CharField(max_length=30) diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py index 0bb21972d10..3622a95c11a 100644 --- a/tests/select_for_update/tests.py +++ b/tests/select_for_update/tests.py @@ -113,7 +113,10 @@ class SelectForUpdateTests(TransactionTestCase): )) features = connections['default'].features if features.select_for_update_of_column: - expected = ['select_for_update_person"."id', 'select_for_update_country"."id'] + expected = [ + 'select_for_update_person"."id', + 'select_for_update_country"."entity_ptr_id', + ] else: expected = ['select_for_update_person', 'select_for_update_country'] expected = [connection.ops.quote_name(value) for value in expected] @@ -137,13 +140,29 @@ class SelectForUpdateTests(TransactionTestCase): if connection.features.select_for_update_of_column: expected = [ 'select_for_update_eucountry"."country_ptr_id', - 'select_for_update_country"."id', + 'select_for_update_country"."entity_ptr_id', ] else: expected = ['select_for_update_eucountry', 'select_for_update_country'] expected = [connection.ops.quote_name(value) for value in expected] self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_sql_related_model_inheritance_generated_of(self): + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(EUCity.objects.select_related('country').select_for_update( + of=('self', 'country'), + )) + if connection.features.select_for_update_of_column: + expected = [ + 'select_for_update_eucity"."id', + 'select_for_update_eucountry"."country_ptr_id', + ] + else: + expected = ['select_for_update_eucity', 'select_for_update_eucountry'] + expected = [connection.ops.quote_name(value) for value in expected] + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + @skipUnlessDBFeature('has_select_for_update_of') def test_for_update_sql_model_inheritance_nested_ptr_generated_of(self): with transaction.atomic(), CaptureQueriesContext(connection) as ctx: @@ -153,13 +172,29 @@ class SelectForUpdateTests(TransactionTestCase): if connection.features.select_for_update_of_column: expected = [ 'select_for_update_eucity"."id', - 'select_for_update_country"."id', + 'select_for_update_country"."entity_ptr_id', ] else: expected = ['select_for_update_eucity', 'select_for_update_country'] expected = [connection.ops.quote_name(value) for value in expected] self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_sql_multilevel_model_inheritance_ptr_generated_of(self): + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(EUCountry.objects.select_for_update( + of=('country_ptr', 'country_ptr__entity_ptr'), + )) + if connection.features.select_for_update_of_column: + expected = [ + 'select_for_update_country"."entity_ptr_id', + 'select_for_update_entity"."id', + ] + else: + expected = ['select_for_update_country', 'select_for_update_entity'] + expected = [connection.ops.quote_name(value) for value in expected] + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + @skipUnlessDBFeature('has_select_for_update_of') def test_for_update_of_followed_by_values(self): with transaction.atomic(): @@ -264,7 +299,8 @@ class SelectForUpdateTests(TransactionTestCase): msg = ( 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' 'Only relational fields followed in the query are allowed. ' - 'Choices are: self, born, born__country.' + 'Choices are: self, born, born__country, ' + 'born__country__entity_ptr.' ) invalid_of = [ ('nonexistent',), @@ -307,13 +343,13 @@ class SelectForUpdateTests(TransactionTestCase): ) with self.assertRaisesMessage( FieldError, - msg % 'country, country__country_ptr', + msg % 'country, country__country_ptr, country__country_ptr__entity_ptr', ): with transaction.atomic(): EUCity.objects.select_related( 'country', ).select_for_update(of=('name',)).get() - with self.assertRaisesMessage(FieldError, msg % 'country_ptr'): + with self.assertRaisesMessage(FieldError, msg % 'country_ptr, country_ptr__entity_ptr'): with transaction.atomic(): EUCountry.objects.select_for_update(of=('name',)).get()