Fixed #22288 -- Fixed F() expressions with the __range lookup.

This commit is contained in:
Matthew Wilkes 2016-06-02 11:05:25 -07:00 committed by Tim Graham
parent f6cd669ff2
commit 4f138fe5a4
11 changed files with 292 additions and 21 deletions

View File

@ -498,6 +498,7 @@ answer newbie questions, and generally made Django that much better:
Matthew Schinckel <matt@schinckel.net>
Matthew Somerville <matthew-django@dracos.co.uk>
Matthew Tretter <m@tthewwithanm.com>
Matthew Wilkes <matt@matthewwilkes.name>
Matthias Kestenholz <mk@406.ch>
Matthias Pronk <django@masida.nl>
Matt Hoskins <skaffenuk@googlemail.com>

View File

@ -239,7 +239,13 @@ class ArrayInLookup(In):
values = super(ArrayInLookup, self).get_prep_lookup()
# In.process_rhs() expects values to be hashable, so convert lists
# to tuples.
return [tuple(value) for value in values]
prepared_values = []
for value in values:
if hasattr(value, 'resolve_expression'):
prepared_values.append(value)
else:
prepared_values.append(tuple(value))
return prepared_values
class IndexTransform(Transform):

View File

@ -155,6 +155,10 @@ class DatabaseOperations(BaseDatabaseOperations):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
@ -171,6 +175,10 @@ class DatabaseOperations(BaseDatabaseOperations):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware times.")

View File

@ -408,6 +408,10 @@ WHEN (new.%(col_name)s IS NULL)
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# cx_Oracle doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
@ -421,6 +425,10 @@ WHEN (new.%(col_name)s IS NULL)
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
if isinstance(value, six.string_types):
return datetime.datetime.strptime(value, '%H:%M:%S')

View File

@ -182,6 +182,10 @@ class DatabaseOperations(BaseDatabaseOperations):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
if settings.USE_TZ:
@ -195,6 +199,10 @@ class DatabaseOperations(BaseDatabaseOperations):
if value is None:
return None
# Expression values are adapted by the database.
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
if timezone.is_aware(value):
raise ValueError("SQLite backend does not support timezone-aware times.")

View File

@ -1,3 +1,4 @@
import itertools
import math
import warnings
from copy import copy
@ -170,6 +171,12 @@ class FieldGetDbPrepValueMixin(object):
"""
get_db_prep_lookup_value_is_iterable = False
@classmethod
def get_prep_lookup_value(cls, value, output_field):
if hasattr(value, '_prepare'):
return value._prepare(output_field)
return output_field.get_prep_value(value)
def get_db_prep_lookup(self, value, connection):
# For relational fields, use the output_field of the 'field' attribute.
field = getattr(self.lhs.output_field, 'field', None)
@ -191,6 +198,51 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
"""
get_db_prep_lookup_value_is_iterable = True
def get_prep_lookup(self):
prepared_values = []
if hasattr(self.rhs, '_prepare'):
# A subquery is like an iterable but its items shouldn't be
# prepared independently.
return self.rhs._prepare(self.lhs.output_field)
for rhs_value in self.rhs:
if hasattr(rhs_value, 'resolve_expression'):
# An expression will be handled by the database but can coexist
# alongside real values.
pass
elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
prepared_values.append(rhs_value)
return prepared_values
def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable of values. Use batch_process_rhs()
# to prepare/transform those values.
return self.batch_process_rhs(compiler, connection)
else:
return super(FieldGetDbPrepValueIterableMixin, self).process_rhs(compiler, connection)
def resolve_expression_parameter(self, compiler, connection, sql, param):
params = [param]
if hasattr(param, 'resolve_expression'):
param = param.resolve_expression(compiler.query)
if hasattr(param, 'as_sql'):
sql, params = param.as_sql(compiler, connection)
return sql, params
def batch_process_rhs(self, compiler, connection, rhs=None):
pre_processed = super(FieldGetDbPrepValueIterableMixin, self).batch_process_rhs(compiler, connection, rhs)
# The params list may contain expressions which compile to a
# sql/param pair. Zip them to get sql and param pairs that refer to the
# same argument and attempt to replace them with the result of
# compiling the param step.
sql, params = zip(*(
self.resolve_expression_parameter(compiler, connection, sql, param)
for sql, param in zip(*pre_processed)
))
params = itertools.chain.from_iterable(params)
return sql, tuple(params)
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = 'exact'
@ -255,13 +307,6 @@ IntegerField.register_lookup(IntegerLessThan)
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = 'in'
def get_prep_lookup(self):
if hasattr(self.rhs, '_prepare'):
return self.rhs._prepare(self.lhs.output_field)
if hasattr(self.lhs.output_field, 'get_prep_value'):
return [self.lhs.output_field.get_prep_value(v) for v in self.rhs]
return self.rhs
def process_rhs(self, compiler, connection):
db_rhs = getattr(self.rhs, '_db', None)
if db_rhs is not None and db_rhs != connection.alias:
@ -409,21 +454,9 @@ Field.register_lookup(IEndsWith)
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = 'range'
def get_prep_lookup(self):
if hasattr(self.rhs, '_prepare'):
return self.rhs._prepare(self.lhs.output_field)
return [self.lhs.output_field.get_prep_value(v) for v in self.rhs]
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable of 2 values, we use batch_process_rhs
# to prepare/transform those values
return self.batch_process_rhs(compiler, connection)
else:
return super(Range, self).process_rhs(compiler, connection)
Field.register_lookup(Range)

View File

@ -990,6 +990,20 @@ class Query(object):
pre_joins = self.alias_refcount.copy()
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
elif isinstance(value, (list, tuple)):
# The items of the iterable may be expressions and therefore need
# to be resolved independently.
processed_values = []
used_joins = set()
for sub_value in value:
if hasattr(sub_value, 'resolve_expression'):
pre_joins = self.alias_refcount.copy()
processed_values.append(
sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
)
# The used_joins for a tuple of expressions is the union of
# the used_joins for the individual expressions.
used_joins |= set(k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0))
# Subqueries need to use a different set of aliases than the
# outer query. Call bump_prefix to change aliases of the inner
# query (the value).

View File

@ -234,6 +234,9 @@ Models
* Added support for expressions in :meth:`.QuerySet.values` and
:meth:`~.QuerySet.values_list`.
* Added support for query expressions on lookups that take multiple arguments,
such as ``range``.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -61,6 +61,15 @@ class Experiment(models.Model):
return self.end - self.start
@python_2_unicode_compatible
class Result(models.Model):
experiment = models.ForeignKey(Experiment, models.CASCADE)
result_time = models.DateTimeField()
def __str__(self):
return "Result at %s" % self.result_time
@python_2_unicode_compatible
class Time(models.Model):
time = models.TimeField(null=True)
@ -69,6 +78,16 @@ class Time(models.Model):
return "%s" % self.time
@python_2_unicode_compatible
class SimulationRun(models.Model):
start = models.ForeignKey(Time, models.CASCADE, null=True)
end = models.ForeignKey(Time, models.CASCADE, null=True)
midpoint = models.TimeField()
def __str__(self):
return "%s (%s to %s)" % (self.midpoint, self.start, self.end)
@python_2_unicode_compatible
class UUID(models.Model):
uuid = models.UUIDField(null=True)

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import datetime
import unittest
import uuid
from copy import deepcopy
@ -17,11 +18,15 @@ from django.db.models.expressions import (
from django.db.models.functions import (
Coalesce, Concat, Length, Lower, Substr, Upper,
)
from django.db.models.sql import constants
from django.db.models.sql.datastructures import Join
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate
from django.utils import six
from .models import UUID, Company, Employee, Experiment, Number, Time
from .models import (
UUID, Company, Employee, Experiment, Number, Result, SimulationRun, Time,
)
class BasicExpressionsTests(TestCase):
@ -391,6 +396,144 @@ class BasicExpressionsTests(TestCase):
self.assertEqual(str(qs.query).count('JOIN'), 2)
class IterableLookupInnerExpressionsTests(TestCase):
@classmethod
def setUpTestData(cls):
ceo = Employee.objects.create(firstname='Just', lastname='Doit', salary=30)
# MySQL requires that the values calculated for expressions don't pass
# outside of the field's range, so it's inconvenient to use the values
# in the more general tests.
Company.objects.create(name='5020 Ltd', num_employees=50, num_chairs=20, ceo=ceo)
Company.objects.create(name='5040 Ltd', num_employees=50, num_chairs=40, ceo=ceo)
Company.objects.create(name='5050 Ltd', num_employees=50, num_chairs=50, ceo=ceo)
Company.objects.create(name='5060 Ltd', num_employees=50, num_chairs=60, ceo=ceo)
Company.objects.create(name='99300 Ltd', num_employees=99, num_chairs=300, ceo=ceo)
def test_in_lookup_allows_F_expressions_and_expressions_for_integers(self):
# __in lookups can use F() expressions for integers.
queryset = Company.objects.filter(num_employees__in=([F('num_chairs') - 10]))
self.assertQuerysetEqual(queryset, ['<Company: 5060 Ltd>'], ordered=False)
self.assertQuerysetEqual(
Company.objects.filter(num_employees__in=([F('num_chairs') - 10, F('num_chairs') + 10])),
['<Company: 5040 Ltd>', '<Company: 5060 Ltd>'],
ordered=False
)
self.assertQuerysetEqual(
Company.objects.filter(
num_employees__in=([F('num_chairs') - 10, F('num_chairs'), F('num_chairs') + 10])
),
['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
ordered=False
)
def test_expressions_in_lookups_join_choice(self):
midpoint = datetime.time(13, 0)
t1 = Time.objects.create(time=datetime.time(12, 0))
t2 = Time.objects.create(time=datetime.time(14, 0))
SimulationRun.objects.create(start=t1, end=t2, midpoint=midpoint)
SimulationRun.objects.create(start=t1, end=None, midpoint=midpoint)
SimulationRun.objects.create(start=None, end=t2, midpoint=midpoint)
SimulationRun.objects.create(start=None, end=None, midpoint=midpoint)
queryset = SimulationRun.objects.filter(midpoint__range=[F('start__time'), F('end__time')])
self.assertQuerysetEqual(
queryset,
['<SimulationRun: 13:00:00 (12:00:00 to 14:00:00)>'],
ordered=False
)
for alias in queryset.query.alias_map.values():
if isinstance(alias, Join):
self.assertEqual(alias.join_type, constants.INNER)
queryset = SimulationRun.objects.exclude(midpoint__range=[F('start__time'), F('end__time')])
self.assertQuerysetEqual(queryset, [], ordered=False)
for alias in queryset.query.alias_map.values():
if isinstance(alias, Join):
self.assertEqual(alias.join_type, constants.LOUTER)
def test_range_lookup_allows_F_expressions_and_expressions_for_integers(self):
# Range lookups can use F() expressions for integers.
Company.objects.filter(num_employees__exact=F("num_chairs"))
self.assertQuerysetEqual(
Company.objects.filter(num_employees__range=(F('num_chairs'), 100)),
['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>'],
ordered=False
)
self.assertQuerysetEqual(
Company.objects.filter(num_employees__range=(F('num_chairs') - 10, F('num_chairs') + 10)),
['<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
ordered=False
)
self.assertQuerysetEqual(
Company.objects.filter(num_employees__range=(F('num_chairs') - 10, 100)),
['<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>', '<Company: 5060 Ltd>'],
ordered=False
)
self.assertQuerysetEqual(
Company.objects.filter(num_employees__range=(1, 100)),
[
'<Company: 5020 Ltd>', '<Company: 5040 Ltd>', '<Company: 5050 Ltd>',
'<Company: 5060 Ltd>', '<Company: 99300 Ltd>',
],
ordered=False
)
@unittest.skipUnless(connection.vendor == 'sqlite',
"This defensive test only works on databases that don't validate parameter types")
def test_complex_expressions_do_not_introduce_sql_injection_via_untrusted_string_inclusion(self):
"""
This tests that SQL injection isn't possible using compilation of
expressions in iterable filters, as their compilation happens before
the main query compilation. It's limited to SQLite, as PostgreSQL,
Oracle and other vendors have defense in depth against this by type
checking. Testing against SQLite (the most permissive of the built-in
databases) demonstrates that the problem doesn't exist while keeping
the test simple.
"""
queryset = Company.objects.filter(name__in=[F('num_chairs') + '1)) OR ((1==1'])
self.assertQuerysetEqual(queryset, [], ordered=False)
def test_in_lookup_allows_F_expressions_and_expressions_for_datetimes(self):
start = datetime.datetime(2016, 2, 3, 15, 0, 0)
end = datetime.datetime(2016, 2, 5, 15, 0, 0)
experiment_1 = Experiment.objects.create(
name='Integrity testing',
assigned=start.date(),
start=start,
end=end,
completed=end.date(),
estimated_time=end - start,
)
experiment_2 = Experiment.objects.create(
name='Taste testing',
assigned=start.date(),
start=start,
end=end,
completed=end.date(),
estimated_time=end - start,
)
Result.objects.create(
experiment=experiment_1,
result_time=datetime.datetime(2016, 2, 4, 15, 0, 0),
)
Result.objects.create(
experiment=experiment_1,
result_time=datetime.datetime(2016, 3, 10, 2, 0, 0),
)
Result.objects.create(
experiment=experiment_2,
result_time=datetime.datetime(2016, 1, 8, 5, 0, 0),
)
within_experiment_time = [F('experiment__start'), F('experiment__end')]
queryset = Result.objects.filter(result_time__range=within_experiment_time)
self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
within_experiment_time = [F('experiment__start'), F('experiment__end')]
queryset = Result.objects.filter(result_time__range=within_experiment_time)
self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
class ExpressionsTests(TestCase):
def test_F_object_deepcopy(self):

View File

@ -173,12 +173,40 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[:2]
)
@unittest.expectedFailure
def test_in_including_F_object(self):
# This test asserts that Array objects passed to filters can be
# constructed to contain F objects. This currently doesn't work as the
# psycopg2 mogrify method that generates the ARRAY() syntax is
# expecting literals, not column references (#27095).
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__in=[[models.F('id')]]),
self.objs[:2]
)
def test_in_as_F_object(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__in=[models.F('field')]),
self.objs[:4]
)
def test_contained_by(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
self.objs[:2]
)
@unittest.expectedFailure
def test_contained_by_including_F_object(self):
# This test asserts that Array objects passed to filters can be
# constructed to contain F objects. This currently doesn't work as the
# psycopg2 mogrify method that generates the ARRAY() syntax is
# expecting literals, not column references (#27095).
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]),
self.objs[:2]
)
def test_contains(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__contains=[2]),