Fixed #28010 -- Added FOR UPDATE OF support to QuerySet.select_for_update().

This commit is contained in:
Ran Benita 2017-06-29 23:00:15 +03:00 committed by Tim Graham
parent 2d18c60fbb
commit b9f7dce84b
12 changed files with 206 additions and 23 deletions

View File

@ -36,6 +36,10 @@ class BaseDatabaseFeatures:
has_select_for_update = False
has_select_for_update_nowait = False
has_select_for_update_skip_locked = False
has_select_for_update_of = False
# Does the database's SELECT FOR UPDATE OF syntax require a column rather
# than a table?
select_for_update_of_column = False
supports_select_related = True

View File

@ -177,16 +177,15 @@ class BaseDatabaseOperations:
"""
return []
def for_update_sql(self, nowait=False, skip_locked=False):
def for_update_sql(self, nowait=False, skip_locked=False, of=()):
"""
Return the FOR UPDATE SQL clause to lock rows for an update operation.
"""
if nowait:
return 'FOR UPDATE NOWAIT'
elif skip_locked:
return 'FOR UPDATE SKIP LOCKED'
else:
return 'FOR UPDATE'
return 'FOR UPDATE%s%s%s' % (
' OF %s' % ', '.join(of) if of else '',
' NOWAIT' if nowait else '',
' SKIP LOCKED' if skip_locked else '',
)
def last_executed_query(self, cursor, sql, params):
"""

View File

@ -9,6 +9,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_skip_locked = True
has_select_for_update_of = True
select_for_update_of_column = True
can_return_id_from_insert = True
allow_sliced_subqueries = False
can_introspect_autofield = True

View File

@ -13,6 +13,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_defer_constraint_checks = True
has_select_for_update = True
has_select_for_update_nowait = True
has_select_for_update_of = True
has_bulk_insert = True
uses_savepoints = True
can_release_savepoints = True

View File

@ -839,7 +839,7 @@ class QuerySet:
return self
return self._combinator_query('difference', *other_qs)
def select_for_update(self, nowait=False, skip_locked=False):
def select_for_update(self, nowait=False, skip_locked=False, of=()):
"""
Return a new QuerySet instance that will select objects with a
FOR UPDATE lock.
@ -851,6 +851,7 @@ class QuerySet:
obj.query.select_for_update = True
obj.query.select_for_update_nowait = nowait
obj.query.select_for_update_skip_locked = skip_locked
obj.query.select_for_update_of = of
return obj
def select_related(self, *fields):

View File

@ -1,3 +1,4 @@
import collections
import re
from itertools import chain
@ -472,14 +473,21 @@ class SQLCompiler:
)
nowait = self.query.select_for_update_nowait
skip_locked = self.query.select_for_update_skip_locked
# If it's a NOWAIT/SKIP LOCKED query but the backend
# doesn't support it, raise a DatabaseError to prevent a
of = self.query.select_for_update_of
# If it's a NOWAIT/SKIP LOCKED/OF query but the backend
# doesn't support it, raise NotSupportedError to prevent a
# possible deadlock.
if nowait and not self.connection.features.has_select_for_update_nowait:
raise NotSupportedError('NOWAIT is not supported on this database backend.')
elif skip_locked and not self.connection.features.has_select_for_update_skip_locked:
raise NotSupportedError('SKIP LOCKED is not supported on this database backend.')
for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)
elif of and not self.connection.features.has_select_for_update_of:
raise NotSupportedError('FOR UPDATE OF is not supported on this database backend.')
for_update_part = self.connection.ops.for_update_sql(
nowait=nowait,
skip_locked=skip_locked,
of=self.get_select_for_update_of_arguments(),
)
if for_update_part and self.connection.features.for_update_after_from:
result.append(for_update_part)
@ -832,6 +840,59 @@ class SQLCompiler:
)
return related_klass_infos
def get_select_for_update_of_arguments(self):
"""
Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
the query.
"""
def _get_field_choices():
"""Yield all allowed field paths in breadth-first search order."""
queue = collections.deque([(None, self.klass_info)])
while queue:
parent_path, klass_info = queue.popleft()
if parent_path is None:
path = []
yield 'self'
else:
path = parent_path + [klass_info['field'].name]
yield LOOKUP_SEP.join(path)
queue.extend(
(path, klass_info)
for klass_info in klass_info.get('related_klass_infos', [])
)
result = []
invalid_names = []
for name in self.query.select_for_update_of:
parts = [] if name == 'self' else name.split(LOOKUP_SEP)
klass_info = self.klass_info
for part in parts:
for related_klass_info in klass_info.get('related_klass_infos', []):
if related_klass_info['field'].name == part:
klass_info = related_klass_info
break
else:
klass_info = None
break
if klass_info is None:
invalid_names.append(name)
continue
select_index = klass_info['select_fields'][0]
col = self.select[select_index][0]
if self.connection.features.select_for_update_of_column:
result.append(self.compile(col)[0])
else:
result.append(self.quote_name_unless_alias(col.alias))
if invalid_names:
raise FieldError(
'Invalid field name(s) given in select_for_update(of=(...)): %s. '
'Only relational fields followed in the query are allowed. '
'Choices are: %s.' % (
', '.join(invalid_names),
', '.join(_get_field_choices()),
)
)
return result
def deferred_to_columns(self):
"""
Convert the self.deferred_loading data structure to mapping of table

View File

@ -161,6 +161,7 @@ class Query:
self.select_for_update = False
self.select_for_update_nowait = False
self.select_for_update_skip_locked = False
self.select_for_update_of = ()
self.select_related = False
# Arbitrary limit for select_related to prevents infinite recursion.
@ -288,6 +289,7 @@ class Query:
obj.select_for_update = self.select_for_update
obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_for_update_skip_locked = self.select_for_update_skip_locked
obj.select_for_update_of = self.select_for_update_of
obj.select_related = self.select_related
obj.values_select = self.values_select
obj._annotations = self._annotations.copy() if self._annotations is not None else None

View File

@ -629,9 +629,9 @@ both MySQL and Django will attempt to convert the values from UTC to local time.
Row locking with ``QuerySet.select_for_update()``
-------------------------------------------------
MySQL does not support the ``NOWAIT`` and ``SKIP LOCKED`` options to the
``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used with
``nowait=True`` or ``skip_locked=True``, then a
MySQL does not support the ``NOWAIT``, ``SKIP LOCKED``, and ``OF`` options to
the ``SELECT ... FOR UPDATE`` statement. If ``select_for_update()`` is used
with ``nowait=True``, ``skip_locked=True``, or ``of`` then a
:exc:`~django.db.NotSupportedError` is raised.
Automatic typecasting can cause unexpected results

View File

@ -1611,7 +1611,7 @@ For example::
``select_for_update()``
~~~~~~~~~~~~~~~~~~~~~~~
.. method:: select_for_update(nowait=False, skip_locked=False)
.. method:: select_for_update(nowait=False, skip_locked=False, of=())
Returns a queryset that will lock rows until the end of the transaction,
generating a ``SELECT ... FOR UPDATE`` SQL statement on supported databases.
@ -1635,14 +1635,21 @@ queryset is evaluated. You can also ignore locked rows by using
``select_for_update()`` with both options enabled will result in a
:exc:`ValueError`.
By default, ``select_for_update()`` locks all rows that are selected by the
query. For example, rows of related objects specified in :meth:`select_related`
are locked in addition to rows of the queryset's model. If this isn't desired,
specify the related objects you want to lock in ``select_for_update(of=(...))``
using the same fields syntax as :meth:`select_related`. Use the value ``'self'``
to refer to the queryset's model.
Currently, the ``postgresql``, ``oracle``, and ``mysql`` database
backends support ``select_for_update()``. However, MySQL doesn't support the
``nowait`` and ``skip_locked`` arguments.
``nowait``, ``skip_locked``, and ``of`` arguments.
Passing ``nowait=True`` or ``skip_locked=True`` to ``select_for_update()``
using database backends that do not support these options, such as MySQL,
raises a :exc:`~django.db.NotSupportedError`. This prevents code from
unexpectedly blocking.
Passing ``nowait=True``, ``skip_locked=True``, or ``of`` to
``select_for_update()`` using database backends that do not support these
options, such as MySQL, raises a :exc:`~django.db.NotSupportedError`. This
prevents code from unexpectedly blocking.
Evaluating a queryset with ``select_for_update()`` in autocommit mode on
backends which support ``SELECT ... FOR UPDATE`` is a
@ -1670,6 +1677,10 @@ raised if ``select_for_update()`` is used in autocommit mode.
The ``skip_locked`` argument was added.
.. versionchanged:: 2.0
The ``of`` argument was added.
``raw()``
~~~~~~~~~

View File

@ -252,6 +252,12 @@ Models
:class:`~django.db.models.functions.datetime.Extract` now works with
:class:`~django.db.models.DurationField`.
* Added the ``of`` argument to :meth:`.QuerySet.select_for_update()`, supported
on PostgreSQL and Oracle, to lock only rows from specific tables rather than
all selected tables. It may be helpful particularly when
:meth:`~.QuerySet.select_for_update()` is used in conjunction with
:meth:`~.QuerySet.select_related()`.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
@ -331,6 +337,11 @@ backends.
* The first argument of ``SchemaEditor._create_index_name()`` is now
``table_name`` rather than ``model``.
* To enable ``FOR UPDATE OF`` support, set
``DatabaseFeatures.has_select_for_update_of = True``. If the database
requires that the arguments to ``OF`` be columns rather than tables, set
``DatabaseFeatures.select_for_update_of_column = True``.
Dropped support for Oracle 11.2
-------------------------------

View File

@ -1,5 +1,16 @@
from django.db import models
class Country(models.Model):
name = models.CharField(max_length=30)
class City(models.Model):
name = models.CharField(max_length=30)
country = models.ForeignKey(Country, models.CASCADE)
class Person(models.Model):
name = models.CharField(max_length=30)
born = models.ForeignKey(City, models.CASCADE, related_name='+')
died = models.ForeignKey(City, models.CASCADE, related_name='+')

View File

@ -4,6 +4,7 @@ from unittest import mock
from multiple_database.routers import TestRouter
from django.core.exceptions import FieldError
from django.db import (
DatabaseError, NotSupportedError, connection, connections, router,
transaction,
@ -14,7 +15,7 @@ from django.test import (
)
from django.test.utils import CaptureQueriesContext
from .models import Person
from .models import City, Country, Person
class SelectForUpdateTests(TransactionTestCase):
@ -24,7 +25,11 @@ class SelectForUpdateTests(TransactionTestCase):
def setUp(self):
# This is executed in autocommit mode so that code in
# run_select_for_update can see this data.
self.person = Person.objects.create(name='Reinhardt')
self.country1 = Country.objects.create(name='Belgium')
self.country2 = Country.objects.create(name='France')
self.city1 = City.objects.create(name='Liberchies', country=self.country1)
self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2)
self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2)
# We need another database connection in transaction to test that one
# connection issuing a SELECT ... FOR UPDATE will block.
@ -90,6 +95,29 @@ class SelectForUpdateTests(TransactionTestCase):
list(Person.objects.all().select_for_update(skip_locked=True))
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True))
@skipUnlessDBFeature('has_select_for_update_of')
def test_for_update_sql_generated_of(self):
"""
The backend's FOR UPDATE OF variant appears in the generated SQL when
select_for_update() is invoked.
"""
with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
list(Person.objects.select_related(
'born__country',
).select_for_update(
of=('born__country',),
).select_for_update(
of=('self', 'born__country')
))
features = connections['default'].features
if features.select_for_update_of_column:
expected = ['"select_for_update_person"."id"', '"select_for_update_country"."id"']
else:
expected = ['"select_for_update_person"', '"select_for_update_country"']
if features.uppercases_column_names:
expected = [value.upper() for value in expected]
self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected))
@skipUnlessDBFeature('has_select_for_update_nowait')
def test_nowait_raises_error_on_block(self):
"""
@ -152,6 +180,58 @@ class SelectForUpdateTests(TransactionTestCase):
with transaction.atomic():
Person.objects.select_for_update(skip_locked=True).get()
@skipIfDBFeature('has_select_for_update_of')
@skipUnlessDBFeature('has_select_for_update')
def test_unsupported_of_raises_error(self):
"""
NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on
a database backend that supports FOR UPDATE but not OF.
"""
msg = 'FOR UPDATE OF is not supported on this database backend.'
with self.assertRaisesMessage(NotSupportedError, msg):
with transaction.atomic():
Person.objects.select_for_update(of=('self',)).get()
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
def test_unrelated_of_argument_raises_error(self):
"""
FieldError is raised if a non-relation field is specified in of=(...).
"""
msg = (
'Invalid field name(s) given in select_for_update(of=(...)): %s. '
'Only relational fields followed in the query are allowed. '
'Choices are: self, born, born__country.'
)
invalid_of = [
('nonexistent',),
('name',),
('born__nonexistent',),
('born__name',),
('born__nonexistent', 'born__name'),
]
for of in invalid_of:
with self.subTest(of=of):
with self.assertRaisesMessage(FieldError, msg % ', '.join(of)):
with transaction.atomic():
Person.objects.select_related('born__country').select_for_update(of=of).get()
@skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of')
def test_related_but_unselected_of_argument_raises_error(self):
"""
FieldError is raised if a relation field that is not followed in the
query is specified in of=(...).
"""
msg = (
'Invalid field name(s) given in select_for_update(of=(...)): %s. '
'Only relational fields followed in the query are allowed. '
'Choices are: self, born.'
)
for name in ['born__country', 'died', 'died__country']:
with self.subTest(name=name):
with self.assertRaisesMessage(FieldError, msg % name):
with transaction.atomic():
Person.objects.select_related('born').select_for_update(of=(name,)).get()
@skipUnlessDBFeature('has_select_for_update')
def test_for_update_after_from(self):
features_class = connections['default'].features.__class__
@ -182,7 +262,7 @@ class SelectForUpdateTests(TransactionTestCase):
@skipUnlessDBFeature('supports_select_for_update_with_limit')
def test_select_for_update_with_limit(self):
other = Person.objects.create(name='Grappeli')
other = Person.objects.create(name='Grappeli', born=self.city1, died=self.city2)
with transaction.atomic():
qs = list(Person.objects.all().order_by('pk').select_for_update()[1:2])
self.assertEqual(qs[0], other)