Fixed #27077 -- Made SQLite's in-memory database checks DRYer.

This commit is contained in:
Chris Jerdonek 2016-08-17 17:34:18 -07:00 committed by Tim Graham
parent 35ea6d83c8
commit 3d0a3c5fff
3 changed files with 15 additions and 10 deletions

View File

@ -233,7 +233,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# If database is in memory, closing the connection destroys the # If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on # database. To prevent accidental data loss, ignore close requests on
# an in-memory db. # an in-memory db.
if not self.is_in_memory_db(self.settings_dict['NAME']): if not self.is_in_memory_db():
BaseDatabaseWrapper.close(self) BaseDatabaseWrapper.close(self)
def _savepoint_allowed(self): def _savepoint_allowed(self):
@ -319,8 +319,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
""" """
self.cursor().execute("BEGIN") self.cursor().execute("BEGIN")
def is_in_memory_db(self, name): def is_in_memory_db(self):
return name == ":memory:" or "mode=memory" in force_text(name) return self.creation.is_in_memory_db(self.settings_dict['NAME'])
FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s') FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')

View File

@ -4,11 +4,16 @@ import sys
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.creation import BaseDatabaseCreation from django.db.backends.base.creation import BaseDatabaseCreation
from django.utils.encoding import force_text
from django.utils.six.moves import input from django.utils.six.moves import input
class DatabaseCreation(BaseDatabaseCreation): class DatabaseCreation(BaseDatabaseCreation):
@staticmethod
def is_in_memory_db(database_name):
return database_name == ':memory:' or 'mode=memory' in force_text(database_name)
def _get_test_db_name(self): def _get_test_db_name(self):
test_database_name = self.connection.settings_dict['TEST']['NAME'] test_database_name = self.connection.settings_dict['TEST']['NAME']
can_share_in_memory_db = self.connection.features.can_share_in_memory_db can_share_in_memory_db = self.connection.features.can_share_in_memory_db
@ -30,7 +35,7 @@ class DatabaseCreation(BaseDatabaseCreation):
if keepdb: if keepdb:
return test_database_name return test_database_name
if not self.connection.is_in_memory_db(test_database_name): if not self.is_in_memory_db(test_database_name):
# Erase the old test database # Erase the old test database
if verbosity >= 1: if verbosity >= 1:
print("Destroying old test database for alias %s..." % ( print("Destroying old test database for alias %s..." % (
@ -56,7 +61,7 @@ class DatabaseCreation(BaseDatabaseCreation):
def get_test_db_clone_settings(self, number): def get_test_db_clone_settings(self, number):
orig_settings_dict = self.connection.settings_dict orig_settings_dict = self.connection.settings_dict
source_database_name = orig_settings_dict['NAME'] source_database_name = orig_settings_dict['NAME']
if self.connection.is_in_memory_db(source_database_name): if self.is_in_memory_db(source_database_name):
return orig_settings_dict return orig_settings_dict
else: else:
new_settings_dict = orig_settings_dict.copy() new_settings_dict = orig_settings_dict.copy()
@ -68,7 +73,7 @@ class DatabaseCreation(BaseDatabaseCreation):
source_database_name = self.connection.settings_dict['NAME'] source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(number)['NAME'] target_database_name = self.get_test_db_clone_settings(number)['NAME']
# Forking automatically makes a copy of an in-memory database. # Forking automatically makes a copy of an in-memory database.
if not self.connection.is_in_memory_db(source_database_name): if not self.is_in_memory_db(source_database_name):
# Erase the old test database # Erase the old test database
if os.access(target_database_name, os.F_OK): if os.access(target_database_name, os.F_OK):
if keepdb: if keepdb:
@ -89,7 +94,7 @@ class DatabaseCreation(BaseDatabaseCreation):
sys.exit(2) sys.exit(2)
def _destroy_test_db(self, test_database_name, verbosity): def _destroy_test_db(self, test_database_name, verbosity):
if test_database_name and not self.connection.is_in_memory_db(test_database_name): if test_database_name and not self.is_in_memory_db(test_database_name):
# Remove the SQLite database file # Remove the SQLite database file
os.remove(test_database_name) os.remove(test_database_name)
@ -103,6 +108,6 @@ class DatabaseCreation(BaseDatabaseCreation):
""" """
test_database_name = self._get_test_db_name() test_database_name = self._get_test_db_name()
sig = [self.connection.settings_dict['NAME']] sig = [self.connection.settings_dict['NAME']]
if self.connection.is_in_memory_db(test_database_name): if self.is_in_memory_db(test_database_name):
sig.append(self.connection.alias) sig.append(self.connection.alias)
return tuple(sig) return tuple(sig)

View File

@ -1299,7 +1299,7 @@ class LiveServerTestCase(TransactionTestCase):
for conn in connections.all(): for conn in connections.all():
# If using in-memory sqlite databases, pass the connections to # If using in-memory sqlite databases, pass the connections to
# the server thread. # the server thread.
if conn.vendor == 'sqlite' and conn.is_in_memory_db(conn.settings_dict['NAME']): if conn.vendor == 'sqlite' and conn.is_in_memory_db():
# Explicitly enable thread-shareability for this connection # Explicitly enable thread-shareability for this connection
conn.allow_thread_sharing = True conn.allow_thread_sharing = True
connections_override[conn.alias] = conn connections_override[conn.alias] = conn
@ -1339,7 +1339,7 @@ class LiveServerTestCase(TransactionTestCase):
# Restore sqlite in-memory database connections' non-shareability # Restore sqlite in-memory database connections' non-shareability
for conn in connections.all(): for conn in connections.all():
if conn.vendor == 'sqlite' and conn.is_in_memory_db(conn.settings_dict['NAME']): if conn.vendor == 'sqlite' and conn.is_in_memory_db():
conn.allow_thread_sharing = False conn.allow_thread_sharing = False
@classmethod @classmethod