diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 66bdde54b4..abb5645147 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,5 +1,4 @@ from copy import copy -from itertools import repeat import inspect from django.conf import settings @@ -7,6 +6,8 @@ from django.utils import timezone from django.utils.functional import cached_property from django.utils.six.moves import xrange +from .query_utils import QueryWrapper + class RegisterLookupMixin(object): def _get_lookup(self, lookup_name): @@ -57,6 +58,9 @@ class RegisterLookupMixin(object): class Transform(RegisterLookupMixin): + + bilateral = False + def __init__(self, lhs, lookups): self.lhs = lhs self.init_lookups = lookups[:] @@ -78,9 +82,42 @@ class Transform(RegisterLookupMixin): class Lookup(RegisterLookupMixin): lookup_name = None - def __init__(self, lhs, rhs): + def __init__(self, lhs, rhs, bilateral_transforms=None): self.lhs, self.rhs = lhs, rhs self.rhs = self.get_prep_lookup() + if bilateral_transforms is None: + bilateral_transforms = [] + if bilateral_transforms: + # We should warn the user as soon as possible if he is trying to apply + # a bilateral transformation on a nested QuerySet: that won't work. + # We need to import QuerySet here so as to avoid circular + from django.db.models.query import QuerySet + if isinstance(rhs, QuerySet): + raise NotImplementedError("Bilateral transformations on nested querysets are not supported.") + self.bilateral_transforms = bilateral_transforms + + def apply_bilateral_transforms(self, value): + for transform, lookups in self.bilateral_transforms: + value = transform(value, lookups) + return value + + def batch_process_rhs(self, qn, connection, rhs=None): + if rhs is None: + rhs = self.rhs + if self.bilateral_transforms: + sqls, sqls_params = [], [] + for p in rhs: + value = QueryWrapper('%s', + [self.lhs.output_field.get_db_prep_value(p, connection)]) + value = self.apply_bilateral_transforms(value) + sql, sql_params = qn.compile(value) + sqls.append(sql) + sqls_params.extend(sql_params) + else: + params = self.lhs.output_field.get_db_prep_lookup( + self.lookup_name, rhs, connection, prepared=True) + sqls, sqls_params = ['%s'] * len(params), params + return sqls, sqls_params def get_prep_lookup(self): return self.lhs.output_field.get_prep_lookup(self.lookup_name, self.rhs) @@ -96,6 +133,13 @@ class Lookup(RegisterLookupMixin): def process_rhs(self, qn, connection): value = self.rhs + if self.bilateral_transforms: + if self.rhs_is_direct_value(): + # Do not call get_db_prep_lookup here as the value will be + # transformed before being used for lookup + value = QueryWrapper("%s", + [self.lhs.output_field.get_db_prep_value(value, connection)]) + value = self.apply_bilateral_transforms(value) # Due to historical reasons there are a couple of different # ways to produce sql here. get_compiler is likely a Query # instance, _as_sql QuerySet and as_sql just something with @@ -203,15 +247,19 @@ default_lookups['lte'] = LessThanOrEqual class In(BuiltinLookup): lookup_name = 'in' - def get_db_prep_lookup(self, value, connection): - params = self.lhs.output_field.get_db_prep_lookup( - self.lookup_name, value, connection, prepared=True) - if not params: - # TODO: check why this leads to circular import - from django.db.models.sql.datastructures import EmptyResultSet - raise EmptyResultSet - placeholder = '(' + ', '.join('%s' for p in params) + ')' - return (placeholder, params) + def process_rhs(self, qn, connection): + if self.rhs_is_direct_value(): + # rhs should be an iterable, we use batch_process_rhs + # to prepare/transform those values + rhs = list(self.rhs) + if not rhs: + from django.db.models.sql.datastructures import EmptyResultSet + raise EmptyResultSet + sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs) + placeholder = '(' + ', '.join(sqls) + ')' + return (placeholder, sqls_params) + else: + return super(In, self).process_rhs(qn, connection) def get_rhs_op(self, connection, rhs): return 'IN %s' % rhs @@ -220,8 +268,10 @@ class In(BuiltinLookup): max_in_list_size = connection.ops.max_in_list_size() if self.rhs_is_direct_value() and (max_in_list_size and len(self.rhs) > max_in_list_size): - rhs, rhs_params = self.process_rhs(qn, connection) + # This is a special case for Oracle which limits the number of elements + # which can appear in an 'IN' clause. lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.batch_process_rhs(qn, connection) in_clause_elements = ['('] params = [] for offset in xrange(0, len(rhs_params), max_in_list_size): @@ -229,11 +279,12 @@ class In(BuiltinLookup): in_clause_elements.append(' OR ') in_clause_elements.append('%s IN (' % lhs) params.extend(lhs_params) - group_size = min(len(rhs_params) - offset, max_in_list_size) - param_group = ', '.join(repeat('%s', group_size)) + sqls = rhs[offset: offset + max_in_list_size] + sqls_params = rhs_params[offset: offset + max_in_list_size] + param_group = ', '.join(sqls) in_clause_elements.append(param_group) in_clause_elements.append(')') - params.extend(rhs_params[offset: offset + max_in_list_size]) + params.extend(sqls_params) in_clause_elements.append(')') return ''.join(in_clause_elements), params else: @@ -252,10 +303,10 @@ class PatternLookup(BuiltinLookup): # we need to add the % pattern match to the lookup by something like # col LIKE othercol || '%%' # So, for Python values we don't need any special pattern, but for - # SQL reference values we need the correct pattern added. - value = self.rhs - if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql') - or hasattr(value, '_as_sql')): + # SQL reference values or SQL transformations we need the correct + # pattern added. + if (hasattr(self.rhs, 'get_compiler') or hasattr(self.rhs, 'as_sql') + or hasattr(self.rhs, '_as_sql') or self.bilateral_transforms): return connection.pattern_ops[self.lookup_name] % rhs else: return super(PatternLookup, self).get_rhs_op(connection, rhs) @@ -291,8 +342,20 @@ class Year(Between): default_lookups['year'] = Year -class Range(Between): +class Range(BuiltinLookup): lookup_name = 'range' + + def get_rhs_op(self, connection, rhs): + return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) + + def process_rhs(self, qn, connection): + if self.rhs_is_direct_value(): + # rhs should be an iterable of 2 values, we use batch_process_rhs + # to prepare/transform those values + return self.batch_process_rhs(qn, connection) + else: + return super(Range, self).process_rhs(qn, connection) + default_lookups['range'] = Range diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 08af1fb008..b6690e4526 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1111,18 +1111,21 @@ class Query(object): def build_lookup(self, lookups, lhs, rhs): lookups = lookups[:] + bilaterals = [] while lookups: lookup = lookups[0] if len(lookups) == 1: final_lookup = lhs.get_lookup(lookup) if final_lookup: - return final_lookup(lhs, rhs) + return final_lookup(lhs, rhs, bilaterals) # We didn't find a lookup, so we are going to try get_transform # + get_lookup('exact'). lookups.append('exact') next = lhs.get_transform(lookup) if next: lhs = next(lhs, lookups) + if getattr(next, 'bilateral', False): + bilaterals.append((next, lookups)) else: raise FieldError( "Unsupported lookup '%s' for %s or join on the field not " diff --git a/docs/howto/custom-lookups.txt b/docs/howto/custom-lookups.txt index 820a2ef574..d3ed726ba3 100644 --- a/docs/howto/custom-lookups.txt +++ b/docs/howto/custom-lookups.txt @@ -127,7 +127,7 @@ function ``ABS()`` to transform the value before comparison:: lhs, params = qn.compile(self.lhs) return "ABS(%s)" % lhs, params -Next, lets register it for ``IntegerField``:: +Next, let's register it for ``IntegerField``:: from django.db.models import IntegerField IntegerField.register_lookup(AbsoluteValue) @@ -144,9 +144,7 @@ SQL:: SELECT ... WHERE ABS("experiments"."change") < 27 -Subclasses of ``Transform`` usually only operate on the left-hand side of the -expression. Further lookups will work on the transformed value. Note that in -this case where there is no other lookup specified, Django interprets +Note that in case there is no other lookup specified, Django interprets ``change__abs=27`` as ``change__abs__exact=27``. When looking for which lookups are allowable after the ``Transform`` has been @@ -197,7 +195,7 @@ Notice also that as both sides are used multiple times in the query the params need to contain ``lhs_params`` and ``rhs_params`` multiple times. The final query does the inversion (``27`` to ``-27``) directly in the -database. The reason for doing this is that if the self.rhs is something else +database. The reason for doing this is that if the ``self.rhs`` is something else than a plain integer value (for example an ``F()`` reference) we can't do the transformations in Python. @@ -208,6 +206,46 @@ transformations in Python. want to add an index on ``abs(change)`` which would allow these queries to be very efficient. +A bilateral transformer example +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``AbsoluteValue`` example we discussed previously is a transformation which +applies to the left-hand side of the lookup. There may be some cases where you +want the transformation to be applied to both the left-hand side and the +right-hand side. For instance, if you want to filter a queryset based on the +equality of the left and right-hand side insensitively to some SQL function. + +Let's examine the simple example of case-insensitive transformation here. This +transformation isn't very useful in practice as Django already comes with a bunch +of built-in case-insensitive lookups, but it will be a nice demonstration of +bilateral transformations in a database-agnostic way. + +We define an ``UpperCase`` transformer which uses the SQL function ``UPPER()`` to +transform the values before comparison. We define +:attr:`bilateral = True ` to indicate that +this transformation should apply to both ``lhs`` and ``rhs``:: + + from django.db.models import Transform + + class UpperCase(Transform): + lookup_name = 'upper' + bilateral = True + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return "UPPER(%s)" % lhs, params + +Next, let's register it:: + + from django.db.models import CharField, TextField + CharField.register_lookup(UpperCase) + TextField.register_lookup(UpperCase) + +Now, the queryset ``Author.objects.filter(name__upper="doe")`` will generate a case +insensitive query like this:: + + SELECT ... WHERE UPPER("author"."name") = UPPER('doe') + Writing alternative implementations for existing lookups ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/ref/models/lookups.txt b/docs/ref/models/lookups.txt index d3f64c07a9..da338b7cb2 100644 --- a/docs/ref/models/lookups.txt +++ b/docs/ref/models/lookups.txt @@ -129,6 +129,15 @@ Transform reference This class follows the :ref:`Query Expression API `, which implies that you can use ``____``. + .. attribute:: bilateral + + .. versionadded:: 1.8 + + A boolean indicating whether this transformation should apply to both + ``lhs`` and ``rhs``. Bilateral transformations will be applied to ``rhs`` in + the same order as they appear in the lookup expression. By default it is set + to ``False``. For example usage, see :doc:`/howto/custom-lookups`. + .. attribute:: lhs The left-hand side - what is being transformed. It must follow the diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 6b08e2b5a1..e5d3282874 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -306,6 +306,11 @@ Models * :doc:`Custom Lookups` can now be registered using a decorator pattern. +* The new :attr:`Transform.bilateral ` + attribute allows creating bilateral transformations. These transformations + are applied to both ``lhs`` and ``rhs`` when used in a lookup expression, + providing opportunities for more sophisticated lookups. + Signals ^^^^^^^ diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index a965e5a4e8..d0f18c5d7b 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -17,7 +17,7 @@ class Div3Lookup(models.Lookup): lhs, params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) params.extend(rhs_params) - return '%s %%%% 3 = %s' % (lhs, rhs), params + return '(%s) %%%% 3 = %s' % (lhs, rhs), params def as_oracle(self, qn, connection): lhs, params = self.process_lhs(qn, connection) @@ -31,12 +31,32 @@ class Div3Transform(models.Transform): def as_sql(self, qn, connection): lhs, lhs_params = qn.compile(self.lhs) - return '%s %%%% 3' % (lhs,), lhs_params + return '(%s) %%%% 3' % lhs, lhs_params def as_oracle(self, qn, connection): lhs, lhs_params = qn.compile(self.lhs) return 'mod(%s, 3)' % lhs, lhs_params +class Div3BilateralTransform(Div3Transform): + bilateral = True + + +class Mult3BilateralTransform(models.Transform): + bilateral = True + lookup_name = 'mult3' + + def as_sql(self, qn, connection): + lhs, lhs_params = qn.compile(self.lhs) + return '3 * (%s)' % lhs, lhs_params + +class UpperBilateralTransform(models.Transform): + bilateral = True + lookup_name = 'upper' + + def as_sql(self, qn, connection): + lhs, lhs_params = qn.compile(self.lhs) + return 'UPPER(%s)' % lhs, lhs_params + class YearTransform(models.Transform): lookup_name = 'year' @@ -225,10 +245,112 @@ class LookupTests(TestCase): self.assertQuerysetEqual( baseqs.filter(age__div3__in=[0, 2]), [a2, a3], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__in=[2, 4]), + [a2], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__gte=3), + [], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__range=(1, 2)), + [a1, a2, a4], lambda x: x) finally: models.IntegerField._unregister_lookup(Div3Transform) +class BilateralTransformTests(TestCase): + + def test_bilateral_upper(self): + models.CharField.register_lookup(UpperBilateralTransform) + try: + Author.objects.bulk_create([ + Author(name='Doe'), + Author(name='doe'), + Author(name='Foo'), + ]) + self.assertQuerysetEqual( + Author.objects.filter(name__upper='doe'), + ["", ""], ordered=False) + finally: + models.CharField._unregister_lookup(UpperBilateralTransform) + + def test_bilateral_inner_qs(self): + models.CharField.register_lookup(UpperBilateralTransform) + try: + with self.assertRaises(NotImplementedError): + Author.objects.filter(name__upper__in=Author.objects.values_list('name')) + finally: + models.CharField._unregister_lookup(UpperBilateralTransform) + + def test_div3_bilateral_extract(self): + models.IntegerField.register_lookup(Div3BilateralTransform) + try: + a1 = Author.objects.create(name='a1', age=1) + a2 = Author.objects.create(name='a2', age=2) + a3 = Author.objects.create(name='a3', age=3) + a4 = Author.objects.create(name='a4', age=4) + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(age__div3=2), + [a2], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__lte=3), + [a3], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__in=[0, 2]), + [a2, a3], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__in=[2, 4]), + [a1, a2, a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__gte=3), + [a1, a2, a3, a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__range=(1, 2)), + [a1, a2, a4], lambda x: x) + finally: + models.IntegerField._unregister_lookup(Div3BilateralTransform) + + def test_bilateral_order(self): + models.IntegerField.register_lookup(Mult3BilateralTransform) + models.IntegerField.register_lookup(Div3BilateralTransform) + try: + a1 = Author.objects.create(name='a1', age=1) + a2 = Author.objects.create(name='a2', age=2) + a3 = Author.objects.create(name='a3', age=3) + a4 = Author.objects.create(name='a4', age=4) + baseqs = Author.objects.order_by('name') + + self.assertQuerysetEqual( + baseqs.filter(age__mult3__div3=42), + # mult3__div3 always leads to 0 + [a1, a2, a3, a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__mult3=42), + [a3], lambda x: x) + finally: + models.IntegerField._unregister_lookup(Mult3BilateralTransform) + models.IntegerField._unregister_lookup(Div3BilateralTransform) + + def test_bilateral_fexpr(self): + models.IntegerField.register_lookup(Mult3BilateralTransform) + try: + a1 = Author.objects.create(name='a1', age=1, average_rating=3.2) + a2 = Author.objects.create(name='a2', age=2, average_rating=0.5) + a3 = Author.objects.create(name='a3', age=3, average_rating=1.5) + a4 = Author.objects.create(name='a4', age=4) + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(age__mult3=models.F('age')), + [a1, a2, a3, a4], lambda x: x) + self.assertQuerysetEqual( + # Same as age >= average_rating + baseqs.filter(age__mult3__gte=models.F('average_rating')), + [a2, a3], lambda x: x) + finally: + models.IntegerField._unregister_lookup(Mult3BilateralTransform) + + class YearLteTests(TestCase): def setUp(self): models.DateField.register_lookup(YearTransform)