Fixed #28334 -- Added caching for hstore/citext OIDs.
This commit is contained in:
parent
cb362a6750
commit
86a18dc46a
|
@ -1,4 +1,6 @@
|
|||
from django.contrib.postgres.signals import register_type_handlers
|
||||
from django.contrib.postgres.signals import (
|
||||
get_citext_oids, get_hstore_oids, register_type_handlers,
|
||||
)
|
||||
from django.db.migrations.operations.base import Operation
|
||||
|
||||
|
||||
|
@ -15,6 +17,9 @@ class CreateExtension(Operation):
|
|||
if schema_editor.connection.vendor != 'postgresql':
|
||||
return
|
||||
schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name))
|
||||
# Clear cached, stale oids.
|
||||
get_hstore_oids.cache_clear()
|
||||
get_citext_oids.cache_clear()
|
||||
# Registering new type handlers cannot be done before the extension is
|
||||
# installed, otherwise a subsequent data migration would use the same
|
||||
# connection.
|
||||
|
@ -22,6 +27,9 @@ class CreateExtension(Operation):
|
|||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name))
|
||||
# Clear cached, stale oids.
|
||||
get_hstore_oids.cache_clear()
|
||||
get_citext_oids.cache_clear()
|
||||
|
||||
def describe(self):
|
||||
return "Creates extension %s" % self.name
|
||||
|
|
|
@ -1,14 +1,45 @@
|
|||
import functools
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import ProgrammingError
|
||||
from psycopg2.extras import register_hstore
|
||||
|
||||
from django.db import connections
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_hstore_oids(connection_alias):
|
||||
"""Return hstore and hstore array OIDs."""
|
||||
with connections[connection_alias].cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT t.oid, typarray "
|
||||
"FROM pg_type t "
|
||||
"JOIN pg_namespace ns ON typnamespace = ns.oid "
|
||||
"WHERE typname = 'hstore'"
|
||||
)
|
||||
oids = []
|
||||
array_oids = []
|
||||
for row in cursor:
|
||||
oids.append(row[0])
|
||||
array_oids.append(row[1])
|
||||
return tuple(oids), tuple(array_oids)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_citext_oids(connection_alias):
|
||||
"""Return citext array OIDs."""
|
||||
with connections[connection_alias].cursor() as cursor:
|
||||
cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'")
|
||||
return tuple(row[0] for row in cursor)
|
||||
|
||||
|
||||
def register_type_handlers(connection, **kwargs):
|
||||
if connection.vendor != 'postgresql':
|
||||
return
|
||||
|
||||
try:
|
||||
register_hstore(connection.connection, globally=True)
|
||||
oids, array_oids = get_hstore_oids(connection.alias)
|
||||
register_hstore(connection.connection, globally=True, oid=oids, array_oid=array_oids)
|
||||
except ProgrammingError:
|
||||
# Hstore is not available on the database.
|
||||
#
|
||||
|
@ -21,11 +52,8 @@ def register_type_handlers(connection, **kwargs):
|
|||
pass
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
# Retrieve oids of citext arrays.
|
||||
cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'")
|
||||
oids = tuple(row[0] for row in cursor)
|
||||
array_type = psycopg2.extensions.new_array_type(oids, 'citext[]', psycopg2.STRING)
|
||||
citext_oids = get_citext_oids(connection.alias)
|
||||
array_type = psycopg2.extensions.new_array_type(citext_oids, 'citext[]', psycopg2.STRING)
|
||||
psycopg2.extensions.register_type(array_type, None)
|
||||
except ProgrammingError:
|
||||
# citext is not available on the database.
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
from django.db import connection
|
||||
|
||||
from . import PostgreSQLTestCase
|
||||
|
||||
try:
|
||||
from django.contrib.postgres.signals import get_hstore_oids, get_citext_oids
|
||||
except ImportError:
|
||||
pass # pyscogp2 isn't installed.
|
||||
|
||||
|
||||
class OIDTests(PostgreSQLTestCase):
|
||||
|
||||
def assertOIDs(self, oids):
|
||||
self.assertIsInstance(oids, tuple)
|
||||
self.assertGreater(len(oids), 0)
|
||||
self.assertTrue(all(isinstance(oid, int) for oid in oids))
|
||||
|
||||
def test_hstore_cache(self):
|
||||
with self.assertNumQueries(0):
|
||||
get_hstore_oids(connection.alias)
|
||||
|
||||
def test_citext_cache(self):
|
||||
with self.assertNumQueries(0):
|
||||
get_citext_oids(connection.alias)
|
||||
|
||||
def test_hstore_values(self):
|
||||
oids, array_oids = get_hstore_oids(connection.alias)
|
||||
self.assertOIDs(oids)
|
||||
self.assertOIDs(array_oids)
|
||||
|
||||
def test_citext_values(self):
|
||||
oids = get_citext_oids(connection.alias)
|
||||
self.assertOIDs(oids)
|
Loading…
Reference in New Issue