From 38cada7c94f5f73d2d47a0a730ea5d71d266fa2c Mon Sep 17 00:00:00 2001 From: Ian Foote Date: Wed, 11 Oct 2017 22:55:52 +0530 Subject: [PATCH] Fixed #28077 -- Added support for PostgreSQL opclasses in Index. Thanks Vinay Karanam for the initial patch. --- django/db/backends/base/schema.py | 7 ++- django/db/backends/ddl_references.py | 18 +++++++ django/db/backends/postgresql/schema.py | 12 +++-- django/db/models/indexes.py | 13 ++++- docs/ref/models/indexes.txt | 22 +++++++- docs/releases/2.2.txt | 2 +- tests/indexes/models.py | 5 ++ tests/indexes/tests.py | 67 ++++++++++++++++++++++--- tests/model_indexes/tests.py | 13 +++++ 9 files changed, 143 insertions(+), 16 deletions(-) diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index a722e497c3..ec2cf0e5c7 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -907,7 +907,7 @@ class BaseDatabaseSchemaEditor: return '' def _create_index_sql(self, model, fields, *, name=None, suffix='', using='', - db_tablespace=None, col_suffixes=(), sql=None): + db_tablespace=None, col_suffixes=(), sql=None, opclasses=()): """ Return the SQL statement to create the index for one or several fields. `sql` can be specified if the syntax differs from the standard (GIS @@ -929,10 +929,13 @@ class BaseDatabaseSchemaEditor: table=Table(table, self.quote_name), name=IndexName(table, columns, suffix, create_index_name), using=using, - columns=Columns(table, columns, self.quote_name, col_suffixes=col_suffixes), + columns=self._index_columns(table, columns, col_suffixes, opclasses), extra=tablespace_sql, ) + def _index_columns(self, table, columns, col_suffixes, opclasses): + return Columns(table, columns, self.quote_name, col_suffixes=col_suffixes) + def _model_indexes_sql(self, model): """ Return a list of all index SQL statements (field indexes, diff --git a/django/db/backends/ddl_references.py b/django/db/backends/ddl_references.py index b894d58793..d71f6169ea 100644 --- a/django/db/backends/ddl_references.py +++ b/django/db/backends/ddl_references.py @@ -103,6 +103,24 @@ class IndexName(TableColumns): return self.create_index_name(self.table, self.columns, self.suffix) +class IndexColumns(Columns): + def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()): + self.opclasses = opclasses + super().__init__(table, columns, quote_name, col_suffixes) + + def __str__(self): + def col_str(column, idx): + try: + col = self.quote_name(column) + self.col_suffixes[idx] + except IndexError: + col = self.quote_name(column) + # Index.__init__() guarantees that self.opclasses is the same + # length as self.columns. + return '{} {}'.format(col, self.opclasses[idx]) + + return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns)) + + class ForeignKeyName(TableColumns): """Hold a reference to a foreign key name.""" diff --git a/django/db/backends/postgresql/schema.py b/django/db/backends/postgresql/schema.py index 18388cc523..feaddfab52 100644 --- a/django/db/backends/postgresql/schema.py +++ b/django/db/backends/postgresql/schema.py @@ -1,6 +1,7 @@ import psycopg2 from django.db.backends.base.schema import BaseDatabaseSchemaEditor +from django.db.backends.ddl_references import IndexColumns class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): @@ -12,8 +13,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s" sql_create_index = "CREATE INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s" - sql_create_varchar_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s varchar_pattern_ops)%(extra)s" - sql_create_text_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s text_pattern_ops)%(extra)s" sql_delete_index = "DROP INDEX IF EXISTS %(name)s" # Setting the constraint to IMMEDIATE runs any deferred checks to allow @@ -49,9 +48,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): if '[' in db_type: return None if db_type.startswith('varchar'): - return self._create_index_sql(model, [field], suffix='_like', sql=self.sql_create_varchar_index) + return self._create_index_sql(model, [field], suffix='_like', opclasses=['varchar_pattern_ops']) elif db_type.startswith('text'): - return self._create_index_sql(model, [field], suffix='_like', sql=self.sql_create_text_index) + return self._create_index_sql(model, [field], suffix='_like', opclasses=['text_pattern_ops']) return None def _alter_column_type_sql(self, model, old_field, new_field, new_type): @@ -132,3 +131,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): if old_field.unique and not (new_field.db_index or new_field.unique): index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like') self.execute(self._delete_constraint_sql(self.sql_delete_index, model, index_to_remove)) + + def _index_columns(self, table, columns, col_suffixes, opclasses): + if opclasses: + return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses) + return super()._index_columns(table, columns, col_suffixes, opclasses) diff --git a/django/db/models/indexes.py b/django/db/models/indexes.py index 9bfb9e0558..c378b13a5c 100644 --- a/django/db/models/indexes.py +++ b/django/db/models/indexes.py @@ -12,9 +12,15 @@ class Index: # cross-database compatibility with Oracle) max_name_length = 30 - def __init__(self, *, fields=(), name=None, db_tablespace=None): + def __init__(self, *, fields=(), name=None, db_tablespace=None, opclasses=()): + if opclasses and not name: + raise ValueError('An index must be named to use opclasses.') if not isinstance(fields, (list, tuple)): raise ValueError('Index.fields must be a list or tuple.') + if not isinstance(opclasses, (list, tuple)): + raise ValueError('Index.opclasses must be a list or tuple.') + if opclasses and len(fields) != len(opclasses): + raise ValueError('Index.fields and Index.opclasses must have the same number of elements.') if not fields: raise ValueError('At least one field is required to define an index.') self.fields = list(fields) @@ -31,6 +37,7 @@ class Index: if errors: raise ValueError(errors) self.db_tablespace = db_tablespace + self.opclasses = opclasses def check_name(self): errors = [] @@ -49,7 +56,7 @@ class Index: col_suffixes = [order[1] for order in self.fields_orders] return schema_editor._create_index_sql( model, fields, name=self.name, using=using, db_tablespace=self.db_tablespace, - col_suffixes=col_suffixes, + col_suffixes=col_suffixes, opclasses=self.opclasses, ) def remove_sql(self, model, schema_editor): @@ -65,6 +72,8 @@ class Index: kwargs = {'fields': self.fields, 'name': self.name} if self.db_tablespace is not None: kwargs['db_tablespace'] = self.db_tablespace + if self.opclasses: + kwargs['opclasses'] = self.opclasses return (path, (), kwargs) def clone(self): diff --git a/docs/ref/models/indexes.txt b/docs/ref/models/indexes.txt index a78423bd36..e585a6f824 100644 --- a/docs/ref/models/indexes.txt +++ b/docs/ref/models/indexes.txt @@ -21,7 +21,7 @@ options`_. ``Index`` options ================= -.. class:: Index(fields=(), name=None, db_tablespace=None) +.. class:: Index(fields=(), name=None, db_tablespace=None, opclasses=()) Creates an index (B-Tree) in the database. @@ -72,3 +72,23 @@ in the same tablespace as the table. For a list of PostgreSQL-specific indexes, see :mod:`django.contrib.postgres.indexes`. + +``opclasses`` +------------- + +.. attribute:: Index.opclasses + +.. versionadded:: 2.2 + +The names of the `PostgreSQL operator classes +`_ to use for +this index. If you require a custom operator class, you must provide one for +each field in the index. + +For example, ``GinIndex(name='json_index', fields=['jsonfield'], +opclasses=['jsonb_path_ops'])`` creates a gin index on ``jsonfield`` using +``jsonb_path_ops``. + +``opclasses`` are ignored for databases besides PostgreSQL. + +:attr:`Index.name` is required when using ``opclasses``. diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 742f4893be..a68f3135e4 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -164,7 +164,7 @@ Migrations Models ~~~~~~ -* ... +* Added support for PostgreSQL operator classes (:attr:`.Index.opclasses`). Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/indexes/models.py b/tests/indexes/models.py index 208da32c6e..27bafb9cda 100644 --- a/tests/indexes/models.py +++ b/tests/indexes/models.py @@ -52,3 +52,8 @@ if connection.vendor == 'postgresql': headline = models.CharField(max_length=100, db_index=True) body = models.TextField(db_index=True) slug = models.CharField(max_length=40, unique=True) + + +class IndexedArticle2(models.Model): + headline = models.CharField(max_length=100) + body = models.TextField() diff --git a/tests/indexes/tests.py b/tests/indexes/tests.py index ee2cbd1564..219dfe67b1 100644 --- a/tests/indexes/tests.py +++ b/tests/indexes/tests.py @@ -1,11 +1,14 @@ -from unittest import skipUnless +from unittest import skipIf, skipUnless from django.db import connection +from django.db.models import Index from django.db.models.deletion import CASCADE from django.db.models.fields.related import ForeignKey from django.test import TestCase, TransactionTestCase -from .models import Article, ArticleTranslation, IndexTogetherSingleList +from .models import ( + Article, ArticleTranslation, IndexedArticle2, IndexTogetherSingleList, +) class SchemaIndexesTests(TestCase): @@ -66,8 +69,33 @@ class SchemaIndexesTests(TestCase): index_sql = connection.schema_editor()._model_indexes_sql(IndexTogetherSingleList) self.assertEqual(len(index_sql), 1) - @skipUnless(connection.vendor == 'postgresql', "This is a postgresql-specific issue") - def test_postgresql_text_indexes(self): + +@skipIf(connection.vendor == 'postgresql', 'opclasses are PostgreSQL only') +class SchemaIndexesNotPostgreSQLTests(TransactionTestCase): + available_apps = ['indexes'] + + def test_create_index_ignores_opclasses(self): + index = Index( + name='test_ops_class', + fields=['headline'], + opclasses=['varchar_pattern_ops'], + ) + with connection.schema_editor() as editor: + # This would error if opclasses weren't ingored. + editor.add_index(IndexedArticle2, index) + + +@skipUnless(connection.vendor == 'postgresql', 'PostgreSQL tests') +class SchemaIndexesPostgreSQLTests(TransactionTestCase): + available_apps = ['indexes'] + get_opclass_query = ''' + SELECT opcname, c.relname FROM pg_opclass AS oc + JOIN pg_index as i on oc.oid = ANY(i.indclass) + JOIN pg_class as c on c.oid = i.indexrelid + WHERE c.relname = '%s' + ''' + + def test_text_indexes(self): """Test creation of PostgreSQL-specific text indexes (#12234)""" from .models import IndexedArticle index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(IndexedArticle)] @@ -78,12 +106,39 @@ class SchemaIndexesTests(TestCase): # index (#19441). self.assertIn('("slug" varchar_pattern_ops)', index_sql[4]) - @skipUnless(connection.vendor == 'postgresql', "This is a postgresql-specific issue") - def test_postgresql_virtual_relation_indexes(self): + def test_virtual_relation_indexes(self): """Test indexes are not created for related objects""" index_sql = connection.schema_editor()._model_indexes_sql(Article) self.assertEqual(len(index_sql), 1) + def test_ops_class(self): + index = Index( + name='test_ops_class', + fields=['headline'], + opclasses=['varchar_pattern_ops'], + ) + with connection.schema_editor() as editor: + editor.add_index(IndexedArticle2, index) + with editor.connection.cursor() as cursor: + cursor.execute(self.get_opclass_query % 'test_ops_class') + self.assertEqual(cursor.fetchall(), [('varchar_pattern_ops', 'test_ops_class')]) + + def test_ops_class_multiple_columns(self): + index = Index( + name='test_ops_class_multiple', + fields=['headline', 'body'], + opclasses=['varchar_pattern_ops', 'text_pattern_ops'], + ) + with connection.schema_editor() as editor: + editor.add_index(IndexedArticle2, index) + with editor.connection.cursor() as cursor: + cursor.execute(self.get_opclass_query % 'test_ops_class_multiple') + expected_ops_classes = ( + ('varchar_pattern_ops', 'test_ops_class_multiple'), + ('text_pattern_ops', 'test_ops_class_multiple'), + ) + self.assertCountEqual(cursor.fetchall(), expected_ops_classes) + @skipUnless(connection.vendor == 'mysql', 'MySQL tests') class SchemaIndexesMySQLTests(TransactionTestCase): diff --git a/tests/model_indexes/tests.py b/tests/model_indexes/tests.py index c75c8e8473..36c217982e 100644 --- a/tests/model_indexes/tests.py +++ b/tests/model_indexes/tests.py @@ -39,6 +39,19 @@ class IndexesTests(SimpleTestCase): with self.assertRaisesMessage(ValueError, msg): models.Index() + def test_opclasses_requires_index_name(self): + with self.assertRaisesMessage(ValueError, 'An index must be named to use opclasses.'): + models.Index(opclasses=['jsonb_path_ops']) + + def test_opclasses_requires_list_or_tuple(self): + with self.assertRaisesMessage(ValueError, 'Index.opclasses must be a list or tuple.'): + models.Index(name='test_opclass', fields=['field'], opclasses='jsonb_path_ops') + + def test_opclasses_and_fields_same_length(self): + msg = 'Index.fields and Index.opclasses must have the same number of elements.' + with self.assertRaisesMessage(ValueError, msg): + models.Index(name='test_opclass', fields=['field', 'other'], opclasses=['jsonb_path_ops']) + def test_max_name_length(self): msg = 'Index names cannot be longer than 30 characters.' with self.assertRaisesMessage(ValueError, msg):