Adding 'sqlmigrate' command and quote_parameter to support it.

This commit is contained in:
Andrew Godwin 2013-09-06 15:27:51 -05:00
parent 5ca290f5db
commit efd1e6096e
13 changed files with 169 additions and 22 deletions

View File

@ -0,0 +1,52 @@
# encoding: utf8
from __future__ import unicode_literals
from optparse import make_option
from django.core.management.base import BaseCommand, CommandError
from django.db import connections, DEFAULT_DB_ALIAS
from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.loader import AmbiguityError
class Command(BaseCommand):
option_list = BaseCommand.option_list + (
make_option('--database', action='store', dest='database',
default=DEFAULT_DB_ALIAS, help='Nominates a database to create SQL for. '
'Defaults to the "default" database.'),
make_option('--backwards', action='store_true', dest='backwards',
default=False, help='Creates SQL to unapply the migration, rather than to apply it'),
)
help = "Prints the SQL statements for the named migration."
def handle(self, *args, **options):
# Get the database we're operating from
db = options.get('database')
connection = connections[db]
# Load up an executor to get all the migration data
executor = MigrationExecutor(connection)
# Resolve command-line arguments into a migration
if len(args) != 2:
raise CommandError("Wrong number of arguments (expecting 'sqlmigrate appname migrationname')")
else:
app_label, migration_name = args
if app_label not in executor.loader.migrated_apps:
raise CommandError("App '%s' does not have migrations" % app_label)
try:
migration = executor.loader.get_migration_by_prefix(app_label, migration_name)
except AmbiguityError:
raise CommandError("More than one migration matches '%s' in app '%s'. Please be more specific." % (app_label, migration_name))
except KeyError:
raise CommandError("Cannot find a migration matching '%s' from app '%s'. Is it in INSTALLED_APPS?" % (app_label, migration_name))
targets = [(app_label, migration.name)]
# Make a plan that represents just the requested migrations and show SQL
# for it
plan = [(executor.loader.graph.nodes[targets[0]], options.get("backwards", False))]
sql_statements = executor.collect_sql(plan)
for statement in sql_statements:
self.stdout.write(statement)

View File

@ -521,7 +521,7 @@ class BaseDatabaseWrapper(object):
""" """
raise NotImplementedError raise NotImplementedError
def schema_editor(self): def schema_editor(self, *args, **kwargs):
"Returns a new instance of this backend's SchemaEditor" "Returns a new instance of this backend's SchemaEditor"
raise NotImplementedError() raise NotImplementedError()
@ -958,6 +958,15 @@ class BaseDatabaseOperations(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def quote_parameter(self, value):
"""
Returns a quoted version of the value so it's safe to use in an SQL
string. This should NOT be used to prepare SQL statements to send to
the database; it is meant for outputting SQL statements to a file
or the console for later execution by a developer/DBA.
"""
raise NotImplementedError()
def random_function_sql(self): def random_function_sql(self):
""" """
Returns an SQL expression that returns a random value. Returns an SQL expression that returns a random value.

View File

@ -305,6 +305,11 @@ class DatabaseOperations(BaseDatabaseOperations):
return name # Quoting once is enough. return name # Quoting once is enough.
return "`%s`" % name return "`%s`" % name
def quote_parameter(self, value):
# Inner import to allow module to fail to load gracefully
import MySQLdb.converters
return MySQLdb.escape(value, MySQLdb.converters.conversions)
def random_function_sql(self): def random_function_sql(self):
return 'RAND()' return 'RAND()'
@ -518,9 +523,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
table_name, column_name, bad_row[1], table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name)) referenced_table_name, referenced_column_name))
def schema_editor(self): def schema_editor(self, *args, **kwargs):
"Returns a new instance of this backend's SchemaEditor" "Returns a new instance of this backend's SchemaEditor"
return DatabaseSchemaEditor(self) return DatabaseSchemaEditor(self, *args, **kwargs)
def is_usable(self): def is_usable(self):
try: try:

View File

@ -320,6 +320,16 @@ WHEN (new.%(col_name)s IS NULL)
name = name.replace('%', '%%') name = name.replace('%', '%%')
return name.upper() return name.upper()
def quote_parameter(self, value):
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
return "'%s'" % value
elif isinstance(value, six.string_types):
return repr(value)
elif isinstance(value, bool):
return "1" if value else "0"
else:
return str(value)
def random_function_sql(self): def random_function_sql(self):
return "DBMS_RANDOM.RANDOM" return "DBMS_RANDOM.RANDOM"
@ -628,9 +638,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2]) six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
raise raise
def schema_editor(self): def schema_editor(self, *args, **kwargs):
"Returns a new instance of this backend's SchemaEditor" "Returns a new instance of this backend's SchemaEditor"
return DatabaseSchemaEditor(self) return DatabaseSchemaEditor(self, *args, **kwargs)
# Oracle doesn't support savepoint commits. Ignore them. # Oracle doesn't support savepoint commits. Ignore them.
def _savepoint_commit(self, sid): def _savepoint_commit(self, sid):

View File

@ -93,11 +93,4 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
return self.normalize_name(for_name + "_" + suffix) return self.normalize_name(for_name + "_" + suffix)
def prepare_default(self, value): def prepare_default(self, value):
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)): return self.connection.ops.quote_parameter(value)
return "'%s'" % value
elif isinstance(value, six.string_types):
return repr(value)
elif isinstance(value, bool):
return "1" if value else "0"
else:
return str(value)

View File

@ -205,9 +205,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else: else:
return True return True
def schema_editor(self): def schema_editor(self, *args, **kwargs):
"Returns a new instance of this backend's SchemaEditor" "Returns a new instance of this backend's SchemaEditor"
return DatabaseSchemaEditor(self) return DatabaseSchemaEditor(self, *args, **kwargs)
@cached_property @cached_property
def psycopg2_version(self): def psycopg2_version(self):

View File

@ -98,6 +98,11 @@ class DatabaseOperations(BaseDatabaseOperations):
return name # Quoting once is enough. return name # Quoting once is enough.
return '"%s"' % name return '"%s"' % name
def quote_parameter(self, value):
# Inner import so backend fails nicely if it's not present
import psycopg2
return psycopg2.extensions.adapt(value)
def set_time_zone_sql(self): def set_time_zone_sql(self):
return "SET TIME ZONE %s" return "SET TIME ZONE %s"

View File

@ -54,14 +54,17 @@ class BaseDatabaseSchemaEditor(object):
sql_create_fk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" sql_create_fk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_fk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s;" sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
sql_delete_index = "DROP INDEX %(name)s" sql_delete_index = "DROP INDEX %(name)s"
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)" sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
sql_delete_pk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_pk = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
def __init__(self, connection): def __init__(self, connection, collect_sql=False):
self.connection = connection self.connection = connection
self.collect_sql = collect_sql
if self.collect_sql:
self.collected_sql = []
# State-managing methods # State-managing methods
@ -86,7 +89,10 @@ class BaseDatabaseSchemaEditor(object):
cursor = self.connection.cursor() cursor = self.connection.cursor()
# Log the command we're running, then run it # Log the command we're running, then run it
logger.debug("%s; (params %r)" % (sql, params)) logger.debug("%s; (params %r)" % (sql, params))
cursor.execute(sql, params) if self.collect_sql:
self.collected_sql.append((sql % map(self.connection.ops.quote_parameter, params)) + ";")
else:
cursor.execute(sql, params)
def quote_name(self, name): def quote_name(self, name):
return self.connection.ops.quote_name(name) return self.connection.ops.quote_name(name)

View File

@ -214,6 +214,25 @@ class DatabaseOperations(BaseDatabaseOperations):
return name # Quoting once is enough. return name # Quoting once is enough.
return '"%s"' % name return '"%s"' % name
def quote_parameter(self, value):
# Inner import to allow nice failure for backend if not present
import _sqlite3
try:
value = _sqlite3.adapt(value)
except _sqlite3.ProgrammingError:
pass
# Manual emulation of SQLite parameter quoting
if isinstance(value, six.integer_types):
return str(value)
elif isinstance(value, six.string_types):
return six.text_type(value)
elif isinstance(value, type(True)):
return str(int(value))
elif value is None:
return "NULL"
else:
raise ValueError("Cannot quote parameter value %r" % value)
def no_limit_value(self): def no_limit_value(self):
return -1 return -1
@ -437,9 +456,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
""" """
self.cursor().execute("BEGIN") self.cursor().execute("BEGIN")
def schema_editor(self): def schema_editor(self, *args, **kwargs):
"Returns a new instance of this backend's SchemaEditor" "Returns a new instance of this backend's SchemaEditor"
return DatabaseSchemaEditor(self) return DatabaseSchemaEditor(self, *args, **kwargs)
FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s') FORMAT_QMARK_REGEX = re.compile(r'(?<!%)%s')

View File

@ -55,7 +55,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
self.create_model(temp_model) self.create_model(temp_model)
# Copy data from the old table # Copy data from the old table
field_maps = list(mapping.items()) field_maps = list(mapping.items())
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % ( self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(temp_model._meta.db_table), self.quote_name(temp_model._meta.db_table),
', '.join(x for x, y in field_maps), ', '.join(x for x, y in field_maps),
', '.join(y for x, y in field_maps), ', '.join(y for x, y in field_maps),
@ -137,7 +137,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Make a new through table # Make a new through table
self.create_model(new_field.rel.through) self.create_model(new_field.rel.through)
# Copy the data across # Copy the data across
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % ( self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(new_field.rel.through._meta.db_table), self.quote_name(new_field.rel.through._meta.db_table),
', '.join([ ', '.join([
"id", "id",

View File

@ -61,6 +61,22 @@ class MigrationExecutor(object):
else: else:
self.unapply_migration(migration, fake=fake) self.unapply_migration(migration, fake=fake)
def collect_sql(self, plan):
"""
Takes a migration plan and returns a list of collected SQL
statements that represent the best-efforts version of that plan.
"""
statements = []
for migration, backwards in plan:
with self.connection.schema_editor(collect_sql=True) as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
if not backwards:
migration.apply(project_state, schema_editor)
else:
migration.unapply(project_state, schema_editor)
statements.extend(schema_editor.collected_sql)
return statements
def apply_migration(self, migration, fake=False): def apply_migration(self, migration, fake=False):
""" """
Runs a migration forwards. Runs a migration forwards.

View File

@ -993,6 +993,24 @@ Prints the CREATE INDEX SQL statements for the given app name(s).
The :djadminopt:`--database` option can be used to specify the database for The :djadminopt:`--database` option can be used to specify the database for
which to print the SQL. which to print the SQL.
sqlmigrate <appname> <migrationname>
------------------------------------
.. django-admin:: sqlmigrate
Prints the SQL for the named migration. This requires an active database
connection, which it will use to resolve constraint names; this means you must
generate the SQL against a copy of the database you wish to later apply it on.
The :djadminopt:`--database` option can be used to specify the database for
which to generate the SQL.
.. django-admin-option:: --backwards
By default, the SQL created is for running the migration in the forwards
direction. Pass ``--backwards`` to generate the SQL for
un-applying the migration instead.
sqlsequencereset <appname appname ...> sqlsequencereset <appname appname ...>
-------------------------------------- --------------------------------------

View File

@ -48,6 +48,20 @@ class MigrateTests(MigrationTestBase):
self.assertTableNotExists("migrations_tribble") self.assertTableNotExists("migrations_tribble")
self.assertTableNotExists("migrations_book") self.assertTableNotExists("migrations_book")
@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations"})
def test_sqlmigrate(self):
"""
Makes sure that sqlmigrate does something.
"""
# Test forwards. All the databases agree on CREATE TABLE, at least.
stdout = six.StringIO()
call_command("sqlmigrate", "migrations", "0001", stdout=stdout)
self.assertIn("create table", stdout.getvalue().lower())
# And backwards is a DROP TABLE
stdout = six.StringIO()
call_command("sqlmigrate", "migrations", "0001", stdout=stdout, backwards=True)
self.assertIn("drop table", stdout.getvalue().lower())
class MakeMigrationsTests(MigrationTestBase): class MakeMigrationsTests(MigrationTestBase):
""" """