django/tests/custom_lookups/tests.py

256 lines
9.7 KiB
Python
Raw Normal View History

from copy import copy
from datetime import date
import unittest
from django.test import TestCase
from .models import Author
from django.db import models
from django.db import connection
from django.db.backends.utils import add_implementation
class Div3Lookup(models.lookups.Lookup):
lookup_name = 'div3'
def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
return '%s %%%% 3 = %s' % (lhs, rhs), params
class Div3Extract(models.lookups.Extract):
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params
class Div3LookupWithExtract(Div3Lookup):
lookup_name = 'div3'
extract_class = Div3Extract
class YearLte(models.lookups.LessThanOrEqual):
"""
The purpose of this lookup is to efficiently compare the year of the field.
"""
def as_sql(self, qn, connection):
# Skip the YearExtract above us (no possibility for efficient
# lookup otherwise).
real_lhs = self.lhs.lhs
lhs_sql, params = self.process_lhs(qn, connection, real_lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
# Build SQL where the integer year is concatenated with last month
# and day, then convert that to date. (We try to have SQL like:
# WHERE somecol <= '2013-12-31')
# but also make it work if the rhs_sql is field reference.
return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
class YearExtract(models.lookups.Extract):
def as_sql(self, qn, connection):
lhs_sql, params = qn.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params
@property
def output_type(self):
return models.IntegerField()
def get_lookup(self, lookup):
if lookup == 'lte':
return YearLte
else:
return super(YearExtract, self).get_lookup(lookup)
class YearWithExtract(models.lookups.Year):
extract_class = YearExtract
class InMonth(models.lookups.Lookup):
"""
InMonth matches if the column's month is contained in the value's month.
"""
lookup_name = 'inmonth'
def as_sql(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
# We need to be careful so that we get the params in right
# places.
full_params = params[:]
full_params.extend(rhs_params)
full_params.extend(params)
full_params.extend(rhs_params)
return ("%s >= date_trunc('month', %s) and "
"%s < date_trunc('month', %s) + interval '1 months'" %
(lhs, rhs, lhs, rhs), full_params)
class LookupTests(TestCase):
def test_basic_lookup(self):
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
models.IntegerField.register_lookup(Div3Lookup)
try:
self.assertQuerysetEqual(
Author.objects.filter(age__div3=0),
[a3], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=1).order_by('age'),
[a1, a4], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=2),
[a2], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(age__div3=3),
[], lambda x: x
)
finally:
models.IntegerField._unregister_lookup(Div3Lookup)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_birthdate_month(self):
a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
models.DateField.register_lookup(InMonth)
try:
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)),
[a3], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)),
[a2], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)),
[a1], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)),
[a4], lambda x: x
)
self.assertQuerysetEqual(
Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)),
[], lambda x: x
)
finally:
models.DateField._unregister_lookup(InMonth)
def test_custom_compiles(self):
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
class AnotherEqual(models.lookups.Exact):
lookup_name = 'anotherequal'
models.Field.register_lookup(AnotherEqual)
try:
@add_implementation(AnotherEqual, connection.vendor)
def custom_eq_sql(node, compiler):
return '1 = 1', []
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='asdf').order_by('name'),
[a1, a2, a3, a4], lambda x: x)
@add_implementation(AnotherEqual, connection.vendor)
def another_custom_eq_sql(node, compiler):
# If you need to override one method, it seems this is the best
# option.
node = copy(node)
class OverriddenAnotherEqual(AnotherEqual):
def get_rhs_op(self, connection, rhs):
return ' <> %s'
node.__class__ = OverriddenAnotherEqual
return node.as_sql(compiler, compiler.connection)
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
self.assertQuerysetEqual(
Author.objects.filter(name__anotherequal='a1').order_by('name'),
[a2, a3, a4], lambda x: x
)
finally:
models.Field._unregister_lookup(AnotherEqual)
def test_div3_extract(self):
models.IntegerField.register_lookup(Div3LookupWithExtract)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__div3__lte=3),
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3LookupWithExtract)
class YearLteTests(TestCase):
def setUp(self):
models.DateField.register_lookup(YearWithExtract)
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
def tearDown(self):
models.DateField._unregister_lookup(YearWithExtract)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_year_lte(self):
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2012),
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=2011),
[self.a1], lambda x: x)
# The non-optimized version works, too.
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lt=2012),
[self.a1], lambda x: x)
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
def test_year_lte_fexpr(self):
self.a2.age = 2011
self.a2.save()
self.a3.age = 2012
self.a3.save()
self.a4.age = 2013
self.a4.save()
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lte=models.F('age')),
[self.a3, self.a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(birthdate__year__lt=models.F('age')),
[self.a4], lambda x: x)
def test_year_lte_sql(self):
# This test will just check the generated SQL for __lte. This
# doesn't require running on PostgreSQL and spots the most likely
# error - not running YearLte SQL at all.
baseqs = Author.objects.order_by('name')
self.assertIn(
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
self.assertIn(
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))