From 36e90d1f45a13f53ce25fdc2d9c04433b835c9af Mon Sep 17 00:00:00 2001 From: Claude Paroz Date: Sat, 4 Apr 2015 18:10:26 +0200 Subject: [PATCH] Stopped special-casing postgres-specific tests Refs #23879. --- django/contrib/postgres/operations.py | 2 + tests/postgres_tests/__init__.py | 9 ++++ tests/postgres_tests/fields.py | 31 +++++++++++ .../migrations/0001_setup_extensions.py | 12 +++-- .../migrations/0002_create_test_models.py | 54 +++++++++++-------- tests/postgres_tests/models.py | 27 ++++++---- tests/postgres_tests/test_aggregates.py | 6 +-- tests/postgres_tests/test_array.py | 26 +++++---- tests/postgres_tests/test_hstore.py | 24 +++++---- tests/postgres_tests/test_ranges.py | 27 ++++++---- tests/postgres_tests/test_unaccent.py | 5 +- tests/runtests.py | 2 - 12 files changed, 154 insertions(+), 71 deletions(-) create mode 100644 tests/postgres_tests/fields.py diff --git a/django/contrib/postgres/operations.py b/django/contrib/postgres/operations.py index 5b3bd2a3cc..3bb81dc3f9 100644 --- a/django/contrib/postgres/operations.py +++ b/django/contrib/postgres/operations.py @@ -12,6 +12,8 @@ class CreateExtension(Operation): pass def database_forwards(self, app_label, schema_editor, from_state, to_state): + if schema_editor.connection.vendor != 'postgresql': + return schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % self.name) def database_backwards(self, app_label, schema_editor, from_state, to_state): diff --git a/tests/postgres_tests/__init__.py b/tests/postgres_tests/__init__.py index e69de29bb2..9076bef850 100644 --- a/tests/postgres_tests/__init__.py +++ b/tests/postgres_tests/__init__.py @@ -0,0 +1,9 @@ +import unittest + +from django.db import connection +from django.test import TestCase + + +@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") +class PostgresSQLTestCase(TestCase): + pass diff --git a/tests/postgres_tests/fields.py b/tests/postgres_tests/fields.py new file mode 100644 index 0000000000..e21ca1bfcb --- /dev/null +++ b/tests/postgres_tests/fields.py @@ -0,0 +1,31 @@ +""" +Indirection layer for PostgreSQL-specific fields, so the tests don't fail when +run with a backend other than PostgreSQL. +""" +from django.db import models + +try: + from django.contrib.postgres.fields import ( + ArrayField, BigIntegerRangeField, DateRangeField, DateTimeRangeField, + FloatRangeField, HStoreField, IntegerRangeField, + ) +except ImportError: + class DummyArrayField(models.Field): + def __init__(self, base_field, size=None, **kwargs): + super(DummyArrayField, self).__init__(**kwargs) + + def deconstruct(self): + name, path, args, kwargs = super(DummyArrayField, self).deconstruct() + kwargs.update({ + 'base_field': '', + 'size': 1, + }) + return name, path, args, kwargs + + ArrayField = DummyArrayField + BigIntegerRangeField = models.Field + DateRangeField = models.Field + DateTimeRangeField = models.Field + FloatRangeField = models.Field + HStoreField = models.Field + IntegerRangeField = models.Field diff --git a/tests/postgres_tests/migrations/0001_setup_extensions.py b/tests/postgres_tests/migrations/0001_setup_extensions.py index 0915b74343..ad5d1c716a 100644 --- a/tests/postgres_tests/migrations/0001_setup_extensions.py +++ b/tests/postgres_tests/migrations/0001_setup_extensions.py @@ -1,11 +1,17 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -from django.contrib.postgres.operations import ( - HStoreExtension, UnaccentExtension, -) from django.db import migrations +try: + from django.contrib.postgres.operations import ( + HStoreExtension, UnaccentExtension, + ) +except ImportError: + from django.test import mock + HStoreExtension = mock.Mock() + UnaccentExtension = mock.Mock() + class Migration(migrations.Migration): diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index 7853f40a14..106818a7c6 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -import django.contrib.postgres.fields -import django.contrib.postgres.fields.hstore from django.db import migrations, models +from ..fields import * # NOQA + class Migration(migrations.Migration): @@ -17,9 +17,10 @@ class Migration(migrations.Migration): name='CharArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.ArrayField(models.CharField(max_length=10), size=None)), + ('field', ArrayField(models.CharField(max_length=10), size=None)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), @@ -27,11 +28,12 @@ class Migration(migrations.Migration): name='DateTimeArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('datetimes', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)), - ('dates', django.contrib.postgres.fields.ArrayField(models.DateField(), size=None)), - ('times', django.contrib.postgres.fields.ArrayField(models.TimeField(), size=None)), + ('datetimes', ArrayField(models.DateTimeField(), size=None)), + ('dates', ArrayField(models.DateField(), size=None)), + ('times', ArrayField(models.TimeField(), size=None)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), @@ -39,9 +41,10 @@ class Migration(migrations.Migration): name='HStoreModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.hstore.HStoreField(blank=True, null=True)), + ('field', HStoreField(blank=True, null=True)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), @@ -49,11 +52,12 @@ class Migration(migrations.Migration): name='OtherTypesArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('ips', django.contrib.postgres.fields.ArrayField(models.GenericIPAddressField(), size=None)), - ('uuids', django.contrib.postgres.fields.ArrayField(models.UUIDField(), size=None)), - ('decimals', django.contrib.postgres.fields.ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), + ('ips', ArrayField(models.GenericIPAddressField(), size=None)), + ('uuids', ArrayField(models.UUIDField(), size=None)), + ('decimals', ArrayField(models.DecimalField(max_digits=5, decimal_places=2), size=None)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), @@ -61,9 +65,10 @@ class Migration(migrations.Migration): name='IntegerArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None)), + ('field', ArrayField(models.IntegerField(), size=None)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), @@ -71,9 +76,10 @@ class Migration(migrations.Migration): name='NestedIntegerArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.ArrayField(django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None), size=None)), + ('field', ArrayField(ArrayField(models.IntegerField(), size=None), size=None)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), @@ -81,9 +87,10 @@ class Migration(migrations.Migration): name='NullableIntegerArrayModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None, null=True, blank=True)), + ('field', ArrayField(models.IntegerField(), size=None, null=True, blank=True)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), @@ -130,20 +137,25 @@ class Migration(migrations.Migration): name='RangesModel', fields=[ ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), - ('ints', django.contrib.postgres.fields.IntegerRangeField(null=True, blank=True)), - ('bigints', django.contrib.postgres.fields.BigIntegerRangeField(null=True, blank=True)), - ('floats', django.contrib.postgres.fields.FloatRangeField(null=True, blank=True)), - ('timestamps', django.contrib.postgres.fields.DateTimeRangeField(null=True, blank=True)), - ('dates', django.contrib.postgres.fields.DateRangeField(null=True, blank=True)), + ('ints', IntegerRangeField(null=True, blank=True)), + ('bigints', BigIntegerRangeField(null=True, blank=True)), + ('floats', FloatRangeField(null=True, blank=True)), + ('timestamps', DateTimeRangeField(null=True, blank=True)), + ('dates', DateRangeField(null=True, blank=True)), ], options={ + 'required_db_vendor': 'postgresql', }, bases=(models.Model,), ), ] def apply(self, project_state, schema_editor, collect_sql=False): - PG_VERSION = schema_editor.connection.pg_version - if PG_VERSION >= 90200: - self.operations = self.operations + self.pg_92_operations + try: + PG_VERSION = schema_editor.connection.pg_version + except AttributeError: + pass # We are probably not on PostgreSQL + else: + if PG_VERSION >= 90200: + self.operations = self.operations + self.pg_92_operations return super(Migration, self).apply(project_state, schema_editor, collect_sql) diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index b53eda4cca..aafd529443 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -1,39 +1,46 @@ -from django.contrib.postgres.fields import ( +from django.db import connection, models + +from .fields import ( ArrayField, BigIntegerRangeField, DateRangeField, DateTimeRangeField, FloatRangeField, HStoreField, IntegerRangeField, ) -from django.db import connection, models -class IntegerArrayModel(models.Model): +class PostgreSQLModel(models.Model): + class Meta: + abstract = True + required_db_vendor = 'postgresql' + + +class IntegerArrayModel(PostgreSQLModel): field = ArrayField(models.IntegerField()) -class NullableIntegerArrayModel(models.Model): +class NullableIntegerArrayModel(PostgreSQLModel): field = ArrayField(models.IntegerField(), blank=True, null=True) -class CharArrayModel(models.Model): +class CharArrayModel(PostgreSQLModel): field = ArrayField(models.CharField(max_length=10)) -class DateTimeArrayModel(models.Model): +class DateTimeArrayModel(PostgreSQLModel): datetimes = ArrayField(models.DateTimeField()) dates = ArrayField(models.DateField()) times = ArrayField(models.TimeField()) -class NestedIntegerArrayModel(models.Model): +class NestedIntegerArrayModel(PostgreSQLModel): field = ArrayField(ArrayField(models.IntegerField())) -class OtherTypesArrayModel(models.Model): +class OtherTypesArrayModel(PostgreSQLModel): ips = ArrayField(models.GenericIPAddressField()) uuids = ArrayField(models.UUIDField()) decimals = ArrayField(models.DecimalField(max_digits=5, decimal_places=2)) -class HStoreModel(models.Model): +class HStoreModel(PostgreSQLModel): field = HStoreField(blank=True, null=True) @@ -47,7 +54,7 @@ class TextFieldModel(models.Model): # Only create this model for databases which support it if connection.vendor == 'postgresql' and connection.pg_version >= 90200: - class RangesModel(models.Model): + class RangesModel(PostgreSQLModel): ints = IntegerRangeField(blank=True, null=True) bigints = BigIntegerRangeField(blank=True, null=True) floats = FloatRangeField(blank=True, null=True) diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index 2857d917d5..4ce0d4237a 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -4,13 +4,13 @@ from django.contrib.postgres.aggregates import ( RegrSYY, StatAggregate, StringAgg, ) from django.db.models.expressions import F, Value -from django.test import TestCase from django.test.utils import Approximate +from . import PostgresSQLTestCase from .models import AggregateTestModel, StatTestModel -class TestGeneralAggregate(TestCase): +class TestGeneralAggregate(PostgresSQLTestCase): @classmethod def setUpTestData(cls): AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0) @@ -111,7 +111,7 @@ class TestGeneralAggregate(TestCase): self.assertEqual(values, {'stringagg': ''}) -class TestStatisticsAggregate(TestCase): +class TestStatisticsAggregate(PostgresSQLTestCase): @classmethod def setUpTestData(cls): StatTestModel.objects.create( diff --git a/tests/postgres_tests/test_array.py b/tests/postgres_tests/test_array.py index 35a3bffd58..895c910a8b 100644 --- a/tests/postgres_tests/test_array.py +++ b/tests/postgres_tests/test_array.py @@ -4,21 +4,26 @@ import unittest import uuid from django import forms -from django.contrib.postgres.fields import ArrayField -from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField from django.core import exceptions, serializers, validators from django.core.management import call_command from django.db import IntegrityError, connection, models -from django.test import TestCase, TransactionTestCase, override_settings +from django.test import TransactionTestCase, override_settings from django.utils import timezone +from . import PostgresSQLTestCase from .models import ( ArrayFieldSubclass, CharArrayModel, DateTimeArrayModel, IntegerArrayModel, NestedIntegerArrayModel, NullableIntegerArrayModel, OtherTypesArrayModel, ) +try: + from django.contrib.postgres.fields import ArrayField + from django.contrib.postgres.forms import SimpleArrayField, SplitArrayField +except ImportError: + pass -class TestSaveLoad(TestCase): + +class TestSaveLoad(PostgresSQLTestCase): def test_integer(self): instance = IntegerArrayModel(field=[1, 2, 3]) @@ -93,7 +98,7 @@ class TestSaveLoad(TestCase): self.assertEqual(instance.decimals, loaded.decimals) -class TestQuerying(TestCase): +class TestQuerying(PostgresSQLTestCase): def setUp(self): self.objs = [ @@ -224,7 +229,7 @@ class TestQuerying(TestCase): ) -class TestChecks(TestCase): +class TestChecks(PostgresSQLTestCase): def test_field_checks(self): field = ArrayField(models.CharField()) @@ -241,6 +246,7 @@ class TestChecks(TestCase): self.assertEqual(errors[0].id, 'postgres.E002') +@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests") class TestMigrations(TransactionTestCase): available_apps = ['postgres_tests'] @@ -288,7 +294,7 @@ class TestMigrations(TransactionTestCase): self.assertNotIn(table_name, connection.introspection.table_names(cursor)) -class TestSerialization(TestCase): +class TestSerialization(PostgresSQLTestCase): test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\"]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]' def test_dumping(self): @@ -301,7 +307,7 @@ class TestSerialization(TestCase): self.assertEqual(instance.field, [1, 2]) -class TestValidation(TestCase): +class TestValidation(PostgresSQLTestCase): def test_unbounded(self): field = ArrayField(models.IntegerField()) @@ -339,7 +345,7 @@ class TestValidation(TestCase): self.assertEqual(cm.exception.messages[0], 'Item 0 in the array did not validate: Ensure this value is greater than or equal to 1.') -class TestSimpleFormField(TestCase): +class TestSimpleFormField(PostgresSQLTestCase): def test_valid(self): field = SimpleArrayField(forms.CharField()) @@ -411,7 +417,7 @@ class TestSimpleFormField(TestCase): self.assertEqual(form_field.max_length, 4) -class TestSplitFormField(TestCase): +class TestSplitFormField(PostgresSQLTestCase): def test_valid(self): class SplitForm(forms.Form): diff --git a/tests/postgres_tests/test_hstore.py b/tests/postgres_tests/test_hstore.py index 06544163f8..ed88e2fab5 100644 --- a/tests/postgres_tests/test_hstore.py +++ b/tests/postgres_tests/test_hstore.py @@ -1,15 +1,19 @@ import json -from django.contrib.postgres import forms -from django.contrib.postgres.fields import HStoreField -from django.contrib.postgres.validators import KeysValidator from django.core import exceptions, serializers -from django.test import TestCase +from . import PostgresSQLTestCase from .models import HStoreModel +try: + from django.contrib.postgres import forms + from django.contrib.postgres.fields import HStoreField + from django.contrib.postgres.validators import KeysValidator +except ImportError: + pass -class SimpleTests(TestCase): + +class SimpleTests(PostgresSQLTestCase): apps = ['django.contrib.postgres'] def test_save_load_success(self): @@ -33,7 +37,7 @@ class SimpleTests(TestCase): self.assertEqual(reloaded.field, value) -class TestQuerying(TestCase): +class TestQuerying(PostgresSQLTestCase): def setUp(self): self.objs = [ @@ -111,7 +115,7 @@ class TestQuerying(TestCase): ) -class TestSerialization(TestCase): +class TestSerialization(PostgresSQLTestCase): test_data = '[{"fields": {"field": "{\\"a\\": \\"b\\"}"}, "model": "postgres_tests.hstoremodel", "pk": null}]' def test_dumping(self): @@ -124,7 +128,7 @@ class TestSerialization(TestCase): self.assertEqual(instance.field, {'a': 'b'}) -class TestValidation(TestCase): +class TestValidation(PostgresSQLTestCase): def test_not_a_string(self): field = HStoreField() @@ -134,7 +138,7 @@ class TestValidation(TestCase): self.assertEqual(cm.exception.message % cm.exception.params, 'The value of "a" is not a string.') -class TestFormField(TestCase): +class TestFormField(PostgresSQLTestCase): def test_valid(self): field = forms.HStoreField() @@ -164,7 +168,7 @@ class TestFormField(TestCase): self.assertIsInstance(form_field, forms.HStoreField) -class TestValidator(TestCase): +class TestValidator(PostgresSQLTestCase): def test_simple_valid(self): validator = KeysValidator(keys=['a', 'b']) diff --git a/tests/postgres_tests/test_ranges.py b/tests/postgres_tests/test_ranges.py index d11f09d1a3..35b79f9dc4 100644 --- a/tests/postgres_tests/test_ranges.py +++ b/tests/postgres_tests/test_ranges.py @@ -2,23 +2,30 @@ import datetime import json import unittest -from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange - from django import forms -from django.contrib.postgres import fields as pg_fields, forms as pg_forms -from django.contrib.postgres.validators import ( - RangeMaxValueValidator, RangeMinValueValidator, -) from django.core import exceptions, serializers from django.db import connection from django.test import TestCase from django.utils import timezone +from . import PostgresSQLTestCase from .models import RangesModel +try: + from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange + from django.contrib.postgres import fields as pg_fields, forms as pg_forms + from django.contrib.postgres.validators import ( + RangeMaxValueValidator, RangeMinValueValidator, + ) +except ImportError: + pass + def skipUnlessPG92(test): - PG_VERSION = connection.pg_version + try: + PG_VERSION = connection.pg_version + except AttributeError: + PG_VERSION = 0 if PG_VERSION < 90200: return unittest.skip('PostgreSQL >= 9.2 required')(test) return test @@ -215,7 +222,7 @@ class TestSerialization(TestCase): self.assertEqual(instance.dates, None) -class TestValidators(TestCase): +class TestValidators(PostgresSQLTestCase): def test_max(self): validator = RangeMaxValueValidator(5) @@ -234,7 +241,7 @@ class TestValidators(TestCase): self.assertEqual(cm.exception.code, 'min_value') -class TestFormField(TestCase): +class TestFormField(PostgresSQLTestCase): def test_valid_integer(self): field = pg_forms.IntegerRangeField() @@ -493,7 +500,7 @@ class TestFormField(TestCase): self.assertIsInstance(form_field, pg_forms.DateTimeRangeField) -class TestWidget(TestCase): +class TestWidget(PostgresSQLTestCase): def test_range_widget(self): f = pg_forms.ranges.DateTimeRangeField() self.assertHTMLEqual( diff --git a/tests/postgres_tests/test_unaccent.py b/tests/postgres_tests/test_unaccent.py index ba65155301..af1618d183 100644 --- a/tests/postgres_tests/test_unaccent.py +++ b/tests/postgres_tests/test_unaccent.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -from django.test import TestCase, modify_settings +from django.test import modify_settings +from . import PostgresSQLTestCase from .models import CharFieldModel, TextFieldModel @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'}) -class UnaccentTest(TestCase): +class UnaccentTest(PostgresSQLTestCase): Model = CharFieldModel diff --git a/tests/runtests.py b/tests/runtests.py index a1c1dfdb1b..db0e09c162 100755 --- a/tests/runtests.py +++ b/tests/runtests.py @@ -84,8 +84,6 @@ def get_test_modules(): os.path.isfile(f) or not os.path.exists(os.path.join(dirpath, f, '__init__.py'))): continue - if connection.vendor != 'postgresql' and f == 'postgres_tests': - continue modules.append((modpath, f)) return modules