Fixed #12460 -- Improved inspectdb handling of special field names

Thanks mihail lukin for the report and elijahr and kgibula for their
work on the patch.
This commit is contained in:
Claude Paroz 2012-08-23 21:07:56 +02:00
parent 10d32072af
commit 395c6083af
3 changed files with 109 additions and 51 deletions

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals
import keyword import keyword
from optparse import make_option from optparse import make_option
@ -31,6 +33,7 @@ class Command(NoArgsCommand):
table_name_filter = options.get('table_name_filter') table_name_filter = options.get('table_name_filter')
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '') table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
strip_prefix = lambda s: s.startswith("u'") and s[1:] or s
cursor = connection.cursor() cursor = connection.cursor()
yield "# This is an auto-generated Django model module." yield "# This is an auto-generated Django model module."
@ -41,6 +44,7 @@ class Command(NoArgsCommand):
yield "#" yield "#"
yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'" yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
yield "# into your database." yield "# into your database."
yield "from __future__ import unicode_literals"
yield '' yield ''
yield 'from %s import models' % self.db_module yield 'from %s import models' % self.db_module
yield '' yield ''
@ -59,16 +63,19 @@ class Command(NoArgsCommand):
indexes = connection.introspection.get_indexes(cursor, table_name) indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError: except NotImplementedError:
indexes = {} indexes = {}
used_column_names = [] # Holds column names used in the table so far
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)): for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
column_name = row[0]
att_name = column_name.lower()
comment_notes = [] # Holds Field notes, to be displayed in a Python comment. comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'. extra_params = {} # Holds Field parameters such as 'db_column'.
column_name = row[0]
is_relation = i in relations
# If the column name can't be used verbatim as a Python att_name, params, notes = self.normalize_col_name(
# attribute, set the "db_column" for this Field. column_name, used_column_names, is_relation)
if ' ' in att_name or '-' in att_name or keyword.iskeyword(att_name) or column_name != att_name: extra_params.update(params)
extra_params['db_column'] = column_name comment_notes.extend(notes)
used_column_names.append(att_name)
# Add primary_key and unique, if necessary. # Add primary_key and unique, if necessary.
if column_name in indexes: if column_name in indexes:
@ -77,30 +84,12 @@ class Command(NoArgsCommand):
elif indexes[column_name]['unique']: elif indexes[column_name]['unique']:
extra_params['unique'] = True extra_params['unique'] = True
# Modify the field name to make it Python-compatible. if is_relation:
if ' ' in att_name:
att_name = att_name.replace(' ', '_')
comment_notes.append('Field renamed to remove spaces.')
if '-' in att_name:
att_name = att_name.replace('-', '_')
comment_notes.append('Field renamed to remove dashes.')
if column_name != att_name:
comment_notes.append('Field name made lowercase.')
if i in relations:
rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1]) rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1])
if rel_to in known_models: if rel_to in known_models:
field_type = 'ForeignKey(%s' % rel_to field_type = 'ForeignKey(%s' % rel_to
else: else:
field_type = "ForeignKey('%s'" % rel_to field_type = "ForeignKey('%s'" % rel_to
if att_name.endswith('_id'):
att_name = att_name[:-3]
else:
extra_params['db_column'] = column_name
else: else:
# Calling `get_field_type` to get the field type string and any # Calling `get_field_type` to get the field type string and any
# additional paramters and notes. # additional paramters and notes.
@ -110,16 +99,6 @@ class Command(NoArgsCommand):
field_type += '(' field_type += '('
if keyword.iskeyword(att_name):
att_name += '_field'
comment_notes.append('Field renamed because it was a Python reserved word.')
if att_name[0].isdigit():
att_name = 'number_%s' % att_name
extra_params['db_column'] = six.text_type(column_name)
comment_notes.append("Field renamed because it wasn't a "
"valid Python identifier.")
# Don't output 'id = meta.AutoField(primary_key=True)', because # Don't output 'id = meta.AutoField(primary_key=True)', because
# that's assumed if it doesn't exist. # that's assumed if it doesn't exist.
if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}: if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}:
@ -136,7 +115,9 @@ class Command(NoArgsCommand):
if extra_params: if extra_params:
if not field_desc.endswith('('): if not field_desc.endswith('('):
field_desc += ', ' field_desc += ', '
field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()]) field_desc += ', '.join([
'%s=%s' % (k, strip_prefix(repr(v)))
for k, v in extra_params.items()])
field_desc += ')' field_desc += ')'
if comment_notes: if comment_notes:
field_desc += ' # ' + ' '.join(comment_notes) field_desc += ' # ' + ' '.join(comment_notes)
@ -144,6 +125,64 @@ class Command(NoArgsCommand):
for meta_line in self.get_meta(table_name): for meta_line in self.get_meta(table_name):
yield meta_line yield meta_line
def normalize_col_name(self, col_name, used_column_names, is_relation):
"""
Modify the column name to make it Python-compatible as a field name
"""
field_params = {}
field_notes = []
new_name = col_name.lower()
if new_name != col_name:
field_notes.append('Field name made lowercase.')
if is_relation:
if new_name.endswith('_id'):
new_name = new_name[:-3]
else:
field_params['db_column'] = col_name
if ' ' in new_name:
new_name = new_name.replace(' ', '_')
field_notes.append('Field renamed to remove spaces.')
if '-' in new_name:
new_name = new_name.replace('-', '_')
field_notes.append('Field renamed to remove dashes.')
if new_name.find('__') >= 0:
while new_name.find('__') >= 0:
new_name = new_name.replace('__', '_')
field_notes.append("Field renamed because it contained more than one '_' in a row.")
if new_name.startswith('_'):
new_name = 'field%s' % new_name
field_notes.append("Field renamed because it started with '_'.")
if new_name.endswith('_'):
new_name = '%sfield' % new_name
field_notes.append("Field renamed because it ended with '_'.")
if keyword.iskeyword(new_name):
new_name += '_field'
field_notes.append('Field renamed because it was a Python reserved word.')
if new_name[0].isdigit():
new_name = 'number_%s' % new_name
field_notes.append("Field renamed because it wasn't a valid Python identifier.")
if new_name in used_column_names:
num = 0
while '%s_%d' % (new_name, num) in used_column_names:
num += 1
new_name = '%s_%d' % (new_name, num)
field_notes.append('Field renamed because of name conflict.')
if col_name != new_name and field_notes:
field_params['db_column'] = col_name
return new_name, field_params, field_notes
def get_field_type(self, connection, table_name, row): def get_field_type(self, connection, table_name, row):
""" """
Given the database connection, the table name, and the cursor row Given the database connection, the table name, and the cursor row
@ -181,6 +220,6 @@ class Command(NoArgsCommand):
to construct the inner Meta class for the model corresponding to construct the inner Meta class for the model corresponding
to the given database table name. to the given database table name.
""" """
return [' class Meta:', return [" class Meta:",
' db_table = %r' % table_name, " db_table = '%s'" % table_name,
''] ""]

View File

@ -19,3 +19,9 @@ class DigitsInColumnName(models.Model):
all_digits = models.CharField(max_length=11, db_column='123') all_digits = models.CharField(max_length=11, db_column='123')
leading_digit = models.CharField(max_length=11, db_column='4extra') leading_digit = models.CharField(max_length=11, db_column='4extra')
leading_digits = models.CharField(max_length=11, db_column='45extra') leading_digits = models.CharField(max_length=11, db_column='45extra')
class UnderscoresInColumnName(models.Model):
field = models.IntegerField(db_column='field')
field_field_0 = models.IntegerField(db_column='Field_')
field_field_1 = models.IntegerField(db_column='Field__')
field_field_2 = models.IntegerField(db_column='__field')

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals
from django.core.management import call_command from django.core.management import call_command
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.utils.six import StringIO from django.utils.six import StringIO
@ -17,7 +19,6 @@ class InspectDBTestCase(TestCase):
# the Django test suite, check that one of its tables hasn't been # the Django test suite, check that one of its tables hasn't been
# inspected # inspected
self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message) self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message)
out.close()
@skipUnlessDBFeature('can_introspect_foreign_keys') @skipUnlessDBFeature('can_introspect_foreign_keys')
def test_attribute_name_not_python_keyword(self): def test_attribute_name_not_python_keyword(self):
@ -27,15 +28,16 @@ class InspectDBTestCase(TestCase):
call_command('inspectdb', call_command('inspectdb',
table_name_filter=lambda tn:tn.startswith('inspectdb_'), table_name_filter=lambda tn:tn.startswith('inspectdb_'),
stdout=out) stdout=out)
output = out.getvalue()
error_message = "inspectdb generated an attribute name which is a python keyword" error_message = "inspectdb generated an attribute name which is a python keyword"
self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", out.getvalue(), msg=error_message) self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", output, msg=error_message)
# As InspectdbPeople model is defined after InspectdbMessage, it should be quoted # As InspectdbPeople model is defined after InspectdbMessage, it should be quoted
self.assertIn("from_field = models.ForeignKey('InspectdbPeople')", out.getvalue()) self.assertIn("from_field = models.ForeignKey('InspectdbPeople', db_column='from_id')",
output)
self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)", self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)",
out.getvalue()) output)
self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)", self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)",
out.getvalue()) output)
out.close()
def test_digits_column_name_introspection(self): def test_digits_column_name_introspection(self):
"""Introspection of column names consist/start with digits (#16536/#17676)""" """Introspection of column names consist/start with digits (#16536/#17676)"""
@ -45,13 +47,24 @@ class InspectDBTestCase(TestCase):
call_command('inspectdb', call_command('inspectdb',
table_name_filter=lambda tn:tn.startswith('inspectdb_'), table_name_filter=lambda tn:tn.startswith('inspectdb_'),
stdout=out) stdout=out)
output = out.getvalue()
error_message = "inspectdb generated a model field name which is a number" error_message = "inspectdb generated a model field name which is a number"
self.assertNotIn(" 123 = models.CharField", out.getvalue(), msg=error_message) self.assertNotIn(" 123 = models.CharField", output, msg=error_message)
self.assertIn("number_123 = models.CharField", out.getvalue()) self.assertIn("number_123 = models.CharField", output)
error_message = "inspectdb generated a model field name which starts with a digit" error_message = "inspectdb generated a model field name which starts with a digit"
self.assertNotIn(" 4extra = models.CharField", out.getvalue(), msg=error_message) self.assertNotIn(" 4extra = models.CharField", output, msg=error_message)
self.assertIn("number_4extra = models.CharField", out.getvalue()) self.assertIn("number_4extra = models.CharField", output)
self.assertNotIn(" 45extra = models.CharField", out.getvalue(), msg=error_message) self.assertNotIn(" 45extra = models.CharField", output, msg=error_message)
self.assertIn("number_45extra = models.CharField", out.getvalue()) self.assertIn("number_45extra = models.CharField", output)
def test_underscores_column_name_introspection(self):
"""Introspection of column names containing underscores (#12460)"""
out = StringIO()
call_command('inspectdb', stdout=out)
output = out.getvalue()
self.assertIn("field = models.IntegerField()", output)
self.assertIn("field_field = models.IntegerField(db_column='Field_')", output)
self.assertIn("field_field_0 = models.IntegerField(db_column='Field__')", output)
self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)