diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index b216e9eb8de..b096a20e883 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -7,8 +7,8 @@ from django.utils.functional import cached_property class Extract(object): - def __init__(self, constraint_class, lhs): - self.constraint_class, self.lhs = constraint_class, lhs + def __init__(self, lhs): + self.lhs = lhs def get_lookup(self, lookup): return self.output_type.get_lookup(lookup) @@ -21,15 +21,18 @@ class Extract(object): return self.lhs.output_type def relabeled_clone(self, relabels): - return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels)) + return self.__class__(self.lhs.relabeled_clone(relabels)) + + def get_cols(self): + return self.lhs.get_cols() class Lookup(object): lookup_name = None extract_class = None - def __init__(self, constraint_class, lhs, rhs): - self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs + def __init__(self, lhs, rhs): + self.lhs, self.rhs = lhs, rhs if rhs is None: if not self.extract_class: raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name) @@ -37,7 +40,7 @@ class Lookup(object): self.rhs = self.get_prep_lookup() def get_extract(self): - return self.extract_class(self.constraint_class, self.lhs) + return self.extract_class(self.lhs) def get_prep_lookup(self): return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index d8b9c0ccb93..ef561e663f0 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -71,7 +71,8 @@ class SQLCompiler(object): def compile(self, node): if node.__class__ in self.connection.compile_implementations: - return self.connection.compile_implementations[node.__class__](node, self) + return self.connection.compile_implementations[node.__class__]( + node, self, self.connection) else: return node.as_sql(self, self.connection) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index bffdd551055..971db28ae12 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -18,6 +18,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.aggregates import refs_aggregate from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist +from django.db.models.lookups import Extract from django.db.models.query_utils import Q from django.db.models.related import PathInfo from django.db.models.sql import aggregates as base_aggregates_module @@ -1088,9 +1089,12 @@ class Query(object): if next: if not lookups: # This was the last lookup, so return value lookup. - return next(self.where_class, lhs, rhs) + if issubclass(next, Extract): + lhs = next(lhs) + next = lhs.get_lookup('exact') + return next(lhs, rhs) else: - lhs = next(self.where_class, lhs, None).get_extract() + lhs = next(lhs) # A field's get_lookup() can return None to opt for backwards # compatibility path. elif len(lookups) > 1: diff --git a/docs/ref/models/lookups.txt b/docs/ref/models/lookups.txt new file mode 100644 index 00000000000..b5dc5c20d14 --- /dev/null +++ b/docs/ref/models/lookups.txt @@ -0,0 +1,243 @@ +============== +Custom lookups +============== + +.. module:: django.db.models.lookups + :synopsis: Custom lookups + +.. currentmodule:: django.db.models + +(This documentation is candidate for complete rewrite, but contains +useful information of how to test the current implementation.) + +This documentation constains instructions of how to create custom lookups +for model fields. + +Django's ORM works using lookup paths when building query filters and other +query structures. For example in the query Book.filter(author__age__lte=30) +the author__age__lte is the lookup path. + +The lookup path consist of three different part. First is the related lookups, +above part author refers to Book's related model Author. Second part of the +lookup path is the final field, above this is Author's field age. Finally the +lte part is commonly called just lookup (TODO: this nomenclature is confusing, +can we invent something better). + +This documentation concentrates on writing custom lookups, that is custom +implementations for lte or any other lookup you wish to use. + +Django will fetch a ``Lookup`` class from the final field using the field's +method get_lookup(lookup_name). This method can do three things: + + 1. Return a Lookup class + 2. Raise a FieldError + 3. Return None + +Above return None is only available during backwards compatibility period and +returning None will not be allowed in Django 1.9 or later. The interpretation +is to use the old way of lookup hadling inside the ORM. + +The returned Lookup will be used to build the query. + +The Lookup class +~~~~~~~~~~~~~~~~ + +The API is as follows: + +.. attribute:: lookup_name + +A string used by Django to distinguish different lookups. + +.. method:: __init__(lhs, rhs) + +The lhs and rhs are the field reference (reference to field age in the +author__age__lte=30 example), and rhs is the value (30 in the example). + +.. attribute:: Lookup.lhs + +The left hand side part of this lookup. You can assume it implements the +query part interface (TODO: write interface definition...). + +.. method:: Lookup.as_sql(qn, connection) + +This method is used to produce the query string of the Lookup. A typical +implementation is usually something like:: + + def as_sql(self, qn, connection): + lhs, params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params.extend(rhs_params) + return '%s %s', (lhs, rhs), params + +where the is some query operator. The qn is a callable that +can be used to convert strings to quoted variants (that is, colname to +"colname"). Note that the quotation is *not* safe against SQL injection. + +In addition the qn implements method compile() which can be used to turn +anything with as_sql() method to query string. You should always call +qn.compile(part) instead of part.as_sql(qn, connection) so that 3rd party +backends have ability to customize the produced query string. More of this +later on. + +The connection is the used connection. + +.. method:: Lookup.process_lhs(qn, connection, lhs=None) + +This method is used to convert the left hand side of the lookup into query +string. The left hand side can be a field reference or a nested lookup. The +lhs kwarg can be used to convert something else than self.lhs to query string. + +.. method:: Lookup.process_rhs(qn, connection, rhs=None) + +The process_rhs method is used to convert the right hand side into query string. +The rhs is the value given in the filter clause. It can be a raw value to +compare agains, a F() reference to another field or even a QuerySet. + +.. method:: get_extract() + +The get_extract method is used in nested lookups. It must return an Extract instance. + +.. classattribute:: Lookup.extract_class + +The default implementation of get_extract() will return an instance of extract_class. + +In addition there are some private methods - that is, implementing just the above +mentioned attributes and methods is not enough, you must subclass Lookup instead. + +The Extract class +~~~~~~~~~~~~~~~~~ + +An Extract is something that converts a value to another value in the query string. +For example you could have an Extract that procudes modulo 3 of the given value. +In SQL this would be something like "author"."age" % 3. + +Extracts are used in nested lookups. The Extract class must implement the query +part interface. + +A simple Lookup example +~~~~~~~~~~~~~~~~~~~~~~~ + +This is how to write a simple div3 lookup for IntegerField:: + + from django.db.models import Lookup, IntegerField + class Div3(Lookup): + lookup_name = 'div3' + + def as_sql(self, qn, connection): + lhs_sql, params = self.process_lhs(qn, connection) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + # We need doulbe-escaping for the %%%% operator. + return '%s %%%% %s' % (lhs_sql, rhs_sql), params + + IntegerField.register_lookup(Div3) + +Now all IntegerFields or subclasses of IntegerField will have +a div3 lookup. For example you could do Author.objects.filter(age__div3=2). +This query would return every author whose age % 3 == 2. + +A simple nested lookup example +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here is how to write an Extract and a Lookup for IntegerField. The example +lookup can be used similarly as the above div3 lookup, and in addition it +support nesting lookups:: + + class Div3Extract(Extract): + lookup_name = 'div3' + + def as_sql(self, qn, connection): + lhs, lhs_params = qn.compile(self.lhs) + return '%s %%%% 3' % (lhs,), lhs_params + + IntegerField.register_lookup(Div3Extract) + +Note that if you already added Div3 for IntegerField in the above +example, now Div3LookupWithExtract will override that lookup. + +This lookup can be used like Div3 lookup, but in addition it supports +nesting, too. The default output type for Extracts is the same type as the +lhs' output_type. So, the Div3Extract supports all the same lookups as +IntegerField. For example Author.objects.filter(age__div3__in=[1, 2]) +returns all authors for which age % 3 in (1, 2). + +A more complex nested lookup +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We will write a Year lookup that extracts year from date field. This +field will convert the output type of the field - the lhs (or "input") +field is DateField, but output is of type IntegerField.:: + + from django.db.models import IntegerField, DateField + from django.db.models.lookups import Extract + + class YearExtract(Extract): + lookup_name = 'year' + + def as_sql(self, qn, connection): + lhs_sql, params = qn.compile(self.lhs) + # hmmh - this is internal API... + return connection.ops.date_extract_sql('year', lhs_sql), params + + @property + def output_type(self): + return IntegerField() + + DateField.register_lookup(YearExtract) + +Now you could write Author.objects.filter(birthdate__year=1981). This will +produce SQL like 'EXTRACT('year' from "author"."birthdate") = 1981'. The +produces SQL depends on used backend. In addtition you can use any lookup +defined for IntegerField, even div3 if you added that. So, +Authos.objects.filter(birthdate__year__div3=2) will return every author +with birthdate.year % 3 == 2. + +We could go further and add an optimized implementation for exact lookups:: + + from django.db.models.lookups import Lookup + + class YearExtractOptimized(YearExtract): + def get_lookup(self, lookup): + if lookup == 'exact': + return YearExact + return super(YearExtractOptimized, self).get_lookup() + + class YearExact(Lookup): + def as_sql(self, qn, connection): + # We will need to skip the extract part, and instead go + # directly with the originating field, that is self.lhs.lhs + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + # Note that we must be careful so that we have params in the + # same order as we have the parts in the SQL. + params = [] + params.extend(lhs_params) + params.extend(rhs_params) + params.extend(lhs_params) + params.extend(rhs_params) + # We use PostgreSQL specific SQL here. Note that we must do the + # conversions in SQL instead of in Python to support F() references. + return ("%(lhs)s >= (%(rhs)s || '-01-01')::date " + "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + +Note that we used PostgreSQL specific SQL above. What if we want to support +MySQL, too? This can be done by registering a different compiling implementation +for MySQL:: + + from django.db.backends.utils import add_implementation + @add_implementation(YearExact, 'mysql') + def mysql_year_exact(node, qn, connection): + lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs) + rhs_sql, rhs_params = node.process_rhs(qn, connection) + params = [] + params.extend(lhs_params) + params.extend(rhs_params) + params.extend(lhs_params) + params.extend(rhs_params) + return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + +Now, on MySQL instead of calling as_sql() of the YearExact Django will use the +above compile implementation. diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 5864bf3546c..19d952c4d72 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -20,16 +20,13 @@ class Div3Lookup(models.lookups.Lookup): class Div3Extract(models.lookups.Extract): + lookup_name = 'div3' + def as_sql(self, qn, connection): lhs, lhs_params = qn.compile(self.lhs) return '%s %%%% 3' % (lhs,), lhs_params -class Div3LookupWithExtract(Div3Lookup): - lookup_name = 'div3' - extract_class = Div3Extract - - class YearLte(models.lookups.LessThanOrEqual): """ The purpose of this lookup is to efficiently compare the year of the field. @@ -50,6 +47,8 @@ class YearLte(models.lookups.LessThanOrEqual): class YearExtract(models.lookups.Extract): + lookup_name = 'year' + def as_sql(self, qn, connection): lhs_sql, params = qn.compile(self.lhs) return connection.ops.date_extract_sql('year', lhs_sql), params @@ -61,12 +60,44 @@ class YearExtract(models.lookups.Extract): def get_lookup(self, lookup): if lookup == 'lte': return YearLte + elif lookup == 'exact': + return YearExact else: return super(YearExtract, self).get_lookup(lookup) -class YearWithExtract(models.lookups.Year): - extract_class = YearExtract +class YearExact(models.lookups.Lookup): + def as_sql(self, qn, connection): + # We will need to skip the extract part, and instead go + # directly with the originating field, that is self.lhs.lhs + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + # Note that we must be careful so that we have params in the + # same order as we have the parts in the SQL. + params = [] + params.extend(lhs_params) + params.extend(rhs_params) + params.extend(lhs_params) + params.extend(rhs_params) + # We use PostgreSQL specific SQL here. Note that we must do the + # conversions in SQL instead of in Python to support F() references. + return ("%(lhs)s >= (%(rhs)s || '-01-01')::date " + "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + + +@add_implementation(YearExact, 'mysql') +def mysql_year_exact(node, qn, connection): + lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs) + rhs_sql, rhs_params = node.process_rhs(qn, connection) + params = [] + params.extend(lhs_params) + params.extend(rhs_params) + params.extend(lhs_params) + params.extend(rhs_params) + return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) class InMonth(models.lookups.Lookup): @@ -158,7 +189,7 @@ class LookupTests(TestCase): models.Field.register_lookup(AnotherEqual) try: @add_implementation(AnotherEqual, connection.vendor) - def custom_eq_sql(node, compiler): + def custom_eq_sql(node, qn, connection): return '1 = 1', [] self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query)) @@ -167,7 +198,7 @@ class LookupTests(TestCase): [a1, a2, a3, a4], lambda x: x) @add_implementation(AnotherEqual, connection.vendor) - def another_custom_eq_sql(node, compiler): + def another_custom_eq_sql(node, qn, connection): # If you need to override one method, it seems this is the best # option. node = copy(node) @@ -176,7 +207,7 @@ class LookupTests(TestCase): def get_rhs_op(self, connection, rhs): return ' <> %s' node.__class__ = OverriddenAnotherEqual - return node.as_sql(compiler, compiler.connection) + return node.as_sql(qn, connection) self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query)) self.assertQuerysetEqual( Author.objects.filter(name__anotherequal='a1').order_by('name'), @@ -186,13 +217,16 @@ class LookupTests(TestCase): models.Field._unregister_lookup(AnotherEqual) def test_div3_extract(self): - models.IntegerField.register_lookup(Div3LookupWithExtract) + models.IntegerField.register_lookup(Div3Extract) try: a1 = Author.objects.create(name='a1', age=1) a2 = Author.objects.create(name='a2', age=2) a3 = Author.objects.create(name='a3', age=3) a4 = Author.objects.create(name='a4', age=4) baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(age__div3=2), + [a2], lambda x: x) self.assertQuerysetEqual( baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4], lambda x: x) @@ -200,19 +234,19 @@ class LookupTests(TestCase): baseqs.filter(age__div3__in=[0, 2]), [a2, a3], lambda x: x) finally: - models.IntegerField._unregister_lookup(Div3LookupWithExtract) + models.IntegerField._unregister_lookup(Div3Extract) class YearLteTests(TestCase): def setUp(self): - models.DateField.register_lookup(YearWithExtract) + models.DateField.register_lookup(YearExtract) self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31)) self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1)) def tearDown(self): - models.DateField._unregister_lookup(YearWithExtract) + models.DateField._unregister_lookup(YearExtract) @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") def test_year_lte(self): @@ -220,6 +254,11 @@ class YearLteTests(TestCase): self.assertQuerysetEqual( baseqs.filter(birthdate__year__lte=2012), [self.a1, self.a2, self.a3, self.a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(birthdate__year=2012), + [self.a2, self.a3, self.a4], lambda x: x) + + self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query)) self.assertQuerysetEqual( baseqs.filter(birthdate__year__lte=2011), [self.a1], lambda x: x) @@ -253,3 +292,12 @@ class YearLteTests(TestCase): '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query)) self.assertIn( '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) + + @unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used') + def test_mysql_year_exact(self): + self.assertQuerysetEqual( + Author.objects.filter(birthdate__year=2012).order_by('name'), + [self.a2, self.a3, self.a4], lambda x: x) + self.assertIn( + 'concat(', + str(Author.objects.filter(birthdate__year=2012).query))