Factored out common code in database backends.

This commit is contained in:
Aymeric Augustin 2013-02-18 17:12:42 +01:00
parent 64d0f89ab1
commit 29628e0b6e
5 changed files with 19 additions and 40 deletions

View File

@ -11,6 +11,7 @@ from contextlib import contextmanager
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
from django.db.backends.signals import connection_created
from django.db.backends import util
from django.db.transaction import TransactionManagementError
from django.utils.functional import cached_property
@ -52,6 +53,17 @@ class BaseDatabaseWrapper(object):
__hash__ = object.__hash__
def _valid_connection(self):
return self.connection is not None
def _cursor(self):
if not self._valid_connection():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
return self.create_cursor()
def _commit(self):
if self.connection is not None:
return self.connection.commit()

View File

@ -33,19 +33,16 @@ from MySQLdb.constants import FIELD_TYPE, CLIENT
from django.conf import settings
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.mysql.client import DatabaseClient
from django.db.backends.mysql.creation import DatabaseCreation
from django.db.backends.mysql.introspection import DatabaseIntrospection
from django.db.backends.mysql.validation import DatabaseValidation
from django.utils.encoding import force_str
from django.utils.functional import cached_property
from django.utils.safestring import SafeBytes, SafeText
from django.utils import six
from django.utils import timezone
# Raise exceptions for database warnings if DEBUG is on
from django.conf import settings
if settings.DEBUG:
warnings.filterwarnings("error", category=Database.Warning)
@ -454,12 +451,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
cursor.execute('SET SQL_AUTO_IS_NULL = 0')
cursor.close()
def _cursor(self):
if not self._valid_connection():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
def create_cursor(self):
cursor = self.connection.cursor()
return CursorWrapper(cursor)

View File

@ -48,7 +48,6 @@ except ImportError as e:
from django.conf import settings
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.oracle.client import DatabaseClient
from django.db.backends.oracle.creation import DatabaseCreation
from django.db.backends.oracle.introspection import DatabaseIntrospection
@ -521,9 +520,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def _valid_connection(self):
return self.connection is not None
def _connect_string(self):
settings_dict = self.settings_dict
if not settings_dict['HOST'].strip():
@ -537,8 +533,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return "%s/%s@%s" % (settings_dict['USER'],
settings_dict['PASSWORD'], dsn)
def create_cursor(self, conn):
return FormatStylePlaceholderCursor(conn)
def create_cursor(self):
return FormatStylePlaceholderCursor(self.connection)
def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy()
@ -551,7 +547,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return Database.connect(conn_string, **conn_params)
def init_connection_state(self):
cursor = self.create_cursor(self.connection)
cursor = self.create_cursor()
# Set the territory first. The territory overrides NLS_DATE_FORMAT
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
# these are set in single statement it isn't clear what is supposed
@ -572,7 +568,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# This check is performed only once per DatabaseWrapper
# instance per thread, since subsequent connections will use
# the same settings.
cursor = self.create_cursor(self.connection)
cursor = self.create_cursor()
try:
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators['contains'],
@ -602,14 +598,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# stmtcachesize is available only in 4.3.2 and up.
pass
def _cursor(self):
if not self._valid_connection():
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
return self.create_cursor(self.connection)
# Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid):
pass

View File

@ -8,7 +8,6 @@ import sys
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations
from django.db.backends.postgresql_psycopg2.client import DatabaseClient
from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation
@ -205,12 +204,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.connection.set_isolation_level(self.isolation_level)
self._get_pg_version()
def _cursor(self):
if self.connection is None:
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
def create_cursor(self):
cursor = self.connection.cursor()
cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None
return CursorWrapper(cursor)

View File

@ -14,7 +14,6 @@ import sys
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection
@ -344,13 +343,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def init_connection_state(self):
pass
def _cursor(self):
if self.connection is None:
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
def create_cursor(self):
return self.connection.cursor(factory=SQLiteCursorWrapper)
def check_constraints(self, table_names=None):