diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 4b38a2d6b1..41a456bfed 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -36,6 +36,10 @@ class BaseDatabaseFeatures: has_select_for_update = False has_select_for_update_nowait = False has_select_for_update_skip_locked = False + has_select_for_update_of = False + # Does the database's SELECT FOR UPDATE OF syntax require a column rather + # than a table? + select_for_update_of_column = False supports_select_related = True diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index c3df3465c3..cf6b5f9166 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -177,16 +177,15 @@ class BaseDatabaseOperations: """ return [] - def for_update_sql(self, nowait=False, skip_locked=False): + def for_update_sql(self, nowait=False, skip_locked=False, of=()): """ Return the FOR UPDATE SQL clause to lock rows for an update operation. """ - if nowait: - return 'FOR UPDATE NOWAIT' - elif skip_locked: - return 'FOR UPDATE SKIP LOCKED' - else: - return 'FOR UPDATE' + return 'FOR UPDATE%s%s%s' % ( + ' OF %s' % ', '.join(of) if of else '', + ' NOWAIT' if nowait else '', + ' SKIP LOCKED' if skip_locked else '', + ) def last_executed_query(self, cursor, sql, params): """ diff --git a/django/db/backends/oracle/features.py b/django/db/backends/oracle/features.py index fe6e30dc46..90584ff14f 100644 --- a/django/db/backends/oracle/features.py +++ b/django/db/backends/oracle/features.py @@ -9,6 +9,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): has_select_for_update = True has_select_for_update_nowait = True has_select_for_update_skip_locked = True + has_select_for_update_of = True + select_for_update_of_column = True can_return_id_from_insert = True allow_sliced_subqueries = False can_introspect_autofield = True diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 3f6cc7894d..0f291a6586 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -13,6 +13,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_defer_constraint_checks = True has_select_for_update = True has_select_for_update_nowait = True + has_select_for_update_of = True has_bulk_insert = True uses_savepoints = True can_release_savepoints = True diff --git a/django/db/models/query.py b/django/db/models/query.py index 38f69f22d1..e5e1c1b9f4 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -839,7 +839,7 @@ class QuerySet: return self return self._combinator_query('difference', *other_qs) - def select_for_update(self, nowait=False, skip_locked=False): + def select_for_update(self, nowait=False, skip_locked=False, of=()): """ Return a new QuerySet instance that will select objects with a FOR UPDATE lock. @@ -851,6 +851,7 @@ class QuerySet: obj.query.select_for_update = True obj.query.select_for_update_nowait = nowait obj.query.select_for_update_skip_locked = skip_locked + obj.query.select_for_update_of = of return obj def select_related(self, *fields): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b4b27a5b56..c705d33af8 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1,3 +1,4 @@ +import collections import re from itertools import chain @@ -472,14 +473,21 @@ class SQLCompiler: ) nowait = self.query.select_for_update_nowait skip_locked = self.query.select_for_update_skip_locked - # If it's a NOWAIT/SKIP LOCKED query but the backend - # doesn't support it, raise a DatabaseError to prevent a + of = self.query.select_for_update_of + # If it's a NOWAIT/SKIP LOCKED/OF query but the backend + # doesn't support it, raise NotSupportedError to prevent a # possible deadlock. if nowait and not self.connection.features.has_select_for_update_nowait: raise NotSupportedError('NOWAIT is not supported on this database backend.') elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: raise NotSupportedError('SKIP LOCKED is not supported on this database backend.') - for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked) + elif of and not self.connection.features.has_select_for_update_of: + raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.') + for_update_part = self.connection.ops.for_update_sql( + nowait=nowait, + skip_locked=skip_locked, + of=self.get_select_for_update_of_arguments(), + ) if for_update_part and self.connection.features.for_update_after_from: result.append(for_update_part) @@ -832,6 +840,59 @@ class SQLCompiler: ) return related_klass_infos + def get_select_for_update_of_arguments(self): + """ + Return a quoted list of arguments for the SELECT FOR UPDATE OF part of + the query. + """ + def _get_field_choices(): + """Yield all allowed field paths in breadth-first search order.""" + queue = collections.deque([(None, self.klass_info)]) + while queue: + parent_path, klass_info = queue.popleft() + if parent_path is None: + path = [] + yield 'self' + else: + path = parent_path + [klass_info['field'].name] + yield LOOKUP_SEP.join(path) + queue.extend( + (path, klass_info) + for klass_info in klass_info.get('related_klass_infos', []) + ) + result = [] + invalid_names = [] + for name in self.query.select_for_update_of: + parts = [] if name == 'self' else name.split(LOOKUP_SEP) + klass_info = self.klass_info + for part in parts: + for related_klass_info in klass_info.get('related_klass_infos', []): + if related_klass_info['field'].name == part: + klass_info = related_klass_info + break + else: + klass_info = None + break + if klass_info is None: + invalid_names.append(name) + continue + select_index = klass_info['select_fields'][0] + col = self.select[select_index][0] + if self.connection.features.select_for_update_of_column: + result.append(self.compile(col)[0]) + else: + result.append(self.quote_name_unless_alias(col.alias)) + if invalid_names: + raise FieldError( + 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' + 'Only relational fields followed in the query are allowed. ' + 'Choices are: %s.' % ( + ', '.join(invalid_names), + ', '.join(_get_field_choices()), + ) + ) + return result + def deferred_to_columns(self): """ Convert the self.deferred_loading data structure to mapping of table diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b4a87938f7..70fd648c52 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -161,6 +161,7 @@ class Query: self.select_for_update = False self.select_for_update_nowait = False self.select_for_update_skip_locked = False + self.select_for_update_of = () self.select_related = False # Arbitrary limit for select_related to prevents infinite recursion. @@ -288,6 +289,7 @@ class Query: obj.select_for_update = self.select_for_update obj.select_for_update_nowait = self.select_for_update_nowait obj.select_for_update_skip_locked = self.select_for_update_skip_locked + obj.select_for_update_of = self.select_for_update_of obj.select_related = self.select_related obj.values_select = self.values_select obj._annotations = self._annotations.copy() if self._annotations is not None else None diff --git a/docs/ref/databases.txt b/docs/ref/databases.txt index 45b1772514..69921f437b 100644 --- a/docs/ref/databases.txt +++ b/docs/ref/databases.txt @@ -629,9 +629,9 @@ both MySQL and Django will attempt to convert the values from UTC to local time. Row locking with ``QuerySet.select_for_update()`` ------------------------------------------------- -MySQL does not support the ``NOWAIT`` and ``SKIP LOCKED`` options to the -``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used with -``nowait=True`` or ``skip_locked=True``, then a +MySQL does not support the ``NOWAIT``, ``SKIP LOCKED``, and ``OF`` options to +the ``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used +with ``nowait=True``, ``skip_locked=True``, or ``of`` then a :exc:`~django.db.NotSupportedError` is raised. Automatic typecasting can cause unexpected results diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index a9006a14a9..d8d063a7a5 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1611,7 +1611,7 @@ For example:: ``select_for_update()`` ~~~~~~~~~~~~~~~~~~~~~~~ -.. method:: select_for_update(nowait=False, skip_locked=False) +.. method:: select_for_update(nowait=False, skip_locked=False, of=()) Returns a queryset that will lock rows until the end of the transaction, generating a ``SELECT ... FOR UPDATE`` SQL statement on supported databases. @@ -1635,14 +1635,21 @@ queryset is evaluated. You can also ignore locked rows by using ``select_for_update()`` with both options enabled will result in a :exc:`ValueError`. +By default, ``select_for_update()`` locks all rows that are selected by the +query. For example, rows of related objects specified in :meth:`select_related` +are locked in addition to rows of the queryset's model. If this isn't desired, +specify the related objects you want to lock in ``select_for_update(of=(...))`` +using the same fields syntax as :meth:`select_related`. Use the value ``'self'`` +to refer to the queryset's model. + Currently, the ``postgresql``, ``oracle``, and ``mysql`` database backends support ``select_for_update()``. However, MySQL doesn't support the -``nowait`` and ``skip_locked`` arguments. +``nowait``, ``skip_locked``, and ``of`` arguments. -Passing ``nowait=True`` or ``skip_locked=True`` to ``select_for_update()`` -using database backends that do not support these options, such as MySQL, -raises a :exc:`~django.db.NotSupportedError`. This prevents code from -unexpectedly blocking. +Passing ``nowait=True``, ``skip_locked=True``, or ``of`` to +``select_for_update()`` using database backends that do not support these +options, such as MySQL, raises a :exc:`~django.db.NotSupportedError`. This +prevents code from unexpectedly blocking. Evaluating a queryset with ``select_for_update()`` in autocommit mode on backends which support ``SELECT ... FOR UPDATE`` is a @@ -1670,6 +1677,10 @@ raised if ``select_for_update()`` is used in autocommit mode. The ``skip_locked`` argument was added. +.. versionchanged:: 2.0 + + The ``of`` argument was added. + ``raw()`` ~~~~~~~~~ diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index 078cbbdf2c..8dfc43c24d 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -252,6 +252,12 @@ Models :class:`~django.db.models.functions.datetime.Extract` now works with :class:`~django.db.models.DurationField`. +* Added the ``of`` argument to :meth:`.QuerySet.select_for_update()`, supported + on PostgreSQL and Oracle, to lock only rows from specific tables rather than + all selected tables. It may be helpful particularly when + :meth:`~.QuerySet.select_for_update()` is used in conjunction with + :meth:`~.QuerySet.select_related()`. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ @@ -331,6 +337,11 @@ backends. * The first argument of ``SchemaEditor._create_index_name()`` is now ``table_name`` rather than ``model``. +* To enable ``FOR UPDATE OF`` support, set + ``DatabaseFeatures.has_select_for_update_of = True``. If the database + requires that the arguments to ``OF`` be columns rather than tables, set + ``DatabaseFeatures.select_for_update_of_column = True``. + Dropped support for Oracle 11.2 ------------------------------- diff --git a/tests/select_for_update/models.py b/tests/select_for_update/models.py index 48ad58faa9..b04ed31b00 100644 --- a/tests/select_for_update/models.py +++ b/tests/select_for_update/models.py @@ -1,5 +1,16 @@ from django.db import models +class Country(models.Model): + name = models.CharField(max_length=30) + + +class City(models.Model): + name = models.CharField(max_length=30) + country = models.ForeignKey(Country, models.CASCADE) + + class Person(models.Model): name = models.CharField(max_length=30) + born = models.ForeignKey(City, models.CASCADE, related_name='+') + died = models.ForeignKey(City, models.CASCADE, related_name='+') diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py index 0c581f0f37..7228af6e8e 100644 --- a/tests/select_for_update/tests.py +++ b/tests/select_for_update/tests.py @@ -4,6 +4,7 @@ from unittest import mock from multiple_database.routers import TestRouter +from django.core.exceptions import FieldError from django.db import ( DatabaseError, NotSupportedError, connection, connections, router, transaction, @@ -14,7 +15,7 @@ from django.test import ( ) from django.test.utils import CaptureQueriesContext -from .models import Person +from .models import City, Country, Person class SelectForUpdateTests(TransactionTestCase): @@ -24,7 +25,11 @@ class SelectForUpdateTests(TransactionTestCase): def setUp(self): # This is executed in autocommit mode so that code in # run_select_for_update can see this data. - self.person = Person.objects.create(name='Reinhardt') + self.country1 = Country.objects.create(name='Belgium') + self.country2 = Country.objects.create(name='France') + self.city1 = City.objects.create(name='Liberchies', country=self.country1) + self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2) + self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2) # We need another database connection in transaction to test that one # connection issuing a SELECT ... FOR UPDATE will block. @@ -90,6 +95,29 @@ class SelectForUpdateTests(TransactionTestCase): list(Person.objects.all().select_for_update(skip_locked=True)) self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True)) + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_sql_generated_of(self): + """ + The backend's FOR UPDATE OF variant appears in the generated SQL when + select_for_update() is invoked. + """ + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(Person.objects.select_related( + 'born__country', + ).select_for_update( + of=('born__country',), + ).select_for_update( + of=('self', 'born__country') + )) + features = connections['default'].features + if features.select_for_update_of_column: + expected = ['"select_for_update_person"."id"', '"select_for_update_country"."id"'] + else: + expected = ['"select_for_update_person"', '"select_for_update_country"'] + if features.uppercases_column_names: + expected = [value.upper() for value in expected] + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + @skipUnlessDBFeature('has_select_for_update_nowait') def test_nowait_raises_error_on_block(self): """ @@ -152,6 +180,58 @@ class SelectForUpdateTests(TransactionTestCase): with transaction.atomic(): Person.objects.select_for_update(skip_locked=True).get() + @skipIfDBFeature('has_select_for_update_of') + @skipUnlessDBFeature('has_select_for_update') + def test_unsupported_of_raises_error(self): + """ + NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on + a database backend that supports FOR UPDATE but not OF. + """ + msg = 'FOR UPDATE OF is not supported on this database backend.' + with self.assertRaisesMessage(NotSupportedError, msg): + with transaction.atomic(): + Person.objects.select_for_update(of=('self',)).get() + + @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of') + def test_unrelated_of_argument_raises_error(self): + """ + FieldError is raised if a non-relation field is specified in of=(...). + """ + msg = ( + 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' + 'Only relational fields followed in the query are allowed. ' + 'Choices are: self, born, born__country.' + ) + invalid_of = [ + ('nonexistent',), + ('name',), + ('born__nonexistent',), + ('born__name',), + ('born__nonexistent', 'born__name'), + ] + for of in invalid_of: + with self.subTest(of=of): + with self.assertRaisesMessage(FieldError, msg % ', '.join(of)): + with transaction.atomic(): + Person.objects.select_related('born__country').select_for_update(of=of).get() + + @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of') + def test_related_but_unselected_of_argument_raises_error(self): + """ + FieldError is raised if a relation field that is not followed in the + query is specified in of=(...). + """ + msg = ( + 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' + 'Only relational fields followed in the query are allowed. ' + 'Choices are: self, born.' + ) + for name in ['born__country', 'died', 'died__country']: + with self.subTest(name=name): + with self.assertRaisesMessage(FieldError, msg % name): + with transaction.atomic(): + Person.objects.select_related('born').select_for_update(of=(name,)).get() + @skipUnlessDBFeature('has_select_for_update') def test_for_update_after_from(self): features_class = connections['default'].features.__class__ @@ -182,7 +262,7 @@ class SelectForUpdateTests(TransactionTestCase): @skipUnlessDBFeature('supports_select_for_update_with_limit') def test_select_for_update_with_limit(self): - other = Person.objects.create(name='Grappeli') + other = Person.objects.create(name='Grappeli', born=self.city1, died=self.city2) with transaction.atomic(): qs = list(Person.objects.all().order_by('pk').select_for_update()[1:2]) self.assertEqual(qs[0], other)