From 395c6083af93cc37c7d16ef4db451091841cefdc Mon Sep 17 00:00:00 2001 From: Claude Paroz Date: Thu, 23 Aug 2012 21:07:56 +0200 Subject: [PATCH] Fixed #12460 -- Improved inspectdb handling of special field names Thanks mihail lukin for the report and elijahr and kgibula for their work on the patch. --- django/core/management/commands/inspectdb.py | 117 ++++++++++++------- tests/regressiontests/inspectdb/models.py | 6 + tests/regressiontests/inspectdb/tests.py | 37 ++++-- 3 files changed, 109 insertions(+), 51 deletions(-) diff --git a/django/core/management/commands/inspectdb.py b/django/core/management/commands/inspectdb.py index 7c868e4b60..c3c0776273 100644 --- a/django/core/management/commands/inspectdb.py +++ b/django/core/management/commands/inspectdb.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import keyword from optparse import make_option @@ -31,6 +33,7 @@ class Command(NoArgsCommand): table_name_filter = options.get('table_name_filter') table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '') + strip_prefix = lambda s: s.startswith("u'") and s[1:] or s cursor = connection.cursor() yield "# This is an auto-generated Django model module." @@ -41,6 +44,7 @@ class Command(NoArgsCommand): yield "#" yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'" yield "# into your database." + yield "from __future__ import unicode_literals" yield '' yield 'from %s import models' % self.db_module yield '' @@ -59,16 +63,19 @@ class Command(NoArgsCommand): indexes = connection.introspection.get_indexes(cursor, table_name) except NotImplementedError: indexes = {} + used_column_names = [] # Holds column names used in the table so far for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)): - column_name = row[0] - att_name = column_name.lower() comment_notes = [] # Holds Field notes, to be displayed in a Python comment. extra_params = {} # Holds Field parameters such as 'db_column'. + column_name = row[0] + is_relation = i in relations - # If the column name can't be used verbatim as a Python - # attribute, set the "db_column" for this Field. - if ' ' in att_name or '-' in att_name or keyword.iskeyword(att_name) or column_name != att_name: - extra_params['db_column'] = column_name + att_name, params, notes = self.normalize_col_name( + column_name, used_column_names, is_relation) + extra_params.update(params) + comment_notes.extend(notes) + + used_column_names.append(att_name) # Add primary_key and unique, if necessary. if column_name in indexes: @@ -77,30 +84,12 @@ class Command(NoArgsCommand): elif indexes[column_name]['unique']: extra_params['unique'] = True - # Modify the field name to make it Python-compatible. - if ' ' in att_name: - att_name = att_name.replace(' ', '_') - comment_notes.append('Field renamed to remove spaces.') - - if '-' in att_name: - att_name = att_name.replace('-', '_') - comment_notes.append('Field renamed to remove dashes.') - - if column_name != att_name: - comment_notes.append('Field name made lowercase.') - - if i in relations: + if is_relation: rel_to = relations[i][1] == table_name and "'self'" or table2model(relations[i][1]) - if rel_to in known_models: field_type = 'ForeignKey(%s' % rel_to else: field_type = "ForeignKey('%s'" % rel_to - - if att_name.endswith('_id'): - att_name = att_name[:-3] - else: - extra_params['db_column'] = column_name else: # Calling `get_field_type` to get the field type string and any # additional paramters and notes. @@ -110,16 +99,6 @@ class Command(NoArgsCommand): field_type += '(' - if keyword.iskeyword(att_name): - att_name += '_field' - comment_notes.append('Field renamed because it was a Python reserved word.') - - if att_name[0].isdigit(): - att_name = 'number_%s' % att_name - extra_params['db_column'] = six.text_type(column_name) - comment_notes.append("Field renamed because it wasn't a " - "valid Python identifier.") - # Don't output 'id = meta.AutoField(primary_key=True)', because # that's assumed if it doesn't exist. if att_name == 'id' and field_type == 'AutoField(' and extra_params == {'primary_key': True}: @@ -136,7 +115,9 @@ class Command(NoArgsCommand): if extra_params: if not field_desc.endswith('('): field_desc += ', ' - field_desc += ', '.join(['%s=%r' % (k, v) for k, v in extra_params.items()]) + field_desc += ', '.join([ + '%s=%s' % (k, strip_prefix(repr(v))) + for k, v in extra_params.items()]) field_desc += ')' if comment_notes: field_desc += ' # ' + ' '.join(comment_notes) @@ -144,6 +125,64 @@ class Command(NoArgsCommand): for meta_line in self.get_meta(table_name): yield meta_line + def normalize_col_name(self, col_name, used_column_names, is_relation): + """ + Modify the column name to make it Python-compatible as a field name + """ + field_params = {} + field_notes = [] + + new_name = col_name.lower() + if new_name != col_name: + field_notes.append('Field name made lowercase.') + + if is_relation: + if new_name.endswith('_id'): + new_name = new_name[:-3] + else: + field_params['db_column'] = col_name + + if ' ' in new_name: + new_name = new_name.replace(' ', '_') + field_notes.append('Field renamed to remove spaces.') + + if '-' in new_name: + new_name = new_name.replace('-', '_') + field_notes.append('Field renamed to remove dashes.') + + if new_name.find('__') >= 0: + while new_name.find('__') >= 0: + new_name = new_name.replace('__', '_') + field_notes.append("Field renamed because it contained more than one '_' in a row.") + + if new_name.startswith('_'): + new_name = 'field%s' % new_name + field_notes.append("Field renamed because it started with '_'.") + + if new_name.endswith('_'): + new_name = '%sfield' % new_name + field_notes.append("Field renamed because it ended with '_'.") + + if keyword.iskeyword(new_name): + new_name += '_field' + field_notes.append('Field renamed because it was a Python reserved word.') + + if new_name[0].isdigit(): + new_name = 'number_%s' % new_name + field_notes.append("Field renamed because it wasn't a valid Python identifier.") + + if new_name in used_column_names: + num = 0 + while '%s_%d' % (new_name, num) in used_column_names: + num += 1 + new_name = '%s_%d' % (new_name, num) + field_notes.append('Field renamed because of name conflict.') + + if col_name != new_name and field_notes: + field_params['db_column'] = col_name + + return new_name, field_params, field_notes + def get_field_type(self, connection, table_name, row): """ Given the database connection, the table name, and the cursor row @@ -181,6 +220,6 @@ class Command(NoArgsCommand): to construct the inner Meta class for the model corresponding to the given database table name. """ - return [' class Meta:', - ' db_table = %r' % table_name, - ''] + return [" class Meta:", + " db_table = '%s'" % table_name, + ""] diff --git a/tests/regressiontests/inspectdb/models.py b/tests/regressiontests/inspectdb/models.py index 9f815855b6..352053aafe 100644 --- a/tests/regressiontests/inspectdb/models.py +++ b/tests/regressiontests/inspectdb/models.py @@ -19,3 +19,9 @@ class DigitsInColumnName(models.Model): all_digits = models.CharField(max_length=11, db_column='123') leading_digit = models.CharField(max_length=11, db_column='4extra') leading_digits = models.CharField(max_length=11, db_column='45extra') + +class UnderscoresInColumnName(models.Model): + field = models.IntegerField(db_column='field') + field_field_0 = models.IntegerField(db_column='Field_') + field_field_1 = models.IntegerField(db_column='Field__') + field_field_2 = models.IntegerField(db_column='__field') diff --git a/tests/regressiontests/inspectdb/tests.py b/tests/regressiontests/inspectdb/tests.py index aae7bc5cc7..b5647e9e38 100644 --- a/tests/regressiontests/inspectdb/tests.py +++ b/tests/regressiontests/inspectdb/tests.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + from django.core.management import call_command from django.test import TestCase, skipUnlessDBFeature from django.utils.six import StringIO @@ -17,7 +19,6 @@ class InspectDBTestCase(TestCase): # the Django test suite, check that one of its tables hasn't been # inspected self.assertNotIn("class DjangoContentType(models.Model):", out.getvalue(), msg=error_message) - out.close() @skipUnlessDBFeature('can_introspect_foreign_keys') def test_attribute_name_not_python_keyword(self): @@ -27,15 +28,16 @@ class InspectDBTestCase(TestCase): call_command('inspectdb', table_name_filter=lambda tn:tn.startswith('inspectdb_'), stdout=out) + output = out.getvalue() error_message = "inspectdb generated an attribute name which is a python keyword" - self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", out.getvalue(), msg=error_message) + self.assertNotIn("from = models.ForeignKey(InspectdbPeople)", output, msg=error_message) # As InspectdbPeople model is defined after InspectdbMessage, it should be quoted - self.assertIn("from_field = models.ForeignKey('InspectdbPeople')", out.getvalue()) + self.assertIn("from_field = models.ForeignKey('InspectdbPeople', db_column='from_id')", + output) self.assertIn("people_pk = models.ForeignKey(InspectdbPeople, primary_key=True)", - out.getvalue()) + output) self.assertIn("people_unique = models.ForeignKey(InspectdbPeople, unique=True)", - out.getvalue()) - out.close() + output) def test_digits_column_name_introspection(self): """Introspection of column names consist/start with digits (#16536/#17676)""" @@ -45,13 +47,24 @@ class InspectDBTestCase(TestCase): call_command('inspectdb', table_name_filter=lambda tn:tn.startswith('inspectdb_'), stdout=out) + output = out.getvalue() error_message = "inspectdb generated a model field name which is a number" - self.assertNotIn(" 123 = models.CharField", out.getvalue(), msg=error_message) - self.assertIn("number_123 = models.CharField", out.getvalue()) + self.assertNotIn(" 123 = models.CharField", output, msg=error_message) + self.assertIn("number_123 = models.CharField", output) error_message = "inspectdb generated a model field name which starts with a digit" - self.assertNotIn(" 4extra = models.CharField", out.getvalue(), msg=error_message) - self.assertIn("number_4extra = models.CharField", out.getvalue()) + self.assertNotIn(" 4extra = models.CharField", output, msg=error_message) + self.assertIn("number_4extra = models.CharField", output) - self.assertNotIn(" 45extra = models.CharField", out.getvalue(), msg=error_message) - self.assertIn("number_45extra = models.CharField", out.getvalue()) + self.assertNotIn(" 45extra = models.CharField", output, msg=error_message) + self.assertIn("number_45extra = models.CharField", output) + + def test_underscores_column_name_introspection(self): + """Introspection of column names containing underscores (#12460)""" + out = StringIO() + call_command('inspectdb', stdout=out) + output = out.getvalue() + self.assertIn("field = models.IntegerField()", output) + self.assertIn("field_field = models.IntegerField(db_column='Field_')", output) + self.assertIn("field_field_0 = models.IntegerField(db_column='Field__')", output) + self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)