import time import unittest from datetime import date, datetime from django.core.exceptions import FieldError from django.db import connection, models from django.db.models.fields.related_lookups import RelatedGreaterThan from django.db.models.lookups import EndsWith, StartsWith from django.test import SimpleTestCase, TestCase, override_settings from django.test.utils import register_lookup from django.utils import timezone from .models import Article, Author, MySQLUnixTimestamp class Div3Lookup(models.Lookup): lookup_name = "div3" def as_sql(self, compiler, connection): lhs, params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) return "(%s) %%%% 3 = %s" % (lhs, rhs), params def as_oracle(self, compiler, connection): lhs, params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) return "mod(%s, 3) = %s" % (lhs, rhs), params class Div3Transform(models.Transform): lookup_name = "div3" def as_sql(self, compiler, connection): lhs, lhs_params = compiler.compile(self.lhs) return "(%s) %%%% 3" % lhs, lhs_params def as_oracle(self, compiler, connection, **extra_context): lhs, lhs_params = compiler.compile(self.lhs) return "mod(%s, 3)" % lhs, lhs_params class Div3BilateralTransform(Div3Transform): bilateral = True class Mult3BilateralTransform(models.Transform): bilateral = True lookup_name = "mult3" def as_sql(self, compiler, connection): lhs, lhs_params = compiler.compile(self.lhs) return "3 * (%s)" % lhs, lhs_params class LastDigitTransform(models.Transform): lookup_name = "lastdigit" def as_sql(self, compiler, connection): lhs, lhs_params = compiler.compile(self.lhs) return "SUBSTR(CAST(%s AS CHAR(2)), 2, 1)" % lhs, lhs_params class UpperBilateralTransform(models.Transform): bilateral = True lookup_name = "upper" def as_sql(self, compiler, connection): lhs, lhs_params = compiler.compile(self.lhs) return "UPPER(%s)" % lhs, lhs_params class YearTransform(models.Transform): # Use a name that avoids collision with the built-in year lookup. lookup_name = "testyear" def as_sql(self, compiler, connection): lhs_sql, params = compiler.compile(self.lhs) return connection.ops.date_extract_sql("year", lhs_sql, params) @property def output_field(self): return models.IntegerField() @YearTransform.register_lookup class YearExact(models.lookups.Lookup): lookup_name = "exact" def as_sql(self, compiler, connection): # We will need to skip the extract part, and instead go # directly with the originating field, that is self.lhs.lhs lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs) rhs_sql, rhs_params = self.process_rhs(compiler, connection) # Note that we must be careful so that we have params in the # same order as we have the parts in the SQL. params = lhs_params + rhs_params + lhs_params + rhs_params # We use PostgreSQL specific SQL here. Note that we must do the # conversions in SQL instead of in Python to support F() references. return ( "%(lhs)s >= (%(rhs)s || '-01-01')::date " "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % {"lhs": lhs_sql, "rhs": rhs_sql}, params, ) @YearTransform.register_lookup class YearLte(models.lookups.LessThanOrEqual): """ The purpose of this lookup is to efficiently compare the year of the field. """ def as_sql(self, compiler, connection): # Skip the YearTransform above us (no possibility for efficient # lookup otherwise). real_lhs = self.lhs.lhs lhs_sql, params = self.process_lhs(compiler, connection, real_lhs) rhs_sql, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) # Build SQL where the integer year is concatenated with last month # and day, then convert that to date. (We try to have SQL like: # WHERE somecol <= '2013-12-31') # but also make it work if the rhs_sql is field reference. return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params class Exactly(models.lookups.Exact): """ This lookup is used to test lookup registration. """ lookup_name = "exactly" def get_rhs_op(self, connection, rhs): return connection.operators["exact"] % rhs class SQLFuncMixin: def as_sql(self, compiler, connection): return "%s()" % self.name, [] @property def output_field(self): return CustomField() class SQLFuncLookup(SQLFuncMixin, models.Lookup): def __init__(self, name, *args, **kwargs): super().__init__(*args, **kwargs) self.name = name class SQLFuncTransform(SQLFuncMixin, models.Transform): def __init__(self, name, *args, **kwargs): super().__init__(*args, **kwargs) self.name = name class SQLFuncFactory: def __init__(self, key, name): self.key = key self.name = name def __call__(self, *args, **kwargs): if self.key == "lookupfunc": return SQLFuncLookup(self.name, *args, **kwargs) return SQLFuncTransform(self.name, *args, **kwargs) class CustomField(models.TextField): def get_lookup(self, lookup_name): if lookup_name.startswith("lookupfunc_"): key, name = lookup_name.split("_", 1) return SQLFuncFactory(key, name) return super().get_lookup(lookup_name) def get_transform(self, lookup_name): if lookup_name.startswith("transformfunc_"): key, name = lookup_name.split("_", 1) return SQLFuncFactory(key, name) return super().get_transform(lookup_name) class CustomModel(models.Model): field = CustomField() # We will register this class temporarily in the test method. class InMonth(models.lookups.Lookup): """ InMonth matches if the column's month is the same as value's month. """ lookup_name = "inmonth" def as_sql(self, compiler, connection): lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) # We need to be careful so that we get the params in right # places. params = lhs_params + rhs_params + lhs_params + rhs_params return ( "%s >= date_trunc('month', %s) and " "%s < date_trunc('month', %s) + interval '1 months'" % (lhs, rhs, lhs, rhs), params, ) class DateTimeTransform(models.Transform): lookup_name = "as_datetime" @property def output_field(self): return models.DateTimeField() def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) return "from_unixtime({})".format(lhs), params class CustomStartsWith(StartsWith): lookup_name = "sw" class CustomEndsWith(EndsWith): lookup_name = "ew" class RelatedMoreThan(RelatedGreaterThan): lookup_name = "rmt" class LookupTests(TestCase): def test_custom_name_lookup(self): a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16)) Author.objects.create(name="a2", birthdate=date(2012, 2, 29)) with register_lookup(models.DateField, YearTransform), register_lookup( models.DateField, YearTransform, lookup_name="justtheyear" ), register_lookup(YearTransform, Exactly), register_lookup( YearTransform, Exactly, lookup_name="isactually" ): qs1 = Author.objects.filter(birthdate__testyear__exactly=1981) qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981) self.assertSequenceEqual(qs1, [a1]) self.assertSequenceEqual(qs2, [a1]) def test_custom_exact_lookup_none_rhs(self): """ __exact=None is transformed to __isnull=True if a custom lookup class with lookup_name != 'exact' is registered as the `exact` lookup. """ field = Author._meta.get_field("birthdate") OldExactLookup = field.get_lookup("exact") author = Author.objects.create(name="author", birthdate=None) try: field.register_lookup(Exactly, "exact") self.assertEqual(Author.objects.get(birthdate__exact=None), author) finally: field.register_lookup(OldExactLookup, "exact") def test_basic_lookup(self): a1 = Author.objects.create(name="a1", age=1) a2 = Author.objects.create(name="a2", age=2) a3 = Author.objects.create(name="a3", age=3) a4 = Author.objects.create(name="a4", age=4) with register_lookup(models.IntegerField, Div3Lookup): self.assertSequenceEqual(Author.objects.filter(age__div3=0), [a3]) self.assertSequenceEqual( Author.objects.filter(age__div3=1).order_by("age"), [a1, a4] ) self.assertSequenceEqual(Author.objects.filter(age__div3=2), [a2]) self.assertSequenceEqual(Author.objects.filter(age__div3=3), []) @unittest.skipUnless( connection.vendor == "postgresql", "PostgreSQL specific SQL used" ) def test_birthdate_month(self): a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16)) a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29)) a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31)) a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1)) with register_lookup(models.DateField, InMonth): self.assertSequenceEqual( Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3] ) self.assertSequenceEqual( Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2] ) self.assertSequenceEqual( Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1] ) self.assertSequenceEqual( Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4] ) self.assertSequenceEqual( Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), [] ) def test_div3_extract(self): with register_lookup(models.IntegerField, Div3Transform): a1 = Author.objects.create(name="a1", age=1) a2 = Author.objects.create(name="a2", age=2) a3 = Author.objects.create(name="a3", age=3) a4 = Author.objects.create(name="a4", age=4) baseqs = Author.objects.order_by("name") self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2]) self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4]) self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3]) self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a2]) self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), []) self.assertSequenceEqual( baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4] ) def test_foreignobject_lookup_registration(self): field = Article._meta.get_field("author") with register_lookup(models.ForeignObject, Exactly): self.assertIs(field.get_lookup("exactly"), Exactly) # ForeignObject should ignore regular Field lookups with register_lookup(models.Field, Exactly): self.assertIsNone(field.get_lookup("exactly")) def test_lookups_caching(self): field = Article._meta.get_field("author") # clear and re-cache field.get_lookups.cache_clear() self.assertNotIn("exactly", field.get_lookups()) # registration should bust the cache with register_lookup(models.ForeignObject, Exactly): # getting the lookups again should re-cache self.assertIn("exactly", field.get_lookups()) # Unregistration should bust the cache. self.assertNotIn("exactly", field.get_lookups()) class BilateralTransformTests(TestCase): def test_bilateral_upper(self): with register_lookup(models.CharField, UpperBilateralTransform): author1 = Author.objects.create(name="Doe") author2 = Author.objects.create(name="doe") author3 = Author.objects.create(name="Foo") self.assertCountEqual( Author.objects.filter(name__upper="doe"), [author1, author2], ) self.assertSequenceEqual( Author.objects.filter(name__upper__contains="f"), [author3], ) def test_bilateral_inner_qs(self): with register_lookup(models.CharField, UpperBilateralTransform): msg = "Bilateral transformations on nested querysets are not implemented." with self.assertRaisesMessage(NotImplementedError, msg): Author.objects.filter( name__upper__in=Author.objects.values_list("name") ) def test_bilateral_multi_value(self): with register_lookup(models.CharField, UpperBilateralTransform): Author.objects.bulk_create( [ Author(name="Foo"), Author(name="Bar"), Author(name="Ray"), ] ) self.assertQuerysetEqual( Author.objects.filter(name__upper__in=["foo", "bar", "doe"]).order_by( "name" ), ["Bar", "Foo"], lambda a: a.name, ) def test_div3_bilateral_extract(self): with register_lookup(models.IntegerField, Div3BilateralTransform): a1 = Author.objects.create(name="a1", age=1) a2 = Author.objects.create(name="a2", age=2) a3 = Author.objects.create(name="a3", age=3) a4 = Author.objects.create(name="a4", age=4) baseqs = Author.objects.order_by("name") self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2]) self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a3]) self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3]) self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a1, a2, a4]) self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [a1, a2, a3, a4]) self.assertSequenceEqual( baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4] ) def test_bilateral_order(self): with register_lookup( models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform ): a1 = Author.objects.create(name="a1", age=1) a2 = Author.objects.create(name="a2", age=2) a3 = Author.objects.create(name="a3", age=3) a4 = Author.objects.create(name="a4", age=4) baseqs = Author.objects.order_by("name") # mult3__div3 always leads to 0 self.assertSequenceEqual( baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4] ) self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3]) def test_transform_order_by(self): with register_lookup(models.IntegerField, LastDigitTransform): a1 = Author.objects.create(name="a1", age=11) a2 = Author.objects.create(name="a2", age=23) a3 = Author.objects.create(name="a3", age=32) a4 = Author.objects.create(name="a4", age=40) qs = Author.objects.order_by("age__lastdigit") self.assertSequenceEqual(qs, [a4, a1, a3, a2]) def test_bilateral_fexpr(self): with register_lookup(models.IntegerField, Mult3BilateralTransform): a1 = Author.objects.create(name="a1", age=1, average_rating=3.2) a2 = Author.objects.create(name="a2", age=2, average_rating=0.5) a3 = Author.objects.create(name="a3", age=3, average_rating=1.5) a4 = Author.objects.create(name="a4", age=4) baseqs = Author.objects.order_by("name") self.assertSequenceEqual( baseqs.filter(age__mult3=models.F("age")), [a1, a2, a3, a4] ) # Same as age >= average_rating self.assertSequenceEqual( baseqs.filter(age__mult3__gte=models.F("average_rating")), [a2, a3] ) @override_settings(USE_TZ=True) class DateTimeLookupTests(TestCase): @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific SQL used") def test_datetime_output_field(self): with register_lookup(models.PositiveIntegerField, DateTimeTransform): ut = MySQLUnixTimestamp.objects.create(timestamp=time.time()) y2k = timezone.make_aware(datetime(2000, 1, 1)) self.assertSequenceEqual( MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut] ) class YearLteTests(TestCase): @classmethod def setUpTestData(cls): cls.a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16)) cls.a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29)) cls.a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31)) cls.a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1)) def setUp(self): models.DateField.register_lookup(YearTransform) def tearDown(self): models.DateField._unregister_lookup(YearTransform) @unittest.skipUnless( connection.vendor == "postgresql", "PostgreSQL specific SQL used" ) def test_year_lte(self): baseqs = Author.objects.order_by("name") self.assertSequenceEqual( baseqs.filter(birthdate__testyear__lte=2012), [self.a1, self.a2, self.a3, self.a4], ) self.assertSequenceEqual( baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4] ) self.assertNotIn("BETWEEN", str(baseqs.filter(birthdate__testyear=2012).query)) self.assertSequenceEqual( baseqs.filter(birthdate__testyear__lte=2011), [self.a1] ) # The non-optimized version works, too. self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=2012), [self.a1]) @unittest.skipUnless( connection.vendor == "postgresql", "PostgreSQL specific SQL used" ) def test_year_lte_fexpr(self): self.a2.age = 2011 self.a2.save() self.a3.age = 2012 self.a3.save() self.a4.age = 2013 self.a4.save() baseqs = Author.objects.order_by("name") self.assertSequenceEqual( baseqs.filter(birthdate__testyear__lte=models.F("age")), [self.a3, self.a4] ) self.assertSequenceEqual( baseqs.filter(birthdate__testyear__lt=models.F("age")), [self.a4] ) def test_year_lte_sql(self): # This test will just check the generated SQL for __lte. This # doesn't require running on PostgreSQL and spots the most likely # error - not running YearLte SQL at all. baseqs = Author.objects.order_by("name") self.assertIn( "<= (2011 || ", str(baseqs.filter(birthdate__testyear__lte=2011).query) ) self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear__lte=2011).query)) def test_postgres_year_exact(self): baseqs = Author.objects.order_by("name") self.assertIn("= (2011 || ", str(baseqs.filter(birthdate__testyear=2011).query)) self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear=2011).query)) def test_custom_implementation_year_exact(self): try: # Two ways to add a customized implementation for different backends: # First is MonkeyPatch of the class. def as_custom_sql(self, compiler, connection): lhs_sql, lhs_params = self.process_lhs( compiler, connection, self.lhs.lhs ) rhs_sql, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params + lhs_params + rhs_params return ( "%(lhs)s >= " "str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " "AND %(lhs)s <= " "str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % {"lhs": lhs_sql, "rhs": rhs_sql}, params, ) setattr(YearExact, "as_" + connection.vendor, as_custom_sql) self.assertIn( "concat(", str(Author.objects.filter(birthdate__testyear=2012).query) ) finally: delattr(YearExact, "as_" + connection.vendor) try: # The other way is to subclass the original lookup and register the # subclassed lookup instead of the original. class CustomYearExact(YearExact): # This method should be named "as_mysql" for MySQL, # "as_postgresql" for postgres and so on, but as we don't know # which DB we are running on, we need to use setattr. def as_custom_sql(self, compiler, connection): lhs_sql, lhs_params = self.process_lhs( compiler, connection, self.lhs.lhs ) rhs_sql, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params + lhs_params + rhs_params return ( "%(lhs)s >= " "str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " "AND %(lhs)s <= " "str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % {"lhs": lhs_sql, "rhs": rhs_sql}, params, ) setattr( CustomYearExact, "as_" + connection.vendor, CustomYearExact.as_custom_sql, ) YearTransform.register_lookup(CustomYearExact) self.assertIn( "CONCAT(", str(Author.objects.filter(birthdate__testyear=2012).query) ) finally: YearTransform._unregister_lookup(CustomYearExact) YearTransform.register_lookup(YearExact) class TrackCallsYearTransform(YearTransform): # Use a name that avoids collision with the built-in year lookup. lookup_name = "testyear" call_order = [] def as_sql(self, compiler, connection): lhs_sql, params = compiler.compile(self.lhs) return connection.ops.date_extract_sql("year", lhs_sql), params @property def output_field(self): return models.IntegerField() def get_lookup(self, lookup_name): self.call_order.append("lookup") return super().get_lookup(lookup_name) def get_transform(self, lookup_name): self.call_order.append("transform") return super().get_transform(lookup_name) class LookupTransformCallOrderTests(SimpleTestCase): def test_call_order(self): with register_lookup(models.DateField, TrackCallsYearTransform): # junk lookup - tries lookup, then transform, then fails msg = ( "Unsupported lookup 'junk' for IntegerField or join on the field not " "permitted." ) with self.assertRaisesMessage(FieldError, msg): Author.objects.filter(birthdate__testyear__junk=2012) self.assertEqual( TrackCallsYearTransform.call_order, ["lookup", "transform"] ) TrackCallsYearTransform.call_order = [] # junk transform - tries transform only, then fails with self.assertRaisesMessage(FieldError, msg): Author.objects.filter(birthdate__testyear__junk__more_junk=2012) self.assertEqual(TrackCallsYearTransform.call_order, ["transform"]) TrackCallsYearTransform.call_order = [] # Just getting the year (implied __exact) - lookup only Author.objects.filter(birthdate__testyear=2012) self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"]) TrackCallsYearTransform.call_order = [] # Just getting the year (explicit __exact) - lookup only Author.objects.filter(birthdate__testyear__exact=2012) self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"]) class CustomisedMethodsTests(SimpleTestCase): def test_overridden_get_lookup(self): q = CustomModel.objects.filter(field__lookupfunc_monkeys=3) self.assertIn("monkeys()", str(q.query)) def test_overridden_get_transform(self): q = CustomModel.objects.filter(field__transformfunc_banana=3) self.assertIn("banana()", str(q.query)) def test_overridden_get_lookup_chain(self): q = CustomModel.objects.filter( field__transformfunc_banana__lookupfunc_elephants=3 ) self.assertIn("elephants()", str(q.query)) def test_overridden_get_transform_chain(self): q = CustomModel.objects.filter( field__transformfunc_banana__transformfunc_pear=3 ) self.assertIn("pear()", str(q.query)) class SubqueryTransformTests(TestCase): def test_subquery_usage(self): with register_lookup(models.IntegerField, Div3Transform): Author.objects.create(name="a1", age=1) a2 = Author.objects.create(name="a2", age=2) Author.objects.create(name="a3", age=3) Author.objects.create(name="a4", age=4) qs = Author.objects.order_by("name").filter( id__in=Author.objects.filter(age__div3=2) ) self.assertSequenceEqual(qs, [a2]) class RegisterLookupTests(SimpleTestCase): def test_class_lookup(self): author_name = Author._meta.get_field("name") with register_lookup(models.CharField, CustomStartsWith): self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith) self.assertIsNone(author_name.get_lookup("sw")) def test_lookup_on_transform(self): transform = Div3Transform with register_lookup(Div3Transform, CustomStartsWith): with register_lookup(Div3Transform, CustomEndsWith): self.assertEqual( transform.get_lookups(), {"sw": CustomStartsWith, "ew": CustomEndsWith}, ) self.assertEqual(transform.get_lookups(), {"sw": CustomStartsWith}) self.assertEqual(transform.get_lookups(), {}) def test_transform_on_field(self): author_age = Author._meta.get_field("age") with register_lookup(models.IntegerField, Div3Transform): self.assertEqual(author_age.get_transform("div3"), Div3Transform) self.assertIsNone(author_age.get_transform("div3")) def test_related_lookup(self): article_author = Article._meta.get_field("author") with register_lookup(models.Field, CustomStartsWith): self.assertIsNone(article_author.get_lookup("sw")) with register_lookup(models.ForeignKey, RelatedMoreThan): self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan)