Fixed #27498 -- Fixed filtering on annotated DecimalField on SQLite.

This commit is contained in:
Peter Inglesby 2016-12-10 18:05:34 +00:00 committed by Tim Graham
parent 96181080ba
commit a4cac17200
3 changed files with 86 additions and 1 deletions

View File

@ -2,10 +2,13 @@ import itertools
import math import math
import warnings import warnings
from copy import copy from copy import copy
from decimal import Decimal
from django.core.exceptions import EmptyResultSet from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Func, Value from django.db.models.expressions import Func, Value
from django.db.models.fields import DateTimeField, Field, IntegerField from django.db.models.fields import (
DateTimeField, DecimalField, Field, IntegerField,
)
from django.db.models.query_utils import RegisterLookupMixin from django.db.models.query_utils import RegisterLookupMixin
from django.utils.deprecation import RemovedInDjango20Warning from django.utils.deprecation import RemovedInDjango20Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -306,6 +309,40 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
IntegerField.register_lookup(IntegerLessThan) IntegerField.register_lookup(IntegerLessThan)
class DecimalComparisonLookup(object):
def as_sqlite(self, compiler, connection):
lhs_sql, params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
# For comparisons whose lhs is a DecimalField, cast rhs AS NUMERIC
# because the rhs will have been converted to a string by the
# rev_typecast_decimal() adapter.
if isinstance(self.rhs, Decimal):
rhs_sql = 'CAST(%s AS NUMERIC)' % rhs_sql
rhs_sql = self.get_rhs_op(connection, rhs_sql)
return '%s %s' % (lhs_sql, rhs_sql), params
@DecimalField.register_lookup
class DecimalGreaterThan(DecimalComparisonLookup, GreaterThan):
pass
@DecimalField.register_lookup
class DecimalGreaterThanOrEqual(DecimalComparisonLookup, GreaterThanOrEqual):
pass
@DecimalField.register_lookup
class DecimalLessThan(DecimalComparisonLookup, LessThan):
pass
@DecimalField.register_lookup
class DecimalLessThanOrEqual(DecimalComparisonLookup, LessThanOrEqual):
pass
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = 'in' lookup_name = 'in'

View File

@ -86,3 +86,13 @@ class MyISAMArticle(models.Model):
class Meta: class Meta:
db_table = 'myisam_article' db_table = 'myisam_article'
managed = False managed = False
class Product(models.Model):
name = models.CharField(max_length=80)
qty_target = models.DecimalField(max_digits=6, decimal_places=2)
class Stock(models.Model):
product = models.ForeignKey(Product, models.CASCADE)
qty_available = models.DecimalField(max_digits=6, decimal_places=2)

View File

@ -0,0 +1,38 @@
from django.db.models.aggregates import Sum
from django.db.models.expressions import F
from django.test import TestCase
from .models import Product, Stock
class DecimalFieldLookupTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.p1 = Product.objects.create(name='Product1', qty_target=10)
Stock.objects.create(product=cls.p1, qty_available=5)
Stock.objects.create(product=cls.p1, qty_available=6)
cls.p2 = Product.objects.create(name='Product2', qty_target=10)
Stock.objects.create(product=cls.p2, qty_available=5)
Stock.objects.create(product=cls.p2, qty_available=5)
cls.p3 = Product.objects.create(name='Product3', qty_target=10)
Stock.objects.create(product=cls.p3, qty_available=5)
Stock.objects.create(product=cls.p3, qty_available=4)
cls.queryset = Product.objects.annotate(
qty_available_sum=Sum('stock__qty_available'),
).annotate(qty_needed=F('qty_target') - F('qty_available_sum'))
def test_gt(self):
qs = self.queryset.filter(qty_needed__gt=0)
self.assertCountEqual(qs, [self.p3])
def test_gte(self):
qs = self.queryset.filter(qty_needed__gte=0)
self.assertCountEqual(qs, [self.p2, self.p3])
def test_lt(self):
qs = self.queryset.filter(qty_needed__lt=0)
self.assertCountEqual(qs, [self.p1])
def test_lte(self):
qs = self.queryset.filter(qty_needed__lte=0)
self.assertCountEqual(qs, [self.p1, self.p2])