Fixed #23493 -- Added bilateral attribute to Transform

This commit is contained in:
Thomas Chaumeny 2014-09-14 12:34:41 +02:00 committed by Anssi Kääriäinen
parent 6b39401baf
commit 00aa562884
6 changed files with 268 additions and 28 deletions

View File

@ -1,5 +1,4 @@
from copy import copy
from itertools import repeat
import inspect
from django.conf import settings
@ -7,6 +6,8 @@ from django.utils import timezone
from django.utils.functional import cached_property
from django.utils.six.moves import xrange
from .query_utils import QueryWrapper
class RegisterLookupMixin(object):
def _get_lookup(self, lookup_name):
@ -57,6 +58,9 @@ class RegisterLookupMixin(object):
class Transform(RegisterLookupMixin):
bilateral = False
def __init__(self, lhs, lookups):
self.lhs = lhs
self.init_lookups = lookups[:]
@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin):
class Lookup(RegisterLookupMixin):
lookup_name = None
def __init__(self, lhs, rhs):
def __init__(self, lhs, rhs, bilateral_transforms=None):
self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup()
if bilateral_transforms is None:
bilateral_transforms = []
if bilateral_transforms:
# We should warn the user as soon as possible if he is trying to apply
# a bilateral transformation on a nested QuerySet: that won't work.
# We need to import QuerySet here so as to avoid circular
from django.db.models.query import QuerySet
if isinstance(rhs, QuerySet):
raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value):
for transform, lookups in self.bilateral_transforms:
value = transform(value, lookups)
return value
def batch_process_rhs(self, qn, connection, rhs=None):
if rhs is None:
rhs = self.rhs
if self.bilateral_transforms:
sqls, sqls_params = [], []
for p in rhs:
value = QueryWrapper('%s',
[self.lhs.output_field.get_db_prep_value(p, connection)])
value = self.apply_bilateral_transforms(value)
sql, sql_params = qn.compile(value)
sqls.append(sql)
sqls_params.extend(sql_params)
else:
params = self.lhs.output_field.get_db_prep_lookup(
self.lookup_name, rhs, connection, prepared=True)
sqls, sqls_params = ['%s'] * len(params), params
return sqls, sqls_params
def get_prep_lookup(self):
return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin):
def process_rhs(self, qn, connection):
value = self.rhs
if self.bilateral_transforms:
if self.rhs_is_direct_value():
# Do not call get_db_prep_lookup here as the value will be
# transformed before being used for lookup
value = QueryWrapper("%s",
[self.lhs.output_field.get_db_prep_value(value, connection)])
value = self.apply_bilateral_transforms(value)
# Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with
@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual
class In(BuiltinLookup):
lookup_name = 'in'
def get_db_prep_lookup(self, value, connection):
params = self.lhs.output_field.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True)
if not params:
# TODO: check why this leads to circular import
from django.db.models.sql.datastructures import EmptyResultSet
raise EmptyResultSet
placeholder = '(' + ', '.join('%s' for p in params) + ')'
return (placeholder, params)
def process_rhs(self, qn, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable, we use batch_process_rhs
# to prepare/transform those values
rhs = list(self.rhs)
if not rhs:
from django.db.models.sql.datastructures import EmptyResultSet
raise EmptyResultSet
sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs)
placeholder = '(' + ', '.join(sqls) + ')'
return (placeholder, sqls_params)
else:
return super(In, self).process_rhs(qn, connection)
def get_rhs_op(self, connection, rhs):
return 'IN %s' % rhs
@ -220,8 +268,10 @@ class In(BuiltinLookup):
max_in_list_size = connection.ops.max_in_list_size()
if self.rhs_is_direct_value() and (max_in_list_size and
len(self.rhs) > max_in_list_size):
rhs, rhs_params = self.process_rhs(qn, connection)
# This is a special case for Oracle which limits the number of elements
# which can appear in an 'IN' clause.
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.batch_process_rhs(qn, connection)
in_clause_elements = ['(']
params = []
for offset in xrange(0, len(rhs_params), max_in_list_size):
@ -229,11 +279,12 @@ class In(BuiltinLookup):
in_clause_elements.append(' OR ')
in_clause_elements.append('%s IN (' % lhs)
params.extend(lhs_params)
group_size = min(len(rhs_params) - offset, max_in_list_size)
param_group = ', '.join(repeat('%s', group_size))
sqls = rhs[offset: offset + max_in_list_size]
sqls_params = rhs_params[offset: offset + max_in_list_size]
param_group = ', '.join(sqls)
in_clause_elements.append(param_group)
in_clause_elements.append(')')
params.extend(rhs_params[offset: offset + max_in_list_size])
params.extend(sqls_params)
in_clause_elements.append(')')
return ''.join(in_clause_elements), params
else:
@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup):
# we need to add the % pattern match to the lookup by something like
# col LIKE othercol || '%%'
# So, for Python values we don't need any special pattern, but for
# SQL reference values we need the correct pattern added.
value = self.rhs
if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
or hasattr(value, '_as_sql')):
# SQL reference values or SQL transformations we need the correct
# pattern added.
if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
return connection.pattern_ops[self.lookup_name] % rhs
else:
return super(PatternLookup, self).get_rhs_op(connection, rhs)
@ -291,8 +342,20 @@ class Year(Between):
default_lookups['year'] = Year
class Range(Between):
class Range(BuiltinLookup):
lookup_name = 'range'
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
def process_rhs(self, qn, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable of 2 values, we use batch_process_rhs
# to prepare/transform those values
return self.batch_process_rhs(qn, connection)
else:
return super(Range, self).process_rhs(qn, connection)
default_lookups['range'] = Range

View File

@ -1111,18 +1111,21 @@ class Query(object):
def build_lookup(self, lookups, lhs, rhs):
lookups = lookups[:]
bilaterals = []
while lookups:
lookup = lookups[0]
if len(lookups) == 1:
final_lookup = lhs.get_lookup(lookup)
if final_lookup:
return final_lookup(lhs, rhs)
return final_lookup(lhs, rhs, bilaterals)
# We didn't find a lookup, so we are going to try get_transform
# + get_lookup('exact').
lookups.append('exact')
next = lhs.get_transform(lookup)
if next:
lhs = next(lhs, lookups)
if getattr(next, 'bilateral', False):
bilaterals.append((next, lookups))
else:
raise FieldError(
"Unsupported lookup '%s' for %s or join on the field not "

View File

@ -127,7 +127,7 @@ function ``ABS()`` to transform the value before comparison::
lhs, params = qn.compile(self.lhs)
return "ABS(%s)" % lhs, params
Next, lets register it for ``IntegerField``::
Next, let's register it for ``IntegerField``::
from django.db.models import IntegerField
IntegerField.register_lookup(AbsoluteValue)
@ -144,9 +144,7 @@ SQL::
SELECT ... WHERE ABS("experiments"."change") < 27
Subclasses of ``Transform`` usually only operate on the left-hand side of the
expression. Further lookups will work on the transformed value. Note that in
this case where there is no other lookup specified, Django interprets
Note that in case there is no other lookup specified, Django interprets
``change__abs=27`` as ``change__abs__exact=27``.
When looking for which lookups are allowable after the ``Transform`` has been
@ -197,7 +195,7 @@ Notice also that as both sides are used multiple times in the query the params
need to contain ``lhs_params`` and ``rhs_params`` multiple times.
The final query does the inversion (``27`` to ``-27``) directly in the
database. The reason for doing this is that if the self.rhs is something else
database. The reason for doing this is that if the ``self.rhs`` is something else
than a plain integer value (for example an ``F()`` reference) we can't do the
transformations in Python.
@ -208,6 +206,46 @@ transformations in Python.
want to add an index on ``abs(change)`` which would allow these queries to
be very efficient.
A bilateral transformer example
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The ``AbsoluteValue`` example we discussed previously is a transformation which
applies to the left-hand side of the lookup. There may be some cases where you
want the transformation to be applied to both the left-hand side and the
right-hand side. For instance, if you want to filter a queryset based on the
equality of the left and right-hand side insensitively to some SQL function.
Let's examine the simple example of case-insensitive transformation here. This
transformation isn't very useful in practice as Django already comes with a bunch
of built-in case-insensitive lookups, but it will be a nice demonstration of
bilateral transformations in a database-agnostic way.
We define an ``UpperCase`` transformer which uses the SQL function ``UPPER()`` to
transform the values before comparison. We define
:attr:`bilateral = True <django.db.models.Transform.bilateral>` to indicate that
this transformation should apply to both ``lhs`` and ``rhs``::
from django.db.models import Transform
class UpperCase(Transform):
lookup_name = 'upper'
bilateral = True
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return "UPPER(%s)" % lhs, params
Next, let's register it::
from django.db.models import CharField, TextField
CharField.register_lookup(UpperCase)
TextField.register_lookup(UpperCase)
Now, the queryset ``Author.objects.filter(name__upper="doe")`` will generate a case
insensitive query like this::
SELECT ... WHERE UPPER("author"."name") = UPPER('doe')
Writing alternative implementations for existing lookups
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -129,6 +129,15 @@ Transform reference
This class follows the :ref:`Query Expression API <query-expression>`, which
implies that you can use ``<expression>__<transform1>__<transform2>``.
.. attribute:: bilateral
.. versionadded:: 1.8
A boolean indicating whether this transformation should apply to both
``lhs`` and ``rhs``. Bilateral transformations will be applied to ``rhs`` in
the same order as they appear in the lookup expression. By default it is set
to ``False``. For example usage, see :doc:`/howto/custom-lookups`.
.. attribute:: lhs
The left-hand side - what is being transformed. It must follow the

View File

@ -306,6 +306,11 @@ Models
* :doc:`Custom Lookups</howto/custom-lookups>` can now be registered using
a decorator pattern.
* The new :attr:`Transform.bilateral <django.db.models.Transform.bilateral>`
attribute allows creating bilateral transformations. These transformations
are applied to both ``lhs`` and ``rhs`` when used in a lookup expression,
providing opportunities for more sophisticated lookups.
Signals
^^^^^^^

View File

@ -17,7 +17,7 @@ class Div3Lookup(models.Lookup):
lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params)
return '%s %%%% 3 = %s' % (lhs, rhs), params
return '(%s) %%%% 3 = %s' % (lhs, rhs), params
def as_oracle(self, qn, connection):
lhs, params = self.process_lhs(qn, connection)
@ -31,12 +31,32 @@ class Div3Transform(models.Transform):
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params
return '(%s) %%%% 3' % lhs, lhs_params
def as_oracle(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return 'mod(%s, 3)' % lhs, lhs_params
class Div3BilateralTransform(Div3Transform):
bilateral = True
class Mult3BilateralTransform(models.Transform):
bilateral = True
lookup_name = 'mult3'
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '3 * (%s)' % lhs, lhs_params
class UpperBilateralTransform(models.Transform):
bilateral = True
lookup_name = 'upper'
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return 'UPPER(%s)' % lhs, lhs_params
class YearTransform(models.Transform):
lookup_name = 'year'
@ -225,10 +245,112 @@ class LookupTests(TestCase):
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[2, 4]),
[a2], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__gte=3),
[], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__range=(1, 2)),
[a1, a2, a4], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3Transform)
class BilateralTransformTests(TestCase):
def test_bilateral_upper(self):
models.CharField.register_lookup(UpperBilateralTransform)
try:
Author.objects.bulk_create([
Author(name='Doe'),
Author(name='doe'),
Author(name='Foo'),
])
self.assertQuerysetEqual(
Author.objects.filter(name__upper='doe'),
["<Author: Doe>", "<Author: doe>"], ordered=False)
finally:
models.CharField._unregister_lookup(UpperBilateralTransform)
def test_bilateral_inner_qs(self):
models.CharField.register_lookup(UpperBilateralTransform)
try:
with self.assertRaises(NotImplementedError):
Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
finally:
models.CharField._unregister_lookup(UpperBilateralTransform)
def test_div3_bilateral_extract(self):
models.IntegerField.register_lookup(Div3BilateralTransform)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__div3=2),
[a2], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__lte=3),
[a3], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[2, 4]),
[a1, a2, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__gte=3),
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__range=(1, 2)),
[a1, a2, a4], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3BilateralTransform)
def test_bilateral_order(self):
models.IntegerField.register_lookup(Mult3BilateralTransform)
models.IntegerField.register_lookup(Div3BilateralTransform)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__mult3__div3=42),
# mult3__div3 always leads to 0
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__mult3=42),
[a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Mult3BilateralTransform)
models.IntegerField._unregister_lookup(Div3BilateralTransform)
def test_bilateral_fexpr(self):
models.IntegerField.register_lookup(Mult3BilateralTransform)
try:
a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)
a2 = Author.objects.create(name='a2', age=2, average_rating=0.5)
a3 = Author.objects.create(name='a3', age=3, average_rating=1.5)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__mult3=models.F('age')),
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
# Same as age >= average_rating
baseqs.filter(age__mult3__gte=models.F('average_rating')),
[a2, a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Mult3BilateralTransform)
class YearLteTests(TestCase):
def setUp(self):
models.DateField.register_lookup(YearTransform)