Fixed #5461 -- Refactored the database backend code to use classes for the creation and introspection modules. Introduces a new validation module for DB-specific validation. This is a backwards incompatible change; see the wiki for details.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@8296 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Russell Keith-Magee 2008-08-11 12:11:25 +00:00
parent cec69eb70d
commit 9dc4ba875f
43 changed files with 1528 additions and 1423 deletions

View File

@ -1,5 +1,5 @@
from django.test.utils import create_test_db
def create_spatial_db(test=True, verbosity=1, autoclobber=False):
if not test: raise NotImplementedError('This uses `create_test_db` from test/utils.py')
create_test_db(verbosity, autoclobber)
from django.db import connection
connection.creation.create_test_db(verbosity, autoclobber)

View File

@ -1,8 +1,6 @@
from django.db.backends.oracle.creation import create_test_db
def create_spatial_db(test=True, verbosity=1, autoclobber=False):
"A wrapper over the Oracle `create_test_db` routine."
if not test: raise NotImplementedError('This uses `create_test_db` from db/backends/oracle/creation.py')
from django.conf import settings
from django.db import connection
create_test_db(settings, connection, verbosity, autoclobber)
connection.creation.create_test_db(verbosity, autoclobber)

View File

@ -1,7 +1,7 @@
from django.conf import settings
from django.core.management import call_command
from django.db import connection
from django.test.utils import _set_autocommit, TEST_DATABASE_PREFIX
from django.db.backends.creation import TEST_DATABASE_PREFIX
import os, re, sys
def getstatusoutput(cmd):
@ -38,9 +38,9 @@ def _create_with_cursor(db_name, verbosity=1, autoclobber=False):
create_sql = 'CREATE DATABASE %s' % connection.ops.quote_name(db_name)
if settings.DATABASE_USER:
create_sql += ' OWNER %s' % settings.DATABASE_USER
cursor = connection.cursor()
_set_autocommit(connection)
connection.creation.set_autocommit(connection)
try:
# Trying to create the database first.
@ -58,12 +58,12 @@ def _create_with_cursor(db_name, verbosity=1, autoclobber=False):
else:
raise Exception('Spatial Database Creation canceled.')
foo = _create_with_cursor
created_regex = re.compile(r'^createdb: database creation failed: ERROR: database ".+" already exists')
def _create_with_shell(db_name, verbosity=1, autoclobber=False):
"""
If no spatial database already exists, then using a cursor will not work.
Thus, a `createdb` command will be issued through the shell to bootstrap
If no spatial database already exists, then using a cursor will not work.
Thus, a `createdb` command will be issued through the shell to bootstrap
creation of the spatial database.
"""
@ -83,7 +83,7 @@ def _create_with_shell(db_name, verbosity=1, autoclobber=False):
if verbosity >= 1: print 'Destroying old spatial database...'
drop_cmd = 'dropdb %s%s' % (options, db_name)
status, output = getstatusoutput(drop_cmd)
if status != 0:
if status != 0:
raise Exception('Could not drop database %s: %s' % (db_name, output))
if verbosity >= 1: print 'Creating new spatial database...'
status, output = getstatusoutput(create_cmd)
@ -102,10 +102,10 @@ def create_spatial_db(test=False, verbosity=1, autoclobber=False, interactive=Fa
raise Exception('Spatial database creation only supported postgresql_psycopg2 platform.')
# Getting the spatial database name
if test:
if test:
db_name = get_spatial_db(test=True)
_create_with_cursor(db_name, verbosity=verbosity, autoclobber=autoclobber)
else:
else:
db_name = get_spatial_db()
_create_with_shell(db_name, verbosity=verbosity, autoclobber=autoclobber)
@ -125,7 +125,7 @@ def create_spatial_db(test=False, verbosity=1, autoclobber=False, interactive=Fa
# Syncing the database
call_command('syncdb', verbosity=verbosity, interactive=interactive)
def drop_db(db_name=False, test=False):
"""
Drops the given database (defaults to what is returned from
@ -151,7 +151,7 @@ def get_cmd_options(db_name):
def get_spatial_db(test=False):
"""
Returns the name of the spatial database. The 'test' keyword may be set
Returns the name of the spatial database. The 'test' keyword may be set
to return the test spatial database name.
"""
if test:
@ -167,13 +167,13 @@ def get_spatial_db(test=False):
def load_postgis_sql(db_name, verbosity=1):
"""
This routine loads up the PostGIS SQL files lwpostgis.sql and
This routine loads up the PostGIS SQL files lwpostgis.sql and
spatial_ref_sys.sql.
"""
# Getting the path to the PostGIS SQL
try:
# POSTGIS_SQL_PATH may be placed in settings to tell GeoDjango where the
# POSTGIS_SQL_PATH may be placed in settings to tell GeoDjango where the
# PostGIS SQL files are located. This is especially useful on Win32
# platforms since the output of pg_config looks like "C:/PROGRA~1/..".
sql_path = settings.POSTGIS_SQL_PATH
@ -193,7 +193,7 @@ def load_postgis_sql(db_name, verbosity=1):
# Getting the psql command-line options, and command format.
options = get_cmd_options(db_name)
cmd_fmt = 'psql %s-f "%%s"' % options
# Now trying to load up the PostGIS functions
cmd = cmd_fmt % lwpostgis_file
if verbosity >= 1: print cmd
@ -211,8 +211,8 @@ def load_postgis_sql(db_name, verbosity=1):
# Setting the permissions because on Windows platforms the owner
# of the spatial_ref_sys and geometry_columns tables is always
# the postgres user, regardless of how the db is created.
if os.name == 'nt': set_permissions(db_name)
if os.name == 'nt': set_permissions(db_name)
def set_permissions(db_name):
"""
Sets the permissions on the given database to that of the user specified

View File

@ -7,7 +7,7 @@ from django.core.management.commands.inspectdb import Command as InspectCommand
from django.contrib.gis.db.backend import SpatialBackend
class Command(InspectCommand):
# Mapping from lower-case OGC type to the corresponding GeoDjango field.
geofield_mapping = {'point' : 'PointField',
'linestring' : 'LineStringField',
@ -21,11 +21,11 @@ class Command(InspectCommand):
def geometry_columns(self):
"""
Returns a datastructure of metadata information associated with the
Returns a datastructure of metadata information associated with the
`geometry_columns` (or equivalent) table.
"""
# The `geo_cols` is a dictionary data structure that holds information
# about any geographic columns in the database.
# about any geographic columns in the database.
geo_cols = {}
def add_col(table, column, coldata):
if table in geo_cols:
@ -47,7 +47,7 @@ class Command(InspectCommand):
elif SpatialBackend.name == 'mysql':
# On MySQL have to get all table metadata before hand; this means walking through
# each table and seeing if any column types are spatial. Can't detect this with
# `cursor.description` (what the introspection module does) because all spatial types
# `cursor.description` (what the introspection module does) because all spatial types
# have the same integer type (255 for GEOMETRY).
from django.db import connection
cursor = connection.cursor()
@ -67,13 +67,11 @@ class Command(InspectCommand):
def handle_inspection(self):
"Overloaded from Django's version to handle geographic database tables."
from django.db import connection, get_introspection_module
from django.db import connection
import keyword
introspection_module = get_introspection_module()
geo_cols = self.geometry_columns()
table2model = lambda table_name: table_name.title().replace('_', '')
cursor = connection.cursor()
@ -88,20 +86,20 @@ class Command(InspectCommand):
yield ''
yield 'from django.contrib.gis.db import models'
yield ''
for table_name in introspection_module.get_table_list(cursor):
for table_name in connection.introspection.get_table_list(cursor):
# Getting the geographic table dictionary.
geo_table = geo_cols.get(table_name, {})
yield 'class %s(models.Model):' % table2model(table_name)
try:
relations = introspection_module.get_relations(cursor, table_name)
relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError:
relations = {}
try:
indexes = introspection_module.get_indexes(cursor, table_name)
indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError:
indexes = {}
for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)):
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
att_name, iatt_name = row[0].lower(), row[0]
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'.
@ -133,12 +131,12 @@ class Command(InspectCommand):
if srid != 4326: extra_params['srid'] = srid
else:
try:
field_type = introspection_module.DATA_TYPES_REVERSE[row[1]]
field_type = connection.introspection.data_types_reverse[row[1]]
except KeyError:
field_type = 'TextField'
comment_notes.append('This field type is a guess.')
# This is a hook for DATA_TYPES_REVERSE to return a tuple of
# This is a hook for data_types_reverse to return a tuple of
# (field_type, extra_params_dict).
if type(field_type) is tuple:
field_type, new_params = field_type

View File

@ -6,5 +6,5 @@ class Command(NoArgsCommand):
requires_model_validation = False
def handle_noargs(self, **options):
from django.db import runshell
runshell()
from django.db import connection
connection.client.runshell()

View File

@ -13,11 +13,9 @@ class Command(NoArgsCommand):
raise CommandError("Database inspection isn't supported for the currently selected database backend.")
def handle_inspection(self):
from django.db import connection, get_introspection_module
from django.db import connection
import keyword
introspection_module = get_introspection_module()
table2model = lambda table_name: table_name.title().replace('_', '')
cursor = connection.cursor()
@ -32,17 +30,17 @@ class Command(NoArgsCommand):
yield ''
yield 'from django.db import models'
yield ''
for table_name in introspection_module.get_table_list(cursor):
for table_name in connection.introspection.get_table_list(cursor):
yield 'class %s(models.Model):' % table2model(table_name)
try:
relations = introspection_module.get_relations(cursor, table_name)
relations = connection.introspection.get_relations(cursor, table_name)
except NotImplementedError:
relations = {}
try:
indexes = introspection_module.get_indexes(cursor, table_name)
indexes = connection.introspection.get_indexes(cursor, table_name)
except NotImplementedError:
indexes = {}
for i, row in enumerate(introspection_module.get_table_description(cursor, table_name)):
for i, row in enumerate(connection.introspection.get_table_description(cursor, table_name)):
att_name = row[0].lower()
comment_notes = [] # Holds Field notes, to be displayed in a Python comment.
extra_params = {} # Holds Field parameters such as 'db_column'.
@ -65,7 +63,7 @@ class Command(NoArgsCommand):
extra_params['db_column'] = att_name
else:
try:
field_type = introspection_module.DATA_TYPES_REVERSE[row[1]]
field_type = connection.introspection.data_types_reverse[row[1]]
except KeyError:
field_type = 'TextField'
comment_notes.append('This field type is a guess.')

View File

@ -21,7 +21,7 @@ class Command(NoArgsCommand):
def handle_noargs(self, **options):
from django.db import connection, transaction, models
from django.conf import settings
from django.core.management.sql import table_names, installed_models, sql_model_create, sql_for_pending_references, many_to_many_sql_for_model, custom_sql_for_model, sql_indexes_for_model, emit_post_sync_signal
from django.core.management.sql import custom_sql_for_model, emit_post_sync_signal
verbosity = int(options.get('verbosity', 1))
interactive = options.get('interactive')
@ -50,16 +50,9 @@ class Command(NoArgsCommand):
cursor = connection.cursor()
if connection.features.uses_case_insensitive_names:
table_name_converter = lambda x: x.upper()
else:
table_name_converter = lambda x: x
# Get a list of all existing database tables, so we know what needs to
# be added.
tables = [table_name_converter(name) for name in table_names()]
# Get a list of already installed *models* so that references work right.
seen_models = installed_models(tables)
tables = connection.introspection.table_names()
seen_models = connection.introspection.installed_models(tables)
created_models = set()
pending_references = {}
@ -71,21 +64,21 @@ class Command(NoArgsCommand):
# Create the model's database table, if it doesn't already exist.
if verbosity >= 2:
print "Processing %s.%s model" % (app_name, model._meta.object_name)
if table_name_converter(model._meta.db_table) in tables:
if connection.introspection.table_name_converter(model._meta.db_table) in tables:
continue
sql, references = sql_model_create(model, self.style, seen_models)
sql, references = connection.creation.sql_create_model(model, self.style, seen_models)
seen_models.add(model)
created_models.add(model)
for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs)
if refto in seen_models:
sql.extend(sql_for_pending_references(refto, self.style, pending_references))
sql.extend(sql_for_pending_references(model, self.style, pending_references))
sql.extend(connection.creation.sql_for_pending_references(refto, self.style, pending_references))
sql.extend(connection.creation.sql_for_pending_references(model, self.style, pending_references))
if verbosity >= 1:
print "Creating table %s" % model._meta.db_table
for statement in sql:
cursor.execute(statement)
tables.append(table_name_converter(model._meta.db_table))
tables.append(connection.introspection.table_name_converter(model._meta.db_table))
# Create the m2m tables. This must be done after all tables have been created
# to ensure that all referred tables will exist.
@ -94,7 +87,7 @@ class Command(NoArgsCommand):
model_list = models.get_models(app)
for model in model_list:
if model in created_models:
sql = many_to_many_sql_for_model(model, self.style)
sql = connection.creation.sql_for_many_to_many(model, self.style)
if sql:
if verbosity >= 2:
print "Creating many-to-many tables for %s.%s model" % (app_name, model._meta.object_name)
@ -140,7 +133,7 @@ class Command(NoArgsCommand):
app_name = app.__name__.split('.')[-2]
for model in models.get_models(app):
if model in created_models:
index_sql = sql_indexes_for_model(model, self.style)
index_sql = connection.creation.sql_indexes_for_model(model, self.style)
if index_sql:
if verbosity >= 1:
print "Installing index for %s.%s model" % (app_name, model._meta.object_name)

View File

@ -18,13 +18,13 @@ class Command(BaseCommand):
def handle(self, *fixture_labels, **options):
from django.core.management import call_command
from django.test.utils import create_test_db
from django.db import connection
verbosity = int(options.get('verbosity', 1))
addrport = options.get('addrport')
# Create a test database.
db_name = create_test_db(verbosity=verbosity)
db_name = connection.creation.create_test_db(verbosity=verbosity)
# Import the fixture data into the test database.
call_command('loaddata', *fixture_labels, **{'verbosity': verbosity})

View File

@ -7,65 +7,9 @@ try:
except NameError:
from sets import Set as set # Python 2.3 fallback
def table_names():
"Returns a list of all table names that exist in the database."
from django.db import connection, get_introspection_module
cursor = connection.cursor()
return set(get_introspection_module().get_table_list(cursor))
def django_table_names(only_existing=False):
"""
Returns a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, the resulting list will only include the tables
that actually exist in the database.
"""
from django.db import models
tables = set()
for app in models.get_apps():
for model in models.get_models(app):
tables.add(model._meta.db_table)
tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
tables = [t for t in tables if t in table_names()]
return tables
def installed_models(table_list):
"Returns a set of all models that are installed, given a list of existing table names."
from django.db import connection, models
all_models = []
for app in models.get_apps():
for model in models.get_models(app):
all_models.append(model)
if connection.features.uses_case_insensitive_names:
converter = lambda x: x.upper()
else:
converter = lambda x: x
return set([m for m in all_models if converter(m._meta.db_table) in map(converter, table_list)])
def sequence_list():
"Returns a list of information about all DB sequences for all models in all apps."
from django.db import models
apps = models.get_apps()
sequence_list = []
for app in apps:
for model in models.get_models(app):
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
def sql_create(app, style):
"Returns a list of the CREATE TABLE SQL statements for the given app."
from django.db import models
from django.db import connection, models
from django.conf import settings
if settings.DATABASE_ENGINE == 'dummy':
@ -81,23 +25,24 @@ def sql_create(app, style):
# we can be conservative).
app_models = models.get_models(app)
final_output = []
known_models = set([model for model in installed_models(table_names()) if model not in app_models])
tables = connection.introspection.table_names()
known_models = set([model for model in connection.introspection.installed_models(tables) if model not in app_models])
pending_references = {}
for model in app_models:
output, references = sql_model_create(model, style, known_models)
output, references = connection.creation.sql_create_model(model, style, known_models)
final_output.extend(output)
for refto, refs in references.items():
pending_references.setdefault(refto, []).extend(refs)
if refto in known_models:
final_output.extend(sql_for_pending_references(refto, style, pending_references))
final_output.extend(sql_for_pending_references(model, style, pending_references))
final_output.extend(connection.creation.sql_for_pending_references(refto, style, pending_references))
final_output.extend(connection.creation.sql_for_pending_references(model, style, pending_references))
# Keep track of the fact that we've created the table for this model.
known_models.add(model)
# Create the many-to-many join tables.
for model in app_models:
final_output.extend(many_to_many_sql_for_model(model, style))
final_output.extend(connection.creation.sql_for_many_to_many(model, style))
# Handle references to tables that are from other apps
# but don't exist physically.
@ -106,7 +51,7 @@ def sql_create(app, style):
alter_sql = []
for model in not_installed_models:
alter_sql.extend(['-- ' + sql for sql in
sql_for_pending_references(model, style, pending_references)])
connection.creation.sql_for_pending_references(model, style, pending_references)])
if alter_sql:
final_output.append('-- The following references should be added but depend on non-existent tables:')
final_output.extend(alter_sql)
@ -115,10 +60,9 @@ def sql_create(app, style):
def sql_delete(app, style):
"Returns a list of the DROP TABLE SQL statements for the given app."
from django.db import connection, models, get_introspection_module
from django.db import connection, models
from django.db.backends.util import truncate_name
from django.contrib.contenttypes import generic
introspection = get_introspection_module()
# This should work even if a connection isn't available
try:
@ -128,16 +72,11 @@ def sql_delete(app, style):
# Figure out which tables already exist
if cursor:
table_names = introspection.get_table_list(cursor)
table_names = connection.introspection.get_table_list(cursor)
else:
table_names = []
if connection.features.uses_case_insensitive_names:
table_name_converter = lambda x: x.upper()
else:
table_name_converter = lambda x: x
output = []
qn = connection.ops.quote_name
# Output DROP TABLE statements for standard application tables.
to_delete = set()
@ -145,7 +84,7 @@ def sql_delete(app, style):
references_to_delete = {}
app_models = models.get_models(app)
for model in app_models:
if cursor and table_name_converter(model._meta.db_table) in table_names:
if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped
opts = model._meta
for f in opts.local_fields:
@ -155,40 +94,15 @@ def sql_delete(app, style):
to_delete.add(model)
for model in app_models:
if cursor and table_name_converter(model._meta.db_table) in table_names:
# Drop the table now
output.append('%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(model._meta.db_table))))
if connection.features.supports_constraints and model in references_to_delete:
for rel_class, f in references_to_delete[model]:
table = rel_class._meta.db_table
col = f.column
r_table = model._meta.db_table
r_col = model._meta.get_field(f.rel.field_name).column
r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table))))
output.append('%s %s %s %s;' % \
(style.SQL_KEYWORD('ALTER TABLE'),
style.SQL_TABLE(qn(table)),
style.SQL_KEYWORD(connection.ops.drop_foreignkey_sql()),
style.SQL_FIELD(truncate_name(r_name, connection.ops.max_name_length()))))
del references_to_delete[model]
if model._meta.has_auto_field:
ds = connection.ops.drop_sequence_sql(model._meta.db_table)
if ds:
output.append(ds)
if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
# Output DROP TABLE statements for many-to-many tables.
for model in app_models:
opts = model._meta
for f in opts.local_many_to_many:
if not f.creates_table:
continue
if cursor and table_name_converter(f.m2m_db_table()) in table_names:
output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(f.m2m_db_table()))))
ds = connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column))
if ds:
output.append(ds)
if cursor and connection.introspection.table_name_converter(f.m2m_db_table()) in table_names:
output.extend(connection.creation.sql_destroy_many_to_many(model, f, style))
app_label = app_models[0]._meta.app_label
@ -213,10 +127,10 @@ def sql_flush(style, only_django=False):
"""
from django.db import connection
if only_django:
tables = django_table_names()
tables = connection.introspection.django_table_names()
else:
tables = table_names()
statements = connection.ops.sql_flush(style, tables, sequence_list())
tables = connection.introspection.table_names()
statements = connection.ops.sql_flush(style, tables, connection.introspection.sequence_list())
return statements
def sql_custom(app, style):
@ -234,198 +148,16 @@ def sql_custom(app, style):
def sql_indexes(app, style):
"Returns a list of the CREATE INDEX SQL statements for all models in the given app."
from django.db import models
from django.db import connection, models
output = []
for model in models.get_models(app):
output.extend(sql_indexes_for_model(model, style))
output.extend(connection.creation.sql_indexes_for_model(model, style))
return output
def sql_all(app, style):
"Returns a list of CREATE TABLE SQL, initial-data inserts, and CREATE INDEX SQL for the given module."
return sql_create(app, style) + sql_custom(app, style) + sql_indexes(app, style)
def sql_model_create(model, style, known_models=set()):
"""
Returns the SQL required to create a single model, as a tuple of:
(list_of_sql, pending_references_dict)
"""
from django.db import connection, models
opts = model._meta
final_output = []
table_output = []
pending_references = {}
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
# Skip ManyToManyFields, because they're not represented as
# database columns in this table.
continue
# Make the definition (e.g. 'foo VARCHAR(30)') for this field.
field_output = [style.SQL_FIELD(qn(f.column)),
style.SQL_COLTYPE(col_type)]
field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or '')))
if f.primary_key:
field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))
elif f.unique:
field_output.append(style.SQL_KEYWORD('UNIQUE'))
if tablespace and connection.features.supports_tablespaces and f.unique:
# We must specify the index tablespace inline, because we
# won't be generating a CREATE INDEX statement for this field.
field_output.append(connection.ops.tablespace_sql(tablespace, inline=True))
if f.rel:
if inline_references and f.rel.to in known_models:
field_output.append(style.SQL_KEYWORD('REFERENCES') + ' ' + \
style.SQL_TABLE(qn(f.rel.to._meta.db_table)) + ' (' + \
style.SQL_FIELD(qn(f.rel.to._meta.get_field(f.rel.field_name).column)) + ')' +
connection.ops.deferrable_sql()
)
else:
# We haven't yet created the table to which this field
# is related, so save it for later.
pr = pending_references.setdefault(f.rel.to, []).append((model, f))
table_output.append(' '.join(field_output))
if opts.order_with_respect_to:
table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \
style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \
style.SQL_KEYWORD('NULL'))
for field_constraints in opts.unique_together:
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
for i, line in enumerate(table_output): # Combine and add commas.
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
full_statement.append(')')
if opts.db_tablespace and connection.features.supports_tablespaces:
full_statement.append(connection.ops.tablespace_sql(opts.db_tablespace))
full_statement.append(';')
final_output.append('\n'.join(full_statement))
if opts.has_auto_field:
# Add any extra SQL needed to support auto-incrementing primary keys.
auto_column = opts.auto_field.db_column or opts.auto_field.name
autoinc_sql = connection.ops.autoinc_sql(opts.db_table, auto_column)
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output, pending_references
def sql_for_pending_references(model, style, pending_references):
"""
Returns any ALTER TABLE statements to add constraints after the fact.
"""
from django.db import connection
from django.db.backends.util import truncate_name
qn = connection.ops.quote_name
final_output = []
if connection.features.supports_constraints:
opts = model._meta
if model in pending_references:
for rel_class, f in pending_references[model]:
rel_opts = rel_class._meta
r_table = rel_opts.db_table
r_col = f.column
table = opts.db_table
col = opts.get_field(f.rel.field_name).column
# For MySQL, r_name must be unique in the first 64 characters.
# So we are careful with character usage here.
r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
(qn(r_table), truncate_name(r_name, connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
connection.ops.deferrable_sql()))
del pending_references[model]
return final_output
def many_to_many_sql_for_model(model, style):
from django.db import connection, models
from django.contrib.contenttypes import generic
from django.db.backends.util import truncate_name
opts = model._meta
final_output = []
qn = connection.ops.quote_name
inline_references = connection.features.inline_fk_references
for f in opts.local_many_to_many:
if f.creates_table:
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace and connection.features.supports_tablespaces:
tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace, inline=True)
else:
tablespace_sql = ''
table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \
style.SQL_TABLE(qn(f.m2m_db_table())) + ' (']
table_output.append(' %s %s %s%s,' %
(style.SQL_FIELD(qn('id')),
style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()),
style.SQL_KEYWORD('NOT NULL PRIMARY KEY'),
tablespace_sql))
if inline_references:
deferred = []
table_output.append(' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(opts.db_table)),
style.SQL_FIELD(qn(opts.pk.column)),
connection.ops.deferrable_sql()))
table_output.append(' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(f.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(f.rel.to._meta.db_table)),
style.SQL_FIELD(qn(f.rel.to._meta.pk.column)),
connection.ops.deferrable_sql()))
else:
table_output.append(' %s %s %s,' %
(style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL')))
table_output.append(' %s %s %s,' %
(style.SQL_FIELD(qn(f.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(f.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL')))
deferred = [
(f.m2m_db_table(), f.m2m_column_name(), opts.db_table,
opts.pk.column),
( f.m2m_db_table(), f.m2m_reverse_name(),
f.rel.to._meta.db_table, f.rel.to._meta.pk.column)
]
table_output.append(' %s (%s, %s)%s' %
(style.SQL_KEYWORD('UNIQUE'),
style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_FIELD(qn(f.m2m_reverse_name())),
tablespace_sql))
table_output.append(')')
if opts.db_tablespace and connection.features.supports_tablespaces:
# f.db_tablespace is only for indices, so ignore its value here.
table_output.append(connection.ops.tablespace_sql(opts.db_tablespace))
table_output.append(';')
final_output.append('\n'.join(table_output))
for r_table, r_col, table, col in deferred:
r_name = '%s_refs_%s_%x' % (r_col, col,
abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
(qn(r_table),
truncate_name(r_name, connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
connection.ops.deferrable_sql()))
# Add any extra SQL needed to support auto-incrementing PKs
autoinc_sql = connection.ops.autoinc_sql(f.m2m_db_table(), 'id')
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output
def custom_sql_for_model(model, style):
from django.db import models
from django.conf import settings
@ -461,28 +193,6 @@ def custom_sql_for_model(model, style):
return output
def sql_indexes_for_model(model, style):
"Returns the CREATE INDEX SQL statements for a single model"
from django.db import connection
output = []
qn = connection.ops.quote_name
for f in model._meta.local_fields:
if f.db_index and not f.unique:
tablespace = f.db_tablespace or model._meta.db_tablespace
if tablespace and connection.features.supports_tablespaces:
tablespace_sql = ' ' + connection.ops.tablespace_sql(tablespace)
else:
tablespace_sql = ''
output.append(
style.SQL_KEYWORD('CREATE INDEX') + ' ' + \
style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' + \
style.SQL_KEYWORD('ON') + ' ' + \
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' + \
"(%s)" % style.SQL_FIELD(qn(f.column)) + \
"%s;" % tablespace_sql
)
return output
def emit_post_sync_signal(created_models, verbosity, interactive):
from django.db import models

View File

@ -61,11 +61,8 @@ def get_validation_errors(outfile, app=None):
if f.db_index not in (None, True, False):
e.add(opts, '"%s": "db_index" should be either None, True or False.' % f.name)
# Check that max_length <= 255 if using older MySQL versions.
if settings.DATABASE_ENGINE == 'mysql':
db_version = connection.get_server_version()
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
e.add(opts, '"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' % (f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))
# Perform any backend-specific field validation.
connection.validation.validate_field(e, opts, f)
# Check to see if the related field will clash with any existing
# fields, m2m fields, m2m related objects or related objects

View File

@ -14,14 +14,12 @@ try:
# backends that ships with Django, so look there first.
_import_path = 'django.db.backends.'
backend = __import__('%s%s.base' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
creation = __import__('%s%s.creation' % (_import_path, settings.DATABASE_ENGINE), {}, {}, [''])
except ImportError, e:
# If the import failed, we might be looking for a database backend
# distributed external to Django. So we'll try that next.
try:
_import_path = ''
backend = __import__('%s.base' % settings.DATABASE_ENGINE, {}, {}, [''])
creation = __import__('%s.creation' % settings.DATABASE_ENGINE, {}, {}, [''])
except ImportError, e_user:
# The database backend wasn't found. Display a helpful error message
# listing all possible (built-in) database backends.
@ -29,27 +27,11 @@ except ImportError, e:
available_backends = [f for f in os.listdir(backend_dir) if not f.startswith('_') and not f.startswith('.') and not f.endswith('.py') and not f.endswith('.pyc')]
available_backends.sort()
if settings.DATABASE_ENGINE not in available_backends:
raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s" % \
(settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends)))
raise ImproperlyConfigured, "%r isn't an available database backend. Available options are: %s\nError was: %s" % \
(settings.DATABASE_ENGINE, ", ".join(map(repr, available_backends, e_user)))
else:
raise # If there's some other error, this must be an error in Django itself.
def _import_database_module(import_path='', module_name=''):
"""Lazily import a database module when requested."""
return __import__('%s%s.%s' % (import_path, settings.DATABASE_ENGINE, module_name), {}, {}, [''])
# We don't want to import the introspect module unless someone asks for it, so
# lazily load it on demmand.
get_introspection_module = curry(_import_database_module, _import_path, 'introspection')
def get_creation_module():
return creation
# We want runshell() to work the same way, but we have to treat it a
# little differently (since it just runs instead of returning a module like
# the above) and wrap the lazily-loaded runshell() method.
runshell = lambda: _import_database_module(_import_path, "client").runshell()
# Convenient aliases for backend bits.
connection = backend.DatabaseWrapper(**settings.DATABASE_OPTIONS)
DatabaseError = backend.DatabaseError

View File

@ -42,14 +42,9 @@ class BaseDatabaseWrapper(local):
return util.CursorDebugWrapper(cursor, self)
class BaseDatabaseFeatures(object):
allows_group_by_ordinal = True
inline_fk_references = True
# True if django.db.backend.utils.typecast_timestamp is used on values
# returned from dates() calls.
needs_datetime_string_cast = True
supports_constraints = True
supports_tablespaces = False
uses_case_insensitive_names = False
uses_custom_query_class = False
empty_fetchmany_value = []
update_can_self_select = True
@ -253,13 +248,13 @@ class BaseDatabaseOperations(object):
"""
return "BEGIN;"
def tablespace_sql(self, tablespace, inline=False):
def sql_for_tablespace(self, tablespace, inline=False):
"""
Returns the tablespace SQL, or None if the backend doesn't use
tablespaces.
Returns the SQL that will be appended to tables or rows to define
a tablespace. Returns '' if the backend doesn't use tablespaces.
"""
return None
return ''
def prep_for_like_query(self, x):
"""Prepares a value for use in a LIKE query."""
from django.utils.encoding import smart_unicode
@ -325,3 +320,89 @@ class BaseDatabaseOperations(object):
"""
return self.year_lookup_bounds(value)
class BaseDatabaseIntrospection(object):
"""
This class encapsulates all backend-specific introspection utilities
"""
data_types_reverse = {}
def __init__(self, connection):
self.connection = connection
def table_name_converter(self, name):
"""Apply a conversion to the name for the purposes of comparison.
The default table name converter is for case sensitive comparison.
"""
return name
def table_names(self):
"Returns a list of names of all tables that exist in the database."
cursor = self.connection.cursor()
return self.get_table_list(cursor)
def django_table_names(self, only_existing=False):
"""
Returns a list of all table names that have associated Django models and
are in INSTALLED_APPS.
If only_existing is True, the resulting list will only include the tables
that actually exist in the database.
"""
from django.db import models
tables = set()
for app in models.get_apps():
for model in models.get_models(app):
tables.add(model._meta.db_table)
tables.update([f.m2m_db_table() for f in model._meta.local_many_to_many])
if only_existing:
tables = [t for t in tables if t in self.table_names()]
return tables
def installed_models(self, tables):
"Returns a set of all models represented by the provided list of table names."
from django.db import models
all_models = []
for app in models.get_apps():
for model in models.get_models(app):
all_models.append(model)
return set([m for m in all_models
if self.table_name_converter(m._meta.db_table) in map(self.table_name_converter, tables)
])
def sequence_list(self):
"Returns a list of information about all DB sequences for all models in all apps."
from django.db import models
apps = models.get_apps()
sequence_list = []
for app in apps:
for model in models.get_models(app):
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
sequence_list.append({'table': model._meta.db_table, 'column': f.column})
break # Only one AutoField is allowed per model, so don't bother continuing.
for f in model._meta.local_many_to_many:
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
class BaseDatabaseClient(object):
"""
This class encapsualtes all backend-specific methods for opening a
client shell
"""
def runshell(self):
raise NotImplementedError()
class BaseDatabaseValidation(object):
"""
This class encapsualtes all backend-specific model validation.
"""
def validate_field(self, errors, opts, f):
"By default, there is no backend-specific validation"
pass

View File

@ -1,7 +1,396 @@
class BaseCreation(object):
import sys
import time
from django.conf import settings
from django.core.management import call_command
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
class BaseDatabaseCreation(object):
"""
This class encapsulates all backend-specific differences that pertain to
database *creation*, such as the column types to use for particular Django
Fields.
Fields, the SQL used to create and destroy tables, and the creation and
destruction of test databases.
"""
pass
data_types = {}
def __init__(self, connection):
self.connection = connection
def sql_create_model(self, model, style, known_models=set()):
"""
Returns the SQL required to create a single model, as a tuple of:
(list_of_sql, pending_references_dict)
"""
from django.db import models
opts = model._meta
final_output = []
table_output = []
pending_references = {}
qn = self.connection.ops.quote_name
for f in opts.local_fields:
col_type = f.db_type()
tablespace = f.db_tablespace or opts.db_tablespace
if col_type is None:
# Skip ManyToManyFields, because they're not represented as
# database columns in this table.
continue
# Make the definition (e.g. 'foo VARCHAR(30)') for this field.
field_output = [style.SQL_FIELD(qn(f.column)),
style.SQL_COLTYPE(col_type)]
field_output.append(style.SQL_KEYWORD('%sNULL' % (not f.null and 'NOT ' or '')))
if f.primary_key:
field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))
elif f.unique:
field_output.append(style.SQL_KEYWORD('UNIQUE'))
if tablespace and f.unique:
# We must specify the index tablespace inline, because we
# won't be generating a CREATE INDEX statement for this field.
field_output.append(self.connection.ops.tablespace_sql(tablespace, inline=True))
if f.rel:
ref_output, pending = self.sql_for_inline_foreign_key_references(f, known_models, style)
if pending:
pr = pending_references.setdefault(f.rel.to, []).append((model, f))
else:
field_output.extend(ref_output)
table_output.append(' '.join(field_output))
if opts.order_with_respect_to:
table_output.append(style.SQL_FIELD(qn('_order')) + ' ' + \
style.SQL_COLTYPE(models.IntegerField().db_type()) + ' ' + \
style.SQL_KEYWORD('NULL'))
for field_constraints in opts.unique_together:
table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' % \
", ".join([style.SQL_FIELD(qn(opts.get_field(f).column)) for f in field_constraints]))
full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + style.SQL_TABLE(qn(opts.db_table)) + ' (']
for i, line in enumerate(table_output): # Combine and add commas.
full_statement.append(' %s%s' % (line, i < len(table_output)-1 and ',' or ''))
full_statement.append(')')
if opts.db_tablespace:
full_statement.append(self.connection.ops.tablespace_sql(opts.db_tablespace))
full_statement.append(';')
final_output.append('\n'.join(full_statement))
if opts.has_auto_field:
# Add any extra SQL needed to support auto-incrementing primary keys.
auto_column = opts.auto_field.db_column or opts.auto_field.name
autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table, auto_column)
if autoinc_sql:
for stmt in autoinc_sql:
final_output.append(stmt)
return final_output, pending_references
def sql_for_inline_foreign_key_references(self, field, known_models, style):
"Return the SQL snippet defining the foreign key reference for a field"
qn = self.connection.ops.quote_name
if field.rel.to in known_models:
output = [style.SQL_KEYWORD('REFERENCES') + ' ' + \
style.SQL_TABLE(qn(field.rel.to._meta.db_table)) + ' (' + \
style.SQL_FIELD(qn(field.rel.to._meta.get_field(field.rel.field_name).column)) + ')' +
self.connection.ops.deferrable_sql()
]
pending = False
else:
# We haven't yet created the table to which this field
# is related, so save it for later.
output = []
pending = True
return output, pending
def sql_for_pending_references(self, model, style, pending_references):
"Returns any ALTER TABLE statements to add constraints after the fact."
from django.db.backends.util import truncate_name
qn = self.connection.ops.quote_name
final_output = []
opts = model._meta
if model in pending_references:
for rel_class, f in pending_references[model]:
rel_opts = rel_class._meta
r_table = rel_opts.db_table
r_col = f.column
table = opts.db_table
col = opts.get_field(f.rel.field_name).column
# For MySQL, r_name must be unique in the first 64 characters.
# So we are careful with character usage here.
r_name = '%s_refs_%s_%x' % (r_col, col, abs(hash((r_table, table))))
final_output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % \
(qn(r_table), truncate_name(r_name, self.connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
self.connection.ops.deferrable_sql()))
del pending_references[model]
return final_output
def sql_for_many_to_many(self, model, style):
"Return the CREATE TABLE statments for all the many-to-many tables defined on a model"
output = []
for f in model._meta.local_many_to_many:
output.extend(self.sql_for_many_to_many_field(model, f, style))
return output
def sql_for_many_to_many_field(self, model, f, style):
"Return the CREATE TABLE statements for a single m2m field"
from django.db import models
from django.db.backends.util import truncate_name
output = []
if f.creates_table:
opts = model._meta
qn = self.connection.ops.quote_name
tablespace = f.db_tablespace or opts.db_tablespace
if tablespace:
sql = self.connection.ops.tablespace_sql(tablespace, inline=True)
if sql:
tablespace_sql = ' ' + sql
else:
tablespace_sql = ''
else:
tablespace_sql = ''
table_output = [style.SQL_KEYWORD('CREATE TABLE') + ' ' + \
style.SQL_TABLE(qn(f.m2m_db_table())) + ' (']
table_output.append(' %s %s %s%s,' %
(style.SQL_FIELD(qn('id')),
style.SQL_COLTYPE(models.AutoField(primary_key=True).db_type()),
style.SQL_KEYWORD('NOT NULL PRIMARY KEY'),
tablespace_sql))
deferred = []
inline_output, deferred = self.sql_for_inline_many_to_many_references(model, f, style)
table_output.extend(inline_output)
table_output.append(' %s (%s, %s)%s' %
(style.SQL_KEYWORD('UNIQUE'),
style.SQL_FIELD(qn(f.m2m_column_name())),
style.SQL_FIELD(qn(f.m2m_reverse_name())),
tablespace_sql))
table_output.append(')')
if opts.db_tablespace:
# f.db_tablespace is only for indices, so ignore its value here.
table_output.append(self.connection.ops.tablespace_sql(opts.db_tablespace))
table_output.append(';')
output.append('\n'.join(table_output))
for r_table, r_col, table, col in deferred:
r_name = '%s_refs_%s_%x' % (r_col, col,
abs(hash((r_table, table))))
output.append(style.SQL_KEYWORD('ALTER TABLE') + ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
(qn(r_table),
truncate_name(r_name, self.connection.ops.max_name_length()),
qn(r_col), qn(table), qn(col),
self.connection.ops.deferrable_sql()))
# Add any extra SQL needed to support auto-incrementing PKs
autoinc_sql = self.connection.ops.autoinc_sql(f.m2m_db_table(), 'id')
if autoinc_sql:
for stmt in autoinc_sql:
output.append(stmt)
return output
def sql_for_inline_many_to_many_references(self, model, field, style):
"Create the references to other tables required by a many-to-many table"
from django.db import models
opts = model._meta
qn = self.connection.ops.quote_name
table_output = [
' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(field.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(opts.db_table)),
style.SQL_FIELD(qn(opts.pk.column)),
self.connection.ops.deferrable_sql()),
' %s %s %s %s (%s)%s,' %
(style.SQL_FIELD(qn(field.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL REFERENCES'),
style.SQL_TABLE(qn(field.rel.to._meta.db_table)),
style.SQL_FIELD(qn(field.rel.to._meta.pk.column)),
self.connection.ops.deferrable_sql())
]
deferred = []
return table_output, deferred
def sql_indexes_for_model(self, model, style):
"Returns the CREATE INDEX SQL statements for a single model"
output = []
for f in model._meta.local_fields:
output.extend(self.sql_indexes_for_field(model, f, style))
return output
def sql_indexes_for_field(self, model, f, style):
"Return the CREATE INDEX SQL statements for a single model field"
if f.db_index and not f.unique:
qn = self.connection.ops.quote_name
tablespace = f.db_tablespace or model._meta.db_tablespace
if tablespace:
sql = self.connection.ops.tablespace_sql(tablespace)
if sql:
tablespace_sql = ' ' + sql
else:
tablespace_sql = ''
else:
tablespace_sql = ''
output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' +
style.SQL_TABLE(qn('%s_%s' % (model._meta.db_table, f.column))) + ' ' +
style.SQL_KEYWORD('ON') + ' ' +
style.SQL_TABLE(qn(model._meta.db_table)) + ' ' +
"(%s)" % style.SQL_FIELD(qn(f.column)) +
"%s;" % tablespace_sql]
else:
output = []
return output
def sql_destroy_model(self, model, references_to_delete, style):
"Return the DROP TABLE and restraint dropping statements for a single model"
# Drop the table now
qn = self.connection.ops.quote_name
output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(model._meta.db_table)))]
if model in references_to_delete:
output.extend(self.sql_remove_table_constraints(model, references_to_delete))
if model._meta.has_auto_field:
ds = self.connection.ops.drop_sequence_sql(model._meta.db_table)
if ds:
output.append(ds)
return output
def sql_remove_table_constraints(self, model, references_to_delete):
output = []
for rel_class, f in references_to_delete[model]:
table = rel_class._meta.db_table
col = f.column
r_table = model._meta.db_table
r_col = model._meta.get_field(f.rel.field_name).column
r_name = '%s_refs_%s_%x' % (col, r_col, abs(hash((table, r_table))))
output.append('%s %s %s %s;' % \
(style.SQL_KEYWORD('ALTER TABLE'),
style.SQL_TABLE(qn(table)),
style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
style.SQL_FIELD(truncate_name(r_name, self.connection.ops.max_name_length()))))
del references_to_delete[model]
return output
def sql_destroy_many_to_many(self, model, f, style):
"Returns the DROP TABLE statements for a single m2m field"
qn = self.connection.ops.quote_name
output = []
if f.creates_table:
output.append("%s %s;" % (style.SQL_KEYWORD('DROP TABLE'),
style.SQL_TABLE(qn(f.m2m_db_table()))))
ds = self.connection.ops.drop_sequence_sql("%s_%s" % (model._meta.db_table, f.column))
if ds:
output.append(ds)
return output
def create_test_db(self, verbosity=1, autoclobber=False):
"""
Creates a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
if verbosity >= 1:
print "Creating test database..."
test_database_name = self._create_test_db(verbosity, autoclobber)
self.connection.close()
settings.DATABASE_NAME = test_database_name
call_command('syncdb', verbosity=verbosity, interactive=False)
if settings.CACHE_BACKEND.startswith('db://'):
cache_name = settings.CACHE_BACKEND[len('db://'):]
call_command('createcachetable', cache_name)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
cursor = self.connection.cursor()
return test_database_name
def _create_test_db(self, verbosity, autoclobber):
"Internal implementation - creates the test db tables."
suffix = self.sql_table_creation_suffix()
if settings.TEST_DATABASE_NAME:
test_database_name = settings.TEST_DATABASE_NAME
else:
test_database_name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
qn = self.connection.ops.quote_name
# Create the test database and connect to it. We need to autocommit
# if the database supports it because PostgreSQL doesn't allow
# CREATE/DROP DATABASE statements within transactions.
cursor = self.connection.cursor()
self.set_autocommit()
try:
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
cursor.execute("DROP DATABASE %s" % qn(test_database_name))
if verbosity >= 1:
print "Creating test database..."
cursor.execute("CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
return test_database_name
def destroy_test_db(self, old_database_name, verbosity=1):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
if verbosity >= 1:
print "Destroying test database..."
self.connection.close()
test_database_name = settings.DATABASE_NAME
settings.DATABASE_NAME = old_database_name
self._destroy_test_db(test_database_name, verbosity)
def _destroy_test_db(self, test_database_name, verbosity):
"Internal implementation - remove the test db tables."
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = self.connection.cursor()
self.set_autocommit()
time.sleep(1) # To avoid "database is being accessed by other users" errors.
cursor.execute("DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name))
self.connection.close()
def set_autocommit(self):
"Make sure a connection is in autocommit mode."
if hasattr(self.connection.connection, "autocommit"):
if callable(self.connection.connection.autocommit):
self.connection.connection.autocommit(True)
else:
self.connection.connection.autocommit = True
elif hasattr(self.connection.connection, "set_isolation_level"):
self.connection.connection.set_isolation_level(0)
def sql_table_creation_suffix(self):
"SQL to append to the end of the test table creation statements"
return ''

View File

@ -8,7 +8,8 @@ ImproperlyConfigured.
"""
from django.core.exceptions import ImproperlyConfigured
from django.db.backends import BaseDatabaseFeatures, BaseDatabaseOperations
from django.db.backends import *
from django.db.backends.creation import BaseDatabaseCreation
def complain(*args, **kwargs):
raise ImproperlyConfigured, "You haven't set the DATABASE_ENGINE setting yet."
@ -25,16 +26,30 @@ class IntegrityError(DatabaseError):
class DatabaseOperations(BaseDatabaseOperations):
quote_name = complain
class DatabaseWrapper(object):
features = BaseDatabaseFeatures()
ops = DatabaseOperations()
class DatabaseClient(BaseDatabaseClient):
runshell = complain
class DatabaseIntrospection(BaseDatabaseIntrospection):
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
class DatabaseWrapper(object):
operators = {}
cursor = complain
_commit = complain
_rollback = ignore
def __init__(self, **kwargs):
pass
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = BaseDatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = BaseDatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def close(self):
pass

View File

@ -1,3 +0,0 @@
from django.db.backends.dummy.base import complain
runshell = complain

View File

@ -1 +0,0 @@
DATA_TYPES = {}

View File

@ -1,8 +0,0 @@
from django.db.backends.dummy.base import complain
get_table_list = complain
get_table_description = complain
get_relations = complain
get_indexes = complain
DATA_TYPES_REVERSE = {}

View File

@ -4,7 +4,12 @@ MySQL database backend for Django.
Requires MySQLdb: http://sourceforge.net/projects/mysql-python
"""
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util
from django.db.backends import *
from django.db.backends.mysql.client import DatabaseClient
from django.db.backends.mysql.creation import DatabaseCreation
from django.db.backends.mysql.introspection import DatabaseIntrospection
from django.db.backends.mysql.validation import DatabaseValidation
try:
import MySQLdb as Database
except ImportError, e:
@ -60,7 +65,6 @@ server_version_re = re.compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
# TRADITIONAL will automatically cause most warnings to be treated as errors.
class DatabaseFeatures(BaseDatabaseFeatures):
inline_fk_references = False
empty_fetchmany_value = ()
update_can_self_select = False
@ -142,8 +146,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = {
'exact': '= BINARY %s',
'iexact': 'LIKE %s',
@ -164,6 +167,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, **kwargs):
super(DatabaseWrapper, self).__init__(**kwargs)
self.server_version = None
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = DatabaseValidation()
def _valid_connection(self):
if self.connection is not None:

View File

@ -1,27 +1,29 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os
def runshell():
args = ['']
db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME)
user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER)
passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD)
host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST)
port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT)
defaults_file = settings.DATABASE_OPTIONS.get('read_default_file')
# Seems to be no good way to set sql_mode with CLI
class DatabaseClient(BaseDatabaseClient):
def runshell(self):
args = ['']
db = settings.DATABASE_OPTIONS.get('db', settings.DATABASE_NAME)
user = settings.DATABASE_OPTIONS.get('user', settings.DATABASE_USER)
passwd = settings.DATABASE_OPTIONS.get('passwd', settings.DATABASE_PASSWORD)
host = settings.DATABASE_OPTIONS.get('host', settings.DATABASE_HOST)
port = settings.DATABASE_OPTIONS.get('port', settings.DATABASE_PORT)
defaults_file = settings.DATABASE_OPTIONS.get('read_default_file')
# Seems to be no good way to set sql_mode with CLI
if defaults_file:
args += ["--defaults-file=%s" % defaults_file]
if user:
args += ["--user=%s" % user]
if passwd:
args += ["--password=%s" % passwd]
if host:
args += ["--host=%s" % host]
if port:
args += ["--port=%s" % port]
if db:
args += [db]
if defaults_file:
args += ["--defaults-file=%s" % defaults_file]
if user:
args += ["--user=%s" % user]
if passwd:
args += ["--password=%s" % passwd]
if host:
args += ["--host=%s" % host]
if port:
args += ["--port=%s" % port]
if db:
args += [db]
os.execvp('mysql', args)
os.execvp('mysql', args)

View File

@ -1,28 +1,68 @@
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
DATA_TYPES = {
'AutoField': 'integer AUTO_INCREMENT',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'IPAddressField': 'char(15)',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PhoneNumberField': 'varchar(20)',
'PositiveIntegerField': 'integer UNSIGNED',
'PositiveSmallIntegerField': 'smallint UNSIGNED',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
from django.conf import settings
from django.db.backends.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'integer AUTO_INCREMENT',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'IPAddressField': 'char(15)',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PhoneNumberField': 'varchar(20)',
'PositiveIntegerField': 'integer UNSIGNED',
'PositiveSmallIntegerField': 'smallint UNSIGNED',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
def sql_table_creation_suffix(self):
suffix = []
if settings.TEST_DATABASE_CHARSET:
suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET)
if settings.TEST_DATABASE_COLLATION:
suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION)
return ' '.join(suffix)
def sql_for_inline_foreign_key_references(self, field, known_models, style):
"All inline references are pending under MySQL"
return [], True
def sql_for_inline_many_to_many_references(self, model, field, style):
from django.db import models
opts = model._meta
qn = self.connection.ops.quote_name
table_output = [
' %s %s %s,' %
(style.SQL_FIELD(qn(field.m2m_column_name())),
style.SQL_COLTYPE(models.ForeignKey(model).db_type()),
style.SQL_KEYWORD('NOT NULL')),
' %s %s %s,' %
(style.SQL_FIELD(qn(field.m2m_reverse_name())),
style.SQL_COLTYPE(models.ForeignKey(field.rel.to).db_type()),
style.SQL_KEYWORD('NOT NULL'))
]
deferred = [
(field.m2m_db_table(), field.m2m_column_name(), opts.db_table,
opts.pk.column),
(field.m2m_db_table(), field.m2m_reverse_name(),
field.rel.to._meta.db_table, field.rel.to._meta.pk.column)
]
return table_output, deferred

View File

@ -1,96 +1,97 @@
from django.db.backends.mysql.base import DatabaseOperations
from django.db.backends import BaseDatabaseIntrospection
from MySQLdb import ProgrammingError, OperationalError
from MySQLdb.constants import FIELD_TYPE
import re
quote_name = DatabaseOperations().quote_name
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
def get_table_list(cursor):
"Returns a list of table names in the current database."
cursor.execute("SHOW TABLES")
return [row[0] for row in cursor.fetchall()]
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
FIELD_TYPE.FLOAT: 'FloatField',
FIELD_TYPE.INT24: 'IntegerField',
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'IntegerField',
FIELD_TYPE.SHORT: 'IntegerField',
FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField',
}
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name))
return cursor.description
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
cursor.execute("SHOW TABLES")
return [row[0] for row in cursor.fetchall()]
def _name_to_index(cursor, table_name):
"""
Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based.
"""
return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))])
def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return cursor.description
def get_relations(cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
my_field_dict = _name_to_index(cursor, table_name)
constraints = []
relations = {}
try:
# This should work for MySQL 5.0.
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
constraints.extend(cursor.fetchall())
except (ProgrammingError, OperationalError):
# Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
# Go through all constraints and save the equal matches.
cursor.execute("SHOW CREATE TABLE %s" % quote_name(table_name))
def _name_to_index(self, cursor, table_name):
"""
Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based.
"""
return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))])
def get_relations(self, cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
my_field_dict = self._name_to_index(cursor, table_name)
constraints = []
relations = {}
try:
# This should work for MySQL 5.0.
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name])
constraints.extend(cursor.fetchall())
except (ProgrammingError, OperationalError):
# Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
# Go through all constraints and save the equal matches.
cursor.execute("SHOW CREATE TABLE %s" % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
pos = 0
while True:
match = foreign_key_re.search(row[1], pos)
if match == None:
break
pos = match.end()
constraints.append(match.groups())
for my_fieldname, other_table, other_field in constraints:
other_field_index = self._name_to_index(cursor, other_table)[other_field]
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
indexes = {}
for row in cursor.fetchall():
pos = 0
while True:
match = foreign_key_re.search(row[1], pos)
if match == None:
break
pos = match.end()
constraints.append(match.groups())
indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
return indexes
for my_fieldname, other_table, other_field in constraints:
other_field_index = _name_to_index(cursor, other_table)[other_field]
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
cursor.execute("SHOW INDEX FROM %s" % quote_name(table_name))
indexes = {}
for row in cursor.fetchall():
indexes[row[4]] = {'primary_key': (row[2] == 'PRIMARY'), 'unique': not bool(row[1])}
return indexes
DATA_TYPES_REVERSE = {
FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
FIELD_TYPE.FLOAT: 'FloatField',
FIELD_TYPE.INT24: 'IntegerField',
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'IntegerField',
FIELD_TYPE.SHORT: 'IntegerField',
FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField',
}

View File

@ -0,0 +1,13 @@
from django.db.backends import BaseDatabaseValidation
class DatabaseValidation(BaseDatabaseValidation):
def validate_field(self, errors, opts, f):
"Prior to MySQL 5.0.3, character fields could not exceed 255 characters"
from django.db import models
from django.db import connection
db_version = connection.get_server_version()
if db_version < (5, 0, 3) and isinstance(f, (models.CharField, models.CommaSeparatedIntegerField, models.SlugField)) and f.max_length > 255:
errors.add(opts,
'"%s": %s cannot have a "max_length" greater than 255 when you are using a version of MySQL prior to 5.0.3 (you are using %s).' %
(f.name, f.__class__.__name__, '.'.join([str(n) for n in db_version[:3]])))

View File

@ -8,8 +8,11 @@ import os
import datetime
import time
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util
from django.db.backends import *
from django.db.backends.oracle import query
from django.db.backends.oracle.client import DatabaseClient
from django.db.backends.oracle.creation import DatabaseCreation
from django.db.backends.oracle.introspection import DatabaseIntrospection
from django.utils.encoding import smart_str, force_unicode
# Oracle takes client-side character set encoding from the environment.
@ -24,11 +27,8 @@ DatabaseError = Database.Error
IntegrityError = Database.IntegrityError
class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_ordinal = False
empty_fetchmany_value = ()
needs_datetime_string_cast = False
supports_tablespaces = True
uses_case_insensitive_names = True
uses_custom_query_class = True
interprets_empty_strings_as_nulls = True
@ -194,10 +194,8 @@ class DatabaseOperations(BaseDatabaseOperations):
return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = {
'exact': '= %s',
'iexact': '= UPPER(%s)',
@ -214,6 +212,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
}
oracle_version = None
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _valid_connection(self):
return self.connection is not None

View File

@ -1,11 +1,13 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os
def runshell():
dsn = settings.DATABASE_USER
if settings.DATABASE_PASSWORD:
dsn += "/%s" % settings.DATABASE_PASSWORD
if settings.DATABASE_NAME:
dsn += "@%s" % settings.DATABASE_NAME
args = ["sqlplus", "-L", dsn]
os.execvp("sqlplus", args)
class DatabaseClient(BaseDatabaseClient):
def runshell(self):
dsn = settings.DATABASE_USER
if settings.DATABASE_PASSWORD:
dsn += "/%s" % settings.DATABASE_PASSWORD
if settings.DATABASE_NAME:
dsn += "@%s" % settings.DATABASE_NAME
args = ["sqlplus", "-L", dsn]
os.execvp("sqlplus", args)

View File

@ -1,291 +1,289 @@
import sys, time
from django.conf import settings
from django.core import management
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
#
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
DATA_TYPES = {
'AutoField': 'NUMBER(11)',
'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))',
'CharField': 'NVARCHAR2(%(max_length)s)',
'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
'FileField': 'NVARCHAR2(%(max_length)s)',
'FilePathField': 'NVARCHAR2(%(max_length)s)',
'FloatField': 'DOUBLE PRECISION',
'IntegerField': 'NUMBER(11)',
'IPAddressField': 'VARCHAR2(15)',
'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))',
'OneToOneField': 'NUMBER(11)',
'PhoneNumberField': 'VARCHAR2(20)',
'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'SlugField': 'NVARCHAR2(50)',
'SmallIntegerField': 'NUMBER(11)',
'TextField': 'NCLOB',
'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)',
'USStateField': 'CHAR(2)',
}
from django.db.backends.creation import BaseDatabaseCreation
TEST_DATABASE_PREFIX = 'test_'
PASSWORD = 'Im_a_lumberjack'
REMEMBER = {}
def create_test_db(settings, connection, verbosity=1, autoclobber=False):
TEST_DATABASE_NAME = _test_database_name(settings)
TEST_DATABASE_USER = _test_database_user(settings)
TEST_DATABASE_PASSWD = _test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings)
class DatabaseCreation(BaseDatabaseCreation):
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
#
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
parameters = {
'dbname': TEST_DATABASE_NAME,
'user': TEST_DATABASE_USER,
'password': TEST_DATABASE_PASSWD,
'tblspace': TEST_DATABASE_TBLSPACE,
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
}
data_types = {
'AutoField': 'NUMBER(11)',
'BooleanField': 'NUMBER(1) CHECK (%(qn_column)s IN (0,1))',
'CharField': 'NVARCHAR2(%(max_length)s)',
'CommaSeparatedIntegerField': 'VARCHAR2(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
'FileField': 'NVARCHAR2(%(max_length)s)',
'FilePathField': 'NVARCHAR2(%(max_length)s)',
'FloatField': 'DOUBLE PRECISION',
'IntegerField': 'NUMBER(11)',
'IPAddressField': 'VARCHAR2(15)',
'NullBooleanField': 'NUMBER(1) CHECK ((%(qn_column)s IN (0,1)) OR (%(qn_column)s IS NULL))',
'OneToOneField': 'NUMBER(11)',
'PhoneNumberField': 'VARCHAR2(20)',
'PositiveIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'PositiveSmallIntegerField': 'NUMBER(11) CHECK (%(qn_column)s >= 0)',
'SlugField': 'NVARCHAR2(50)',
'SmallIntegerField': 'NUMBER(11)',
'TextField': 'NCLOB',
'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)',
'USStateField': 'CHAR(2)',
}
def _create_test_db(self, verbosity, autoclobber):
TEST_DATABASE_NAME = self._test_database_name(settings)
TEST_DATABASE_USER = self._test_database_user(settings)
TEST_DATABASE_PASSWD = self._test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings)
REMEMBER['user'] = settings.DATABASE_USER
REMEMBER['passwd'] = settings.DATABASE_PASSWORD
parameters = {
'dbname': TEST_DATABASE_NAME,
'user': TEST_DATABASE_USER,
'password': TEST_DATABASE_PASSWD,
'tblspace': TEST_DATABASE_TBLSPACE,
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
}
cursor = connection.cursor()
if _test_database_create(settings):
if verbosity >= 1:
print 'Creating test database...'
try:
_create_test_db(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("It appears the test database, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
_destroy_test_db(cursor, parameters, verbosity)
if verbosity >= 1:
print "Creating test database..."
_create_test_db(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
self.remember['user'] = settings.DATABASE_USER
self.remember['passwd'] = settings.DATABASE_PASSWORD
if _test_user_create(settings):
if verbosity >= 1:
print "Creating test user..."
try:
_create_test_user(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error creating the test user: %s\n" % e)
if not autoclobber:
confirm = raw_input("It appears the test user, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_USER)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test user..."
_destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
print "Creating test user..."
_create_test_user(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
cursor = self.connection.cursor()
if self._test_database_create(settings):
if verbosity >= 1:
print 'Creating test database...'
try:
self._execute_test_db_creation(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("It appears the test database, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
self._execute_test_db_destruction(cursor, parameters, verbosity)
if verbosity >= 1:
print "Creating test database..."
self._execute_test_db_creation(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
connection.close()
settings.DATABASE_USER = TEST_DATABASE_USER
settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD
if self._test_user_create(settings):
if verbosity >= 1:
print "Creating test user..."
try:
self._create_test_user(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error creating the test user: %s\n" % e)
if not autoclobber:
confirm = raw_input("It appears the test user, %s, already exists. Type 'yes' to delete it, or 'no' to cancel: " % TEST_DATABASE_USER)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test user..."
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
print "Creating test user..."
self._create_test_user(cursor, parameters, verbosity)
except Exception, e:
sys.stderr.write("Got an error recreating the test user: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
management.call_command('syncdb', verbosity=verbosity, interactive=False)
settings.DATABASE_USER = TEST_DATABASE_USER
settings.DATABASE_PASSWORD = TEST_DATABASE_PASSWD
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
cursor = connection.cursor()
return TEST_DATABASE_NAME
def _destroy_test_db(self, test_database_name, verbosity=1):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
TEST_DATABASE_NAME = self._test_database_name(settings)
TEST_DATABASE_USER = self._test_database_user(settings)
TEST_DATABASE_PASSWD = self._test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = self._test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = self._test_database_tblspace_tmp(settings)
def destroy_test_db(settings, connection, old_database_name, verbosity=1):
connection.close()
settings.DATABASE_USER = self.remember['user']
settings.DATABASE_PASSWORD = self.remember['passwd']
TEST_DATABASE_NAME = _test_database_name(settings)
TEST_DATABASE_USER = _test_database_user(settings)
TEST_DATABASE_PASSWD = _test_database_passwd(settings)
TEST_DATABASE_TBLSPACE = _test_database_tblspace(settings)
TEST_DATABASE_TBLSPACE_TMP = _test_database_tblspace_tmp(settings)
parameters = {
'dbname': TEST_DATABASE_NAME,
'user': TEST_DATABASE_USER,
'password': TEST_DATABASE_PASSWD,
'tblspace': TEST_DATABASE_TBLSPACE,
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
}
settings.DATABASE_NAME = old_database_name
settings.DATABASE_USER = REMEMBER['user']
settings.DATABASE_PASSWORD = REMEMBER['passwd']
self.remember['user'] = settings.DATABASE_USER
self.remember['passwd'] = settings.DATABASE_PASSWORD
parameters = {
'dbname': TEST_DATABASE_NAME,
'user': TEST_DATABASE_USER,
'password': TEST_DATABASE_PASSWD,
'tblspace': TEST_DATABASE_TBLSPACE,
'tblspace_temp': TEST_DATABASE_TBLSPACE_TMP,
}
cursor = self.connection.cursor()
time.sleep(1) # To avoid "database is being accessed by other users" errors.
if self._test_user_create(settings):
if verbosity >= 1:
print 'Destroying test user...'
self._destroy_test_user(cursor, parameters, verbosity)
if self._test_database_create(settings):
if verbosity >= 1:
print 'Destroying test database tables...'
self._execute_test_db_destruction(cursor, parameters, verbosity)
self.connection.close()
REMEMBER['user'] = settings.DATABASE_USER
REMEMBER['passwd'] = settings.DATABASE_PASSWORD
cursor = connection.cursor()
time.sleep(1) # To avoid "database is being accessed by other users" errors.
if _test_user_create(settings):
if verbosity >= 1:
print 'Destroying test user...'
_destroy_test_user(cursor, parameters, verbosity)
if _test_database_create(settings):
if verbosity >= 1:
print 'Destroying test database...'
_destroy_test_db(cursor, parameters, verbosity)
connection.close()
def _create_test_db(cursor, parameters, verbosity):
if verbosity >= 2:
print "_create_test_db(): dbname = %s" % parameters['dbname']
statements = [
"""CREATE TABLESPACE %(tblspace)s
DATAFILE '%(tblspace)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
"""CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE '%(tblspace_temp)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _create_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_create_test_user(): username = %s" % parameters['user']
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY %(password)s
DEFAULT TABLESPACE %(tblspace)s
TEMPORARY TABLESPACE %(tblspace_temp)s
""",
"""GRANT CONNECT, RESOURCE TO %(user)s""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_db(cursor, parameters, verbosity):
if verbosity >= 2:
print "_destroy_test_db(): dbname=%s" % parameters['dbname']
statements = [
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_destroy_test_user(): user=%s" % parameters['user']
print "Be patient. This can take some time..."
statements = [
'DROP USER %(user)s CASCADE',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(cursor, statements, parameters, verbosity):
for template in statements:
stmt = template % parameters
def _execute_test_db_creation(cursor, parameters, verbosity):
if verbosity >= 2:
print stmt
print "_create_test_db(): dbname = %s" % parameters['dbname']
statements = [
"""CREATE TABLESPACE %(tblspace)s
DATAFILE '%(tblspace)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
"""CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
TEMPFILE '%(tblspace_temp)s.dbf' SIZE 20M
REUSE AUTOEXTEND ON NEXT 10M MAXSIZE 100M
""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _create_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_create_test_user(): username = %s" % parameters['user']
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY %(password)s
DEFAULT TABLESPACE %(tblspace)s
TEMPORARY TABLESPACE %(tblspace_temp)s
""",
"""GRANT CONNECT, RESOURCE TO %(user)s""",
]
_execute_statements(cursor, statements, parameters, verbosity)
def _execute_test_db_destruction(cursor, parameters, verbosity):
if verbosity >= 2:
print "_execute_test_db_destruction(): dbname=%s" % parameters['dbname']
statements = [
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(cursor, parameters, verbosity):
if verbosity >= 2:
print "_destroy_test_user(): user=%s" % parameters['user']
print "Be patient. This can take some time..."
statements = [
'DROP USER %(user)s CASCADE',
]
_execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(cursor, statements, parameters, verbosity):
for template in statements:
stmt = template % parameters
if verbosity >= 2:
print stmt
try:
cursor.execute(stmt)
except Exception, err:
sys.stderr.write("Failed (%s)\n" % (err))
raise
def _test_database_name(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try:
cursor.execute(stmt)
except Exception, err:
sys.stderr.write("Failed (%s)\n" % (err))
if settings.TEST_DATABASE_NAME:
name = settings.TEST_DATABASE_NAME
except AttributeError:
pass
except:
raise
return name
def _test_database_name(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try:
if settings.TEST_DATABASE_NAME:
name = settings.TEST_DATABASE_NAME
except AttributeError:
pass
except:
raise
return name
def _test_database_create(settings):
name = True
try:
if settings.TEST_DATABASE_CREATE:
name = True
else:
name = False
except AttributeError:
pass
except:
raise
return name
def _test_database_create(settings):
name = True
try:
if settings.TEST_DATABASE_CREATE:
name = True
else:
name = False
except AttributeError:
pass
except:
raise
return name
def _test_user_create(settings):
name = True
try:
if settings.TEST_USER_CREATE:
name = True
else:
name = False
except AttributeError:
pass
except:
raise
return name
def _test_user_create(settings):
name = True
try:
if settings.TEST_USER_CREATE:
name = True
else:
name = False
except AttributeError:
pass
except:
raise
return name
def _test_database_user(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try:
if settings.TEST_DATABASE_USER:
name = settings.TEST_DATABASE_USER
except AttributeError:
pass
except:
raise
return name
def _test_database_user(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try:
if settings.TEST_DATABASE_USER:
name = settings.TEST_DATABASE_USER
except AttributeError:
pass
except:
raise
return name
def _test_database_passwd(settings):
name = PASSWORD
try:
if settings.TEST_DATABASE_PASSWD:
name = settings.TEST_DATABASE_PASSWD
except AttributeError:
pass
except:
raise
return name
def _test_database_passwd(settings):
name = PASSWORD
try:
if settings.TEST_DATABASE_PASSWD:
name = settings.TEST_DATABASE_PASSWD
except AttributeError:
pass
except:
raise
return name
def _test_database_tblspace(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try:
if settings.TEST_DATABASE_TBLSPACE:
name = settings.TEST_DATABASE_TBLSPACE
except AttributeError:
pass
except:
raise
return name
def _test_database_tblspace(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
try:
if settings.TEST_DATABASE_TBLSPACE:
name = settings.TEST_DATABASE_TBLSPACE
except AttributeError:
pass
except:
raise
return name
def _test_database_tblspace_tmp(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp'
try:
if settings.TEST_DATABASE_TBLSPACE_TMP:
name = settings.TEST_DATABASE_TBLSPACE_TMP
except AttributeError:
pass
except:
raise
return name
def _test_database_tblspace_tmp(settings):
name = TEST_DATABASE_PREFIX + settings.DATABASE_NAME + '_temp'
try:
if settings.TEST_DATABASE_TBLSPACE_TMP:
name = settings.TEST_DATABASE_TBLSPACE_TMP
except AttributeError:
pass
except:
raise
return name

View File

@ -1,98 +1,103 @@
from django.db.backends.oracle.base import DatabaseOperations
import re
from django.db.backends import BaseDatabaseIntrospection
import cx_Oracle
import re
quote_name = DatabaseOperations().quote_name
foreign_key_re = re.compile(r"\sCONSTRAINT `[^`]*` FOREIGN KEY \(`([^`]*)`\) REFERENCES `([^`]*)` \(`([^`]*)`\)")
def get_table_list(cursor):
"Returns a list of table names in the current database."
cursor.execute("SELECT TABLE_NAME FROM USER_TABLES")
return [row[0].upper() for row in cursor.fetchall()]
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type objects to Django Field types.
data_types_reverse = {
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateTimeField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % quote_name(table_name))
return cursor.description
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
cursor.execute("SELECT TABLE_NAME FROM USER_TABLES")
return [row[0].upper() for row in cursor.fetchall()]
def _name_to_index(cursor, table_name):
"""
Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based.
"""
return dict([(d[0], i) for i, d in enumerate(get_table_description(cursor, table_name))])
def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s WHERE ROWNUM < 2" % self.connection.ops.quote_name(table_name))
return cursor.description
def get_relations(cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb,
user_tab_cols ta, user_tab_cols tb
WHERE user_constraints.table_name = %s AND
ta.table_name = %s AND
ta.column_name = ca.column_name AND
ca.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
cb.table_name = tb.table_name AND
cb.column_name = tb.column_name AND
ca.position = cb.position""", [table_name, table_name, table_name])
def table_name_converter(self, name):
"Table name comparison is case insensitive under Oracle"
return name.upper()
def _name_to_index(self, cursor, table_name):
"""
Returns a dictionary of {field_name: field_index} for the given table.
Indexes are 0-based.
"""
return dict([(d[0], i) for i, d in enumerate(self.get_table_description(cursor, table_name))])
relations = {}
for row in cursor.fetchall():
relations[row[0]] = (row[2], row[1])
return relations
def get_relations(self, cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT ta.column_id - 1, tb.table_name, tb.column_id - 1
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb,
user_tab_cols ta, user_tab_cols tb
WHERE user_constraints.table_name = %s AND
ta.table_name = %s AND
ta.column_name = ca.column_name AND
ca.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
cb.table_name = tb.table_name AND
cb.column_name = tb.column_name AND
ca.position = cb.position""", [table_name, table_name, table_name])
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
# "We were in the nick of time; you were in great peril!"
sql = """
WITH primarycols AS (
SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL
FROM user_cons_columns, user_constraints
WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND
user_constraints.constraint_type = 'P' AND
user_cons_columns.table_name = %s),
uniquecols AS (
SELECT user_ind_columns.table_name, user_ind_columns.column_name, 1 AS UNIQUECOL
FROM user_indexes, user_ind_columns
WHERE uniqueness = 'UNIQUE' AND
user_indexes.index_name = user_ind_columns.index_name AND
user_ind_columns.table_name = %s)
SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL
FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM
uniquecols) allcols,
primarycols, uniquecols
WHERE allcols.column_name = primarycols.column_name (+) AND
allcols.column_name = uniquecols.column_name (+)
"""
cursor.execute(sql, [table_name, table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]}
return indexes
relations = {}
for row in cursor.fetchall():
relations[row[0]] = (row[2], row[1])
return relations
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
# "We were in the nick of time; you were in great peril!"
sql = """
WITH primarycols AS (
SELECT user_cons_columns.table_name, user_cons_columns.column_name, 1 AS PRIMARYCOL
FROM user_cons_columns, user_constraints
WHERE user_cons_columns.constraint_name = user_constraints.constraint_name AND
user_constraints.constraint_type = 'P' AND
user_cons_columns.table_name = %s),
uniquecols AS (
SELECT user_ind_columns.table_name, user_ind_columns.column_name, 1 AS UNIQUECOL
FROM user_indexes, user_ind_columns
WHERE uniqueness = 'UNIQUE' AND
user_indexes.index_name = user_ind_columns.index_name AND
user_ind_columns.table_name = %s)
SELECT allcols.column_name, primarycols.primarycol, uniquecols.UNIQUECOL
FROM (SELECT column_name FROM primarycols UNION SELECT column_name FROM
uniquecols) allcols,
primarycols, uniquecols
WHERE allcols.column_name = primarycols.column_name (+) AND
allcols.column_name = uniquecols.column_name (+)
"""
cursor.execute(sql, [table_name, table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
indexes[row[0]] = {'primary_key': row[1], 'unique': row[2]}
return indexes
# Maps type objects to Django Field types.
DATA_TYPES_REVERSE = {
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateTimeField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}

View File

@ -4,9 +4,13 @@ PostgreSQL database backend for Django.
Requires psycopg 1: http://initd.org/projects/psycopg1
"""
from django.utils.encoding import smart_str, smart_unicode
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, util
from django.db.backends import *
from django.db.backends.postgresql.client import DatabaseClient
from django.db.backends.postgresql.creation import DatabaseCreation
from django.db.backends.postgresql.introspection import DatabaseIntrospection
from django.db.backends.postgresql.operations import DatabaseOperations
from django.utils.encoding import smart_str, smart_unicode
try:
import psycopg as Database
except ImportError, e:
@ -59,12 +63,7 @@ class UnicodeCursorWrapper(object):
def __iter__(self):
return iter(self.cursor)
class DatabaseFeatures(BaseDatabaseFeatures):
pass # This backend uses all the defaults.
class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = {
'exact': '= %s',
'iexact': 'ILIKE %s',
@ -82,6 +81,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': 'ILIKE %s',
}
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = BaseDatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings):
set_tz = False
if self.connection is None:

View File

@ -1,15 +1,17 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os
def runshell():
args = ['psql']
if settings.DATABASE_USER:
args += ["-U", settings.DATABASE_USER]
if settings.DATABASE_PASSWORD:
args += ["-W"]
if settings.DATABASE_HOST:
args.extend(["-h", settings.DATABASE_HOST])
if settings.DATABASE_PORT:
args.extend(["-p", str(settings.DATABASE_PORT)])
args += [settings.DATABASE_NAME]
os.execvp('psql', args)
class DatabaseClient(BaseDatabaseClient):
def runshell(self):
args = ['psql']
if settings.DATABASE_USER:
args += ["-U", settings.DATABASE_USER]
if settings.DATABASE_PASSWORD:
args += ["-W"]
if settings.DATABASE_HOST:
args.extend(["-h", settings.DATABASE_HOST])
if settings.DATABASE_PORT:
args.extend(["-p", str(settings.DATABASE_PORT)])
args += [settings.DATABASE_NAME]
os.execvp('psql', args)

View File

@ -1,28 +1,38 @@
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
DATA_TYPES = {
'AutoField': 'serial',
'BooleanField': 'boolean',
'CharField': 'varchar(%(max_length)s)',
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'IPAddressField': 'inet',
'NullBooleanField': 'boolean',
'OneToOneField': 'integer',
'PhoneNumberField': 'varchar(20)',
'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)',
'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
from django.conf import settings
from django.db.backends.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
'AutoField': 'serial',
'BooleanField': 'boolean',
'CharField': 'varchar(%(max_length)s)',
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'IPAddressField': 'inet',
'NullBooleanField': 'boolean',
'OneToOneField': 'integer',
'PhoneNumberField': 'varchar(20)',
'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)',
'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
def sql_table_creation_suffix(self):
assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time."
if settings.TEST_DATABASE_CHARSET:
return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET
return ''

View File

@ -1,86 +1,86 @@
from django.db.backends.postgresql.base import DatabaseOperations
from django.db.backends import BaseDatabaseIntrospection
quote_name = DatabaseOperations().quote_name
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
cursor.execute("""
SELECT c.relname
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r', 'v', '')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)""")
return [row[0] for row in cursor.fetchall()]
def get_table_list(cursor):
"Returns a list of table names in the current database."
cursor.execute("""
SELECT c.relname
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r', 'v', '')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return cursor.description
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name))
return cursor.description
def get_relations(self, cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT con.conkey, con.confkey, c2.relname
FROM pg_constraint con, pg_class c1, pg_class c2
WHERE c1.oid = con.conrelid
AND c2.oid = con.confrelid
AND c1.relname = %s
AND con.contype = 'f'""", [table_name])
relations = {}
for row in cursor.fetchall():
try:
# row[0] and row[1] are like "{2}", so strip the curly braces.
relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2])
except ValueError:
continue
return relations
def get_relations(cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT con.conkey, con.confkey, c2.relname
FROM pg_constraint con, pg_class c1, pg_class c2
WHERE c1.oid = con.conrelid
AND c2.oid = con.confrelid
AND c1.relname = %s
AND con.contype = 'f'""", [table_name])
relations = {}
for row in cursor.fetchall():
try:
# row[0] and row[1] are like "{2}", so strip the curly braces.
relations[int(row[0][1:-1]) - 1] = (int(row[1][1:-1]) - 1, row[2])
except ValueError:
continue
return relations
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
cursor.execute("""
SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND attr.attrelid = c.oid
AND attr.attnum = idx.indkey[0]
AND c.relname = %s""", [table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
if ' ' in row[1]:
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
cursor.execute("""
SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND attr.attrelid = c.oid
AND attr.attnum = idx.indkey[0]
AND c.relname = %s""", [table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
if ' ' in row[1]:
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
# Maps type codes to Django Field types.
DATA_TYPES_REVERSE = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}

View File

@ -4,8 +4,12 @@ PostgreSQL database backend for Django.
Requires psycopg 2: http://initd.org/projects/psycopg2
"""
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures
from django.db.backends import *
from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations
from django.db.backends.postgresql.client import DatabaseClient
from django.db.backends.postgresql.creation import DatabaseCreation
from django.db.backends.postgresql_psycopg2.introspection import DatabaseIntrospection
from django.utils.safestring import SafeUnicode
try:
import psycopg2 as Database
@ -31,8 +35,6 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
return cursor.query
class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
operators = {
'exact': '= %s',
'iexact': 'ILIKE %s',
@ -50,6 +52,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': 'ILIKE %s',
}
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings):
set_tz = False
if self.connection is None:

View File

@ -1 +0,0 @@
from django.db.backends.postgresql.client import *

View File

@ -1 +0,0 @@
from django.db.backends.postgresql.creation import *

View File

@ -1,83 +1,21 @@
from django.db.backends.postgresql_psycopg2.base import DatabaseOperations
from django.db.backends.postgresql.introspection import DatabaseIntrospection as PostgresDatabaseIntrospection
quote_name = DatabaseOperations().quote_name
class DatabaseIntrospection(PostgresDatabaseIntrospection):
def get_table_list(cursor):
"Returns a list of table names in the current database."
cursor.execute("""
SELECT c.relname
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r', 'v', '')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
cursor.execute("SELECT * FROM %s LIMIT 1" % quote_name(table_name))
return cursor.description
def get_relations(cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT con.conkey, con.confkey, c2.relname
FROM pg_constraint con, pg_class c1, pg_class c2
WHERE c1.oid = con.conrelid
AND c2.oid = con.confrelid
AND c1.relname = %s
AND con.contype = 'f'""", [table_name])
relations = {}
for row in cursor.fetchall():
# row[0] and row[1] are single-item lists, so grab the single item.
relations[row[0][0] - 1] = (row[1][0] - 1, row[2])
return relations
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
# This query retrieves each index on the given table, including the
# first associated field name
cursor.execute("""
SELECT attr.attname, idx.indkey, idx.indisunique, idx.indisprimary
FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
pg_catalog.pg_index idx, pg_catalog.pg_attribute attr
WHERE c.oid = idx.indrelid
AND idx.indexrelid = c2.oid
AND attr.attrelid = c.oid
AND attr.attnum = idx.indkey[0]
AND c.relname = %s""", [table_name])
indexes = {}
for row in cursor.fetchall():
# row[1] (idx.indkey) is stored in the DB as an array. It comes out as
# a string of space-separated integers. This designates the field
# indexes (1-based) of the fields that have indexes on the table.
# Here, we skip any indexes across multiple fields.
if ' ' in row[1]:
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
# Maps type codes to Django Field types.
DATA_TYPES_REVERSE = {
16: 'BooleanField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
701: 'FloatField',
869: 'IPAddressField',
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1266: 'TimeField',
1700: 'DecimalField',
}
def get_relations(self, cursor, table_name):
"""
Returns a dictionary of {field_index: (field_index_other_table, other_table)}
representing all relationships to the given table. Indexes are 0-based.
"""
cursor.execute("""
SELECT con.conkey, con.confkey, c2.relname
FROM pg_constraint con, pg_class c1, pg_class c2
WHERE c1.oid = con.conrelid
AND c2.oid = con.confrelid
AND c1.relname = %s
AND con.contype = 'f'""", [table_name])
relations = {}
for row in cursor.fetchall():
# row[0] and row[1] are single-item lists, so grab the single item.
relations[row[0][0] - 1] = (row[1][0] - 1, row[2])
return relations

View File

@ -6,7 +6,11 @@ Python 2.3 and 2.4 require pysqlite2 (http://pysqlite.org/).
Python 2.5 and later use the sqlite3 module in the standard library.
"""
from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, util
from django.db.backends import *
from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection
try:
try:
from sqlite3 import dbapi2 as Database
@ -46,7 +50,6 @@ if Database.version_info >= (2,4,1):
Database.register_adapter(str, lambda s:s.decode('utf-8'))
class DatabaseFeatures(BaseDatabaseFeatures):
supports_constraints = False
# SQLite cannot handle us only partially reading from a cursor's result set
# and then writing the same rows to the database in another cursor. This
# setting ensures we always read result sets fully into memory all in one
@ -96,11 +99,8 @@ class DatabaseOperations(BaseDatabaseOperations):
second = '%s-12-31 23:59:59.999999'
return [first % value, second % value]
class DatabaseWrapper(BaseDatabaseWrapper):
features = DatabaseFeatures()
ops = DatabaseOperations()
# SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it.
# See http://www.sqlite.org/lang_expr.html for an explanation.
@ -121,6 +121,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
'iendswith': "LIKE %s ESCAPE '\\'",
}
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation()
def _cursor(self, settings):
if self.connection is None:
if not settings.DATABASE_NAME:

View File

@ -1,6 +1,8 @@
from django.db.backends import BaseDatabaseClient
from django.conf import settings
import os
def runshell():
args = ['', settings.DATABASE_NAME]
os.execvp('sqlite3', args)
class DatabaseClient(BaseDatabaseClient):
def runshell(self):
args = ['', settings.DATABASE_NAME]
os.execvp('sqlite3', args)

View File

@ -1,27 +1,73 @@
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
DATA_TYPES = {
'AutoField': 'integer',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
'DecimalField': 'decimal',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'real',
'IntegerField': 'integer',
'IPAddressField': 'char(15)',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PhoneNumberField': 'varchar(20)',
'PositiveIntegerField': 'integer unsigned',
'PositiveSmallIntegerField': 'smallint unsigned',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
import os
import sys
from django.conf import settings
from django.db.backends.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
data_types = {
'AutoField': 'integer',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'CommaSeparatedIntegerField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
'DecimalField': 'decimal',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'real',
'IntegerField': 'integer',
'IPAddressField': 'char(15)',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PhoneNumberField': 'varchar(20)',
'PositiveIntegerField': 'integer unsigned',
'PositiveSmallIntegerField': 'smallint unsigned',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'USStateField': 'varchar(2)',
}
def sql_for_pending_references(self, model, style, pending_references):
"SQLite3 doesn't support constraints"
return []
def sql_remove_table_constraints(self, model, references_to_delete):
"SQLite3 doesn't support constraints"
return []
def _create_test_db(self, verbosity, autoclobber):
if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:":
test_database_name = settings.TEST_DATABASE_NAME
# Erase the old test database
if verbosity >= 1:
print "Destroying old test database..."
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
os.remove(test_database_name)
except Exception, e:
sys.stderr.write("Got an error deleting the old test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
if verbosity >= 1:
print "Creating test database..."
else:
test_database_name = ":memory:"
return test_database_name
def _destroy_test_db(self, test_database_name, verbosity):
if test_database_name and test_database_name != ":memory:":
# Remove the SQLite database file
os.remove(test_database_name)

View File

@ -1,84 +1,30 @@
from django.db.backends.sqlite3.base import DatabaseOperations
quote_name = DatabaseOperations().quote_name
def get_table_list(cursor):
"Returns a list of table names in the current database."
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND NOT name='sqlite_sequence'
ORDER BY name""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
return [(info['name'], info['type'], None, None, None, None,
info['null_ok']) for info in _table_info(cursor, table_name)]
def get_relations(cursor, table_name):
raise NotImplementedError
def get_indexes(cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
indexes = {}
for info in _table_info(cursor, table_name):
indexes[info['name']] = {'primary_key': info['pk'] != 0,
'unique': False}
cursor.execute('PRAGMA index_list(%s)' % quote_name(table_name))
# seq, name, unique
for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]:
if not unique:
continue
cursor.execute('PRAGMA index_info(%s)' % quote_name(index))
info = cursor.fetchall()
# Skip indexes across multiple fields
if len(info) != 1:
continue
name = info[0][2] # seqno, cid, name
indexes[name]['unique'] = True
return indexes
def _table_info(cursor, name):
cursor.execute('PRAGMA table_info(%s)' % quote_name(name))
# cid, name, type, notnull, dflt_value, pk
return [{'name': field[1],
'type': field[2],
'null_ok': not field[3],
'pk': field[5] # undocumented
} for field in cursor.fetchall()]
# Maps SQL types to Django Field types. Some of the SQL types have multiple
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
BASE_DATA_TYPES_REVERSE = {
'bool': 'BooleanField',
'boolean': 'BooleanField',
'smallint': 'SmallIntegerField',
'smallinteger': 'SmallIntegerField',
'int': 'IntegerField',
'integer': 'IntegerField',
'text': 'TextField',
'char': 'CharField',
'date': 'DateField',
'datetime': 'DateTimeField',
'time': 'TimeField',
}
from django.db.backends import BaseDatabaseIntrospection
# This light wrapper "fakes" a dictionary interface, because some SQLite data
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
# as a simple dictionary lookup.
class FlexibleFieldLookupDict:
# Maps SQL types to Django Field types. Some of the SQL types have multiple
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
base_data_types_reverse = {
'bool': 'BooleanField',
'boolean': 'BooleanField',
'smallint': 'SmallIntegerField',
'smallinteger': 'SmallIntegerField',
'int': 'IntegerField',
'integer': 'IntegerField',
'text': 'TextField',
'char': 'CharField',
'date': 'DateField',
'datetime': 'DateTimeField',
'time': 'TimeField',
}
def __getitem__(self, key):
key = key.lower()
try:
return BASE_DATA_TYPES_REVERSE[key]
return self.base_data_types_reverse[key]
except KeyError:
import re
m = re.search(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$', key)
@ -86,4 +32,58 @@ class FlexibleFieldLookupDict:
return ('CharField', {'max_length': int(m.group(1))})
raise KeyError
DATA_TYPES_REVERSE = FlexibleFieldLookupDict()
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = FlexibleFieldLookupDict()
def get_table_list(self, cursor):
"Returns a list of table names in the current database."
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND NOT name='sqlite_sequence'
ORDER BY name""")
return [row[0] for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"Returns a description of the table, with the DB-API cursor.description interface."
return [(info['name'], info['type'], None, None, None, None,
info['null_ok']) for info in self._table_info(cursor, table_name)]
def get_relations(self, cursor, table_name):
raise NotImplementedError
def get_indexes(self, cursor, table_name):
"""
Returns a dictionary of fieldname -> infodict for the given table,
where each infodict is in the format:
{'primary_key': boolean representing whether it's the primary key,
'unique': boolean representing whether it's a unique index}
"""
indexes = {}
for info in self._table_info(cursor, table_name):
indexes[info['name']] = {'primary_key': info['pk'] != 0,
'unique': False}
cursor.execute('PRAGMA index_list(%s)' % self.connection.ops.quote_name(table_name))
# seq, name, unique
for index, unique in [(field[1], field[2]) for field in cursor.fetchall()]:
if not unique:
continue
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
info = cursor.fetchall()
# Skip indexes across multiple fields
if len(info) != 1:
continue
name = info[0][2] # seqno, cid, name
indexes[name]['unique'] = True
return indexes
def _table_info(self, cursor, name):
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(name))
# cid, name, type, notnull, dflt_value, pk
return [{'name': field[1],
'type': field[2],
'null_ok': not field[3],
'pk': field[5] # undocumented
} for field in cursor.fetchall()]

View File

@ -7,7 +7,7 @@ try:
except ImportError:
from django.utils import _decimal as decimal # for Python 2.3
from django.db import connection, get_creation_module
from django.db import connection
from django.db.models import signals
from django.db.models.query_utils import QueryWrapper
from django.dispatch import dispatcher
@ -145,14 +145,14 @@ class Field(object):
# as the TextField Django field type, which means XMLField's
# get_internal_type() returns 'TextField'.
#
# But the limitation of the get_internal_type() / DATA_TYPES approach
# But the limitation of the get_internal_type() / data_types approach
# is that it cannot handle database column types that aren't already
# mapped to one of the built-in Django field types. In this case, you
# can implement db_type() instead of get_internal_type() to specify
# exactly which wacky database column type you want to use.
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
try:
return get_creation_module().DATA_TYPES[self.get_internal_type()] % data
return connection.creation.data_types[self.get_internal_type()] % data
except KeyError:
return None

View File

@ -3,7 +3,6 @@ from django.conf import settings
from django.db.models import get_app, get_apps
from django.test import _doctest as doctest
from django.test.utils import setup_test_environment, teardown_test_environment
from django.test.utils import create_test_db, destroy_test_db
from django.test.testcases import OutputChecker, DocTestRunner
# The module name for tests outside models.py
@ -139,9 +138,10 @@ def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]):
suite.addTest(test)
old_name = settings.DATABASE_NAME
create_test_db(verbosity, autoclobber=not interactive)
from django.db import connection
connection.creation.create_test_db(verbosity, autoclobber=not interactive)
result = unittest.TextTestRunner(verbosity=verbosity).run(suite)
destroy_test_db(old_name, verbosity)
connection.creation.destroy_test_db(old_name, verbosity)
teardown_test_environment()

View File

@ -1,16 +1,11 @@
import sys, time, os
from django.conf import settings
from django.db import connection, get_creation_module
from django.db import connection
from django.core import mail
from django.core.management import call_command
from django.test import signals
from django.template import Template
from django.utils.translation import deactivate
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = 'test_'
def instrumented_test_render(self, context):
"""
An instrumented Template render method, providing a signal
@ -70,147 +65,3 @@ def teardown_test_environment():
del mail.outbox
def _set_autocommit(connection):
"Make sure a connection is in autocommit mode."
if hasattr(connection.connection, "autocommit"):
if callable(connection.connection.autocommit):
connection.connection.autocommit(True)
else:
connection.connection.autocommit = True
elif hasattr(connection.connection, "set_isolation_level"):
connection.connection.set_isolation_level(0)
def get_mysql_create_suffix():
suffix = []
if settings.TEST_DATABASE_CHARSET:
suffix.append('CHARACTER SET %s' % settings.TEST_DATABASE_CHARSET)
if settings.TEST_DATABASE_COLLATION:
suffix.append('COLLATE %s' % settings.TEST_DATABASE_COLLATION)
return ' '.join(suffix)
def get_postgresql_create_suffix():
assert settings.TEST_DATABASE_COLLATION is None, "PostgreSQL does not support collation setting at database creation time."
if settings.TEST_DATABASE_CHARSET:
return "WITH ENCODING '%s'" % settings.TEST_DATABASE_CHARSET
return ''
def create_test_db(verbosity=1, autoclobber=False):
"""
Creates a test database, prompting the user for confirmation if the
database already exists. Returns the name of the test database created.
"""
# If the database backend wants to create the test DB itself, let it
creation_module = get_creation_module()
if hasattr(creation_module, "create_test_db"):
creation_module.create_test_db(settings, connection, verbosity, autoclobber)
return
if verbosity >= 1:
print "Creating test database..."
# If we're using SQLite, it's more convenient to test against an
# in-memory database. Using the TEST_DATABASE_NAME setting you can still choose
# to run on a physical database.
if settings.DATABASE_ENGINE == "sqlite3":
if settings.TEST_DATABASE_NAME and settings.TEST_DATABASE_NAME != ":memory:":
TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME
# Erase the old test database
if verbosity >= 1:
print "Destroying old test database..."
if os.access(TEST_DATABASE_NAME, os.F_OK):
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
os.remove(TEST_DATABASE_NAME)
except Exception, e:
sys.stderr.write("Got an error deleting the old test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
if verbosity >= 1:
print "Creating test database..."
else:
TEST_DATABASE_NAME = ":memory:"
else:
suffix = {
'postgresql': get_postgresql_create_suffix,
'postgresql_psycopg2': get_postgresql_create_suffix,
'mysql': get_mysql_create_suffix,
}.get(settings.DATABASE_ENGINE, lambda: '')()
if settings.TEST_DATABASE_NAME:
TEST_DATABASE_NAME = settings.TEST_DATABASE_NAME
else:
TEST_DATABASE_NAME = TEST_DATABASE_PREFIX + settings.DATABASE_NAME
qn = connection.ops.quote_name
# Create the test database and connect to it. We need to autocommit
# if the database supports it because PostgreSQL doesn't allow
# CREATE/DROP DATABASE statements within transactions.
cursor = connection.cursor()
_set_autocommit(connection)
try:
cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix))
except Exception, e:
sys.stderr.write("Got an error creating the test database: %s\n" % e)
if not autoclobber:
confirm = raw_input("Type 'yes' if you would like to try deleting the test database '%s', or 'no' to cancel: " % TEST_DATABASE_NAME)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
print "Destroying old test database..."
cursor.execute("DROP DATABASE %s" % qn(TEST_DATABASE_NAME))
if verbosity >= 1:
print "Creating test database..."
cursor.execute("CREATE DATABASE %s %s" % (qn(TEST_DATABASE_NAME), suffix))
except Exception, e:
sys.stderr.write("Got an error recreating the test database: %s\n" % e)
sys.exit(2)
else:
print "Tests cancelled."
sys.exit(1)
connection.close()
settings.DATABASE_NAME = TEST_DATABASE_NAME
call_command('syncdb', verbosity=verbosity, interactive=False)
if settings.CACHE_BACKEND.startswith('db://'):
cache_name = settings.CACHE_BACKEND[len('db://'):]
call_command('createcachetable', cache_name)
# Get a cursor (even though we don't need one yet). This has
# the side effect of initializing the test database.
cursor = connection.cursor()
return TEST_DATABASE_NAME
def destroy_test_db(old_database_name, verbosity=1):
# If the database wants to drop the test DB itself, let it
creation_module = get_creation_module()
if hasattr(creation_module, "destroy_test_db"):
creation_module.destroy_test_db(settings, connection, old_database_name, verbosity)
return
if verbosity >= 1:
print "Destroying test database..."
connection.close()
TEST_DATABASE_NAME = settings.DATABASE_NAME
settings.DATABASE_NAME = old_database_name
if settings.DATABASE_ENGINE == "sqlite3":
if TEST_DATABASE_NAME and TEST_DATABASE_NAME != ":memory:":
# Remove the SQLite database file
os.remove(TEST_DATABASE_NAME)
else:
# Remove the test database to clean up after
# ourselves. Connect to the previous database (not the test database)
# to do so, because it's not allowed to delete a database while being
# connected to it.
cursor = connection.cursor()
_set_autocommit(connection)
time.sleep(1) # To avoid "database is being accessed by other users" errors.
cursor.execute("DROP DATABASE %s" % connection.ops.quote_name(TEST_DATABASE_NAME))
connection.close()

View File

@ -1026,6 +1026,9 @@ a number of utility methods in the ``django.test.utils`` module.
black magic hooks into the template system and restoring normal e-mail
services.
The creation module of the database backend (``connection.creation``) also
provides some utilities that can be useful during testing.
``create_test_db(verbosity=1, autoclobber=False)``
Creates a new test database and runs ``syncdb`` against it.
@ -1044,7 +1047,7 @@ a number of utility methods in the ``django.test.utils`` module.
``create_test_db()`` has the side effect of modifying
``settings.DATABASE_NAME`` to match the name of the test database.
New in the Django development version, this function returns the name of
**New in Django development version:** This function returns the name of
the test database that it created.
``destroy_test_db(old_database_name, verbosity=1)``

View File

@ -15,10 +15,6 @@ class Person(models.Model):
def __unicode__(self):
return u'%s %s' % (self.first_name, self.last_name)
if connection.features.uses_case_insensitive_names:
t_convert = lambda x: x.upper()
else:
t_convert = lambda x: x
qn = connection.ops.quote_name
__test__ = {'API_TESTS': """
@ -29,7 +25,7 @@ __test__ = {'API_TESTS': """
>>> opts = Square._meta
>>> f1, f2 = opts.get_field('root'), opts.get_field('square')
>>> query = ('INSERT INTO %s (%s, %s) VALUES (%%s, %%s)'
... % (t_convert(opts.db_table), qn(f1.column), qn(f2.column)))
... % (connection.introspection.table_name_converter(opts.db_table), qn(f1.column), qn(f2.column)))
>>> cursor.executemany(query, [(i, i**2) for i in range(-5, 6)]) and None or None
>>> Square.objects.order_by('root')
[<Square: -5 ** 2 == 25>, <Square: -4 ** 2 == 16>, <Square: -3 ** 2 == 9>, <Square: -2 ** 2 == 4>, <Square: -1 ** 2 == 1>, <Square: 0 ** 2 == 0>, <Square: 1 ** 2 == 1>, <Square: 2 ** 2 == 4>, <Square: 3 ** 2 == 9>, <Square: 4 ** 2 == 16>, <Square: 5 ** 2 == 25>]
@ -48,7 +44,7 @@ __test__ = {'API_TESTS': """
>>> opts2 = Person._meta
>>> f3, f4 = opts2.get_field('first_name'), opts2.get_field('last_name')
>>> query2 = ('SELECT %s, %s FROM %s ORDER BY %s'
... % (qn(f3.column), qn(f4.column), t_convert(opts2.db_table),
... % (qn(f3.column), qn(f4.column), connection.introspection.table_name_converter(opts2.db_table),
... qn(f3.column)))
>>> cursor.execute(query2) and None or None
>>> cursor.fetchone()