Fixed #28077 -- Added support for PostgreSQL opclasses in Index.

Thanks Vinay Karanam for the initial patch.
This commit is contained in:
Ian Foote 2017-10-11 22:55:52 +05:30 committed by Tim Graham
parent b4cba4ed62
commit 38cada7c94
9 changed files with 143 additions and 16 deletions

View File

@ -907,7 +907,7 @@ class BaseDatabaseSchemaEditor:
return '' return ''
def _create_index_sql(self, model, fields, *, name=None, suffix='', using='', def _create_index_sql(self, model, fields, *, name=None, suffix='', using='',
db_tablespace=None, col_suffixes=(), sql=None): db_tablespace=None, col_suffixes=(), sql=None, opclasses=()):
""" """
Return the SQL statement to create the index for one or several fields. Return the SQL statement to create the index for one or several fields.
`sql` can be specified if the syntax differs from the standard (GIS `sql` can be specified if the syntax differs from the standard (GIS
@ -929,10 +929,13 @@ class BaseDatabaseSchemaEditor:
table=Table(table, self.quote_name), table=Table(table, self.quote_name),
name=IndexName(table, columns, suffix, create_index_name), name=IndexName(table, columns, suffix, create_index_name),
using=using, using=using,
columns=Columns(table, columns, self.quote_name, col_suffixes=col_suffixes), columns=self._index_columns(table, columns, col_suffixes, opclasses),
extra=tablespace_sql, extra=tablespace_sql,
) )
def _index_columns(self, table, columns, col_suffixes, opclasses):
return Columns(table, columns, self.quote_name, col_suffixes=col_suffixes)
def _model_indexes_sql(self, model): def _model_indexes_sql(self, model):
""" """
Return a list of all index SQL statements (field indexes, Return a list of all index SQL statements (field indexes,

View File

@ -103,6 +103,24 @@ class IndexName(TableColumns):
return self.create_index_name(self.table, self.columns, self.suffix) return self.create_index_name(self.table, self.columns, self.suffix)
class IndexColumns(Columns):
def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
self.opclasses = opclasses
super().__init__(table, columns, quote_name, col_suffixes)
def __str__(self):
def col_str(column, idx):
try:
col = self.quote_name(column) + self.col_suffixes[idx]
except IndexError:
col = self.quote_name(column)
# Index.__init__() guarantees that self.opclasses is the same
# length as self.columns.
return '{} {}'.format(col, self.opclasses[idx])
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
class ForeignKeyName(TableColumns): class ForeignKeyName(TableColumns):
"""Hold a reference to a foreign key name.""" """Hold a reference to a foreign key name."""

View File

@ -1,6 +1,7 @@
import psycopg2 import psycopg2
from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.backends.ddl_references import IndexColumns
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
@ -12,8 +13,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s" sql_create_index = "CREATE INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s"
sql_create_varchar_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s varchar_pattern_ops)%(extra)s"
sql_create_text_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s text_pattern_ops)%(extra)s"
sql_delete_index = "DROP INDEX IF EXISTS %(name)s" sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
# Setting the constraint to IMMEDIATE runs any deferred checks to allow # Setting the constraint to IMMEDIATE runs any deferred checks to allow
@ -49,9 +48,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if '[' in db_type: if '[' in db_type:
return None return None
if db_type.startswith('varchar'): if db_type.startswith('varchar'):
return self._create_index_sql(model, [field], suffix='_like', sql=self.sql_create_varchar_index) return self._create_index_sql(model, [field], suffix='_like', opclasses=['varchar_pattern_ops'])
elif db_type.startswith('text'): elif db_type.startswith('text'):
return self._create_index_sql(model, [field], suffix='_like', sql=self.sql_create_text_index) return self._create_index_sql(model, [field], suffix='_like', opclasses=['text_pattern_ops'])
return None return None
def _alter_column_type_sql(self, model, old_field, new_field, new_type): def _alter_column_type_sql(self, model, old_field, new_field, new_type):
@ -132,3 +131,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if old_field.unique and not (new_field.db_index or new_field.unique): if old_field.unique and not (new_field.db_index or new_field.unique):
index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like') index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_constraint_sql(self.sql_delete_index, model, index_to_remove)) self.execute(self._delete_constraint_sql(self.sql_delete_index, model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
return super()._index_columns(table, columns, col_suffixes, opclasses)

View File

@ -12,9 +12,15 @@ class Index:
# cross-database compatibility with Oracle) # cross-database compatibility with Oracle)
max_name_length = 30 max_name_length = 30
def __init__(self, *, fields=(), name=None, db_tablespace=None): def __init__(self, *, fields=(), name=None, db_tablespace=None, opclasses=()):
if opclasses and not name:
raise ValueError('An index must be named to use opclasses.')
if not isinstance(fields, (list, tuple)): if not isinstance(fields, (list, tuple)):
raise ValueError('Index.fields must be a list or tuple.') raise ValueError('Index.fields must be a list or tuple.')
if not isinstance(opclasses, (list, tuple)):
raise ValueError('Index.opclasses must be a list or tuple.')
if opclasses and len(fields) != len(opclasses):
raise ValueError('Index.fields and Index.opclasses must have the same number of elements.')
if not fields: if not fields:
raise ValueError('At least one field is required to define an index.') raise ValueError('At least one field is required to define an index.')
self.fields = list(fields) self.fields = list(fields)
@ -31,6 +37,7 @@ class Index:
if errors: if errors:
raise ValueError(errors) raise ValueError(errors)
self.db_tablespace = db_tablespace self.db_tablespace = db_tablespace
self.opclasses = opclasses
def check_name(self): def check_name(self):
errors = [] errors = []
@ -49,7 +56,7 @@ class Index:
col_suffixes = [order[1] for order in self.fields_orders] col_suffixes = [order[1] for order in self.fields_orders]
return schema_editor._create_index_sql( return schema_editor._create_index_sql(
model, fields, name=self.name, using=using, db_tablespace=self.db_tablespace, model, fields, name=self.name, using=using, db_tablespace=self.db_tablespace,
col_suffixes=col_suffixes, col_suffixes=col_suffixes, opclasses=self.opclasses,
) )
def remove_sql(self, model, schema_editor): def remove_sql(self, model, schema_editor):
@ -65,6 +72,8 @@ class Index:
kwargs = {'fields': self.fields, 'name': self.name} kwargs = {'fields': self.fields, 'name': self.name}
if self.db_tablespace is not None: if self.db_tablespace is not None:
kwargs['db_tablespace'] = self.db_tablespace kwargs['db_tablespace'] = self.db_tablespace
if self.opclasses:
kwargs['opclasses'] = self.opclasses
return (path, (), kwargs) return (path, (), kwargs)
def clone(self): def clone(self):

View File

@ -21,7 +21,7 @@ options`_.
``Index`` options ``Index`` options
================= =================
.. class:: Index(fields=(), name=None, db_tablespace=None) .. class:: Index(fields=(), name=None, db_tablespace=None, opclasses=())
Creates an index (B-Tree) in the database. Creates an index (B-Tree) in the database.
@ -72,3 +72,23 @@ in the same tablespace as the table.
For a list of PostgreSQL-specific indexes, see For a list of PostgreSQL-specific indexes, see
:mod:`django.contrib.postgres.indexes`. :mod:`django.contrib.postgres.indexes`.
``opclasses``
-------------
.. attribute:: Index.opclasses
.. versionadded:: 2.2
The names of the `PostgreSQL operator classes
<https://www.postgresql.org/docs/current/static/indexes-opclass.html>`_ to use for
this index. If you require a custom operator class, you must provide one for
each field in the index.
For example, ``GinIndex(name='json_index', fields=['jsonfield'],
opclasses=['jsonb_path_ops'])`` creates a gin index on ``jsonfield`` using
``jsonb_path_ops``.
``opclasses`` are ignored for databases besides PostgreSQL.
:attr:`Index.name` is required when using ``opclasses``.

View File

@ -164,7 +164,7 @@ Migrations
Models Models
~~~~~~ ~~~~~~
* ... * Added support for PostgreSQL operator classes (:attr:`.Index.opclasses`).
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -52,3 +52,8 @@ if connection.vendor == 'postgresql':
headline = models.CharField(max_length=100, db_index=True) headline = models.CharField(max_length=100, db_index=True)
body = models.TextField(db_index=True) body = models.TextField(db_index=True)
slug = models.CharField(max_length=40, unique=True) slug = models.CharField(max_length=40, unique=True)
class IndexedArticle2(models.Model):
headline = models.CharField(max_length=100)
body = models.TextField()

View File

@ -1,11 +1,14 @@
from unittest import skipUnless from unittest import skipIf, skipUnless
from django.db import connection from django.db import connection
from django.db.models import Index
from django.db.models.deletion import CASCADE from django.db.models.deletion import CASCADE
from django.db.models.fields.related import ForeignKey from django.db.models.fields.related import ForeignKey
from django.test import TestCase, TransactionTestCase from django.test import TestCase, TransactionTestCase
from .models import Article, ArticleTranslation, IndexTogetherSingleList from .models import (
Article, ArticleTranslation, IndexedArticle2, IndexTogetherSingleList,
)
class SchemaIndexesTests(TestCase): class SchemaIndexesTests(TestCase):
@ -66,8 +69,33 @@ class SchemaIndexesTests(TestCase):
index_sql = connection.schema_editor()._model_indexes_sql(IndexTogetherSingleList) index_sql = connection.schema_editor()._model_indexes_sql(IndexTogetherSingleList)
self.assertEqual(len(index_sql), 1) self.assertEqual(len(index_sql), 1)
@skipUnless(connection.vendor == 'postgresql', "This is a postgresql-specific issue")
def test_postgresql_text_indexes(self): @skipIf(connection.vendor == 'postgresql', 'opclasses are PostgreSQL only')
class SchemaIndexesNotPostgreSQLTests(TransactionTestCase):
available_apps = ['indexes']
def test_create_index_ignores_opclasses(self):
index = Index(
name='test_ops_class',
fields=['headline'],
opclasses=['varchar_pattern_ops'],
)
with connection.schema_editor() as editor:
# This would error if opclasses weren't ingored.
editor.add_index(IndexedArticle2, index)
@skipUnless(connection.vendor == 'postgresql', 'PostgreSQL tests')
class SchemaIndexesPostgreSQLTests(TransactionTestCase):
available_apps = ['indexes']
get_opclass_query = '''
SELECT opcname, c.relname FROM pg_opclass AS oc
JOIN pg_index as i on oc.oid = ANY(i.indclass)
JOIN pg_class as c on c.oid = i.indexrelid
WHERE c.relname = '%s'
'''
def test_text_indexes(self):
"""Test creation of PostgreSQL-specific text indexes (#12234)""" """Test creation of PostgreSQL-specific text indexes (#12234)"""
from .models import IndexedArticle from .models import IndexedArticle
index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(IndexedArticle)] index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(IndexedArticle)]
@ -78,12 +106,39 @@ class SchemaIndexesTests(TestCase):
# index (#19441). # index (#19441).
self.assertIn('("slug" varchar_pattern_ops)', index_sql[4]) self.assertIn('("slug" varchar_pattern_ops)', index_sql[4])
@skipUnless(connection.vendor == 'postgresql', "This is a postgresql-specific issue") def test_virtual_relation_indexes(self):
def test_postgresql_virtual_relation_indexes(self):
"""Test indexes are not created for related objects""" """Test indexes are not created for related objects"""
index_sql = connection.schema_editor()._model_indexes_sql(Article) index_sql = connection.schema_editor()._model_indexes_sql(Article)
self.assertEqual(len(index_sql), 1) self.assertEqual(len(index_sql), 1)
def test_ops_class(self):
index = Index(
name='test_ops_class',
fields=['headline'],
opclasses=['varchar_pattern_ops'],
)
with connection.schema_editor() as editor:
editor.add_index(IndexedArticle2, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query % 'test_ops_class')
self.assertEqual(cursor.fetchall(), [('varchar_pattern_ops', 'test_ops_class')])
def test_ops_class_multiple_columns(self):
index = Index(
name='test_ops_class_multiple',
fields=['headline', 'body'],
opclasses=['varchar_pattern_ops', 'text_pattern_ops'],
)
with connection.schema_editor() as editor:
editor.add_index(IndexedArticle2, index)
with editor.connection.cursor() as cursor:
cursor.execute(self.get_opclass_query % 'test_ops_class_multiple')
expected_ops_classes = (
('varchar_pattern_ops', 'test_ops_class_multiple'),
('text_pattern_ops', 'test_ops_class_multiple'),
)
self.assertCountEqual(cursor.fetchall(), expected_ops_classes)
@skipUnless(connection.vendor == 'mysql', 'MySQL tests') @skipUnless(connection.vendor == 'mysql', 'MySQL tests')
class SchemaIndexesMySQLTests(TransactionTestCase): class SchemaIndexesMySQLTests(TransactionTestCase):

View File

@ -39,6 +39,19 @@ class IndexesTests(SimpleTestCase):
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
models.Index() models.Index()
def test_opclasses_requires_index_name(self):
with self.assertRaisesMessage(ValueError, 'An index must be named to use opclasses.'):
models.Index(opclasses=['jsonb_path_ops'])
def test_opclasses_requires_list_or_tuple(self):
with self.assertRaisesMessage(ValueError, 'Index.opclasses must be a list or tuple.'):
models.Index(name='test_opclass', fields=['field'], opclasses='jsonb_path_ops')
def test_opclasses_and_fields_same_length(self):
msg = 'Index.fields and Index.opclasses must have the same number of elements.'
with self.assertRaisesMessage(ValueError, msg):
models.Index(name='test_opclass', fields=['field', 'other'], opclasses=['jsonb_path_ops'])
def test_max_name_length(self): def test_max_name_length(self):
msg = 'Index names cannot be longer than 30 characters.' msg = 'Index names cannot be longer than 30 characters.'
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):