diff --git a/tests/regressiontests/serializers_regress/tests.py b/tests/regressiontests/serializers_regress/tests.py index 84e90ff7e1..be920c6920 100644 --- a/tests/regressiontests/serializers_regress/tests.py +++ b/tests/regressiontests/serializers_regress/tests.py @@ -10,14 +10,16 @@ forward, backwards and self references. import datetime import decimal -import unittest -from cStringIO import StringIO +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO -from django.utils.functional import curry -from django.core import serializers -from django.db import transaction, DEFAULT_DB_ALIAS -from django.core import management from django.conf import settings +from django.core import serializers, management +from django.db import transaction, DEFAULT_DB_ALIAS +from django.test import TestCase +from django.utils.functional import curry from models import * @@ -59,10 +61,10 @@ def im2m_create(pk, klass, data): def im_create(pk, klass, data): instance = klass(id=pk) - setattr(instance, 'right_id', data['right']) - setattr(instance, 'left_id', data['left']) + instance.right_id = data['right'] + instance.left_id = data['left'] if 'extra' in data: - setattr(instance, 'extra', data['extra']) + instance.extra = data['extra'] models.Model.save_base(instance, raw=True) return [instance] @@ -96,7 +98,9 @@ def inherited_create(pk, klass, data): def data_compare(testcase, pk, klass, data): instance = klass.objects.get(id=pk) testcase.assertEqual(data, instance.data, - "Objects with PK=%d not equal; expected '%s' (%s), got '%s' (%s)" % (pk,data, type(data), instance.data, type(instance.data))) + "Objects with PK=%d not equal; expected '%s' (%s), got '%s' (%s)" % ( + pk, data, type(data), instance.data, type(instance.data)) + ) def generic_compare(testcase, pk, klass, data): instance = klass.objects.get(id=pk) @@ -348,28 +352,16 @@ if settings.DATABASES[DEFAULT_DB_ALIAS]['ENGINE'] != 'django.db.backends.mysql': # Dynamically create serializer tests to ensure that all # registered serializers are automatically tested. -class SerializerTests(unittest.TestCase): +class SerializerTests(TestCase): pass def serializerTest(format, self): - # Clear the database first - management.call_command('flush', verbosity=0, interactive=False) # Create all the objects defined in the test data objects = [] instance_count = {} - transaction.enter_transaction_management() - try: - transaction.managed(True) - for (func, pk, klass, datum) in test_data: - objects.extend(func[0](pk, klass, datum)) - instance_count[klass] = 0 - transaction.commit() - except: - transaction.rollback() - transaction.leave_transaction_management() - raise - transaction.leave_transaction_management() + for (func, pk, klass, datum) in test_data: + objects.extend(func[0](pk, klass, datum)) # Get a count of the number of objects created for each class for klass in instance_count: @@ -381,19 +373,8 @@ def serializerTest(format, self): # Serialize the test database serialized_data = serializers.serialize(format, objects, indent=2) - # Flush the database and recreate from the serialized data - management.call_command('flush', verbosity=0, interactive=False) - transaction.enter_transaction_management() - try: - transaction.managed(True) - for obj in serializers.deserialize(format, serialized_data): - obj.save() - transaction.commit() - except: - transaction.rollback() - transaction.leave_transaction_management() - raise - transaction.leave_transaction_management() + for obj in serializers.deserialize(format, serialized_data): + obj.save() # Assert that the deserialized data is the same # as the original source @@ -406,10 +387,7 @@ def serializerTest(format, self): self.assertEquals(count, klass.objects.count()) def fieldsTest(format, self): - # Clear the database first - management.call_command('flush', verbosity=0, interactive=False) - - obj = ComplexModel(field1='first',field2='second',field3='third') + obj = ComplexModel(field1='first', field2='second', field3='third') obj.save_base(raw=True) # Serialize then deserialize the test database @@ -422,9 +400,6 @@ def fieldsTest(format, self): self.assertEqual(result.object.field3, 'third') def streamTest(format, self): - # Clear the database first - management.call_command('flush', verbosity=0, interactive=False) - obj = ComplexModel(field1='first',field2='second',field3='third') obj.save_base(raw=True) @@ -440,7 +415,7 @@ def streamTest(format, self): stream.close() for format in serializers.get_serializer_formats(): - setattr(SerializerTests, 'test_'+format+'_serializer', curry(serializerTest, format)) - setattr(SerializerTests, 'test_'+format+'_serializer_fields', curry(fieldsTest, format)) + setattr(SerializerTests, 'test_' + format + '_serializer', curry(serializerTest, format)) + setattr(SerializerTests, 'test_' + format + '_serializer_fields', curry(fieldsTest, format)) if format != 'python': - setattr(SerializerTests, 'test_'+format+'_serializer_stream', curry(streamTest, format)) + setattr(SerializerTests, 'test_' + format + '_serializer_stream', curry(streamTest, format))