Added documentation, polished implementation
This commit is contained in:
parent
32c04357a8
commit
2adf50428d
|
@ -7,8 +7,8 @@ from django.utils.functional import cached_property
|
|||
|
||||
|
||||
class Extract(object):
|
||||
def __init__(self, constraint_class, lhs):
|
||||
self.constraint_class, self.lhs = constraint_class, lhs
|
||||
def __init__(self, lhs):
|
||||
self.lhs = lhs
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_type.get_lookup(lookup)
|
||||
|
@ -21,15 +21,18 @@ class Extract(object):
|
|||
return self.lhs.output_type
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(self.constraint_class, self.lhs.relabeled_clone(relabels))
|
||||
return self.__class__(self.lhs.relabeled_clone(relabels))
|
||||
|
||||
def get_cols(self):
|
||||
return self.lhs.get_cols()
|
||||
|
||||
|
||||
class Lookup(object):
|
||||
lookup_name = None
|
||||
extract_class = None
|
||||
|
||||
def __init__(self, constraint_class, lhs, rhs):
|
||||
self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs
|
||||
def __init__(self, lhs, rhs):
|
||||
self.lhs, self.rhs = lhs, rhs
|
||||
if rhs is None:
|
||||
if not self.extract_class:
|
||||
raise FieldError("Lookup '%s' doesn't support nesting." % self.lookup_name)
|
||||
|
@ -37,7 +40,7 @@ class Lookup(object):
|
|||
self.rhs = self.get_prep_lookup()
|
||||
|
||||
def get_extract(self):
|
||||
return self.extract_class(self.constraint_class, self.lhs)
|
||||
return self.extract_class(self.lhs)
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs)
|
||||
|
|
|
@ -71,7 +71,8 @@ class SQLCompiler(object):
|
|||
|
||||
def compile(self, node):
|
||||
if node.__class__ in self.connection.compile_implementations:
|
||||
return self.connection.compile_implementations[node.__class__](node, self)
|
||||
return self.connection.compile_implementations[node.__class__](
|
||||
node, self, self.connection)
|
||||
else:
|
||||
return node.as_sql(self, self.connection)
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from django.db.models.constants import LOOKUP_SEP
|
|||
from django.db.models.aggregates import refs_aggregate
|
||||
from django.db.models.expressions import ExpressionNode
|
||||
from django.db.models.fields import FieldDoesNotExist
|
||||
from django.db.models.lookups import Extract
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.related import PathInfo
|
||||
from django.db.models.sql import aggregates as base_aggregates_module
|
||||
|
@ -1088,9 +1089,12 @@ class Query(object):
|
|||
if next:
|
||||
if not lookups:
|
||||
# This was the last lookup, so return value lookup.
|
||||
return next(self.where_class, lhs, rhs)
|
||||
if issubclass(next, Extract):
|
||||
lhs = next(lhs)
|
||||
next = lhs.get_lookup('exact')
|
||||
return next(lhs, rhs)
|
||||
else:
|
||||
lhs = next(self.where_class, lhs, None).get_extract()
|
||||
lhs = next(lhs)
|
||||
# A field's get_lookup() can return None to opt for backwards
|
||||
# compatibility path.
|
||||
elif len(lookups) > 1:
|
||||
|
|
|
@ -0,0 +1,243 @@
|
|||
==============
|
||||
Custom lookups
|
||||
==============
|
||||
|
||||
.. module:: django.db.models.lookups
|
||||
:synopsis: Custom lookups
|
||||
|
||||
.. currentmodule:: django.db.models
|
||||
|
||||
(This documentation is candidate for complete rewrite, but contains
|
||||
useful information of how to test the current implementation.)
|
||||
|
||||
This documentation constains instructions of how to create custom lookups
|
||||
for model fields.
|
||||
|
||||
Django's ORM works using lookup paths when building query filters and other
|
||||
query structures. For example in the query Book.filter(author__age__lte=30)
|
||||
the author__age__lte is the lookup path.
|
||||
|
||||
The lookup path consist of three different part. First is the related lookups,
|
||||
above part author refers to Book's related model Author. Second part of the
|
||||
lookup path is the final field, above this is Author's field age. Finally the
|
||||
lte part is commonly called just lookup (TODO: this nomenclature is confusing,
|
||||
can we invent something better).
|
||||
|
||||
This documentation concentrates on writing custom lookups, that is custom
|
||||
implementations for lte or any other lookup you wish to use.
|
||||
|
||||
Django will fetch a ``Lookup`` class from the final field using the field's
|
||||
method get_lookup(lookup_name). This method can do three things:
|
||||
|
||||
1. Return a Lookup class
|
||||
2. Raise a FieldError
|
||||
3. Return None
|
||||
|
||||
Above return None is only available during backwards compatibility period and
|
||||
returning None will not be allowed in Django 1.9 or later. The interpretation
|
||||
is to use the old way of lookup hadling inside the ORM.
|
||||
|
||||
The returned Lookup will be used to build the query.
|
||||
|
||||
The Lookup class
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
The API is as follows:
|
||||
|
||||
.. attribute:: lookup_name
|
||||
|
||||
A string used by Django to distinguish different lookups.
|
||||
|
||||
.. method:: __init__(lhs, rhs)
|
||||
|
||||
The lhs and rhs are the field reference (reference to field age in the
|
||||
author__age__lte=30 example), and rhs is the value (30 in the example).
|
||||
|
||||
.. attribute:: Lookup.lhs
|
||||
|
||||
The left hand side part of this lookup. You can assume it implements the
|
||||
query part interface (TODO: write interface definition...).
|
||||
|
||||
.. method:: Lookup.as_sql(qn, connection)
|
||||
|
||||
This method is used to produce the query string of the Lookup. A typical
|
||||
implementation is usually something like::
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
params = lhs_params.extend(rhs_params)
|
||||
return '%s <OPERATOR> %s', (lhs, rhs), params
|
||||
|
||||
where the <OPERATOR> is some query operator. The qn is a callable that
|
||||
can be used to convert strings to quoted variants (that is, colname to
|
||||
"colname"). Note that the quotation is *not* safe against SQL injection.
|
||||
|
||||
In addition the qn implements method compile() which can be used to turn
|
||||
anything with as_sql() method to query string. You should always call
|
||||
qn.compile(part) instead of part.as_sql(qn, connection) so that 3rd party
|
||||
backends have ability to customize the produced query string. More of this
|
||||
later on.
|
||||
|
||||
The connection is the used connection.
|
||||
|
||||
.. method:: Lookup.process_lhs(qn, connection, lhs=None)
|
||||
|
||||
This method is used to convert the left hand side of the lookup into query
|
||||
string. The left hand side can be a field reference or a nested lookup. The
|
||||
lhs kwarg can be used to convert something else than self.lhs to query string.
|
||||
|
||||
.. method:: Lookup.process_rhs(qn, connection, rhs=None)
|
||||
|
||||
The process_rhs method is used to convert the right hand side into query string.
|
||||
The rhs is the value given in the filter clause. It can be a raw value to
|
||||
compare agains, a F() reference to another field or even a QuerySet.
|
||||
|
||||
.. method:: get_extract()
|
||||
|
||||
The get_extract method is used in nested lookups. It must return an Extract instance.
|
||||
|
||||
.. classattribute:: Lookup.extract_class
|
||||
|
||||
The default implementation of get_extract() will return an instance of extract_class.
|
||||
|
||||
In addition there are some private methods - that is, implementing just the above
|
||||
mentioned attributes and methods is not enough, you must subclass Lookup instead.
|
||||
|
||||
The Extract class
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
An Extract is something that converts a value to another value in the query string.
|
||||
For example you could have an Extract that procudes modulo 3 of the given value.
|
||||
In SQL this would be something like "author"."age" % 3.
|
||||
|
||||
Extracts are used in nested lookups. The Extract class must implement the query
|
||||
part interface.
|
||||
|
||||
A simple Lookup example
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This is how to write a simple div3 lookup for IntegerField::
|
||||
|
||||
from django.db.models import Lookup, IntegerField
|
||||
class Div3(Lookup):
|
||||
lookup_name = 'div3'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs_sql, params = self.process_lhs(qn, connection)
|
||||
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
||||
params.extend(rhs_params)
|
||||
# We need doulbe-escaping for the %%%% operator.
|
||||
return '%s %%%% %s' % (lhs_sql, rhs_sql), params
|
||||
|
||||
IntegerField.register_lookup(Div3)
|
||||
|
||||
Now all IntegerFields or subclasses of IntegerField will have
|
||||
a div3 lookup. For example you could do Author.objects.filter(age__div3=2).
|
||||
This query would return every author whose age % 3 == 2.
|
||||
|
||||
A simple nested lookup example
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here is how to write an Extract and a Lookup for IntegerField. The example
|
||||
lookup can be used similarly as the above div3 lookup, and in addition it
|
||||
support nesting lookups::
|
||||
|
||||
class Div3Extract(Extract):
|
||||
lookup_name = 'div3'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = qn.compile(self.lhs)
|
||||
return '%s %%%% 3' % (lhs,), lhs_params
|
||||
|
||||
IntegerField.register_lookup(Div3Extract)
|
||||
|
||||
Note that if you already added Div3 for IntegerField in the above
|
||||
example, now Div3LookupWithExtract will override that lookup.
|
||||
|
||||
This lookup can be used like Div3 lookup, but in addition it supports
|
||||
nesting, too. The default output type for Extracts is the same type as the
|
||||
lhs' output_type. So, the Div3Extract supports all the same lookups as
|
||||
IntegerField. For example Author.objects.filter(age__div3__in=[1, 2])
|
||||
returns all authors for which age % 3 in (1, 2).
|
||||
|
||||
A more complex nested lookup
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
We will write a Year lookup that extracts year from date field. This
|
||||
field will convert the output type of the field - the lhs (or "input")
|
||||
field is DateField, but output is of type IntegerField.::
|
||||
|
||||
from django.db.models import IntegerField, DateField
|
||||
from django.db.models.lookups import Extract
|
||||
|
||||
class YearExtract(Extract):
|
||||
lookup_name = 'year'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs_sql, params = qn.compile(self.lhs)
|
||||
# hmmh - this is internal API...
|
||||
return connection.ops.date_extract_sql('year', lhs_sql), params
|
||||
|
||||
@property
|
||||
def output_type(self):
|
||||
return IntegerField()
|
||||
|
||||
DateField.register_lookup(YearExtract)
|
||||
|
||||
Now you could write Author.objects.filter(birthdate__year=1981). This will
|
||||
produce SQL like 'EXTRACT('year' from "author"."birthdate") = 1981'. The
|
||||
produces SQL depends on used backend. In addtition you can use any lookup
|
||||
defined for IntegerField, even div3 if you added that. So,
|
||||
Authos.objects.filter(birthdate__year__div3=2) will return every author
|
||||
with birthdate.year % 3 == 2.
|
||||
|
||||
We could go further and add an optimized implementation for exact lookups::
|
||||
|
||||
from django.db.models.lookups import Lookup
|
||||
|
||||
class YearExtractOptimized(YearExtract):
|
||||
def get_lookup(self, lookup):
|
||||
if lookup == 'exact':
|
||||
return YearExact
|
||||
return super(YearExtractOptimized, self).get_lookup()
|
||||
|
||||
class YearExact(Lookup):
|
||||
def as_sql(self, qn, connection):
|
||||
# We will need to skip the extract part, and instead go
|
||||
# directly with the originating field, that is self.lhs.lhs
|
||||
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
|
||||
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
||||
# Note that we must be careful so that we have params in the
|
||||
# same order as we have the parts in the SQL.
|
||||
params = []
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
# We use PostgreSQL specific SQL here. Note that we must do the
|
||||
# conversions in SQL instead of in Python to support F() references.
|
||||
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
|
||||
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
|
||||
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
|
||||
|
||||
Note that we used PostgreSQL specific SQL above. What if we want to support
|
||||
MySQL, too? This can be done by registering a different compiling implementation
|
||||
for MySQL::
|
||||
|
||||
from django.db.backends.utils import add_implementation
|
||||
@add_implementation(YearExact, 'mysql')
|
||||
def mysql_year_exact(node, qn, connection):
|
||||
lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
|
||||
rhs_sql, rhs_params = node.process_rhs(qn, connection)
|
||||
params = []
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
|
||||
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
|
||||
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
|
||||
|
||||
Now, on MySQL instead of calling as_sql() of the YearExact Django will use the
|
||||
above compile implementation.
|
|
@ -20,16 +20,13 @@ class Div3Lookup(models.lookups.Lookup):
|
|||
|
||||
|
||||
class Div3Extract(models.lookups.Extract):
|
||||
lookup_name = 'div3'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = qn.compile(self.lhs)
|
||||
return '%s %%%% 3' % (lhs,), lhs_params
|
||||
|
||||
|
||||
class Div3LookupWithExtract(Div3Lookup):
|
||||
lookup_name = 'div3'
|
||||
extract_class = Div3Extract
|
||||
|
||||
|
||||
class YearLte(models.lookups.LessThanOrEqual):
|
||||
"""
|
||||
The purpose of this lookup is to efficiently compare the year of the field.
|
||||
|
@ -50,6 +47,8 @@ class YearLte(models.lookups.LessThanOrEqual):
|
|||
|
||||
|
||||
class YearExtract(models.lookups.Extract):
|
||||
lookup_name = 'year'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs_sql, params = qn.compile(self.lhs)
|
||||
return connection.ops.date_extract_sql('year', lhs_sql), params
|
||||
|
@ -61,12 +60,44 @@ class YearExtract(models.lookups.Extract):
|
|||
def get_lookup(self, lookup):
|
||||
if lookup == 'lte':
|
||||
return YearLte
|
||||
elif lookup == 'exact':
|
||||
return YearExact
|
||||
else:
|
||||
return super(YearExtract, self).get_lookup(lookup)
|
||||
|
||||
|
||||
class YearWithExtract(models.lookups.Year):
|
||||
extract_class = YearExtract
|
||||
class YearExact(models.lookups.Lookup):
|
||||
def as_sql(self, qn, connection):
|
||||
# We will need to skip the extract part, and instead go
|
||||
# directly with the originating field, that is self.lhs.lhs
|
||||
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs)
|
||||
rhs_sql, rhs_params = self.process_rhs(qn, connection)
|
||||
# Note that we must be careful so that we have params in the
|
||||
# same order as we have the parts in the SQL.
|
||||
params = []
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
# We use PostgreSQL specific SQL here. Note that we must do the
|
||||
# conversions in SQL instead of in Python to support F() references.
|
||||
return ("%(lhs)s >= (%(rhs)s || '-01-01')::date "
|
||||
"AND %(lhs)s <= (%(rhs)s || '-12-31')::date" %
|
||||
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
|
||||
|
||||
|
||||
@add_implementation(YearExact, 'mysql')
|
||||
def mysql_year_exact(node, qn, connection):
|
||||
lhs_sql, lhs_params = node.process_lhs(qn, connection, node.lhs.lhs)
|
||||
rhs_sql, rhs_params = node.process_rhs(qn, connection)
|
||||
params = []
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
params.extend(lhs_params)
|
||||
params.extend(rhs_params)
|
||||
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
|
||||
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
|
||||
{'lhs': lhs_sql, 'rhs': rhs_sql}, params)
|
||||
|
||||
|
||||
class InMonth(models.lookups.Lookup):
|
||||
|
@ -158,7 +189,7 @@ class LookupTests(TestCase):
|
|||
models.Field.register_lookup(AnotherEqual)
|
||||
try:
|
||||
@add_implementation(AnotherEqual, connection.vendor)
|
||||
def custom_eq_sql(node, compiler):
|
||||
def custom_eq_sql(node, qn, connection):
|
||||
return '1 = 1', []
|
||||
|
||||
self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query))
|
||||
|
@ -167,7 +198,7 @@ class LookupTests(TestCase):
|
|||
[a1, a2, a3, a4], lambda x: x)
|
||||
|
||||
@add_implementation(AnotherEqual, connection.vendor)
|
||||
def another_custom_eq_sql(node, compiler):
|
||||
def another_custom_eq_sql(node, qn, connection):
|
||||
# If you need to override one method, it seems this is the best
|
||||
# option.
|
||||
node = copy(node)
|
||||
|
@ -176,7 +207,7 @@ class LookupTests(TestCase):
|
|||
def get_rhs_op(self, connection, rhs):
|
||||
return ' <> %s'
|
||||
node.__class__ = OverriddenAnotherEqual
|
||||
return node.as_sql(compiler, compiler.connection)
|
||||
return node.as_sql(qn, connection)
|
||||
self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query))
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(name__anotherequal='a1').order_by('name'),
|
||||
|
@ -186,13 +217,16 @@ class LookupTests(TestCase):
|
|||
models.Field._unregister_lookup(AnotherEqual)
|
||||
|
||||
def test_div3_extract(self):
|
||||
models.IntegerField.register_lookup(Div3LookupWithExtract)
|
||||
models.IntegerField.register_lookup(Div3Extract)
|
||||
try:
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
a2 = Author.objects.create(name='a2', age=2)
|
||||
a3 = Author.objects.create(name='a3', age=3)
|
||||
a4 = Author.objects.create(name='a4', age=4)
|
||||
baseqs = Author.objects.order_by('name')
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3=2),
|
||||
[a2], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__lte=3),
|
||||
[a1, a2, a3, a4], lambda x: x)
|
||||
|
@ -200,19 +234,19 @@ class LookupTests(TestCase):
|
|||
baseqs.filter(age__div3__in=[0, 2]),
|
||||
[a2, a3], lambda x: x)
|
||||
finally:
|
||||
models.IntegerField._unregister_lookup(Div3LookupWithExtract)
|
||||
models.IntegerField._unregister_lookup(Div3Extract)
|
||||
|
||||
|
||||
class YearLteTests(TestCase):
|
||||
def setUp(self):
|
||||
models.DateField.register_lookup(YearWithExtract)
|
||||
models.DateField.register_lookup(YearExtract)
|
||||
self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16))
|
||||
self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29))
|
||||
self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31))
|
||||
self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1))
|
||||
|
||||
def tearDown(self):
|
||||
models.DateField._unregister_lookup(YearWithExtract)
|
||||
models.DateField._unregister_lookup(YearExtract)
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used")
|
||||
def test_year_lte(self):
|
||||
|
@ -220,6 +254,11 @@ class YearLteTests(TestCase):
|
|||
self.assertQuerysetEqual(
|
||||
baseqs.filter(birthdate__year__lte=2012),
|
||||
[self.a1, self.a2, self.a3, self.a4], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(birthdate__year=2012),
|
||||
[self.a2, self.a3, self.a4], lambda x: x)
|
||||
|
||||
self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query))
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(birthdate__year__lte=2011),
|
||||
[self.a1], lambda x: x)
|
||||
|
@ -253,3 +292,12 @@ class YearLteTests(TestCase):
|
|||
'<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query))
|
||||
self.assertIn(
|
||||
'-12-31', str(baseqs.filter(birthdate__year__lte=2011).query))
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific SQL used')
|
||||
def test_mysql_year_exact(self):
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(birthdate__year=2012).order_by('name'),
|
||||
[self.a2, self.a3, self.a4], lambda x: x)
|
||||
self.assertIn(
|
||||
'concat(',
|
||||
str(Author.objects.filter(birthdate__year=2012).query))
|
||||
|
|
Loading…
Reference in New Issue