mirror of https://github.com/django/django.git
Fixed #23493 -- Added bilateral attribute to Transform
This commit is contained in:
parent
6b39401baf
commit
00aa562884
|
@ -1,5 +1,4 @@
|
|||
from copy import copy
|
||||
from itertools import repeat
|
||||
import inspect
|
||||
|
||||
from django.conf import settings
|
||||
|
@ -7,6 +6,8 @@ from django.utils import timezone
|
|||
from django.utils.functional import cached_property
|
||||
from django.utils.six.moves import xrange
|
||||
|
||||
from .query_utils import QueryWrapper
|
||||
|
||||
|
||||
class RegisterLookupMixin(object):
|
||||
def _get_lookup(self, lookup_name):
|
||||
|
@ -57,6 +58,9 @@ class RegisterLookupMixin(object):
|
|||
|
||||
|
||||
class Transform(RegisterLookupMixin):
|
||||
|
||||
bilateral = False
|
||||
|
||||
def __init__(self, lhs, lookups):
|
||||
self.lhs = lhs
|
||||
self.init_lookups = lookups[:]
|
||||
|
@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin):
|
|||
class Lookup(RegisterLookupMixin):
|
||||
lookup_name = None
|
||||
|
||||
def __init__(self, lhs, rhs):
|
||||
def __init__(self, lhs, rhs, bilateral_transforms=None):
|
||||
self.lhs, self.rhs = lhs, rhs
|
||||
self.rhs = self.get_prep_lookup()
|
||||
if bilateral_transforms is None:
|
||||
bilateral_transforms = []
|
||||
if bilateral_transforms:
|
||||
# We should warn the user as soon as possible if he is trying to apply
|
||||
# a bilateral transformation on a nested QuerySet: that won't work.
|
||||
# We need to import QuerySet here so as to avoid circular
|
||||
from django.db.models.query import QuerySet
|
||||
if isinstance(rhs, QuerySet):
|
||||
raise NotImplementedError("Bilateral transformations on nested querysets are not supported.")
|
||||
self.bilateral_transforms = bilateral_transforms
|
||||
|
||||
def apply_bilateral_transforms(self, value):
|
||||
for transform, lookups in self.bilateral_transforms:
|
||||
value = transform(value, lookups)
|
||||
return value
|
||||
|
||||
def batch_process_rhs(self, qn, connection, rhs=None):
|
||||
if rhs is None:
|
||||
rhs = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
sqls, sqls_params = [], []
|
||||
for p in rhs:
|
||||
value = QueryWrapper('%s',
|
||||
[self.lhs.output_field.get_db_prep_value(p, connection)])
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
sql, sql_params = qn.compile(value)
|
||||
sqls.append(sql)
|
||||
sqls_params.extend(sql_params)
|
||||
else:
|
||||
params = self.lhs.output_field.get_db_prep_lookup(
|
||||
self.lookup_name, rhs, connection, prepared=True)
|
||||
sqls, sqls_params = ['%s'] * len(params), params
|
||||
return sqls, sqls_params
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs)
|
||||
|
@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin):
|
|||
|
||||
def process_rhs(self, qn, connection):
|
||||
value = self.rhs
|
||||
if self.bilateral_transforms:
|
||||
if self.rhs_is_direct_value():
|
||||
# Do not call get_db_prep_lookup here as the value will be
|
||||
# transformed before being used for lookup
|
||||
value = QueryWrapper("%s",
|
||||
[self.lhs.output_field.get_db_prep_value(value, connection)])
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
# Due to historical reasons there are a couple of different
|
||||
# ways to produce sql here. get_compiler is likely a Query
|
||||
# instance, _as_sql QuerySet and as_sql just something with
|
||||
|
@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual
|
|||
class In(BuiltinLookup):
|
||||
lookup_name = 'in'
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
params = self.lhs.output_field.get_db_prep_lookup(
|
||||
self.lookup_name, value, connection, prepared=True)
|
||||
if not params:
|
||||
# TODO: check why this leads to circular import
|
||||
from django.db.models.sql.datastructures import EmptyResultSet
|
||||
raise EmptyResultSet
|
||||
placeholder = '(' + ', '.join('%s' for p in params) + ')'
|
||||
return (placeholder, params)
|
||||
def process_rhs(self, qn, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
# rhs should be an iterable, we use batch_process_rhs
|
||||
# to prepare/transform those values
|
||||
rhs = list(self.rhs)
|
||||
if not rhs:
|
||||
from django.db.models.sql.datastructures import EmptyResultSet
|
||||
raise EmptyResultSet
|
||||
sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs)
|
||||
placeholder = '(' + ', '.join(sqls) + ')'
|
||||
return (placeholder, sqls_params)
|
||||
else:
|
||||
return super(In, self).process_rhs(qn, connection)
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return 'IN %s' % rhs
|
||||
|
@ -220,8 +268,10 @@ class In(BuiltinLookup):
|
|||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
if self.rhs_is_direct_value() and (max_in_list_size and
|
||||
len(self.rhs) > max_in_list_size):
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
# This is a special case for Oracle which limits the number of elements
|
||||
# which can appear in an 'IN' clause.
|
||||
lhs, lhs_params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.batch_process_rhs(qn, connection)
|
||||
in_clause_elements = ['(']
|
||||
params = []
|
||||
for offset in xrange(0, len(rhs_params), max_in_list_size):
|
||||
|
@ -229,11 +279,12 @@ class In(BuiltinLookup):
|
|||
in_clause_elements.append(' OR ')
|
||||
in_clause_elements.append('%s IN (' % lhs)
|
||||
params.extend(lhs_params)
|
||||
group_size = min(len(rhs_params) - offset, max_in_list_size)
|
||||
param_group = ', '.join(repeat('%s', group_size))
|
||||
sqls = rhs[offset: offset + max_in_list_size]
|
||||
sqls_params = rhs_params[offset: offset + max_in_list_size]
|
||||
param_group = ', '.join(sqls)
|
||||
in_clause_elements.append(param_group)
|
||||
in_clause_elements.append(')')
|
||||
params.extend(rhs_params[offset: offset + max_in_list_size])
|
||||
params.extend(sqls_params)
|
||||
in_clause_elements.append(')')
|
||||
return ''.join(in_clause_elements), params
|
||||
else:
|
||||
|
@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup):
|
|||
# we need to add the % pattern match to the lookup by something like
|
||||
# col LIKE othercol || '%%'
|
||||
# So, for Python values we don't need any special pattern, but for
|
||||
# SQL reference values we need the correct pattern added.
|
||||
value = self.rhs
|
||||
if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql')
|
||||
or hasattr(value, '_as_sql')):
|
||||
# SQL reference values or SQL transformations we need the correct
|
||||
# pattern added.
|
||||
if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql')
|
||||
or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms):
|
||||
return connection.pattern_ops[self.lookup_name] % rhs
|
||||
else:
|
||||
return super(PatternLookup, self).get_rhs_op(connection, rhs)
|
||||
|
@ -291,8 +342,20 @@ class Year(Between):
|
|||
default_lookups['year'] = Year
|
||||
|
||||
|
||||
class Range(Between):
|
||||
class Range(BuiltinLookup):
|
||||
lookup_name = 'range'
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
# rhs should be an iterable of 2 values, we use batch_process_rhs
|
||||
# to prepare/transform those values
|
||||
return self.batch_process_rhs(qn, connection)
|
||||
else:
|
||||
return super(Range, self).process_rhs(qn, connection)
|
||||
|
||||
default_lookups['range'] = Range
|
||||
|
||||
|
||||
|
|
|
@ -1111,18 +1111,21 @@ class Query(object):
|
|||
|
||||
def build_lookup(self, lookups, lhs, rhs):
|
||||
lookups = lookups[:]
|
||||
bilaterals = []
|
||||
while lookups:
|
||||
lookup = lookups[0]
|
||||
if len(lookups) == 1:
|
||||
final_lookup = lhs.get_lookup(lookup)
|
||||
if final_lookup:
|
||||
return final_lookup(lhs, rhs)
|
||||
return final_lookup(lhs, rhs, bilaterals)
|
||||
# We didn't find a lookup, so we are going to try get_transform
|
||||
# + get_lookup('exact').
|
||||
lookups.append('exact')
|
||||
next = lhs.get_transform(lookup)
|
||||
if next:
|
||||
lhs = next(lhs, lookups)
|
||||
if getattr(next, 'bilateral', False):
|
||||
bilaterals.append((next, lookups))
|
||||
else:
|
||||
raise FieldError(
|
||||
"Unsupported lookup '%s' for %s or join on the field not "
|
||||
|
|
|
@ -127,7 +127,7 @@ function ``ABS()`` to transform the value before comparison::
|
|||
lhs, params = qn.compile(self.lhs)
|
||||
return "ABS(%s)" % lhs, params
|
||||
|
||||
Next, lets register it for ``IntegerField``::
|
||||
Next, let's register it for ``IntegerField``::
|
||||
|
||||
from django.db.models import IntegerField
|
||||
IntegerField.register_lookup(AbsoluteValue)
|
||||
|
@ -144,9 +144,7 @@ SQL::
|
|||
|
||||
SELECT ... WHERE ABS("experiments"."change") < 27
|
||||
|
||||
Subclasses of ``Transform`` usually only operate on the left-hand side of the
|
||||
expression. Further lookups will work on the transformed value. Note that in
|
||||
this case where there is no other lookup specified, Django interprets
|
||||
Note that in case there is no other lookup specified, Django interprets
|
||||
``change__abs=27`` as ``change__abs__exact=27``.
|
||||
|
||||
When looking for which lookups are allowable after the ``Transform`` has been
|
||||
|
@ -197,7 +195,7 @@ Notice also that as both sides are used multiple times in the query the params
|
|||
need to contain ``lhs_params`` and ``rhs_params`` multiple times.
|
||||
|
||||
The final query does the inversion (``27`` to ``-27``) directly in the
|
||||
database. The reason for doing this is that if the self.rhs is something else
|
||||
database. The reason for doing this is that if the ``self.rhs`` is something else
|
||||
than a plain integer value (for example an ``F()`` reference) we can't do the
|
||||
transformations in Python.
|
||||
|
||||
|
@ -208,6 +206,46 @@ transformations in Python.
|
|||
want to add an index on ``abs(change)`` which would allow these queries to
|
||||
be very efficient.
|
||||
|
||||
A bilateral transformer example
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The ``AbsoluteValue`` example we discussed previously is a transformation which
|
||||
applies to the left-hand side of the lookup. There may be some cases where you
|
||||
want the transformation to be applied to both the left-hand side and the
|
||||
right-hand side. For instance, if you want to filter a queryset based on the
|
||||
equality of the left and right-hand side insensitively to some SQL function.
|
||||
|
||||
Let's examine the simple example of case-insensitive transformation here. This
|
||||
transformation isn't very useful in practice as Django already comes with a bunch
|
||||
of built-in case-insensitive lookups, but it will be a nice demonstration of
|
||||
bilateral transformations in a database-agnostic way.
|
||||
|
||||
We define an ``UpperCase`` transformer which uses the SQL function ``UPPER()`` to
|
||||
transform the values before comparison. We define
|
||||
:attr:`bilateral = True <django.db.models.Transform.bilateral>` to indicate that
|
||||
this transformation should apply to both ``lhs`` and ``rhs``::
|
||||
|
||||
from django.db.models import Transform
|
||||
|
||||
class UpperCase(Transform):
|
||||
lookup_name = 'upper'
|
||||
bilateral = True
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, params = qn.compile(self.lhs)
|
||||
return "UPPER(%s)" % lhs, params
|
||||
|
||||
Next, let's register it::
|
||||
|
||||
from django.db.models import CharField, TextField
|
||||
CharField.register_lookup(UpperCase)
|
||||
TextField.register_lookup(UpperCase)
|
||||
|
||||
Now, the queryset ``Author.objects.filter(name__upper="doe")`` will generate a case
|
||||
insensitive query like this::
|
||||
|
||||
SELECT ... WHERE UPPER("author"."name") = UPPER('doe')
|
||||
|
||||
Writing alternative implementations for existing lookups
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -129,6 +129,15 @@ Transform reference
|
|||
This class follows the :ref:`Query Expression API <query-expression>`, which
|
||||
implies that you can use ``<expression>__<transform1>__<transform2>``.
|
||||
|
||||
.. attribute:: bilateral
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
A boolean indicating whether this transformation should apply to both
|
||||
``lhs`` and ``rhs``. Bilateral transformations will be applied to ``rhs`` in
|
||||
the same order as they appear in the lookup expression. By default it is set
|
||||
to ``False``. For example usage, see :doc:`/howto/custom-lookups`.
|
||||
|
||||
.. attribute:: lhs
|
||||
|
||||
The left-hand side - what is being transformed. It must follow the
|
||||
|
|
|
@ -306,6 +306,11 @@ Models
|
|||
* :doc:`Custom Lookups</howto/custom-lookups>` can now be registered using
|
||||
a decorator pattern.
|
||||
|
||||
* The new :attr:`Transform.bilateral <django.db.models.Transform.bilateral>`
|
||||
attribute allows creating bilateral transformations. These transformations
|
||||
are applied to both ``lhs`` and ``rhs`` when used in a lookup expression,
|
||||
providing opportunities for more sophisticated lookups.
|
||||
|
||||
Signals
|
||||
^^^^^^^
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ class Div3Lookup(models.Lookup):
|
|||
lhs, params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
params.extend(rhs_params)
|
||||
return '%s %%%% 3 = %s' % (lhs, rhs), params
|
||||
return '(%s) %%%% 3 = %s' % (lhs, rhs), params
|
||||
|
||||
def as_oracle(self, qn, connection):
|
||||
lhs, params = self.process_lhs(qn, connection)
|
||||
|
@ -31,12 +31,32 @@ class Div3Transform(models.Transform):
|
|||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = qn.compile(self.lhs)
|
||||
return '%s %%%% 3' % (lhs,), lhs_params
|
||||
return '(%s) %%%% 3' % lhs, lhs_params
|
||||
|
||||
def as_oracle(self, qn, connection):
|
||||
lhs, lhs_params = qn.compile(self.lhs)
|
||||
return 'mod(%s, 3)' % lhs, lhs_params
|
||||
|
||||
class Div3BilateralTransform(Div3Transform):
|
||||
bilateral = True
|
||||
|
||||
|
||||
class Mult3BilateralTransform(models.Transform):
|
||||
bilateral = True
|
||||
lookup_name = 'mult3'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = qn.compile(self.lhs)
|
||||
return '3 * (%s)' % lhs, lhs_params
|
||||
|
||||
class UpperBilateralTransform(models.Transform):
|
||||
bilateral = True
|
||||
lookup_name = 'upper'
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = qn.compile(self.lhs)
|
||||
return 'UPPER(%s)' % lhs, lhs_params
|
||||
|
||||
|
||||
class YearTransform(models.Transform):
|
||||
lookup_name = 'year'
|
||||
|
@ -225,10 +245,112 @@ class LookupTests(TestCase):
|
|||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__in=[0, 2]),
|
||||
[a2, a3], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__in=[2, 4]),
|
||||
[a2], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__gte=3),
|
||||
[], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__range=(1, 2)),
|
||||
[a1, a2, a4], lambda x: x)
|
||||
finally:
|
||||
models.IntegerField._unregister_lookup(Div3Transform)
|
||||
|
||||
|
||||
class BilateralTransformTests(TestCase):
|
||||
|
||||
def test_bilateral_upper(self):
|
||||
models.CharField.register_lookup(UpperBilateralTransform)
|
||||
try:
|
||||
Author.objects.bulk_create([
|
||||
Author(name='Doe'),
|
||||
Author(name='doe'),
|
||||
Author(name='Foo'),
|
||||
])
|
||||
self.assertQuerysetEqual(
|
||||
Author.objects.filter(name__upper='doe'),
|
||||
["<Author: Doe>", "<Author: doe>"], ordered=False)
|
||||
finally:
|
||||
models.CharField._unregister_lookup(UpperBilateralTransform)
|
||||
|
||||
def test_bilateral_inner_qs(self):
|
||||
models.CharField.register_lookup(UpperBilateralTransform)
|
||||
try:
|
||||
with self.assertRaises(NotImplementedError):
|
||||
Author.objects.filter(name__upper__in=Author.objects.values_list('name'))
|
||||
finally:
|
||||
models.CharField._unregister_lookup(UpperBilateralTransform)
|
||||
|
||||
def test_div3_bilateral_extract(self):
|
||||
models.IntegerField.register_lookup(Div3BilateralTransform)
|
||||
try:
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
a2 = Author.objects.create(name='a2', age=2)
|
||||
a3 = Author.objects.create(name='a3', age=3)
|
||||
a4 = Author.objects.create(name='a4', age=4)
|
||||
baseqs = Author.objects.order_by('name')
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3=2),
|
||||
[a2], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__lte=3),
|
||||
[a3], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__in=[0, 2]),
|
||||
[a2, a3], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__in=[2, 4]),
|
||||
[a1, a2, a4], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__gte=3),
|
||||
[a1, a2, a3, a4], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__range=(1, 2)),
|
||||
[a1, a2, a4], lambda x: x)
|
||||
finally:
|
||||
models.IntegerField._unregister_lookup(Div3BilateralTransform)
|
||||
|
||||
def test_bilateral_order(self):
|
||||
models.IntegerField.register_lookup(Mult3BilateralTransform)
|
||||
models.IntegerField.register_lookup(Div3BilateralTransform)
|
||||
try:
|
||||
a1 = Author.objects.create(name='a1', age=1)
|
||||
a2 = Author.objects.create(name='a2', age=2)
|
||||
a3 = Author.objects.create(name='a3', age=3)
|
||||
a4 = Author.objects.create(name='a4', age=4)
|
||||
baseqs = Author.objects.order_by('name')
|
||||
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__mult3__div3=42),
|
||||
# mult3__div3 always leads to 0
|
||||
[a1, a2, a3, a4], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__div3__mult3=42),
|
||||
[a3], lambda x: x)
|
||||
finally:
|
||||
models.IntegerField._unregister_lookup(Mult3BilateralTransform)
|
||||
models.IntegerField._unregister_lookup(Div3BilateralTransform)
|
||||
|
||||
def test_bilateral_fexpr(self):
|
||||
models.IntegerField.register_lookup(Mult3BilateralTransform)
|
||||
try:
|
||||
a1 = Author.objects.create(name='a1', age=1, average_rating=3.2)
|
||||
a2 = Author.objects.create(name='a2', age=2, average_rating=0.5)
|
||||
a3 = Author.objects.create(name='a3', age=3, average_rating=1.5)
|
||||
a4 = Author.objects.create(name='a4', age=4)
|
||||
baseqs = Author.objects.order_by('name')
|
||||
self.assertQuerysetEqual(
|
||||
baseqs.filter(age__mult3=models.F('age')),
|
||||
[a1, a2, a3, a4], lambda x: x)
|
||||
self.assertQuerysetEqual(
|
||||
# Same as age >= average_rating
|
||||
baseqs.filter(age__mult3__gte=models.F('average_rating')),
|
||||
[a2, a3], lambda x: x)
|
||||
finally:
|
||||
models.IntegerField._unregister_lookup(Mult3BilateralTransform)
|
||||
|
||||
|
||||
class YearLteTests(TestCase):
|
||||
def setUp(self):
|
||||
models.DateField.register_lookup(YearTransform)
|
||||
|
|
Loading…
Reference in New Issue