diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index e3677dd35a..65612896cb 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -229,6 +229,9 @@ class BaseDatabaseFeatures(object): # be equal? ignores_quoted_identifier_case = False + # Place FOR UPDATE right after FROM clause. Used on MSSQL. + for_update_after_from = False + def __init__(self, connection): self.connection = connection diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 5eaac1fc74..7c40fbeb38 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -402,6 +402,25 @@ class SQLCompiler(object): result.extend(from_) params.extend(f_params) + for_update_part = None + if self.query.select_for_update and self.connection.features.has_select_for_update: + if self.connection.get_autocommit(): + raise TransactionManagementError("select_for_update cannot be used outside of a transaction.") + + nowait = self.query.select_for_update_nowait + skip_locked = self.query.select_for_update_skip_locked + # If it's a NOWAIT/SKIP LOCKED query but the backend doesn't + # support it, raise a DatabaseError to prevent a possible + # deadlock. + if nowait and not self.connection.features.has_select_for_update_nowait: + raise DatabaseError('NOWAIT is not supported on this database backend.') + elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: + raise DatabaseError('SKIP LOCKED is not supported on this database backend.') + for_update_part = self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked) + + if for_update_part and self.connection.features.for_update_after_from: + result.append(for_update_part) + if where: result.append('WHERE %s' % where) params.extend(w_params) @@ -439,22 +458,8 @@ class SQLCompiler(object): result.append('LIMIT %d' % val) result.append('OFFSET %d' % self.query.low_mark) - if self.query.select_for_update and self.connection.features.has_select_for_update: - if self.connection.get_autocommit(): - raise TransactionManagementError( - "select_for_update cannot be used outside of a transaction." - ) - - nowait = self.query.select_for_update_nowait - skip_locked = self.query.select_for_update_skip_locked - # If we've been asked for a NOWAIT/SKIP LOCKED query but the - # backend does not support it, raise a DatabaseError otherwise - # we could get an unexpected deadlock. - if nowait and not self.connection.features.has_select_for_update_nowait: - raise DatabaseError('NOWAIT is not supported on this database backend.') - elif skip_locked and not self.connection.features.has_select_for_update_skip_locked: - raise DatabaseError('SKIP LOCKED is not supported on this database backend.') - result.append(self.connection.ops.for_update_sql(nowait=nowait, skip_locked=skip_locked)) + if for_update_part and not self.connection.features.for_update_after_from: + result.append(for_update_part) return ' '.join(result), tuple(params) finally: diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py index dcdd742ad4..e6e1155ee2 100644 --- a/tests/select_for_update/tests.py +++ b/tests/select_for_update/tests.py @@ -5,9 +5,11 @@ import time from multiple_database.routers import TestRouter -from django.db import DatabaseError, connection, router, transaction +from django.db import ( + DatabaseError, connection, connections, router, transaction, +) from django.test import ( - TransactionTestCase, override_settings, skipIfDBFeature, + TransactionTestCase, mock, override_settings, skipIfDBFeature, skipUnlessDBFeature, ) from django.test.utils import CaptureQueriesContext @@ -150,6 +152,14 @@ class SelectForUpdateTests(TransactionTestCase): with transaction.atomic(): Person.objects.select_for_update(skip_locked=True).get() + @skipUnlessDBFeature('has_select_for_update') + def test_for_update_after_from(self): + features_class = connections['default'].features.__class__ + attribute_to_patch = "%s.%s.for_update_after_from" % (features_class.__module__, features_class.__name__) + with mock.patch(attribute_to_patch, return_value=True): + with transaction.atomic(): + self.assertIn('FOR UPDATE WHERE', str(Person.objects.filter(name='foo').select_for_update().query)) + @skipUnlessDBFeature('has_select_for_update') def test_for_update_requires_transaction(self): """