Changes to get raw queries working on the oracle backend.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@11968 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Ian Kelly 2009-12-22 21:05:15 +00:00
parent cec64b96b0
commit cdf5ad4217
3 changed files with 66 additions and 47 deletions

View File

@ -514,58 +514,18 @@ class FormatStylePlaceholderCursor(object):
row = self.cursor.fetchone()
if row is None:
return row
return self._rowfactory(row)
return _rowfactory(row, self.cursor)
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
return tuple([self._rowfactory(r)
return tuple([_rowfactory(r, self.cursor)
for r in self.cursor.fetchmany(size)])
def fetchall(self):
return tuple([self._rowfactory(r)
return tuple([_rowfactory(r, self.cursor)
for r in self.cursor.fetchall()])
def _rowfactory(self, row):
# Cast numeric values as the appropriate Python type based upon the
# cursor description, and convert strings to unicode.
casted = []
for value, desc in zip(row, self.cursor.description):
if value is not None and desc[1] is Database.NUMBER:
precision, scale = desc[4:6]
if scale == -127:
if precision == 0:
# NUMBER column: decimal-precision floating point
# This will normally be an integer from a sequence,
# but it could be a decimal value.
if '.' in value:
value = Decimal(value)
else:
value = int(value)
else:
# FLOAT column: binary-precision floating point.
# This comes from FloatField columns.
value = float(value)
elif precision > 0:
# NUMBER(p,s) column: decimal-precision fixed point.
# This comes from IntField and DecimalField columns.
if scale == 0:
value = int(value)
else:
value = Decimal(value)
elif '.' in value:
# No type information. This normally comes from a
# mathematical expression in the SELECT list. Guess int
# or Decimal based on whether it has a decimal point.
value = Decimal(value)
else:
value = int(value)
elif desc[1] in (Database.STRING, Database.FIXED_CHAR,
Database.LONG_STRING):
value = to_unicode(value)
casted.append(value)
return tuple(casted)
def __getattr__(self, attr):
if attr in self.__dict__:
return self.__dict__[attr]
@ -573,7 +533,63 @@ class FormatStylePlaceholderCursor(object):
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
return CursorIterator(self.cursor)
class CursorIterator(object):
"""Cursor iterator wrapper that invokes our custom row factory."""
def __init__(self, cursor):
self.cursor = cursor
self.iter = iter(cursor)
def __iter__(self):
return self
def next(self):
return _rowfactory(self.iter.next(), self.cursor)
def _rowfactory(row, cursor):
# Cast numeric values as the appropriate Python type based upon the
# cursor description, and convert strings to unicode.
casted = []
for value, desc in zip(row, cursor.description):
if value is not None and desc[1] is Database.NUMBER:
precision, scale = desc[4:6]
if scale == -127:
if precision == 0:
# NUMBER column: decimal-precision floating point
# This will normally be an integer from a sequence,
# but it could be a decimal value.
if '.' in value:
value = Decimal(value)
else:
value = int(value)
else:
# FLOAT column: binary-precision floating point.
# This comes from FloatField columns.
value = float(value)
elif precision > 0:
# NUMBER(p,s) column: decimal-precision fixed point.
# This comes from IntField and DecimalField columns.
if scale == 0:
value = int(value)
else:
value = Decimal(value)
elif '.' in value:
# No type information. This normally comes from a
# mathematical expression in the SELECT list. Guess int
# or Decimal based on whether it has a decimal point.
value = Decimal(value)
else:
value = int(value)
elif desc[1] in (Database.STRING, Database.FIXED_CHAR,
Database.LONG_STRING):
value = to_unicode(value)
casted.append(value)
return tuple(casted)
def to_unicode(s):

View File

@ -1210,10 +1210,11 @@ class RawQuerySet(object):
A dict mapping column names to model field names.
"""
if not hasattr(self, '_model_fields'):
converter = connections[self.db].introspection.table_name_converter
self._model_fields = {}
for field in self.model._meta.fields:
name, column = field.get_attname_column()
self._model_fields[column] = name
self._model_fields[converter(column)] = name
return self._model_fields
def transform_results(self, values):

View File

@ -42,7 +42,9 @@ class RawQuery(object):
def get_columns(self):
if self.cursor is None:
self._execute_query()
return [column_meta[0] for column_meta in self.cursor.description]
converter = connections[self.using].introspection.table_name_converter
return [converter(column_meta[0])
for column_meta in self.cursor.description]
def validate_sql(self, sql):
if not sql.lower().strip().startswith('select'):
@ -53,7 +55,7 @@ class RawQuery(object):
# Always execute a new query for a new iterator.
# This could be optomized with a cache at the expense of RAM.
self._execute_query()
return self.cursor
return iter(self.cursor)
def __repr__(self):
return "<RawQuery: %r>" % (self.sql % self.params)