Make get_constraints return columns in order

This commit is contained in:
Andrew Godwin 2013-07-02 18:02:20 +01:00
parent 61ff46cf8b
commit 3a6580e485
5 changed files with 32 additions and 31 deletions

View File

@ -1,6 +1,6 @@
import re import re
from .base import FIELD_TYPE from .base import FIELD_TYPE
from django.utils.datastructures import SortedSet
from django.db.backends import BaseDatabaseIntrospection, FieldInfo from django.db.backends import BaseDatabaseIntrospection, FieldInfo
from django.utils.encoding import force_text from django.utils.encoding import force_text
@ -141,7 +141,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for constraint, column, ref_table, ref_column in cursor.fetchall(): for constraint, column, ref_table, ref_column in cursor.fetchall():
if constraint not in constraints: if constraint not in constraints:
constraints[constraint] = { constraints[constraint] = {
'columns': set(), 'columns': SortedSet(),
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': False,
'index': False, 'index': False,
@ -169,7 +169,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for table, non_unique, index, colseq, column in [x[:5] for x in cursor.fetchall()]: for table, non_unique, index, colseq, column in [x[:5] for x in cursor.fetchall()]:
if index not in constraints: if index not in constraints:
constraints[index] = { constraints[index] = {
'columns': set(), 'columns': SortedSet(),
'primary_key': False, 'primary_key': False,
'unique': False, 'unique': False,
'index': True, 'index': True,
@ -178,5 +178,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
} }
constraints[index]['index'] = True constraints[index]['index'] = True
constraints[index]['columns'].add(column) constraints[index]['columns'].add(column)
# Convert the sorted sets to lists
for constraint in constraints.values():
constraint['columns'] = list(constraint['columns'])
# Return # Return
return constraints return constraints

View File

@ -140,7 +140,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# If we're the first column, make the record # If we're the first column, make the record
if constraint not in constraints: if constraint not in constraints:
constraints[constraint] = { constraints[constraint] = {
"columns": set(), "columns": [],
"primary_key": kind.lower() == "primary key", "primary_key": kind.lower() == "primary key",
"unique": kind.lower() in ["primary key", "unique"], "unique": kind.lower() in ["primary key", "unique"],
"foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None, "foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None,
@ -148,7 +148,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"index": False, "index": False,
} }
# Record the details # Record the details
constraints[constraint]['columns'].add(column) constraints[constraint]['columns'].append(column)
# Now get CHECK constraint columns # Now get CHECK constraint columns
cursor.execute(""" cursor.execute("""
SELECT kc.constraint_name, kc.column_name SELECT kc.constraint_name, kc.column_name
@ -166,7 +166,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# If we're the first column, make the record # If we're the first column, make the record
if constraint not in constraints: if constraint not in constraints:
constraints[constraint] = { constraints[constraint] = {
"columns": set(), "columns": [],
"primary_key": False, "primary_key": False,
"unique": False, "unique": False,
"foreign_key": False, "foreign_key": False,
@ -174,17 +174,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"index": False, "index": False,
} }
# Record the details # Record the details
constraints[constraint]['columns'].add(column) constraints[constraint]['columns'].append(column)
# Now get indexes # Now get indexes
cursor.execute(""" cursor.execute("""
SELECT SELECT
c2.relname, c2.relname,
ARRAY( ARRAY(
SELECT attr.attname SELECT (SELECT attname FROM pg_catalog.pg_attribute WHERE attnum = i AND attrelid = c.oid)
FROM unnest(idx.indkey) i, pg_catalog.pg_attribute attr FROM unnest(idx.indkey) i
WHERE
attr.attnum = i AND
attr.attrelid = c.oid
), ),
idx.indisunique, idx.indisunique,
idx.indisprimary idx.indisprimary
@ -197,7 +194,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for index, columns, unique, primary in cursor.fetchall(): for index, columns, unique, primary in cursor.fetchall():
if index not in constraints: if index not in constraints:
constraints[index] = { constraints[index] = {
"columns": set(columns), "columns": list(columns),
"primary_key": primary, "primary_key": primary,
"unique": unique, "unique": unique,
"foreign_key": False, "foreign_key": False,

View File

@ -87,6 +87,7 @@ class BaseDatabaseSchemaEditor(object):
cursor = self.connection.cursor() cursor = self.connection.cursor()
# Log the command we're running, then run it # Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params)) logger.debug("%s; (params %r)" % (sql, params))
#print("%s; (params %r)" % (sql, params))
cursor.execute(sql, params) cursor.execute(sql, params)
def quote_name(self, name): def quote_name(self, name):
@ -228,12 +229,12 @@ class BaseDatabaseSchemaEditor(object):
Note: The input unique_togethers must be doubly-nested, not the single- Note: The input unique_togethers must be doubly-nested, not the single-
nested ["foo", "bar"] format. nested ["foo", "bar"] format.
""" """
olds = set(frozenset(fields) for fields in old_unique_together) olds = set(tuple(fields) for fields in old_unique_together)
news = set(frozenset(fields) for fields in new_unique_together) news = set(tuple(fields) for fields in new_unique_together)
# Deleted uniques # Deleted uniques
for fields in olds.difference(news): for fields in olds.difference(news):
columns = [model._meta.get_field_by_name(field)[0].column for field in fields] columns = [model._meta.get_field_by_name(field)[0].column for field in fields]
constraint_names = self._constraint_names(model, list(columns), unique=True) constraint_names = self._constraint_names(model, columns, unique=True)
if len(constraint_names) != 1: if len(constraint_names) != 1:
raise ValueError("Found wrong number (%s) of constraints for %s(%s)" % ( raise ValueError("Found wrong number (%s) of constraints for %s(%s)" % (
len(constraint_names), len(constraint_names),
@ -261,8 +262,8 @@ class BaseDatabaseSchemaEditor(object):
Note: The input index_togethers must be doubly-nested, not the single- Note: The input index_togethers must be doubly-nested, not the single-
nested ["foo", "bar"] format. nested ["foo", "bar"] format.
""" """
olds = set(frozenset(fields) for fields in old_index_together) olds = set(tuple(fields) for fields in old_index_together)
news = set(frozenset(fields) for fields in new_index_together) news = set(tuple(fields) for fields in new_index_together)
# Deleted indexes # Deleted indexes
for fields in olds.difference(news): for fields in olds.difference(news):
columns = [model._meta.get_field_by_name(field)[0].column for field in fields] columns = [model._meta.get_field_by_name(field)[0].column for field in fields]
@ -646,7 +647,7 @@ class BaseDatabaseSchemaEditor(object):
""" """
Returns all constraint names matching the columns and conditions Returns all constraint names matching the columns and conditions
""" """
column_names = set(column_names) if column_names else None column_names = list(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
result = [] result = []
for name, infodict in constraints.items(): for name, infodict in constraints.items():

View File

@ -197,14 +197,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for index_rank, column_rank, column in cursor.fetchall(): for index_rank, column_rank, column in cursor.fetchall():
if index not in constraints: if index not in constraints:
constraints[index] = { constraints[index] = {
"columns": set(), "columns": [],
"primary_key": False, "primary_key": False,
"unique": bool(unique), "unique": bool(unique),
"foreign_key": False, "foreign_key": False,
"check": False, "check": False,
"index": True, "index": True,
} }
constraints[index]['columns'].add(column) constraints[index]['columns'].append(column)
# Get the PK # Get the PK
pk_column = self.get_primary_key_column(cursor, table_name) pk_column = self.get_primary_key_column(cursor, table_name)
if pk_column: if pk_column:
@ -213,7 +213,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# deletes PK constraints by name, as you can't delete constraints # deletes PK constraints by name, as you can't delete constraints
# in SQLite; we remake the table with a new PK instead. # in SQLite; we remake the table with a new PK instead.
constraints["__primary__"] = { constraints["__primary__"] = {
"columns": set([pk_column]), "columns": [pk_column],
"primary_key": True, "primary_key": True,
"unique": False, # It's not actually a unique constraint. "unique": False, # It's not actually a unique constraint.
"foreign_key": False, "foreign_key": False,

View File

@ -128,7 +128,7 @@ class SchemaTests(TransactionTestCase):
# Make sure the new FK constraint is present # Make sure the new FK constraint is present
constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table) constraints = connection.introspection.get_constraints(connection.cursor(), Book._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == set(["author_id"]) and details['foreign_key']: if details['columns'] == ["author_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id')) self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
break break
else: else:
@ -285,7 +285,7 @@ class SchemaTests(TransactionTestCase):
constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table) constraints = connection.introspection.get_constraints(connection.cursor(), BookWithM2M._meta.get_field_by_name("tags")[0].rel.through._meta.db_table)
if connection.features.supports_foreign_keys: if connection.features.supports_foreign_keys:
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == set(["tag_id"]) and details['foreign_key']: if details['columns'] == ["tag_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_tag', 'id')) self.assertEqual(details['foreign_key'], ('schema_tag', 'id'))
break break
else: else:
@ -306,7 +306,7 @@ class SchemaTests(TransactionTestCase):
constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table) constraints = connection.introspection.get_constraints(connection.cursor(), new_field.rel.through._meta.db_table)
if connection.features.supports_foreign_keys: if connection.features.supports_foreign_keys:
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == set(["uniquetest_id"]) and details['foreign_key']: if details['columns'] == ["uniquetest_id"] and details['foreign_key']:
self.assertEqual(details['foreign_key'], ('schema_uniquetest', 'id')) self.assertEqual(details['foreign_key'], ('schema_uniquetest', 'id'))
break break
else: else:
@ -327,7 +327,7 @@ class SchemaTests(TransactionTestCase):
# Ensure the constraint exists # Ensure the constraint exists
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == set(["height"]) and details['check']: if details['columns'] == ["height"] and details['check']:
break break
else: else:
self.fail("No check constraint for height found") self.fail("No check constraint for height found")
@ -343,7 +343,7 @@ class SchemaTests(TransactionTestCase):
) )
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == set(["height"]) and details['check']: if details['columns'] == ["height"] and details['check']:
self.fail("Check constraint for height found") self.fail("Check constraint for height found")
# Alter the column to re-add it # Alter the column to re-add it
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
@ -355,7 +355,7 @@ class SchemaTests(TransactionTestCase):
) )
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table) constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
for name, details in constraints.items(): for name, details in constraints.items():
if details['columns'] == set(["height"]) and details['check']: if details['columns'] == ["height"] and details['check']:
break break
else: else:
self.fail("No check constraint for height found") self.fail("No check constraint for height found")
@ -465,7 +465,7 @@ class SchemaTests(TransactionTestCase):
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
if c['columns'] == set(["slug", "title"]) if c['columns'] == ["slug", "title"]
), ),
) )
# Alter the model to add an index # Alter the model to add an index
@ -481,7 +481,7 @@ class SchemaTests(TransactionTestCase):
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
if c['columns'] == set(["slug", "title"]) if c['columns'] == ["slug", "title"]
), ),
) )
# Alter it back # Alter it back
@ -499,7 +499,7 @@ class SchemaTests(TransactionTestCase):
any( any(
c["index"] c["index"]
for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values() for c in connection.introspection.get_constraints(connection.cursor(), "schema_tag").values()
if c['columns'] == set(["slug", "title"]) if c['columns'] == ["slug", "title"]
), ),
) )