Fixed #23867 -- removed DateQuerySet hacks

The .dates() queries were implemented by using custom Query, QuerySet,
and Compiler classes. Instead implement them by using expressions and
database converters APIs.
This commit is contained in:
Anssi Kääriäinen 2014-11-18 11:24:33 +02:00 committed by Tim Graham
parent cc870b8ef5
commit cbb5cdd155
10 changed files with 111 additions and 253 deletions

View File

@ -22,11 +22,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
pass
class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
pass
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
pass

View File

@ -235,11 +235,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
pass
class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
pass
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
pass

View File

@ -23,11 +23,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass
class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
pass
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
pass

View File

@ -54,11 +54,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
pass
class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
pass
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
pass

View File

@ -1,11 +1,13 @@
import copy
import datetime
from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import refs_aggregate
from django.utils import timezone
from django.utils.functional import cached_property
@ -124,6 +126,9 @@ class ExpressionNode(CombinableMixin):
# aggregate specific fields
is_summary = False
def get_db_converters(self, connection):
return [self.convert_value]
def __init__(self, output_field=None):
self._output_field = output_field
@ -531,40 +536,95 @@ class Date(ExpressionNode):
"""
Add a date selection column.
"""
def __init__(self, col, lookup_type):
def __init__(self, lookup, lookup_type):
super(Date, self).__init__(output_field=fields.DateField())
self.col = col
self.lookup = lookup
self.col = None
self.lookup_type = lookup_type
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = self.exprs
def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(compiler, connection)
assert not(params)
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
class DateTime(ExpressionNode):
"""
Add a datetime selection column.
"""
def __init__(self, col, lookup_type, tzname):
super(DateTime, self).__init__(output_field=fields.DateTimeField())
self.col = col
self.lookup_type = lookup_type
self.tzname = tzname
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = exprs
def resolve_expression(self, query, allow_joins, reuse, summarize):
copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field
assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name
if settings.USE_TZ:
assert not isinstance(field, fields.DateTimeField), (
"%r is a DateTimeField, not a DateField." % field.name
)
return copy
def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(compiler, connection)
assert not(params)
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
def copy(self):
copy = super(Date, self).copy()
copy.lookup = self.lookup
copy.lookup_type = self.lookup_type
return copy
def convert_value(self, value, connection):
if isinstance(value, datetime.datetime):
value = value.date()
return value
class DateTime(ExpressionNode):
"""
Add a datetime selection column.
"""
def __init__(self, lookup, lookup_type, tzinfo):
super(DateTime, self).__init__(output_field=fields.DateTimeField())
self.lookup = lookup
self.col = None
self.lookup_type = lookup_type
if tzinfo is None:
self.tzname = None
else:
self.tzname = timezone._get_timezone_name(tzinfo)
self.tzinfo = tzinfo
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = exprs
def resolve_expression(self, query, allow_joins, reuse, summarize):
copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field
assert isinstance(field, fields.DateTimeField), (
"%r isn't a DateTimeField." % field.name
)
return copy
def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(compiler, connection)
assert not(params)
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)
def copy(self):
copy = super(DateTime, self).copy()
copy.lookup = self.lookup
copy.lookup_type = self.lookup_type
copy.tzname = self.tzname
return copy
def convert_value(self, value, connection):
if settings.USE_TZ:
if value is None:
raise ValueError(
"Database returned an invalid value in QuerySet.datetimes(). "
"Are time zone definitions for your database and pytz installed?"
)
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo)
return value

View File

@ -18,6 +18,7 @@ from django.db.models.query_utils import (Q, select_related_descend,
from django.db.models.deletion import Collector
from django.db.models.sql.constants import CURSOR
from django.db.models import sql
from django.db.models.expressions import Date, DateTime, F
from django.utils.functional import partition
from django.utils import six
from django.utils import timezone
@ -658,8 +659,12 @@ class QuerySet(object):
"'kind' must be one of 'year', 'month' or 'day'."
assert order in ('ASC', 'DESC'), \
"'order' must be either 'ASC' or 'DESC'."
return self._clone(klass=DateQuerySet, setup=True,
_field_name=field_name, _kind=kind, _order=order)
return self.annotate(
datefield=Date(field_name, kind),
plain_field=F(field_name)
).values_list(
'datefield', flat=True
).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datefield')
def datetimes(self, field_name, kind, order='ASC', tzinfo=None):
"""
@ -675,8 +680,12 @@ class QuerySet(object):
tzinfo = timezone.get_current_timezone()
else:
tzinfo = None
return self._clone(klass=DateTimeQuerySet, setup=True,
_field_name=field_name, _kind=kind, _order=order, _tzinfo=tzinfo)
return self.annotate(
datetimefield=DateTime(field_name, kind, tzinfo),
plain_field=F(field_name)
).values_list(
'datetimefield', flat=True
).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datetimefield')
def none(self):
"""
@ -1272,57 +1281,6 @@ class ValuesListQuerySet(ValuesQuerySet):
return clone
class DateQuerySet(QuerySet):
def iterator(self):
return self.query.get_compiler(self.db).results_iter()
def _setup_query(self):
"""
Sets up any special features of the query attribute.
Called by the _clone() method after initializing the rest of the
instance.
"""
self.query.clear_deferred_loading()
self.query = self.query.clone(klass=sql.DateQuery, setup=True)
self.query.select = []
self.query.add_select(self._field_name, self._kind, self._order)
def _clone(self, klass=None, setup=False, **kwargs):
c = super(DateQuerySet, self)._clone(klass, False, **kwargs)
c._field_name = self._field_name
c._kind = self._kind
if setup and hasattr(c, '_setup_query'):
c._setup_query()
return c
class DateTimeQuerySet(QuerySet):
def iterator(self):
return self.query.get_compiler(self.db).results_iter()
def _setup_query(self):
"""
Sets up any special features of the query attribute.
Called by the _clone() method after initializing the rest of the
instance.
"""
self.query.clear_deferred_loading()
self.query = self.query.clone(klass=sql.DateTimeQuery, setup=True, tzinfo=self._tzinfo)
self.query.select = []
self.query.add_select(self._field_name, self._kind, self._order)
def _clone(self, klass=None, setup=False, **kwargs):
c = super(DateTimeQuerySet, self)._clone(klass, False, **kwargs)
c._field_name = self._field_name
c._kind = self._kind
c._tzinfo = self._tzinfo
if setup and hasattr(c, '_setup_query'):
c._setup_query()
return c
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
only_load=None, from_parent=None):
"""

View File

@ -1,7 +1,5 @@
import datetime
import warnings
from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP
@ -13,7 +11,6 @@ from django.db.models.sql.query import get_order_dir, Query
from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseError
from django.utils import six
from django.utils import timezone
from django.utils.deprecation import RemovedInDjango20Warning
from django.utils.six.moves import zip
@ -698,10 +695,14 @@ class SQLCompiler(object):
index_extra_select = len(self.query.extra_select)
for i, field in enumerate(fields):
if field:
backend_converters = self.connection.ops.get_db_converters(field.get_internal_type())
try:
output_field = field.output_field
except AttributeError:
output_field = field
backend_converters = self.connection.ops.get_db_converters(output_field.get_internal_type())
field_converters = field.get_db_converters(self.connection)
if backend_converters or field_converters:
converters[index_extra_select + i] = (backend_converters, field_converters, field)
converters[index_extra_select + i] = (backend_converters, field_converters, output_field)
return converters
def apply_converters(self, row, converters):
@ -753,11 +754,8 @@ class SQLCompiler(object):
# annotations come before the related cols
if has_annotation_select:
# extra is always at the start of the field list
prepended_cols = len(self.query.extra_select)
annotation_start = len(fields) + prepended_cols
fields = fields + [
anno.output_field for alias, anno in self.query.annotation_select.items()]
annotation_end = len(fields) + prepended_cols
anno for alias, anno in self.query.annotation_select.items()]
# add related fields
fields = fields + [
@ -768,16 +766,6 @@ class SQLCompiler(object):
]
converters = self.get_converters(fields)
if has_annotation_select:
for (alias, annotation), position in zip(
self.query.annotation_select.items(),
range(annotation_start, annotation_end + 1)):
if position in converters:
# annotation conversions always run first
converters[position][1].insert(0, annotation.convert_value)
else:
converters[position] = ([], [annotation.convert_value], annotation.output_field)
if converters:
row = self.apply_converters(row, converters)
yield row
@ -1122,47 +1110,6 @@ class SQLAggregateCompiler(SQLCompiler):
return sql, params
class SQLDateCompiler(SQLCompiler):
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
from django.db.models.fields import DateField
converters = self.get_converters([DateField()])
offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
date = self.apply_converters(row, converters)[offset]
if isinstance(date, datetime.datetime):
date = date.date()
yield date
class SQLDateTimeCompiler(SQLCompiler):
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
from django.db.models.fields import DateTimeField
converters = self.get_converters([DateTimeField()])
offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
datetime = self.apply_converters(row, converters)[offset]
# Datetimes are artificially returned in UTC on databases that
# don't support time zone. Restore the zone used in the query.
if settings.USE_TZ:
if datetime is None:
raise ValueError("Database returned an invalid value "
"in QuerySet.datetimes(). Are time zone "
"definitions for your database and pytz installed?")
datetime = datetime.replace(tzinfo=None)
datetime = timezone.make_aware(datetime, self.query.tzinfo)
yield datetime
def cursor_iter(cursor, sentinel):
"""
Yields blocks of rows from a cursor and ensures the cursor is closed when

View File

@ -992,7 +992,8 @@ class Query(object):
"""
Adds a single annotation expression to the Query
"""
annotation = annotation.resolve_expression(self, summarize=is_summary)
annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None,
summarize=is_summary)
self.append_annotation_mask([alias])
self.annotations[alias] = annotation

View File

@ -2,21 +2,15 @@
Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from django.conf import settings
from django.core.exceptions import FieldError
from django.db import connections
from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Date, DateTime, Col
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
from django.db.models.sql.query import Query
from django.utils import six
from django.utils import timezone
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
'DateTimeQuery', 'AggregateQuery']
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery']
class DeleteQuery(Query):
@ -204,77 +198,6 @@ class InsertQuery(Query):
self.raw = raw
class DateQuery(Query):
"""
A DateQuery is a normal query, except that it specifically selects a single
date field. This requires some special handling when converting the results
back to Python objects, so we put it in a separate class.
"""
compiler = 'SQLDateCompiler'
def add_select(self, field_name, lookup_type, order='ASC'):
"""
Converts the query into an extraction query.
"""
try:
field, _, _, joins, _ = self.setup_joins(
field_name.split(LOOKUP_SEP),
self.get_meta(),
self.get_initial_alias(),
)
except FieldError:
raise FieldDoesNotExist("%s has no field named '%s'" % (
self.get_meta().object_name, field_name
))
self._check_field(field) # overridden in DateTimeQuery
alias = joins[-1]
select = self._get_select(Col(alias, field), lookup_type)
self.clear_select_clause()
self.select = [SelectInfo(select, None)]
self.distinct = True
self.order_by = [1] if order == 'ASC' else [-1]
if field.null:
self.add_filter(("%s__isnull" % field_name, False))
def _check_field(self, field):
assert isinstance(field, DateField), \
"%r isn't a DateField." % field.name
if settings.USE_TZ:
assert not isinstance(field, DateTimeField), \
"%r is a DateTimeField, not a DateField." % field.name
def _get_select(self, col, lookup_type):
return Date(col, lookup_type)
class DateTimeQuery(DateQuery):
"""
A DateTimeQuery is like a DateQuery but for a datetime field. If time zone
support is active, the tzinfo attribute contains the time zone to use for
converting the values before truncating them. Otherwise it's set to None.
"""
compiler = 'SQLDateTimeCompiler'
def clone(self, klass=None, memo=None, **kwargs):
if 'tzinfo' not in kwargs and hasattr(self, 'tzinfo'):
kwargs['tzinfo'] = self.tzinfo
return super(DateTimeQuery, self).clone(klass, memo, **kwargs)
def _check_field(self, field):
assert isinstance(field, DateTimeField), \
"%r isn't a DateTimeField." % field.name
def _get_select(self, col, lookup_type):
if self.tzinfo is None:
tzname = None
else:
tzname = timezone._get_timezone_name(self.tzinfo)
return DateTime(col, lookup_type, tzname)
class AggregateQuery(Query):
"""
An AggregateQuery takes another query as a parameter to the FROM

View File

@ -2,7 +2,7 @@ from __future__ import unicode_literals
import datetime
from django.db.models.fields import FieldDoesNotExist
from django.core.exceptions import FieldError
from django.test import TestCase
from django.utils import six
@ -93,8 +93,9 @@ class DatesTests(TestCase):
def test_dates_fails_when_given_invalid_field_argument(self):
six.assertRaisesRegex(
self,
FieldDoesNotExist,
"Article has no field named 'invalid_field'",
FieldError,
"Cannot resolve keyword u?'invalid_field' into field. Choices are: "
"categories, comments, id, pub_date, title",
Article.objects.dates,
"invalid_field",
"year",