From 29628e0b6e5b1c6324e0c06cc56a49a5aa0747e0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 18 Feb 2013 17:12:42 +0100 Subject: [PATCH] Factored out common code in database backends. --- django/db/backends/__init__.py | 12 +++++++++++ django/db/backends/mysql/base.py | 10 +--------- django/db/backends/oracle/base.py | 20 ++++--------------- .../db/backends/postgresql_psycopg2/base.py | 8 +------- django/db/backends/sqlite3/base.py | 9 +-------- 5 files changed, 19 insertions(+), 40 deletions(-) diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 03b62f6413..7a0a577ef1 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -11,6 +11,7 @@ from contextlib import contextmanager from django.conf import settings from django.db import DEFAULT_DB_ALIAS +from django.db.backends.signals import connection_created from django.db.backends import util from django.db.transaction import TransactionManagementError from django.utils.functional import cached_property @@ -52,6 +53,17 @@ class BaseDatabaseWrapper(object): __hash__ = object.__hash__ + def _valid_connection(self): + return self.connection is not None + + def _cursor(self): + if not self._valid_connection(): + conn_params = self.get_connection_params() + self.connection = self.get_new_connection(conn_params) + self.init_connection_state() + connection_created.send(sender=self.__class__, connection=self) + return self.create_cursor() + def _commit(self): if self.connection is not None: return self.connection.commit() diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 9de2a4d62d..eb823083f4 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -33,19 +33,16 @@ from MySQLdb.constants import FIELD_TYPE, CLIENT from django.conf import settings from django.db import utils from django.db.backends import * -from django.db.backends.signals import connection_created from django.db.backends.mysql.client import DatabaseClient from django.db.backends.mysql.creation import DatabaseCreation from django.db.backends.mysql.introspection import DatabaseIntrospection from django.db.backends.mysql.validation import DatabaseValidation from django.utils.encoding import force_str -from django.utils.functional import cached_property from django.utils.safestring import SafeBytes, SafeText from django.utils import six from django.utils import timezone # Raise exceptions for database warnings if DEBUG is on -from django.conf import settings if settings.DEBUG: warnings.filterwarnings("error", category=Database.Warning) @@ -454,12 +451,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): cursor.execute('SET SQL_AUTO_IS_NULL = 0') cursor.close() - def _cursor(self): - if not self._valid_connection(): - conn_params = self.get_connection_params() - self.connection = self.get_new_connection(conn_params) - self.init_connection_state() - connection_created.send(sender=self.__class__, connection=self) + def create_cursor(self): cursor = self.connection.cursor() return CursorWrapper(cursor) diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 7bcfb46798..e329ef3191 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -48,7 +48,6 @@ except ImportError as e: from django.conf import settings from django.db import utils from django.db.backends import * -from django.db.backends.signals import connection_created from django.db.backends.oracle.client import DatabaseClient from django.db.backends.oracle.creation import DatabaseCreation from django.db.backends.oracle.introspection import DatabaseIntrospection @@ -521,9 +520,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE') self.cursor().execute('SET CONSTRAINTS ALL DEFERRED') - def _valid_connection(self): - return self.connection is not None - def _connect_string(self): settings_dict = self.settings_dict if not settings_dict['HOST'].strip(): @@ -537,8 +533,8 @@ class DatabaseWrapper(BaseDatabaseWrapper): return "%s/%s@%s" % (settings_dict['USER'], settings_dict['PASSWORD'], dsn) - def create_cursor(self, conn): - return FormatStylePlaceholderCursor(conn) + def create_cursor(self): + return FormatStylePlaceholderCursor(self.connection) def get_connection_params(self): conn_params = self.settings_dict['OPTIONS'].copy() @@ -551,7 +547,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): return Database.connect(conn_string, **conn_params) def init_connection_state(self): - cursor = self.create_cursor(self.connection) + cursor = self.create_cursor() # Set the territory first. The territory overrides NLS_DATE_FORMAT # and NLS_TIMESTAMP_FORMAT to the territory default. When all of # these are set in single statement it isn't clear what is supposed @@ -572,7 +568,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): # This check is performed only once per DatabaseWrapper # instance per thread, since subsequent connections will use # the same settings. - cursor = self.create_cursor(self.connection) + cursor = self.create_cursor() try: cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s" % self._standard_operators['contains'], @@ -602,14 +598,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): # stmtcachesize is available only in 4.3.2 and up. pass - def _cursor(self): - if not self._valid_connection(): - conn_params = self.get_connection_params() - self.connection = self.get_new_connection(conn_params) - self.init_connection_state() - connection_created.send(sender=self.__class__, connection=self) - return self.create_cursor(self.connection) - # Oracle doesn't support savepoint commits. Ignore them. def _savepoint_commit(self, sid): pass diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index b8d7fe3195..fb1ad5f991 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -8,7 +8,6 @@ import sys from django.db import utils from django.db.backends import * -from django.db.backends.signals import connection_created from django.db.backends.postgresql_psycopg2.operations import DatabaseOperations from django.db.backends.postgresql_psycopg2.client import DatabaseClient from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation @@ -205,12 +204,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): self.connection.set_isolation_level(self.isolation_level) self._get_pg_version() - def _cursor(self): - if self.connection is None: - conn_params = self.get_connection_params() - self.connection = self.get_new_connection(conn_params) - self.init_connection_state() - connection_created.send(sender=self.__class__, connection=self) + def create_cursor(self): cursor = self.connection.cursor() cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None return CursorWrapper(cursor) diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index dd87972d5b..7ddaaf8fe3 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -14,7 +14,6 @@ import sys from django.db import utils from django.db.backends import * -from django.db.backends.signals import connection_created from django.db.backends.sqlite3.client import DatabaseClient from django.db.backends.sqlite3.creation import DatabaseCreation from django.db.backends.sqlite3.introspection import DatabaseIntrospection @@ -344,13 +343,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): def init_connection_state(self): pass - def _cursor(self): - if self.connection is None: - conn_params = self.get_connection_params() - self.connection = self.get_new_connection(conn_params) - self.init_connection_state() - connection_created.send(sender=self.__class__, connection=self) - + def create_cursor(self): return self.connection.cursor(factory=SQLiteCursorWrapper) def check_constraints(self, table_names=None):