Fixed #5543: callproc() and friends now work with Oracle and our FormatStylePlaceholderCursor.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9767 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Matt Boersma 2009-01-16 22:23:58 +00:00
parent 6332ad4804
commit b41a45f8e5
2 changed files with 49 additions and 12 deletions

View File

@ -291,10 +291,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
pass
if not cursor:
cursor = FormatStylePlaceholderCursor(self.connection)
# Necessary to retrieve decimal values without rounding error.
cursor.numbersAsStrings = True
# Default arraysize of 1 is highly sub-optimal.
cursor.arraysize = 100
return cursor
@ -320,7 +316,7 @@ class OracleParam(object):
self.input_size = None
class FormatStylePlaceholderCursor(Database.Cursor):
class FormatStylePlaceholderCursor(object):
"""
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
style. This fixes it -- but note that if you want to use a literal "%s" in
@ -331,6 +327,13 @@ class FormatStylePlaceholderCursor(Database.Cursor):
"""
charset = 'utf-8'
def __init__(self, connection):
self.cursor = connection.cursor()
# Necessary to retrieve decimal values without rounding error.
self.cursor.numbersAsStrings = True
# Default arraysize of 1 is highly sub-optimal.
self.cursor.arraysize = 100
def _format_params(self, params):
return tuple([OracleParam(p, self.charset, True) for p in params])
@ -360,8 +363,7 @@ class FormatStylePlaceholderCursor(Database.Cursor):
query = smart_str(query, self.charset) % tuple(args)
self._guess_input_sizes([params])
try:
return Database.Cursor.execute(self, query,
self._param_generator(params))
return self.cursor.execute(query, self._param_generator(params))
except DatabaseError, e:
# cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
@ -384,7 +386,7 @@ class FormatStylePlaceholderCursor(Database.Cursor):
formatted = [self._format_params(i) for i in params]
self._guess_input_sizes(formatted)
try:
return Database.Cursor.executemany(self, query,
return self.cursor.executemany(query,
[self._param_generator(p) for p in formatted])
except DatabaseError, e:
# cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
@ -393,7 +395,7 @@ class FormatStylePlaceholderCursor(Database.Cursor):
raise e
def fetchone(self):
row = Database.Cursor.fetchone(self)
row = self.cursor.fetchone()
if row is None:
return row
return self._rowfactory(row)
@ -402,17 +404,17 @@ class FormatStylePlaceholderCursor(Database.Cursor):
if size is None:
size = self.arraysize
return tuple([self._rowfactory(r)
for r in Database.Cursor.fetchmany(self, size)])
for r in self.cursor.fetchmany(size)])
def fetchall(self):
return tuple([self._rowfactory(r)
for r in Database.Cursor.fetchall(self)])
for r in self.cursor.fetchall()])
def _rowfactory(self, row):
# Cast numeric values as the appropriate Python type based upon the
# cursor description, and convert strings to unicode.
casted = []
for value, desc in zip(row, self.description):
for value, desc in zip(row, self.cursor.description):
if value is not None and desc[1] is Database.NUMBER:
precision, scale = desc[4:6]
if scale == -127:
@ -447,6 +449,15 @@ class FormatStylePlaceholderCursor(Database.Cursor):
casted.append(value)
return tuple(casted)
def __getattr__(self, attr):
if attr in self.__dict__:
return self.__dict__[attr]
else:
return getattr(self.cursor, attr)
def __iter__(self):
return iter(self.cursor)
def to_unicode(s):
"""

View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Unit tests for specific database backends.
import unittest
from django.db import connection
from django.conf import settings
class Callproc(unittest.TestCase):
def test_dbms_session(self):
# If the backend is Oracle, test that we can call a standard
# stored procedure through our cursor wrapper.
if settings.DATABASE_ENGINE == 'oracle':
cursor = connection.cursor()
cursor.callproc('DBMS_SESSION.SET_IDENTIFIER',
['_django_testing!',])
return True
else:
return True
if __name__ == '__main__':
unittest.main()