From 9dc4ba875f21d5690f6ad5995123a67a3c44bafe Mon Sep 17 00:00:00 2001 From: Russell Keith-Magee Date: Mon, 11 Aug 2008 12:11:25 +0000 Subject: [PATCH] Fixed #5461 -- Refactored the database backend code to use classes for the creation and introspection modules. Introduces a new validation module for DB-specific validation. This is a backwards incompatible change; see the wiki for details. git-svn-id: http://code.djangoproject.com/svn/django/trunk@8296 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- .../contrib/gis/db/backend/mysql/creation.py | 4 +- .../contrib/gis/db/backend/oracle/creation.py | 4 +- .../gis/db/backend/postgis/creation.py | 32 +- .../gis/management/commands/inspectdb.py | 26 +- django/core/management/commands/dbshell.py | 4 +- django/core/management/commands/inspectdb.py | 14 +- django/core/management/commands/syncdb.py | 27 +- django/core/management/commands/testserver.py | 4 +- django/core/management/sql.py | 330 +---------- django/core/management/validation.py | 7 +- django/db/__init__.py | 22 +- django/db/backends/__init__.py | 101 +++- django/db/backends/creation.py | 395 ++++++++++++- django/db/backends/dummy/base.py | 27 +- django/db/backends/dummy/client.py | 3 - django/db/backends/dummy/creation.py | 1 - django/db/backends/dummy/introspection.py | 8 - django/db/backends/mysql/base.py | 18 +- django/db/backends/mysql/client.py | 46 +- django/db/backends/mysql/creation.py | 96 +++- django/db/backends/mysql/introspection.py | 171 +++--- django/db/backends/mysql/validation.py | 13 + django/db/backends/oracle/base.py | 22 +- django/db/backends/oracle/client.py | 18 +- django/db/backends/oracle/creation.py | 526 +++++++++--------- django/db/backends/oracle/introspection.py | 181 +++--- django/db/backends/postgresql/base.py | 23 +- django/db/backends/postgresql/client.py | 26 +- django/db/backends/postgresql/creation.py | 66 ++- .../db/backends/postgresql/introspection.py | 162 +++--- .../db/backends/postgresql_psycopg2/base.py | 18 +- .../db/backends/postgresql_psycopg2/client.py | 1 - .../backends/postgresql_psycopg2/creation.py | 1 - .../postgresql_psycopg2/introspection.py | 100 +--- django/db/backends/sqlite3/base.py | 22 +- django/db/backends/sqlite3/client.py | 8 +- django/db/backends/sqlite3/creation.py | 100 +++- django/db/backends/sqlite3/introspection.py | 148 ++--- django/db/models/fields/__init__.py | 6 +- django/test/simple.py | 6 +- django/test/utils.py | 151 +---- docs/testing.txt | 5 +- tests/regressiontests/backends/models.py | 8 +- 43 files changed, 1528 insertions(+), 1423 deletions(-) create mode 100644 django/db/backends/mysql/validation.py diff --git a/django/contrib/gis/db/backend/mysql/creation.py b/django/contrib/gis/db/backend/mysql/creation.py index e8f471df81..3da21a0cdd 100644 --- a/django/contrib/gis/db/backend/mysql/creation.py +++ b/django/contrib/gis/db/backend/mysql/creation.py @@ -1,5 +1,5 @@ -from django.test.utils import create_test_db def create_spatial_db(test=True, verbosity=1, autoclobber=False): if not test: raise NotImplementedError('This uses `create_test_db` from test/utils.py') - create_test_db(verbosity, autoclobber) + from django.db import connection + connection.creation.create_test_db(verbosity, autoclobber) diff --git a/django/contrib/gis/db/backend/oracle/creation.py b/django/contrib/gis/db/backend/oracle/creation.py index 4a05da7ec1..d9b53d2049 100644 --- a/django/contrib/gis/db/backend/oracle/creation.py +++ b/django/contrib/gis/db/backend/oracle/creation.py @@ -1,8 +1,6 @@ -from django.db.backends.oracle.creation import create_test_db def create_spatial_db(test=True, verbosity=1, autoclobber=False): "A wrapper over the Oracle `create_test_db` routine." if not test: raise NotImplementedError('This uses `create_test_db` from db/backends/oracle/creation.py') - from django.conf import settings from django.db import connection - create_test_db(settings, connection, verbosity, autoclobber) + connection.creation.create_test_db(verbosity, autoclobber) diff --git a/django/contrib/gis/db/backend/postgis/creation.py b/django/contrib/gis/db/backend/postgis/creation.py index a3884db187..daaeae6508 100644 --- a/django/contrib/gis/db/backend/postgis/creation.py +++ b/django/contrib/gis/db/backend/postgis/creation.py @@ -1,7 +1,7 @@ from django.conf import settings from django.core.management import call_command from django.db import connection -from django.test.utils import _set_autocommit, TEST_DATABASE_PREFIX +from django.db.backends.creation import TEST_DATABASE_PREFIX import os, re, sys def getstatusoutput(cmd): @@ -38,9 +38,9 @@ def _create_with_cursor(db_name, verbosity=1, autoclobber=False): create_sql = 'CREATE DATABASE %s' % connection.ops.quote_name(db_name) if settings.DATABASE_USER: create_sql += ' OWNER %s' % settings.DATABASE_USER - + cursor = connection.cursor() - _set_autocommit(connection) + connection.creation.set_autocommit(connection) try: # Trying to create the database first. @@ -58,12 +58,12 @@ def _create_with_cursor(db_name, verbosity=1, autoclobber=False): else: raise Exception('Spatial Database Creation canceled.') foo = _create_with_cursor - + created_regex = re.compile(r'^createdb: database creation failed: ERROR: database ".+" already exists') def _create_with_shell(db_name, verbosity=1, autoclobber=False): """ - If no spatial database already exists, then using a cursor will not work. - Thus, a `createdb` command will be issued through the shell to bootstrap + If no spatial database already exists, then using a cursor will not work. + Thus, a `createdb` command will be issued through the shell to bootstrap creation of the spatial database. """ @@ -83,7 +83,7 @@ def _create_with_shell(db_name, verbosity=1, autoclobber=False): if verbosity >= 1: print 'Destroying old spatial database...' drop_cmd = 'dropdb %s%s' % (options, db_name) status, output = getstatusoutput(drop_cmd) - if status != 0: + if status != 0: raise Exception('Could not drop database %s: %s' % (db_name, output)) if verbosity >= 1: print 'Creating new spatial database...' status, output = getstatusoutput(create_cmd) @@ -102,10 +102,10 @@ def create_spatial_db(test=False, verbosity=1, autoclobber=False, interactive=Fa raise Exception('Spatial database creation only supported postgresql_psycopg2 platform.') # Getting the spatial database name - if test: + if test: db_name = get_spatial_db(test=True) _create_with_cursor(db_name, verbosity=verbosity, autoclobber=autoclobber) - else: + else: db_name = get_spatial_db() _create_with_shell(db_name, verbosity=verbosity, autoclobber=autoclobber) @@ -125,7 +125,7 @@ def create_spatial_db(test=False, verbosity=1, autoclobber=False, interactive=Fa # Syncing the database call_command('syncdb', verbosity=verbosity, interactive=interactive) - + def drop_db(db_name=False, test=False): """ Drops the given database (defaults to what is returned from @@ -151,7 +151,7 @@ def get_cmd_options(db_name): def get_spatial_db(test=False): """ - Returns the name of the spatial database. The 'test' keyword may be set + Returns the name of the spatial database. The 'test' keyword may be set to return the test spatial database name. """ if test: @@ -167,13 +167,13 @@ def get_spatial_db(test=False): def load_postgis_sql(db_name, verbosity=1): """ - This routine loads up the PostGIS SQL files lwpostgis.sql and + This routine loads up the PostGIS SQL files lwpostgis.sql and spatial_ref_sys.sql. """ # Getting the path to the PostGIS SQL try: - # POSTGIS_SQL_PATH may be placed in settings to tell GeoDjango where the + # POSTGIS_SQL_PATH may be placed in settings to tell GeoDjango where the # PostGIS SQL files are located. This is especially useful on Win32 # platforms since the output of pg_config looks like "C:/PROGRA~1/..". sql_path = settings.POSTGIS_SQL_PATH @@ -193,7 +193,7 @@ def load_postgis_sql(db_name, verbosity=1): # Getting the psql command-line options, and command format. options = get_cmd_options(db_name) cmd_fmt = 'psql %s-f "%%s"' % options - + # Now trying to load up the PostGIS functions cmd = cmd_fmt % lwpostgis_file if verbosity >= 1: print cmd @@ -211,8 +211,8 @@ def load_postgis_sql(db_name, verbosity=1): # Setting the permissions because on Windows platforms the owner # of the spatial_ref_sys and geometry_columns tables is always # the postgres user, regardless of how the db is created. - if os.name == 'nt': set_permissions(db_name) - + if os.name == 'nt': set_permissions(db_name) + def set_permissions(db_name): """ Sets the permissions on the given database to that of the user specified diff --git a/django/contrib/gis/management/commands/inspectdb.py b/django/contrib/gis/management/commands/inspectdb.py index 05e205353c..d4fe210953 100644 --- a/django/contrib/gis/management/commands/inspectdb.py +++ b/django/contrib/gis/management/commands/inspectdb.py @@ -7,7 +7,7 @@ from django.core.management.commands.inspectdb import Command as InspectCommand from django.contrib.gis.db.backend import SpatialBackend class Command(InspectCommand): - + # Mapping from lower-case OGC type to the corresponding GeoDjango field. geofield_mapping = {'point' : 'PointField', 'linestring' : 'LineStringField', @@ -21,11 +21,11 @@ class Command(InspectCommand): def geometry_columns(self): """ - Returns a datastructure of metadata information associated with the + Returns a datastructure of metadata information associated with the `geometry_columns` (or equivalent) table. """ # The `geo_cols` is a dictionary data structure that holds information - # about any geographic columns in the database. + # about any geographic columns in the database. geo_cols = {} def add_col(table, column, coldata): if table in geo_cols: @@ -47,7 +47,7 @@ class Command(InspectCommand): elif SpatialBackend.name == 'mysql': # On MySQL have to get all table metadata before hand; this means walking through # each table and seeing if any column types are spatial. Can't detect this with - # `cursor.description` (what the introspection module does) because all spatial types + # `cursor.description` (what the introspection module does) because all spatial types # have the same integer type (255 for GEOMETRY). from django.db import connection cursor = connection.cursor() @@ -67,13 +67,11 @@ class Command(InspectCommand): def handle_inspection(self): "Overloaded from Django's version to handle geographic database tables." - from django.db import connection, get_introspection_module + from django.db import connection import keyword - introspection_module = get_introspection_module() - geo_cols = self.geometry_columns() - + table2model = lambda table_name: table_name.title().replace('_', '') cursor = connection.cursor() @@ -88,20 +86,20 @@ class Command(InspectCommand): yield '' yield 'from django.contrib.gis.db import models' yield '' - for table_name in introspection_module.get_table_list(cursor): + for table_name in connection.introspection.get_table_list(cursor): # Getting the geographic table dictionary. geo_table = geo_cols.get(table_name, {}) yield 'class %s(models.Model):' % table2model(table_name) try: - relations = introspection_module.get_relations(cursor, table_name) + relations = connection.introspection.get_relations(cursor, table_name) except NotImplementedError: relations = {} try: - indexes = introspection_module.get_indexes(cursor, table_name) + indexes = connection.introspection.get_indexes(cursor, table_name) except NotImplementedError: indexes = {} - for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)): + for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)): att_name, iatt_name = row[0].lower(), row[0] comment_notes = [] # Holds Field notes, to be displayed in a Python comment. extra_params = {} # Holds Field parameters such as 'db_column'. @@ -133,12 +131,12 @@ class Command(InspectCommand): if srid != 4326: extra_params['srid'] = srid else: try: - field_type = introspection_module.DATA_TYPES_REVERSE[row[1]] + field_type = connection.introspection.data_types_reverse[row[1]] except KeyError: field_type = 'TextField' comment_notes.append('This field type is a guess.') - # This is a hook for DATA_TYPES_REVERSE to return a tuple of + # This is a hook for data_types_reverse to return a tuple of # (field_type, extra_params_dict). if type(field_type) is tuple: field_type, new_params = field_type diff --git a/django/core/management/commands/dbshell.py b/django/core/management/commands/dbshell.py index ec2a961530..18faa6a130 100644 --- a/django/core/management/commands/dbshell.py +++ b/django/core/management/commands/dbshell.py @@ -6,5 +6,5 @@ class Command(NoArgsCommand): requires_model_validation = False def handle_noargs(self, **options): - from django.db import runshell - runshell() + from django.db import connection + connection.client.runshell() diff --git a/django/core/management/commands/inspectdb.py b/django/core/management/commands/inspectdb.py index 11bc390289..d7d17fd0f3 100644 --- a/django/core/management/commands/inspectdb.py +++ b/django/core/management/commands/inspectdb.py @@ -13,11 +13,9 @@ class Command(NoArgsCommand): raise CommandError("Database inspection isn't supported for the currently selected database backend.") def handle_inspection(self): - from django.db import connection, get_introspection_module + from django.db import connection import keyword - introspection_module = get_introspection_module() - table2model = lambda table_name: table_name.title().replace('_', '') cursor = connection.cursor() @@ -32,17 +30,17 @@ class Command(NoArgsCommand): yield '' yield 'from django.db import models' yield '' - for table_name in introspection_module.get_table_list(cursor): + for table_name in connection.introspection.get_table_list(cursor): yield 'class %s(models.Model):' % table2model(table_name) try: - relations = introspection_module.get_relations(cursor, table_name) + relations = connection.introspection.get_relations(cursor, table_name) except NotImplementedError: relations = {} try: - indexes = introspection_module.get_indexes(cursor, table_name) + indexes = connection.introspection.get_indexes(cursor, table_name) except NotImplementedError: indexes = {} - for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)): + for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)): att_name = row[0].lower() comment_notes = [] # Holds Field notes, to be displayed in a Python comment. extra_params = {} # Holds Field parameters such as 'db_column'. @@ -65,7 +63,7 @@ class Command(NoArgsCommand): extra_params['db_column'] = att_name else: try: - field_type = introspection_module.DATA_TYPES_REVERSE[row[1]] + field_type = connection.introspection.data_types_reverse[row[1]] except KeyError: field_type = 'TextField' comment_notes.append('This field type is a guess.') diff --git a/django/core/management/commands/syncdb.py b/django/core/management/commands/syncdb.py index 38d1c91abd..7aeed4971e 100644 --- a/django/core/management/commands/syncdb.py +++ b/django/core/management/commands/syncdb.py @@ -21,7 +21,7 @@ class Command(NoArgsCommand): def handle_noargs(self, **options): from django.db import connection, transaction, models from django.conf import settings - from django.core.management.sql import table_names, installed_models, sql_model_create, sql_for_pending_references, many_to_many_sql_for_model, custom_sql_for_model, sql_indexes_for_model, emit_post_sync_signal + from django.core.management.sql import custom_sql_for_model, emit_post_sync_signal verbosity = int(options.get('verbosity', 1)) interactive = options.get('interactive') @@ -50,16 +50,9 @@ class Command(NoArgsCommand): cursor = connection.cursor() - if connection.features.uses_case_insensitive_names: - table_name_converter = lambda x: x.upper() - else: - table_name_converter = lambda x: x - # Get a list of all existing database tables, so we know what needs to - # be added. - tables = [table_name_converter(name) for name in table_names()] - # Get a list of already installed *models* so that references work right. - seen_models = installed_models(tables) + tables = connection.introspection.table_names() + seen_models = connection.introspection.installed_models(tables) created_models = set() pending_references = {} @@ -71,21 +64,21 @@ class Command(NoArgsCommand): # Create the model's database table, if it doesn't already exist. if verbosity >= 2: print "Processing %s.%s model" % (app_name, model._meta.object_name) - if table_name_converter(model._meta.db_table) in tables: + if connection.introspection.table_name_converter(model._meta.db_table) in tables: continue - sql, references = sql_model_create(model, self.style, seen_models) + sql, references = connection.creation.sql_create_model(model, self.style, seen_models) seen_models.add(model) created_models.add(model) for refto, refs in references.items(): pending_references.setdefault(refto, []).extend(refs) if refto in seen_models: - sql.extend(sql_for_pending_references(refto, self.style, pending_references)) - sql.extend(sql_for_pending_references(model, self.style, pending_references)) + sql.extend(connection.creation.sql_for_pending_references(refto, self.style, pending_references)) + sql.extend(connection.creation.sql_for_pending_references(model, self.style, pending_references)) if verbosity >= 1: print "Creating table %s" % model._meta.db_table for statement in sql: cursor.execute(statement) - tables.append(table_name_converter(model._meta.db_table)) + tables.append(connection.introspection.table_name_converter(model._meta.db_table)) # Create the m2m tables. This must be done after all tables have been created # to ensure that all referred tables will exist. @@ -94,7 +87,7 @@ class Command(NoArgsCommand): model_list = models.get_models(app) for model in model_list: if model in created_models: - sql = many_to_many_sql_for_model(model, self.style) + sql = connection.creation.sql_for_many_to_many(model, self.style) if sql: if verbosity >= 2: print "Creating many-to-many tables for %s.%s model" % (app_name, model._meta.object_name) @@ -140,7 +133,7 @@ class Command(NoArgsCommand): app_name = app.__name__.split('.')[-2] for model in models.get_models(app): if model in created_models: - index_sql = sql_indexes_for_model(model, self.style) + index_sql = connection.creation.sql_indexes_for_model(model, self.style) if index_sql: if verbosity >= 1: print "Installing index for %s.%s model" % (app_name, model._meta.object_name) diff --git a/django/core/management/commands/testserver.py b/django/core/management/commands/testserver.py index b409bc91d1..78983e73d6 100644 --- a/django/core/management/commands/testserver.py +++ b/django/core/management/commands/testserver.py @@ -18,13 +18,13 @@ class Command(BaseCommand): def handle(self, *fixture_labels, **options): from django.core.management import call_command - from django.test.utils import create_test_db + from django.db import connection verbosity = int(options.get('verbosity', 1)) addrport = options.get('addrport') # Create a test database. - db_name = create_test_db(verbosity=verbosity) + db_name = connection.creation.create_test_db(verbosity=verbosity) # Import the fixture data into the test database. call_command('loaddata', *fixture_labels, **{'verbosity': verbosity}) diff --git a/django/core/management/sql.py b/django/core/management/sql.py index 2cca3c3469..da63f7e640 100644 --- a/django/core/management/sql.py +++ b/django/core/management/sql.py @@ -7,65 +7,9 @@ try: except NameError: from sets import Set as set # Python 2.3 fallback -def table_names(): - "Returns a list of all table names that exist in the database." - from django.db import connection, get_introspection_module - cursor = connection.cursor() - return set(get_introspection_module().get_table_list(cursor)) - -def django_table_names(only_existing=False): - """ - Returns a list of all table names that have associated Django models and - are in INSTALLED_APPS. - - If only_existing is True, the resulting list will only include the tables - that actually exist in the database. - """ - from django.db import models - tables = set() - for app in models.get_apps(): - for model in models.get_models(app): - tables.add(model._meta.db_table) - tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many]) - if only_existing: - tables = [t for t in tables if t in table_names()] - return tables - -def installed_models(table_list): - "Returns a set of all models that are installed, given a list of existing table names." - from django.db import connection, models - all_models = [] - for app in models.get_apps(): - for model in models.get_models(app): - all_models.append(model) - if connection.features.uses_case_insensitive_names: - converter = lambda x: x.upper() - else: - converter = lambda x: x - return set([m for m in all_models if converter(m._meta.db_table) in map(converter, table_list)]) - -def sequence_list(): - "Returns a list of information about all DB sequences for all models in all apps." - from django.db import models - - apps = models.get_apps() - sequence_list = [] - - for app in apps: - for model in models.get_models(app): - for f in model._meta.local_fields: - if isinstance(f, models.AutoField): - sequence_list.append({'table': model._meta.db_table, 'column': f.column}) - break # Only one AutoField is allowed per model, so don't bother continuing. - - for f in model._meta.local_many_to_many: - sequence_list.append({'table': f.m2m_db_table(), 'column': None}) - - return sequence_list - def sql_create(app, style): "Returns a list of the CREATE TABLE SQL statements for the given app." - from django.db import models + from django.db import connection, models from django.conf import settings if settings.DATABASE_ENGINE == 'dummy': @@ -81,23 +25,24 @@ def sql_create(app, style): # we can be conservative). app_models = models.get_models(app) final_output = [] - known_models = set([model for model in installed_models(table_names()) if model not in app_models]) + tables = connection.introspection.table_names() + known_models = set([model for model in connection.introspection.installed_models(tables) if model not in app_models]) pending_references = {} for model in app_models: - output, references = sql_model_create(model, style, known_models) + output, references = connection.creation.sql_create_model(model, style, known_models) final_output.extend(output) for refto, refs in references.items(): pending_references.setdefault(refto, []).extend(refs) if refto in known_models: - final_output.extend(sql_for_pending_references(refto, style, pending_references)) - final_output.extend(sql_for_pending_references(model, style, pending_references)) + final_output.extend(connection.creation.sql_for_pending_references(refto, style, pending_references)) + final_output.extend(connection.creation.sql_for_pending_references(model, style, pending_references)) # Keep track of the fact that we've created the table for this model. known_models.add(model) # Create the many-to-many join tables. for model in app_models: - final_output.extend(many_to_many_sql_for_model(model, style)) + final_output.extend(connection.creation.sql_for_many_to_many(model, style)) # Handle references to tables that are from other apps # but don't exist physically. @@ -106,7 +51,7 @@ def sql_create(app, style): alter_sql = [] for model in not_installed_models: alter_sql.extend(['-- ' + sql for sql in - sql_for_pending_references(model, style, pending_references)]) + connection.creation.sql_for_pending_references(model, style, pending_references)]) if alter_sql: final_output.append('-- The following references should be added but depend on non-existent tables:') final_output.extend(alter_sql) @@ -115,10 +60,9 @@ def sql_create(app, style): def sql_delete(app, style): "Returns a list of the DROP TABLE SQL statements for the given app." - from django.db import connection, models, get_introspection_module + from django.db import connection, models from django.db.backends.util import truncate_name from django.contrib.contenttypes import generic - introspection = get_introspection_module() # This should work even if a connection isn't available try: @@ -128,16 +72,11 @@ def sql_delete(app, style): # Figure out which tables already exist if cursor: - table_names = introspection.get_table_list(cursor) + table_names = connection.introspection.get_table_list(cursor) else: table_names = [] - if connection.features.uses_case_insensitive_names: - table_name_converter = lambda x: x.upper() - else: - table_name_converter = lambda x: x output = [] - qn = connection.ops.quote_name # Output DROP TABLE statements for standard application tables. to_delete = set() @@ -145,7 +84,7 @@ def sql_delete(app, style): references_to_delete = {} app_models = models.get_models(app) for model in app_models: - if cursor and table_name_converter(model._meta.db_table) in table_names: + if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names: # The table exists, so it needs to be dropped opts = model._meta for f in opts.local_fields: @@ -155,40 +94,15 @@ def sql_delete(app, style): to_delete.add(model) for model in app_models: - if cursor and table_name_converter(model._meta.db_table) in table_names: - # Drop the table now - output.append('%s %s;' % (style.SQL_KEYWORD('DROP TABLE'), - style.SQL_TABLE(qn(model._meta.db_table)))) - if connection.features.supports_constraints and model in references_to_delete: - for rel_class, f in references_to_delete[model]: - table = rel_class._meta.db_table - col = f.column - r_table = model._meta.db_table - r_col = model._meta.get_field(f.rel.field_name).column - r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table)))) - output.append('%s %s %s %s;' % \ - (style.SQL_KEYWORD('ALTER TABLE'), - style.SQL_TABLE(qn(table)), - style.SQL_KEYWORD(connection.ops.drop_foreignkey_sql()), - style.SQL_FIELD(truncate_name(r_name, connection.ops.max_name_length())))) - del references_to_delete[model] - if model._meta.has_auto_field: - ds = connection.ops.drop_sequence_sql(model._meta.db_table) - if ds: - output.append(ds) + if connection.introspection.table_name_converter(model._meta.db_table) in table_names: + output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style)) # Output DROP TABLE statements for many-to-many tables. for model in app_models: opts = model._meta for f in opts.local_many_to_many: - if not f.creates_table: - continue - if cursor and table_name_converter(f.m2m_db_table()) in table_names: - output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'), - style.SQL_TABLE(qn(f.m2m_db_table())))) - ds = connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column)) - if ds: - output.append(ds) + if cursor and connection.introspection.table_name_converter(f.m2m_db_table()) in table_names: + output.extend(connection.creation.sql_destroy_many_to_many(model, f, style)) app_label = app_models[0]._meta.app_label @@ -213,10 +127,10 @@ def sql_flush(style, only_django=False): """ from django.db import connection if only_django: - tables = django_table_names() + tables = connection.introspection.django_table_names() else: - tables = table_names() - statements = connection.ops.sql_flush(style, tables, sequence_list()) + tables = connection.introspection.table_names() + statements = connection.ops.sql_flush(style, tables, connection.introspection.sequence_list()) return statements def sql_custom(app, style): @@ -234,198 +148,16 @@ def sql_custom(app, style): def sql_indexes(app, style): "Returns a list of the CREATE INDEX SQL statements for all models in the given app." - from django.db import models + from django.db import connection, models output = [] for model in models.get_models(app): - output.extend(sql_indexes_for_model(model, style)) + output.extend(connection.creation.sql_indexes_for_model(model, style)) return output def sql_all(app, style): "Returns a list of CREATE TABLE SQL, initial-data inserts, and CREATE INDEX SQL for the given module." return sql_create(app, style) + sql_custom(app, style) + sql_indexes(app, style) -def sql_model_create(model, style, known_models=set()): - """ - Returns the SQL required to create a single model, as a tuple of: - (list_of_sql, pending_references_dict) - """ - from django.db import connection, models - - opts = model._meta - final_output = [] - table_output = [] - pending_references = {} - qn = connection.ops.quote_name - inline_references = connection.features.inline_fk_references - for f in opts.local_fields: - col_type = f.db_type() - tablespace = f.db_tablespace or opts.db_tablespace - if col_type is None: - # Skip ManyToManyFields, because they're not represented as - # database columns in this table. - continue - # Make the definition (e.g. 'foo VARCHAR(30)') for this field. - field_output = [style.SQL_FIELD(qn(f.column)), - style.SQL_COLTYPE(col_type)] - field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or ''))) - if f.primary_key: - field_output.append(style.SQL_KEYWORD('PRIMARY KEY')) - elif f.unique: - field_output.append(style.SQL_KEYWORD('UNIQUE')) - if tablespace and connection.features.supports_tablespaces and f.unique: - # We must specify the index tablespace inline, because we - # won't be generating a CREATE INDEX statement for this field. - field_output.append(connection.ops.tablespace_sql(tablespace, inline=True)) - if f.rel: - if inline_references and f.rel.to in known_models: - field_output.append(style.SQL_KEYWORD('REFERENCES') + ' ' + \ - style.SQL_TABLE(qn(f.rel.to._meta.db_table)) + ' (' + \ - style.SQL_FIELD(qn(f.rel.to._meta.get_field(f.rel.field_name).column)) + ')' + - connection.ops.deferrable_sql() - ) - else: - # We haven't yet created the table to which this field - # is related, so save it for later. - pr = pending_references.setdefault(f.rel.to, []).append((model, f)) - table_output.append(' '.join(field_output)) - if opts.order_with_respect_to: - table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \ - style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \ - style.SQL_KEYWORD('NULL')) - for field_constraints in opts.unique_together: - table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \ - ", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints])) - - full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' ('] - for i, line in enumerate(table_output): # Combine and add commas. - full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or '')) - full_statement.append(')') - if opts.db_tablespace and connection.features.supports_tablespaces: - full_statement.append(connection.ops.tablespace_sql(opts.db_tablespace)) - full_statement.append(';') - final_output.append('\n'.join(full_statement)) - - if opts.has_auto_field: - # Add any extra SQL needed to support auto-incrementing primary keys. - auto_column = opts.auto_field.db_column or opts.auto_field.name - autoinc_sql = connection.ops.autoinc_sql(opts.db_table, auto_column) - if autoinc_sql: - for stmt in autoinc_sql: - final_output.append(stmt) - - return final_output, pending_references - -def sql_for_pending_references(model, style, pending_references): - """ - Returns any ALTER TABLE statements to add constraints after the fact. - """ - from django.db import connection - from django.db.backends.util import truncate_name - - qn = connection.ops.quote_name - final_output = [] - if connection.features.supports_constraints: - opts = model._meta - if model in pending_references: - for rel_class, f in pending_references[model]: - rel_opts = rel_class._meta - r_table = rel_opts.db_table - r_col = f.column - table = opts.db_table - col = opts.get_field(f.rel.field_name).column - # For MySQL, r_name must be unique in the first 64 characters. - # So we are careful with character usage here. - r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table)))) - final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \ - (qn(r_table), truncate_name(r_name, connection.ops.max_name_length()), - qn(r_col), qn(table), qn(col), - connection.ops.deferrable_sql())) - del pending_references[model] - return final_output - -def many_to_many_sql_for_model(model, style): - from django.db import connection, models - from django.contrib.contenttypes import generic - from django.db.backends.util import truncate_name - - opts = model._meta - final_output = [] - qn = connection.ops.quote_name - inline_references = connection.features.inline_fk_references - for f in opts.local_many_to_many: - if f.creates_table: - tablespace = f.db_tablespace or opts.db_tablespace - if tablespace and connection.features.supports_tablespaces: - tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True) - else: - tablespace_sql = '' - table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \ - style.SQL_TABLE(qn(f.m2m_db_table())) + ' ('] - table_output.append(' %s %s %s%s,' % - (style.SQL_FIELD(qn('id')), - style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()), - style.SQL_KEYWORD('NOT NULL PRIMARY KEY'), - tablespace_sql)) - if inline_references: - deferred = [] - table_output.append(' %s %s %s %s (%s)%s,' % - (style.SQL_FIELD(qn(f.m2m_column_name())), - style.SQL_COLTYPE(models.ForeignKey(model).db_type()), - style.SQL_KEYWORD('NOT NULL REFERENCES'), - style.SQL_TABLE(qn(opts.db_table)), - style.SQL_FIELD(qn(opts.pk.column)), - connection.ops.deferrable_sql())) - table_output.append(' %s %s %s %s (%s)%s,' % - (style.SQL_FIELD(qn(f.m2m_reverse_name())), - style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()), - style.SQL_KEYWORD('NOT NULL REFERENCES'), - style.SQL_TABLE(qn(f.rel.to._meta.db_table)), - style.SQL_FIELD(qn(f.rel.to._meta.pk.column)), - connection.ops.deferrable_sql())) - else: - table_output.append(' %s %s %s,' % - (style.SQL_FIELD(qn(f.m2m_column_name())), - style.SQL_COLTYPE(models.ForeignKey(model).db_type()), - style.SQL_KEYWORD('NOT NULL'))) - table_output.append(' %s %s %s,' % - (style.SQL_FIELD(qn(f.m2m_reverse_name())), - style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()), - style.SQL_KEYWORD('NOT NULL'))) - deferred = [ - (f.m2m_db_table(), f.m2m_column_name(), opts.db_table, - opts.pk.column), - ( f.m2m_db_table(), f.m2m_reverse_name(), - f.rel.to._meta.db_table, f.rel.to._meta.pk.column) - ] - table_output.append(' %s (%s, %s)%s' % - (style.SQL_KEYWORD('UNIQUE'), - style.SQL_FIELD(qn(f.m2m_column_name())), - style.SQL_FIELD(qn(f.m2m_reverse_name())), - tablespace_sql)) - table_output.append(')') - if opts.db_tablespace and connection.features.supports_tablespaces: - # f.db_tablespace is only for indices, so ignore its value here. - table_output.append(connection.ops.tablespace_sql(opts.db_tablespace)) - table_output.append(';') - final_output.append('\n'.join(table_output)) - - for r_table, r_col, table, col in deferred: - r_name = '%s_refs_%s_%x' % (r_col, col, - abs(hash((r_table, table)))) - final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % - (qn(r_table), - truncate_name(r_name, connection.ops.max_name_length()), - qn(r_col), qn(table), qn(col), - connection.ops.deferrable_sql())) - - # Add any extra SQL needed to support auto-incrementing PKs - autoinc_sql = connection.ops.autoinc_sql(f.m2m_db_table(), 'id') - if autoinc_sql: - for stmt in autoinc_sql: - final_output.append(stmt) - - return final_output - def custom_sql_for_model(model, style): from django.db import models from django.conf import settings @@ -461,28 +193,6 @@ def custom_sql_for_model(model, style): return output -def sql_indexes_for_model(model, style): - "Returns the CREATE INDEX SQL statements for a single model" - from django.db import connection - output = [] - - qn = connection.ops.quote_name - for f in model._meta.local_fields: - if f.db_index and not f.unique: - tablespace = f.db_tablespace or model._meta.db_tablespace - if tablespace and connection.features.supports_tablespaces: - tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace) - else: - tablespace_sql = '' - output.append( - style.SQL_KEYWORD('CREATE INDEX') + ' ' + \ - style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' + \ - style.SQL_KEYWORD('ON') + ' ' + \ - style.SQL_TABLE(qn(model._meta.db_table)) + ' ' + \ - "(%s)" % style.SQL_FIELD(qn(f.column)) + \ - "%s;" % tablespace_sql - ) - return output def emit_post_sync_signal(created_models, verbosity, interactive): from django.db import models diff --git a/django/core/management/validation.py b/django/core/management/validation.py index e9d7b53027..e561483af0 100644 --- a/django/core/management/validation.py +++ b/django/core/management/validation.py @@ -61,11 +61,8 @@ def get_validation_errors(outfile, app=None): if f.db_index not in (None, True, False): e.add(opts, '"%s": "db_index" should be either None, True or False.' % f.name) - # Check that max_length <= 255 if using older MySQL versions. - if settings.DATABASE_ENGINE == 'mysql': - db_version = connection.get_server_version() - if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255: - e.add(opts, '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]]))) + # Perform any backend-specific field validation. + connection.validation.validate_field(e, opts, f) # Check to see if the related field will clash with any existing # fields, m2m fields, m2m related objects or related objects diff --git a/django/db/__init__.py b/django/db/__init__.py index 37c8837b8b..73e97a481f 100644 --- a/django/db/__init__.py +++ b/django/db/__init__.py @@ -14,14 +14,12 @@ try: # backends that ships with Django, so look there first. _import_path = 'django.db.backends.' backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, ['']) - creation = __import__('%s%s.creation' % (_import_path, settings.DATABASE_ENGINE), {}, {}, ['']) except ImportError, e: # If the import failed, we might be looking for a database backend # distributed external to Django. So we'll try that next. try: _import_path = '' backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, ['']) - creation = __import__('%s.creation' % settings.DATABASE_ENGINE, {}, {}, ['']) except ImportError, e_user: # The database backend wasn't found. Display a helpful error message # listing all possible (built-in) database backends. @@ -29,27 +27,11 @@ except ImportError, e: available_backends = [f for f in os.listdir(backend_dir) if not f.startswith('_') and not f.startswith('.') and not f.endswith('.py') and not f.endswith('.pyc')] available_backends.sort() if settings.DATABASE_ENGINE not in available_backends: - raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s" % \ - (settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends))) + raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s\nError was: %s" % \ + (settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends, e_user))) else: raise # If there's some other error, this must be an error in Django itself. -def _import_database_module(import_path='', module_name=''): - """Lazily import a database module when requested.""" - return __import__('%s%s.%s' % (import_path, settings.DATABASE_ENGINE, module_name), {}, {}, ['']) - -# We don't want to import the introspect module unless someone asks for it, so -# lazily load it on demmand. -get_introspection_module = curry(_import_database_module, _import_path, 'introspection') - -def get_creation_module(): - return creation - -# We want runshell() to work the same way, but we have to treat it a -# little differently (since it just runs instead of returning a module like -# the above) and wrap the lazily-loaded runshell() method. -runshell = lambda: _import_database_module(_import_path, "client").runshell() - # Convenient aliases for backend bits. connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS) DatabaseError = backend.DatabaseError diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index d65eacd042..3748fb4b4f 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -42,14 +42,9 @@ class BaseDatabaseWrapper(local): return util.CursorDebugWrapper(cursor, self) class BaseDatabaseFeatures(object): - allows_group_by_ordinal = True - inline_fk_references = True # True if django.db.backend.utils.typecast_timestamp is used on values # returned from dates() calls. needs_datetime_string_cast = True - supports_constraints = True - supports_tablespaces = False - uses_case_insensitive_names = False uses_custom_query_class = False empty_fetchmany_value = [] update_can_self_select = True @@ -253,13 +248,13 @@ class BaseDatabaseOperations(object): """ return "BEGIN;" - def tablespace_sql(self, tablespace, inline=False): + def sql_for_tablespace(self, tablespace, inline=False): """ - Returns the tablespace SQL, or None if the backend doesn't use - tablespaces. + Returns the SQL that will be appended to tables or rows to define + a tablespace. Returns '' if the backend doesn't use tablespaces. """ - return None - + return '' + def prep_for_like_query(self, x): """Prepares a value for use in a LIKE query.""" from django.utils.encoding import smart_unicode @@ -325,3 +320,89 @@ class BaseDatabaseOperations(object): """ return self.year_lookup_bounds(value) +class BaseDatabaseIntrospection(object): + """ + This class encapsulates all backend-specific introspection utilities + """ + data_types_reverse = {} + + def __init__(self, connection): + self.connection = connection + + def table_name_converter(self, name): + """Apply a conversion to the name for the purposes of comparison. + + The default table name converter is for case sensitive comparison. + """ + return name + + def table_names(self): + "Returns a list of names of all tables that exist in the database." + cursor = self.connection.cursor() + return self.get_table_list(cursor) + + def django_table_names(self, only_existing=False): + """ + Returns a list of all table names that have associated Django models and + are in INSTALLED_APPS. + + If only_existing is True, the resulting list will only include the tables + that actually exist in the database. + """ + from django.db import models + tables = set() + for app in models.get_apps(): + for model in models.get_models(app): + tables.add(model._meta.db_table) + tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many]) + if only_existing: + tables = [t for t in tables if t in self.table_names()] + return tables + + def installed_models(self, tables): + "Returns a set of all models represented by the provided list of table names." + from django.db import models + all_models = [] + for app in models.get_apps(): + for model in models.get_models(app): + all_models.append(model) + return set([m for m in all_models + if self.table_name_converter(m._meta.db_table) in map(self.table_name_converter, tables) + ]) + + def sequence_list(self): + "Returns a list of information about all DB sequences for all models in all apps." + from django.db import models + + apps = models.get_apps() + sequence_list = [] + + for app in apps: + for model in models.get_models(app): + for f in model._meta.local_fields: + if isinstance(f, models.AutoField): + sequence_list.append({'table': model._meta.db_table, 'column': f.column}) + break # Only one AutoField is allowed per model, so don't bother continuing. + + for f in model._meta.local_many_to_many: + sequence_list.append({'table': f.m2m_db_table(), 'column': None}) + + return sequence_list + + +class BaseDatabaseClient(object): + """ + This class encapsualtes all backend-specific methods for opening a + client shell + """ + def runshell(self): + raise NotImplementedError() + +class BaseDatabaseValidation(object): + """ + This class encapsualtes all backend-specific model validation. + """ + def validate_field(self, errors, opts, f): + "By default, there is no backend-specific validation" + pass + diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index 4071cef6aa..a462be8251 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -1,7 +1,396 @@ -class BaseCreation(object): +import sys +import time + +from django.conf import settings +from django.core.management import call_command + +# The prefix to put on the default database name when creating +# the test database. +TEST_DATABASE_PREFIX = 'test_' + +class BaseDatabaseCreation(object): """ This class encapsulates all backend-specific differences that pertain to database *creation*, such as the column types to use for particular Django - Fields. + Fields, the SQL used to create and destroy tables, and the creation and + destruction of test databases. """ - pass + data_types = {} + + def __init__(self, connection): + self.connection = connection + + def sql_create_model(self, model, style, known_models=set()): + """ + Returns the SQL required to create a single model, as a tuple of: + (list_of_sql, pending_references_dict) + """ + from django.db import models + + opts = model._meta + final_output = [] + table_output = [] + pending_references = {} + qn = self.connection.ops.quote_name + for f in opts.local_fields: + col_type = f.db_type() + tablespace = f.db_tablespace or opts.db_tablespace + if col_type is None: + # Skip ManyToManyFields, because they're not represented as + # database columns in this table. + continue + # Make the definition (e.g. 'foo VARCHAR(30)') for this field. + field_output = [style.SQL_FIELD(qn(f.column)), + style.SQL_COLTYPE(col_type)] + field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or ''))) + if f.primary_key: + field_output.append(style.SQL_KEYWORD('PRIMARY KEY')) + elif f.unique: + field_output.append(style.SQL_KEYWORD('UNIQUE')) + if tablespace and f.unique: + # We must specify the index tablespace inline, because we + # won't be generating a CREATE INDEX statement for this field. + field_output.append(self.connection.ops.tablespace_sql(tablespace, inline=True)) + if f.rel: + ref_output, pending = self.sql_for_inline_foreign_key_references(f, known_models, style) + if pending: + pr = pending_references.setdefault(f.rel.to, []).append((model, f)) + else: + field_output.extend(ref_output) + table_output.append(' '.join(field_output)) + if opts.order_with_respect_to: + table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \ + style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \ + style.SQL_KEYWORD('NULL')) + for field_constraints in opts.unique_together: + table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \ + ", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints])) + + full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' ('] + for i, line in enumerate(table_output): # Combine and add commas. + full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or '')) + full_statement.append(')') + if opts.db_tablespace: + full_statement.append(self.connection.ops.tablespace_sql(opts.db_tablespace)) + full_statement.append(';') + final_output.append('\n'.join(full_statement)) + + if opts.has_auto_field: + # Add any extra SQL needed to support auto-incrementing primary keys. + auto_column = opts.auto_field.db_column or opts.auto_field.name + autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, auto_column) + if autoinc_sql: + for stmt in autoinc_sql: + final_output.append(stmt) + + return final_output, pending_references + + def sql_for_inline_foreign_key_references(self, field, known_models, style): + "Return the SQL snippet defining the foreign key reference for a field" + qn = self.connection.ops.quote_name + if field.rel.to in known_models: + output = [style.SQL_KEYWORD('REFERENCES') + ' ' + \ + style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + \ + style.SQL_FIELD(qn(field.rel.to._meta.get_field(field.rel.field_name).column)) + ')' + + self.connection.ops.deferrable_sql() + ] + pending = False + else: + # We haven't yet created the table to which this field + # is related, so save it for later. + output = [] + pending = True + + return output, pending + + def sql_for_pending_references(self, model, style, pending_references): + "Returns any ALTER TABLE statements to add constraints after the fact." + from django.db.backends.util import truncate_name + + qn = self.connection.ops.quote_name + final_output = [] + opts = model._meta + if model in pending_references: + for rel_class, f in pending_references[model]: + rel_opts = rel_class._meta + r_table = rel_opts.db_table + r_col = f.column + table = opts.db_table + col = opts.get_field(f.rel.field_name).column + # For MySQL, r_name must be unique in the first 64 characters. + # So we are careful with character usage here. + r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table)))) + final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \ + (qn(r_table), truncate_name(r_name, self.connection.ops.max_name_length()), + qn(r_col), qn(table), qn(col), + self.connection.ops.deferrable_sql())) + del pending_references[model] + return final_output + + def sql_for_many_to_many(self, model, style): + "Return the CREATE TABLE statments for all the many-to-many tables defined on a model" + output = [] + for f in model._meta.local_many_to_many: + output.extend(self.sql_for_many_to_many_field(model, f, style)) + return output + + def sql_for_many_to_many_field(self, model, f, style): + "Return the CREATE TABLE statements for a single m2m field" + from django.db import models + from django.db.backends.util import truncate_name + + output = [] + if f.creates_table: + opts = model._meta + qn = self.connection.ops.quote_name + tablespace = f.db_tablespace or opts.db_tablespace + if tablespace: + sql = self.connection.ops.tablespace_sql(tablespace, inline=True) + if sql: + tablespace_sql = ' ' + sql + else: + tablespace_sql = '' + else: + tablespace_sql = '' + table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \ + style.SQL_TABLE(qn(f.m2m_db_table())) + ' ('] + table_output.append(' %s %s %s%s,' % + (style.SQL_FIELD(qn('id')), + style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()), + style.SQL_KEYWORD('NOT NULL PRIMARY KEY'), + tablespace_sql)) + + deferred = [] + inline_output, deferred = self.sql_for_inline_many_to_many_references(model, f, style) + table_output.extend(inline_output) + + table_output.append(' %s (%s, %s)%s' % + (style.SQL_KEYWORD('UNIQUE'), + style.SQL_FIELD(qn(f.m2m_column_name())), + style.SQL_FIELD(qn(f.m2m_reverse_name())), + tablespace_sql)) + table_output.append(')') + if opts.db_tablespace: + # f.db_tablespace is only for indices, so ignore its value here. + table_output.append(self.connection.ops.tablespace_sql(opts.db_tablespace)) + table_output.append(';') + output.append('\n'.join(table_output)) + + for r_table, r_col, table, col in deferred: + r_name = '%s_refs_%s_%x' % (r_col, col, + abs(hash((r_table, table)))) + output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % + (qn(r_table), + truncate_name(r_name, self.connection.ops.max_name_length()), + qn(r_col), qn(table), qn(col), + self.connection.ops.deferrable_sql())) + + # Add any extra SQL needed to support auto-incrementing PKs + autoinc_sql = self.connection.ops.autoinc_sql(f.m2m_db_table(), 'id') + if autoinc_sql: + for stmt in autoinc_sql: + output.append(stmt) + return output + + def sql_for_inline_many_to_many_references(self, model, field, style): + "Create the references to other tables required by a many-to-many table" + from django.db import models + opts = model._meta + qn = self.connection.ops.quote_name + + table_output = [ + ' %s %s %s %s (%s)%s,' % + (style.SQL_FIELD(qn(field.m2m_column_name())), + style.SQL_COLTYPE(models.ForeignKey(model).db_type()), + style.SQL_KEYWORD('NOT NULL REFERENCES'), + style.SQL_TABLE(qn(opts.db_table)), + style.SQL_FIELD(qn(opts.pk.column)), + self.connection.ops.deferrable_sql()), + ' %s %s %s %s (%s)%s,' % + (style.SQL_FIELD(qn(field.m2m_reverse_name())), + style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()), + style.SQL_KEYWORD('NOT NULL REFERENCES'), + style.SQL_TABLE(qn(field.rel.to._meta.db_table)), + style.SQL_FIELD(qn(field.rel.to._meta.pk.column)), + self.connection.ops.deferrable_sql()) + ] + deferred = [] + + return table_output, deferred + + def sql_indexes_for_model(self, model, style): + "Returns the CREATE INDEX SQL statements for a single model" + output = [] + for f in model._meta.local_fields: + output.extend(self.sql_indexes_for_field(model, f, style)) + return output + + def sql_indexes_for_field(self, model, f, style): + "Return the CREATE INDEX SQL statements for a single model field" + if f.db_index and not f.unique: + qn = self.connection.ops.quote_name + tablespace = f.db_tablespace or model._meta.db_tablespace + if tablespace: + sql = self.connection.ops.tablespace_sql(tablespace) + if sql: + tablespace_sql = ' ' + sql + else: + tablespace_sql = '' + else: + tablespace_sql = '' + output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' + + style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' + + style.SQL_KEYWORD('ON') + ' ' + + style.SQL_TABLE(qn(model._meta.db_table)) + ' ' + + "(%s)" % style.SQL_FIELD(qn(f.column)) + + "%s;" % tablespace_sql] + else: + output = [] + return output + + def sql_destroy_model(self, model, references_to_delete, style): + "Return the DROP TABLE and restraint dropping statements for a single model" + # Drop the table now + qn = self.connection.ops.quote_name + output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'), + style.SQL_TABLE(qn(model._meta.db_table)))] + if model in references_to_delete: + output.extend(self.sql_remove_table_constraints(model, references_to_delete)) + + if model._meta.has_auto_field: + ds = self.connection.ops.drop_sequence_sql(model._meta.db_table) + if ds: + output.append(ds) + return output + + def sql_remove_table_constraints(self, model, references_to_delete): + output = [] + for rel_class, f in references_to_delete[model]: + table = rel_class._meta.db_table + col = f.column + r_table = model._meta.db_table + r_col = model._meta.get_field(f.rel.field_name).column + r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table)))) + output.append('%s %s %s %s;' % \ + (style.SQL_KEYWORD('ALTER TABLE'), + style.SQL_TABLE(qn(table)), + style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()), + style.SQL_FIELD(truncate_name(r_name, self.connection.ops.max_name_length())))) + del references_to_delete[model] + return output + + def sql_destroy_many_to_many(self, model, f, style): + "Returns the DROP TABLE statements for a single m2m field" + qn = self.connection.ops.quote_name + output = [] + if f.creates_table: + output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'), + style.SQL_TABLE(qn(f.m2m_db_table())))) + ds = self.connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column)) + if ds: + output.append(ds) + return output + + def create_test_db(self, verbosity=1, autoclobber=False): + """ + Creates a test database, prompting the user for confirmation if the + database already exists. Returns the name of the test database created. + """ + if verbosity >= 1: + print "Creating test database..." + + test_database_name = self._create_test_db(verbosity, autoclobber) + + self.connection.close() + settings.DATABASE_NAME = test_database_name + + call_command('syncdb', verbosity=verbosity, interactive=False) + + if settings.CACHE_BACKEND.startswith('db://'): + cache_name = settings.CACHE_BACKEND[len('db://'):] + call_command('createcachetable', cache_name) + + # Get a cursor (even though we don't need one yet). This has + # the side effect of initializing the test database. + cursor = self.connection.cursor() + + return test_database_name + + def _create_test_db(self, verbosity, autoclobber): + "Internal implementation - creates the test db tables." + suffix = self.sql_table_creation_suffix() + + if settings.TEST_DATABASE_NAME: + test_database_name = settings.TEST_DATABASE_NAME + else: + test_database_name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + + qn = self.connection.ops.quote_name + + # Create the test database and connect to it. We need to autocommit + # if the database supports it because PostgreSQL doesn't allow + # CREATE/DROP DATABASE statements within transactions. + cursor = self.connection.cursor() + self.set_autocommit() + try: + cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) + except Exception, e: + sys.stderr.write("Got an error creating the test database: %s\n" % e) + if not autoclobber: + confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name) + if autoclobber or confirm == 'yes': + try: + if verbosity >= 1: + print "Destroying old test database..." + cursor.execute("DROP DATABASE %s" % qn(test_database_name)) + if verbosity >= 1: + print "Creating test database..." + cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix)) + except Exception, e: + sys.stderr.write("Got an error recreating the test database: %s\n" % e) + sys.exit(2) + else: + print "Tests cancelled." + sys.exit(1) + + return test_database_name + + def destroy_test_db(self, old_database_name, verbosity=1): + """ + Destroy a test database, prompting the user for confirmation if the + database already exists. Returns the name of the test database created. + """ + if verbosity >= 1: + print "Destroying test database..." + self.connection.close() + test_database_name = settings.DATABASE_NAME + settings.DATABASE_NAME = old_database_name + + self._destroy_test_db(test_database_name, verbosity) + + def _destroy_test_db(self, test_database_name, verbosity): + "Internal implementation - remove the test db tables." + # Remove the test database to clean up after + # ourselves. Connect to the previous database (not the test database) + # to do so, because it's not allowed to delete a database while being + # connected to it. + cursor = self.connection.cursor() + self.set_autocommit() + time.sleep(1) # To avoid "database is being accessed by other users" errors. + cursor.execute("DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)) + self.connection.close() + + def set_autocommit(self): + "Make sure a connection is in autocommit mode." + if hasattr(self.connection.connection, "autocommit"): + if callable(self.connection.connection.autocommit): + self.connection.connection.autocommit(True) + else: + self.connection.connection.autocommit = True + elif hasattr(self.connection.connection, "set_isolation_level"): + self.connection.connection.set_isolation_level(0) + + def sql_table_creation_suffix(self): + "SQL to append to the end of the test table creation statements" + return '' + diff --git a/django/db/backends/dummy/base.py b/django/db/backends/dummy/base.py index fd25d3038f..530ea9c519 100644 --- a/django/db/backends/dummy/base.py +++ b/django/db/backends/dummy/base.py @@ -8,7 +8,8 @@ ImproperlyConfigured. """ from django.core.exceptions import ImproperlyConfigured -from django.db.backends import BaseDatabaseFeatures, BaseDatabaseOperations +from django.db.backends import * +from django.db.backends.creation import BaseDatabaseCreation def complain(*args, **kwargs): raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet." @@ -25,16 +26,30 @@ class IntegrityError(DatabaseError): class DatabaseOperations(BaseDatabaseOperations): quote_name = complain -class DatabaseWrapper(object): - features = BaseDatabaseFeatures() - ops = DatabaseOperations() +class DatabaseClient(BaseDatabaseClient): + runshell = complain + +class DatabaseIntrospection(BaseDatabaseIntrospection): + get_table_list = complain + get_table_description = complain + get_relations = complain + get_indexes = complain + +class DatabaseWrapper(object): operators = {} cursor = complain _commit = complain _rollback = ignore - def __init__(self, **kwargs): - pass + def __init__(self, *args, **kwargs): + super(DatabaseWrapper, self).__init__(*args, **kwargs) + + self.features = BaseDatabaseFeatures() + self.ops = DatabaseOperations() + self.client = DatabaseClient() + self.creation = BaseDatabaseCreation(self) + self.introspection = DatabaseIntrospection(self) + self.validation = BaseDatabaseValidation() def close(self): pass diff --git a/django/db/backends/dummy/client.py b/django/db/backends/dummy/client.py index e332987aa8..e69de29bb2 100644 --- a/django/db/backends/dummy/client.py +++ b/django/db/backends/dummy/client.py @@ -1,3 +0,0 @@ -from django.db.backends.dummy.base import complain - -runshell = complain diff --git a/django/db/backends/dummy/creation.py b/django/db/backends/dummy/creation.py index b82c4fe568..e69de29bb2 100644 --- a/django/db/backends/dummy/creation.py +++ b/django/db/backends/dummy/creation.py @@ -1 +0,0 @@ -DATA_TYPES = {} diff --git a/django/db/backends/dummy/introspection.py b/django/db/backends/dummy/introspection.py index c52a812046..e69de29bb2 100644 --- a/django/db/backends/dummy/introspection.py +++ b/django/db/backends/dummy/introspection.py @@ -1,8 +0,0 @@ -from django.db.backends.dummy.base import complain - -get_table_list = complain -get_table_description = complain -get_relations = complain -get_indexes = complain - -DATA_TYPES_REVERSE = {} diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 3b8d897925..a73c740fdf 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -4,7 +4,12 @@ MySQL database backend for Django. Requires MySQLdb: http://sourceforge.net/projects/mysql-python """ -from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util +from django.db.backends import * +from django.db.backends.mysql.client import DatabaseClient +from django.db.backends.mysql.creation import DatabaseCreation +from django.db.backends.mysql.introspection import DatabaseIntrospection +from django.db.backends.mysql.validation import DatabaseValidation + try: import MySQLdb as Database except ImportError, e: @@ -60,7 +65,6 @@ server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})') # TRADITIONAL will automatically cause most warnings to be treated as errors. class DatabaseFeatures(BaseDatabaseFeatures): - inline_fk_references = False empty_fetchmany_value = () update_can_self_select = False @@ -142,8 +146,7 @@ class DatabaseOperations(BaseDatabaseOperations): return [first % value, second % value] class DatabaseWrapper(BaseDatabaseWrapper): - features = DatabaseFeatures() - ops = DatabaseOperations() + operators = { 'exact': '= BINARY %s', 'iexact': 'LIKE %s', @@ -164,6 +167,13 @@ class DatabaseWrapper(BaseDatabaseWrapper): def __init__(self, **kwargs): super(DatabaseWrapper, self).__init__(**kwargs) self.server_version = None + + self.features = DatabaseFeatures() + self.ops = DatabaseOperations() + self.client = DatabaseClient() + self.creation = DatabaseCreation(self) + self.introspection = DatabaseIntrospection(self) + self.validation = DatabaseValidation() def _valid_connection(self): if self.connection is not None: diff --git a/django/db/backends/mysql/client.py b/django/db/backends/mysql/client.py index 116074a9ce..24758867af 100644 --- a/django/db/backends/mysql/client.py +++ b/django/db/backends/mysql/client.py @@ -1,27 +1,29 @@ +from django.db.backends import BaseDatabaseClient from django.conf import settings import os -def runshell(): - args = [''] - db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME) - user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER) - passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD) - host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST) - port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT) - defaults_file = settings.DATABASE_OPTIONS.get('read_default_file') - # Seems to be no good way to set sql_mode with CLI +class DatabaseClient(BaseDatabaseClient): + def runshell(self): + args = [''] + db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME) + user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER) + passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD) + host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST) + port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT) + defaults_file = settings.DATABASE_OPTIONS.get('read_default_file') + # Seems to be no good way to set sql_mode with CLI - if defaults_file: - args += ["--defaults-file=%s" % defaults_file] - if user: - args += ["--user=%s" % user] - if passwd: - args += ["--password=%s" % passwd] - if host: - args += ["--host=%s" % host] - if port: - args += ["--port=%s" % port] - if db: - args += [db] + if defaults_file: + args += ["--defaults-file=%s" % defaults_file] + if user: + args += ["--user=%s" % user] + if passwd: + args += ["--password=%s" % passwd] + if host: + args += ["--host=%s" % host] + if port: + args += ["--port=%s" % port] + if db: + args += [db] - os.execvp('mysql', args) + os.execvp('mysql', args) diff --git a/django/db/backends/mysql/creation.py b/django/db/backends/mysql/creation.py index 698b07548d..96faaacb75 100644 --- a/django/db/backends/mysql/creation.py +++ b/django/db/backends/mysql/creation.py @@ -1,28 +1,68 @@ -# This dictionary maps Field objects to their associated MySQL column -# types, as strings. Column-type strings can contain format strings; they'll -# be interpolated against the values of Field.__dict__ before being output. -# If a column type is set to None, it won't be included in the output. -DATA_TYPES = { - 'AutoField': 'integer AUTO_INCREMENT', - 'BooleanField': 'bool', - 'CharField': 'varchar(%(max_length)s)', - 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'datetime', - 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'double precision', - 'IntegerField': 'integer', - 'IPAddressField': 'char(15)', - 'NullBooleanField': 'bool', - 'OneToOneField': 'integer', - 'PhoneNumberField': 'varchar(20)', - 'PositiveIntegerField': 'integer UNSIGNED', - 'PositiveSmallIntegerField': 'smallint UNSIGNED', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallIntegerField': 'smallint', - 'TextField': 'longtext', - 'TimeField': 'time', - 'USStateField': 'varchar(2)', -} +from django.conf import settings +from django.db.backends.creation import BaseDatabaseCreation + +class DatabaseCreation(BaseDatabaseCreation): + # This dictionary maps Field objects to their associated MySQL column + # types, as strings. Column-type strings can contain format strings; they'll + # be interpolated against the values of Field.__dict__ before being output. + # If a column type is set to None, it won't be included in the output. + data_types = { + 'AutoField': 'integer AUTO_INCREMENT', + 'BooleanField': 'bool', + 'CharField': 'varchar(%(max_length)s)', + 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', + 'DateField': 'date', + 'DateTimeField': 'datetime', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'varchar(%(max_length)s)', + 'FilePathField': 'varchar(%(max_length)s)', + 'FloatField': 'double precision', + 'IntegerField': 'integer', + 'IPAddressField': 'char(15)', + 'NullBooleanField': 'bool', + 'OneToOneField': 'integer', + 'PhoneNumberField': 'varchar(20)', + 'PositiveIntegerField': 'integer UNSIGNED', + 'PositiveSmallIntegerField': 'smallint UNSIGNED', + 'SlugField': 'varchar(%(max_length)s)', + 'SmallIntegerField': 'smallint', + 'TextField': 'longtext', + 'TimeField': 'time', + 'USStateField': 'varchar(2)', + } + + def sql_table_creation_suffix(self): + suffix = [] + if settings.TEST_DATABASE_CHARSET: + suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET) + if settings.TEST_DATABASE_COLLATION: + suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION) + return ' '.join(suffix) + + def sql_for_inline_foreign_key_references(self, field, known_models, style): + "All inline references are pending under MySQL" + return [], True + + def sql_for_inline_many_to_many_references(self, model, field, style): + from django.db import models + opts = model._meta + qn = self.connection.ops.quote_name + + table_output = [ + ' %s %s %s,' % + (style.SQL_FIELD(qn(field.m2m_column_name())), + style.SQL_COLTYPE(models.ForeignKey(model).db_type()), + style.SQL_KEYWORD('NOT NULL')), + ' %s %s %s,' % + (style.SQL_FIELD(qn(field.m2m_reverse_name())), + style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()), + style.SQL_KEYWORD('NOT NULL')) + ] + deferred = [ + (field.m2m_db_table(), field.m2m_column_name(), opts.db_table, + opts.pk.column), + (field.m2m_db_table(), field.m2m_reverse_name(), + field.rel.to._meta.db_table, field.rel.to._meta.pk.column) + ] + return table_output, deferred + \ No newline at end of file diff --git a/django/db/backends/mysql/introspection.py b/django/db/backends/mysql/introspection.py index 68fe5d99e6..0aa635a8d7 100644 --- a/django/db/backends/mysql/introspection.py +++ b/django/db/backends/mysql/introspection.py @@ -1,96 +1,97 @@ -from django.db.backends.mysql.base import DatabaseOperations +from django.db.backends import BaseDatabaseIntrospection from MySQLdb import ProgrammingError, OperationalError from MySQLdb.constants import FIELD_TYPE import re -quote_name = DatabaseOperations().quote_name foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") -def get_table_list(cursor): - "Returns a list of table names in the current database." - cursor.execute("SHOW TABLES") - return [row[0] for row in cursor.fetchall()] +class DatabaseIntrospection(BaseDatabaseIntrospection): + data_types_reverse = { + FIELD_TYPE.BLOB: 'TextField', + FIELD_TYPE.CHAR: 'CharField', + FIELD_TYPE.DECIMAL: 'DecimalField', + FIELD_TYPE.DATE: 'DateField', + FIELD_TYPE.DATETIME: 'DateTimeField', + FIELD_TYPE.DOUBLE: 'FloatField', + FIELD_TYPE.FLOAT: 'FloatField', + FIELD_TYPE.INT24: 'IntegerField', + FIELD_TYPE.LONG: 'IntegerField', + FIELD_TYPE.LONGLONG: 'IntegerField', + FIELD_TYPE.SHORT: 'IntegerField', + FIELD_TYPE.STRING: 'CharField', + FIELD_TYPE.TIMESTAMP: 'DateTimeField', + FIELD_TYPE.TINY: 'IntegerField', + FIELD_TYPE.TINY_BLOB: 'TextField', + FIELD_TYPE.MEDIUM_BLOB: 'TextField', + FIELD_TYPE.LONG_BLOB: 'TextField', + FIELD_TYPE.VAR_STRING: 'CharField', + } -def get_table_description(cursor, table_name): - "Returns a description of the table, with the DB-API cursor.description interface." - cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) - return cursor.description + def get_table_list(self, cursor): + "Returns a list of table names in the current database." + cursor.execute("SHOW TABLES") + return [row[0] for row in cursor.fetchall()] -def _name_to_index(cursor, table_name): - """ - Returns a dictionary of {field_name: field_index} for the given table. - Indexes are 0-based. - """ - return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))]) + def get_table_description(self, cursor, table_name): + "Returns a description of the table, with the DB-API cursor.description interface." + cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) + return cursor.description -def get_relations(cursor, table_name): - """ - Returns a dictionary of {field_index: (field_index_other_table, other_table)} - representing all relationships to the given table. Indexes are 0-based. - """ - my_field_dict = _name_to_index(cursor, table_name) - constraints = [] - relations = {} - try: - # This should work for MySQL 5.0. - cursor.execute(""" - SELECT column_name, referenced_table_name, referenced_column_name - FROM information_schema.key_column_usage - WHERE table_name = %s - AND table_schema = DATABASE() - AND referenced_table_name IS NOT NULL - AND referenced_column_name IS NOT NULL""", [table_name]) - constraints.extend(cursor.fetchall()) - except (ProgrammingError, OperationalError): - # Fall back to "SHOW CREATE TABLE", for previous MySQL versions. - # Go through all constraints and save the equal matches. - cursor.execute("SHOW CREATE TABLE %s" % quote_name(table_name)) + def _name_to_index(self, cursor, table_name): + """ + Returns a dictionary of {field_name: field_index} for the given table. + Indexes are 0-based. + """ + return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))]) + + def get_relations(self, cursor, table_name): + """ + Returns a dictionary of {field_index: (field_index_other_table, other_table)} + representing all relationships to the given table. Indexes are 0-based. + """ + my_field_dict = self._name_to_index(cursor, table_name) + constraints = [] + relations = {} + try: + # This should work for MySQL 5.0. + cursor.execute(""" + SELECT column_name, referenced_table_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE table_name = %s + AND table_schema = DATABASE() + AND referenced_table_name IS NOT NULL + AND referenced_column_name IS NOT NULL""", [table_name]) + constraints.extend(cursor.fetchall()) + except (ProgrammingError, OperationalError): + # Fall back to "SHOW CREATE TABLE", for previous MySQL versions. + # Go through all constraints and save the equal matches. + cursor.execute("SHOW CREATE TABLE %s" % self.connection.ops.quote_name(table_name)) + for row in cursor.fetchall(): + pos = 0 + while True: + match = foreign_key_re.search(row[1], pos) + if match == None: + break + pos = match.end() + constraints.append(match.groups()) + + for my_fieldname, other_table, other_field in constraints: + other_field_index = self._name_to_index(cursor, other_table)[other_field] + my_field_index = my_field_dict[my_fieldname] + relations[my_field_index] = (other_field_index, other_table) + + return relations + + def get_indexes(self, cursor, table_name): + """ + Returns a dictionary of fieldname -> infodict for the given table, + where each infodict is in the format: + {'primary_key': boolean representing whether it's the primary key, + 'unique': boolean representing whether it's a unique index} + """ + cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)) + indexes = {} for row in cursor.fetchall(): - pos = 0 - while True: - match = foreign_key_re.search(row[1], pos) - if match == None: - break - pos = match.end() - constraints.append(match.groups()) + indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])} + return indexes - for my_fieldname, other_table, other_field in constraints: - other_field_index = _name_to_index(cursor, other_table)[other_field] - my_field_index = my_field_dict[my_fieldname] - relations[my_field_index] = (other_field_index, other_table) - - return relations - -def get_indexes(cursor, table_name): - """ - Returns a dictionary of fieldname -> infodict for the given table, - where each infodict is in the format: - {'primary_key': boolean representing whether it's the primary key, - 'unique': boolean representing whether it's a unique index} - """ - cursor.execute("SHOW INDEX FROM %s" % quote_name(table_name)) - indexes = {} - for row in cursor.fetchall(): - indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])} - return indexes - -DATA_TYPES_REVERSE = { - FIELD_TYPE.BLOB: 'TextField', - FIELD_TYPE.CHAR: 'CharField', - FIELD_TYPE.DECIMAL: 'DecimalField', - FIELD_TYPE.DATE: 'DateField', - FIELD_TYPE.DATETIME: 'DateTimeField', - FIELD_TYPE.DOUBLE: 'FloatField', - FIELD_TYPE.FLOAT: 'FloatField', - FIELD_TYPE.INT24: 'IntegerField', - FIELD_TYPE.LONG: 'IntegerField', - FIELD_TYPE.LONGLONG: 'IntegerField', - FIELD_TYPE.SHORT: 'IntegerField', - FIELD_TYPE.STRING: 'CharField', - FIELD_TYPE.TIMESTAMP: 'DateTimeField', - FIELD_TYPE.TINY: 'IntegerField', - FIELD_TYPE.TINY_BLOB: 'TextField', - FIELD_TYPE.MEDIUM_BLOB: 'TextField', - FIELD_TYPE.LONG_BLOB: 'TextField', - FIELD_TYPE.VAR_STRING: 'CharField', -} diff --git a/django/db/backends/mysql/validation.py b/django/db/backends/mysql/validation.py new file mode 100644 index 0000000000..85354a8468 --- /dev/null +++ b/django/db/backends/mysql/validation.py @@ -0,0 +1,13 @@ +from django.db.backends import BaseDatabaseValidation + +class DatabaseValidation(BaseDatabaseValidation): + def validate_field(self, errors, opts, f): + "Prior to MySQL 5.0.3, character fields could not exceed 255 characters" + from django.db import models + from django.db import connection + db_version = connection.get_server_version() + if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255: + errors.add(opts, + '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % + (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]]))) + \ No newline at end of file diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index bdb73b1864..33900e755c 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -8,8 +8,11 @@ import os import datetime import time -from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util +from django.db.backends import * from django.db.backends.oracle import query +from django.db.backends.oracle.client import DatabaseClient +from django.db.backends.oracle.creation import DatabaseCreation +from django.db.backends.oracle.introspection import DatabaseIntrospection from django.utils.encoding import smart_str, force_unicode # Oracle takes client-side character set encoding from the environment. @@ -24,11 +27,8 @@ DatabaseError = Database.Error IntegrityError = Database.IntegrityError class DatabaseFeatures(BaseDatabaseFeatures): - allows_group_by_ordinal = False empty_fetchmany_value = () needs_datetime_string_cast = False - supports_tablespaces = True - uses_case_insensitive_names = True uses_custom_query_class = True interprets_empty_strings_as_nulls = True @@ -194,10 +194,8 @@ class DatabaseOperations(BaseDatabaseOperations): return [first % value, second % value] - class DatabaseWrapper(BaseDatabaseWrapper): - features = DatabaseFeatures() - ops = DatabaseOperations() + operators = { 'exact': '= %s', 'iexact': '= UPPER(%s)', @@ -214,6 +212,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): } oracle_version = None + def __init__(self, *args, **kwargs): + super(DatabaseWrapper, self).__init__(*args, **kwargs) + + self.features = DatabaseFeatures() + self.ops = DatabaseOperations() + self.client = DatabaseClient() + self.creation = DatabaseCreation(self) + self.introspection = DatabaseIntrospection(self) + self.validation = BaseDatabaseValidation() + def _valid_connection(self): return self.connection is not None diff --git a/django/db/backends/oracle/client.py b/django/db/backends/oracle/client.py index 372783aa97..77fc9b9847 100644 --- a/django/db/backends/oracle/client.py +++ b/django/db/backends/oracle/client.py @@ -1,11 +1,13 @@ +from django.db.backends import BaseDatabaseClient from django.conf import settings import os -def runshell(): - dsn = settings.DATABASE_USER - if settings.DATABASE_PASSWORD: - dsn += "/%s" % settings.DATABASE_PASSWORD - if settings.DATABASE_NAME: - dsn += "@%s" % settings.DATABASE_NAME - args = ["sqlplus", "-L", dsn] - os.execvp("sqlplus", args) +class DatabaseClient(BaseDatabaseClient): + def runshell(self): + dsn = settings.DATABASE_USER + if settings.DATABASE_PASSWORD: + dsn += "/%s" % settings.DATABASE_PASSWORD + if settings.DATABASE_NAME: + dsn += "@%s" % settings.DATABASE_NAME + args = ["sqlplus", "-L", dsn] + os.execvp("sqlplus", args) diff --git a/django/db/backends/oracle/creation.py b/django/db/backends/oracle/creation.py index 2a8badebd5..c36933ae01 100644 --- a/django/db/backends/oracle/creation.py +++ b/django/db/backends/oracle/creation.py @@ -1,291 +1,289 @@ import sys, time +from django.conf import settings from django.core import management - -# This dictionary maps Field objects to their associated Oracle column -# types, as strings. Column-type strings can contain format strings; they'll -# be interpolated against the values of Field.__dict__ before being output. -# If a column type is set to None, it won't be included in the output. -# -# Any format strings starting with "qn_" are quoted before being used in the -# output (the "qn_" prefix is stripped before the lookup is performed. - -DATA_TYPES = { - 'AutoField': 'NUMBER(11)', - 'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))', - 'CharField': 'NVARCHAR2(%(max_length)s)', - 'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)', - 'DateField': 'DATE', - 'DateTimeField': 'TIMESTAMP', - 'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)', - 'FileField': 'NVARCHAR2(%(max_length)s)', - 'FilePathField': 'NVARCHAR2(%(max_length)s)', - 'FloatField': 'DOUBLE PRECISION', - 'IntegerField': 'NUMBER(11)', - 'IPAddressField': 'VARCHAR2(15)', - 'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))', - 'OneToOneField': 'NUMBER(11)', - 'PhoneNumberField': 'VARCHAR2(20)', - 'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', - 'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', - 'SlugField': 'NVARCHAR2(50)', - 'SmallIntegerField': 'NUMBER(11)', - 'TextField': 'NCLOB', - 'TimeField': 'TIMESTAMP', - 'URLField': 'VARCHAR2(%(max_length)s)', - 'USStateField': 'CHAR(2)', -} +from django.db.backends.creation import BaseDatabaseCreation TEST_DATABASE_PREFIX = 'test_' PASSWORD = 'Im_a_lumberjack' -REMEMBER = {} -def create_test_db(settings, connection, verbosity=1, autoclobber=False): - TEST_DATABASE_NAME = _test_database_name(settings) - TEST_DATABASE_USER = _test_database_user(settings) - TEST_DATABASE_PASSWD = _test_database_passwd(settings) - TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings) - TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings) +class DatabaseCreation(BaseDatabaseCreation): + # This dictionary maps Field objects to their associated Oracle column + # types, as strings. Column-type strings can contain format strings; they'll + # be interpolated against the values of Field.__dict__ before being output. + # If a column type is set to None, it won't be included in the output. + # + # Any format strings starting with "qn_" are quoted before being used in the + # output (the "qn_" prefix is stripped before the lookup is performed. - parameters = { - 'dbname': TEST_DATABASE_NAME, - 'user': TEST_DATABASE_USER, - 'password': TEST_DATABASE_PASSWD, - 'tblspace': TEST_DATABASE_TBLSPACE, - 'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, - } + data_types = { + 'AutoField': 'NUMBER(11)', + 'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))', + 'CharField': 'NVARCHAR2(%(max_length)s)', + 'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)', + 'DateField': 'DATE', + 'DateTimeField': 'TIMESTAMP', + 'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'NVARCHAR2(%(max_length)s)', + 'FilePathField': 'NVARCHAR2(%(max_length)s)', + 'FloatField': 'DOUBLE PRECISION', + 'IntegerField': 'NUMBER(11)', + 'IPAddressField': 'VARCHAR2(15)', + 'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))', + 'OneToOneField': 'NUMBER(11)', + 'PhoneNumberField': 'VARCHAR2(20)', + 'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', + 'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)', + 'SlugField': 'NVARCHAR2(50)', + 'SmallIntegerField': 'NUMBER(11)', + 'TextField': 'NCLOB', + 'TimeField': 'TIMESTAMP', + 'URLField': 'VARCHAR2(%(max_length)s)', + 'USStateField': 'CHAR(2)', + } + + def _create_test_db(self, verbosity, autoclobber): + TEST_DATABASE_NAME = self._test_database_name(settings) + TEST_DATABASE_USER = self._test_database_user(settings) + TEST_DATABASE_PASSWD = self._test_database_passwd(settings) + TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings) + TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings) - REMEMBER['user'] = settings.DATABASE_USER - REMEMBER['passwd'] = settings.DATABASE_PASSWORD + parameters = { + 'dbname': TEST_DATABASE_NAME, + 'user': TEST_DATABASE_USER, + 'password': TEST_DATABASE_PASSWD, + 'tblspace': TEST_DATABASE_TBLSPACE, + 'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, + } - cursor = connection.cursor() - if _test_database_create(settings): - if verbosity >= 1: - print 'Creating test database...' - try: - _create_test_db(cursor, parameters, verbosity) - except Exception, e: - sys.stderr.write("Got an error creating the test database: %s\n" % e) - if not autoclobber: - confirm = raw_input("It appears the test database, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_NAME) - if autoclobber or confirm == 'yes': - try: - if verbosity >= 1: - print "Destroying old test database..." - _destroy_test_db(cursor, parameters, verbosity) - if verbosity >= 1: - print "Creating test database..." - _create_test_db(cursor, parameters, verbosity) - except Exception, e: - sys.stderr.write("Got an error recreating the test database: %s\n" % e) - sys.exit(2) - else: - print "Tests cancelled." - sys.exit(1) + self.remember['user'] = settings.DATABASE_USER + self.remember['passwd'] = settings.DATABASE_PASSWORD - if _test_user_create(settings): - if verbosity >= 1: - print "Creating test user..." - try: - _create_test_user(cursor, parameters, verbosity) - except Exception, e: - sys.stderr.write("Got an error creating the test user: %s\n" % e) - if not autoclobber: - confirm = raw_input("It appears the test user, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_USER) - if autoclobber or confirm == 'yes': - try: - if verbosity >= 1: - print "Destroying old test user..." - _destroy_test_user(cursor, parameters, verbosity) - if verbosity >= 1: - print "Creating test user..." - _create_test_user(cursor, parameters, verbosity) - except Exception, e: - sys.stderr.write("Got an error recreating the test user: %s\n" % e) - sys.exit(2) - else: - print "Tests cancelled." - sys.exit(1) + cursor = self.connection.cursor() + if self._test_database_create(settings): + if verbosity >= 1: + print 'Creating test database...' + try: + self._execute_test_db_creation(cursor, parameters, verbosity) + except Exception, e: + sys.stderr.write("Got an error creating the test database: %s\n" % e) + if not autoclobber: + confirm = raw_input("It appears the test database, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_NAME) + if autoclobber or confirm == 'yes': + try: + if verbosity >= 1: + print "Destroying old test database..." + self._execute_test_db_destruction(cursor, parameters, verbosity) + if verbosity >= 1: + print "Creating test database..." + self._execute_test_db_creation(cursor, parameters, verbosity) + except Exception, e: + sys.stderr.write("Got an error recreating the test database: %s\n" % e) + sys.exit(2) + else: + print "Tests cancelled." + sys.exit(1) - connection.close() - settings.DATABASE_USER = TEST_DATABASE_USER - settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD + if self._test_user_create(settings): + if verbosity >= 1: + print "Creating test user..." + try: + self._create_test_user(cursor, parameters, verbosity) + except Exception, e: + sys.stderr.write("Got an error creating the test user: %s\n" % e) + if not autoclobber: + confirm = raw_input("It appears the test user, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_USER) + if autoclobber or confirm == 'yes': + try: + if verbosity >= 1: + print "Destroying old test user..." + self._destroy_test_user(cursor, parameters, verbosity) + if verbosity >= 1: + print "Creating test user..." + self._create_test_user(cursor, parameters, verbosity) + except Exception, e: + sys.stderr.write("Got an error recreating the test user: %s\n" % e) + sys.exit(2) + else: + print "Tests cancelled." + sys.exit(1) - management.call_command('syncdb', verbosity=verbosity, interactive=False) + settings.DATABASE_USER = TEST_DATABASE_USER + settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD - # Get a cursor (even though we don't need one yet). This has - # the side effect of initializing the test database. - cursor = connection.cursor() + return TEST_DATABASE_NAME + + def _destroy_test_db(self, test_database_name, verbosity=1): + """ + Destroy a test database, prompting the user for confirmation if the + database already exists. Returns the name of the test database created. + """ + TEST_DATABASE_NAME = self._test_database_name(settings) + TEST_DATABASE_USER = self._test_database_user(settings) + TEST_DATABASE_PASSWD = self._test_database_passwd(settings) + TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings) + TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings) -def destroy_test_db(settings, connection, old_database_name, verbosity=1): - connection.close() + settings.DATABASE_USER = self.remember['user'] + settings.DATABASE_PASSWORD = self.remember['passwd'] - TEST_DATABASE_NAME = _test_database_name(settings) - TEST_DATABASE_USER = _test_database_user(settings) - TEST_DATABASE_PASSWD = _test_database_passwd(settings) - TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings) - TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings) + parameters = { + 'dbname': TEST_DATABASE_NAME, + 'user': TEST_DATABASE_USER, + 'password': TEST_DATABASE_PASSWD, + 'tblspace': TEST_DATABASE_TBLSPACE, + 'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, + } - settings.DATABASE_NAME = old_database_name - settings.DATABASE_USER = REMEMBER['user'] - settings.DATABASE_PASSWORD = REMEMBER['passwd'] + self.remember['user'] = settings.DATABASE_USER + self.remember['passwd'] = settings.DATABASE_PASSWORD - parameters = { - 'dbname': TEST_DATABASE_NAME, - 'user': TEST_DATABASE_USER, - 'password': TEST_DATABASE_PASSWD, - 'tblspace': TEST_DATABASE_TBLSPACE, - 'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP, - } + cursor = self.connection.cursor() + time.sleep(1) # To avoid "database is being accessed by other users" errors. + if self._test_user_create(settings): + if verbosity >= 1: + print 'Destroying test user...' + self._destroy_test_user(cursor, parameters, verbosity) + if self._test_database_create(settings): + if verbosity >= 1: + print 'Destroying test database tables...' + self._execute_test_db_destruction(cursor, parameters, verbosity) + self.connection.close() - REMEMBER['user'] = settings.DATABASE_USER - REMEMBER['passwd'] = settings.DATABASE_PASSWORD - - cursor = connection.cursor() - time.sleep(1) # To avoid "database is being accessed by other users" errors. - if _test_user_create(settings): - if verbosity >= 1: - print 'Destroying test user...' - _destroy_test_user(cursor, parameters, verbosity) - if _test_database_create(settings): - if verbosity >= 1: - print 'Destroying test database...' - _destroy_test_db(cursor, parameters, verbosity) - connection.close() - -def _create_test_db(cursor, parameters, verbosity): - if verbosity >= 2: - print "_create_test_db(): dbname = %s" % parameters['dbname'] - statements = [ - """CREATE TABLESPACE %(tblspace)s - DATAFILE '%(tblspace)s.dbf' SIZE 20M - REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M - """, - """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s - TEMPFILE '%(tblspace_temp)s.dbf' SIZE 20M - REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M - """, - ] - _execute_statements(cursor, statements, parameters, verbosity) - -def _create_test_user(cursor, parameters, verbosity): - if verbosity >= 2: - print "_create_test_user(): username = %s" % parameters['user'] - statements = [ - """CREATE USER %(user)s - IDENTIFIED BY %(password)s - DEFAULT TABLESPACE %(tblspace)s - TEMPORARY TABLESPACE %(tblspace_temp)s - """, - """GRANT CONNECT, RESOURCE TO %(user)s""", - ] - _execute_statements(cursor, statements, parameters, verbosity) - -def _destroy_test_db(cursor, parameters, verbosity): - if verbosity >= 2: - print "_destroy_test_db(): dbname=%s" % parameters['dbname'] - statements = [ - 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', - 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', - ] - _execute_statements(cursor, statements, parameters, verbosity) - -def _destroy_test_user(cursor, parameters, verbosity): - if verbosity >= 2: - print "_destroy_test_user(): user=%s" % parameters['user'] - print "Be patient. This can take some time..." - statements = [ - 'DROP USER %(user)s CASCADE', - ] - _execute_statements(cursor, statements, parameters, verbosity) - -def _execute_statements(cursor, statements, parameters, verbosity): - for template in statements: - stmt = template % parameters + def _execute_test_db_creation(cursor, parameters, verbosity): if verbosity >= 2: - print stmt + print "_create_test_db(): dbname = %s" % parameters['dbname'] + statements = [ + """CREATE TABLESPACE %(tblspace)s + DATAFILE '%(tblspace)s.dbf' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M + """, + """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s + TEMPFILE '%(tblspace_temp)s.dbf' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M + """, + ] + _execute_statements(cursor, statements, parameters, verbosity) + + def _create_test_user(cursor, parameters, verbosity): + if verbosity >= 2: + print "_create_test_user(): username = %s" % parameters['user'] + statements = [ + """CREATE USER %(user)s + IDENTIFIED BY %(password)s + DEFAULT TABLESPACE %(tblspace)s + TEMPORARY TABLESPACE %(tblspace_temp)s + """, + """GRANT CONNECT, RESOURCE TO %(user)s""", + ] + _execute_statements(cursor, statements, parameters, verbosity) + + def _execute_test_db_destruction(cursor, parameters, verbosity): + if verbosity >= 2: + print "_execute_test_db_destruction(): dbname=%s" % parameters['dbname'] + statements = [ + 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', + 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS', + ] + _execute_statements(cursor, statements, parameters, verbosity) + + def _destroy_test_user(cursor, parameters, verbosity): + if verbosity >= 2: + print "_destroy_test_user(): user=%s" % parameters['user'] + print "Be patient. This can take some time..." + statements = [ + 'DROP USER %(user)s CASCADE', + ] + _execute_statements(cursor, statements, parameters, verbosity) + + def _execute_statements(cursor, statements, parameters, verbosity): + for template in statements: + stmt = template % parameters + if verbosity >= 2: + print stmt + try: + cursor.execute(stmt) + except Exception, err: + sys.stderr.write("Failed (%s)\n" % (err)) + raise + + def _test_database_name(settings): + name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME try: - cursor.execute(stmt) - except Exception, err: - sys.stderr.write("Failed (%s)\n" % (err)) + if settings.TEST_DATABASE_NAME: + name = settings.TEST_DATABASE_NAME + except AttributeError: + pass + except: raise + return name -def _test_database_name(settings): - name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME - try: - if settings.TEST_DATABASE_NAME: - name = settings.TEST_DATABASE_NAME - except AttributeError: - pass - except: - raise - return name + def _test_database_create(settings): + name = True + try: + if settings.TEST_DATABASE_CREATE: + name = True + else: + name = False + except AttributeError: + pass + except: + raise + return name -def _test_database_create(settings): - name = True - try: - if settings.TEST_DATABASE_CREATE: - name = True - else: - name = False - except AttributeError: - pass - except: - raise - return name + def _test_user_create(settings): + name = True + try: + if settings.TEST_USER_CREATE: + name = True + else: + name = False + except AttributeError: + pass + except: + raise + return name -def _test_user_create(settings): - name = True - try: - if settings.TEST_USER_CREATE: - name = True - else: - name = False - except AttributeError: - pass - except: - raise - return name + def _test_database_user(settings): + name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + try: + if settings.TEST_DATABASE_USER: + name = settings.TEST_DATABASE_USER + except AttributeError: + pass + except: + raise + return name -def _test_database_user(settings): - name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME - try: - if settings.TEST_DATABASE_USER: - name = settings.TEST_DATABASE_USER - except AttributeError: - pass - except: - raise - return name + def _test_database_passwd(settings): + name = PASSWORD + try: + if settings.TEST_DATABASE_PASSWD: + name = settings.TEST_DATABASE_PASSWD + except AttributeError: + pass + except: + raise + return name -def _test_database_passwd(settings): - name = PASSWORD - try: - if settings.TEST_DATABASE_PASSWD: - name = settings.TEST_DATABASE_PASSWD - except AttributeError: - pass - except: - raise - return name + def _test_database_tblspace(settings): + name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + try: + if settings.TEST_DATABASE_TBLSPACE: + name = settings.TEST_DATABASE_TBLSPACE + except AttributeError: + pass + except: + raise + return name -def _test_database_tblspace(settings): - name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME - try: - if settings.TEST_DATABASE_TBLSPACE: - name = settings.TEST_DATABASE_TBLSPACE - except AttributeError: - pass - except: - raise - return name - -def _test_database_tblspace_tmp(settings): - name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp' - try: - if settings.TEST_DATABASE_TBLSPACE_TMP: - name = settings.TEST_DATABASE_TBLSPACE_TMP - except AttributeError: - pass - except: - raise - return name + def _test_database_tblspace_tmp(settings): + name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp' + try: + if settings.TEST_DATABASE_TBLSPACE_TMP: + name = settings.TEST_DATABASE_TBLSPACE_TMP + except AttributeError: + pass + except: + raise + return name diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index 6f800c8bb6..890e30a694 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -1,98 +1,103 @@ -from django.db.backends.oracle.base import DatabaseOperations -import re +from django.db.backends import BaseDatabaseIntrospection import cx_Oracle +import re -quote_name = DatabaseOperations().quote_name foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)") -def get_table_list(cursor): - "Returns a list of table names in the current database." - cursor.execute("SELECT TABLE_NAME FROM USER_TABLES") - return [row[0].upper() for row in cursor.fetchall()] +class DatabaseIntrospection(BaseDatabaseIntrospection): + # Maps type objects to Django Field types. + data_types_reverse = { + cx_Oracle.CLOB: 'TextField', + cx_Oracle.DATETIME: 'DateTimeField', + cx_Oracle.FIXED_CHAR: 'CharField', + cx_Oracle.NCLOB: 'TextField', + cx_Oracle.NUMBER: 'DecimalField', + cx_Oracle.STRING: 'CharField', + cx_Oracle.TIMESTAMP: 'DateTimeField', + } -def get_table_description(cursor, table_name): - "Returns a description of the table, with the DB-API cursor.description interface." - cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % quote_name(table_name)) - return cursor.description + def get_table_list(self, cursor): + "Returns a list of table names in the current database." + cursor.execute("SELECT TABLE_NAME FROM USER_TABLES") + return [row[0].upper() for row in cursor.fetchall()] -def _name_to_index(cursor, table_name): - """ - Returns a dictionary of {field_name: field_index} for the given table. - Indexes are 0-based. - """ - return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))]) + def get_table_description(self, cursor, table_name): + "Returns a description of the table, with the DB-API cursor.description interface." + cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % self.connection.ops.quote_name(table_name)) + return cursor.description -def get_relations(cursor, table_name): - """ - Returns a dictionary of {field_index: (field_index_other_table, other_table)} - representing all relationships to the given table. Indexes are 0-based. - """ - cursor.execute(""" -SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1 -FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb, - user_tab_cols ta, user_tab_cols tb -WHERE user_constraints.table_name = %s AND - ta.table_name = %s AND - ta.column_name = ca.column_name AND - ca.table_name = %s AND - user_constraints.constraint_name = ca.constraint_name AND - user_constraints.r_constraint_name = cb.constraint_name AND - cb.table_name = tb.table_name AND - cb.column_name = tb.column_name AND - ca.position = cb.position""", [table_name, table_name, table_name]) + def table_name_converter(self, name): + "Table name comparison is case insensitive under Oracle" + return name.upper() + + def _name_to_index(self, cursor, table_name): + """ + Returns a dictionary of {field_name: field_index} for the given table. + Indexes are 0-based. + """ + return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))]) - relations = {} - for row in cursor.fetchall(): - relations[row[0]] = (row[2], row[1]) - return relations + def get_relations(self, cursor, table_name): + """ + Returns a dictionary of {field_index: (field_index_other_table, other_table)} + representing all relationships to the given table. Indexes are 0-based. + """ + cursor.execute(""" + SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1 + FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb, + user_tab_cols ta, user_tab_cols tb + WHERE user_constraints.table_name = %s AND + ta.table_name = %s AND + ta.column_name = ca.column_name AND + ca.table_name = %s AND + user_constraints.constraint_name = ca.constraint_name AND + user_constraints.r_constraint_name = cb.constraint_name AND + cb.table_name = tb.table_name AND + cb.column_name = tb.column_name AND + ca.position = cb.position""", [table_name, table_name, table_name]) -def get_indexes(cursor, table_name): - """ - Returns a dictionary of fieldname -> infodict for the given table, - where each infodict is in the format: - {'primary_key': boolean representing whether it's the primary key, - 'unique': boolean representing whether it's a unique index} - """ - # This query retrieves each index on the given table, including the - # first associated field name - # "We were in the nick of time; you were in great peril!" - sql = """ -WITH primarycols AS ( - SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL - FROM user_cons_columns, user_constraints - WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND - user_constraints.constraint_type = 'P' AND - user_cons_columns.table_name = %s), - uniquecols AS ( - SELECT user_ind_columns.table_name, user_ind_columns.column_name, 1 AS UNIQUECOL - FROM user_indexes, user_ind_columns - WHERE uniqueness = 'UNIQUE' AND - user_indexes.index_name = user_ind_columns.index_name AND - user_ind_columns.table_name = %s) -SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL -FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM -uniquecols) allcols, - primarycols, uniquecols -WHERE allcols.column_name = primarycols.column_name (+) AND - allcols.column_name = uniquecols.column_name (+) - """ - cursor.execute(sql, [table_name, table_name]) - indexes = {} - for row in cursor.fetchall(): - # row[1] (idx.indkey) is stored in the DB as an array. It comes out as - # a string of space-separated integers. This designates the field - # indexes (1-based) of the fields that have indexes on the table. - # Here, we skip any indexes across multiple fields. - indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]} - return indexes + relations = {} + for row in cursor.fetchall(): + relations[row[0]] = (row[2], row[1]) + return relations + + def get_indexes(self, cursor, table_name): + """ + Returns a dictionary of fieldname -> infodict for the given table, + where each infodict is in the format: + {'primary_key': boolean representing whether it's the primary key, + 'unique': boolean representing whether it's a unique index} + """ + # This query retrieves each index on the given table, including the + # first associated field name + # "We were in the nick of time; you were in great peril!" + sql = """ + WITH primarycols AS ( + SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL + FROM user_cons_columns, user_constraints + WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND + user_constraints.constraint_type = 'P' AND + user_cons_columns.table_name = %s), + uniquecols AS ( + SELECT user_ind_columns.table_name, user_ind_columns.column_name, 1 AS UNIQUECOL + FROM user_indexes, user_ind_columns + WHERE uniqueness = 'UNIQUE' AND + user_indexes.index_name = user_ind_columns.index_name AND + user_ind_columns.table_name = %s) + SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL + FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM + uniquecols) allcols, + primarycols, uniquecols + WHERE allcols.column_name = primarycols.column_name (+) AND + allcols.column_name = uniquecols.column_name (+) + """ + cursor.execute(sql, [table_name, table_name]) + indexes = {} + for row in cursor.fetchall(): + # row[1] (idx.indkey) is stored in the DB as an array. It comes out as + # a string of space-separated integers. This designates the field + # indexes (1-based) of the fields that have indexes on the table. + # Here, we skip any indexes across multiple fields. + indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]} + return indexes -# Maps type objects to Django Field types. -DATA_TYPES_REVERSE = { - cx_Oracle.CLOB: 'TextField', - cx_Oracle.DATETIME: 'DateTimeField', - cx_Oracle.FIXED_CHAR: 'CharField', - cx_Oracle.NCLOB: 'TextField', - cx_Oracle.NUMBER: 'DecimalField', - cx_Oracle.STRING: 'CharField', - cx_Oracle.TIMESTAMP: 'DateTimeField', -} diff --git a/django/db/backends/postgresql/base.py b/django/db/backends/postgresql/base.py index 1dfe34aceb..4a8d6ebef0 100644 --- a/django/db/backends/postgresql/base.py +++ b/django/db/backends/postgresql/base.py @@ -4,9 +4,13 @@ PostgreSQL database backend for Django. Requires psycopg 1: http://initd.org/projects/psycopg1 """ -from django.utils.encoding import smart_str, smart_unicode -from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, util +from django.db.backends import * +from django.db.backends.postgresql.client import DatabaseClient +from django.db.backends.postgresql.creation import DatabaseCreation +from django.db.backends.postgresql.introspection import DatabaseIntrospection from django.db.backends.postgresql.operations import DatabaseOperations +from django.utils.encoding import smart_str, smart_unicode + try: import psycopg as Database except ImportError, e: @@ -59,12 +63,7 @@ class UnicodeCursorWrapper(object): def __iter__(self): return iter(self.cursor) -class DatabaseFeatures(BaseDatabaseFeatures): - pass # This backend uses all the defaults. - class DatabaseWrapper(BaseDatabaseWrapper): - features = DatabaseFeatures() - ops = DatabaseOperations() operators = { 'exact': '= %s', 'iexact': 'ILIKE %s', @@ -82,6 +81,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'iendswith': 'ILIKE %s', } + def __init__(self, *args, **kwargs): + super(DatabaseWrapper, self).__init__(*args, **kwargs) + + self.features = BaseDatabaseFeatures() + self.ops = DatabaseOperations() + self.client = DatabaseClient() + self.creation = DatabaseCreation(self) + self.introspection = DatabaseIntrospection(self) + self.validation = BaseDatabaseValidation() + def _cursor(self, settings): set_tz = False if self.connection is None: diff --git a/django/db/backends/postgresql/client.py b/django/db/backends/postgresql/client.py index 8123ec7848..28daed833a 100644 --- a/django/db/backends/postgresql/client.py +++ b/django/db/backends/postgresql/client.py @@ -1,15 +1,17 @@ +from django.db.backends import BaseDatabaseClient from django.conf import settings import os -def runshell(): - args = ['psql'] - if settings.DATABASE_USER: - args += ["-U", settings.DATABASE_USER] - if settings.DATABASE_PASSWORD: - args += ["-W"] - if settings.DATABASE_HOST: - args.extend(["-h", settings.DATABASE_HOST]) - if settings.DATABASE_PORT: - args.extend(["-p", str(settings.DATABASE_PORT)]) - args += [settings.DATABASE_NAME] - os.execvp('psql', args) +class DatabaseClient(BaseDatabaseClient): + def runshell(self): + args = ['psql'] + if settings.DATABASE_USER: + args += ["-U", settings.DATABASE_USER] + if settings.DATABASE_PASSWORD: + args += ["-W"] + if settings.DATABASE_HOST: + args.extend(["-h", settings.DATABASE_HOST]) + if settings.DATABASE_PORT: + args.extend(["-p", str(settings.DATABASE_PORT)]) + args += [settings.DATABASE_NAME] + os.execvp('psql', args) diff --git a/django/db/backends/postgresql/creation.py b/django/db/backends/postgresql/creation.py index a8877a7d9b..3e537e345e 100644 --- a/django/db/backends/postgresql/creation.py +++ b/django/db/backends/postgresql/creation.py @@ -1,28 +1,38 @@ -# This dictionary maps Field objects to their associated PostgreSQL column -# types, as strings. Column-type strings can contain format strings; they'll -# be interpolated against the values of Field.__dict__ before being output. -# If a column type is set to None, it won't be included in the output. -DATA_TYPES = { - 'AutoField': 'serial', - 'BooleanField': 'boolean', - 'CharField': 'varchar(%(max_length)s)', - 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'timestamp with time zone', - 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'double precision', - 'IntegerField': 'integer', - 'IPAddressField': 'inet', - 'NullBooleanField': 'boolean', - 'OneToOneField': 'integer', - 'PhoneNumberField': 'varchar(20)', - 'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)', - 'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', - 'USStateField': 'varchar(2)', -} +from django.conf import settings +from django.db.backends.creation import BaseDatabaseCreation + +class DatabaseCreation(BaseDatabaseCreation): + # This dictionary maps Field objects to their associated PostgreSQL column + # types, as strings. Column-type strings can contain format strings; they'll + # be interpolated against the values of Field.__dict__ before being output. + # If a column type is set to None, it won't be included in the output. + data_types = { + 'AutoField': 'serial', + 'BooleanField': 'boolean', + 'CharField': 'varchar(%(max_length)s)', + 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', + 'DateField': 'date', + 'DateTimeField': 'timestamp with time zone', + 'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)', + 'FileField': 'varchar(%(max_length)s)', + 'FilePathField': 'varchar(%(max_length)s)', + 'FloatField': 'double precision', + 'IntegerField': 'integer', + 'IPAddressField': 'inet', + 'NullBooleanField': 'boolean', + 'OneToOneField': 'integer', + 'PhoneNumberField': 'varchar(20)', + 'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)', + 'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)', + 'SlugField': 'varchar(%(max_length)s)', + 'SmallIntegerField': 'smallint', + 'TextField': 'text', + 'TimeField': 'time', + 'USStateField': 'varchar(2)', + } + + def sql_table_creation_suffix(self): + assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time." + if settings.TEST_DATABASE_CHARSET: + return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET + return '' diff --git a/django/db/backends/postgresql/introspection.py b/django/db/backends/postgresql/introspection.py index 982c004569..7b3ab3bb8a 100644 --- a/django/db/backends/postgresql/introspection.py +++ b/django/db/backends/postgresql/introspection.py @@ -1,86 +1,86 @@ -from django.db.backends.postgresql.base import DatabaseOperations +from django.db.backends import BaseDatabaseIntrospection -quote_name = DatabaseOperations().quote_name +class DatabaseIntrospection(BaseDatabaseIntrospection): + # Maps type codes to Django Field types. + data_types_reverse = { + 16: 'BooleanField', + 21: 'SmallIntegerField', + 23: 'IntegerField', + 25: 'TextField', + 701: 'FloatField', + 869: 'IPAddressField', + 1043: 'CharField', + 1082: 'DateField', + 1083: 'TimeField', + 1114: 'DateTimeField', + 1184: 'DateTimeField', + 1266: 'TimeField', + 1700: 'DecimalField', + } + + def get_table_list(self, cursor): + "Returns a list of table names in the current database." + cursor.execute(""" + SELECT c.relname + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE c.relkind IN ('r', 'v', '') + AND n.nspname NOT IN ('pg_catalog', 'pg_toast') + AND pg_catalog.pg_table_is_visible(c.oid)""") + return [row[0] for row in cursor.fetchall()] -def get_table_list(cursor): - "Returns a list of table names in the current database." - cursor.execute(""" - SELECT c.relname - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE c.relkind IN ('r', 'v', '') - AND n.nspname NOT IN ('pg_catalog', 'pg_toast') - AND pg_catalog.pg_table_is_visible(c.oid)""") - return [row[0] for row in cursor.fetchall()] + def get_table_description(self, cursor, table_name): + "Returns a description of the table, with the DB-API cursor.description interface." + cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)) + return cursor.description -def get_table_description(cursor, table_name): - "Returns a description of the table, with the DB-API cursor.description interface." - cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) - return cursor.description + def get_relations(self, cursor, table_name): + """ + Returns a dictionary of {field_index: (field_index_other_table, other_table)} + representing all relationships to the given table. Indexes are 0-based. + """ + cursor.execute(""" + SELECT con.conkey, con.confkey, c2.relname + FROM pg_constraint con, pg_class c1, pg_class c2 + WHERE c1.oid = con.conrelid + AND c2.oid = con.confrelid + AND c1.relname = %s + AND con.contype = 'f'""", [table_name]) + relations = {} + for row in cursor.fetchall(): + try: + # row[0] and row[1] are like "{2}", so strip the curly braces. + relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2]) + except ValueError: + continue + return relations -def get_relations(cursor, table_name): - """ - Returns a dictionary of {field_index: (field_index_other_table, other_table)} - representing all relationships to the given table. Indexes are 0-based. - """ - cursor.execute(""" - SELECT con.conkey, con.confkey, c2.relname - FROM pg_constraint con, pg_class c1, pg_class c2 - WHERE c1.oid = con.conrelid - AND c2.oid = con.confrelid - AND c1.relname = %s - AND con.contype = 'f'""", [table_name]) - relations = {} - for row in cursor.fetchall(): - try: - # row[0] and row[1] are like "{2}", so strip the curly braces. - relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2]) - except ValueError: - continue - return relations + def get_indexes(self, cursor, table_name): + """ + Returns a dictionary of fieldname -> infodict for the given table, + where each infodict is in the format: + {'primary_key': boolean representing whether it's the primary key, + 'unique': boolean representing whether it's a unique index} + """ + # This query retrieves each index on the given table, including the + # first associated field name + cursor.execute(""" + SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary + FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, + pg_catalog.pg_index idx, pg_catalog.pg_attribute attr + WHERE c.oid = idx.indrelid + AND idx.indexrelid = c2.oid + AND attr.attrelid = c.oid + AND attr.attnum = idx.indkey[0] + AND c.relname = %s""", [table_name]) + indexes = {} + for row in cursor.fetchall(): + # row[1] (idx.indkey) is stored in the DB as an array. It comes out as + # a string of space-separated integers. This designates the field + # indexes (1-based) of the fields that have indexes on the table. + # Here, we skip any indexes across multiple fields. + if ' ' in row[1]: + continue + indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]} + return indexes -def get_indexes(cursor, table_name): - """ - Returns a dictionary of fieldname -> infodict for the given table, - where each infodict is in the format: - {'primary_key': boolean representing whether it's the primary key, - 'unique': boolean representing whether it's a unique index} - """ - # This query retrieves each index on the given table, including the - # first associated field name - cursor.execute(""" - SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary - FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, - pg_catalog.pg_index idx, pg_catalog.pg_attribute attr - WHERE c.oid = idx.indrelid - AND idx.indexrelid = c2.oid - AND attr.attrelid = c.oid - AND attr.attnum = idx.indkey[0] - AND c.relname = %s""", [table_name]) - indexes = {} - for row in cursor.fetchall(): - # row[1] (idx.indkey) is stored in the DB as an array. It comes out as - # a string of space-separated integers. This designates the field - # indexes (1-based) of the fields that have indexes on the table. - # Here, we skip any indexes across multiple fields. - if ' ' in row[1]: - continue - indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]} - return indexes - -# Maps type codes to Django Field types. -DATA_TYPES_REVERSE = { - 16: 'BooleanField', - 21: 'SmallIntegerField', - 23: 'IntegerField', - 25: 'TextField', - 701: 'FloatField', - 869: 'IPAddressField', - 1043: 'CharField', - 1082: 'DateField', - 1083: 'TimeField', - 1114: 'DateTimeField', - 1184: 'DateTimeField', - 1266: 'TimeField', - 1700: 'DecimalField', -} diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 6b5233e8de..139e36ba59 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -4,8 +4,12 @@ PostgreSQL database backend for Django. Requires psycopg 2: http://initd.org/projects/psycopg2 """ -from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures +from django.db.backends import * from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations +from django.db.backends.postgresql.client import DatabaseClient +from django.db.backends.postgresql.creation import DatabaseCreation +from django.db.backends.postgresql_psycopg2.introspection import DatabaseIntrospection + from django.utils.safestring import SafeUnicode try: import psycopg2 as Database @@ -31,8 +35,6 @@ class DatabaseOperations(PostgresqlDatabaseOperations): return cursor.query class DatabaseWrapper(BaseDatabaseWrapper): - features = DatabaseFeatures() - ops = DatabaseOperations() operators = { 'exact': '= %s', 'iexact': 'ILIKE %s', @@ -50,6 +52,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'iendswith': 'ILIKE %s', } + def __init__(self, *args, **kwargs): + super(DatabaseWrapper, self).__init__(*args, **kwargs) + + self.features = DatabaseFeatures() + self.ops = DatabaseOperations() + self.client = DatabaseClient() + self.creation = DatabaseCreation(self) + self.introspection = DatabaseIntrospection(self) + self.validation = BaseDatabaseValidation() + def _cursor(self, settings): set_tz = False if self.connection is None: diff --git a/django/db/backends/postgresql_psycopg2/client.py b/django/db/backends/postgresql_psycopg2/client.py index c9d879a1b7..e69de29bb2 100644 --- a/django/db/backends/postgresql_psycopg2/client.py +++ b/django/db/backends/postgresql_psycopg2/client.py @@ -1 +0,0 @@ -from django.db.backends.postgresql.client import * diff --git a/django/db/backends/postgresql_psycopg2/creation.py b/django/db/backends/postgresql_psycopg2/creation.py index 8c87e5c493..e69de29bb2 100644 --- a/django/db/backends/postgresql_psycopg2/creation.py +++ b/django/db/backends/postgresql_psycopg2/creation.py @@ -1 +0,0 @@ -from django.db.backends.postgresql.creation import * diff --git a/django/db/backends/postgresql_psycopg2/introspection.py b/django/db/backends/postgresql_psycopg2/introspection.py index bf839c3e95..83bd9b4c44 100644 --- a/django/db/backends/postgresql_psycopg2/introspection.py +++ b/django/db/backends/postgresql_psycopg2/introspection.py @@ -1,83 +1,21 @@ -from django.db.backends.postgresql_psycopg2.base import DatabaseOperations +from django.db.backends.postgresql.introspection import DatabaseIntrospection as PostgresDatabaseIntrospection -quote_name = DatabaseOperations().quote_name +class DatabaseIntrospection(PostgresDatabaseIntrospection): -def get_table_list(cursor): - "Returns a list of table names in the current database." - cursor.execute(""" - SELECT c.relname - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE c.relkind IN ('r', 'v', '') - AND n.nspname NOT IN ('pg_catalog', 'pg_toast') - AND pg_catalog.pg_table_is_visible(c.oid)""") - return [row[0] for row in cursor.fetchall()] - -def get_table_description(cursor, table_name): - "Returns a description of the table, with the DB-API cursor.description interface." - cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name)) - return cursor.description - -def get_relations(cursor, table_name): - """ - Returns a dictionary of {field_index: (field_index_other_table, other_table)} - representing all relationships to the given table. Indexes are 0-based. - """ - cursor.execute(""" - SELECT con.conkey, con.confkey, c2.relname - FROM pg_constraint con, pg_class c1, pg_class c2 - WHERE c1.oid = con.conrelid - AND c2.oid = con.confrelid - AND c1.relname = %s - AND con.contype = 'f'""", [table_name]) - relations = {} - for row in cursor.fetchall(): - # row[0] and row[1] are single-item lists, so grab the single item. - relations[row[0][0] - 1] = (row[1][0] - 1, row[2]) - return relations - -def get_indexes(cursor, table_name): - """ - Returns a dictionary of fieldname -> infodict for the given table, - where each infodict is in the format: - {'primary_key': boolean representing whether it's the primary key, - 'unique': boolean representing whether it's a unique index} - """ - # This query retrieves each index on the given table, including the - # first associated field name - cursor.execute(""" - SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary - FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, - pg_catalog.pg_index idx, pg_catalog.pg_attribute attr - WHERE c.oid = idx.indrelid - AND idx.indexrelid = c2.oid - AND attr.attrelid = c.oid - AND attr.attnum = idx.indkey[0] - AND c.relname = %s""", [table_name]) - indexes = {} - for row in cursor.fetchall(): - # row[1] (idx.indkey) is stored in the DB as an array. It comes out as - # a string of space-separated integers. This designates the field - # indexes (1-based) of the fields that have indexes on the table. - # Here, we skip any indexes across multiple fields. - if ' ' in row[1]: - continue - indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]} - return indexes - -# Maps type codes to Django Field types. -DATA_TYPES_REVERSE = { - 16: 'BooleanField', - 21: 'SmallIntegerField', - 23: 'IntegerField', - 25: 'TextField', - 701: 'FloatField', - 869: 'IPAddressField', - 1043: 'CharField', - 1082: 'DateField', - 1083: 'TimeField', - 1114: 'DateTimeField', - 1184: 'DateTimeField', - 1266: 'TimeField', - 1700: 'DecimalField', -} + def get_relations(self, cursor, table_name): + """ + Returns a dictionary of {field_index: (field_index_other_table, other_table)} + representing all relationships to the given table. Indexes are 0-based. + """ + cursor.execute(""" + SELECT con.conkey, con.confkey, c2.relname + FROM pg_constraint con, pg_class c1, pg_class c2 + WHERE c1.oid = con.conrelid + AND c2.oid = con.confrelid + AND c1.relname = %s + AND con.contype = 'f'""", [table_name]) + relations = {} + for row in cursor.fetchall(): + # row[0] and row[1] are single-item lists, so grab the single item. + relations[row[0][0] - 1] = (row[1][0] - 1, row[2]) + return relations diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index e7d9f25a97..0ee5e069b2 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -6,7 +6,11 @@ Python 2.3 and 2.4 require pysqlite2 (http://pysqlite.org/). Python 2.5 and later use the sqlite3 module in the standard library. """ -from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util +from django.db.backends import * +from django.db.backends.sqlite3.client import DatabaseClient +from django.db.backends.sqlite3.creation import DatabaseCreation +from django.db.backends.sqlite3.introspection import DatabaseIntrospection + try: try: from sqlite3 import dbapi2 as Database @@ -46,7 +50,6 @@ if Database.version_info >= (2,4,1): Database.register_adapter(str, lambda s:s.decode('utf-8')) class DatabaseFeatures(BaseDatabaseFeatures): - supports_constraints = False # SQLite cannot handle us only partially reading from a cursor's result set # and then writing the same rows to the database in another cursor. This # setting ensures we always read result sets fully into memory all in one @@ -96,11 +99,8 @@ class DatabaseOperations(BaseDatabaseOperations): second = '%s-12-31 23:59:59.999999' return [first % value, second % value] - class DatabaseWrapper(BaseDatabaseWrapper): - features = DatabaseFeatures() - ops = DatabaseOperations() - + # SQLite requires LIKE statements to include an ESCAPE clause if the value # being escaped has a percent or underscore in it. # See http://www.sqlite.org/lang_expr.html for an explanation. @@ -121,6 +121,16 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'iendswith': "LIKE %s ESCAPE '\\'", } + def __init__(self, *args, **kwargs): + super(DatabaseWrapper, self).__init__(*args, **kwargs) + + self.features = DatabaseFeatures() + self.ops = DatabaseOperations() + self.client = DatabaseClient() + self.creation = DatabaseCreation(self) + self.introspection = DatabaseIntrospection(self) + self.validation = BaseDatabaseValidation() + def _cursor(self, settings): if self.connection is None: if not settings.DATABASE_NAME: diff --git a/django/db/backends/sqlite3/client.py b/django/db/backends/sqlite3/client.py index 097218341f..affb1c228c 100644 --- a/django/db/backends/sqlite3/client.py +++ b/django/db/backends/sqlite3/client.py @@ -1,6 +1,8 @@ +from django.db.backends import BaseDatabaseClient from django.conf import settings import os -def runshell(): - args = ['', settings.DATABASE_NAME] - os.execvp('sqlite3', args) +class DatabaseClient(BaseDatabaseClient): + def runshell(self): + args = ['', settings.DATABASE_NAME] + os.execvp('sqlite3', args) diff --git a/django/db/backends/sqlite3/creation.py b/django/db/backends/sqlite3/creation.py index c1c2b3170d..6ad6154d36 100644 --- a/django/db/backends/sqlite3/creation.py +++ b/django/db/backends/sqlite3/creation.py @@ -1,27 +1,73 @@ -# SQLite doesn't actually support most of these types, but it "does the right -# thing" given more verbose field definitions, so leave them as is so that -# schema inspection is more useful. -DATA_TYPES = { - 'AutoField': 'integer', - 'BooleanField': 'bool', - 'CharField': 'varchar(%(max_length)s)', - 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', - 'DateField': 'date', - 'DateTimeField': 'datetime', - 'DecimalField': 'decimal', - 'FileField': 'varchar(%(max_length)s)', - 'FilePathField': 'varchar(%(max_length)s)', - 'FloatField': 'real', - 'IntegerField': 'integer', - 'IPAddressField': 'char(15)', - 'NullBooleanField': 'bool', - 'OneToOneField': 'integer', - 'PhoneNumberField': 'varchar(20)', - 'PositiveIntegerField': 'integer unsigned', - 'PositiveSmallIntegerField': 'smallint unsigned', - 'SlugField': 'varchar(%(max_length)s)', - 'SmallIntegerField': 'smallint', - 'TextField': 'text', - 'TimeField': 'time', - 'USStateField': 'varchar(2)', -} +import os +import sys +from django.conf import settings +from django.db.backends.creation import BaseDatabaseCreation + +class DatabaseCreation(BaseDatabaseCreation): + # SQLite doesn't actually support most of these types, but it "does the right + # thing" given more verbose field definitions, so leave them as is so that + # schema inspection is more useful. + data_types = { + 'AutoField': 'integer', + 'BooleanField': 'bool', + 'CharField': 'varchar(%(max_length)s)', + 'CommaSeparatedIntegerField': 'varchar(%(max_length)s)', + 'DateField': 'date', + 'DateTimeField': 'datetime', + 'DecimalField': 'decimal', + 'FileField': 'varchar(%(max_length)s)', + 'FilePathField': 'varchar(%(max_length)s)', + 'FloatField': 'real', + 'IntegerField': 'integer', + 'IPAddressField': 'char(15)', + 'NullBooleanField': 'bool', + 'OneToOneField': 'integer', + 'PhoneNumberField': 'varchar(20)', + 'PositiveIntegerField': 'integer unsigned', + 'PositiveSmallIntegerField': 'smallint unsigned', + 'SlugField': 'varchar(%(max_length)s)', + 'SmallIntegerField': 'smallint', + 'TextField': 'text', + 'TimeField': 'time', + 'USStateField': 'varchar(2)', + } + + def sql_for_pending_references(self, model, style, pending_references): + "SQLite3 doesn't support constraints" + return [] + + def sql_remove_table_constraints(self, model, references_to_delete): + "SQLite3 doesn't support constraints" + return [] + + def _create_test_db(self, verbosity, autoclobber): + if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:": + test_database_name = settings.TEST_DATABASE_NAME + # Erase the old test database + if verbosity >= 1: + print "Destroying old test database..." + if os.access(test_database_name, os.F_OK): + if not autoclobber: + confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name) + if autoclobber or confirm == 'yes': + try: + if verbosity >= 1: + print "Destroying old test database..." + os.remove(test_database_name) + except Exception, e: + sys.stderr.write("Got an error deleting the old test database: %s\n" % e) + sys.exit(2) + else: + print "Tests cancelled." + sys.exit(1) + if verbosity >= 1: + print "Creating test database..." + else: + test_database_name = ":memory:" + return test_database_name + + def _destroy_test_db(self, test_database_name, verbosity): + if test_database_name and test_database_name != ":memory:": + # Remove the SQLite database file + os.remove(test_database_name) + \ No newline at end of file diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 52b880aac2..5e26f33ea6 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -1,84 +1,30 @@ -from django.db.backends.sqlite3.base import DatabaseOperations - -quote_name = DatabaseOperations().quote_name - -def get_table_list(cursor): - "Returns a list of table names in the current database." - # Skip the sqlite_sequence system table used for autoincrement key - # generation. - cursor.execute(""" - SELECT name FROM sqlite_master - WHERE type='table' AND NOT name='sqlite_sequence' - ORDER BY name""") - return [row[0] for row in cursor.fetchall()] - -def get_table_description(cursor, table_name): - "Returns a description of the table, with the DB-API cursor.description interface." - return [(info['name'], info['type'], None, None, None, None, - info['null_ok']) for info in _table_info(cursor, table_name)] - -def get_relations(cursor, table_name): - raise NotImplementedError - -def get_indexes(cursor, table_name): - """ - Returns a dictionary of fieldname -> infodict for the given table, - where each infodict is in the format: - {'primary_key': boolean representing whether it's the primary key, - 'unique': boolean representing whether it's a unique index} - """ - indexes = {} - for info in _table_info(cursor, table_name): - indexes[info['name']] = {'primary_key': info['pk'] != 0, - 'unique': False} - cursor.execute('PRAGMA index_list(%s)' % quote_name(table_name)) - # seq, name, unique - for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]: - if not unique: - continue - cursor.execute('PRAGMA index_info(%s)' % quote_name(index)) - info = cursor.fetchall() - # Skip indexes across multiple fields - if len(info) != 1: - continue - name = info[0][2] # seqno, cid, name - indexes[name]['unique'] = True - return indexes - -def _table_info(cursor, name): - cursor.execute('PRAGMA table_info(%s)' % quote_name(name)) - # cid, name, type, notnull, dflt_value, pk - return [{'name': field[1], - 'type': field[2], - 'null_ok': not field[3], - 'pk': field[5] # undocumented - } for field in cursor.fetchall()] - -# Maps SQL types to Django Field types. Some of the SQL types have multiple -# entries here because SQLite allows for anything and doesn't normalize the -# field type; it uses whatever was given. -BASE_DATA_TYPES_REVERSE = { - 'bool': 'BooleanField', - 'boolean': 'BooleanField', - 'smallint': 'SmallIntegerField', - 'smallinteger': 'SmallIntegerField', - 'int': 'IntegerField', - 'integer': 'IntegerField', - 'text': 'TextField', - 'char': 'CharField', - 'date': 'DateField', - 'datetime': 'DateTimeField', - 'time': 'TimeField', -} +from django.db.backends import BaseDatabaseIntrospection # This light wrapper "fakes" a dictionary interface, because some SQLite data # types include variables in them -- e.g. "varchar(30)" -- and can't be matched # as a simple dictionary lookup. class FlexibleFieldLookupDict: + # Maps SQL types to Django Field types. Some of the SQL types have multiple + # entries here because SQLite allows for anything and doesn't normalize the + # field type; it uses whatever was given. + base_data_types_reverse = { + 'bool': 'BooleanField', + 'boolean': 'BooleanField', + 'smallint': 'SmallIntegerField', + 'smallinteger': 'SmallIntegerField', + 'int': 'IntegerField', + 'integer': 'IntegerField', + 'text': 'TextField', + 'char': 'CharField', + 'date': 'DateField', + 'datetime': 'DateTimeField', + 'time': 'TimeField', + } + def __getitem__(self, key): key = key.lower() try: - return BASE_DATA_TYPES_REVERSE[key] + return self.base_data_types_reverse[key] except KeyError: import re m = re.search(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$', key) @@ -86,4 +32,58 @@ class FlexibleFieldLookupDict: return ('CharField', {'max_length': int(m.group(1))}) raise KeyError -DATA_TYPES_REVERSE = FlexibleFieldLookupDict() +class DatabaseIntrospection(BaseDatabaseIntrospection): + data_types_reverse = FlexibleFieldLookupDict() + + def get_table_list(self, cursor): + "Returns a list of table names in the current database." + # Skip the sqlite_sequence system table used for autoincrement key + # generation. + cursor.execute(""" + SELECT name FROM sqlite_master + WHERE type='table' AND NOT name='sqlite_sequence' + ORDER BY name""") + return [row[0] for row in cursor.fetchall()] + + def get_table_description(self, cursor, table_name): + "Returns a description of the table, with the DB-API cursor.description interface." + return [(info['name'], info['type'], None, None, None, None, + info['null_ok']) for info in self._table_info(cursor, table_name)] + + def get_relations(self, cursor, table_name): + raise NotImplementedError + + def get_indexes(self, cursor, table_name): + """ + Returns a dictionary of fieldname -> infodict for the given table, + where each infodict is in the format: + {'primary_key': boolean representing whether it's the primary key, + 'unique': boolean representing whether it's a unique index} + """ + indexes = {} + for info in self._table_info(cursor, table_name): + indexes[info['name']] = {'primary_key': info['pk'] != 0, + 'unique': False} + cursor.execute('PRAGMA index_list(%s)' % self.connection.ops.quote_name(table_name)) + # seq, name, unique + for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]: + if not unique: + continue + cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index)) + info = cursor.fetchall() + # Skip indexes across multiple fields + if len(info) != 1: + continue + name = info[0][2] # seqno, cid, name + indexes[name]['unique'] = True + return indexes + + def _table_info(self, cursor, name): + cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(name)) + # cid, name, type, notnull, dflt_value, pk + return [{'name': field[1], + 'type': field[2], + 'null_ok': not field[3], + 'pk': field[5] # undocumented + } for field in cursor.fetchall()] + diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index f19fb258a3..788c8f470d 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -7,7 +7,7 @@ try: except ImportError: from django.utils import _decimal as decimal # for Python 2.3 -from django.db import connection, get_creation_module +from django.db import connection from django.db.models import signals from django.db.models.query_utils import QueryWrapper from django.dispatch import dispatcher @@ -145,14 +145,14 @@ class Field(object): # as the TextField Django field type, which means XMLField's # get_internal_type() returns 'TextField'. # - # But the limitation of the get_internal_type() / DATA_TYPES approach + # But the limitation of the get_internal_type() / data_types approach # is that it cannot handle database column types that aren't already # mapped to one of the built-in Django field types. In this case, you # can implement db_type() instead of get_internal_type() to specify # exactly which wacky database column type you want to use. data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") try: - return get_creation_module().DATA_TYPES[self.get_internal_type()] % data + return connection.creation.data_types[self.get_internal_type()] % data except KeyError: return None diff --git a/django/test/simple.py b/django/test/simple.py index 43e1156a0a..ce9f59e90e 100644 --- a/django/test/simple.py +++ b/django/test/simple.py @@ -3,7 +3,6 @@ from django.conf import settings from django.db.models import get_app, get_apps from django.test import _doctest as doctest from django.test.utils import setup_test_environment, teardown_test_environment -from django.test.utils import create_test_db, destroy_test_db from django.test.testcases import OutputChecker, DocTestRunner # The module name for tests outside models.py @@ -139,9 +138,10 @@ def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]): suite.addTest(test) old_name = settings.DATABASE_NAME - create_test_db(verbosity, autoclobber=not interactive) + from django.db import connection + connection.creation.create_test_db(verbosity, autoclobber=not interactive) result = unittest.TextTestRunner(verbosity=verbosity).run(suite) - destroy_test_db(old_name, verbosity) + connection.creation.destroy_test_db(old_name, verbosity) teardown_test_environment() diff --git a/django/test/utils.py b/django/test/utils.py index 733307a1c0..69bd25bc12 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -1,16 +1,11 @@ import sys, time, os from django.conf import settings -from django.db import connection, get_creation_module +from django.db import connection from django.core import mail -from django.core.management import call_command from django.test import signals from django.template import Template from django.utils.translation import deactivate -# The prefix to put on the default database name when creating -# the test database. -TEST_DATABASE_PREFIX = 'test_' - def instrumented_test_render(self, context): """ An instrumented Template render method, providing a signal @@ -70,147 +65,3 @@ def teardown_test_environment(): del mail.outbox -def _set_autocommit(connection): - "Make sure a connection is in autocommit mode." - if hasattr(connection.connection, "autocommit"): - if callable(connection.connection.autocommit): - connection.connection.autocommit(True) - else: - connection.connection.autocommit = True - elif hasattr(connection.connection, "set_isolation_level"): - connection.connection.set_isolation_level(0) - -def get_mysql_create_suffix(): - suffix = [] - if settings.TEST_DATABASE_CHARSET: - suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET) - if settings.TEST_DATABASE_COLLATION: - suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION) - return ' '.join(suffix) - -def get_postgresql_create_suffix(): - assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time." - if settings.TEST_DATABASE_CHARSET: - return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET - return '' - -def create_test_db(verbosity=1, autoclobber=False): - """ - Creates a test database, prompting the user for confirmation if the - database already exists. Returns the name of the test database created. - """ - # If the database backend wants to create the test DB itself, let it - creation_module = get_creation_module() - if hasattr(creation_module, "create_test_db"): - creation_module.create_test_db(settings, connection, verbosity, autoclobber) - return - - if verbosity >= 1: - print "Creating test database..." - # If we're using SQLite, it's more convenient to test against an - # in-memory database. Using the TEST_DATABASE_NAME setting you can still choose - # to run on a physical database. - if settings.DATABASE_ENGINE == "sqlite3": - if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:": - TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME - # Erase the old test database - if verbosity >= 1: - print "Destroying old test database..." - if os.access(TEST_DATABASE_NAME, os.F_OK): - if not autoclobber: - confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME) - if autoclobber or confirm == 'yes': - try: - if verbosity >= 1: - print "Destroying old test database..." - os.remove(TEST_DATABASE_NAME) - except Exception, e: - sys.stderr.write("Got an error deleting the old test database: %s\n" % e) - sys.exit(2) - else: - print "Tests cancelled." - sys.exit(1) - if verbosity >= 1: - print "Creating test database..." - else: - TEST_DATABASE_NAME = ":memory:" - else: - suffix = { - 'postgresql': get_postgresql_create_suffix, - 'postgresql_psycopg2': get_postgresql_create_suffix, - 'mysql': get_mysql_create_suffix, - }.get(settings.DATABASE_ENGINE, lambda: '')() - if settings.TEST_DATABASE_NAME: - TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME - else: - TEST_DATABASE_NAME = TEST_DATABASE_PREFIX + settings.DATABASE_NAME - - qn = connection.ops.quote_name - - # Create the test database and connect to it. We need to autocommit - # if the database supports it because PostgreSQL doesn't allow - # CREATE/DROP DATABASE statements within transactions. - cursor = connection.cursor() - _set_autocommit(connection) - try: - cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix)) - except Exception, e: - sys.stderr.write("Got an error creating the test database: %s\n" % e) - if not autoclobber: - confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME) - if autoclobber or confirm == 'yes': - try: - if verbosity >= 1: - print "Destroying old test database..." - cursor.execute("DROP DATABASE %s" % qn(TEST_DATABASE_NAME)) - if verbosity >= 1: - print "Creating test database..." - cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix)) - except Exception, e: - sys.stderr.write("Got an error recreating the test database: %s\n" % e) - sys.exit(2) - else: - print "Tests cancelled." - sys.exit(1) - - connection.close() - settings.DATABASE_NAME = TEST_DATABASE_NAME - - call_command('syncdb', verbosity=verbosity, interactive=False) - - if settings.CACHE_BACKEND.startswith('db://'): - cache_name = settings.CACHE_BACKEND[len('db://'):] - call_command('createcachetable', cache_name) - - # Get a cursor (even though we don't need one yet). This has - # the side effect of initializing the test database. - cursor = connection.cursor() - - return TEST_DATABASE_NAME - -def destroy_test_db(old_database_name, verbosity=1): - # If the database wants to drop the test DB itself, let it - creation_module = get_creation_module() - if hasattr(creation_module, "destroy_test_db"): - creation_module.destroy_test_db(settings, connection, old_database_name, verbosity) - return - - if verbosity >= 1: - print "Destroying test database..." - connection.close() - TEST_DATABASE_NAME = settings.DATABASE_NAME - settings.DATABASE_NAME = old_database_name - if settings.DATABASE_ENGINE == "sqlite3": - if TEST_DATABASE_NAME and TEST_DATABASE_NAME != ":memory:": - # Remove the SQLite database file - os.remove(TEST_DATABASE_NAME) - else: - # Remove the test database to clean up after - # ourselves. Connect to the previous database (not the test database) - # to do so, because it's not allowed to delete a database while being - # connected to it. - cursor = connection.cursor() - _set_autocommit(connection) - time.sleep(1) # To avoid "database is being accessed by other users" errors. - cursor.execute("DROP DATABASE %s" % connection.ops.quote_name(TEST_DATABASE_NAME)) - connection.close() diff --git a/docs/testing.txt b/docs/testing.txt index c9b23c2948..85f36518a3 100644 --- a/docs/testing.txt +++ b/docs/testing.txt @@ -1026,6 +1026,9 @@ a number of utility methods in the ``django.test.utils`` module. black magic hooks into the template system and restoring normal e-mail services. +The creation module of the database backend (``connection.creation``) also +provides some utilities that can be useful during testing. + ``create_test_db(verbosity=1, autoclobber=False)`` Creates a new test database and runs ``syncdb`` against it. @@ -1044,7 +1047,7 @@ a number of utility methods in the ``django.test.utils`` module. ``create_test_db()`` has the side effect of modifying ``settings.DATABASE_NAME`` to match the name of the test database. - New in the Django development version, this function returns the name of + **New in Django development version:** This function returns the name of the test database that it created. ``destroy_test_db(old_database_name, verbosity=1)`` diff --git a/tests/regressiontests/backends/models.py b/tests/regressiontests/backends/models.py index a041ab6b12..61b7b1af73 100644 --- a/tests/regressiontests/backends/models.py +++ b/tests/regressiontests/backends/models.py @@ -15,10 +15,6 @@ class Person(models.Model): def __unicode__(self): return u'%s %s' % (self.first_name, self.last_name) -if connection.features.uses_case_insensitive_names: - t_convert = lambda x: x.upper() -else: - t_convert = lambda x: x qn = connection.ops.quote_name __test__ = {'API_TESTS': """ @@ -29,7 +25,7 @@ __test__ = {'API_TESTS': """ >>> opts = Square._meta >>> f1, f2 = opts.get_field('root'), opts.get_field('square') >>> query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)' -... % (t_convert(opts.db_table), qn(f1.column), qn(f2.column))) +... % (connection.introspection.table_name_converter(opts.db_table), qn(f1.column), qn(f2.column))) >>> cursor.executemany(query, [(i, i**2) for i in range(-5, 6)]) and None or None >>> Square.objects.order_by('root') [, , , , , , , , , , ] @@ -48,7 +44,7 @@ __test__ = {'API_TESTS': """ >>> opts2 = Person._meta >>> f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name') >>> query2 = ('SELECT %s, %s FROM %s ORDER BY %s' -... % (qn(f3.column), qn(f4.column), t_convert(opts2.db_table), +... % (qn(f3.column), qn(f4.column), connection.introspection.table_name_converter(opts2.db_table), ... qn(f3.column))) >>> cursor.execute(query2) and None or None >>> cursor.fetchone()