from unittest.mock import MagicMock from django.db import DEFAULT_DB_ALIAS, connection, connections from django.db.backends.base.base import BaseDatabaseWrapper from django.test import SimpleTestCase, TestCase from ..models import Square class DatabaseWrapperTests(SimpleTestCase): def test_initialization_class_attributes(self): """ The "initialization" class attributes like client_class and creation_class should be set on the class and reflected in the corresponding instance attributes of the instantiated backend. """ conn = connections[DEFAULT_DB_ALIAS] conn_class = type(conn) attr_names = [ ('client_class', 'client'), ('creation_class', 'creation'), ('features_class', 'features'), ('introspection_class', 'introspection'), ('ops_class', 'ops'), ('validation_class', 'validation'), ] for class_attr_name, instance_attr_name in attr_names: class_attr_value = getattr(conn_class, class_attr_name) self.assertIsNotNone(class_attr_value) instance_attr_value = getattr(conn, instance_attr_name) self.assertIsInstance(instance_attr_value, class_attr_value) def test_initialization_display_name(self): self.assertEqual(BaseDatabaseWrapper.display_name, 'unknown') self.assertNotEqual(connection.display_name, 'unknown') class ExecuteWrapperTests(TestCase): @staticmethod def call_execute(connection, params=None): ret_val = '1' if params is None else '%s' sql = 'SELECT ' + ret_val + connection.features.bare_select_suffix with connection.cursor() as cursor: cursor.execute(sql, params) def call_executemany(self, connection, params=None): # executemany() must use an update query. Make sure it does nothing # by putting a false condition in the WHERE clause. sql = 'DELETE FROM {} WHERE 0=1 AND 0=%s'.format(Square._meta.db_table) if params is None: params = [(i,) for i in range(3)] with connection.cursor() as cursor: cursor.executemany(sql, params) @staticmethod def mock_wrapper(): return MagicMock(side_effect=lambda execute, *args: execute(*args)) def test_wrapper_invoked(self): wrapper = self.mock_wrapper() with connection.execute_wrapper(wrapper): self.call_execute(connection) self.assertTrue(wrapper.called) (_, sql, params, many, context), _ = wrapper.call_args self.assertIn('SELECT', sql) self.assertIsNone(params) self.assertIs(many, False) self.assertEqual(context['connection'], connection) def test_wrapper_invoked_many(self): wrapper = self.mock_wrapper() with connection.execute_wrapper(wrapper): self.call_executemany(connection) self.assertTrue(wrapper.called) (_, sql, param_list, many, context), _ = wrapper.call_args self.assertIn('DELETE', sql) self.assertIsInstance(param_list, (list, tuple)) self.assertIs(many, True) self.assertEqual(context['connection'], connection) def test_database_queried(self): wrapper = self.mock_wrapper() with connection.execute_wrapper(wrapper): with connection.cursor() as cursor: sql = 'SELECT 17' + connection.features.bare_select_suffix cursor.execute(sql) seventeen = cursor.fetchall() self.assertEqual(list(seventeen), [(17,)]) self.call_executemany(connection) def test_nested_wrapper_invoked(self): outer_wrapper = self.mock_wrapper() inner_wrapper = self.mock_wrapper() with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(inner_wrapper): self.call_execute(connection) self.assertEqual(inner_wrapper.call_count, 1) self.call_executemany(connection) self.assertEqual(inner_wrapper.call_count, 2) def test_outer_wrapper_blocks(self): def blocker(*args): pass wrapper = self.mock_wrapper() c = connection # This alias shortens the next line. with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(wrapper): with c.cursor() as cursor: cursor.execute("The database never sees this") self.assertEqual(wrapper.call_count, 1) cursor.executemany("The database never sees this %s", [("either",)]) self.assertEqual(wrapper.call_count, 2) def test_wrapper_gets_sql(self): wrapper = self.mock_wrapper() sql = "SELECT 'aloha'" + connection.features.bare_select_suffix with connection.execute_wrapper(wrapper), connection.cursor() as cursor: cursor.execute(sql) (_, reported_sql, _, _, _), _ = wrapper.call_args self.assertEqual(reported_sql, sql) def test_wrapper_connection_specific(self): wrapper = self.mock_wrapper() with connections['other'].execute_wrapper(wrapper): self.assertEqual(connections['other'].execute_wrappers, [wrapper]) self.call_execute(connection) self.assertFalse(wrapper.called) self.assertEqual(connection.execute_wrappers, []) self.assertEqual(connections['other'].execute_wrappers, [])