diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index 660542c095..468eb16e14 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -92,6 +92,12 @@ class BaseDatabaseWrapper: # is called? self.run_commit_hooks_on_set_autocommit_on = False + # A stack of wrappers to be invoked around execute()/executemany() + # calls. Each entry is a function taking five arguments: execute, sql, + # params, many, and context. It's the function's responsibility to + # call execute(sql, params, many, context). + self.execute_wrappers = [] + self.client = self.client_class(self) self.creation = self.creation_class(self) self.features = self.features_class(self) @@ -629,6 +635,18 @@ class BaseDatabaseWrapper: sids, func = current_run_on_commit.pop(0) func() + @contextmanager + def execute_wrapper(self, wrapper): + """ + Return a context manager under which the wrapper is applied to suitable + database query executions. + """ + self.execute_wrappers.append(wrapper) + try: + yield + finally: + self.execute_wrappers.pop() + def copy(self, alias=None, allow_thread_sharing=None): """ Return a copy of this connection. diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 9634807a87..816164d36a 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -1,5 +1,6 @@ import datetime import decimal +import functools import hashlib import logging import re @@ -65,6 +66,18 @@ class CursorWrapper: return self.cursor.callproc(procname, params, kparams) def execute(self, sql, params=None): + return self._execute_with_wrappers(sql, params, many=False, executor=self._execute) + + def executemany(self, sql, param_list): + return self._execute_with_wrappers(sql, param_list, many=True, executor=self._executemany) + + def _execute_with_wrappers(self, sql, params, many, executor): + context = {'connection': self.db, 'cursor': self} + for wrapper in reversed(self.db.execute_wrappers): + executor = functools.partial(wrapper, executor) + return executor(sql, params, many, context) + + def _execute(self, sql, params, *ignored_wrapper_args): self.db.validate_no_broken_transaction() with self.db.wrap_database_errors: if params is None: @@ -72,7 +85,7 @@ class CursorWrapper: else: return self.cursor.execute(sql, params) - def executemany(self, sql, param_list): + def _executemany(self, sql, param_list, *ignored_wrapper_args): self.db.validate_no_broken_transaction() with self.db.wrap_database_errors: return self.cursor.executemany(sql, param_list) diff --git a/docs/releases/2.0.txt b/docs/releases/2.0.txt index 71d41b664c..77eee0c206 100644 --- a/docs/releases/2.0.txt +++ b/docs/releases/2.0.txt @@ -339,6 +339,11 @@ Models parameters, if the backend supports this feature. Of Django's built-in backends, only Oracle supports it. +* The new :meth:`connection.execute_wrapper() + ` method allows + :doc:`installing wrappers around execution of database queries + `. + * The new ``filter`` argument for built-in aggregates allows :ref:`adding different conditionals ` to multiple aggregations over the same fields or relations. diff --git a/docs/topics/db/index.txt b/docs/topics/db/index.txt index 51f60a65d7..032ec30bb5 100644 --- a/docs/topics/db/index.txt +++ b/docs/topics/db/index.txt @@ -21,4 +21,5 @@ model maps to a single database table. multi-db tablespaces optimization + instrumentation examples/index diff --git a/docs/topics/db/instrumentation.txt b/docs/topics/db/instrumentation.txt new file mode 100644 index 0000000000..9578c1224b --- /dev/null +++ b/docs/topics/db/instrumentation.txt @@ -0,0 +1,114 @@ +======================== +Database instrumentation +======================== + +.. versionadded:: 2.0 + +To help you understand and control the queries issued by your code, Django +provides a hook for installing wrapper functions around the execution of +database queries. For example, wrappers can count queries, measure query +duration, log queries, or even prevent query execution (e.g. to make sure that +no queries are issued while rendering a template with prefetched data). + +The wrappers are modeled after :doc:`middleware ` -- +they are callables which take another callable as one of their arguments. They +call that callable to invoke the (possibly wrapped) database query, and they +can do what they want around that call. They are, however, created and +installed by user code, and so don't need a separate factory like middleware do. + +Installing a wrapper is done in a context manager -- so the wrappers are +temporary and specific to some flow in your code. + +As mentioned above, an example of a wrapper is a query execution blocker. It +could look like this:: + + def blocker(*args): + raise Exception('No database access allowed here.') + +And it would be used in a view to block queries from the template like so:: + + from django.db import connection + from django.shortcuts import render + + def my_view(request): + context = {...} # Code to generate context with all data. + template_name = ... + with connection.execute_wrapper(blocker): + return render(request, template_name, context) + +The parameters sent to the wrappers are: + +* ``execute`` -- a callable, which should be invoked with the rest of the + parameters in order to execute the query. + +* ``sql`` -- a ``str``, the SQL query to be sent to the database. + +* ``params`` -- a list/tuple of parameter values for the SQL command, or a + list/tuple of lists/tuples if the wrapped call is ``executemany()``. + +* ``many`` -- a ``bool`` indicating whether the ultimately invoked call is + ``execute()`` or ``executemany()`` (and whether ``params`` is expected to be + a sequence of values, or a sequence of sequences of values). + +* ``context`` -- a dictionary with further data about the context of + invocation. This includes the connection and cursor. + +Using the parameters, a slightly more complex version of the blocker could +include the connection name in the error message:: + + def blocker(execute, sql, params, many, context): + alias = context['connection'].alias + raise Exception("Access to database '{}' blocked here".format(alias)) + +For a more complete example, a query logger could look like this:: + + import time + + class QueryLogger: + + def __init__(self): + self.queries = [] + + def __call__(self, execute, sql, params, many, context): + current_query = {'sql': sql, 'params': params, 'many': many} + start = time.time() + try: + result = execute(sql, params, many, context) + except Exception as e: + current_query['status'] = 'error' + current_query['exception'] = e + raise + else: + current_query['status'] = 'ok' + return result + finally: + duration = time.time() - start + current_query['duration'] = duration + self.queries.append(current_query) + +To use this, you would create a logger object and install it as a wrapper:: + + from django.db import connection + + ql = QueryLogger() + with connection.execute_wrapper(ql): + do_queries() + # Now we can print the log. + print(ql.queries) + +.. currentmodule:: django.db.backends.base.DatabaseWrapper + +``connection.execute_wrapper()`` +-------------------------------- + +.. method:: execute_wrapper(wrapper) + +Returns a context manager which, when entered, installs a wrapper around +database query executions, and when exited, removes the wrapper. The wrapper is +installed on the thread-local connection object. + +``wrapper`` is a callable taking five arguments. It is called for every query +execution in the scope of the context manager, with arguments ``execute``, +``sql``, ``params``, ``many``, and ``context`` as described above. It's +expected to call ``execute(sql, params, many, context)`` and return the return +value of that call. diff --git a/tests/backends/base/test_base.py b/tests/backends/base/test_base.py index 15cfbb8579..f89aec57f0 100644 --- a/tests/backends/base/test_base.py +++ b/tests/backends/base/test_base.py @@ -1,6 +1,10 @@ +from unittest.mock import MagicMock + from django.db import DEFAULT_DB_ALIAS, connection, connections from django.db.backends.base.base import BaseDatabaseWrapper -from django.test import SimpleTestCase +from django.test import SimpleTestCase, TestCase + +from ..models import Square class DatabaseWrapperTests(SimpleTestCase): @@ -30,3 +34,96 @@ class DatabaseWrapperTests(SimpleTestCase): def test_initialization_display_name(self): self.assertEqual(BaseDatabaseWrapper.display_name, 'unknown') self.assertNotEqual(connection.display_name, 'unknown') + + +class ExecuteWrapperTests(TestCase): + + @staticmethod + def call_execute(connection, params=None): + ret_val = '1' if params is None else '%s' + sql = 'SELECT ' + ret_val + connection.features.bare_select_suffix + with connection.cursor() as cursor: + cursor.execute(sql, params) + + def call_executemany(self, connection, params=None): + # executemany() must use an update query. Make sure it does nothing + # by putting a false condition in the WHERE clause. + sql = 'DELETE FROM {} WHERE 0=1 AND 0=%s'.format(Square._meta.db_table) + if params is None: + params = [(i,) for i in range(3)] + with connection.cursor() as cursor: + cursor.executemany(sql, params) + + @staticmethod + def mock_wrapper(): + return MagicMock(side_effect=lambda execute, *args: execute(*args)) + + def test_wrapper_invoked(self): + wrapper = self.mock_wrapper() + with connection.execute_wrapper(wrapper): + self.call_execute(connection) + self.assertTrue(wrapper.called) + (_, sql, params, many, context), _ = wrapper.call_args + self.assertIn('SELECT', sql) + self.assertIsNone(params) + self.assertIs(many, False) + self.assertEqual(context['connection'], connection) + + def test_wrapper_invoked_many(self): + wrapper = self.mock_wrapper() + with connection.execute_wrapper(wrapper): + self.call_executemany(connection) + self.assertTrue(wrapper.called) + (_, sql, param_list, many, context), _ = wrapper.call_args + self.assertIn('DELETE', sql) + self.assertIsInstance(param_list, (list, tuple)) + self.assertIs(many, True) + self.assertEqual(context['connection'], connection) + + def test_database_queried(self): + wrapper = self.mock_wrapper() + with connection.execute_wrapper(wrapper): + with connection.cursor() as cursor: + sql = 'SELECT 17' + connection.features.bare_select_suffix + cursor.execute(sql) + seventeen = cursor.fetchall() + self.assertEqual(list(seventeen), [(17,)]) + self.call_executemany(connection) + + def test_nested_wrapper_invoked(self): + outer_wrapper = self.mock_wrapper() + inner_wrapper = self.mock_wrapper() + with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(inner_wrapper): + self.call_execute(connection) + self.assertEqual(inner_wrapper.call_count, 1) + self.call_executemany(connection) + self.assertEqual(inner_wrapper.call_count, 2) + + def test_outer_wrapper_blocks(self): + def blocker(*args): + pass + wrapper = self.mock_wrapper() + c = connection # This alias shortens the next line. + with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(wrapper): + with c.cursor() as cursor: + cursor.execute("The database never sees this") + self.assertEqual(wrapper.call_count, 1) + cursor.executemany("The database never sees this %s", [("either",)]) + self.assertEqual(wrapper.call_count, 2) + + def test_wrapper_gets_sql(self): + wrapper = self.mock_wrapper() + sql = "SELECT 'aloha'" + connection.features.bare_select_suffix + with connection.execute_wrapper(wrapper), connection.cursor() as cursor: + cursor.execute(sql) + (_, reported_sql, _, _, _), _ = wrapper.call_args + self.assertEqual(reported_sql, sql) + + def test_wrapper_connection_specific(self): + wrapper = self.mock_wrapper() + with connections['other'].execute_wrapper(wrapper): + self.assertEqual(connections['other'].execute_wrappers, [wrapper]) + self.call_execute(connection) + self.assertFalse(wrapper.called) + self.assertEqual(connection.execute_wrappers, []) + self.assertEqual(connections['other'].execute_wrappers, [])