2017-09-11 23:56:39 +08:00
|
|
|
import copy
|
2015-10-22 20:42:18 +08:00
|
|
|
import unittest
|
|
|
|
from functools import wraps
|
2017-09-11 23:56:39 +08:00
|
|
|
from unittest import mock
|
2013-09-24 18:22:50 +08:00
|
|
|
|
2008-08-06 02:13:06 +08:00
|
|
|
from django.conf import settings
|
2015-10-22 20:42:18 +08:00
|
|
|
from django.db import DEFAULT_DB_ALIAS, connection
|
2019-08-20 15:54:41 +08:00
|
|
|
from django.db.models import Func
|
2015-10-22 20:42:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
def skipUnlessGISLookup(*gis_lookups):
|
|
|
|
"""
|
|
|
|
Skip a test unless a database supports all of gis_lookups.
|
|
|
|
"""
|
2022-02-04 03:24:19 +08:00
|
|
|
|
2015-10-22 20:42:18 +08:00
|
|
|
def decorator(test_func):
|
|
|
|
@wraps(test_func)
|
|
|
|
def skip_wrapper(*args, **kwargs):
|
|
|
|
if any(key not in connection.ops.gis_operators for key in gis_lookups):
|
|
|
|
raise unittest.SkipTest(
|
|
|
|
"Database doesn't support all the lookups: %s"
|
|
|
|
% ", ".join(gis_lookups)
|
|
|
|
)
|
|
|
|
return test_func(*args, **kwargs)
|
2022-02-04 03:24:19 +08:00
|
|
|
|
2015-10-22 20:42:18 +08:00
|
|
|
return skip_wrapper
|
2022-02-04 03:24:19 +08:00
|
|
|
|
2015-10-22 20:42:18 +08:00
|
|
|
return decorator
|
2008-08-06 02:13:06 +08:00
|
|
|
|
|
|
|
|
2009-12-22 23:18:51 +08:00
|
|
|
_default_db = settings.DATABASES[DEFAULT_DB_ALIAS]["ENGINE"].rsplit(".")[-1]
|
2014-08-22 00:47:57 +08:00
|
|
|
# MySQL spatial indices can't handle NULL geometries.
|
2020-11-14 22:08:30 +08:00
|
|
|
gisfield_may_be_null = _default_db != "mysql"
|
2014-08-22 00:47:57 +08:00
|
|
|
|
2017-09-11 23:56:39 +08:00
|
|
|
|
|
|
|
class FuncTestMixin:
|
|
|
|
"""Assert that Func expressions aren't mutated during their as_sql()."""
|
2022-02-04 03:24:19 +08:00
|
|
|
|
2017-09-11 23:56:39 +08:00
|
|
|
def setUp(self):
|
|
|
|
def as_sql_wrapper(original_as_sql):
|
|
|
|
def inner(*args, **kwargs):
|
|
|
|
func = original_as_sql.__self__
|
|
|
|
# Resolve output_field before as_sql() so touching it in
|
|
|
|
# as_sql() won't change __dict__.
|
|
|
|
func.output_field
|
|
|
|
__dict__original = copy.deepcopy(func.__dict__)
|
|
|
|
result = original_as_sql(*args, **kwargs)
|
|
|
|
msg = (
|
|
|
|
"%s Func was mutated during compilation." % func.__class__.__name__
|
2022-02-04 03:24:19 +08:00
|
|
|
)
|
2017-09-11 23:56:39 +08:00
|
|
|
self.assertEqual(func.__dict__, __dict__original, msg)
|
|
|
|
return result
|
2022-02-04 03:24:19 +08:00
|
|
|
|
2017-09-11 23:56:39 +08:00
|
|
|
return inner
|
|
|
|
|
|
|
|
def __getattribute__(self, name):
|
|
|
|
if name != vendor_impl:
|
|
|
|
return __getattribute__original(self, name)
|
|
|
|
try:
|
|
|
|
as_sql = __getattribute__original(self, vendor_impl)
|
|
|
|
except AttributeError:
|
|
|
|
as_sql = __getattribute__original(self, "as_sql")
|
|
|
|
return as_sql_wrapper(as_sql)
|
|
|
|
|
|
|
|
vendor_impl = "as_" + connection.vendor
|
|
|
|
__getattribute__original = Func.__getattribute__
|
|
|
|
self.func_patcher = mock.patch.object(
|
|
|
|
Func, "__getattribute__", __getattribute__
|
|
|
|
)
|
|
|
|
self.func_patcher.start()
|
|
|
|
super().setUp()
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
super().tearDown()
|
|
|
|
self.func_patcher.stop()
|