From cbb5cdd155668ba771cad6b975676d3b20fed37b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Tue, 18 Nov 2014 11:24:33 +0200 Subject: [PATCH] Fixed #23867 -- removed DateQuerySet hacks The .dates() queries were implemented by using custom Query, QuerySet, and Compiler classes. Instead implement them by using expressions and database converters APIs. --- .../gis/db/backends/oracle/compiler.py | 8 -- django/contrib/gis/db/models/sql/compiler.py | 8 -- django/db/backends/mysql/compiler.py | 8 -- django/db/backends/oracle/compiler.py | 8 -- django/db/models/expressions.py | 108 ++++++++++++++---- django/db/models/query.py | 68 +++-------- django/db/models/sql/compiler.py | 67 ++--------- django/db/models/sql/query.py | 3 +- django/db/models/sql/subqueries.py | 79 +------------ tests/dates/tests.py | 7 +- 10 files changed, 111 insertions(+), 253 deletions(-) diff --git a/django/contrib/gis/db/backends/oracle/compiler.py b/django/contrib/gis/db/backends/oracle/compiler.py index 4fe65bce28..a78765b901 100644 --- a/django/contrib/gis/db/backends/oracle/compiler.py +++ b/django/contrib/gis/db/backends/oracle/compiler.py @@ -22,11 +22,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler): class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler): pass - - -class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler): - pass - - -class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler): - pass diff --git a/django/contrib/gis/db/models/sql/compiler.py b/django/contrib/gis/db/models/sql/compiler.py index 4323798136..dd156ea4b6 100644 --- a/django/contrib/gis/db/models/sql/compiler.py +++ b/django/contrib/gis/db/models/sql/compiler.py @@ -235,11 +235,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler): class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler): pass - - -class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler): - pass - - -class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler): - pass diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index ef1be1d069..fc74ac1991 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -23,11 +23,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): pass - - -class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler): - pass - - -class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler): - pass diff --git a/django/db/backends/oracle/compiler.py b/django/db/backends/oracle/compiler.py index a380d8a5ce..cfc8a08c1b 100644 --- a/django/db/backends/oracle/compiler.py +++ b/django/db/backends/oracle/compiler.py @@ -54,11 +54,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): pass - - -class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler): - pass - - -class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler): - pass diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 7407d1c40b..996a215da3 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1,11 +1,13 @@ import copy import datetime +from django.conf import settings from django.core.exceptions import FieldError from django.db.backends import utils as backend_utils from django.db.models import fields from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import refs_aggregate +from django.utils import timezone from django.utils.functional import cached_property @@ -124,6 +126,9 @@ class ExpressionNode(CombinableMixin): # aggregate specific fields is_summary = False + def get_db_converters(self, connection): + return [self.convert_value] + def __init__(self, output_field=None): self._output_field = output_field @@ -531,40 +536,95 @@ class Date(ExpressionNode): """ Add a date selection column. """ - def __init__(self, col, lookup_type): + def __init__(self, lookup, lookup_type): super(Date, self).__init__(output_field=fields.DateField()) - self.col = col + self.lookup = lookup + self.col = None self.lookup_type = lookup_type def get_source_expressions(self): return [self.col] - def set_source_expressions(self, exprs): - self.col, = self.exprs - - def as_sql(self, compiler, connection): - sql, params = self.col.as_sql(compiler, connection) - assert not(params) - return connection.ops.date_trunc_sql(self.lookup_type, sql), [] - - -class DateTime(ExpressionNode): - """ - Add a datetime selection column. - """ - def __init__(self, col, lookup_type, tzname): - super(DateTime, self).__init__(output_field=fields.DateTimeField()) - self.col = col - self.lookup_type = lookup_type - self.tzname = tzname - - def get_source_expressions(self): - return [self.col] - def set_source_expressions(self, exprs): self.col, = exprs + def resolve_expression(self, query, allow_joins, reuse, summarize): + copy = self.copy() + copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize) + field = copy.col.output_field + assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name + if settings.USE_TZ: + assert not isinstance(field, fields.DateTimeField), ( + "%r is a DateTimeField, not a DateField." % field.name + ) + return copy + + def as_sql(self, compiler, connection): + sql, params = self.col.as_sql(compiler, connection) + assert not(params) + return connection.ops.date_trunc_sql(self.lookup_type, sql), [] + + def copy(self): + copy = super(Date, self).copy() + copy.lookup = self.lookup + copy.lookup_type = self.lookup_type + return copy + + def convert_value(self, value, connection): + if isinstance(value, datetime.datetime): + value = value.date() + return value + + +class DateTime(ExpressionNode): + """ + Add a datetime selection column. + """ + def __init__(self, lookup, lookup_type, tzinfo): + super(DateTime, self).__init__(output_field=fields.DateTimeField()) + self.lookup = lookup + self.col = None + self.lookup_type = lookup_type + if tzinfo is None: + self.tzname = None + else: + self.tzname = timezone._get_timezone_name(tzinfo) + self.tzinfo = tzinfo + + def get_source_expressions(self): + return [self.col] + + def set_source_expressions(self, exprs): + self.col, = exprs + + def resolve_expression(self, query, allow_joins, reuse, summarize): + copy = self.copy() + copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize) + field = copy.col.output_field + assert isinstance(field, fields.DateTimeField), ( + "%r isn't a DateTimeField." % field.name + ) + return copy + def as_sql(self, compiler, connection): sql, params = self.col.as_sql(compiler, connection) assert not(params) return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname) + + def copy(self): + copy = super(DateTime, self).copy() + copy.lookup = self.lookup + copy.lookup_type = self.lookup_type + copy.tzname = self.tzname + return copy + + def convert_value(self, value, connection): + if settings.USE_TZ: + if value is None: + raise ValueError( + "Database returned an invalid value in QuerySet.datetimes(). " + "Are time zone definitions for your database and pytz installed?" + ) + value = value.replace(tzinfo=None) + value = timezone.make_aware(value, self.tzinfo) + return value diff --git a/django/db/models/query.py b/django/db/models/query.py index a7474bbc45..2aa1f6464e 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -18,6 +18,7 @@ from django.db.models.query_utils import (Q, select_related_descend, from django.db.models.deletion import Collector from django.db.models.sql.constants import CURSOR from django.db.models import sql +from django.db.models.expressions import Date, DateTime, F from django.utils.functional import partition from django.utils import six from django.utils import timezone @@ -658,8 +659,12 @@ class QuerySet(object): "'kind' must be one of 'year', 'month' or 'day'." assert order in ('ASC', 'DESC'), \ "'order' must be either 'ASC' or 'DESC'." - return self._clone(klass=DateQuerySet, setup=True, - _field_name=field_name, _kind=kind, _order=order) + return self.annotate( + datefield=Date(field_name, kind), + plain_field=F(field_name) + ).values_list( + 'datefield', flat=True + ).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datefield') def datetimes(self, field_name, kind, order='ASC', tzinfo=None): """ @@ -675,8 +680,12 @@ class QuerySet(object): tzinfo = timezone.get_current_timezone() else: tzinfo = None - return self._clone(klass=DateTimeQuerySet, setup=True, - _field_name=field_name, _kind=kind, _order=order, _tzinfo=tzinfo) + return self.annotate( + datetimefield=DateTime(field_name, kind, tzinfo), + plain_field=F(field_name) + ).values_list( + 'datetimefield', flat=True + ).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datetimefield') def none(self): """ @@ -1272,57 +1281,6 @@ class ValuesListQuerySet(ValuesQuerySet): return clone -class DateQuerySet(QuerySet): - def iterator(self): - return self.query.get_compiler(self.db).results_iter() - - def _setup_query(self): - """ - Sets up any special features of the query attribute. - - Called by the _clone() method after initializing the rest of the - instance. - """ - self.query.clear_deferred_loading() - self.query = self.query.clone(klass=sql.DateQuery, setup=True) - self.query.select = [] - self.query.add_select(self._field_name, self._kind, self._order) - - def _clone(self, klass=None, setup=False, **kwargs): - c = super(DateQuerySet, self)._clone(klass, False, **kwargs) - c._field_name = self._field_name - c._kind = self._kind - if setup and hasattr(c, '_setup_query'): - c._setup_query() - return c - - -class DateTimeQuerySet(QuerySet): - def iterator(self): - return self.query.get_compiler(self.db).results_iter() - - def _setup_query(self): - """ - Sets up any special features of the query attribute. - - Called by the _clone() method after initializing the rest of the - instance. - """ - self.query.clear_deferred_loading() - self.query = self.query.clone(klass=sql.DateTimeQuery, setup=True, tzinfo=self._tzinfo) - self.query.select = [] - self.query.add_select(self._field_name, self._kind, self._order) - - def _clone(self, klass=None, setup=False, **kwargs): - c = super(DateTimeQuerySet, self)._clone(klass, False, **kwargs) - c._field_name = self._field_name - c._kind = self._kind - c._tzinfo = self._tzinfo - if setup and hasattr(c, '_setup_query'): - c._setup_query() - return c - - def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, only_load=None, from_parent=None): """ diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 4825fae3fa..b800c3fc3e 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1,7 +1,5 @@ -import datetime import warnings -from django.conf import settings from django.core.exceptions import FieldError from django.db.backends.utils import truncate_name from django.db.models.constants import LOOKUP_SEP @@ -13,7 +11,6 @@ from django.db.models.sql.query import get_order_dir, Query from django.db.transaction import TransactionManagementError from django.db.utils import DatabaseError from django.utils import six -from django.utils import timezone from django.utils.deprecation import RemovedInDjango20Warning from django.utils.six.moves import zip @@ -698,10 +695,14 @@ class SQLCompiler(object): index_extra_select = len(self.query.extra_select) for i, field in enumerate(fields): if field: - backend_converters = self.connection.ops.get_db_converters(field.get_internal_type()) + try: + output_field = field.output_field + except AttributeError: + output_field = field + backend_converters = self.connection.ops.get_db_converters(output_field.get_internal_type()) field_converters = field.get_db_converters(self.connection) if backend_converters or field_converters: - converters[index_extra_select + i] = (backend_converters, field_converters, field) + converters[index_extra_select + i] = (backend_converters, field_converters, output_field) return converters def apply_converters(self, row, converters): @@ -753,11 +754,8 @@ class SQLCompiler(object): # annotations come before the related cols if has_annotation_select: # extra is always at the start of the field list - prepended_cols = len(self.query.extra_select) - annotation_start = len(fields) + prepended_cols fields = fields + [ - anno.output_field for alias, anno in self.query.annotation_select.items()] - annotation_end = len(fields) + prepended_cols + anno for alias, anno in self.query.annotation_select.items()] # add related fields fields = fields + [ @@ -768,16 +766,6 @@ class SQLCompiler(object): ] converters = self.get_converters(fields) - if has_annotation_select: - for (alias, annotation), position in zip( - self.query.annotation_select.items(), - range(annotation_start, annotation_end + 1)): - if position in converters: - # annotation conversions always run first - converters[position][1].insert(0, annotation.convert_value) - else: - converters[position] = ([], [annotation.convert_value], annotation.output_field) - if converters: row = self.apply_converters(row, converters) yield row @@ -1122,47 +1110,6 @@ class SQLAggregateCompiler(SQLCompiler): return sql, params -class SQLDateCompiler(SQLCompiler): - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - from django.db.models.fields import DateField - converters = self.get_converters([DateField()]) - - offset = len(self.query.extra_select) - for rows in self.execute_sql(MULTI): - for row in rows: - date = self.apply_converters(row, converters)[offset] - if isinstance(date, datetime.datetime): - date = date.date() - yield date - - -class SQLDateTimeCompiler(SQLCompiler): - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - from django.db.models.fields import DateTimeField - converters = self.get_converters([DateTimeField()]) - - offset = len(self.query.extra_select) - for rows in self.execute_sql(MULTI): - for row in rows: - datetime = self.apply_converters(row, converters)[offset] - # Datetimes are artificially returned in UTC on databases that - # don't support time zone. Restore the zone used in the query. - if settings.USE_TZ: - if datetime is None: - raise ValueError("Database returned an invalid value " - "in QuerySet.datetimes(). Are time zone " - "definitions for your database and pytz installed?") - datetime = datetime.replace(tzinfo=None) - datetime = timezone.make_aware(datetime, self.query.tzinfo) - yield datetime - - def cursor_iter(cursor, sentinel): """ Yields blocks of rows from a cursor and ensures the cursor is closed when diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 4702bc1945..a5d067a37b 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -992,7 +992,8 @@ class Query(object): """ Adds a single annotation expression to the Query """ - annotation = annotation.resolve_expression(self, summarize=is_summary) + annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None, + summarize=is_summary) self.append_annotation_mask([alias]) self.annotations[alias] = annotation diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 6f3f7358d3..12bde13bf3 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -2,21 +2,15 @@ Query subclasses which provide extra functionality beyond simple data retrieval. """ -from django.conf import settings from django.core.exceptions import FieldError from django.db import connections from django.db.models.query_utils import Q -from django.db.models.constants import LOOKUP_SEP -from django.db.models.expressions import Date, DateTime, Col -from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo from django.db.models.sql.query import Query from django.utils import six -from django.utils import timezone -__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', - 'DateTimeQuery', 'AggregateQuery'] +__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery'] class DeleteQuery(Query): @@ -204,77 +198,6 @@ class InsertQuery(Query): self.raw = raw -class DateQuery(Query): - """ - A DateQuery is a normal query, except that it specifically selects a single - date field. This requires some special handling when converting the results - back to Python objects, so we put it in a separate class. - """ - - compiler = 'SQLDateCompiler' - - def add_select(self, field_name, lookup_type, order='ASC'): - """ - Converts the query into an extraction query. - """ - try: - field, _, _, joins, _ = self.setup_joins( - field_name.split(LOOKUP_SEP), - self.get_meta(), - self.get_initial_alias(), - ) - except FieldError: - raise FieldDoesNotExist("%s has no field named '%s'" % ( - self.get_meta().object_name, field_name - )) - self._check_field(field) # overridden in DateTimeQuery - alias = joins[-1] - select = self._get_select(Col(alias, field), lookup_type) - self.clear_select_clause() - self.select = [SelectInfo(select, None)] - self.distinct = True - self.order_by = [1] if order == 'ASC' else [-1] - - if field.null: - self.add_filter(("%s__isnull" % field_name, False)) - - def _check_field(self, field): - assert isinstance(field, DateField), \ - "%r isn't a DateField." % field.name - if settings.USE_TZ: - assert not isinstance(field, DateTimeField), \ - "%r is a DateTimeField, not a DateField." % field.name - - def _get_select(self, col, lookup_type): - return Date(col, lookup_type) - - -class DateTimeQuery(DateQuery): - """ - A DateTimeQuery is like a DateQuery but for a datetime field. If time zone - support is active, the tzinfo attribute contains the time zone to use for - converting the values before truncating them. Otherwise it's set to None. - """ - - compiler = 'SQLDateTimeCompiler' - - def clone(self, klass=None, memo=None, **kwargs): - if 'tzinfo' not in kwargs and hasattr(self, 'tzinfo'): - kwargs['tzinfo'] = self.tzinfo - return super(DateTimeQuery, self).clone(klass, memo, **kwargs) - - def _check_field(self, field): - assert isinstance(field, DateTimeField), \ - "%r isn't a DateTimeField." % field.name - - def _get_select(self, col, lookup_type): - if self.tzinfo is None: - tzname = None - else: - tzname = timezone._get_timezone_name(self.tzinfo) - return DateTime(col, lookup_type, tzname) - - class AggregateQuery(Query): """ An AggregateQuery takes another query as a parameter to the FROM diff --git a/tests/dates/tests.py b/tests/dates/tests.py index 2177ad046e..cc6a571d0d 100644 --- a/tests/dates/tests.py +++ b/tests/dates/tests.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import datetime -from django.db.models.fields import FieldDoesNotExist +from django.core.exceptions import FieldError from django.test import TestCase from django.utils import six @@ -93,8 +93,9 @@ class DatesTests(TestCase): def test_dates_fails_when_given_invalid_field_argument(self): six.assertRaisesRegex( self, - FieldDoesNotExist, - "Article has no field named 'invalid_field'", + FieldError, + "Cannot resolve keyword u?'invalid_field' into field. Choices are: " + "categories, comments, id, pub_date, title", Article.objects.dates, "invalid_field", "year",