Refactored get_query_set_class() to DatabaseOperations.query_set_class(). Also added BaseDatabaseFeatures.uses_custom_queryset. Refs #5106
git-svn-id: http://code.djangoproject.com/svn/django/trunk@5976 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
6d8e6090e5
commit
e13ea3c70d
|
@ -48,6 +48,7 @@ class BaseDatabaseFeatures(object):
|
|||
supports_constraints = True
|
||||
supports_tablespaces = False
|
||||
uses_case_insensitive_names = False
|
||||
uses_custom_queryset = False
|
||||
|
||||
class BaseDatabaseOperations(object):
|
||||
"""
|
||||
|
@ -144,6 +145,15 @@ class BaseDatabaseOperations(object):
|
|||
"""
|
||||
return 'DEFAULT'
|
||||
|
||||
def query_set_class(self, DefaultQuerySet):
|
||||
"""
|
||||
Given the default QuerySet class, returns a custom QuerySet class
|
||||
to use for this backend. Returns None if a custom QuerySet isn't used.
|
||||
See also BaseDatabaseFeatures.uses_custom_queryset, which regulates
|
||||
whether this method is called at all.
|
||||
"""
|
||||
return None
|
||||
|
||||
def quote_name(self, name):
|
||||
"""
|
||||
Returns a quoted version of the given table, index or column name. Does
|
||||
|
|
|
@ -28,6 +28,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
needs_upper_for_iops = True
|
||||
supports_tablespaces = True
|
||||
uses_case_insensitive_names = True
|
||||
uses_custom_queryset = True
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
def autoinc_sql(self, table):
|
||||
|
@ -78,192 +79,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
def max_name_length(self):
|
||||
return 30
|
||||
|
||||
def quote_name(self, name):
|
||||
# SQL92 requires delimited (quoted) names to be case-sensitive. When
|
||||
# not quoted, Oracle has case-insensitive behavior for identifiers, but
|
||||
# always defaults to uppercase.
|
||||
# We simplify things by making Oracle identifiers always uppercase.
|
||||
if not name.startswith('"') and not name.endswith('"'):
|
||||
name = '"%s"' % util.truncate_name(name.upper(), DatabaseOperations().max_name_length())
|
||||
return name.upper()
|
||||
|
||||
def random_function_sql(self):
|
||||
return "DBMS_RANDOM.RANDOM"
|
||||
|
||||
def sql_flush(self, style, tables, sequences):
|
||||
# Return a list of 'TRUNCATE x;', 'TRUNCATE y;',
|
||||
# 'TRUNCATE z;'... style SQL statements
|
||||
if tables:
|
||||
# Oracle does support TRUNCATE, but it seems to get us into
|
||||
# FK referential trouble, whereas DELETE FROM table works.
|
||||
sql = ['%s %s %s;' % \
|
||||
(style.SQL_KEYWORD('DELETE'),
|
||||
style.SQL_KEYWORD('FROM'),
|
||||
style.SQL_FIELD(self.quote_name(table))
|
||||
) for table in tables]
|
||||
# Since we've just deleted all the rows, running our sequence
|
||||
# ALTER code will reset the sequence to 0.
|
||||
for sequence_info in sequences:
|
||||
table_name = sequence_info['table']
|
||||
seq_name = get_sequence_name(table_name)
|
||||
query = _get_sequence_reset_sql() % {'sequence': seq_name, 'table': self.quote_name(table_name)}
|
||||
sql.append(query)
|
||||
return sql
|
||||
else:
|
||||
return []
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
from django.db import models
|
||||
output = []
|
||||
query = _get_sequence_reset_sql()
|
||||
for model in model_list:
|
||||
for f in model._meta.fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
sequence_name = get_sequence_name(model._meta.db_table)
|
||||
output.append(query % {'sequence':sequence_name,
|
||||
'table':model._meta.db_table})
|
||||
break # Only one AutoField is allowed per model, so don't bother continuing.
|
||||
for f in model._meta.many_to_many:
|
||||
sequence_name = get_sequence_name(f.m2m_db_table())
|
||||
output.append(query % {'sequence':sequence_name,
|
||||
'table':f.m2m_db_table()})
|
||||
return output
|
||||
|
||||
def start_transaction_sql(self):
|
||||
return ''
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
return "%sTABLESPACE %s" % ((inline and "USING INDEX " or ""), self.quote_name(tablespace))
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
features = DatabaseFeatures()
|
||||
ops = DatabaseOperations()
|
||||
|
||||
def _valid_connection(self):
|
||||
return self.connection is not None
|
||||
|
||||
def _cursor(self, settings):
|
||||
if not self._valid_connection():
|
||||
if len(settings.DATABASE_HOST.strip()) == 0:
|
||||
settings.DATABASE_HOST = 'localhost'
|
||||
if len(settings.DATABASE_PORT.strip()) != 0:
|
||||
dsn = Database.makedsn(settings.DATABASE_HOST, int(settings.DATABASE_PORT), settings.DATABASE_NAME)
|
||||
self.connection = Database.connect(settings.DATABASE_USER, settings.DATABASE_PASSWORD, dsn, **self.options)
|
||||
else:
|
||||
conn_string = "%s/%s@%s" % (settings.DATABASE_USER, settings.DATABASE_PASSWORD, settings.DATABASE_NAME)
|
||||
self.connection = Database.connect(conn_string, **self.options)
|
||||
cursor = FormatStylePlaceholderCursor(self.connection)
|
||||
# Default arraysize of 1 is highly sub-optimal.
|
||||
cursor.arraysize = 100
|
||||
# Set oracle date to ansi date format.
|
||||
cursor.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'")
|
||||
cursor.execute("ALTER SESSION SET NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'")
|
||||
return cursor
|
||||
|
||||
class FormatStylePlaceholderCursor(Database.Cursor):
|
||||
"""
|
||||
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
||||
style. This fixes it -- but note that if you want to use a literal "%s" in
|
||||
a query, you'll need to use "%%s".
|
||||
|
||||
We also do automatic conversion between Unicode on the Python side and
|
||||
UTF-8 -- for talking to Oracle -- in here.
|
||||
"""
|
||||
charset = 'utf-8'
|
||||
|
||||
def _rewrite_args(self, query, params=None):
|
||||
if params is None:
|
||||
params = []
|
||||
else:
|
||||
params = self._format_params(params)
|
||||
args = [(':arg%d' % i) for i in range(len(params))]
|
||||
query = smart_str(query, self.charset) % tuple(args)
|
||||
# cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
|
||||
# it does want a trailing ';' but not a trailing '/'. However, these
|
||||
# characters must be included in the original query in case the query
|
||||
# is being passed to SQL*Plus.
|
||||
if query.endswith(';') or query.endswith('/'):
|
||||
query = query[:-1]
|
||||
return query, params
|
||||
|
||||
def _format_params(self, params):
|
||||
if isinstance(params, dict):
|
||||
result = {}
|
||||
charset = self.charset
|
||||
for key, value in params.items():
|
||||
result[smart_str(key, charset)] = smart_str(value, charset)
|
||||
return result
|
||||
else:
|
||||
return tuple([smart_str(p, self.charset, True) for p in params])
|
||||
|
||||
def execute(self, query, params=None):
|
||||
query, params = self._rewrite_args(query, params)
|
||||
return Database.Cursor.execute(self, query, params)
|
||||
|
||||
def executemany(self, query, params=None):
|
||||
query, params = self._rewrite_args(query, params)
|
||||
return Database.Cursor.executemany(self, query, params)
|
||||
|
||||
def fetchone(self):
|
||||
return to_unicode(Database.Cursor.fetchone(self))
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
return tuple([tuple([to_unicode(e) for e in r]) for r in Database.Cursor.fetchmany(self, size)])
|
||||
|
||||
def fetchall(self):
|
||||
return tuple([tuple([to_unicode(e) for e in r]) for r in Database.Cursor.fetchall(self)])
|
||||
|
||||
def to_unicode(s):
|
||||
"""
|
||||
Convert strings to Unicode objects (and return all other data types
|
||||
unchanged).
|
||||
"""
|
||||
if isinstance(s, basestring):
|
||||
return force_unicode(s)
|
||||
return s
|
||||
|
||||
def get_field_cast_sql(db_type):
|
||||
if db_type.endswith('LOB'):
|
||||
return "DBMS_LOB.SUBSTR(%s%s)"
|
||||
else:
|
||||
return "%s%s"
|
||||
|
||||
def get_drop_sequence(table):
|
||||
return "DROP SEQUENCE %s;" % DatabaseOperations().quote_name(get_sequence_name(table))
|
||||
|
||||
def _get_sequence_reset_sql():
|
||||
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
|
||||
return """
|
||||
DECLARE
|
||||
startvalue integer;
|
||||
cval integer;
|
||||
BEGIN
|
||||
LOCK TABLE %(table)s IN SHARE MODE;
|
||||
SELECT NVL(MAX(id), 0) INTO startvalue FROM %(table)s;
|
||||
SELECT %(sequence)s.nextval INTO cval FROM dual;
|
||||
cval := startvalue - cval;
|
||||
IF cval != 0 THEN
|
||||
EXECUTE IMMEDIATE 'ALTER SEQUENCE %(sequence)s MINVALUE 0 INCREMENT BY '||cval;
|
||||
SELECT %(sequence)s.nextval INTO cval FROM dual;
|
||||
EXECUTE IMMEDIATE 'ALTER SEQUENCE %(sequence)s INCREMENT BY 1';
|
||||
END IF;
|
||||
COMMIT;
|
||||
END;
|
||||
/"""
|
||||
|
||||
def get_sequence_name(table):
|
||||
name_length = DatabaseOperations().max_name_length() - 3
|
||||
return '%s_SQ' % util.truncate_name(table, name_length).upper()
|
||||
|
||||
def get_trigger_name(table):
|
||||
name_length = DatabaseOperations().max_name_length() - 3
|
||||
return '%s_TR' % util.truncate_name(table, name_length).upper()
|
||||
|
||||
def get_query_set_class(DefaultQuerySet):
|
||||
"Create a custom QuerySet class for Oracle."
|
||||
|
||||
def query_set_class(self, DefaultQuerySet):
|
||||
from django.db import connection
|
||||
from django.db.models.query import EmptyResultSet, GET_ITERATOR_CHUNK_SIZE, quote_only_if_word
|
||||
|
||||
|
@ -500,6 +316,188 @@ def get_query_set_class(DefaultQuerySet):
|
|||
|
||||
return OracleQuerySet
|
||||
|
||||
def quote_name(self, name):
|
||||
# SQL92 requires delimited (quoted) names to be case-sensitive. When
|
||||
# not quoted, Oracle has case-insensitive behavior for identifiers, but
|
||||
# always defaults to uppercase.
|
||||
# We simplify things by making Oracle identifiers always uppercase.
|
||||
if not name.startswith('"') and not name.endswith('"'):
|
||||
name = '"%s"' % util.truncate_name(name.upper(), DatabaseOperations().max_name_length())
|
||||
return name.upper()
|
||||
|
||||
def random_function_sql(self):
|
||||
return "DBMS_RANDOM.RANDOM"
|
||||
|
||||
def sql_flush(self, style, tables, sequences):
|
||||
# Return a list of 'TRUNCATE x;', 'TRUNCATE y;',
|
||||
# 'TRUNCATE z;'... style SQL statements
|
||||
if tables:
|
||||
# Oracle does support TRUNCATE, but it seems to get us into
|
||||
# FK referential trouble, whereas DELETE FROM table works.
|
||||
sql = ['%s %s %s;' % \
|
||||
(style.SQL_KEYWORD('DELETE'),
|
||||
style.SQL_KEYWORD('FROM'),
|
||||
style.SQL_FIELD(self.quote_name(table))
|
||||
) for table in tables]
|
||||
# Since we've just deleted all the rows, running our sequence
|
||||
# ALTER code will reset the sequence to 0.
|
||||
for sequence_info in sequences:
|
||||
table_name = sequence_info['table']
|
||||
seq_name = get_sequence_name(table_name)
|
||||
query = _get_sequence_reset_sql() % {'sequence': seq_name, 'table': self.quote_name(table_name)}
|
||||
sql.append(query)
|
||||
return sql
|
||||
else:
|
||||
return []
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
from django.db import models
|
||||
output = []
|
||||
query = _get_sequence_reset_sql()
|
||||
for model in model_list:
|
||||
for f in model._meta.fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
sequence_name = get_sequence_name(model._meta.db_table)
|
||||
output.append(query % {'sequence':sequence_name,
|
||||
'table':model._meta.db_table})
|
||||
break # Only one AutoField is allowed per model, so don't bother continuing.
|
||||
for f in model._meta.many_to_many:
|
||||
sequence_name = get_sequence_name(f.m2m_db_table())
|
||||
output.append(query % {'sequence':sequence_name,
|
||||
'table':f.m2m_db_table()})
|
||||
return output
|
||||
|
||||
def start_transaction_sql(self):
|
||||
return ''
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
return "%sTABLESPACE %s" % ((inline and "USING INDEX " or ""), self.quote_name(tablespace))
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
features = DatabaseFeatures()
|
||||
ops = DatabaseOperations()
|
||||
|
||||
def _valid_connection(self):
|
||||
return self.connection is not None
|
||||
|
||||
def _cursor(self, settings):
|
||||
if not self._valid_connection():
|
||||
if len(settings.DATABASE_HOST.strip()) == 0:
|
||||
settings.DATABASE_HOST = 'localhost'
|
||||
if len(settings.DATABASE_PORT.strip()) != 0:
|
||||
dsn = Database.makedsn(settings.DATABASE_HOST, int(settings.DATABASE_PORT), settings.DATABASE_NAME)
|
||||
self.connection = Database.connect(settings.DATABASE_USER, settings.DATABASE_PASSWORD, dsn, **self.options)
|
||||
else:
|
||||
conn_string = "%s/%s@%s" % (settings.DATABASE_USER, settings.DATABASE_PASSWORD, settings.DATABASE_NAME)
|
||||
self.connection = Database.connect(conn_string, **self.options)
|
||||
cursor = FormatStylePlaceholderCursor(self.connection)
|
||||
# Default arraysize of 1 is highly sub-optimal.
|
||||
cursor.arraysize = 100
|
||||
# Set oracle date to ansi date format.
|
||||
cursor.execute("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'")
|
||||
cursor.execute("ALTER SESSION SET NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'")
|
||||
return cursor
|
||||
|
||||
class FormatStylePlaceholderCursor(Database.Cursor):
|
||||
"""
|
||||
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
||||
style. This fixes it -- but note that if you want to use a literal "%s" in
|
||||
a query, you'll need to use "%%s".
|
||||
|
||||
We also do automatic conversion between Unicode on the Python side and
|
||||
UTF-8 -- for talking to Oracle -- in here.
|
||||
"""
|
||||
charset = 'utf-8'
|
||||
|
||||
def _rewrite_args(self, query, params=None):
|
||||
if params is None:
|
||||
params = []
|
||||
else:
|
||||
params = self._format_params(params)
|
||||
args = [(':arg%d' % i) for i in range(len(params))]
|
||||
query = smart_str(query, self.charset) % tuple(args)
|
||||
# cx_Oracle wants no trailing ';' for SQL statements. For PL/SQL, it
|
||||
# it does want a trailing ';' but not a trailing '/'. However, these
|
||||
# characters must be included in the original query in case the query
|
||||
# is being passed to SQL*Plus.
|
||||
if query.endswith(';') or query.endswith('/'):
|
||||
query = query[:-1]
|
||||
return query, params
|
||||
|
||||
def _format_params(self, params):
|
||||
if isinstance(params, dict):
|
||||
result = {}
|
||||
charset = self.charset
|
||||
for key, value in params.items():
|
||||
result[smart_str(key, charset)] = smart_str(value, charset)
|
||||
return result
|
||||
else:
|
||||
return tuple([smart_str(p, self.charset, True) for p in params])
|
||||
|
||||
def execute(self, query, params=None):
|
||||
query, params = self._rewrite_args(query, params)
|
||||
return Database.Cursor.execute(self, query, params)
|
||||
|
||||
def executemany(self, query, params=None):
|
||||
query, params = self._rewrite_args(query, params)
|
||||
return Database.Cursor.executemany(self, query, params)
|
||||
|
||||
def fetchone(self):
|
||||
return to_unicode(Database.Cursor.fetchone(self))
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
return tuple([tuple([to_unicode(e) for e in r]) for r in Database.Cursor.fetchmany(self, size)])
|
||||
|
||||
def fetchall(self):
|
||||
return tuple([tuple([to_unicode(e) for e in r]) for r in Database.Cursor.fetchall(self)])
|
||||
|
||||
def to_unicode(s):
|
||||
"""
|
||||
Convert strings to Unicode objects (and return all other data types
|
||||
unchanged).
|
||||
"""
|
||||
if isinstance(s, basestring):
|
||||
return force_unicode(s)
|
||||
return s
|
||||
|
||||
def get_field_cast_sql(db_type):
|
||||
if db_type.endswith('LOB'):
|
||||
return "DBMS_LOB.SUBSTR(%s%s)"
|
||||
else:
|
||||
return "%s%s"
|
||||
|
||||
def get_drop_sequence(table):
|
||||
return "DROP SEQUENCE %s;" % DatabaseOperations().quote_name(get_sequence_name(table))
|
||||
|
||||
def _get_sequence_reset_sql():
|
||||
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
|
||||
return """
|
||||
DECLARE
|
||||
startvalue integer;
|
||||
cval integer;
|
||||
BEGIN
|
||||
LOCK TABLE %(table)s IN SHARE MODE;
|
||||
SELECT NVL(MAX(id), 0) INTO startvalue FROM %(table)s;
|
||||
SELECT %(sequence)s.nextval INTO cval FROM dual;
|
||||
cval := startvalue - cval;
|
||||
IF cval != 0 THEN
|
||||
EXECUTE IMMEDIATE 'ALTER SEQUENCE %(sequence)s MINVALUE 0 INCREMENT BY '||cval;
|
||||
SELECT %(sequence)s.nextval INTO cval FROM dual;
|
||||
EXECUTE IMMEDIATE 'ALTER SEQUENCE %(sequence)s INCREMENT BY 1';
|
||||
END IF;
|
||||
COMMIT;
|
||||
END;
|
||||
/"""
|
||||
|
||||
def get_sequence_name(table):
|
||||
name_length = DatabaseOperations().max_name_length() - 3
|
||||
return '%s_SQ' % util.truncate_name(table, name_length).upper()
|
||||
|
||||
def get_trigger_name(table):
|
||||
name_length = DatabaseOperations().max_name_length() - 3
|
||||
return '%s_TR' % util.truncate_name(table, name_length).upper()
|
||||
|
||||
OPERATOR_MAPPING = {
|
||||
'exact': '= %s',
|
||||
|
|
|
@ -564,9 +564,9 @@ class _QuerySet(object):
|
|||
|
||||
return select, " ".join(sql), params
|
||||
|
||||
# Use the backend's QuerySet class if it defines one, otherwise use _QuerySet.
|
||||
if hasattr(backend, 'get_query_set_class'):
|
||||
QuerySet = backend.get_query_set_class(_QuerySet)
|
||||
# Use the backend's QuerySet class if it defines one. Otherwise, use _QuerySet.
|
||||
if connection.features.uses_custom_queryset:
|
||||
QuerySet = connection.ops.query_set_class(_QuerySet)
|
||||
else:
|
||||
QuerySet = _QuerySet
|
||||
|
||||
|
|
Loading…
Reference in New Issue