django/tests/select_for_update/tests.py

260 lines
9.4 KiB
Python

from __future__ import unicode_literals
import threading
import time
from multiple_database.routers import TestRouter
from django.conf import settings
from django.db import connection, router, transaction
from django.db.utils import DEFAULT_DB_ALIAS, ConnectionHandler, DatabaseError
from django.test import (
TransactionTestCase, override_settings, skipIfDBFeature,
skipUnlessDBFeature,
)
from .models import Person
# We need to set settings.DEBUG to True so we can capture the output SQL
# to examine.
@override_settings(DEBUG=True)
class SelectForUpdateTests(TransactionTestCase):
available_apps = ['select_for_update']
def setUp(self):
# This is executed in autocommit mode so that code in
# run_select_for_update can see this data.
self.person = Person.objects.create(name='Reinhardt')
# We need another database connection in transaction to test that one
# connection issuing a SELECT ... FOR UPDATE will block.
new_connections = ConnectionHandler(settings.DATABASES)
self.new_connection = new_connections[DEFAULT_DB_ALIAS]
def tearDown(self):
try:
self.end_blocking_transaction()
except (DatabaseError, AttributeError):
pass
self.new_connection.close()
def start_blocking_transaction(self):
self.new_connection.set_autocommit(False)
# Start a blocking transaction. At some point,
# end_blocking_transaction() should be called.
self.cursor = self.new_connection.cursor()
sql = 'SELECT * FROM %(db_table)s %(for_update)s;' % {
'db_table': Person._meta.db_table,
'for_update': self.new_connection.ops.for_update_sql(),
}
self.cursor.execute(sql, ())
self.cursor.fetchone()
def end_blocking_transaction(self):
# Roll back the blocking transaction.
self.new_connection.rollback()
self.new_connection.set_autocommit(True)
def has_for_update_sql(self, tested_connection, nowait=False):
# Examine the SQL that was executed to determine whether it
# contains the 'SELECT..FOR UPDATE' stanza.
for_update_sql = tested_connection.ops.for_update_sql(nowait)
sql = tested_connection.queries[-1]['sql']
return bool(sql.find(for_update_sql) > -1)
@skipUnlessDBFeature('has_select_for_update')
def test_for_update_sql_generated(self):
"""
Test that the backend's FOR UPDATE variant appears in
generated SQL when select_for_update is invoked.
"""
with transaction.atomic():
list(Person.objects.all().select_for_update())
self.assertTrue(self.has_for_update_sql(connection))
@skipUnlessDBFeature('has_select_for_update_nowait')
def test_for_update_sql_generated_nowait(self):
"""
Test that the backend's FOR UPDATE NOWAIT variant appears in
generated SQL when select_for_update is invoked.
"""
with transaction.atomic():
list(Person.objects.all().select_for_update(nowait=True))
self.assertTrue(self.has_for_update_sql(connection, nowait=True))
@skipUnlessDBFeature('has_select_for_update_nowait')
def test_nowait_raises_error_on_block(self):
"""
If nowait is specified, we expect an error to be raised rather
than blocking.
"""
self.start_blocking_transaction()
status = []
thread = threading.Thread(
target=self.run_select_for_update,
args=(status,),
kwargs={'nowait': True},
)
thread.start()
time.sleep(1)
thread.join()
self.end_blocking_transaction()
self.assertIsInstance(status[-1], DatabaseError)
@skipIfDBFeature('has_select_for_update_nowait')
@skipUnlessDBFeature('has_select_for_update')
def test_unsupported_nowait_raises_error(self):
"""
If a SELECT...FOR UPDATE NOWAIT is run on a database backend
that supports FOR UPDATE but not NOWAIT, then we should find
that a DatabaseError is raised.
"""
self.assertRaises(
DatabaseError,
list,
Person.objects.all().select_for_update(nowait=True)
)
@skipUnlessDBFeature('has_select_for_update')
def test_for_update_requires_transaction(self):
"""
Test that a TransactionManagementError is raised
when a select_for_update query is executed outside of a transaction.
"""
with self.assertRaises(transaction.TransactionManagementError):
list(Person.objects.all().select_for_update())
@skipUnlessDBFeature('has_select_for_update')
def test_for_update_requires_transaction_only_in_execution(self):
"""
Test that no TransactionManagementError is raised
when select_for_update is invoked outside of a transaction -
only when the query is executed.
"""
people = Person.objects.all().select_for_update()
with self.assertRaises(transaction.TransactionManagementError):
list(people)
def run_select_for_update(self, status, nowait=False):
"""
Utility method that runs a SELECT FOR UPDATE against all
Person instances. After the select_for_update, it attempts
to update the name of the only record, save, and commit.
This function expects to run in a separate thread.
"""
status.append('started')
try:
# We need to enter transaction management again, as this is done on
# per-thread basis
with transaction.atomic():
people = list(
Person.objects.all().select_for_update(nowait=nowait)
)
people[0].name = 'Fred'
people[0].save()
except DatabaseError as e:
status.append(e)
finally:
# This method is run in a separate thread. It uses its own
# database connection. Close it without waiting for the GC.
connection.close()
@skipUnlessDBFeature('has_select_for_update')
@skipUnlessDBFeature('supports_transactions')
def test_block(self):
"""
Check that a thread running a select_for_update that
accesses rows being touched by a similar operation
on another connection blocks correctly.
"""
# First, let's start the transaction in our thread.
self.start_blocking_transaction()
# Now, try it again using the ORM's select_for_update
# facility. Do this in a separate thread.
status = []
thread = threading.Thread(
target=self.run_select_for_update, args=(status,)
)
# The thread should immediately block, but we'll sleep
# for a bit to make sure.
thread.start()
sanity_count = 0
while len(status) != 1 and sanity_count < 10:
sanity_count += 1
time.sleep(1)
if sanity_count >= 10:
raise ValueError('Thread did not run and block')
# Check the person hasn't been updated. Since this isn't
# using FOR UPDATE, it won't block.
p = Person.objects.get(pk=self.person.pk)
self.assertEqual('Reinhardt', p.name)
# When we end our blocking transaction, our thread should
# be able to continue.
self.end_blocking_transaction()
thread.join(5.0)
# Check the thread has finished. Assuming it has, we should
# find that it has updated the person's name.
self.assertFalse(thread.isAlive())
# We must commit the transaction to ensure that MySQL gets a fresh read,
# since by default it runs in REPEATABLE READ mode
transaction.commit()
p = Person.objects.get(pk=self.person.pk)
self.assertEqual('Fred', p.name)
@skipUnlessDBFeature('has_select_for_update')
def test_raw_lock_not_available(self):
"""
Check that running a raw query which can't obtain a FOR UPDATE lock
raises the correct exception
"""
self.start_blocking_transaction()
def raw(status):
try:
list(
Person.objects.raw(
'SELECT * FROM %s %s' % (
Person._meta.db_table,
connection.ops.for_update_sql(nowait=True)
)
)
)
except DatabaseError as e:
status.append(e)
finally:
# This method is run in a separate thread. It uses its own
# database connection. Close it without waiting for the GC.
connection.close()
status = []
thread = threading.Thread(target=raw, kwargs={'status': status})
thread.start()
time.sleep(1)
thread.join()
self.end_blocking_transaction()
self.assertIsInstance(status[-1], DatabaseError)
@skipUnlessDBFeature('has_select_for_update')
@override_settings(DATABASE_ROUTERS=[TestRouter()])
def test_select_for_update_on_multidb(self):
query = Person.objects.select_for_update()
self.assertEqual(router.db_for_write(Person), query.db)
@skipUnlessDBFeature('has_select_for_update')
def test_select_for_update_with_get(self):
with transaction.atomic():
person = Person.objects.select_for_update().get(name='Reinhardt')
self.assertEqual(person.name, 'Reinhardt')