mirror of https://github.com/django/django.git
82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
import unittest
|
|
from decimal import Decimal
|
|
|
|
from django.db import connection
|
|
from django.db.backends.signals import connection_created
|
|
from django.db.migrations.writer import MigrationWriter
|
|
from django.test import TestCase
|
|
from django.test.utils import modify_settings
|
|
|
|
try:
|
|
from django.contrib.postgres.fields import (
|
|
DateRangeField,
|
|
DateTimeRangeField,
|
|
DecimalRangeField,
|
|
IntegerRangeField,
|
|
)
|
|
from django.db.backends.postgresql.psycopg_any import (
|
|
DateRange,
|
|
DateTimeRange,
|
|
DateTimeTZRange,
|
|
NumericRange,
|
|
is_psycopg3,
|
|
)
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
|
|
class PostgresConfigTests(TestCase):
|
|
def test_register_type_handlers_connection(self):
|
|
from django.contrib.postgres.signals import register_type_handlers
|
|
|
|
self.assertNotIn(
|
|
register_type_handlers, connection_created._live_receivers(None)
|
|
)
|
|
with modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
|
|
self.assertIn(
|
|
register_type_handlers, connection_created._live_receivers(None)
|
|
)
|
|
self.assertNotIn(
|
|
register_type_handlers, connection_created._live_receivers(None)
|
|
)
|
|
|
|
def test_register_serializer_for_migrations(self):
|
|
tests = (
|
|
(DateRange(empty=True), DateRangeField),
|
|
(DateTimeRange(empty=True), DateRangeField),
|
|
(DateTimeTZRange(None, None, "[]"), DateTimeRangeField),
|
|
(NumericRange(Decimal("1.0"), Decimal("5.0"), "()"), DecimalRangeField),
|
|
(NumericRange(1, 10), IntegerRangeField),
|
|
)
|
|
|
|
def assertNotSerializable():
|
|
for default, test_field in tests:
|
|
with self.subTest(default=default):
|
|
field = test_field(default=default)
|
|
with self.assertRaisesMessage(
|
|
ValueError, "Cannot serialize: %s" % default.__class__.__name__
|
|
):
|
|
MigrationWriter.serialize(field)
|
|
|
|
assertNotSerializable()
|
|
import_name = "psycopg.types.range" if is_psycopg3 else "psycopg2.extras"
|
|
with self.modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"}):
|
|
for default, test_field in tests:
|
|
with self.subTest(default=default):
|
|
field = test_field(default=default)
|
|
serialized_field, imports = MigrationWriter.serialize(field)
|
|
self.assertEqual(
|
|
imports,
|
|
{
|
|
"import django.contrib.postgres.fields.ranges",
|
|
f"import {import_name}",
|
|
},
|
|
)
|
|
self.assertIn(
|
|
f"{field.__module__}.{field.__class__.__name__}"
|
|
f"(default={import_name}.{default!r})",
|
|
serialized_field,
|
|
)
|
|
assertNotSerializable()
|