django/tests/backends/postgresql/test_server_side_cursors.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

102 lines
3.6 KiB
Python
Raw Normal View History

import operator
import unittest
from collections import namedtuple
from contextlib import contextmanager
from django.db import connection, models
from django.test import TestCase
2017-02-12 04:37:49 +08:00
from ..models import Person
2017-02-12 04:37:49 +08:00
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
class ServerSideCursorsPostgres(TestCase):
cursor_fields = (
"name, statement, is_holdable, is_binary, is_scrollable, creation_time"
)
PostgresCursor = namedtuple("PostgresCursor", cursor_fields)
@classmethod
def setUpTestData(cls):
Person.objects.create(first_name="a", last_name="a")
Person.objects.create(first_name="b", last_name="b")
def inspect_cursors(self):
with connection.cursor() as cursor:
cursor.execute(
"SELECT {fields} FROM pg_cursors;".format(fields=self.cursor_fields)
)
cursors = cursor.fetchall()
return [self.PostgresCursor._make(cursor) for cursor in cursors]
@contextmanager
def override_db_setting(self, **kwargs):
for setting in kwargs:
original_value = connection.settings_dict.get(setting)
if setting in connection.settings_dict:
self.addCleanup(
operator.setitem, connection.settings_dict, setting, original_value
)
else:
self.addCleanup(operator.delitem, connection.settings_dict, setting)
connection.settings_dict[setting] = kwargs[setting]
yield
def assertUsesCursor(self, queryset, num_expected=1):
next(queryset) # Open a server-side cursor
cursors = self.inspect_cursors()
self.assertEqual(len(cursors), num_expected)
for cursor in cursors:
self.assertIn("_django_curs_", cursor.name)
self.assertFalse(cursor.is_scrollable)
self.assertFalse(cursor.is_holdable)
self.assertFalse(cursor.is_binary)
def asserNotUsesCursor(self, queryset):
self.assertUsesCursor(queryset, num_expected=0)
def test_server_side_cursor(self):
self.assertUsesCursor(Person.objects.iterator())
def test_values(self):
self.assertUsesCursor(Person.objects.values("first_name").iterator())
def test_values_list(self):
self.assertUsesCursor(Person.objects.values_list("first_name").iterator())
def test_values_list_flat(self):
self.assertUsesCursor(
Person.objects.values_list("first_name", flat=True).iterator()
)
def test_values_list_fields_not_equal_to_names(self):
expr = models.Count("id")
self.assertUsesCursor(
Person.objects.annotate(id__count=expr)
.values_list(expr, "id__count")
.iterator()
)
def test_server_side_cursor_many_cursors(self):
persons = Person.objects.iterator()
persons2 = Person.objects.iterator()
next(persons) # Open a server-side cursor
self.assertUsesCursor(persons2, num_expected=2)
def test_closed_server_side_cursor(self):
persons = Person.objects.iterator()
next(persons) # Open a server-side cursor
del persons
cursors = self.inspect_cursors()
self.assertEqual(len(cursors), 0)
def test_server_side_cursors_setting(self):
with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=False):
persons = Person.objects.iterator()
self.assertUsesCursor(persons)
del persons # Close server-side cursor
with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=True):
self.asserNotUsesCursor(Person.objects.iterator())