mirror of https://github.com/django/django.git
Fixed #12460 -- Improved inspectdb handling of special field names
Thanks mihail lukin for the report and elijahr and kgibula for their work on the patch.
This commit is contained in:
parent
10d32072af
commit
395c6083af
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import keyword
|
import keyword
|
||||||
from optparse import make_option
|
from optparse import make_option
|
||||||
|
|
||||||
|
@ -31,6 +33,7 @@ class Command(NoArgsCommand):
|
||||||
table_name_filter = options.get('table_name_filter')
|
table_name_filter = options.get('table_name_filter')
|
||||||
|
|
||||||
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
|
table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
|
||||||
|
strip_prefix = lambda s: s.startswith("u'") and s[1:] or s
|
||||||
|
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
yield "# This is an auto-generated Django model module."
|
yield "# This is an auto-generated Django model module."
|
||||||
|
@ -41,6 +44,7 @@ class Command(NoArgsCommand):
|
||||||
yield "#"
|
yield "#"
|
||||||
yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
|
yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
|
||||||
yield "# into your database."
|
yield "# into your database."
|
||||||
|
yield "from __future__ import unicode_literals"
|
||||||
yield ''
|
yield ''
|
||||||
yield 'from %s import models' % self.db_module
|
yield 'from %s import models' % self.db_module
|
||||||
yield ''
|
yield ''
|
||||||
|
@ -59,16 +63,19 @@ class Command(NoArgsCommand):
|
||||||
indexes = connection.introspection.get_indexes(cursor, table_name)
|
indexes = connection.introspection.get_indexes(cursor, table_name)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
indexes = {}
|
indexes = {}
|
||||||
|
used_column_names = [] # Holds column names used in the table so far
|
||||||
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
|
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
|
||||||
column_name = row[0]
|
|
||||||
att_name = column_name.lower()
|
|
||||||
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
|
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
|
||||||
extra_params = {} # Holds Field parameters such as 'db_column'.
|
extra_params = {} # Holds Field parameters such as 'db_column'.
|
||||||
|
column_name = row[0]
|
||||||
|
is_relation = i in relations
|
||||||
|
|
||||||
# If the column name can't be used verbatim as a Python
|
att_name, params, notes = self.normalize_col_name(
|
||||||
# attribute, set the "db_column" for this Field.
|
column_name, used_column_names, is_relation)
|
||||||
if ' ' in att_name or '-' in att_name or keyword.iskeyword(att_name) or column_name != att_name:
|
extra_params.update(params)
|
||||||
extra_params['db_column'] = column_name
|
comment_notes.extend(notes)
|
||||||
|
|
||||||
|
used_column_names.append(att_name)
|
||||||
|
|
||||||
# Add primary_key and unique, if necessary.
|
# Add primary_key and unique, if necessary.
|
||||||
if column_name in indexes:
|
if column_name in indexes:
|
||||||
|
@ -77,30 +84,12 @@ class Command(NoArgsCommand):
|
||||||
elif indexes[column_name]['unique']:
|
elif indexes[column_name]['unique']:
|
||||||
extra_params['unique'] = True
|
extra_params['unique'] = True
|
||||||
|
|
||||||
# Modify the field name to make it Python-compatible.
|
if is_relation:
|
||||||
if ' ' in att_name:
|
|
||||||
att_name = att_name.replace(' ', '_')
|
|
||||||
comment_notes.append('Field renamed to remove spaces.')
|
|
||||||
|
|
||||||
if '-' in att_name:
|
|
||||||
att_name = att_name.replace('-', '_')
|
|
||||||
comment_notes.append('Field renamed to remove dashes.')
|
|
||||||
|
|
||||||
if column_name != att_name:
|
|
||||||
comment_notes.append('Field name made lowercase.')
|
|
||||||
|
|
||||||
if i in relations:
|
|
||||||
rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1])
|
rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1])
|
||||||
|
|
||||||
if rel_to in known_models:
|
if rel_to in known_models:
|
||||||
field_type = 'ForeignKey(%s' % rel_to
|
field_type = 'ForeignKey(%s' % rel_to
|
||||||
else:
|
else:
|
||||||
field_type = "ForeignKey('%s'" % rel_to
|
field_type = "ForeignKey('%s'" % rel_to
|
||||||
|
|
||||||
if att_name.endswith('_id'):
|
|
||||||
att_name = att_name[:-3]
|
|
||||||
else:
|
|
||||||
extra_params['db_column'] = column_name
|
|
||||||
else:
|
else:
|
||||||
# Calling `get_field_type` to get the field type string and any
|
# Calling `get_field_type` to get the field type string and any
|
||||||
# additional paramters and notes.
|
# additional paramters and notes.
|
||||||
|
@ -110,16 +99,6 @@ class Command(NoArgsCommand):
|
||||||
|
|
||||||
field_type += '('
|
field_type += '('
|
||||||
|
|
||||||
if keyword.iskeyword(att_name):
|
|
||||||
att_name += '_field'
|
|
||||||
comment_notes.append('Field renamed because it was a Python reserved word.')
|
|
||||||
|
|
||||||
if att_name[0].isdigit():
|
|
||||||
att_name = 'number_%s' % att_name
|
|
||||||
extra_params['db_column'] = six.text_type(column_name)
|
|
||||||
comment_notes.append("Field renamed because it wasn't a "
|
|
||||||
"valid Python identifier.")
|
|
||||||
|
|
||||||
# Don't output 'id = meta.AutoField(primary_key=True)', because
|
# Don't output 'id = meta.AutoField(primary_key=True)', because
|
||||||
# that's assumed if it doesn't exist.
|
# that's assumed if it doesn't exist.
|
||||||
if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}:
|
if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}:
|
||||||
|
@ -136,7 +115,9 @@ class Command(NoArgsCommand):
|
||||||
if extra_params:
|
if extra_params:
|
||||||
if not field_desc.endswith('('):
|
if not field_desc.endswith('('):
|
||||||
field_desc += ', '
|
field_desc += ', '
|
||||||
field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()])
|
field_desc += ', '.join([
|
||||||
|
'%s=%s' % (k, strip_prefix(repr(v)))
|
||||||
|
for k, v in extra_params.items()])
|
||||||
field_desc += ')'
|
field_desc += ')'
|
||||||
if comment_notes:
|
if comment_notes:
|
||||||
field_desc += ' # ' + ' '.join(comment_notes)
|
field_desc += ' # ' + ' '.join(comment_notes)
|
||||||
|
@ -144,6 +125,64 @@ class Command(NoArgsCommand):
|
||||||
for meta_line in self.get_meta(table_name):
|
for meta_line in self.get_meta(table_name):
|
||||||
yield meta_line
|
yield meta_line
|
||||||
|
|
||||||
|
def normalize_col_name(self, col_name, used_column_names, is_relation):
|
||||||
|
"""
|
||||||
|
Modify the column name to make it Python-compatible as a field name
|
||||||
|
"""
|
||||||
|
field_params = {}
|
||||||
|
field_notes = []
|
||||||
|
|
||||||
|
new_name = col_name.lower()
|
||||||
|
if new_name != col_name:
|
||||||
|
field_notes.append('Field name made lowercase.')
|
||||||
|
|
||||||
|
if is_relation:
|
||||||
|
if new_name.endswith('_id'):
|
||||||
|
new_name = new_name[:-3]
|
||||||
|
else:
|
||||||
|
field_params['db_column'] = col_name
|
||||||
|
|
||||||
|
if ' ' in new_name:
|
||||||
|
new_name = new_name.replace(' ', '_')
|
||||||
|
field_notes.append('Field renamed to remove spaces.')
|
||||||
|
|
||||||
|
if '-' in new_name:
|
||||||
|
new_name = new_name.replace('-', '_')
|
||||||
|
field_notes.append('Field renamed to remove dashes.')
|
||||||
|
|
||||||
|
if new_name.find('__') >= 0:
|
||||||
|
while new_name.find('__') >= 0:
|
||||||
|
new_name = new_name.replace('__', '_')
|
||||||
|
field_notes.append("Field renamed because it contained more than one '_' in a row.")
|
||||||
|
|
||||||
|
if new_name.startswith('_'):
|
||||||
|
new_name = 'field%s' % new_name
|
||||||
|
field_notes.append("Field renamed because it started with '_'.")
|
||||||
|
|
||||||
|
if new_name.endswith('_'):
|
||||||
|
new_name = '%sfield' % new_name
|
||||||
|
field_notes.append("Field renamed because it ended with '_'.")
|
||||||
|
|
||||||
|
if keyword.iskeyword(new_name):
|
||||||
|
new_name += '_field'
|
||||||
|
field_notes.append('Field renamed because it was a Python reserved word.')
|
||||||
|
|
||||||
|
if new_name[0].isdigit():
|
||||||
|
new_name = 'number_%s' % new_name
|
||||||
|
field_notes.append("Field renamed because it wasn't a valid Python identifier.")
|
||||||
|
|
||||||
|
if new_name in used_column_names:
|
||||||
|
num = 0
|
||||||
|
while '%s_%d' % (new_name, num) in used_column_names:
|
||||||
|
num += 1
|
||||||
|
new_name = '%s_%d' % (new_name, num)
|
||||||
|
field_notes.append('Field renamed because of name conflict.')
|
||||||
|
|
||||||
|
if col_name != new_name and field_notes:
|
||||||
|
field_params['db_column'] = col_name
|
||||||
|
|
||||||
|
return new_name, field_params, field_notes
|
||||||
|
|
||||||
def get_field_type(self, connection, table_name, row):
|
def get_field_type(self, connection, table_name, row):
|
||||||
"""
|
"""
|
||||||
Given the database connection, the table name, and the cursor row
|
Given the database connection, the table name, and the cursor row
|
||||||
|
@ -181,6 +220,6 @@ class Command(NoArgsCommand):
|
||||||
to construct the inner Meta class for the model corresponding
|
to construct the inner Meta class for the model corresponding
|
||||||
to the given database table name.
|
to the given database table name.
|
||||||
"""
|
"""
|
||||||
return [' class Meta:',
|
return [" class Meta:",
|
||||||
' db_table = %r' % table_name,
|
" db_table = '%s'" % table_name,
|
||||||
'']
|
""]
|
||||||
|
|
|
@ -19,3 +19,9 @@ class DigitsInColumnName(models.Model):
|
||||||
all_digits = models.CharField(max_length=11, db_column='123')
|
all_digits = models.CharField(max_length=11, db_column='123')
|
||||||
leading_digit = models.CharField(max_length=11, db_column='4extra')
|
leading_digit = models.CharField(max_length=11, db_column='4extra')
|
||||||
leading_digits = models.CharField(max_length=11, db_column='45extra')
|
leading_digits = models.CharField(max_length=11, db_column='45extra')
|
||||||
|
|
||||||
|
class UnderscoresInColumnName(models.Model):
|
||||||
|
field = models.IntegerField(db_column='field')
|
||||||
|
field_field_0 = models.IntegerField(db_column='Field_')
|
||||||
|
field_field_1 = models.IntegerField(db_column='Field__')
|
||||||
|
field_field_2 = models.IntegerField(db_column='__field')
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
from django.utils.six import StringIO
|
from django.utils.six import StringIO
|
||||||
|
@ -17,7 +19,6 @@ class InspectDBTestCase(TestCase):
|
||||||
# the Django test suite, check that one of its tables hasn't been
|
# the Django test suite, check that one of its tables hasn't been
|
||||||
# inspected
|
# inspected
|
||||||
self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message)
|
self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message)
|
||||||
out.close()
|
|
||||||
|
|
||||||
@skipUnlessDBFeature('can_introspect_foreign_keys')
|
@skipUnlessDBFeature('can_introspect_foreign_keys')
|
||||||
def test_attribute_name_not_python_keyword(self):
|
def test_attribute_name_not_python_keyword(self):
|
||||||
|
@ -27,15 +28,16 @@ class InspectDBTestCase(TestCase):
|
||||||
call_command('inspectdb',
|
call_command('inspectdb',
|
||||||
table_name_filter=lambda tn:tn.startswith('inspectdb_'),
|
table_name_filter=lambda tn:tn.startswith('inspectdb_'),
|
||||||
stdout=out)
|
stdout=out)
|
||||||
|
output = out.getvalue()
|
||||||
error_message = "inspectdb generated an attribute name which is a python keyword"
|
error_message = "inspectdb generated an attribute name which is a python keyword"
|
||||||
self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", out.getvalue(), msg=error_message)
|
self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", output, msg=error_message)
|
||||||
# As InspectdbPeople model is defined after InspectdbMessage, it should be quoted
|
# As InspectdbPeople model is defined after InspectdbMessage, it should be quoted
|
||||||
self.assertIn("from_field = models.ForeignKey('InspectdbPeople')", out.getvalue())
|
self.assertIn("from_field = models.ForeignKey('InspectdbPeople', db_column='from_id')",
|
||||||
|
output)
|
||||||
self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)",
|
self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)",
|
||||||
out.getvalue())
|
output)
|
||||||
self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)",
|
self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)",
|
||||||
out.getvalue())
|
output)
|
||||||
out.close()
|
|
||||||
|
|
||||||
def test_digits_column_name_introspection(self):
|
def test_digits_column_name_introspection(self):
|
||||||
"""Introspection of column names consist/start with digits (#16536/#17676)"""
|
"""Introspection of column names consist/start with digits (#16536/#17676)"""
|
||||||
|
@ -45,13 +47,24 @@ class InspectDBTestCase(TestCase):
|
||||||
call_command('inspectdb',
|
call_command('inspectdb',
|
||||||
table_name_filter=lambda tn:tn.startswith('inspectdb_'),
|
table_name_filter=lambda tn:tn.startswith('inspectdb_'),
|
||||||
stdout=out)
|
stdout=out)
|
||||||
|
output = out.getvalue()
|
||||||
error_message = "inspectdb generated a model field name which is a number"
|
error_message = "inspectdb generated a model field name which is a number"
|
||||||
self.assertNotIn(" 123 = models.CharField", out.getvalue(), msg=error_message)
|
self.assertNotIn(" 123 = models.CharField", output, msg=error_message)
|
||||||
self.assertIn("number_123 = models.CharField", out.getvalue())
|
self.assertIn("number_123 = models.CharField", output)
|
||||||
|
|
||||||
error_message = "inspectdb generated a model field name which starts with a digit"
|
error_message = "inspectdb generated a model field name which starts with a digit"
|
||||||
self.assertNotIn(" 4extra = models.CharField", out.getvalue(), msg=error_message)
|
self.assertNotIn(" 4extra = models.CharField", output, msg=error_message)
|
||||||
self.assertIn("number_4extra = models.CharField", out.getvalue())
|
self.assertIn("number_4extra = models.CharField", output)
|
||||||
|
|
||||||
self.assertNotIn(" 45extra = models.CharField", out.getvalue(), msg=error_message)
|
self.assertNotIn(" 45extra = models.CharField", output, msg=error_message)
|
||||||
self.assertIn("number_45extra = models.CharField", out.getvalue())
|
self.assertIn("number_45extra = models.CharField", output)
|
||||||
|
|
||||||
|
def test_underscores_column_name_introspection(self):
|
||||||
|
"""Introspection of column names containing underscores (#12460)"""
|
||||||
|
out = StringIO()
|
||||||
|
call_command('inspectdb', stdout=out)
|
||||||
|
output = out.getvalue()
|
||||||
|
self.assertIn("field = models.IntegerField()", output)
|
||||||
|
self.assertIn("field_field = models.IntegerField(db_column='Field_')", output)
|
||||||
|
self.assertIn("field_field_0 = models.IntegerField(db_column='Field__')", output)
|
||||||
|
self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)
|
||||||
|
|
Loading…
Reference in New Issue