Fixed #16737 -- Support non-ascii column names in inspectdb

Thanks moof at metamoof.net for the report.
This commit is contained in:
Claude Paroz 2013-04-01 19:51:53 +02:00
parent 2817a29d90
commit 8c41bd93c2
4 changed files with 19 additions and 5 deletions

View File

@ -2,6 +2,7 @@ import re
from .base import FIELD_TYPE from .base import FIELD_TYPE
from django.db.backends import BaseDatabaseIntrospection, FieldInfo from django.db.backends import BaseDatabaseIntrospection, FieldInfo
from django.utils.encoding import force_text
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
@ -55,7 +56,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
numeric_map = dict([(line[0], tuple([int(n) for n in line[1:]])) for line in cursor.fetchall()]) numeric_map = dict([(line[0], tuple([int(n) for n in line[1:]])) for line in cursor.fetchall()])
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return [FieldInfo(*(line[:3] + (length_map.get(line[0], line[3]),) return [FieldInfo(*((force_text(line[0]),)
+ line[1:3]
+ (length_map.get(line[0], line[3]),)
+ numeric_map.get(line[0], line[4:6]) + numeric_map.get(line[0], line[4:6])
+ (line[6],))) + (line[6],)))
for line in cursor.description] for line in cursor.description]

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db.backends import BaseDatabaseIntrospection, FieldInfo from django.db.backends import BaseDatabaseIntrospection, FieldInfo
from django.utils.encoding import force_text
class DatabaseIntrospection(BaseDatabaseIntrospection): class DatabaseIntrospection(BaseDatabaseIntrospection):
@ -46,7 +47,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
WHERE table_name = %s""", [table_name]) WHERE table_name = %s""", [table_name])
null_map = dict(cursor.fetchall()) null_map = dict(cursor.fetchall())
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return [FieldInfo(*(line[:6] + (null_map[line[0]]=='YES',))) return [FieldInfo(*((force_text(line[0]),) + line[1:6] + (null_map[force_text(line[0])]=='YES',)))
for line in cursor.description] for line in cursor.description]
def get_relations(self, cursor, table_name): def get_relations(self, cursor, table_name):

View File

@ -1,3 +1,6 @@
# -*- encoding: utf-8 -*-
from __future__ import unicode_literals
from django.db import models from django.db import models
@ -29,6 +32,7 @@ class SpecialColumnName(models.Model):
field_field_2 = models.IntegerField(db_column='__field') field_field_2 = models.IntegerField(db_column='__field')
# Other chars # Other chars
prc_x = models.IntegerField(db_column='prc(%) x') prc_x = models.IntegerField(db_column='prc(%) x')
non_ascii = models.IntegerField(db_column='tamaño')
class ColumnTypes(models.Model): class ColumnTypes(models.Model):
id = models.AutoField(primary_key=True) id = models.AutoField(primary_key=True)

View File

@ -1,3 +1,4 @@
# -*- encoding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
import re import re
@ -6,7 +7,7 @@ from django.core.management import call_command
from django.db import connection from django.db import connection
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.utils.unittest import expectedFailure from django.utils.unittest import expectedFailure
from django.utils.six import StringIO from django.utils.six import PY3, StringIO
if connection.vendor == 'oracle': if connection.vendor == 'oracle':
expectedFailureOnOracle = expectedFailure expectedFailureOnOracle = expectedFailure
@ -146,6 +147,11 @@ class InspectDBTestCase(TestCase):
self.assertIn("field_field_0 = models.IntegerField(db_column='%s__')" % base_name, output) self.assertIn("field_field_0 = models.IntegerField(db_column='%s__')" % base_name, output)
self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output) self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)
self.assertIn("prc_x = models.IntegerField(db_column='prc(%) x')", output) self.assertIn("prc_x = models.IntegerField(db_column='prc(%) x')", output)
if PY3:
# Python 3 allows non-ascii identifiers
self.assertIn("tamaño = models.IntegerField()", output)
else:
self.assertIn("tama_o = models.IntegerField(db_column='tama\\xf1o')", output)
def test_managed_models(self): def test_managed_models(self):
"""Test that by default the command generates models with `Meta.managed = False` (#14305)""" """Test that by default the command generates models with `Meta.managed = False` (#14305)"""