Stopped special-casing postgres-specific tests

Refs #23879.
This commit is contained in:
Claude Paroz 2015-04-04 18:10:26 +02:00
parent 6b6d13bf6e
commit 36e90d1f45
12 changed files with 154 additions and 71 deletions

View File

@ -12,6 +12,8 @@ class CreateExtension(Operation):
pass
def database_forwards(self, app_label, schema_editor, from_state, to_state):
if schema_editor.connection.vendor != 'postgresql':
return
schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % self.name)
def database_backwards(self, app_label, schema_editor, from_state, to_state):

View File

@ -0,0 +1,9 @@
import unittest
from django.db import connection
from django.test import TestCase
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
class PostgresSQLTestCase(TestCase):
pass

View File

@ -0,0 +1,31 @@
"""
Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
run with a backend other than PostgreSQL.
"""
from django.db import models
try:
from django.contrib.postgres.fields import (
ArrayField, BigIntegerRangeField, DateRangeField, DateTimeRangeField,
FloatRangeField, HStoreField, IntegerRangeField,
)
except ImportError:
class DummyArrayField(models.Field):
def __init__(self, base_field, size=None, **kwargs):
super(DummyArrayField, self).__init__(**kwargs)
def deconstruct(self):
name, path, args, kwargs = super(DummyArrayField, self).deconstruct()
kwargs.update({
'base_field': '',
'size': 1,
})
return name, path, args, kwargs
ArrayField = DummyArrayField
BigIntegerRangeField = models.Field
DateRangeField = models.Field
DateTimeRangeField = models.Field
FloatRangeField = models.Field
HStoreField = models.Field
IntegerRangeField = models.Field

View File

@ -1,11 +1,17 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.contrib.postgres.operations import (
HStoreExtension, UnaccentExtension,
)
from django.db import migrations
try:
from django.contrib.postgres.operations import (
HStoreExtension, UnaccentExtension,
)
except ImportError:
from django.test import mock
HStoreExtension = mock.Mock()
UnaccentExtension = mock.Mock()
class Migration(migrations.Migration):

View File

@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import django.contrib.postgres.fields
import django.contrib.postgres.fields.hstore
from django.db import migrations, models
from ..fields import * # NOQA
class Migration(migrations.Migration):
@ -17,9 +17,10 @@ class Migration(migrations.Migration):
name='CharArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', django.contrib.postgres.fields.ArrayField(models.CharField(max_length=10), size=None)),
('field', ArrayField(models.CharField(max_length=10), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
@ -27,11 +28,12 @@ class Migration(migrations.Migration):
name='DateTimeArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('datetimes', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)),
('dates', django.contrib.postgres.fields.ArrayField(models.DateField(), size=None)),
('times', django.contrib.postgres.fields.ArrayField(models.TimeField(), size=None)),
('datetimes', ArrayField(models.DateTimeField(), size=None)),
('dates', ArrayField(models.DateField(), size=None)),
('times', ArrayField(models.TimeField(), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
@ -39,9 +41,10 @@ class Migration(migrations.Migration):
name='HStoreModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', django.contrib.postgres.fields.hstore.HStoreField(blank=True, null=True)),
('field', HStoreField(blank=True, null=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
@ -49,11 +52,12 @@ class Migration(migrations.Migration):
name='OtherTypesArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('ips', django.contrib.postgres.fields.ArrayField(models.GenericIPAddressField(), size=None)),
('uuids', django.contrib.postgres.fields.ArrayField(models.UUIDField(), size=None)),
('decimals', django.contrib.postgres.fields.ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
('ips', ArrayField(models.GenericIPAddressField(), size=None)),
('uuids', ArrayField(models.UUIDField(), size=None)),
('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
@ -61,9 +65,10 @@ class Migration(migrations.Migration):
name='IntegerArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None)),
('field', ArrayField(models.IntegerField(), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
@ -71,9 +76,10 @@ class Migration(migrations.Migration):
name='NestedIntegerArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', django.contrib.postgres.fields.ArrayField(django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None), size=None)),
('field', ArrayField(ArrayField(models.IntegerField(), size=None), size=None)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
@ -81,9 +87,10 @@ class Migration(migrations.Migration):
name='NullableIntegerArrayModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None, null=True, blank=True)),
('field', ArrayField(models.IntegerField(), size=None, null=True, blank=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
@ -130,20 +137,25 @@ class Migration(migrations.Migration):
name='RangesModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('ints', django.contrib.postgres.fields.IntegerRangeField(null=True, blank=True)),
('bigints', django.contrib.postgres.fields.BigIntegerRangeField(null=True, blank=True)),
('floats', django.contrib.postgres.fields.FloatRangeField(null=True, blank=True)),
('timestamps', django.contrib.postgres.fields.DateTimeRangeField(null=True, blank=True)),
('dates', django.contrib.postgres.fields.DateRangeField(null=True, blank=True)),
('ints', IntegerRangeField(null=True, blank=True)),
('bigints', BigIntegerRangeField(null=True, blank=True)),
('floats', FloatRangeField(null=True, blank=True)),
('timestamps', DateTimeRangeField(null=True, blank=True)),
('dates', DateRangeField(null=True, blank=True)),
],
options={
'required_db_vendor': 'postgresql',
},
bases=(models.Model,),
),
]
def apply(self, project_state, schema_editor, collect_sql=False):
PG_VERSION = schema_editor.connection.pg_version
if PG_VERSION >= 90200:
self.operations = self.operations + self.pg_92_operations
try:
PG_VERSION = schema_editor.connection.pg_version
except AttributeError:
pass # We are probably not on PostgreSQL
else:
if PG_VERSION >= 90200:
self.operations = self.operations + self.pg_92_operations
return super(Migration, self).apply(project_state, schema_editor, collect_sql)

View File

@ -1,39 +1,46 @@
from django.contrib.postgres.fields import (
from django.db import connection, models
from .fields import (
ArrayField, BigIntegerRangeField, DateRangeField, DateTimeRangeField,
FloatRangeField, HStoreField, IntegerRangeField,
)
from django.db import connection, models
class IntegerArrayModel(models.Model):
class PostgreSQLModel(models.Model):
class Meta:
abstract = True
required_db_vendor = 'postgresql'
class IntegerArrayModel(PostgreSQLModel):
field = ArrayField(models.IntegerField())
class NullableIntegerArrayModel(models.Model):
class NullableIntegerArrayModel(PostgreSQLModel):
field = ArrayField(models.IntegerField(), blank=True, null=True)
class CharArrayModel(models.Model):
class CharArrayModel(PostgreSQLModel):
field = ArrayField(models.CharField(max_length=10))
class DateTimeArrayModel(models.Model):
class DateTimeArrayModel(PostgreSQLModel):
datetimes = ArrayField(models.DateTimeField())
dates = ArrayField(models.DateField())
times = ArrayField(models.TimeField())
class NestedIntegerArrayModel(models.Model):
class NestedIntegerArrayModel(PostgreSQLModel):
field = ArrayField(ArrayField(models.IntegerField()))
class OtherTypesArrayModel(models.Model):
class OtherTypesArrayModel(PostgreSQLModel):
ips = ArrayField(models.GenericIPAddressField())
uuids = ArrayField(models.UUIDField())
decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2))
class HStoreModel(models.Model):
class HStoreModel(PostgreSQLModel):
field = HStoreField(blank=True, null=True)
@ -47,7 +54,7 @@ class TextFieldModel(models.Model):
# Only create this model for databases which support it
if connection.vendor == 'postgresql' and connection.pg_version >= 90200:
class RangesModel(models.Model):
class RangesModel(PostgreSQLModel):
ints = IntegerRangeField(blank=True, null=True)
bigints = BigIntegerRangeField(blank=True, null=True)
floats = FloatRangeField(blank=True, null=True)

View File

@ -4,13 +4,13 @@ from django.contrib.postgres.aggregates import (
RegrSYY, StatAggregate, StringAgg,
)
from django.db.models.expressions import F, Value
from django.test import TestCase
from django.test.utils import Approximate
from . import PostgresSQLTestCase
from .models import AggregateTestModel, StatTestModel
class TestGeneralAggregate(TestCase):
class TestGeneralAggregate(PostgresSQLTestCase):
@classmethod
def setUpTestData(cls):
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0)
@ -111,7 +111,7 @@ class TestGeneralAggregate(TestCase):
self.assertEqual(values, {'stringagg': ''})
class TestStatisticsAggregate(TestCase):
class TestStatisticsAggregate(PostgresSQLTestCase):
@classmethod
def setUpTestData(cls):
StatTestModel.objects.create(

View File

@ -4,21 +4,26 @@ import unittest
import uuid
from django import forms
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField
from django.core import exceptions, serializers, validators
from django.core.management import call_command
from django.db import IntegrityError, connection, models
from django.test import TestCase, TransactionTestCase, override_settings
from django.test import TransactionTestCase, override_settings
from django.utils import timezone
from . import PostgresSQLTestCase
from .models import (
ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel,
NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel,
)
try:
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField
except ImportError:
pass
class TestSaveLoad(TestCase):
class TestSaveLoad(PostgresSQLTestCase):
def test_integer(self):
instance = IntegerArrayModel(field=[1, 2, 3])
@ -93,7 +98,7 @@ class TestSaveLoad(TestCase):
self.assertEqual(instance.decimals, loaded.decimals)
class TestQuerying(TestCase):
class TestQuerying(PostgresSQLTestCase):
def setUp(self):
self.objs = [
@ -224,7 +229,7 @@ class TestQuerying(TestCase):
)
class TestChecks(TestCase):
class TestChecks(PostgresSQLTestCase):
def test_field_checks(self):
field = ArrayField(models.CharField())
@ -241,6 +246,7 @@ class TestChecks(TestCase):
self.assertEqual(errors[0].id, 'postgres.E002')
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
class TestMigrations(TransactionTestCase):
available_apps = ['postgres_tests']
@ -288,7 +294,7 @@ class TestMigrations(TransactionTestCase):
self.assertNotIn(table_name, connection.introspection.table_names(cursor))
class TestSerialization(TestCase):
class TestSerialization(PostgresSQLTestCase):
test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\"]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]'
def test_dumping(self):
@ -301,7 +307,7 @@ class TestSerialization(TestCase):
self.assertEqual(instance.field, [1, 2])
class TestValidation(TestCase):
class TestValidation(PostgresSQLTestCase):
def test_unbounded(self):
field = ArrayField(models.IntegerField())
@ -339,7 +345,7 @@ class TestValidation(TestCase):
self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Ensure this value is greater than or equal to 1.')
class TestSimpleFormField(TestCase):
class TestSimpleFormField(PostgresSQLTestCase):
def test_valid(self):
field = SimpleArrayField(forms.CharField())
@ -411,7 +417,7 @@ class TestSimpleFormField(TestCase):
self.assertEqual(form_field.max_length, 4)
class TestSplitFormField(TestCase):
class TestSplitFormField(PostgresSQLTestCase):
def test_valid(self):
class SplitForm(forms.Form):

View File

@ -1,15 +1,19 @@
import json
from django.contrib.postgres import forms
from django.contrib.postgres.fields import HStoreField
from django.contrib.postgres.validators import KeysValidator
from django.core import exceptions, serializers
from django.test import TestCase
from . import PostgresSQLTestCase
from .models import HStoreModel
try:
from django.contrib.postgres import forms
from django.contrib.postgres.fields import HStoreField
from django.contrib.postgres.validators import KeysValidator
except ImportError:
pass
class SimpleTests(TestCase):
class SimpleTests(PostgresSQLTestCase):
apps = ['django.contrib.postgres']
def test_save_load_success(self):
@ -33,7 +37,7 @@ class SimpleTests(TestCase):
self.assertEqual(reloaded.field, value)
class TestQuerying(TestCase):
class TestQuerying(PostgresSQLTestCase):
def setUp(self):
self.objs = [
@ -111,7 +115,7 @@ class TestQuerying(TestCase):
)
class TestSerialization(TestCase):
class TestSerialization(PostgresSQLTestCase):
test_data = '[{"fields": {"field": "{\\"a\\": \\"b\\"}"}, "model": "postgres_tests.hstoremodel", "pk": null}]'
def test_dumping(self):
@ -124,7 +128,7 @@ class TestSerialization(TestCase):
self.assertEqual(instance.field, {'a': 'b'})
class TestValidation(TestCase):
class TestValidation(PostgresSQLTestCase):
def test_not_a_string(self):
field = HStoreField()
@ -134,7 +138,7 @@ class TestValidation(TestCase):
self.assertEqual(cm.exception.message % cm.exception.params, 'The value of "a" is not a string.')
class TestFormField(TestCase):
class TestFormField(PostgresSQLTestCase):
def test_valid(self):
field = forms.HStoreField()
@ -164,7 +168,7 @@ class TestFormField(TestCase):
self.assertIsInstance(form_field, forms.HStoreField)
class TestValidator(TestCase):
class TestValidator(PostgresSQLTestCase):
def test_simple_valid(self):
validator = KeysValidator(keys=['a', 'b'])

View File

@ -2,23 +2,30 @@ import datetime
import json
import unittest
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
from django import forms
from django.contrib.postgres import fields as pg_fields, forms as pg_forms
from django.contrib.postgres.validators import (
RangeMaxValueValidator, RangeMinValueValidator,
)
from django.core import exceptions, serializers
from django.db import connection
from django.test import TestCase
from django.utils import timezone
from . import PostgresSQLTestCase
from .models import RangesModel
try:
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
from django.contrib.postgres import fields as pg_fields, forms as pg_forms
from django.contrib.postgres.validators import (
RangeMaxValueValidator, RangeMinValueValidator,
)
except ImportError:
pass
def skipUnlessPG92(test):
PG_VERSION = connection.pg_version
try:
PG_VERSION = connection.pg_version
except AttributeError:
PG_VERSION = 0
if PG_VERSION < 90200:
return unittest.skip('PostgreSQL >= 9.2 required')(test)
return test
@ -215,7 +222,7 @@ class TestSerialization(TestCase):
self.assertEqual(instance.dates, None)
class TestValidators(TestCase):
class TestValidators(PostgresSQLTestCase):
def test_max(self):
validator = RangeMaxValueValidator(5)
@ -234,7 +241,7 @@ class TestValidators(TestCase):
self.assertEqual(cm.exception.code, 'min_value')
class TestFormField(TestCase):
class TestFormField(PostgresSQLTestCase):
def test_valid_integer(self):
field = pg_forms.IntegerRangeField()
@ -493,7 +500,7 @@ class TestFormField(TestCase):
self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
class TestWidget(TestCase):
class TestWidget(PostgresSQLTestCase):
def test_range_widget(self):
f = pg_forms.ranges.DateTimeRangeField()
self.assertHTMLEqual(

View File

@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.test import TestCase, modify_settings
from django.test import modify_settings
from . import PostgresSQLTestCase
from .models import CharFieldModel, TextFieldModel
@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
class UnaccentTest(TestCase):
class UnaccentTest(PostgresSQLTestCase):
Model = CharFieldModel

View File

@ -84,8 +84,6 @@ def get_test_modules():
os.path.isfile(f) or
not os.path.exists(os.path.join(dirpath, f, '__init__.py'))):
continue
if connection.vendor != 'postgresql' and f == 'postgres_tests':
continue
modules.append((modpath, f))
return modules