Fixed #23493 -- Added bilateral attribute to Transform

This commit is contained in:
Thomas Chaumeny 2014-09-14 12:34:41 +02:00 committed by Anssi Kääriäinen
parent 6b39401baf
commit 00aa562884
6 changed files with 268 additions and 28 deletions

View File

@ -1,5 +1,4 @@
from copy import copy from copy import copy
from itertools import repeat
import inspect import inspect
from django.conf import settings from django.conf import settings
@ -7,6 +6,8 @@ from django.utils import timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.six.moves import xrange from django.utils.six.moves import xrange
from .query_utils import QueryWrapper
class RegisterLookupMixin(object): class RegisterLookupMixin(object):
def _get_lookup(self, lookup_name): def _get_lookup(self, lookup_name):
@ -57,6 +58,9 @@ class RegisterLookupMixin(object):
class Transform(RegisterLookupMixin): class Transform(RegisterLookupMixin):
bilateral = False
def __init__(self, lhs, lookups): def __init__(self, lhs, lookups):
self.lhs = lhs self.lhs = lhs
self.init_lookups = lookups[:] self.init_lookups = lookups[:]
@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin):
class Lookup(RegisterLookupMixin): class Lookup(RegisterLookupMixin):
lookup_name = None lookup_name = None
def __init__(self, lhs, rhs): def __init__(self, lhs, rhs, bilateral_transforms=None):
self.lhs, self.rhs = lhs, rhs self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup() self.rhs = self.get_prep_lookup()
if bilateral_transforms is None:
bilateral_transforms = []
if bilateral_transforms:
# We should warn the user as soon as possible if he is trying to apply
# a bilateral transformation on a nested QuerySet: that won't work.
# We need to import QuerySet here so as to avoid circular
from django.db.models.query import QuerySet
if isinstance(rhs, QuerySet):
raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value):
for transform, lookups in self.bilateral_transforms:
value = transform(value, lookups)
return value
def batch_process_rhs(self, qn, connection, rhs=None):
if rhs is None:
rhs = self.rhs
if self.bilateral_transforms:
sqls, sqls_params = [], []
for p in rhs:
value = QueryWrapper('%s',
[self.lhs.output_field.get_db_prep_value(p, connection)])
value = self.apply_bilateral_transforms(value)
sql, sql_params = qn.compile(value)
sqls.append(sql)
sqls_params.extend(sql_params)
else:
params = self.lhs.output_field.get_db_prep_lookup(
self.lookup_name, rhs, connection, prepared=True)
sqls, sqls_params = ['%s'] * len(params), params
return sqls, sqls_params
def get_prep_lookup(self): def get_prep_lookup(self):
return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs) return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin):
def process_rhs(self, qn, connection): def process_rhs(self, qn, connection):
value = self.rhs value = self.rhs
if self.bilateral_transforms:
if self.rhs_is_direct_value():
# Do not call get_db_prep_lookup here as the value will be
# transformed before being used for lookup
value = QueryWrapper("%s",
[self.lhs.output_field.get_db_prep_value(value, connection)])
value = self.apply_bilateral_transforms(value)
# Due to historical reasons there are a couple of different # Due to historical reasons there are a couple of different
# ways to produce sql here. get_compiler is likely a Query # ways to produce sql here. get_compiler is likely a Query
# instance, _as_sql QuerySet and as_sql just something with # instance, _as_sql QuerySet and as_sql just something with
@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual
class In(BuiltinLookup): class In(BuiltinLookup):
lookup_name = 'in' lookup_name = 'in'
def get_db_prep_lookup(self, value, connection): def process_rhs(self, qn, connection):
params = self.lhs.output_field.get_db_prep_lookup( if self.rhs_is_direct_value():
self.lookup_name, value, connection, prepared=True) # rhs should be an iterable, we use batch_process_rhs
if not params: # to prepare/transform those values
# TODO: check why this leads to circular import rhs = list(self.rhs)
if not rhs:
from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.datastructures import EmptyResultSet
raise EmptyResultSet raise EmptyResultSet
placeholder = '(' + ', '.join('%s' for p in params) + ')' sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs)
return (placeholder, params) placeholder = '(' + ', '.join(sqls) + ')'
return (placeholder, sqls_params)
else:
return super(In, self).process_rhs(qn, connection)
def get_rhs_op(self, connection, rhs): def get_rhs_op(self, connection, rhs):
return 'IN %s' % rhs return 'IN %s' % rhs
@ -220,8 +268,10 @@ class In(BuiltinLookup):
max_in_list_size = connection.ops.max_in_list_size() max_in_list_size = connection.ops.max_in_list_size()
if self.rhs_is_direct_value() and (max_in_list_size and if self.rhs_is_direct_value() and (max_in_list_size and
len(self.rhs) > max_in_list_size): len(self.rhs) > max_in_list_size):
rhs, rhs_params = self.process_rhs(qn, connection) # This is a special case for Oracle which limits the number of elements
# which can appear in an 'IN' clause.
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.batch_process_rhs(qn, connection)
in_clause_elements = ['('] in_clause_elements = ['(']
params = [] params = []
for offset in xrange(0, len(rhs_params), max_in_list_size): for offset in xrange(0, len(rhs_params), max_in_list_size):
@ -229,11 +279,12 @@ class In(BuiltinLookup):
in_clause_elements.append(' OR ') in_clause_elements.append(' OR ')
in_clause_elements.append('%s IN (' % lhs) in_clause_elements.append('%s IN (' % lhs)
params.extend(lhs_params) params.extend(lhs_params)
group_size = min(len(rhs_params) - offset, max_in_list_size) sqls = rhs[offset: offset + max_in_list_size]
param_group = ', '.join(repeat('%s', group_size)) sqls_params = rhs_params[offset: offset + max_in_list_size]
param_group = ', '.join(sqls)
in_clause_elements.append(param_group) in_clause_elements.append(param_group)
in_clause_elements.append(')') in_clause_elements.append(')')
params.extend(rhs_params[offset: offset + max_in_list_size]) params.extend(sqls_params)
in_clause_elements.append(')') in_clause_elements.append(')')
return ''.join(in_clause_elements), params return ''.join(in_clause_elements), params
else: else:
@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup):
# we need to add the % pattern match to the lookup by something like # we need to add the % pattern match to the lookup by something like
# col LIKE othercol || '%%' # col LIKE othercol || '%%'
# So, for Python values we don't need any special pattern, but for # So, for Python values we don't need any special pattern, but for
# SQL reference values we need the correct pattern added. # SQL reference values or SQL transformations we need the correct
value = self.rhs # pattern added.
if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql') if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
or hasattr(value, '_as_sql')): or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
return connection.pattern_ops[self.lookup_name] % rhs return connection.pattern_ops[self.lookup_name] % rhs
else: else:
return super(PatternLookup, self).get_rhs_op(connection, rhs) return super(PatternLookup, self).get_rhs_op(connection, rhs)
@ -291,8 +342,20 @@ class Year(Between):
default_lookups['year'] = Year default_lookups['year'] = Year
class Range(Between): class Range(BuiltinLookup):
lookup_name = 'range' lookup_name = 'range'
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
def process_rhs(self, qn, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable of 2 values, we use batch_process_rhs
# to prepare/transform those values
return self.batch_process_rhs(qn, connection)
else:
return super(Range, self).process_rhs(qn, connection)
default_lookups['range'] = Range default_lookups['range'] = Range

View File

@ -1111,18 +1111,21 @@ class Query(object):
def build_lookup(self, lookups, lhs, rhs): def build_lookup(self, lookups, lhs, rhs):
lookups = lookups[:] lookups = lookups[:]
bilaterals = []
while lookups: while lookups:
lookup = lookups[0] lookup = lookups[0]
if len(lookups) == 1: if len(lookups) == 1:
final_lookup = lhs.get_lookup(lookup) final_lookup = lhs.get_lookup(lookup)
if final_lookup: if final_lookup:
return final_lookup(lhs, rhs) return final_lookup(lhs, rhs, bilaterals)
# We didn't find a lookup, so we are going to try get_transform # We didn't find a lookup, so we are going to try get_transform
# + get_lookup('exact'). # + get_lookup('exact').
lookups.append('exact') lookups.append('exact')
next = lhs.get_transform(lookup) next = lhs.get_transform(lookup)
if next: if next:
lhs = next(lhs, lookups) lhs = next(lhs, lookups)
if getattr(next, 'bilateral', False):
bilaterals.append((next, lookups))
else: else:
raise FieldError( raise FieldError(
"Unsupported lookup '%s' for %s or join on the field not " "Unsupported lookup '%s' for %s or join on the field not "

View File

@ -127,7 +127,7 @@ function ``ABS()`` to transform the value before comparison::
lhs, params = qn.compile(self.lhs) lhs, params = qn.compile(self.lhs)
return "ABS(%s)" % lhs, params return "ABS(%s)" % lhs, params
Next, lets register it for ``IntegerField``:: Next, let's register it for ``IntegerField``::
from django.db.models import IntegerField from django.db.models import IntegerField
IntegerField.register_lookup(AbsoluteValue) IntegerField.register_lookup(AbsoluteValue)
@ -144,9 +144,7 @@ SQL::
SELECT ... WHERE ABS("experiments"."change") < 27 SELECT ... WHERE ABS("experiments"."change") < 27
Subclasses of ``Transform`` usually only operate on the left-hand side of the Note that in case there is no other lookup specified, Django interprets
expression. Further lookups will work on the transformed value. Note that in
this case where there is no other lookup specified, Django interprets
``change__abs=27`` as ``change__abs__exact=27``. ``change__abs=27`` as ``change__abs__exact=27``.
When looking for which lookups are allowable after the ``Transform`` has been When looking for which lookups are allowable after the ``Transform`` has been
@ -197,7 +195,7 @@ Notice also that as both sides are used multiple times in the query the params
need to contain ``lhs_params`` and ``rhs_params`` multiple times. need to contain ``lhs_params`` and ``rhs_params`` multiple times.
The final query does the inversion (``27`` to ``-27``) directly in the The final query does the inversion (``27`` to ``-27``) directly in the
database. The reason for doing this is that if the self.rhs is something else database. The reason for doing this is that if the ``self.rhs`` is something else
than a plain integer value (for example an ``F()`` reference) we can't do the than a plain integer value (for example an ``F()`` reference) we can't do the
transformations in Python. transformations in Python.
@ -208,6 +206,46 @@ transformations in Python.
want to add an index on ``abs(change)`` which would allow these queries to want to add an index on ``abs(change)`` which would allow these queries to
be very efficient. be very efficient.
A bilateral transformer example
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The ``AbsoluteValue`` example we discussed previously is a transformation which
applies to the left-hand side of the lookup. There may be some cases where you
want the transformation to be applied to both the left-hand side and the
right-hand side. For instance, if you want to filter a queryset based on the
equality of the left and right-hand side insensitively to some SQL function.
Let's examine the simple example of case-insensitive transformation here. This
transformation isn't very useful in practice as Django already comes with a bunch
of built-in case-insensitive lookups, but it will be a nice demonstration of
bilateral transformations in a database-agnostic way.
We define an ``UpperCase`` transformer which uses the SQL function ``UPPER()`` to
transform the values before comparison. We define
:attr:`bilateral = True <django.db.models.Transform.bilateral>` to indicate that
this transformation should apply to both ``lhs`` and ``rhs``::
from django.db.models import Transform
class UpperCase(Transform):
lookup_name = 'upper'
bilateral = True
def as_sql(self, qn, connection):
lhs, params = qn.compile(self.lhs)
return "UPPER(%s)" % lhs, params
Next, let's register it::
from django.db.models import CharField, TextField
CharField.register_lookup(UpperCase)
TextField.register_lookup(UpperCase)
Now, the queryset ``Author.objects.filter(name__upper="doe")`` will generate a case
insensitive query like this::
SELECT ... WHERE UPPER("author"."name") = UPPER('doe')
Writing alternative implementations for existing lookups Writing alternative implementations for existing lookups
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -129,6 +129,15 @@ Transform reference
This class follows the :ref:`Query Expression API <query-expression>`, which This class follows the :ref:`Query Expression API <query-expression>`, which
implies that you can use ``<expression>__<transform1>__<transform2>``. implies that you can use ``<expression>__<transform1>__<transform2>``.
.. attribute:: bilateral
.. versionadded:: 1.8
A boolean indicating whether this transformation should apply to both
``lhs`` and ``rhs``. Bilateral transformations will be applied to ``rhs`` in
the same order as they appear in the lookup expression. By default it is set
to ``False``. For example usage, see :doc:`/howto/custom-lookups`.
.. attribute:: lhs .. attribute:: lhs
The left-hand side - what is being transformed. It must follow the The left-hand side - what is being transformed. It must follow the

View File

@ -306,6 +306,11 @@ Models
* :doc:`Custom Lookups</howto/custom-lookups>` can now be registered using * :doc:`Custom Lookups</howto/custom-lookups>` can now be registered using
a decorator pattern. a decorator pattern.
* The new :attr:`Transform.bilateral <django.db.models.Transform.bilateral>`
attribute allows creating bilateral transformations. These transformations
are applied to both ``lhs`` and ``rhs`` when used in a lookup expression,
providing opportunities for more sophisticated lookups.
Signals Signals
^^^^^^^ ^^^^^^^

View File

@ -17,7 +17,7 @@ class Div3Lookup(models.Lookup):
lhs, params = self.process_lhs(qn, connection) lhs, params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection)
params.extend(rhs_params) params.extend(rhs_params)
return '%s %%%% 3 = %s' % (lhs, rhs), params return '(%s) %%%% 3 = %s' % (lhs, rhs), params
def as_oracle(self, qn, connection): def as_oracle(self, qn, connection):
lhs, params = self.process_lhs(qn, connection) lhs, params = self.process_lhs(qn, connection)
@ -31,12 +31,32 @@ class Div3Transform(models.Transform):
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs) lhs, lhs_params = qn.compile(self.lhs)
return '%s %%%% 3' % (lhs,), lhs_params return '(%s) %%%% 3' % lhs, lhs_params
def as_oracle(self, qn, connection): def as_oracle(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs) lhs, lhs_params = qn.compile(self.lhs)
return 'mod(%s, 3)' % lhs, lhs_params return 'mod(%s, 3)' % lhs, lhs_params
class Div3BilateralTransform(Div3Transform):
bilateral = True
class Mult3BilateralTransform(models.Transform):
bilateral = True
lookup_name = 'mult3'
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return '3 * (%s)' % lhs, lhs_params
class UpperBilateralTransform(models.Transform):
bilateral = True
lookup_name = 'upper'
def as_sql(self, qn, connection):
lhs, lhs_params = qn.compile(self.lhs)
return 'UPPER(%s)' % lhs, lhs_params
class YearTransform(models.Transform): class YearTransform(models.Transform):
lookup_name = 'year' lookup_name = 'year'
@ -225,10 +245,112 @@ class LookupTests(TestCase):
self.assertQuerysetEqual( self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]), baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x) [a2, a3], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[2, 4]),
[a2], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__gte=3),
[], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__range=(1, 2)),
[a1, a2, a4], lambda x: x)
finally: finally:
models.IntegerField._unregister_lookup(Div3Transform) models.IntegerField._unregister_lookup(Div3Transform)
class BilateralTransformTests(TestCase):
def test_bilateral_upper(self):
models.CharField.register_lookup(UpperBilateralTransform)
try:
Author.objects.bulk_create([
Author(name='Doe'),
Author(name='doe'),
Author(name='Foo'),
])
self.assertQuerysetEqual(
Author.objects.filter(name__upper='doe'),
["<Author: Doe>", "<Author: doe>"], ordered=False)
finally:
models.CharField._unregister_lookup(UpperBilateralTransform)
def test_bilateral_inner_qs(self):
models.CharField.register_lookup(UpperBilateralTransform)
try:
with self.assertRaises(NotImplementedError):
Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
finally:
models.CharField._unregister_lookup(UpperBilateralTransform)
def test_div3_bilateral_extract(self):
models.IntegerField.register_lookup(Div3BilateralTransform)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__div3=2),
[a2], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__lte=3),
[a3], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[0, 2]),
[a2, a3], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__in=[2, 4]),
[a1, a2, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__gte=3),
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__range=(1, 2)),
[a1, a2, a4], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Div3BilateralTransform)
def test_bilateral_order(self):
models.IntegerField.register_lookup(Mult3BilateralTransform)
models.IntegerField.register_lookup(Div3BilateralTransform)
try:
a1 = Author.objects.create(name='a1', age=1)
a2 = Author.objects.create(name='a2', age=2)
a3 = Author.objects.create(name='a3', age=3)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__mult3__div3=42),
# mult3__div3 always leads to 0
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
baseqs.filter(age__div3__mult3=42),
[a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Mult3BilateralTransform)
models.IntegerField._unregister_lookup(Div3BilateralTransform)
def test_bilateral_fexpr(self):
models.IntegerField.register_lookup(Mult3BilateralTransform)
try:
a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)
a2 = Author.objects.create(name='a2', age=2, average_rating=0.5)
a3 = Author.objects.create(name='a3', age=3, average_rating=1.5)
a4 = Author.objects.create(name='a4', age=4)
baseqs = Author.objects.order_by('name')
self.assertQuerysetEqual(
baseqs.filter(age__mult3=models.F('age')),
[a1, a2, a3, a4], lambda x: x)
self.assertQuerysetEqual(
# Same as age >= average_rating
baseqs.filter(age__mult3__gte=models.F('average_rating')),
[a2, a3], lambda x: x)
finally:
models.IntegerField._unregister_lookup(Mult3BilateralTransform)
class YearLteTests(TestCase): class YearLteTests(TestCase):
def setUp(self): def setUp(self):
models.DateField.register_lookup(YearTransform) models.DateField.register_lookup(YearTransform)