mirror of https://github.com/django/django.git
Fixed #22288 -- Fixed F() expressions with the __range lookup.
This commit is contained in:
parent
f6cd669ff2
commit
4f138fe5a4
1
AUTHORS
1
AUTHORS
|
@ -498,6 +498,7 @@ answer newbie questions, and generally made Django that much better:
|
|||
Matthew Schinckel <matt@schinckel.net>
|
||||
Matthew Somerville <matthew-django@dracos.co.uk>
|
||||
Matthew Tretter <m@tthewwithanm.com>
|
||||
Matthew Wilkes <matt@matthewwilkes.name>
|
||||
Matthias Kestenholz <mk@406.ch>
|
||||
Matthias Pronk <django@masida.nl>
|
||||
Matt Hoskins <skaffenuk@googlemail.com>
|
||||
|
|
|
@ -239,7 +239,13 @@ class ArrayInLookup(In):
|
|||
values = super(ArrayInLookup, self).get_prep_lookup()
|
||||
# In.process_rhs() expects values to be hashable, so convert lists
|
||||
# to tuples.
|
||||
return [tuple(value) for value in values]
|
||||
prepared_values = []
|
||||
for value in values:
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
prepared_values.append(value)
|
||||
else:
|
||||
prepared_values.append(tuple(value))
|
||||
return prepared_values
|
||||
|
||||
|
||||
class IndexTransform(Transform):
|
||||
|
|
|
@ -155,6 +155,10 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
|
@ -171,6 +175,10 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware times
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("MySQL backend does not support timezone-aware times.")
|
||||
|
|
|
@ -408,6 +408,10 @@ WHEN (new.%(col_name)s IS NULL)
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# cx_Oracle doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
|
@ -421,6 +425,10 @@ WHEN (new.%(col_name)s IS NULL)
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
if isinstance(value, six.string_types):
|
||||
return datetime.datetime.strptime(value, '%H:%M:%S')
|
||||
|
||||
|
|
|
@ -182,6 +182,10 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
|
@ -195,6 +199,10 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("SQLite backend does not support timezone-aware times.")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from copy import copy
|
||||
|
@ -170,6 +171,12 @@ class FieldGetDbPrepValueMixin(object):
|
|||
"""
|
||||
get_db_prep_lookup_value_is_iterable = False
|
||||
|
||||
@classmethod
|
||||
def get_prep_lookup_value(cls, value, output_field):
|
||||
if hasattr(value, '_prepare'):
|
||||
return value._prepare(output_field)
|
||||
return output_field.get_prep_value(value)
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
# For relational fields, use the output_field of the 'field' attribute.
|
||||
field = getattr(self.lhs.output_field, 'field', None)
|
||||
|
@ -191,6 +198,51 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
|||
"""
|
||||
get_db_prep_lookup_value_is_iterable = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
prepared_values = []
|
||||
if hasattr(self.rhs, '_prepare'):
|
||||
# A subquery is like an iterable but its items shouldn't be
|
||||
# prepared independently.
|
||||
return self.rhs._prepare(self.lhs.output_field)
|
||||
for rhs_value in self.rhs:
|
||||
if hasattr(rhs_value, 'resolve_expression'):
|
||||
# An expression will be handled by the database but can coexist
|
||||
# alongside real values.
|
||||
pass
|
||||
elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
|
||||
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
|
||||
prepared_values.append(rhs_value)
|
||||
return prepared_values
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
# rhs should be an iterable of values. Use batch_process_rhs()
|
||||
# to prepare/transform those values.
|
||||
return self.batch_process_rhs(compiler, connection)
|
||||
else:
|
||||
return super(FieldGetDbPrepValueIterableMixin, self).process_rhs(compiler, connection)
|
||||
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
params = [param]
|
||||
if hasattr(param, 'resolve_expression'):
|
||||
param = param.resolve_expression(compiler.query)
|
||||
if hasattr(param, 'as_sql'):
|
||||
sql, params = param.as_sql(compiler, connection)
|
||||
return sql, params
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
pre_processed = super(FieldGetDbPrepValueIterableMixin, self).batch_process_rhs(compiler, connection, rhs)
|
||||
# The params list may contain expressions which compile to a
|
||||
# sql/param pair. Zip them to get sql and param pairs that refer to the
|
||||
# same argument and attempt to replace them with the result of
|
||||
# compiling the param step.
|
||||
sql, params = zip(*(
|
||||
self.resolve_expression_parameter(compiler, connection, sql, param)
|
||||
for sql, param in zip(*pre_processed)
|
||||
))
|
||||
params = itertools.chain.from_iterable(params)
|
||||
return sql, tuple(params)
|
||||
|
||||
|
||||
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = 'exact'
|
||||
|
@ -255,13 +307,6 @@ IntegerField.register_lookup(IntegerLessThan)
|
|||
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = 'in'
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if hasattr(self.rhs, '_prepare'):
|
||||
return self.rhs._prepare(self.lhs.output_field)
|
||||
if hasattr(self.lhs.output_field, 'get_prep_value'):
|
||||
return [self.lhs.output_field.get_prep_value(v) for v in self.rhs]
|
||||
return self.rhs
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
db_rhs = getattr(self.rhs, '_db', None)
|
||||
if db_rhs is not None and db_rhs != connection.alias:
|
||||
|
@ -409,21 +454,9 @@ Field.register_lookup(IEndsWith)
|
|||
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = 'range'
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if hasattr(self.rhs, '_prepare'):
|
||||
return self.rhs._prepare(self.lhs.output_field)
|
||||
return [self.lhs.output_field.get_prep_value(v) for v in self.rhs]
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
# rhs should be an iterable of 2 values, we use batch_process_rhs
|
||||
# to prepare/transform those values
|
||||
return self.batch_process_rhs(compiler, connection)
|
||||
else:
|
||||
return super(Range, self).process_rhs(compiler, connection)
|
||||
Field.register_lookup(Range)
|
||||
|
||||
|
||||
|
|
|
@ -990,6 +990,20 @@ class Query(object):
|
|||
pre_joins = self.alias_refcount.copy()
|
||||
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
|
||||
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# The items of the iterable may be expressions and therefore need
|
||||
# to be resolved independently.
|
||||
processed_values = []
|
||||
used_joins = set()
|
||||
for sub_value in value:
|
||||
if hasattr(sub_value, 'resolve_expression'):
|
||||
pre_joins = self.alias_refcount.copy()
|
||||
processed_values.append(
|
||||
sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
|
||||
)
|
||||
# The used_joins for a tuple of expressions is the union of
|
||||
# the used_joins for the individual expressions.
|
||||
used_joins |= set(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0))
|
||||
# Subqueries need to use a different set of aliases than the
|
||||
# outer query. Call bump_prefix to change aliases of the inner
|
||||
# query (the value).
|
||||
|
|
|
@ -234,6 +234,9 @@ Models
|
|||
* Added support for expressions in :meth:`.QuerySet.values` and
|
||||
:meth:`~.QuerySet.values_list`.
|
||||
|
||||
* Added support for query expressions on lookups that take multiple arguments,
|
||||
such as ``range``.
|
||||
|
||||
Requests and Responses
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -61,6 +61,15 @@ class Experiment(models.Model):
|
|||
return self.end - self.start
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Result(models.Model):
|
||||
experiment = models.ForeignKey(Experiment, models.CASCADE)
|
||||
result_time = models.DateTimeField()
|
||||
|
||||
def __str__(self):
|
||||
return "Result at %s" % self.result_time
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Time(models.Model):
|
||||
time = models.TimeField(null=True)
|
||||
|
@ -69,6 +78,16 @@ class Time(models.Model):
|
|||
return "%s" % self.time
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class SimulationRun(models.Model):
|
||||
start = models.ForeignKey(Time, models.CASCADE, null=True)
|
||||
end = models.ForeignKey(Time, models.CASCADE, null=True)
|
||||
midpoint = models.TimeField()
|
||||
|
||||
def __str__(self):
|
||||
return "%s (%s to %s)" % (self.midpoint, self.start, self.end)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class UUID(models.Model):
|
||||
uuid = models.UUIDField(null=True)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import unittest
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
|
||||
|
@ -17,11 +18,15 @@ from django.db.models.expressions import (
|
|||
from django.db.models.functions import (
|
||||
Coalesce, Concat, Length, Lower, Substr, Upper,
|
||||
)
|
||||
from django.db.models.sql import constants
|
||||
from django.db.models.sql.datastructures import Join
|
||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
from django.test.utils import Approximate
|
||||
from django.utils import six
|
||||
|
||||
from .models import UUID, Company, Employee, Experiment, Number, Time
|
||||
from .models import (
|
||||
UUID, Company, Employee, Experiment, Number, Result, SimulationRun, Time,
|
||||
)
|
||||
|
||||
|
||||
class BasicExpressionsTests(TestCase):
|
||||
|
@ -391,6 +396,144 @@ class BasicExpressionsTests(TestCase):
|
|||
self.assertEqual(str(qs.query).count('JOIN'), 2)
|
||||
|
||||
|
||||
class IterableLookupInnerExpressionsTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
ceo = Employee.objects.create(firstname='Just', lastname='Doit', salary=30)
|
||||
# MySQL requires that the values calculated for expressions don't pass
|
||||
# outside of the field's range, so it's inconvenient to use the values
|
||||
# in the more general tests.
|
||||
Company.objects.create(name='5020 Ltd', num_employees=50, num_chairs=20, ceo=ceo)
|
||||
Company.objects.create(name='5040 Ltd', num_employees=50, num_chairs=40, ceo=ceo)
|
||||
Company.objects.create(name='5050 Ltd', num_employees=50, num_chairs=50, ceo=ceo)
|
||||
Company.objects.create(name='5060 Ltd', num_employees=50, num_chairs=60, ceo=ceo)
|
||||
Company.objects.create(name='99300 Ltd', num_employees=99, num_chairs=300, ceo=ceo)
|
||||
|
||||
def test_in_lookup_allows_F_expressions_and_expressions_for_integers(self):
|
||||
# __in lookups can use F() expressions for integers.
|
||||
queryset = Company.objects.filter(num_employees__in=([F('num_chairs') - 10]))
|
||||
self.assertQuerysetEqual(queryset, ['<Company: 5060 Ltd>'], ordered=False)
|
||||
self.assertQuerysetEqual(
|
||||
Company.objects.filter(num_employees__in=([F('num_chairs') - 10, F('num_chairs') + 10])),
|
||||
['<Company: 5040 Ltd>', '<Company: 5060 Ltd>'],
|
||||
ordered=False
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Company.objects.filter(
|
||||
num_employees__in=([F('num_chairs') - 10, F('num_chairs'), F('num_chairs') + 10])
|
||||
),
|
||||
['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
|
||||
ordered=False
|
||||
)
|
||||
|
||||
def test_expressions_in_lookups_join_choice(self):
|
||||
midpoint = datetime.time(13, 0)
|
||||
t1 = Time.objects.create(time=datetime.time(12, 0))
|
||||
t2 = Time.objects.create(time=datetime.time(14, 0))
|
||||
SimulationRun.objects.create(start=t1, end=t2, midpoint=midpoint)
|
||||
SimulationRun.objects.create(start=t1, end=None, midpoint=midpoint)
|
||||
SimulationRun.objects.create(start=None, end=t2, midpoint=midpoint)
|
||||
SimulationRun.objects.create(start=None, end=None, midpoint=midpoint)
|
||||
|
||||
queryset = SimulationRun.objects.filter(midpoint__range=[F('start__time'), F('end__time')])
|
||||
self.assertQuerysetEqual(
|
||||
queryset,
|
||||
['<SimulationRun: 13:00:00 (12:00:00 to 14:00:00)>'],
|
||||
ordered=False
|
||||
)
|
||||
for alias in queryset.query.alias_map.values():
|
||||
if isinstance(alias, Join):
|
||||
self.assertEqual(alias.join_type, constants.INNER)
|
||||
|
||||
queryset = SimulationRun.objects.exclude(midpoint__range=[F('start__time'), F('end__time')])
|
||||
self.assertQuerysetEqual(queryset, [], ordered=False)
|
||||
for alias in queryset.query.alias_map.values():
|
||||
if isinstance(alias, Join):
|
||||
self.assertEqual(alias.join_type, constants.LOUTER)
|
||||
|
||||
def test_range_lookup_allows_F_expressions_and_expressions_for_integers(self):
|
||||
# Range lookups can use F() expressions for integers.
|
||||
Company.objects.filter(num_employees__exact=F("num_chairs"))
|
||||
self.assertQuerysetEqual(
|
||||
Company.objects.filter(num_employees__range=(F('num_chairs'), 100)),
|
||||
['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>'],
|
||||
ordered=False
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Company.objects.filter(num_employees__range=(F('num_chairs') - 10, F('num_chairs') + 10)),
|
||||
['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
|
||||
ordered=False
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Company.objects.filter(num_employees__range=(F('num_chairs') - 10, 100)),
|
||||
['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
|
||||
ordered=False
|
||||
)
|
||||
self.assertQuerysetEqual(
|
||||
Company.objects.filter(num_employees__range=(1, 100)),
|
||||
[
|
||||
'<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>',
|
||||
'<Company: 5060 Ltd>', '<Company: 99300 Ltd>',
|
||||
],
|
||||
ordered=False
|
||||
)
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'sqlite',
|
||||
"This defensive test only works on databases that don't validate parameter types")
|
||||
def test_complex_expressions_do_not_introduce_sql_injection_via_untrusted_string_inclusion(self):
|
||||
"""
|
||||
This tests that SQL injection isn't possible using compilation of
|
||||
expressions in iterable filters, as their compilation happens before
|
||||
the main query compilation. It's limited to SQLite, as PostgreSQL,
|
||||
Oracle and other vendors have defense in depth against this by type
|
||||
checking. Testing against SQLite (the most permissive of the built-in
|
||||
databases) demonstrates that the problem doesn't exist while keeping
|
||||
the test simple.
|
||||
"""
|
||||
queryset = Company.objects.filter(name__in=[F('num_chairs') + '1)) OR ((1==1'])
|
||||
self.assertQuerysetEqual(queryset, [], ordered=False)
|
||||
|
||||
def test_in_lookup_allows_F_expressions_and_expressions_for_datetimes(self):
|
||||
start = datetime.datetime(2016, 2, 3, 15, 0, 0)
|
||||
end = datetime.datetime(2016, 2, 5, 15, 0, 0)
|
||||
experiment_1 = Experiment.objects.create(
|
||||
name='Integrity testing',
|
||||
assigned=start.date(),
|
||||
start=start,
|
||||
end=end,
|
||||
completed=end.date(),
|
||||
estimated_time=end - start,
|
||||
)
|
||||
experiment_2 = Experiment.objects.create(
|
||||
name='Taste testing',
|
||||
assigned=start.date(),
|
||||
start=start,
|
||||
end=end,
|
||||
completed=end.date(),
|
||||
estimated_time=end - start,
|
||||
)
|
||||
Result.objects.create(
|
||||
experiment=experiment_1,
|
||||
result_time=datetime.datetime(2016, 2, 4, 15, 0, 0),
|
||||
)
|
||||
Result.objects.create(
|
||||
experiment=experiment_1,
|
||||
result_time=datetime.datetime(2016, 3, 10, 2, 0, 0),
|
||||
)
|
||||
Result.objects.create(
|
||||
experiment=experiment_2,
|
||||
result_time=datetime.datetime(2016, 1, 8, 5, 0, 0),
|
||||
)
|
||||
|
||||
within_experiment_time = [F('experiment__start'), F('experiment__end')]
|
||||
queryset = Result.objects.filter(result_time__range=within_experiment_time)
|
||||
self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
|
||||
|
||||
within_experiment_time = [F('experiment__start'), F('experiment__end')]
|
||||
queryset = Result.objects.filter(result_time__range=within_experiment_time)
|
||||
self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
|
||||
|
||||
|
||||
class ExpressionsTests(TestCase):
|
||||
|
||||
def test_F_object_deepcopy(self):
|
||||
|
|
|
@ -173,12 +173,40 @@ class TestQuerying(PostgreSQLTestCase):
|
|||
self.objs[:2]
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_in_including_F_object(self):
|
||||
# This test asserts that Array objects passed to filters can be
|
||||
# constructed to contain F objects. This currently doesn't work as the
|
||||
# psycopg2 mogrify method that generates the ARRAY() syntax is
|
||||
# expecting literals, not column references (#27095).
|
||||
self.assertSequenceEqual(
|
||||
NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]),
|
||||
self.objs[:2]
|
||||
)
|
||||
|
||||
def test_in_as_F_object(self):
|
||||
self.assertSequenceEqual(
|
||||
NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]),
|
||||
self.objs[:4]
|
||||
)
|
||||
|
||||
def test_contained_by(self):
|
||||
self.assertSequenceEqual(
|
||||
NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
|
||||
self.objs[:2]
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_contained_by_including_F_object(self):
|
||||
# This test asserts that Array objects passed to filters can be
|
||||
# constructed to contain F objects. This currently doesn't work as the
|
||||
# psycopg2 mogrify method that generates the ARRAY() syntax is
|
||||
# expecting literals, not column references (#27095).
|
||||
self.assertSequenceEqual(
|
||||
NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]),
|
||||
self.objs[:2]
|
||||
)
|
||||
|
||||
def test_contains(self):
|
||||
self.assertSequenceEqual(
|
||||
NullableIntegerArrayModel.objects.filter(field__contains=[2]),
|
||||
|
|
Loading…
Reference in New Issue