[1.11.x] Fixed #28161 -- Fixed return type of ArrayField(CITextField()).

Thanks Tim for the review.

Backport of b91868507a from master.
This commit is contained in:
Simon Charette 2017-05-03 01:25:30 -04:00
parent f3217ab596
commit 246166cfe4
8 changed files with 40 additions and 15 deletions

View File

@ -5,7 +5,7 @@ from django.db.models import CharField, TextField
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from .lookups import SearchLookup, TrigramSimilar, Unaccent from .lookups import SearchLookup, TrigramSimilar, Unaccent
from .signals import register_hstore_handler from .signals import register_type_handlers
class PostgresConfig(AppConfig): class PostgresConfig(AppConfig):
@ -16,8 +16,8 @@ class PostgresConfig(AppConfig):
# Connections may already exist before we are called. # Connections may already exist before we are called.
for conn in connections.all(): for conn in connections.all():
if conn.connection is not None: if conn.connection is not None:
register_hstore_handler(conn) register_type_handlers(conn)
connection_created.connect(register_hstore_handler) connection_created.connect(register_type_handlers)
CharField.register_lookup(Unaccent) CharField.register_lookup(Unaccent)
TextField.register_lookup(Unaccent) TextField.register_lookup(Unaccent)
CharField.register_lookup(SearchLookup) CharField.register_lookup(SearchLookup)

View File

@ -1,4 +1,4 @@
from django.contrib.postgres.signals import register_hstore_handler from django.contrib.postgres.signals import register_type_handlers
from django.db.migrations.operations.base import Operation from django.db.migrations.operations.base import Operation
@ -15,6 +15,10 @@ class CreateExtension(Operation):
if schema_editor.connection.vendor != 'postgresql': if schema_editor.connection.vendor != 'postgresql':
return return
schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name)) schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name))
# Registering new type handlers cannot be done before the extension is
# installed, otherwise a subsequent data migration would use the same
# connection.
register_type_handlers(schema_editor.connection)
def database_backwards(self, app_label, schema_editor, from_state, to_state): def database_backwards(self, app_label, schema_editor, from_state, to_state):
schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name)) schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name))
@ -40,13 +44,6 @@ class HStoreExtension(CreateExtension):
def __init__(self): def __init__(self):
self.name = 'hstore' self.name = 'hstore'
def database_forwards(self, app_label, schema_editor, from_state, to_state):
super(HStoreExtension, self).database_forwards(app_label, schema_editor, from_state, to_state)
# Register hstore straight away as it cannot be done before the
# extension is installed, a subsequent data migration would use the
# same connection
register_hstore_handler(schema_editor.connection)
class TrigramExtension(CreateExtension): class TrigramExtension(CreateExtension):

View File

@ -1,10 +1,11 @@
import psycopg2
from psycopg2 import ProgrammingError from psycopg2 import ProgrammingError
from psycopg2.extras import register_hstore from psycopg2.extras import register_hstore
from django.utils import six from django.utils import six
def register_hstore_handler(connection, **kwargs): def register_type_handlers(connection, **kwargs):
if connection.vendor != 'postgresql': if connection.vendor != 'postgresql':
return return
@ -23,3 +24,17 @@ def register_hstore_handler(connection, **kwargs):
# This is also needed in order to create the connection in order to # This is also needed in order to create the connection in order to
# install the hstore extension. # install the hstore extension.
pass pass
try:
with connection.cursor() as cursor:
# Retrieve oids of citext arrays.
cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'")
oids = tuple(row[0] for row in cursor)
array_type = psycopg2.extensions.new_array_type(oids, 'citext[]', psycopg2.STRING)
psycopg2.extensions.register_type(array_type, None)
except ProgrammingError:
# citext is not available on the database.
#
# The same comments in the except block of the above call to
# register_hstore() also apply here.
pass

View File

@ -88,3 +88,6 @@ Bugfixes
* Fixed a regression where ``Model._state.db`` wasn't set correctly on * Fixed a regression where ``Model._state.db`` wasn't set correctly on
multi-table inheritance parent models after saving a child model multi-table inheritance parent models after saving a child model
(:ticket:`28166`). (:ticket:`28166`).
* Corrected the return type of ``ArrayField(CITextField())`` values retrieved
from the database (:ticket:`28161`).

View File

@ -12,14 +12,14 @@ class PostgreSQLTestCase(TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
# No need to keep that signal overhead for non PostgreSQL-related tests. # No need to keep that signal overhead for non PostgreSQL-related tests.
from django.contrib.postgres.signals import register_hstore_handler from django.contrib.postgres.signals import register_type_handlers
connection_created.disconnect(register_hstore_handler) connection_created.disconnect(register_type_handlers)
super(PostgreSQLTestCase, cls).tearDownClass() super(PostgreSQLTestCase, cls).tearDownClass()
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
# To locate the widget's template. # To locate the widget's template.
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
class PostgreSQLWidgetTestCase(WidgetTest): class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLTestCase):
pass pass

View File

@ -144,6 +144,7 @@ class Migration(migrations.Migration):
('name', CICharField(primary_key=True, max_length=255)), ('name', CICharField(primary_key=True, max_length=255)),
('email', CIEmailField()), ('email', CIEmailField()),
('description', CITextField()), ('description', CITextField()),
('array_field', ArrayField(CITextField(), null=True)),
], ],
options={ options={
'required_db_vendor': 'postgresql', 'required_db_vendor': 'postgresql',

View File

@ -105,6 +105,7 @@ class CITestModel(PostgreSQLModel):
name = CICharField(primary_key=True, max_length=255) name = CICharField(primary_key=True, max_length=255)
email = CIEmailField() email = CIEmailField()
description = CITextField() description = CITextField()
array_field = ArrayField(CITextField(), null=True)
def __str__(self): def __str__(self):
return self.name return self.name

View File

@ -4,11 +4,13 @@ strings and thus eliminates the need for operations such as iexact and other
modifiers to enforce use of an index. modifiers to enforce use of an index.
""" """
from django.db import IntegrityError from django.db import IntegrityError
from django.test.utils import modify_settings
from . import PostgreSQLTestCase from . import PostgreSQLTestCase
from .models import CITestModel from .models import CITestModel
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
class CITextTestCase(PostgreSQLTestCase): class CITextTestCase(PostgreSQLTestCase):
@classmethod @classmethod
@ -17,6 +19,7 @@ class CITextTestCase(PostgreSQLTestCase):
name='JoHn', name='JoHn',
email='joHn@johN.com', email='joHn@johN.com',
description='Average Joe named JoHn', description='Average Joe named JoHn',
array_field=['JoE', 'jOhn'],
) )
def test_equal_lowercase(self): def test_equal_lowercase(self):
@ -34,3 +37,8 @@ class CITextTestCase(PostgreSQLTestCase):
""" """
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
CITestModel.objects.create(name='John') CITestModel.objects.create(name='John')
def test_array_field(self):
instance = CITestModel.objects.get()
self.assertEqual(instance.array_field, self.john.array_field)
self.assertTrue(CITestModel.objects.filter(array_field__contains=['joe']).exists())