diff --git a/django/db/backends/base/schema.py b/django/db/backends/base/schema.py index f3aade3916..35b82dc1e5 100644 --- a/django/db/backends/base/schema.py +++ b/django/db/backends/base/schema.py @@ -5,7 +5,7 @@ from datetime import datetime from django.db.backends.ddl_references import ( Columns, ForeignKeyName, IndexName, Statement, Table, ) -from django.db.backends.utils import strip_quotes +from django.db.backends.utils import split_identifier from django.db.models import Index from django.db.transaction import TransactionManagementError, atomic from django.utils import timezone @@ -858,7 +858,7 @@ class BaseDatabaseSchemaEditor: The name is divided into 3 parts: the table name, the column names, and a unique digest and suffix. """ - table_name = strip_quotes(table_name) + _, table_name = split_identifier(table_name) hash_data = [table_name] + list(column_names) hash_suffix_part = '%s%s' % (self._digest(*hash_data), suffix) max_length = self.connection.ops.max_name_length() or 200 diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 816164d36a..f4641e3db8 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -3,7 +3,6 @@ import decimal import functools import hashlib import logging -import re from time import time from django.conf import settings @@ -194,20 +193,35 @@ def rev_typecast_decimal(d): return str(d) -def truncate_name(name, length=None, hash_len=4): +def split_identifier(identifier): """ - Shorten a string to a repeatable mangled version with the given length. - If a quote stripped name contains a username, e.g. USERNAME"."TABLE, + Split a SQL identifier into a two element tuple of (namespace, name). + + The identifier could be a table, column, or sequence name might be prefixed + by a namespace. + """ + try: + namespace, name = identifier.split('"."') + except ValueError: + namespace, name = '', identifier + return namespace.strip('"'), name.strip('"') + + +def truncate_name(identifier, length=None, hash_len=4): + """ + Shorten a SQL identifier to a repeatable mangled version with the given + length. + + If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE, truncate the table portion only. """ - match = re.match(r'([^"]+)"\."([^"]+)', name) - table_name = match.group(2) if match else name + namespace, name = split_identifier(identifier) - if length is None or len(table_name) <= length: - return name + if length is None or len(name) <= length: + return identifier - hsh = hashlib.md5(force_bytes(table_name)).hexdigest()[:hash_len] - return '%s%s%s' % (match.group(1) + '"."' if match else '', table_name[:length - hash_len], hsh) + digest = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len] + return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest) def format_number(value, max_digits, decimal_places): diff --git a/docs/releases/1.11.8.txt b/docs/releases/1.11.8.txt index 7e4963f713..426b6d92b2 100644 --- a/docs/releases/1.11.8.txt +++ b/docs/releases/1.11.8.txt @@ -15,3 +15,6 @@ Bugfixes * Added support for ``QuerySet.values()`` and ``values_list()`` for ``union()``, ``difference()``, and ``intersection()`` queries (:ticket:`28781`). + +* Fixed incorrect index name truncation when using a namespaced ``db_table`` + (:ticket:`28792`). diff --git a/tests/backends/test_utils.py b/tests/backends/test_utils.py index be9aeaf698..cd4911fd1a 100644 --- a/tests/backends/test_utils.py +++ b/tests/backends/test_utils.py @@ -2,7 +2,9 @@ from decimal import Decimal, Rounded from django.db import connection -from django.db.backends.utils import format_number, truncate_name +from django.db.backends.utils import ( + format_number, split_identifier, truncate_name, +) from django.db.utils import NotSupportedError from django.test import ( SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, @@ -21,6 +23,12 @@ class TestUtils(SimpleTestCase): self.assertEqual(truncate_name('username"."some_long_table', 10), 'username"."some_la38a') self.assertEqual(truncate_name('username"."some_long_table', 10, 3), 'username"."some_loa38') + def test_split_identifier(self): + self.assertEqual(split_identifier('some_table'), ('', 'some_table')) + self.assertEqual(split_identifier('"some_table"'), ('', 'some_table')) + self.assertEqual(split_identifier('namespace"."some_table'), ('namespace', 'some_table')) + self.assertEqual(split_identifier('"namespace"."some_table"'), ('namespace', 'some_table')) + def test_format_number(self): def equal(value, max_d, places, result): self.assertEqual(format_number(Decimal(value), max_d, places), result) diff --git a/tests/schema/tests.py b/tests/schema/tests.py index f2a96294cc..66a54b1ce4 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -2370,6 +2370,21 @@ class SchemaTests(TransactionTestCase): cast_function=lambda x: x.time(), ) + def test_namespaced_db_table_create_index_name(self): + """ + Table names are stripped of their namespace/schema before being used to + generate index names. + """ + with connection.schema_editor() as editor: + max_name_length = connection.ops.max_name_length() or 200 + namespace = 'n' * max_name_length + table_name = 't' * max_name_length + namespaced_table_name = '"%s"."%s"' % (namespace, table_name) + self.assertEqual( + editor._create_index_name(table_name, []), + editor._create_index_name(namespaced_table_name, []), + ) + @unittest.skipUnless(connection.vendor == 'oracle', 'Oracle specific db_table syntax') def test_creation_with_db_table_double_quotes(self): oracle_user = connection.creation._test_database_user()