diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 207bc0c6c8..b41314a686 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -2,18 +2,11 @@ Classes to represent the default SQL aggregate functions """ -class AggregateField(object): - """An internal field mockup used to identify aggregates in the - data-conversion parts of the database backend. - """ - def __init__(self, internal_type): - self.internal_type = internal_type +from django.db.models.fields import IntegerField, FloatField - def get_internal_type(self): - return self.internal_type - -ordinal_aggregate_field = AggregateField('IntegerField') -computed_aggregate_field = AggregateField('FloatField') +# Fake fields used to identify aggregate types in data-conversion operations. +ordinal_aggregate_field = IntegerField() +computed_aggregate_field = FloatField() class Aggregate(object): """ diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 1455ba6e18..2bd705dd60 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -10,6 +10,7 @@ from itertools import repeat from django.utils import tree from django.db.models.fields import Field from django.db.models.sql.datastructures import EmptyResultSet, FullResultSet +from django.db.models.sql.aggregates import Aggregate # Connection types AND = 'AND' @@ -30,9 +31,8 @@ class WhereNode(tree.Node): the correct SQL). The children in this tree are usually either Q-like objects or lists of - [table_alias, field_name, db_type, lookup_type, value_annotation, - params]. However, a child could also be any class with as_sql() and - relabel_aliases() methods. + [table_alias, field_name, db_type, lookup_type, value_annotation, params]. + However, a child could also be any class with as_sql() and relabel_aliases() methods. """ default = AND @@ -54,25 +54,22 @@ class WhereNode(tree.Node): # emptiness and transform any non-empty values correctly. value = list(value) - # The "annotation" parameter is used to pass auxilliary information + # The "value_annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime # and empty values need special handling. Other types could be used # here in the future (using Python types is suggested for consistency). if isinstance(value, datetime.datetime): - annotation = datetime.datetime + value_annotation = datetime.datetime elif hasattr(value, 'value_annotation'): - annotation = value.value_annotation + value_annotation = value.value_annotation else: - annotation = bool(value) + value_annotation = bool(value) if hasattr(obj, "prepare"): value = obj.prepare(lookup_type, value) - super(WhereNode, self).add((obj, lookup_type, annotation, value), - connector) - return - super(WhereNode, self).add((obj, lookup_type, annotation, value), - connector) + super(WhereNode, self).add( + (obj, lookup_type, value_annotation, value), connector) def as_sql(self, qn, connection): """ @@ -132,21 +129,26 @@ class WhereNode(tree.Node): def make_atom(self, child, qn, connection): """ - Turn a tuple (table_alias, column_name, db_type, lookup_type, - value_annot, params) into valid SQL. + Turn a tuple (Constraint(table_alias, column_name, db_type), + lookup_type, value_annotation, params) into valid SQL. + + The first item of the tuple may also be an Aggregate. Returns the string for the SQL fragment and the parameters to use for it. """ - lvalue, lookup_type, value_annot, params_or_value = child - if hasattr(lvalue, 'process'): + lvalue, lookup_type, value_annotation, params_or_value = child + if isinstance(lvalue, Constraint): try: lvalue, params = lvalue.process(lookup_type, params_or_value, connection) except EmptyShortCircuit: raise EmptyResultSet + elif isinstance(lvalue, Aggregate): + params = lvalue.field.get_db_prep_lookup(lookup_type, params_or_value, connection) else: - params = Field().get_db_prep_lookup(lookup_type, params_or_value, - connection=connection, prepared=True) + raise TypeError("'make_atom' expects a Constraint or an Aggregate " + "as the first item of its 'child' argument.") + if isinstance(lvalue, tuple): # A direct database column lookup. field_sql = self.sql_for_columns(lvalue, qn, connection) @@ -154,7 +156,7 @@ class WhereNode(tree.Node): # A smart object with an as_sql() method. field_sql = lvalue.as_sql(qn, connection) - if value_annot is datetime.datetime: + if value_annotation is datetime.datetime: cast_sql = connection.ops.datetime_cast_sql() else: cast_sql = '%s' @@ -168,7 +170,7 @@ class WhereNode(tree.Node): if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' and connection.features.interprets_empty_strings_as_nulls): lookup_type = 'isnull' - value_annot = True + value_annotation = True if lookup_type in connection.operators: format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) @@ -177,7 +179,7 @@ class WhereNode(tree.Node): extra), params) if lookup_type == 'in': - if not value_annot: + if not value_annotation: raise EmptyResultSet if extra: return ('%s IN %s' % (field_sql, extra), params) @@ -206,7 +208,7 @@ class WhereNode(tree.Node): params) elif lookup_type == 'isnull': return ('%s IS %sNULL' % (field_sql, - (not value_annot and 'NOT ' or '')), ()) + (not value_annotation and 'NOT ' or '')), ()) elif lookup_type == 'search': return (connection.ops.fulltext_search_sql(field_sql), params) elif lookup_type in ('regex', 'iregex'): diff --git a/tests/modeltests/timezones/models.py b/tests/modeltests/timezones/models.py index 9296edf924..f0cc79275d 100644 --- a/tests/modeltests/timezones/models.py +++ b/tests/modeltests/timezones/models.py @@ -6,6 +6,13 @@ class Event(models.Model): class MaybeEvent(models.Model): dt = models.DateTimeField(blank=True, null=True) +class Session(models.Model): + name = models.CharField(max_length=20) + +class SessionEvent(models.Model): + dt = models.DateTimeField() + session = models.ForeignKey(Session, related_name='events') + class Timestamp(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) diff --git a/tests/modeltests/timezones/tests.py b/tests/modeltests/timezones/tests.py index dc9cfb3fd1..818405971c 100644 --- a/tests/modeltests/timezones/tests.py +++ b/tests/modeltests/timezones/tests.py @@ -25,7 +25,7 @@ from django.utils.tzinfo import FixedOffset from django.utils.unittest import skipIf, skipUnless from .forms import EventForm, EventSplitForm, EventModelForm -from .models import Event, MaybeEvent, Timestamp +from .models import Event, MaybeEvent, Session, SessionEvent, Timestamp # These tests use the EAT (Eastern Africa Time) and ICT (Indochina Time) @@ -231,6 +231,28 @@ class LegacyDatabaseTests(BaseDateTimeTests): 'dt__max': datetime.datetime(2011, 9, 1, 23, 20, 20), }) + def test_query_annotation(self): + # Only min and max make sense for datetimes. + morning = Session.objects.create(name='morning') + afternoon = Session.objects.create(name='afternoon') + SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 23, 20, 20), session=afternoon) + SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 13, 20, 30), session=afternoon) + SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 3, 20, 40), session=morning) + morning_min_dt = datetime.datetime(2011, 9, 1, 3, 20, 40) + afternoon_min_dt = datetime.datetime(2011, 9, 1, 13, 20, 30) + self.assertQuerysetEqual( + Session.objects.annotate(dt=Min('events__dt')).order_by('dt'), + [morning_min_dt, afternoon_min_dt], + transform=lambda d: d.dt) + self.assertQuerysetEqual( + Session.objects.annotate(dt=Min('events__dt')).filter(dt__lt=afternoon_min_dt), + [morning_min_dt], + transform=lambda d: d.dt) + self.assertQuerysetEqual( + Session.objects.annotate(dt=Min('events__dt')).filter(dt__gte=afternoon_min_dt), + [afternoon_min_dt], + transform=lambda d: d.dt) + def test_query_dates(self): Event.objects.create(dt=datetime.datetime(2011, 1, 1, 1, 30, 0)) Event.objects.create(dt=datetime.datetime(2011, 1, 1, 4, 30, 0)) @@ -412,6 +434,28 @@ class NewDatabaseTests(BaseDateTimeTests): 'dt__max': datetime.datetime(2011, 9, 1, 23, 20, 20, tzinfo=EAT), }) + def test_query_annotation(self): + # Only min and max make sense for datetimes. + morning = Session.objects.create(name='morning') + afternoon = Session.objects.create(name='afternoon') + SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 23, 20, 20, tzinfo=EAT), session=afternoon) + SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT), session=afternoon) + SessionEvent.objects.create(dt=datetime.datetime(2011, 9, 1, 3, 20, 40, tzinfo=EAT), session=morning) + morning_min_dt = datetime.datetime(2011, 9, 1, 3, 20, 40, tzinfo=EAT) + afternoon_min_dt = datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT) + self.assertQuerysetEqual( + Session.objects.annotate(dt=Min('events__dt')).order_by('dt'), + [morning_min_dt, afternoon_min_dt], + transform=lambda d: d.dt) + self.assertQuerysetEqual( + Session.objects.annotate(dt=Min('events__dt')).filter(dt__lt=afternoon_min_dt), + [morning_min_dt], + transform=lambda d: d.dt) + self.assertQuerysetEqual( + Session.objects.annotate(dt=Min('events__dt')).filter(dt__gte=afternoon_min_dt), + [afternoon_min_dt], + transform=lambda d: d.dt) + def test_query_dates(self): # Same comment as in test_query_date_related_filters. Event.objects.create(dt=datetime.datetime(2011, 1, 1, 1, 30, 0, tzinfo=EAT))