[2.0.x] Fixed #28792 -- Fixed index name truncation of namespaced tables.

Refs #27458, #27843.

Thanks Tim and Mariusz for the review.

Backport of ee85ef8315 from master
This commit is contained in:
Simon Charette 2017-11-11 19:17:20 -05:00 committed by Tim Graham
parent 022aebc550
commit 0696edbc6a
5 changed files with 53 additions and 13 deletions

View File

@ -5,7 +5,7 @@ from datetime import datetime
from django.db.backends.ddl_references import ( from django.db.backends.ddl_references import (
Columns, ForeignKeyName, IndexName, Statement, Table, Columns, ForeignKeyName, IndexName, Statement, Table,
) )
from django.db.backends.utils import strip_quotes from django.db.backends.utils import split_identifier
from django.db.models import Index from django.db.models import Index
from django.db.transaction import TransactionManagementError, atomic from django.db.transaction import TransactionManagementError, atomic
from django.utils import timezone from django.utils import timezone
@ -858,7 +858,7 @@ class BaseDatabaseSchemaEditor:
The name is divided into 3 parts: the table name, the column names, The name is divided into 3 parts: the table name, the column names,
and a unique digest and suffix. and a unique digest and suffix.
""" """
table_name = strip_quotes(table_name) _, table_name = split_identifier(table_name)
hash_data = [table_name] + list(column_names) hash_data = [table_name] + list(column_names)
hash_suffix_part = '%s%s' % (self._digest(*hash_data), suffix) hash_suffix_part = '%s%s' % (self._digest(*hash_data), suffix)
max_length = self.connection.ops.max_name_length() or 200 max_length = self.connection.ops.max_name_length() or 200

View File

@ -3,7 +3,6 @@ import decimal
import functools import functools
import hashlib import hashlib
import logging import logging
import re
from time import time from time import time
from django.conf import settings from django.conf import settings
@ -194,20 +193,35 @@ def rev_typecast_decimal(d):
return str(d) return str(d)
def truncate_name(name, length=None, hash_len=4): def split_identifier(identifier):
""" """
Shorten a string to a repeatable mangled version with the given length. Split a SQL identifier into a two element tuple of (namespace, name).
If a quote stripped name contains a username, e.g. USERNAME"."TABLE,
The identifier could be a table, column, or sequence name might be prefixed
by a namespace.
"""
try:
namespace, name = identifier.split('"."')
except ValueError:
namespace, name = '', identifier
return namespace.strip('"'), name.strip('"')
def truncate_name(identifier, length=None, hash_len=4):
"""
Shorten a SQL identifier to a repeatable mangled version with the given
length.
If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
truncate the table portion only. truncate the table portion only.
""" """
match = re.match(r'([^"]+)"\."([^"]+)', name) namespace, name = split_identifier(identifier)
table_name = match.group(2) if match else name
if length is None or len(table_name) <= length: if length is None or len(name) <= length:
return name return identifier
hsh = hashlib.md5(force_bytes(table_name)).hexdigest()[:hash_len] digest = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len]
return '%s%s%s' % (match.group(1) + '"."' if match else '', table_name[:length - hash_len], hsh) return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest)
def format_number(value, max_digits, decimal_places): def format_number(value, max_digits, decimal_places):

View File

@ -15,3 +15,6 @@ Bugfixes
* Added support for ``QuerySet.values()`` and ``values_list()`` for * Added support for ``QuerySet.values()`` and ``values_list()`` for
``union()``, ``difference()``, and ``intersection()`` queries ``union()``, ``difference()``, and ``intersection()`` queries
(:ticket:`28781`). (:ticket:`28781`).
* Fixed incorrect index name truncation when using a namespaced ``db_table``
(:ticket:`28792`).

View File

@ -2,7 +2,9 @@
from decimal import Decimal, Rounded from decimal import Decimal, Rounded
from django.db import connection from django.db import connection
from django.db.backends.utils import format_number, truncate_name from django.db.backends.utils import (
format_number, split_identifier, truncate_name,
)
from django.db.utils import NotSupportedError from django.db.utils import NotSupportedError
from django.test import ( from django.test import (
SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, SimpleTestCase, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature,
@ -21,6 +23,12 @@ class TestUtils(SimpleTestCase):
self.assertEqual(truncate_name('username"."some_long_table', 10), 'username"."some_la38a') self.assertEqual(truncate_name('username"."some_long_table', 10), 'username"."some_la38a')
self.assertEqual(truncate_name('username"."some_long_table', 10, 3), 'username"."some_loa38') self.assertEqual(truncate_name('username"."some_long_table', 10, 3), 'username"."some_loa38')
def test_split_identifier(self):
self.assertEqual(split_identifier('some_table'), ('', 'some_table'))
self.assertEqual(split_identifier('"some_table"'), ('', 'some_table'))
self.assertEqual(split_identifier('namespace"."some_table'), ('namespace', 'some_table'))
self.assertEqual(split_identifier('"namespace"."some_table"'), ('namespace', 'some_table'))
def test_format_number(self): def test_format_number(self):
def equal(value, max_d, places, result): def equal(value, max_d, places, result):
self.assertEqual(format_number(Decimal(value), max_d, places), result) self.assertEqual(format_number(Decimal(value), max_d, places), result)

View File

@ -2372,6 +2372,21 @@ class SchemaTests(TransactionTestCase):
cast_function=lambda x: x.time(), cast_function=lambda x: x.time(),
) )
def test_namespaced_db_table_create_index_name(self):
"""
Table names are stripped of their namespace/schema before being used to
generate index names.
"""
with connection.schema_editor() as editor:
max_name_length = connection.ops.max_name_length() or 200
namespace = 'n' * max_name_length
table_name = 't' * max_name_length
namespaced_table_name = '"%s"."%s"' % (namespace, table_name)
self.assertEqual(
editor._create_index_name(table_name, []),
editor._create_index_name(namespaced_table_name, []),
)
@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle specific db_table syntax') @unittest.skipUnless(connection.vendor == 'oracle', 'Oracle specific db_table syntax')
def test_creation_with_db_table_double_quotes(self): def test_creation_with_db_table_double_quotes(self):
oracle_user = connection.creation._test_database_user() oracle_user = connection.creation._test_database_user()