Factored out common code in database backends.

This commit is contained in:
Aymeric Augustin 2013-02-18 17:12:42 +01:00
parent 64d0f89ab1
commit 29628e0b6e
5 changed files with 19 additions and 40 deletions

View File

@ -11,6 +11,7 @@ from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
from django.db.backends.signals import connection_created
from django.db.backends import util from django.db.backends import util
from django.db.transaction import TransactionManagementError from django.db.transaction import TransactionManagementError
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -52,6 +53,17 @@ class BaseDatabaseWrapper(object):
__hash__ = object.__hash__ __hash__ = object.__hash__
def _valid_connection(self):
return self.connection is not None
def _cursor(self):
if not self._valid_connection():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
return self.create_cursor()
def _commit(self): def _commit(self):
if self.connection is not None: if self.connection is not None:
return self.connection.commit() return self.connection.commit()

View File

@ -33,19 +33,16 @@ from MySQLdb.constants import FIELD_TYPE, CLIENT
from django.conf import settings from django.conf import settings
from django.db import utils from django.db import utils
from django.db.backends import * from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.mysql.client import DatabaseClient from django.db.backends.mysql.client import DatabaseClient
from django.db.backends.mysql.creation import DatabaseCreation from django.db.backends.mysql.creation import DatabaseCreation
from django.db.backends.mysql.introspection import DatabaseIntrospection from django.db.backends.mysql.introspection import DatabaseIntrospection
from django.db.backends.mysql.validation import DatabaseValidation from django.db.backends.mysql.validation import DatabaseValidation
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.functional import cached_property
from django.utils.safestring import SafeBytes, SafeText from django.utils.safestring import SafeBytes, SafeText
from django.utils import six from django.utils import six
from django.utils import timezone from django.utils import timezone
# Raise exceptions for database warnings if DEBUG is on # Raise exceptions for database warnings if DEBUG is on
from django.conf import settings
if settings.DEBUG: if settings.DEBUG:
warnings.filterwarnings("error", category=Database.Warning) warnings.filterwarnings("error", category=Database.Warning)
@ -454,12 +451,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor.execute('SET SQL_AUTO_IS_NULL = 0') cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close() cursor.close()
def _cursor(self): def create_cursor(self):
if not self._valid_connection():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
cursor = self.connection.cursor() cursor = self.connection.cursor()
return CursorWrapper(cursor) return CursorWrapper(cursor)

View File

@ -48,7 +48,6 @@ except ImportError as e:
from django.conf import settings from django.conf import settings
from django.db import utils from django.db import utils
from django.db.backends import * from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.oracle.client import DatabaseClient from django.db.backends.oracle.client import DatabaseClient
from django.db.backends.oracle.creation import DatabaseCreation from django.db.backends.oracle.creation import DatabaseCreation
from django.db.backends.oracle.introspection import DatabaseIntrospection from django.db.backends.oracle.introspection import DatabaseIntrospection
@ -521,9 +520,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def _valid_connection(self):
return self.connection is not None
def _connect_string(self): def _connect_string(self):
settings_dict = self.settings_dict settings_dict = self.settings_dict
if not settings_dict['HOST'].strip(): if not settings_dict['HOST'].strip():
@ -537,8 +533,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return "%s/%s@%s" % (settings_dict['USER'], return "%s/%s@%s" % (settings_dict['USER'],
settings_dict['PASSWORD'], dsn) settings_dict['PASSWORD'], dsn)
def create_cursor(self, conn): def create_cursor(self):
return FormatStylePlaceholderCursor(conn) return FormatStylePlaceholderCursor(self.connection)
def get_connection_params(self): def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy() conn_params = self.settings_dict['OPTIONS'].copy()
@ -551,7 +547,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return Database.connect(conn_string, **conn_params) return Database.connect(conn_string, **conn_params)
def init_connection_state(self): def init_connection_state(self):
cursor = self.create_cursor(self.connection) cursor = self.create_cursor()
# Set the territory first. The territory overrides NLS_DATE_FORMAT # Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of # and NLS_TIMESTAMP_FORMAT to the territory default. When all of
# these are set in single statement it isn't clear what is supposed # these are set in single statement it isn't clear what is supposed
@ -572,7 +568,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# This check is performed only once per DatabaseWrapper # This check is performed only once per DatabaseWrapper
# instance per thread, since subsequent connections will use # instance per thread, since subsequent connections will use
# the same settings. # the same settings.
cursor = self.create_cursor(self.connection) cursor = self.create_cursor()
try: try:
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s" cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators['contains'], % self._standard_operators['contains'],
@ -602,14 +598,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# stmtcachesize is available only in 4.3.2 and up. # stmtcachesize is available only in 4.3.2 and up.
pass pass
def _cursor(self):
if not self._valid_connection():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
return self.create_cursor(self.connection)
# Oracle doesn't support savepoint commits. Ignore them. # Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid): def _savepoint_commit(self, sid):
pass pass

View File

@ -8,7 +8,6 @@ import sys
from django.db import utils from django.db import utils
from django.db.backends import * from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations
from django.db.backends.postgresql_psycopg2.client import DatabaseClient from django.db.backends.postgresql_psycopg2.client import DatabaseClient
from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation
@ -205,12 +204,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.connection.set_isolation_level(self.isolation_level) self.connection.set_isolation_level(self.isolation_level)
self._get_pg_version() self._get_pg_version()
def _cursor(self): def create_cursor(self):
if self.connection is None:
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
cursor = self.connection.cursor() cursor = self.connection.cursor()
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
return CursorWrapper(cursor) return CursorWrapper(cursor)

View File

@ -14,7 +14,6 @@ import sys
from django.db import utils from django.db import utils
from django.db.backends import * from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.sqlite3.client import DatabaseClient from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection from django.db.backends.sqlite3.introspection import DatabaseIntrospection
@ -344,13 +343,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def init_connection_state(self): def init_connection_state(self):
pass pass
def _cursor(self): def create_cursor(self):
if self.connection is None:
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
return self.connection.cursor(factory=SQLiteCursorWrapper) return self.connection.cursor(factory=SQLiteCursorWrapper)
def check_constraints(self, table_names=None): def check_constraints(self, table_names=None):