diff --git a/django/contrib/gis/db/backends/mysql/base.py b/django/contrib/gis/db/backends/mysql/base.py index d252b38804..9cf61e40d2 100644 --- a/django/contrib/gis/db/backends/mysql/base.py +++ b/django/contrib/gis/db/backends/mysql/base.py @@ -9,9 +9,7 @@ from .schema import MySQLGISSchemaEditor class DatabaseWrapper(MySQLDatabaseWrapper): SchemaEditorClass = MySQLGISSchemaEditor - - def __init__(self, *args, **kwargs): - super(DatabaseWrapper, self).__init__(*args, **kwargs) - self.features = DatabaseFeatures(self) - self.ops = MySQLOperations(self) - self.introspection = MySQLIntrospection(self) + # Classes instantiated in __init__(). + features_class = DatabaseFeatures + introspection_class = MySQLIntrospection + ops_class = MySQLOperations diff --git a/django/contrib/gis/db/backends/oracle/base.py b/django/contrib/gis/db/backends/oracle/base.py index 167b61d01b..a4f6684f6d 100644 --- a/django/contrib/gis/db/backends/oracle/base.py +++ b/django/contrib/gis/db/backends/oracle/base.py @@ -9,9 +9,7 @@ from .schema import OracleGISSchemaEditor class DatabaseWrapper(OracleDatabaseWrapper): SchemaEditorClass = OracleGISSchemaEditor - - def __init__(self, *args, **kwargs): - super(DatabaseWrapper, self).__init__(*args, **kwargs) - self.features = DatabaseFeatures(self) - self.ops = OracleOperations(self) - self.introspection = OracleIntrospection(self) + # Classes instantiated in __init__(). + features_class = DatabaseFeatures + introspection_class = OracleIntrospection + ops_class = OracleOperations diff --git a/django/contrib/gis/db/backends/spatialite/base.py b/django/contrib/gis/db/backends/spatialite/base.py index 44514a87d4..1f03868945 100644 --- a/django/contrib/gis/db/backends/spatialite/base.py +++ b/django/contrib/gis/db/backends/spatialite/base.py @@ -17,6 +17,11 @@ from .schema import SpatialiteSchemaEditor class DatabaseWrapper(SQLiteDatabaseWrapper): SchemaEditorClass = SpatialiteSchemaEditor + # Classes instantiated in __init__(). + client_class = SpatiaLiteClient + features_class = DatabaseFeatures + introspection_class = SpatiaLiteIntrospection + ops_class = SpatiaLiteOperations def __init__(self, *args, **kwargs): # Before we get too far, make sure pysqlite 2.5+ is installed. @@ -37,10 +42,6 @@ class DatabaseWrapper(SQLiteDatabaseWrapper): 'SPATIALITE_LIBRARY_PATH in your settings.' ) super(DatabaseWrapper, self).__init__(*args, **kwargs) - self.features = DatabaseFeatures(self) - self.ops = SpatiaLiteOperations(self) - self.client = SpatiaLiteClient(self) - self.introspection = SpatiaLiteIntrospection(self) def get_new_connection(self, conn_params): conn = super(DatabaseWrapper, self).get_new_connection(conn_params) diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index fcafa5f71a..595d191dce 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -8,6 +8,7 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS from django.db.backends import utils +from django.db.backends.base.validation import BaseDatabaseValidation from django.db.backends.signals import connection_created from django.db.transaction import TransactionManagementError from django.db.utils import DatabaseError, DatabaseErrorWrapper @@ -36,6 +37,13 @@ class BaseDatabaseWrapper(object): ops = None vendor = 'unknown' SchemaEditorClass = None + # Classes instantiated in __init__(). + client_class = None + creation_class = None + features_class = None + introspection_class = None + ops_class = None + validation_class = BaseDatabaseValidation queries_limit = 9000 @@ -88,6 +96,13 @@ class BaseDatabaseWrapper(object): # is called? self.run_commit_hooks_on_set_autocommit_on = False + self.client = self.client_class(self) + self.creation = self.creation_class(self) + self.features = self.features_class(self) + self.introspection = self.introspection_class(self) + self.ops = self.ops_class(self) + self.validation = self.validation_class(self) + def ensure_timezone(self): """ Ensure the connection's timezone is set to `self.timezone_name` and diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index f45d913bef..602bfb8fe6 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -13,7 +13,6 @@ from django.db.backends.base.client import BaseDatabaseClient from django.db.backends.base.creation import BaseDatabaseCreation from django.db.backends.base.introspection import BaseDatabaseIntrospection from django.db.backends.base.operations import BaseDatabaseOperations -from django.db.backends.base.validation import BaseDatabaseValidation from django.db.backends.dummy.features import DummyDatabaseFeatures @@ -71,16 +70,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): _savepoint_commit = complain _savepoint_rollback = ignore _set_autocommit = complain - - def __init__(self, *args, **kwargs): - super(DatabaseWrapper, self).__init__(*args, **kwargs) - - self.features = DummyDatabaseFeatures(self) - self.ops = DatabaseOperations(self) - self.client = DatabaseClient(self) - self.creation = DatabaseCreation(self) - self.introspection = DatabaseIntrospection(self) - self.validation = BaseDatabaseValidation(self) + # Classes instantiated in __init__(). + client_class = DatabaseClient + creation_class = DatabaseCreation + features_class = DummyDatabaseFeatures + introspection_class = DatabaseIntrospection + ops_class = DatabaseOperations def is_usable(self): return True diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 94415139cc..472b2b98df 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -227,16 +227,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): Database = Database SchemaEditorClass = DatabaseSchemaEditor - - def __init__(self, *args, **kwargs): - super(DatabaseWrapper, self).__init__(*args, **kwargs) - - self.features = DatabaseFeatures(self) - self.ops = DatabaseOperations(self) - self.client = DatabaseClient(self) - self.creation = DatabaseCreation(self) - self.introspection = DatabaseIntrospection(self) - self.validation = DatabaseValidation(self) + # Classes instantiated in __init__(). + client_class = DatabaseClient + creation_class = DatabaseCreation + features_class = DatabaseFeatures + introspection_class = DatabaseIntrospection + ops_class = DatabaseOperations + validation_class = DatabaseValidation def get_connection_params(self): kwargs = { diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 91af08acb6..c4dc075355 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -15,7 +15,6 @@ import warnings from django.conf import settings from django.db import utils from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.base.validation import BaseDatabaseValidation from django.utils import six, timezone from django.utils.deprecation import RemovedInDjango20Warning from django.utils.duration import duration_string @@ -179,18 +178,17 @@ class DatabaseWrapper(BaseDatabaseWrapper): Database = Database SchemaEditorClass = DatabaseSchemaEditor + # Classes instantiated in __init__(). + client_class = DatabaseClient + creation_class = DatabaseCreation + features_class = DatabaseFeatures + introspection_class = DatabaseIntrospection + ops_class = DatabaseOperations def __init__(self, *args, **kwargs): super(DatabaseWrapper, self).__init__(*args, **kwargs) - - self.features = DatabaseFeatures(self) use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True) self.features.can_return_id_from_insert = use_returning_into - self.ops = DatabaseOperations(self) - self.client = DatabaseClient(self) - self.creation = DatabaseCreation(self) - self.introspection = DatabaseIntrospection(self) - self.validation = BaseDatabaseValidation(self) def _connect_string(self): settings_dict = self.settings_dict diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 19afa48047..1ab05d81f6 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -10,7 +10,6 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.base.validation import BaseDatabaseValidation from django.db.utils import DatabaseError as WrappedDatabaseError from django.utils import six from django.utils.encoding import force_str @@ -141,16 +140,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): Database = Database SchemaEditorClass = DatabaseSchemaEditor - - def __init__(self, *args, **kwargs): - super(DatabaseWrapper, self).__init__(*args, **kwargs) - - self.features = DatabaseFeatures(self) - self.ops = DatabaseOperations(self) - self.client = DatabaseClient(self) - self.creation = DatabaseCreation(self) - self.introspection = DatabaseIntrospection(self) - self.validation = BaseDatabaseValidation(self) + # Classes instantiated in __init__(). + client_class = DatabaseClient + creation_class = DatabaseCreation + features_class = DatabaseFeatures + introspection_class = DatabaseIntrospection + ops_class = DatabaseOperations def get_connection_params(self): settings_dict = self.settings_dict diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 0d6b4e3d00..e8b707861f 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -15,7 +15,6 @@ from django.conf import settings from django.db import utils from django.db.backends import utils as backend_utils from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.base.validation import BaseDatabaseValidation from django.utils import six, timezone from django.utils.dateparse import ( parse_date, parse_datetime, parse_duration, parse_time, @@ -163,16 +162,12 @@ class DatabaseWrapper(BaseDatabaseWrapper): Database = Database SchemaEditorClass = DatabaseSchemaEditor - - def __init__(self, *args, **kwargs): - super(DatabaseWrapper, self).__init__(*args, **kwargs) - - self.features = DatabaseFeatures(self) - self.ops = DatabaseOperations(self) - self.client = DatabaseClient(self) - self.creation = DatabaseCreation(self) - self.introspection = DatabaseIntrospection(self) - self.validation = BaseDatabaseValidation(self) + # Classes instantiated in __init__(). + client_class = DatabaseClient + creation_class = DatabaseCreation + features_class = DatabaseFeatures + introspection_class = DatabaseIntrospection + ops_class = DatabaseOperations def get_connection_params(self): settings_dict = self.settings_dict diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 4851995149..106fd7f856 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -32,6 +32,31 @@ from django.utils.six.moves import range from . import models +class DatabaseWrapperTests(SimpleTestCase): + + def test_initialization_class_attributes(self): + """ + The "initialization" class attributes like client_class and + creation_class should be set on the class and reflected in the + corresponding instance attributes of the instantiated backend. + """ + conn = connections[DEFAULT_DB_ALIAS] + conn_class = type(conn) + attr_names = [ + ('client_class', 'client'), + ('creation_class', 'creation'), + ('features_class', 'features'), + ('introspection_class', 'introspection'), + ('ops_class', 'ops'), + ('validation_class', 'validation'), + ] + for class_attr_name, instance_attr_name in attr_names: + class_attr_value = getattr(conn_class, class_attr_name) + self.assertIsNotNone(class_attr_value) + instance_attr_value = getattr(conn, instance_attr_name) + self.assertIsInstance(instance_attr_value, class_attr_value) + + class DummyBackendTest(SimpleTestCase): def test_no_databases(self): diff --git a/tests/test_runner/tests.py b/tests/test_runner/tests.py index c2ad07013c..15417db645 100644 --- a/tests/test_runner/tests.py +++ b/tests/test_runner/tests.py @@ -281,7 +281,7 @@ class SetupDatabasesTests(unittest.TestCase): } }) - with mock.patch('django.db.backends.dummy.base.DatabaseCreation') as mocked_db_creation: + with mock.patch('django.db.backends.dummy.base.DatabaseWrapper.creation_class') as mocked_db_creation: with mock.patch('django.test.utils.connections', new=tested_connections): old_config = self.runner_instance.setup_databases() self.runner_instance.teardown_databases(old_config) @@ -306,7 +306,7 @@ class SetupDatabasesTests(unittest.TestCase): 'ENGINE': 'django.db.backends.dummy', }, }) - with mock.patch('django.db.backends.dummy.base.DatabaseCreation') as mocked_db_creation: + with mock.patch('django.db.backends.dummy.base.DatabaseWrapper.creation_class') as mocked_db_creation: with mock.patch('django.test.utils.connections', new=tested_connections): self.runner_instance.setup_databases() mocked_db_creation.return_value.create_test_db.assert_called_once_with( @@ -320,7 +320,7 @@ class SetupDatabasesTests(unittest.TestCase): 'TEST': {'SERIALIZE': False}, }, }) - with mock.patch('django.db.backends.dummy.base.DatabaseCreation') as mocked_db_creation: + with mock.patch('django.db.backends.dummy.base.DatabaseWrapper.creation_class') as mocked_db_creation: with mock.patch('django.test.utils.connections', new=tested_connections): self.runner_instance.setup_databases() mocked_db_creation.return_value.create_test_db.assert_called_once_with(