Fixed #17728 -- When filtering an annotation, ensured the values used in the filter are properly converted to their database representation. This bug was particularly visible with timezone-aware DateTimeFields. Thanks gg for the report and Carl for the review.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17576 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Aymeric Augustin 2012-02-22 19:40:27 +00:00
parent c870318996
commit 8b53616198
4 changed files with 80 additions and 34 deletions

View File

@ -2,18 +2,11 @@
Classes to represent the default SQL aggregate functions Classes to represent the default SQL aggregate functions
""" """
class AggregateField(object): from django.db.models.fields import IntegerField, FloatField
"""An internal field mockup used to identify aggregates in the
data-conversion parts of the database backend.
"""
def __init__(self, internal_type):
self.internal_type = internal_type
def get_internal_type(self): # Fake fields used to identify aggregate types in data-conversion operations.
return self.internal_type ordinal_aggregate_field = IntegerField()
computed_aggregate_field = FloatField()
ordinal_aggregate_field = AggregateField('IntegerField')
computed_aggregate_field = AggregateField('FloatField')
class Aggregate(object): class Aggregate(object):
""" """

View File

@ -10,6 +10,7 @@ from itertools import repeat
from django.utils import tree from django.utils import tree
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet
from django.db.models.sql.aggregates import Aggregate
# Connection types # Connection types
AND = 'AND' AND = 'AND'
@ -30,9 +31,8 @@ class WhereNode(tree.Node):
the correct SQL). the correct SQL).
The children in this tree are usually either Q-like objects or lists of The children in this tree are usually either Q-like objects or lists of
[table_alias, field_name, db_type, lookup_type, value_annotation, [table_alias, field_name, db_type, lookup_type, value_annotation, params].
params]. However, a child could also be any class with as_sql() and However, a child could also be any class with as_sql() and relabel_aliases() methods.
relabel_aliases() methods.
""" """
default = AND default = AND
@ -54,25 +54,22 @@ class WhereNode(tree.Node):
# emptiness and transform any non-empty values correctly. # emptiness and transform any non-empty values correctly.
value = list(value) value = list(value)
# The "annotation" parameter is used to pass auxilliary information # The "value_annotation" parameter is used to pass auxilliary information
# about the value(s) to the query construction. Specifically, datetime # about the value(s) to the query construction. Specifically, datetime
# and empty values need special handling. Other types could be used # and empty values need special handling. Other types could be used
# here in the future (using Python types is suggested for consistency). # here in the future (using Python types is suggested for consistency).
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
annotation = datetime.datetime value_annotation = datetime.datetime
elif hasattr(value, 'value_annotation'): elif hasattr(value, 'value_annotation'):
annotation = value.value_annotation value_annotation = value.value_annotation
else: else:
annotation = bool(value) value_annotation = bool(value)
if hasattr(obj, "prepare"): if hasattr(obj, "prepare"):
value = obj.prepare(lookup_type, value) value = obj.prepare(lookup_type, value)
super(WhereNode, self).add((obj, lookup_type, annotation, value),
connector)
return
super(WhereNode, self).add((obj, lookup_type, annotation, value), super(WhereNode, self).add(
connector) (obj, lookup_type, value_annotation, value), connector)
def as_sql(self, qn, connection): def as_sql(self, qn, connection):
""" """
@ -132,21 +129,26 @@ class WhereNode(tree.Node):
def make_atom(self, child, qn, connection): def make_atom(self, child, qn, connection):
""" """
Turn a tuple (table_alias, column_name, db_type, lookup_type, Turn a tuple (Constraint(table_alias, column_name, db_type),
value_annot, params) into valid SQL. lookup_type, value_annotation, params) into valid SQL.
The first item of the tuple may also be an Aggregate.
Returns the string for the SQL fragment and the parameters to use for Returns the string for the SQL fragment and the parameters to use for
it. it.
""" """
lvalue, lookup_type, value_annot, params_or_value = child lvalue, lookup_type, value_annotation, params_or_value = child
if hasattr(lvalue, 'process'): if isinstance(lvalue, Constraint):
try: try:
lvalue, params = lvalue.process(lookup_type, params_or_value, connection) lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
except EmptyShortCircuit: except EmptyShortCircuit:
raise EmptyResultSet raise EmptyResultSet
elif isinstance(lvalue, Aggregate):
params = lvalue.field.get_db_prep_lookup(lookup_type, params_or_value, connection)
else: else:
params = Field().get_db_prep_lookup(lookup_type, params_or_value, raise TypeError("'make_atom' expects a Constraint or an Aggregate "
connection=connection, prepared=True) "as the first item of its 'child' argument.")
if isinstance(lvalue, tuple): if isinstance(lvalue, tuple):
# A direct database column lookup. # A direct database column lookup.
field_sql = self.sql_for_columns(lvalue, qn, connection) field_sql = self.sql_for_columns(lvalue, qn, connection)
@ -154,7 +156,7 @@ class WhereNode(tree.Node):
# A smart object with an as_sql() method. # A smart object with an as_sql() method.
field_sql = lvalue.as_sql(qn, connection) field_sql = lvalue.as_sql(qn, connection)
if value_annot is datetime.datetime: if value_annotation is datetime.datetime:
cast_sql = connection.ops.datetime_cast_sql() cast_sql = connection.ops.datetime_cast_sql()
else: else:
cast_sql = '%s' cast_sql = '%s'
@ -168,7 +170,7 @@ class WhereNode(tree.Node):
if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
and connection.features.interprets_empty_strings_as_nulls): and connection.features.interprets_empty_strings_as_nulls):
lookup_type = 'isnull' lookup_type = 'isnull'
value_annot = True value_annotation = True
if lookup_type in connection.operators: if lookup_type in connection.operators:
format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
@ -177,7 +179,7 @@ class WhereNode(tree.Node):
extra), params) extra), params)
if lookup_type == 'in': if lookup_type == 'in':
if not value_annot: if not value_annotation:
raise EmptyResultSet raise EmptyResultSet
if extra: if extra:
return ('%s IN %s' % (field_sql, extra), params) return ('%s IN %s' % (field_sql, extra), params)
@ -206,7 +208,7 @@ class WhereNode(tree.Node):
params) params)
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
return ('%s IS %sNULL' % (field_sql, return ('%s IS %sNULL' % (field_sql,
(not value_annot and 'NOT ' or '')), ()) (not value_annotation and 'NOT ' or '')), ())
elif lookup_type == 'search': elif lookup_type == 'search':
return (connection.ops.fulltext_search_sql(field_sql), params) return (connection.ops.fulltext_search_sql(field_sql), params)
elif lookup_type in ('regex', 'iregex'): elif lookup_type in ('regex', 'iregex'):

View File

@ -6,6 +6,13 @@ class Event(models.Model):
class MaybeEvent(models.Model): class MaybeEvent(models.Model):
dt = models.DateTimeField(blank=True, null=True) dt = models.DateTimeField(blank=True, null=True)
class Session(models.Model):
name = models.CharField(max_length=20)
class SessionEvent(models.Model):
dt = models.DateTimeField()
session = models.ForeignKey(Session, related_name='events')
class Timestamp(models.Model): class Timestamp(models.Model):
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
updated = models.DateTimeField(auto_now=True) updated = models.DateTimeField(auto_now=True)

View File

@ -25,7 +25,7 @@ from django.utils.tzinfo import FixedOffset
from django.utils.unittest import skipIf, skipUnless from django.utils.unittest import skipIf, skipUnless
from .forms import EventForm, EventSplitForm, EventModelForm from .forms import EventForm, EventSplitForm, EventModelForm
from .models import Event, MaybeEvent, Timestamp from .models import Event, MaybeEvent, Session, SessionEvent, Timestamp
# These tests use the EAT (Eastern Africa Time) and ICT (Indochina Time) # These tests use the EAT (Eastern Africa Time) and ICT (Indochina Time)
@ -231,6 +231,28 @@ class LegacyDatabaseTests(BaseDateTimeTests):
'dt__max': datetime.datetime(2011, 9, 1, 23, 20, 20), 'dt__max': datetime.datetime(2011, 9, 1, 23, 20, 20),
}) })
def test_query_annotation(self):
# Only min and max make sense for datetimes.
morning = Session.objects.create(name='morning')
afternoon = Session.objects.create(name='afternoon')
SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 23, 20, 20), session=afternoon)
SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 13, 20, 30), session=afternoon)
SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 3, 20, 40), session=morning)
morning_min_dt = datetime.datetime(2011, 9, 1, 3, 20, 40)
afternoon_min_dt = datetime.datetime(2011, 9, 1, 13, 20, 30)
self.assertQuerysetEqual(
Session.objects.annotate(dt=Min('events__dt')).order_by('dt'),
[morning_min_dt, afternoon_min_dt],
transform=lambda d: d.dt)
self.assertQuerysetEqual(
Session.objects.annotate(dt=Min('events__dt')).filter(dt__lt=afternoon_min_dt),
[morning_min_dt],
transform=lambda d: d.dt)
self.assertQuerysetEqual(
Session.objects.annotate(dt=Min('events__dt')).filter(dt__gte=afternoon_min_dt),
[afternoon_min_dt],
transform=lambda d: d.dt)
def test_query_dates(self): def test_query_dates(self):
Event.objects.create(dt=datetime.datetime(2011, 1, 1, 1, 30, 0)) Event.objects.create(dt=datetime.datetime(2011, 1, 1, 1, 30, 0))
Event.objects.create(dt=datetime.datetime(2011, 1, 1, 4, 30, 0)) Event.objects.create(dt=datetime.datetime(2011, 1, 1, 4, 30, 0))
@ -412,6 +434,28 @@ class NewDatabaseTests(BaseDateTimeTests):
'dt__max': datetime.datetime(2011, 9, 1, 23, 20, 20, tzinfo=EAT), 'dt__max': datetime.datetime(2011, 9, 1, 23, 20, 20, tzinfo=EAT),
}) })
def test_query_annotation(self):
# Only min and max make sense for datetimes.
morning = Session.objects.create(name='morning')
afternoon = Session.objects.create(name='afternoon')
SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 23, 20, 20, tzinfo=EAT), session=afternoon)
SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), session=afternoon)
SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 3, 20, 40, tzinfo=EAT), session=morning)
morning_min_dt = datetime.datetime(2011, 9, 1, 3, 20, 40, tzinfo=EAT)
afternoon_min_dt = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT)
self.assertQuerysetEqual(
Session.objects.annotate(dt=Min('events__dt')).order_by('dt'),
[morning_min_dt, afternoon_min_dt],
transform=lambda d: d.dt)
self.assertQuerysetEqual(
Session.objects.annotate(dt=Min('events__dt')).filter(dt__lt=afternoon_min_dt),
[morning_min_dt],
transform=lambda d: d.dt)
self.assertQuerysetEqual(
Session.objects.annotate(dt=Min('events__dt')).filter(dt__gte=afternoon_min_dt),
[afternoon_min_dt],
transform=lambda d: d.dt)
def test_query_dates(self): def test_query_dates(self):
# Same comment as in test_query_date_related_filters. # Same comment as in test_query_date_related_filters.
Event.objects.create(dt=datetime.datetime(2011, 1, 1, 1, 30, 0, tzinfo=EAT)) Event.objects.create(dt=datetime.datetime(2011, 1, 1, 1, 30, 0, tzinfo=EAT))