Refs #18247 -- Fixed SQLite QuerySet filtering on decimal result of Least and Greatest.

This commit is contained in:
Sergey Fedoseev 2017-03-29 22:29:53 +05:00 committed by Tim Graham
parent d5977e492e
commit 068d75688f
4 changed files with 29 additions and 5 deletions

View File

@ -575,8 +575,8 @@ class Func(Expression):
data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
return template % data, params
def as_sqlite(self, compiler, connection):
sql, params = self.as_sql(compiler, connection)
def as_sqlite(self, compiler, connection, **extra_context):
sql, params = self.as_sql(compiler, connection, **extra_context)
try:
if self.output_field.get_internal_type() == 'DecimalField':
sql = 'CAST(%s AS NUMERIC)' % sql

View File

@ -132,7 +132,7 @@ class Greatest(Func):
def as_sqlite(self, compiler, connection):
"""Use the MAX function on SQLite."""
return super().as_sql(compiler, connection, function='MAX')
return super().as_sqlite(compiler, connection, function='MAX')
class Least(Func):
@ -152,7 +152,7 @@ class Least(Func):
def as_sqlite(self, compiler, connection):
"""Use the MIN function on SQLite."""
return super().as_sql(compiler, connection, function='MIN')
return super().as_sqlite(compiler, connection, function='MIN')
class Length(Transform):

View File

@ -49,3 +49,8 @@ class DTModel(models.Model):
def __str__(self):
return 'DTModel({0})'.format(self.name)
class DecimalModel(models.Model):
n1 = models.DecimalField(decimal_places=2, max_digits=6)
n2 = models.DecimalField(decimal_places=2, max_digits=6)

View File

@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from decimal import Decimal
from unittest import skipIf, skipUnless
from django.db import connection
@ -11,7 +12,7 @@ from django.db.models.functions import (
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.utils import timezone
from .models import Article, Author, Fan
from .models import Article, Author, DecimalModel, Fan
lorem_ipsum = """
@ -202,6 +203,15 @@ class FunctionTests(TestCase):
author.refresh_from_db()
self.assertEqual(author.alias, 'Jim')
def test_greatest_decimal_filter(self):
obj = DecimalModel.objects.create(n1=Decimal('1.1'), n2=Decimal('1.2'))
self.assertCountEqual(
DecimalModel.objects.annotate(
greatest=Greatest('n1', 'n2'),
).filter(greatest=Decimal('1.2')),
[obj],
)
def test_least(self):
now = timezone.now()
before = now - timedelta(hours=1)
@ -297,6 +307,15 @@ class FunctionTests(TestCase):
author.refresh_from_db()
self.assertEqual(author.alias, 'James Smith')
def test_least_decimal_filter(self):
obj = DecimalModel.objects.create(n1=Decimal('1.1'), n2=Decimal('1.2'))
self.assertCountEqual(
DecimalModel.objects.annotate(
least=Least('n1', 'n2'),
).filter(least=Decimal('1.1')),
[obj],
)
def test_concat(self):
Author.objects.create(name='Jayden')
Author.objects.create(name='John Smith', alias='smithj', goes_by='John')