Fixed #23867 -- removed DateQuerySet hacks
The .dates() queries were implemented by using custom Query, QuerySet, and Compiler classes. Instead implement them by using expressions and database converters APIs.
This commit is contained in:
parent
cc870b8ef5
commit
cbb5cdd155
|
@ -22,11 +22,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
|
|||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
|
|
@ -235,11 +235,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
|
|||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
|
|
@ -23,11 +23,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
|
|||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
|
||||
pass
|
||||
|
|
|
@ -54,11 +54,3 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
|
|||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, SQLCompiler):
|
||||
pass
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import copy
|
||||
import datetime
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.models import fields
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.query_utils import refs_aggregate
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
|
@ -124,6 +126,9 @@ class ExpressionNode(CombinableMixin):
|
|||
# aggregate specific fields
|
||||
is_summary = False
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return [self.convert_value]
|
||||
|
||||
def __init__(self, output_field=None):
|
||||
self._output_field = output_field
|
||||
|
||||
|
@ -531,40 +536,95 @@ class Date(ExpressionNode):
|
|||
"""
|
||||
Add a date selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type):
|
||||
def __init__(self, lookup, lookup_type):
|
||||
super(Date, self).__init__(output_field=fields.DateField())
|
||||
self.col = col
|
||||
self.lookup = lookup
|
||||
self.col = None
|
||||
self.lookup_type = lookup_type
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.col]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.col, = self.exprs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = self.col.as_sql(compiler, connection)
|
||||
assert not(params)
|
||||
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
|
||||
|
||||
|
||||
class DateTime(ExpressionNode):
|
||||
"""
|
||||
Add a datetime selection column.
|
||||
"""
|
||||
def __init__(self, col, lookup_type, tzname):
|
||||
super(DateTime, self).__init__(output_field=fields.DateTimeField())
|
||||
self.col = col
|
||||
self.lookup_type = lookup_type
|
||||
self.tzname = tzname
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.col]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.col, = exprs
|
||||
|
||||
def resolve_expression(self, query, allow_joins, reuse, summarize):
|
||||
copy = self.copy()
|
||||
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
|
||||
field = copy.col.output_field
|
||||
assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name
|
||||
if settings.USE_TZ:
|
||||
assert not isinstance(field, fields.DateTimeField), (
|
||||
"%r is a DateTimeField, not a DateField." % field.name
|
||||
)
|
||||
return copy
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = self.col.as_sql(compiler, connection)
|
||||
assert not(params)
|
||||
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
|
||||
|
||||
def copy(self):
|
||||
copy = super(Date, self).copy()
|
||||
copy.lookup = self.lookup
|
||||
copy.lookup_type = self.lookup_type
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
if isinstance(value, datetime.datetime):
|
||||
value = value.date()
|
||||
return value
|
||||
|
||||
|
||||
class DateTime(ExpressionNode):
|
||||
"""
|
||||
Add a datetime selection column.
|
||||
"""
|
||||
def __init__(self, lookup, lookup_type, tzinfo):
|
||||
super(DateTime, self).__init__(output_field=fields.DateTimeField())
|
||||
self.lookup = lookup
|
||||
self.col = None
|
||||
self.lookup_type = lookup_type
|
||||
if tzinfo is None:
|
||||
self.tzname = None
|
||||
else:
|
||||
self.tzname = timezone._get_timezone_name(tzinfo)
|
||||
self.tzinfo = tzinfo
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.col]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.col, = exprs
|
||||
|
||||
def resolve_expression(self, query, allow_joins, reuse, summarize):
|
||||
copy = self.copy()
|
||||
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
|
||||
field = copy.col.output_field
|
||||
assert isinstance(field, fields.DateTimeField), (
|
||||
"%r isn't a DateTimeField." % field.name
|
||||
)
|
||||
return copy
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = self.col.as_sql(compiler, connection)
|
||||
assert not(params)
|
||||
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)
|
||||
|
||||
def copy(self):
|
||||
copy = super(DateTime, self).copy()
|
||||
copy.lookup = self.lookup
|
||||
copy.lookup_type = self.lookup_type
|
||||
copy.tzname = self.tzname
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
if settings.USE_TZ:
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
"Database returned an invalid value in QuerySet.datetimes(). "
|
||||
"Are time zone definitions for your database and pytz installed?"
|
||||
)
|
||||
value = value.replace(tzinfo=None)
|
||||
value = timezone.make_aware(value, self.tzinfo)
|
||||
return value
|
||||
|
|
|
@ -18,6 +18,7 @@ from django.db.models.query_utils import (Q, select_related_descend,
|
|||
from django.db.models.deletion import Collector
|
||||
from django.db.models.sql.constants import CURSOR
|
||||
from django.db.models import sql
|
||||
from django.db.models.expressions import Date, DateTime, F
|
||||
from django.utils.functional import partition
|
||||
from django.utils import six
|
||||
from django.utils import timezone
|
||||
|
@ -658,8 +659,12 @@ class QuerySet(object):
|
|||
"'kind' must be one of 'year', 'month' or 'day'."
|
||||
assert order in ('ASC', 'DESC'), \
|
||||
"'order' must be either 'ASC' or 'DESC'."
|
||||
return self._clone(klass=DateQuerySet, setup=True,
|
||||
_field_name=field_name, _kind=kind, _order=order)
|
||||
return self.annotate(
|
||||
datefield=Date(field_name, kind),
|
||||
plain_field=F(field_name)
|
||||
).values_list(
|
||||
'datefield', flat=True
|
||||
).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datefield')
|
||||
|
||||
def datetimes(self, field_name, kind, order='ASC', tzinfo=None):
|
||||
"""
|
||||
|
@ -675,8 +680,12 @@ class QuerySet(object):
|
|||
tzinfo = timezone.get_current_timezone()
|
||||
else:
|
||||
tzinfo = None
|
||||
return self._clone(klass=DateTimeQuerySet, setup=True,
|
||||
_field_name=field_name, _kind=kind, _order=order, _tzinfo=tzinfo)
|
||||
return self.annotate(
|
||||
datetimefield=DateTime(field_name, kind, tzinfo),
|
||||
plain_field=F(field_name)
|
||||
).values_list(
|
||||
'datetimefield', flat=True
|
||||
).distinct().filter(plain_field__isnull=False).order_by(('-' if order == 'DESC' else '') + 'datetimefield')
|
||||
|
||||
def none(self):
|
||||
"""
|
||||
|
@ -1272,57 +1281,6 @@ class ValuesListQuerySet(ValuesQuerySet):
|
|||
return clone
|
||||
|
||||
|
||||
class DateQuerySet(QuerySet):
|
||||
def iterator(self):
|
||||
return self.query.get_compiler(self.db).results_iter()
|
||||
|
||||
def _setup_query(self):
|
||||
"""
|
||||
Sets up any special features of the query attribute.
|
||||
|
||||
Called by the _clone() method after initializing the rest of the
|
||||
instance.
|
||||
"""
|
||||
self.query.clear_deferred_loading()
|
||||
self.query = self.query.clone(klass=sql.DateQuery, setup=True)
|
||||
self.query.select = []
|
||||
self.query.add_select(self._field_name, self._kind, self._order)
|
||||
|
||||
def _clone(self, klass=None, setup=False, **kwargs):
|
||||
c = super(DateQuerySet, self)._clone(klass, False, **kwargs)
|
||||
c._field_name = self._field_name
|
||||
c._kind = self._kind
|
||||
if setup and hasattr(c, '_setup_query'):
|
||||
c._setup_query()
|
||||
return c
|
||||
|
||||
|
||||
class DateTimeQuerySet(QuerySet):
|
||||
def iterator(self):
|
||||
return self.query.get_compiler(self.db).results_iter()
|
||||
|
||||
def _setup_query(self):
|
||||
"""
|
||||
Sets up any special features of the query attribute.
|
||||
|
||||
Called by the _clone() method after initializing the rest of the
|
||||
instance.
|
||||
"""
|
||||
self.query.clear_deferred_loading()
|
||||
self.query = self.query.clone(klass=sql.DateTimeQuery, setup=True, tzinfo=self._tzinfo)
|
||||
self.query.select = []
|
||||
self.query.add_select(self._field_name, self._kind, self._order)
|
||||
|
||||
def _clone(self, klass=None, setup=False, **kwargs):
|
||||
c = super(DateTimeQuerySet, self)._clone(klass, False, **kwargs)
|
||||
c._field_name = self._field_name
|
||||
c._kind = self._kind
|
||||
c._tzinfo = self._tzinfo
|
||||
if setup and hasattr(c, '_setup_query'):
|
||||
c._setup_query()
|
||||
return c
|
||||
|
||||
|
||||
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
||||
only_load=None, from_parent=None):
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import datetime
|
||||
import warnings
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.backends.utils import truncate_name
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
|
@ -13,7 +11,6 @@ from django.db.models.sql.query import get_order_dir, Query
|
|||
from django.db.transaction import TransactionManagementError
|
||||
from django.db.utils import DatabaseError
|
||||
from django.utils import six
|
||||
from django.utils import timezone
|
||||
from django.utils.deprecation import RemovedInDjango20Warning
|
||||
from django.utils.six.moves import zip
|
||||
|
||||
|
@ -698,10 +695,14 @@ class SQLCompiler(object):
|
|||
index_extra_select = len(self.query.extra_select)
|
||||
for i, field in enumerate(fields):
|
||||
if field:
|
||||
backend_converters = self.connection.ops.get_db_converters(field.get_internal_type())
|
||||
try:
|
||||
output_field = field.output_field
|
||||
except AttributeError:
|
||||
output_field = field
|
||||
backend_converters = self.connection.ops.get_db_converters(output_field.get_internal_type())
|
||||
field_converters = field.get_db_converters(self.connection)
|
||||
if backend_converters or field_converters:
|
||||
converters[index_extra_select + i] = (backend_converters, field_converters, field)
|
||||
converters[index_extra_select + i] = (backend_converters, field_converters, output_field)
|
||||
return converters
|
||||
|
||||
def apply_converters(self, row, converters):
|
||||
|
@ -753,11 +754,8 @@ class SQLCompiler(object):
|
|||
# annotations come before the related cols
|
||||
if has_annotation_select:
|
||||
# extra is always at the start of the field list
|
||||
prepended_cols = len(self.query.extra_select)
|
||||
annotation_start = len(fields) + prepended_cols
|
||||
fields = fields + [
|
||||
anno.output_field for alias, anno in self.query.annotation_select.items()]
|
||||
annotation_end = len(fields) + prepended_cols
|
||||
anno for alias, anno in self.query.annotation_select.items()]
|
||||
|
||||
# add related fields
|
||||
fields = fields + [
|
||||
|
@ -768,16 +766,6 @@ class SQLCompiler(object):
|
|||
]
|
||||
|
||||
converters = self.get_converters(fields)
|
||||
if has_annotation_select:
|
||||
for (alias, annotation), position in zip(
|
||||
self.query.annotation_select.items(),
|
||||
range(annotation_start, annotation_end + 1)):
|
||||
if position in converters:
|
||||
# annotation conversions always run first
|
||||
converters[position][1].insert(0, annotation.convert_value)
|
||||
else:
|
||||
converters[position] = ([], [annotation.convert_value], annotation.output_field)
|
||||
|
||||
if converters:
|
||||
row = self.apply_converters(row, converters)
|
||||
yield row
|
||||
|
@ -1122,47 +1110,6 @@ class SQLAggregateCompiler(SQLCompiler):
|
|||
return sql, params
|
||||
|
||||
|
||||
class SQLDateCompiler(SQLCompiler):
|
||||
def results_iter(self):
|
||||
"""
|
||||
Returns an iterator over the results from executing this query.
|
||||
"""
|
||||
from django.db.models.fields import DateField
|
||||
converters = self.get_converters([DateField()])
|
||||
|
||||
offset = len(self.query.extra_select)
|
||||
for rows in self.execute_sql(MULTI):
|
||||
for row in rows:
|
||||
date = self.apply_converters(row, converters)[offset]
|
||||
if isinstance(date, datetime.datetime):
|
||||
date = date.date()
|
||||
yield date
|
||||
|
||||
|
||||
class SQLDateTimeCompiler(SQLCompiler):
|
||||
def results_iter(self):
|
||||
"""
|
||||
Returns an iterator over the results from executing this query.
|
||||
"""
|
||||
from django.db.models.fields import DateTimeField
|
||||
converters = self.get_converters([DateTimeField()])
|
||||
|
||||
offset = len(self.query.extra_select)
|
||||
for rows in self.execute_sql(MULTI):
|
||||
for row in rows:
|
||||
datetime = self.apply_converters(row, converters)[offset]
|
||||
# Datetimes are artificially returned in UTC on databases that
|
||||
# don't support time zone. Restore the zone used in the query.
|
||||
if settings.USE_TZ:
|
||||
if datetime is None:
|
||||
raise ValueError("Database returned an invalid value "
|
||||
"in QuerySet.datetimes(). Are time zone "
|
||||
"definitions for your database and pytz installed?")
|
||||
datetime = datetime.replace(tzinfo=None)
|
||||
datetime = timezone.make_aware(datetime, self.query.tzinfo)
|
||||
yield datetime
|
||||
|
||||
|
||||
def cursor_iter(cursor, sentinel):
|
||||
"""
|
||||
Yields blocks of rows from a cursor and ensures the cursor is closed when
|
||||
|
|
|
@ -992,7 +992,8 @@ class Query(object):
|
|||
"""
|
||||
Adds a single annotation expression to the Query
|
||||
"""
|
||||
annotation = annotation.resolve_expression(self, summarize=is_summary)
|
||||
annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None,
|
||||
summarize=is_summary)
|
||||
self.append_annotation_mask([alias])
|
||||
self.annotations[alias] = annotation
|
||||
|
||||
|
|
|
@ -2,21 +2,15 @@
|
|||
Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||
"""
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import connections
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import Date, DateTime, Col
|
||||
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
|
||||
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
|
||||
from django.db.models.sql.query import Query
|
||||
from django.utils import six
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
|
||||
'DateTimeQuery', 'AggregateQuery']
|
||||
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery']
|
||||
|
||||
|
||||
class DeleteQuery(Query):
|
||||
|
@ -204,77 +198,6 @@ class InsertQuery(Query):
|
|||
self.raw = raw
|
||||
|
||||
|
||||
class DateQuery(Query):
|
||||
"""
|
||||
A DateQuery is a normal query, except that it specifically selects a single
|
||||
date field. This requires some special handling when converting the results
|
||||
back to Python objects, so we put it in a separate class.
|
||||
"""
|
||||
|
||||
compiler = 'SQLDateCompiler'
|
||||
|
||||
def add_select(self, field_name, lookup_type, order='ASC'):
|
||||
"""
|
||||
Converts the query into an extraction query.
|
||||
"""
|
||||
try:
|
||||
field, _, _, joins, _ = self.setup_joins(
|
||||
field_name.split(LOOKUP_SEP),
|
||||
self.get_meta(),
|
||||
self.get_initial_alias(),
|
||||
)
|
||||
except FieldError:
|
||||
raise FieldDoesNotExist("%s has no field named '%s'" % (
|
||||
self.get_meta().object_name, field_name
|
||||
))
|
||||
self._check_field(field) # overridden in DateTimeQuery
|
||||
alias = joins[-1]
|
||||
select = self._get_select(Col(alias, field), lookup_type)
|
||||
self.clear_select_clause()
|
||||
self.select = [SelectInfo(select, None)]
|
||||
self.distinct = True
|
||||
self.order_by = [1] if order == 'ASC' else [-1]
|
||||
|
||||
if field.null:
|
||||
self.add_filter(("%s__isnull" % field_name, False))
|
||||
|
||||
def _check_field(self, field):
|
||||
assert isinstance(field, DateField), \
|
||||
"%r isn't a DateField." % field.name
|
||||
if settings.USE_TZ:
|
||||
assert not isinstance(field, DateTimeField), \
|
||||
"%r is a DateTimeField, not a DateField." % field.name
|
||||
|
||||
def _get_select(self, col, lookup_type):
|
||||
return Date(col, lookup_type)
|
||||
|
||||
|
||||
class DateTimeQuery(DateQuery):
|
||||
"""
|
||||
A DateTimeQuery is like a DateQuery but for a datetime field. If time zone
|
||||
support is active, the tzinfo attribute contains the time zone to use for
|
||||
converting the values before truncating them. Otherwise it's set to None.
|
||||
"""
|
||||
|
||||
compiler = 'SQLDateTimeCompiler'
|
||||
|
||||
def clone(self, klass=None, memo=None, **kwargs):
|
||||
if 'tzinfo' not in kwargs and hasattr(self, 'tzinfo'):
|
||||
kwargs['tzinfo'] = self.tzinfo
|
||||
return super(DateTimeQuery, self).clone(klass, memo, **kwargs)
|
||||
|
||||
def _check_field(self, field):
|
||||
assert isinstance(field, DateTimeField), \
|
||||
"%r isn't a DateTimeField." % field.name
|
||||
|
||||
def _get_select(self, col, lookup_type):
|
||||
if self.tzinfo is None:
|
||||
tzname = None
|
||||
else:
|
||||
tzname = timezone._get_timezone_name(self.tzinfo)
|
||||
return DateTime(col, lookup_type, tzname)
|
||||
|
||||
|
||||
class AggregateQuery(Query):
|
||||
"""
|
||||
An AggregateQuery takes another query as a parameter to the FROM
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import datetime
|
||||
|
||||
from django.db.models.fields import FieldDoesNotExist
|
||||
from django.core.exceptions import FieldError
|
||||
from django.test import TestCase
|
||||
from django.utils import six
|
||||
|
||||
|
@ -93,8 +93,9 @@ class DatesTests(TestCase):
|
|||
def test_dates_fails_when_given_invalid_field_argument(self):
|
||||
six.assertRaisesRegex(
|
||||
self,
|
||||
FieldDoesNotExist,
|
||||
"Article has no field named 'invalid_field'",
|
||||
FieldError,
|
||||
"Cannot resolve keyword u?'invalid_field' into field. Choices are: "
|
||||
"categories, comments, id, pub_date, title",
|
||||
Article.objects.dates,
|
||||
"invalid_field",
|
||||
"year",
|
||||
|
|
Loading…
Reference in New Issue