Fixed #19198 -- merged 4 different Oracle fixes

This commit is contained in:
Anssi Kääriäinen 2012-10-27 19:05:41 +03:00
commit 908efca817
6 changed files with 42 additions and 9 deletions

View File

@ -26,7 +26,7 @@ mysql = _default_db == 'mysql'
spatialite = _default_db == 'spatialite' spatialite = _default_db == 'spatialite'
HAS_SPATIALREFSYS = True HAS_SPATIALREFSYS = True
if oracle: if oracle and 'gis' in settings.DATABASES[DEFAULT_DB_ALIAS]['ENGINE']:
from django.contrib.gis.db.backends.oracle.models import SpatialRefSys from django.contrib.gis.db.backends.oracle.models import SpatialRefSys
elif postgis: elif postgis:
from django.contrib.gis.db.backends.postgis.models import SpatialRefSys from django.contrib.gis.db.backends.postgis.models import SpatialRefSys

View File

@ -256,6 +256,10 @@ WHEN (new.%(col_name)s IS NULL)
if not name.startswith('"') and not name.endswith('"'): if not name.startswith('"') and not name.endswith('"'):
name = '"%s"' % util.truncate_name(name.upper(), name = '"%s"' % util.truncate_name(name.upper(),
self.max_name_length()) self.max_name_length())
# Oracle puts the query text into a (query % args) construct, so % signs
# in names need to be escaped. The '%%' will be collapsed back to '%' at
# that stage so we aren't really making the name longer here.
name = name.replace('%','%%')
return name.upper() return name.upper()
def random_function_sql(self): def random_function_sql(self):

View File

@ -1418,8 +1418,15 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
fields = row[index_start : index_start + field_count] fields = row[index_start : index_start + field_count]
# If all the select_related columns are None, then the related # If all the select_related columns are None, then the related
# object must be non-existent - set the relation to None. # object must be non-existent - set the relation to None.
# Otherwise, construct the related object. # Otherwise, construct the related object. Also, some backends treat ''
if fields == (None,) * field_count: # and None equivalently for char fields, so we have to be prepared for
# '' values.
if connections[using].features.interprets_empty_strings_as_nulls:
vals = tuple([None if f == '' else f for f in fields])
else:
vals = fields
if vals == (None,) * field_count:
obj = None obj = None
else: else:
if field_names: if field_names:

View File

@ -25,6 +25,14 @@ from . import models
class OracleChecks(unittest.TestCase): class OracleChecks(unittest.TestCase):
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle quote_name semantics")
def test_quote_name(self):
# Check that '%' chars are escaped for query execution.
name = '"SOME%NAME"'
quoted_name = connection.ops.quote_name(name)
self.assertEquals(quoted_name % (), name)
@unittest.skipUnless(connection.vendor == 'oracle', @unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics") "No need to check Oracle cursor semantics")
def test_dbms_session(self): def test_dbms_session(self):

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.management import call_command from django.core.management import call_command
from django.db import connection
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.utils.six import StringIO from django.utils.six import StringIO
@ -60,14 +61,16 @@ class InspectDBTestCase(TestCase):
self.assertIn("number_45extra = models.CharField", output) self.assertIn("number_45extra = models.CharField", output)
def test_special_column_name_introspection(self): def test_special_column_name_introspection(self):
"""Introspection of column names containing special characters, """
unsuitable for Python identifiers Introspection of column names containing special characters,
unsuitable for Python identifiers
""" """
out = StringIO() out = StringIO()
call_command('inspectdb', stdout=out) call_command('inspectdb', stdout=out)
output = out.getvalue() output = out.getvalue()
base_name = 'Field' if connection.vendor != 'oracle' else 'field'
self.assertIn("field = models.IntegerField()", output) self.assertIn("field = models.IntegerField()", output)
self.assertIn("field_field = models.IntegerField(db_column='Field_')", output) self.assertIn("field_field = models.IntegerField(db_column='%s_')" % base_name, output)
self.assertIn("field_field_0 = models.IntegerField(db_column='Field__')", output) self.assertIn("field_field_0 = models.IntegerField(db_column='%s__')" % base_name, output)
self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output) self.assertIn("field_field_1 = models.IntegerField(db_column='__field')", output)
self.assertIn("prc_x = models.IntegerField(db_column='prc(%) x')", output) self.assertIn("prc_x = models.IntegerField(db_column='prc(%) x')", output)

View File

@ -4,10 +4,15 @@ from functools import update_wrapper
from django.db import connection from django.db import connection
from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature
from django.utils import six from django.utils import six, unittest
from .models import Reporter, Article from .models import Reporter, Article
if connection.vendor == 'oracle':
expectedFailureOnOracle = unittest.expectedFailure
else:
expectedFailureOnOracle = lambda f: f
# #
# The introspection module is optional, so methods tested here might raise # The introspection module is optional, so methods tested here might raise
# NotImplementedError. This is perfectly acceptable behavior for the backend # NotImplementedError. This is perfectly acceptable behavior for the backend
@ -89,7 +94,13 @@ class IntrospectionTests(six.with_metaclass(IgnoreNotimplementedError, TestCase)
[datatype(r[1], r) for r in desc], [datatype(r[1], r) for r in desc],
['IntegerField', 'CharField', 'CharField', 'CharField', 'BigIntegerField'] ['IntegerField', 'CharField', 'CharField', 'CharField', 'BigIntegerField']
) )
# Check also length of CharFields
# The following test fails on Oracle due to #17202 (can't correctly
# inspect the length of character columns).
@expectedFailureOnOracle
def test_get_table_description_col_lengths(self):
cursor = connection.cursor()
desc = connection.introspection.get_table_description(cursor, Reporter._meta.db_table)
self.assertEqual( self.assertEqual(
[r[3] for r in desc if datatype(r[1], r) == 'CharField'], [r[3] for r in desc if datatype(r[1], r) == 'CharField'],
[30, 30, 75] [30, 30, 75]