diff --git a/django/db/backends/base/base.py b/django/db/backends/base/base.py index 38712bd4bf..9c35ca7e47 100644 --- a/django/db/backends/base/base.py +++ b/django/db/backends/base/base.py @@ -165,7 +165,7 @@ class BaseDatabaseWrapper(object): """Initializes the database connection settings.""" raise NotImplementedError('subclasses of BaseDatabaseWrapper may require an init_connection_state() method') - def create_cursor(self): + def create_cursor(self, name=None): """Creates a cursor. Assumes that a connection is established.""" raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a create_cursor() method') @@ -214,10 +214,21 @@ class BaseDatabaseWrapper(object): # ##### Backend-specific wrappers for PEP-249 connection methods ##### - def _cursor(self): + def _prepare_cursor(self, cursor): + """ + Validate the connection is usable and perform database cursor wrapping. + """ + self.validate_thread_sharing() + if self.queries_logged: + wrapped_cursor = self.make_debug_cursor(cursor) + else: + wrapped_cursor = self.make_cursor(cursor) + return wrapped_cursor + + def _cursor(self, name=None): self.ensure_connection() with self.wrap_database_errors: - return self.create_cursor() + return self.create_cursor(name) def _commit(self): if self.connection is not None: @@ -240,12 +251,7 @@ class BaseDatabaseWrapper(object): """ Creates a cursor, opening a connection if necessary. """ - self.validate_thread_sharing() - if self.queries_logged: - cursor = self.make_debug_cursor(self._cursor()) - else: - cursor = self.make_cursor(self._cursor()) - return cursor + return self._prepare_cursor(self._cursor()) def commit(self): """ @@ -553,6 +559,13 @@ class BaseDatabaseWrapper(object): """ return DatabaseErrorWrapper(self) + def chunked_cursor(self): + """ + Return a cursor that tries to avoid caching in the database (if + supported by the database), otherwise return a regular cursor. + """ + return self.cursor() + def make_debug_cursor(self, cursor): """ Creates a cursor that logs all queries in self.queries_log. diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 9681bb4f33..a6fa920f1e 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -271,7 +271,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): # with SQL standards. cursor.execute('SET SQL_AUTO_IS_NULL = 0') - def create_cursor(self): + def create_cursor(self, name=None): cursor = self.connection.cursor() return CursorWrapper(cursor) diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index c928e2dc6e..3488769bf7 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -249,7 +249,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): if not self.get_autocommit(): self.commit() - def create_cursor(self): + def create_cursor(self, name=None): return FormatStylePlaceholderCursor(self.connection) def _commit(self): diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index b4008532ee..91dee6a722 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -4,6 +4,7 @@ PostgreSQL database backend for Django. Requires psycopg 2: http://initd.org/projects/psycopg2 """ +import threading import warnings from django.conf import settings @@ -145,6 +146,10 @@ class DatabaseWrapper(BaseDatabaseWrapper): introspection_class = DatabaseIntrospection ops_class = DatabaseOperations + def __init__(self, *args, **kwargs): + super(DatabaseWrapper, self).__init__(*args, **kwargs) + self._named_cursor_idx = 0 + def get_connection_params(self): settings_dict = self.settings_dict # None may be used to connect to the default 'postgres' db @@ -206,11 +211,27 @@ class DatabaseWrapper(BaseDatabaseWrapper): if not self.get_autocommit(): self.connection.commit() - def create_cursor(self): - cursor = self.connection.cursor() + def create_cursor(self, name=None): + if name: + # In autocommit mode, the cursor will be used outside of a + # transaction, hence use a holdable cursor. + cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit) + else: + cursor = self.connection.cursor() cursor.tzinfo_factory = utc_tzinfo_factory if settings.USE_TZ else None return cursor + def chunked_cursor(self): + self._named_cursor_idx += 1 + db_cursor = self._cursor( + name='_django_curs_%d_%d' % ( + # Avoid reusing name in other threads + threading.current_thread().ident, + self._named_cursor_idx, + ) + ) + return self._prepare_cursor(db_cursor) + def _set_autocommit(self, autocommit): with self.wrap_database_errors: self.connection.autocommit = autocommit diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index ac7799bfa5..7bbbd6c2f8 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -215,7 +215,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): def init_connection_state(self): pass - def create_cursor(self): + def create_cursor(self, name=None): return self.connection.cursor(factory=SQLiteCursorWrapper) def close(self): diff --git a/django/db/models/query.py b/django/db/models/query.py index 7db076cc58..e059c68f13 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -34,8 +34,9 @@ EmptyResultSet = sql.EmptyResultSet class BaseIterable(object): - def __init__(self, queryset): + def __init__(self, queryset, chunked_fetch=False): self.queryset = queryset + self.chunked_fetch = chunked_fetch class ModelIterable(BaseIterable): @@ -49,7 +50,7 @@ class ModelIterable(BaseIterable): compiler = queryset.query.get_compiler(using=db) # Execute the query. This will also fill compiler.select, klass_info, # and annotations. - results = compiler.execute_sql() + results = compiler.execute_sql(chunked_fetch=self.chunked_fetch) select, klass_info, annotation_col_map = (compiler.select, compiler.klass_info, compiler.annotation_col_map) model_cls = klass_info['model'] @@ -318,7 +319,7 @@ class QuerySet(object): An iterator over the results from applying this QuerySet to the database. """ - return iter(self._iterable_class(self)) + return iter(self._iterable_class(self, chunked_fetch=True)) def aggregate(self, *args, **kwargs): """ @@ -1071,7 +1072,7 @@ class QuerySet(object): def _fetch_all(self): if self._result_cache is None: - self._result_cache = list(self.iterator()) + self._result_cache = list(self._iterable_class(self)) if self._prefetch_related_lookups and not self._prefetch_done: self._prefetch_related_objects() diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index c37264ed0b..0778d38f4a 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -799,7 +799,7 @@ class SQLCompiler(object): self.query.set_extra_mask(['a']) return bool(self.execute_sql(SINGLE)) - def execute_sql(self, result_type=MULTI): + def execute_sql(self, result_type=MULTI, chunked_fetch=False): """ Run the query against the database and returns the result(s). The return value is a single data item if result_type is SINGLE, or an @@ -823,12 +823,16 @@ class SQLCompiler(object): return iter([]) else: return - - cursor = self.connection.cursor() + if chunked_fetch: + cursor = self.connection.chunked_cursor() + else: + cursor = self.connection.cursor() try: cursor.execute(sql, params) except Exception: - cursor.close() + with self.connection.wrap_database_errors: + # Closing a server-side cursor could yield an error + cursor.close() raise if result_type == CURSOR: @@ -852,11 +856,11 @@ class SQLCompiler(object): cursor, self.connection.features.empty_fetchmany_value, self.col_count ) - if not self.connection.features.can_use_chunked_reads: + if not chunked_fetch and not self.connection.features.can_use_chunked_reads: try: # If we are using non-chunked reads, we return the same data # structure as normally, but ensure it is all read into memory - # before going any further. + # before going any further. Use chunked_fetch if requested. return list(result) finally: # done with the cursor diff --git a/django/test/testcases.py b/django/test/testcases.py index 4538716ce7..8203a3837d 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -175,6 +175,7 @@ class SimpleTestCase(unittest.TestCase): for alias in connections: connection = connections[alias] connection.cursor = _CursorFailure(cls.__name__, connection.cursor) + connection.chunked_cursor = _CursorFailure(cls.__name__, connection.chunked_cursor) @classmethod def tearDownClass(cls): @@ -182,6 +183,7 @@ class SimpleTestCase(unittest.TestCase): for alias in connections: connection = connections[alias] connection.cursor = connection.cursor.wrapped + connection.chunked_cursor = connection.chunked_cursor.wrapped if hasattr(cls, '_cls_modified_context'): cls._cls_modified_context.disable() delattr(cls, '_cls_modified_context') diff --git a/docs/ref/databases.txt b/docs/ref/databases.txt index 51feceea04..16addc52bc 100644 --- a/docs/ref/databases.txt +++ b/docs/ref/databases.txt @@ -171,6 +171,24 @@ If you need to add a PostgreSQL extension (like ``hstore``, ``postgis``, etc.) using a migration, use the :class:`~django.contrib.postgres.operations.CreateExtension` operation. +.. _postgresql-server-side-cursors: + +Server-side cursors +------------------- + +.. versionadded:: 1.11 + +When using :meth:`QuerySet.iterator() +`, Django opens a :ref:`server-side +cursor `. By default, PostgreSQL assumes that +only the first 10% of the results of cursor queries will be fetched. The query +planner spends less time planning the query and starts returning results +faster, but this could diminish performance if more than 10% of the results are +retrieved. PostgreSQL's assumptions on the number of rows retrieved for a +cursor query is controlled with the `cursor_tuple_fraction`_ option. + +.. _cursor_tuple_fraction: https://www.postgresql.org/docs/current/static/runtime-config-query.html#GUC-CURSOR-TUPLE-FRACTION + Test database templates ----------------------- diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 3090251659..5fe783514d 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1981,15 +1981,15 @@ evaluated will force it to evaluate again, repeating the query. Also, use of ``iterator()`` causes previous ``prefetch_related()`` calls to be ignored since these two optimizations do not make sense together. -.. warning:: +Some Python database drivers still load the entire result set into memory, but +won't cache results after iterating over them. Oracle and :ref:`PostgreSQL +` use server-side cursors to stream results +from the database without loading the entire result set into memory. - Some Python database drivers like ``psycopg2`` perform caching if using - client side cursors (instantiated with ``connection.cursor()`` and what - Django's ORM uses). Using ``iterator()`` does not affect caching at the - database driver level. To disable this caching, look at `server side - cursors`_. +.. versionchanged:: 1.11 + + PostgreSQL support for server-side cursors was added. -.. _server side cursors: http://initd.org/psycopg/docs/usage.html#server-side-cursors ``latest()`` ~~~~~~~~~~~~ diff --git a/docs/releases/1.11.txt b/docs/releases/1.11.txt index 9be1f4c53d..9bb9a2e903 100644 --- a/docs/releases/1.11.txt +++ b/docs/releases/1.11.txt @@ -273,6 +273,11 @@ Database backends * Added the :setting:`TEST['TEMPLATE'] ` setting to let PostgreSQL users specify a template for creating the test database. +* :meth:`.QuerySet.iterator()` now uses :ref:`server-side cursors + ` on PostgreSQL. This feature transfers some of + the worker memory load (used to hold query results) to the database and might + increase database memory usage. + Email ~~~~~ @@ -527,6 +532,10 @@ Database backend API * Renamed the ``ignores_quoted_identifier_case`` feature to ``ignores_table_name_case`` to more accurately reflect how it is used. +* The ``name`` keyword argument is added to the + ``DatabaseWrapper.create_cursor(self, name=None)`` method to allow usage of + server-side cursors on backends that support it. + Dropped support for PostgreSQL 9.2 and PostGIS 2.0 -------------------------------------------------- diff --git a/tests/backends/test_postgresql.py b/tests/backends/test_postgresql.py new file mode 100644 index 0000000000..024fc1add3 --- /dev/null +++ b/tests/backends/test_postgresql.py @@ -0,0 +1,54 @@ +import unittest +from collections import namedtuple + +from django.db import connection +from django.test import TestCase + +from .models import Person + + +@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") +class ServerSideCursorsPostgres(TestCase): + cursor_fields = 'name, statement, is_holdable, is_binary, is_scrollable, creation_time' + PostgresCursor = namedtuple('PostgresCursor', cursor_fields) + + @classmethod + def setUpTestData(cls): + Person.objects.create(first_name='a', last_name='a') + Person.objects.create(first_name='b', last_name='b') + + def inspect_cursors(self): + with connection.cursor() as cursor: + cursor.execute('SELECT {fields} FROM pg_cursors;'.format(fields=self.cursor_fields)) + cursors = cursor.fetchall() + return [self.PostgresCursor._make(cursor) for cursor in cursors] + + def test_server_side_cursor(self): + persons = Person.objects.iterator() + next(persons) # Open a server-side cursor + cursors = self.inspect_cursors() + self.assertEqual(len(cursors), 1) + self.assertIn('_django_curs_', cursors[0].name) + self.assertFalse(cursors[0].is_scrollable) + self.assertFalse(cursors[0].is_holdable) + self.assertFalse(cursors[0].is_binary) + + def test_server_side_cursor_many_cursors(self): + persons = Person.objects.iterator() + persons2 = Person.objects.iterator() + next(persons) # Open a server-side cursor + next(persons2) # Open a second server-side cursor + cursors = self.inspect_cursors() + self.assertEqual(len(cursors), 2) + for cursor in cursors: + self.assertIn('_django_curs_', cursor.name) + self.assertFalse(cursor.is_scrollable) + self.assertFalse(cursor.is_holdable) + self.assertFalse(cursor.is_binary) + + def test_closed_server_side_cursor(self): + persons = Person.objects.iterator() + next(persons) # Open a server-side cursor + del persons + cursors = self.inspect_cursors() + self.assertEqual(len(cursors), 0) diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index 0013452cac..54b83f524b 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -1069,6 +1069,18 @@ class DisallowedDatabaseQueriesTests(SimpleTestCase): Car.objects.first() +class DisallowedDatabaseQueriesChunkedCursorsTests(SimpleTestCase): + def test_disallowed_database_queries(self): + expected_message = ( + "Database queries aren't allowed in SimpleTestCase. Either use " + "TestCase or TransactionTestCase to ensure proper test isolation or " + "set DisallowedDatabaseQueriesChunkedCursorsTests.allow_database_queries " + "to True to silence this failure." + ) + with self.assertRaisesMessage(AssertionError, expected_message): + next(Car.objects.iterator()) + + class AllowedDatabaseQueriesTests(SimpleTestCase): allow_database_queries = True