Fixed #29500 -- Fixed SQLite function crashes on null values.

Co-authored-by: Srinivas Reddy Thatiparthy <thatiparthysreenivas@gmail.com>
Co-authored-by: Nick Pope <nick.pope@flightdataservices.com>
This commit is contained in:
Srinivas Reddy Thatiparthy 2018-07-01 02:19:20 +05:30 committed by Tim Graham
parent 76dfa834e7
commit 34d6bceec4
29 changed files with 272 additions and 38 deletions

View File

@ -569,7 +569,7 @@ END;
if internal_type == 'DateField': if internal_type == 'DateField':
lhs_sql, lhs_params = lhs lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs rhs_sql, rhs_params = rhs
return "NUMTODSINTERVAL(%s - %s, 'DAY')" % (lhs_sql, rhs_sql), lhs_params + rhs_params return "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), lhs_params + rhs_params
return super().subtract_temporals(internal_type, lhs, rhs) return super().subtract_temporals(internal_type, lhs, rhs)
def bulk_batch_size(self, fields, objs): def bulk_batch_size(self, fields, objs):

View File

@ -3,6 +3,7 @@ SQLite3 backend for the sqlite3 module in the standard library.
""" """
import datetime import datetime
import decimal import decimal
import functools
import math import math
import operator import operator
import re import re
@ -34,6 +35,19 @@ def decoder(conv_func):
return lambda s: conv_func(s.decode()) return lambda s: conv_func(s.decode())
def none_guard(func):
"""
Decorator that returns None if any of the arguments to the decorated
function are None. Many SQL functions return NULL if any of their arguments
are NULL. This decorator simplifies the implementation of this for the
custom functions registered below.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
return None if None in args else func(*args, **kwargs)
return wrapper
Database.register_converter("bool", b'1'.__eq__) Database.register_converter("bool", b'1'.__eq__)
Database.register_converter("time", decoder(parse_time)) Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime)) Database.register_converter("datetime", decoder(parse_datetime))
@ -171,30 +185,30 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn.create_function("django_time_trunc", 2, _sqlite_time_trunc) conn.create_function("django_time_trunc", 2, _sqlite_time_trunc)
conn.create_function("django_time_diff", 2, _sqlite_time_diff) conn.create_function("django_time_diff", 2, _sqlite_time_diff)
conn.create_function("django_timestamp_diff", 2, _sqlite_timestamp_diff) conn.create_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
conn.create_function("regexp", 2, _sqlite_regexp)
conn.create_function("django_format_dtdelta", 3, _sqlite_format_dtdelta) conn.create_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
conn.create_function('regexp', 2, _sqlite_regexp)
conn.create_function('ACOS', 1, none_guard(math.acos))
conn.create_function('ASIN', 1, none_guard(math.asin))
conn.create_function('ATAN', 1, none_guard(math.atan))
conn.create_function('ATAN2', 2, none_guard(math.atan2))
conn.create_function('CEILING', 1, none_guard(math.ceil))
conn.create_function('COS', 1, none_guard(math.cos))
conn.create_function('COT', 1, none_guard(lambda x: 1 / math.tan(x)))
conn.create_function('DEGREES', 1, none_guard(math.degrees))
conn.create_function('EXP', 1, none_guard(math.exp))
conn.create_function('FLOOR', 1, none_guard(math.floor))
conn.create_function('LN', 1, none_guard(math.log))
conn.create_function('LOG', 2, none_guard(lambda x, y: math.log(y, x)))
conn.create_function('LPAD', 3, _sqlite_lpad) conn.create_function('LPAD', 3, _sqlite_lpad)
conn.create_function('REPEAT', 2, operator.mul) conn.create_function('MOD', 2, none_guard(math.fmod))
conn.create_function('RPAD', 3, _sqlite_rpad)
conn.create_function('ACOS', 1, math.acos)
conn.create_function('ASIN', 1, math.asin)
conn.create_function('ATAN', 1, math.atan)
conn.create_function('ATAN2', 2, math.atan2)
conn.create_function('CEILING', 1, math.ceil)
conn.create_function('COS', 1, math.cos)
conn.create_function('COT', 1, lambda x: 1 / math.tan(x))
conn.create_function('DEGREES', 1, math.degrees)
conn.create_function('EXP', 1, math.exp)
conn.create_function('FLOOR', 1, math.floor)
conn.create_function('LN', 1, math.log)
conn.create_function('LOG', 2, lambda x, y: math.log(y, x))
conn.create_function('MOD', 2, math.fmod)
conn.create_function('PI', 0, lambda: math.pi) conn.create_function('PI', 0, lambda: math.pi)
conn.create_function('POWER', 2, operator.pow) conn.create_function('POWER', 2, none_guard(operator.pow))
conn.create_function('RADIANS', 1, math.radians) conn.create_function('RADIANS', 1, none_guard(math.radians))
conn.create_function('SIN', 1, math.sin) conn.create_function('REPEAT', 2, none_guard(operator.mul))
conn.create_function('SQRT', 1, math.sqrt) conn.create_function('RPAD', 3, _sqlite_rpad)
conn.create_function('TAN', 1, math.tan) conn.create_function('SIN', 1, none_guard(math.sin))
conn.create_function('SQRT', 1, none_guard(math.sqrt))
conn.create_function('TAN', 1, none_guard(math.tan))
conn.execute('PRAGMA foreign_keys = ON') conn.execute('PRAGMA foreign_keys = ON')
return conn return conn
@ -356,6 +370,8 @@ def _sqlite_date_trunc(lookup_type, dt):
def _sqlite_time_trunc(lookup_type, dt): def _sqlite_time_trunc(lookup_type, dt):
if dt is None:
return None
try: try:
dt = backend_utils.typecast_time(dt) dt = backend_utils.typecast_time(dt)
except (ValueError, TypeError): except (ValueError, TypeError):
@ -432,6 +448,7 @@ def _sqlite_time_extract(lookup_type, dt):
return getattr(dt, lookup_type) return getattr(dt, lookup_type)
@none_guard
def _sqlite_format_dtdelta(conn, lhs, rhs): def _sqlite_format_dtdelta(conn, lhs, rhs):
""" """
LHS and RHS can be either: LHS and RHS can be either:
@ -452,6 +469,7 @@ def _sqlite_format_dtdelta(conn, lhs, rhs):
return str(out) return str(out)
@none_guard
def _sqlite_time_diff(lhs, rhs): def _sqlite_time_diff(lhs, rhs):
left = backend_utils.typecast_time(lhs) left = backend_utils.typecast_time(lhs)
right = backend_utils.typecast_time(rhs) right = backend_utils.typecast_time(rhs)
@ -467,21 +485,25 @@ def _sqlite_time_diff(lhs, rhs):
) )
@none_guard
def _sqlite_timestamp_diff(lhs, rhs): def _sqlite_timestamp_diff(lhs, rhs):
left = backend_utils.typecast_timestamp(lhs) left = backend_utils.typecast_timestamp(lhs)
right = backend_utils.typecast_timestamp(rhs) right = backend_utils.typecast_timestamp(rhs)
return duration_microseconds(left - right) return duration_microseconds(left - right)
@none_guard
def _sqlite_regexp(re_pattern, re_string): def _sqlite_regexp(re_pattern, re_string):
return bool(re.search(re_pattern, str(re_string))) if re_string is not None else False return bool(re.search(re_pattern, str(re_string)))
@none_guard
def _sqlite_lpad(text, length, fill_text): def _sqlite_lpad(text, length, fill_text):
if len(text) >= length: if len(text) >= length:
return text[:length] return text[:length]
return (fill_text * length)[:length - len(text)] + text return (fill_text * length)[:length - len(text)] + text
@none_guard
def _sqlite_rpad(text, length, fill_text): def _sqlite_rpad(text, length, fill_text):
return (text + fill_text * length)[:length] return (text + fill_text * length)[:length]

View File

@ -218,16 +218,20 @@ class TruncBase(TimezoneMixin, Transform):
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if isinstance(self.output_field, DateTimeField): if isinstance(self.output_field, DateTimeField):
if settings.USE_TZ: if not settings.USE_TZ:
if value is None: pass
raise ValueError( elif value is not None:
"Database returned an invalid datetime value. "
"Are time zone definitions for your database installed?"
)
value = value.replace(tzinfo=None) value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo) value = timezone.make_aware(value, self.tzinfo)
elif not connection.features.has_zoneinfo_database:
raise ValueError(
'Database returned an invalid datetime value. Are time '
'zone definitions for your database installed?'
)
elif isinstance(value, datetime): elif isinstance(value, datetime):
if isinstance(self.output_field, DateField): if value is None:
pass
elif isinstance(self.output_field, DateField):
value = value.date() value = value.date()
elif isinstance(self.output_field, TimeField): elif isinstance(self.output_field, TimeField):
value = value.time() value = value.time()

View File

@ -139,7 +139,7 @@ class LPad(BytesToCharFieldConversionMixin, Func):
function = 'LPAD' function = 'LPAD'
def __init__(self, expression, length, fill_text=Value(' '), **extra): def __init__(self, expression, length, fill_text=Value(' '), **extra):
if not hasattr(length, 'resolve_expression') and length < 0: if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
raise ValueError("'length' must be greater or equal to 0.") raise ValueError("'length' must be greater or equal to 0.")
super().__init__(expression, length, fill_text, **extra) super().__init__(expression, length, fill_text, **extra)
@ -165,13 +165,14 @@ class Repeat(BytesToCharFieldConversionMixin, Func):
function = 'REPEAT' function = 'REPEAT'
def __init__(self, expression, number, **extra): def __init__(self, expression, number, **extra):
if not hasattr(number, 'resolve_expression') and number < 0: if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
raise ValueError("'number' must be greater or equal to 0.") raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra) super().__init__(expression, number, **extra)
def as_oracle(self, compiler, connection, **extra_context): def as_oracle(self, compiler, connection, **extra_context):
expression, number = self.source_expressions expression, number = self.source_expressions
rpad = RPad(expression, Length(expression) * number, expression) length = None if number is None else Length(expression) * number
rpad = RPad(expression, length, expression)
return rpad.as_sql(compiler, connection, **extra_context) return rpad.as_sql(compiler, connection, **extra_context)

View File

@ -59,6 +59,22 @@ class Tests(TestCase):
creation = DatabaseWrapper(settings_dict).creation creation = DatabaseWrapper(settings_dict).creation
self.assertEqual(creation._get_test_db_name(), creation.connection.settings_dict['TEST']['NAME']) self.assertEqual(creation._get_test_db_name(), creation.connection.settings_dict['TEST']['NAME'])
def test_regexp_function(self):
tests = (
('test', r'[0-9]+', False),
('test', r'[a-z]+', True),
('test', None, None),
(None, r'[a-z]+', None),
(None, None, None),
)
for string, pattern, expected in tests:
with self.subTest((string, pattern)):
with connection.cursor() as cursor:
cursor.execute('SELECT %s REGEXP %s', [string, pattern])
value = cursor.fetchone()[0]
value = bool(value) if value in {0, 1} else value
self.assertIs(value, expected)
@unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests') @unittest.skipUnless(connection.vendor == 'sqlite', 'SQLite tests')
@isolate_apps('backends') @isolate_apps('backends')

View File

@ -66,11 +66,14 @@ class DateFunctionTests(TestCase):
def create_model(self, start_datetime, end_datetime): def create_model(self, start_datetime, end_datetime):
return DTModel.objects.create( return DTModel.objects.create(
name=start_datetime.isoformat(), name=start_datetime.isoformat() if start_datetime else 'None',
start_datetime=start_datetime, end_datetime=end_datetime, start_datetime=start_datetime,
start_date=start_datetime.date(), end_date=end_datetime.date(), end_datetime=end_datetime,
start_time=start_datetime.time(), end_time=end_datetime.time(), start_date=start_datetime.date() if start_datetime else None,
duration=(end_datetime - start_datetime), end_date=end_datetime.date() if end_datetime else None,
start_time=start_datetime.time() if start_datetime else None,
end_time=end_datetime.time() if end_datetime else None,
duration=(end_datetime - start_datetime) if start_datetime and end_datetime else None,
) )
def test_extract_year_exact_lookup(self): def test_extract_year_exact_lookup(self):
@ -215,6 +218,12 @@ class DateFunctionTests(TestCase):
self.assertEqual(DTModel.objects.filter(start_date__month=Extract('start_date', 'month')).count(), 2) self.assertEqual(DTModel.objects.filter(start_date__month=Extract('start_date', 'month')).count(), 2)
self.assertEqual(DTModel.objects.filter(start_time__hour=Extract('start_time', 'hour')).count(), 2) self.assertEqual(DTModel.objects.filter(start_time__hour=Extract('start_time', 'hour')).count(), 2)
def test_extract_none(self):
self.create_model(None, None)
for t in (Extract('start_datetime', 'year'), Extract('start_date', 'year'), Extract('start_time', 'hour')):
with self.subTest(t):
self.assertIsNone(DTModel.objects.annotate(extracted=t).first().extracted)
@skipUnlessDBFeature('has_native_duration_field') @skipUnlessDBFeature('has_native_duration_field')
def test_extract_duration(self): def test_extract_duration(self):
start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321) start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321)
@ -608,6 +617,12 @@ class DateFunctionTests(TestCase):
qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField())) qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField()))
self.assertEqual(qs.count(), 2) self.assertEqual(qs.count(), 2)
def test_trunc_none(self):
self.create_model(None, None)
for t in (Trunc('start_datetime', 'year'), Trunc('start_date', 'year'), Trunc('start_time', 'hour')):
with self.subTest(t):
self.assertIsNone(DTModel.objects.annotate(truncated=t).first().truncated)
def test_trunc_year_func(self): def test_trunc_year_func(self):
start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321) start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321)
end_datetime = truncate_to(datetime(2016, 6, 15, 14, 10, 50, 123), 'year') end_datetime = truncate_to(datetime(2016, 6, 15, 14, 10, 50, 123), 'year')
@ -761,6 +776,10 @@ class DateFunctionTests(TestCase):
with self.assertRaisesMessage(ValueError, "Cannot truncate TimeField 'start_time' to DateField"): with self.assertRaisesMessage(ValueError, "Cannot truncate TimeField 'start_time' to DateField"):
list(DTModel.objects.annotate(truncated=TruncDate('start_time', output_field=TimeField()))) list(DTModel.objects.annotate(truncated=TruncDate('start_time', output_field=TimeField())))
def test_trunc_date_none(self):
self.create_model(None, None)
self.assertIsNone(DTModel.objects.annotate(truncated=TruncDate('start_datetime')).first().truncated)
def test_trunc_time_func(self): def test_trunc_time_func(self):
start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321) start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321)
end_datetime = datetime(2016, 6, 15, 14, 10, 50, 123) end_datetime = datetime(2016, 6, 15, 14, 10, 50, 123)
@ -785,6 +804,10 @@ class DateFunctionTests(TestCase):
with self.assertRaisesMessage(ValueError, "Cannot truncate DateField 'start_date' to TimeField"): with self.assertRaisesMessage(ValueError, "Cannot truncate DateField 'start_date' to TimeField"):
list(DTModel.objects.annotate(truncated=TruncTime('start_date', output_field=DateField()))) list(DTModel.objects.annotate(truncated=TruncTime('start_date', output_field=DateField())))
def test_trunc_time_none(self):
self.create_model(None, None)
self.assertIsNone(DTModel.objects.annotate(truncated=TruncTime('start_datetime')).first().truncated)
def test_trunc_day_func(self): def test_trunc_day_func(self):
start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321) start_datetime = datetime(2015, 6, 15, 14, 30, 50, 321)
end_datetime = truncate_to(datetime(2016, 6, 15, 14, 10, 50, 123), 'day') end_datetime = truncate_to(datetime(2016, 6, 15, 14, 10, 50, 123), 'day')

View File

@ -10,6 +10,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class AbsTests(TestCase): class AbsTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_abs=Abs('normal')).first()
self.assertIsNone(obj.null_abs)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-0.8'), n2=Decimal('1.2')) DecimalModel.objects.create(n1=Decimal('-0.8'), n2=Decimal('1.2'))
obj = DecimalModel.objects.annotate(n1_abs=Abs('n1'), n2_abs=Abs('n2')).first() obj = DecimalModel.objects.annotate(n1_abs=Abs('n1'), n2_abs=Abs('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class ACosTests(TestCase): class ACosTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_acos=ACos('normal')).first()
self.assertIsNone(obj.null_acos)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-0.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-0.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_acos=ACos('n1'), n2_acos=ACos('n2')).first() obj = DecimalModel.objects.annotate(n1_acos=ACos('n1'), n2_acos=ACos('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class ASinTests(TestCase): class ASinTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_asin=ASin('normal')).first()
self.assertIsNone(obj.null_asin)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('0.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('0.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_asin=ASin('n1'), n2_asin=ASin('n2')).first() obj = DecimalModel.objects.annotate(n1_asin=ASin('n1'), n2_asin=ASin('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class ATanTests(TestCase): class ATanTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_atan=ATan('normal')).first()
self.assertIsNone(obj.null_atan)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_atan=ATan('n1'), n2_atan=ATan('n2')).first() obj = DecimalModel.objects.annotate(n1_atan=ATan('n1'), n2_atan=ATan('n2')).first()

View File

@ -9,6 +9,15 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class ATan2Tests(TestCase): class ATan2Tests(TestCase):
def test_null(self):
IntegerModel.objects.create(big=100)
obj = IntegerModel.objects.annotate(
null_atan2_sn=ATan2('small', 'normal'),
null_atan2_nb=ATan2('normal', 'big'),
).first()
self.assertIsNone(obj.null_atan2_sn)
self.assertIsNone(obj.null_atan2_nb)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-9.9'), n2=Decimal('4.6')) DecimalModel.objects.create(n1=Decimal('-9.9'), n2=Decimal('4.6'))
obj = DecimalModel.objects.annotate(n_atan2=ATan2('n1', 'n2')).first() obj = DecimalModel.objects.annotate(n_atan2=ATan2('n1', 'n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class CeilTests(TestCase): class CeilTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_ceil=Ceil('normal')).first()
self.assertIsNone(obj.null_ceil)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_ceil=Ceil('n1'), n2_ceil=Ceil('n2')).first() obj = DecimalModel.objects.annotate(n1_ceil=Ceil('n1'), n2_ceil=Ceil('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class CosTests(TestCase): class CosTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_cos=Cos('normal')).first()
self.assertIsNone(obj.null_cos)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_cos=Cos('n1'), n2_cos=Cos('n2')).first() obj = DecimalModel.objects.annotate(n1_cos=Cos('n1'), n2_cos=Cos('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class CotTests(TestCase): class CotTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_cot=Cot('normal')).first()
self.assertIsNone(obj.null_cot)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_cot=Cot('n1'), n2_cot=Cot('n2')).first() obj = DecimalModel.objects.annotate(n1_cot=Cot('n1'), n2_cot=Cot('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class DegreesTests(TestCase): class DegreesTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_degrees=Degrees('normal')).first()
self.assertIsNone(obj.null_degrees)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_degrees=Degrees('n1'), n2_degrees=Degrees('n2')).first() obj = DecimalModel.objects.annotate(n1_degrees=Degrees('n1'), n2_degrees=Degrees('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class ExpTests(TestCase): class ExpTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_exp=Exp('normal')).first()
self.assertIsNone(obj.null_exp)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_exp=Exp('n1'), n2_exp=Exp('n2')).first() obj = DecimalModel.objects.annotate(n1_exp=Exp('n1'), n2_exp=Exp('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class FloorTests(TestCase): class FloorTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_floor=Floor('normal')).first()
self.assertIsNone(obj.null_floor)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_floor=Floor('n1'), n2_floor=Floor('n2')).first() obj = DecimalModel.objects.annotate(n1_floor=Floor('n1'), n2_floor=Floor('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class LnTests(TestCase): class LnTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_ln=Ln('normal')).first()
self.assertIsNone(obj.null_ln)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_ln=Ln('n1'), n2_ln=Ln('n2')).first() obj = DecimalModel.objects.annotate(n1_ln=Ln('n1'), n2_ln=Ln('n2')).first()

View File

@ -9,6 +9,15 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class LogTests(TestCase): class LogTests(TestCase):
def test_null(self):
IntegerModel.objects.create(big=100)
obj = IntegerModel.objects.annotate(
null_log_small=Log('small', 'normal'),
null_log_normal=Log('normal', 'big'),
).first()
self.assertIsNone(obj.null_log_small)
self.assertIsNone(obj.null_log_normal)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('3.6')) DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('3.6'))
obj = DecimalModel.objects.annotate(n_log=Log('n1', 'n2')).first() obj = DecimalModel.objects.annotate(n_log=Log('n1', 'n2')).first()

View File

@ -9,6 +9,15 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class ModTests(TestCase): class ModTests(TestCase):
def test_null(self):
IntegerModel.objects.create(big=100)
obj = IntegerModel.objects.annotate(
null_mod_small=Mod('small', 'normal'),
null_mod_normal=Mod('normal', 'big'),
).first()
self.assertIsNone(obj.null_mod_small)
self.assertIsNone(obj.null_mod_normal)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-9.9'), n2=Decimal('4.6')) DecimalModel.objects.create(n1=Decimal('-9.9'), n2=Decimal('4.6'))
obj = DecimalModel.objects.annotate(n_mod=Mod('n1', 'n2')).first() obj = DecimalModel.objects.annotate(n_mod=Mod('n1', 'n2')).first()

View File

@ -8,6 +8,15 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class PowerTests(TestCase): class PowerTests(TestCase):
def test_null(self):
IntegerModel.objects.create(big=100)
obj = IntegerModel.objects.annotate(
null_power_small=Power('small', 'normal'),
null_power_normal=Power('normal', 'big'),
).first()
self.assertIsNone(obj.null_power_small)
self.assertIsNone(obj.null_power_normal)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('1.0'), n2=Decimal('-0.6')) DecimalModel.objects.create(n1=Decimal('1.0'), n2=Decimal('-0.6'))
obj = DecimalModel.objects.annotate(n_power=Power('n1', 'n2')).first() obj = DecimalModel.objects.annotate(n_power=Power('n1', 'n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class RadiansTests(TestCase): class RadiansTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_radians=Radians('normal')).first()
self.assertIsNone(obj.null_radians)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_radians=Radians('n1'), n2_radians=Radians('n2')).first() obj = DecimalModel.objects.annotate(n1_radians=Radians('n1'), n2_radians=Radians('n2')).first()

View File

@ -10,6 +10,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class RoundTests(TestCase): class RoundTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_round=Round('normal')).first()
self.assertIsNone(obj.null_round)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_round=Round('n1'), n2_round=Round('n2')).first() obj = DecimalModel.objects.annotate(n1_round=Round('n1'), n2_round=Round('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class SinTests(TestCase): class SinTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_sin=Sin('normal')).first()
self.assertIsNone(obj.null_sin)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_sin=Sin('n1'), n2_sin=Sin('n2')).first() obj = DecimalModel.objects.annotate(n1_sin=Sin('n1'), n2_sin=Sin('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class SqrtTests(TestCase): class SqrtTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_sqrt=Sqrt('normal')).first()
self.assertIsNone(obj.null_sqrt)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_sqrt=Sqrt('n1'), n2_sqrt=Sqrt('n2')).first() obj = DecimalModel.objects.annotate(n1_sqrt=Sqrt('n1'), n2_sqrt=Sqrt('n2')).first()

View File

@ -11,6 +11,11 @@ from ..models import DecimalModel, FloatModel, IntegerModel
class TanTests(TestCase): class TanTests(TestCase):
def test_null(self):
IntegerModel.objects.create()
obj = IntegerModel.objects.annotate(null_tan=Tan('normal')).first()
self.assertIsNone(obj.null_tan)
def test_decimal(self): def test_decimal(self):
DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6')) DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
obj = DecimalModel.objects.annotate(n1_tan=Tan('n1'), n2_tan=Tan('n2')).first() obj = DecimalModel.objects.annotate(n1_tan=Tan('n1'), n2_tan=Tan('n2')).first()

View File

@ -1,3 +1,4 @@
from django.db import connection
from django.db.models import CharField, Value from django.db.models import CharField, Value
from django.db.models.functions import Length, LPad, RPad from django.db.models.functions import Length, LPad, RPad
from django.test import TestCase from django.test import TestCase
@ -8,6 +9,7 @@ from ..models import Author
class PadTests(TestCase): class PadTests(TestCase):
def test_pad(self): def test_pad(self):
Author.objects.create(name='John', alias='j') Author.objects.create(name='John', alias='j')
none_value = '' if connection.features.interprets_empty_strings_as_nulls else None
tests = ( tests = (
(LPad('name', 7, Value('xy')), 'xyxJohn'), (LPad('name', 7, Value('xy')), 'xyxJohn'),
(RPad('name', 7, Value('xy')), 'Johnxyx'), (RPad('name', 7, Value('xy')), 'Johnxyx'),
@ -21,6 +23,10 @@ class PadTests(TestCase):
(RPad('name', 2), 'Jo'), (RPad('name', 2), 'Jo'),
(LPad('name', 0), ''), (LPad('name', 0), ''),
(RPad('name', 0), ''), (RPad('name', 0), ''),
(LPad('name', None), none_value),
(RPad('name', None), none_value),
(LPad('goes_by', 1), none_value),
(RPad('goes_by', 1), none_value),
) )
for function, padded_name in tests: for function, padded_name in tests:
with self.subTest(function=function): with self.subTest(function=function):

View File

@ -1,3 +1,4 @@
from django.db import connection
from django.db.models import CharField, Value from django.db.models import CharField, Value
from django.db.models.functions import Length, Repeat from django.db.models.functions import Length, Repeat
from django.test import TestCase from django.test import TestCase
@ -8,11 +9,14 @@ from ..models import Author
class RepeatTests(TestCase): class RepeatTests(TestCase):
def test_basic(self): def test_basic(self):
Author.objects.create(name='John', alias='xyz') Author.objects.create(name='John', alias='xyz')
none_value = '' if connection.features.interprets_empty_strings_as_nulls else None
tests = ( tests = (
(Repeat('name', 0), ''), (Repeat('name', 0), ''),
(Repeat('name', 2), 'JohnJohn'), (Repeat('name', 2), 'JohnJohn'),
(Repeat('name', Length('alias'), output_field=CharField()), 'JohnJohnJohn'), (Repeat('name', Length('alias'), output_field=CharField()), 'JohnJohnJohn'),
(Repeat(Value('x'), 3, output_field=CharField()), 'xxx'), (Repeat(Value('x'), 3, output_field=CharField()), 'xxx'),
(Repeat('name', None), none_value),
(Repeat('goes_by', 1), none_value),
) )
for function, repeated_text in tests: for function, repeated_text in tests:
with self.subTest(function=function): with self.subTest(function=function):

View File

@ -1249,6 +1249,12 @@ class FTimeDeltaTests(TestCase):
] ]
self.assertEqual(delta_math, ['e4']) self.assertEqual(delta_math, ['e4'])
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
F('start') + Value(None, output_field=models.DurationField()),
output_field=models.DateTimeField(),
))
self.assertIsNone(queryset.first().shifted)
@skipUnlessDBFeature('supports_temporal_subtraction') @skipUnlessDBFeature('supports_temporal_subtraction')
def test_date_subtraction(self): def test_date_subtraction(self):
queryset = Experiment.objects.annotate( queryset = Experiment.objects.annotate(
@ -1266,6 +1272,18 @@ class FTimeDeltaTests(TestCase):
less_than_5_days = {e.name for e in queryset.filter(completion_duration__lt=datetime.timedelta(days=5))} less_than_5_days = {e.name for e in queryset.filter(completion_duration__lt=datetime.timedelta(days=5))}
self.assertEqual(less_than_5_days, {'e0', 'e1', 'e2'}) self.assertEqual(less_than_5_days, {'e0', 'e1', 'e2'})
queryset = Experiment.objects.annotate(difference=ExpressionWrapper(
F('completed') - Value(None, output_field=models.DateField()),
output_field=models.DurationField(),
))
self.assertIsNone(queryset.first().difference)
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
F('completed') - Value(None, output_field=models.DurationField()),
output_field=models.DateField(),
))
self.assertIsNone(queryset.first().shifted)
@skipUnlessDBFeature('supports_temporal_subtraction') @skipUnlessDBFeature('supports_temporal_subtraction')
def test_time_subtraction(self): def test_time_subtraction(self):
Time.objects.create(time=datetime.time(12, 30, 15, 2345)) Time.objects.create(time=datetime.time(12, 30, 15, 2345))
@ -1280,6 +1298,18 @@ class FTimeDeltaTests(TestCase):
datetime.timedelta(hours=1, minutes=15, seconds=15, microseconds=2345) datetime.timedelta(hours=1, minutes=15, seconds=15, microseconds=2345)
) )
queryset = Time.objects.annotate(difference=ExpressionWrapper(
F('time') - Value(None, output_field=models.TimeField()),
output_field=models.DurationField(),
))
self.assertIsNone(queryset.first().difference)
queryset = Time.objects.annotate(shifted=ExpressionWrapper(
F('time') - Value(None, output_field=models.DurationField()),
output_field=models.TimeField(),
))
self.assertIsNone(queryset.first().shifted)
@skipUnlessDBFeature('supports_temporal_subtraction') @skipUnlessDBFeature('supports_temporal_subtraction')
def test_datetime_subtraction(self): def test_datetime_subtraction(self):
under_estimate = [ under_estimate = [
@ -1292,6 +1322,18 @@ class FTimeDeltaTests(TestCase):
] ]
self.assertEqual(over_estimate, ['e4']) self.assertEqual(over_estimate, ['e4'])
queryset = Experiment.objects.annotate(difference=ExpressionWrapper(
F('start') - Value(None, output_field=models.DateTimeField()),
output_field=models.DurationField(),
))
self.assertIsNone(queryset.first().difference)
queryset = Experiment.objects.annotate(shifted=ExpressionWrapper(
F('start') - Value(None, output_field=models.DurationField()),
output_field=models.DateTimeField(),
))
self.assertIsNone(queryset.first().shifted)
@skipUnlessDBFeature('supports_temporal_subtraction') @skipUnlessDBFeature('supports_temporal_subtraction')
def test_datetime_subtraction_microseconds(self): def test_datetime_subtraction_microseconds(self):
delta = datetime.timedelta(microseconds=8999999999999999) delta = datetime.timedelta(microseconds=8999999999999999)