Added datetime-handling infrastructure in the ORM layers.

This commit is contained in:
Aymeric Augustin 2013-02-10 21:41:08 +01:00
parent 104d82a777
commit 0c829c23f4
10 changed files with 249 additions and 62 deletions

View File

@ -1,14 +1,16 @@
import datetime
try:
from itertools import zip_longest
except ImportError:
from itertools import izip_longest as zip_longest
from django.utils.six.moves import zip
from django.db.backends.util import truncate_name, typecast_timestamp
from django.conf import settings
from django.db.backends.util import truncate_name, typecast_date, typecast_timestamp
from django.db.models.sql import compiler
from django.db.models.sql.constants import MULTI
from django.utils import six
from django.utils.six.moves import zip
from django.utils import timezone
SQLCompiler = compiler.SQLCompiler
@ -280,5 +282,35 @@ class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
if self.connection.ops.oracle:
date = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
date = typecast_timestamp(str(date))
date = typecast_date(str(date))
if isinstance(date, datetime.datetime):
date = date.date()
yield date
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
"""
This is overridden for GeoDjango to properly cast date columns, since
`GeoQuery.resolve_columns` is used for spatial values.
See #14648, #16757.
"""
def results_iter(self):
if self.connection.ops.oracle:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
datetime = row[offset]
if self.connection.ops.oracle:
datetime = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
datetime = typecast_timestamp(str(datetime))
# Datetimes are artifically returned in UTC on databases that
# don't support time zone. Restore the zone used in the query.
if settings.USE_TZ:
datetime = datetime.replace(tzinfo=None)
datetime = timezone.make_aware(datetime, self.query.tzinfo)
yield datetime

View File

@ -1,3 +1,5 @@
import datetime
from django.db.utils import DatabaseError
try:
@ -14,7 +16,7 @@ from django.db.transaction import TransactionManagementError
from django.utils.functional import cached_property
from django.utils.importlib import import_module
from django.utils import six
from django.utils.timezone import is_aware
from django.utils import timezone
class BaseDatabaseWrapper(object):
@ -526,7 +528,7 @@ class BaseDatabaseOperations(object):
def date_trunc_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month' or 'day', returns the SQL that
truncates the given date field field_name to a DATE object with only
truncates the given date field field_name to a date object with only
the given specificity.
"""
raise NotImplementedError()
@ -540,6 +542,28 @@ class BaseDatabaseOperations(object):
"""
return "%s"
def datetime_extract_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute' or
'second', returns the SQL that extracts a value from the given
datetime field field_name.
When time zone support is enabled, the SQL should include a '%s'
placeholder for the time zone's name.
"""
raise NotImplementedError()
def datetime_trunc_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute' or
'second', returns the SQL that truncates the given datetime field
field_name to a datetime object with only the given specificity.
When time zone support is enabled, the SQL should include a '%s'
placeholder for the time zone's name.
"""
raise NotImplementedError()
def deferrable_sql(self):
"""
Returns the SQL necessary to make a constraint "initially deferred"
@ -856,7 +880,7 @@ class BaseDatabaseOperations(object):
"""
if value is None:
return None
if is_aware(value):
if timezone.is_aware(value):
raise ValueError("Django does not support timezone-aware times.")
return six.text_type(value)
@ -869,29 +893,33 @@ class BaseDatabaseOperations(object):
return None
return util.format_number(value, max_digits, decimal_places)
def year_lookup_bounds(self, value):
"""
Returns a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a field value using a year lookup
`value` is an int, containing the looked-up year.
"""
first = '%s-01-01 00:00:00'
second = '%s-12-31 23:59:59.999999'
return [first % value, second % value]
def year_lookup_bounds_for_date_field(self, value):
"""
Returns a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateField value using a year lookup
with a BETWEEN operator to query a DateField value using a year
lookup.
`value` is an int, containing the looked-up year.
By default, it just calls `self.year_lookup_bounds`. Some backends need
this hook because on their DB date fields can't be compared to values
which include a time part.
"""
return self.year_lookup_bounds(value)
first = datetime.date(value, 1, 1)
second = datetime.date(value, 12, 31)
return [first, second]
def year_lookup_bounds_for_datetime_field(self, value):
"""
Returns a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateTimeField value using a year
lookup.
`value` is an int, containing the looked-up year.
"""
first = datetime.datetime(value, 1, 1)
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
if settings.USE_TZ:
tz = timezone.get_current_timezone()
first = timezone.make_aware(first, tz)
second = timezone.make_aware(second, tz)
return [first, second]
def convert_values(self, value, field):
"""

View File

@ -312,9 +312,10 @@ class Field(object):
return value._prepare()
if lookup_type in (
'regex', 'iregex', 'month', 'day', 'week_day', 'search',
'contains', 'icontains', 'iexact', 'startswith', 'istartswith',
'endswith', 'iendswith', 'isnull'
'iexact', 'contains', 'icontains',
'startswith', 'istartswith', 'endswith', 'iendswith',
'month', 'day', 'week_day', 'hour', 'minute', 'second',
'isnull', 'search', 'regex', 'iregex',
):
return value
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
@ -350,8 +351,8 @@ class Field(object):
sql, params = value._as_sql(connection=connection)
return QueryWrapper(('(%s)' % sql), params)
if lookup_type in ('regex', 'iregex', 'month', 'day', 'week_day',
'search'):
if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute',
'second', 'search', 'regex', 'iregex'):
return [value]
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
return [self.get_db_prep_value(value, connection=connection,
@ -370,10 +371,12 @@ class Field(object):
elif lookup_type == 'isnull':
return []
elif lookup_type == 'year':
if self.get_internal_type() == 'DateField':
if isinstance(self, DateTimeField):
return connection.ops.year_lookup_bounds_for_datetime_field(value)
elif isinstance(self, DateField):
return connection.ops.year_lookup_bounds_for_date_field(value)
else:
return connection.ops.year_lookup_bounds(value)
return [value] # this isn't supposed to happen
def has_default(self):
"""
@ -722,9 +725,9 @@ class DateField(Field):
is_next=False))
def get_prep_lookup(self, lookup_type, value):
# For "__month", "__day", and "__week_day" lookups, convert the value
# to an int so the database backend always sees a consistent type.
if lookup_type in ('month', 'day', 'week_day'):
# For dates lookups, convert the value to an int
# so the database backend always sees a consistent type.
if lookup_type in ('month', 'day', 'week_day', 'hour', 'minute', 'second'):
return int(value)
return super(DateField, self).get_prep_lookup(lookup_type, value)

View File

@ -130,6 +130,9 @@ class Manager(object):
def dates(self, *args, **kwargs):
return self.get_query_set().dates(*args, **kwargs)
def datetimes(self, *args, **kwargs):
return self.get_query_set().datetimes(*args, **kwargs)
def distinct(self, *args, **kwargs):
return self.get_query_set().distinct(*args, **kwargs)

View File

@ -7,6 +7,7 @@ import itertools
import sys
import warnings
from django.conf import settings
from django.core import exceptions
from django.db import connections, router, transaction, IntegrityError
from django.db.models.constants import LOOKUP_SEP
@ -17,6 +18,7 @@ from django.db.models.deletion import Collector
from django.db.models import sql
from django.utils.functional import partition
from django.utils import six
from django.utils import timezone
# Used to control how many objects are worked with at once in some cases (e.g.
# when deleting objects).
@ -629,16 +631,33 @@ class QuerySet(object):
def dates(self, field_name, kind, order='ASC'):
"""
Returns a list of datetime objects representing all available dates for
Returns a list of date objects representing all available dates for
the given field_name, scoped to 'kind'.
"""
assert kind in ("month", "year", "day"), \
assert kind in ("year", "month", "day"), \
"'kind' must be one of 'year', 'month' or 'day'."
assert order in ('ASC', 'DESC'), \
"'order' must be either 'ASC' or 'DESC'."
return self._clone(klass=DateQuerySet, setup=True,
_field_name=field_name, _kind=kind, _order=order)
def datetimes(self, field_name, kind, order='ASC', tzinfo=None):
"""
Returns a list of datetime objects representing all available
datetimes for the given field_name, scoped to 'kind'.
"""
assert kind in ("year", "month", "day", "hour", "minute", "second"), \
"'kind' must be one of 'year', 'month', 'day', 'hour', 'minute' or 'second'."
assert order in ('ASC', 'DESC'), \
"'order' must be either 'ASC' or 'DESC'."
if settings.USE_TZ:
if tzinfo is None:
tzinfo = timezone.get_current_timezone()
else:
tzinfo = None
return self._clone(klass=DateTimeQuerySet, setup=True,
_field_name=field_name, _kind=kind, _order=order, _tzinfo=tzinfo)
def none(self):
"""
Returns an empty QuerySet.
@ -1187,7 +1206,7 @@ class DateQuerySet(QuerySet):
self.query.clear_deferred_loading()
self.query = self.query.clone(klass=sql.DateQuery, setup=True)
self.query.select = []
self.query.add_date_select(self._field_name, self._kind, self._order)
self.query.add_select(self._field_name, self._kind, self._order)
def _clone(self, klass=None, setup=False, **kwargs):
c = super(DateQuerySet, self)._clone(klass, False, **kwargs)
@ -1198,6 +1217,32 @@ class DateQuerySet(QuerySet):
return c
class DateTimeQuerySet(QuerySet):
def iterator(self):
return self.query.get_compiler(self.db).results_iter()
def _setup_query(self):
"""
Sets up any special features of the query attribute.
Called by the _clone() method after initializing the rest of the
instance.
"""
self.query.clear_deferred_loading()
self.query = self.query.clone(klass=sql.DateTimeQuery, setup=True, tzinfo=self._tzinfo)
self.query.select = []
self.query.add_select(self._field_name, self._kind, self._order)
def _clone(self, klass=None, setup=False, **kwargs):
c = super(DateTimeQuerySet, self)._clone(klass, False, **kwargs)
c._field_name = self._field_name
c._kind = self._kind
c._tzinfo = self._tzinfo
if setup and hasattr(c, '_setup_query'):
c._setup_query()
return c
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
only_load=None, from_parent=None):
"""

View File

@ -1,5 +1,6 @@
from django.utils.six.moves import zip
import datetime
from django.conf import settings
from django.core.exceptions import FieldError
from django.db import transaction
from django.db.backends.util import truncate_name
@ -12,6 +13,8 @@ from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
from django.db.utils import DatabaseError
from django.utils import six
from django.utils.six.moves import zip
from django.utils import timezone
class SQLCompiler(object):
@ -1005,10 +1008,10 @@ class SQLDateCompiler(SQLCompiler):
"""
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
from django.db.models.fields import DateField
fields = [DateField()]
else:
from django.db.backends.util import typecast_timestamp
from django.db.backends.util import typecast_date
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select)
@ -1018,9 +1021,45 @@ class SQLDateCompiler(SQLCompiler):
if resolve_columns:
date = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
date = typecast_timestamp(str(date))
date = typecast_date(str(date))
if isinstance(date, datetime.datetime):
date = date.date()
yield date
class SQLDateTimeCompiler(SQLCompiler):
def as_sql(self):
sql, params = super(SQLDateTimeCompiler, self).as_sql()
if settings.USE_TZ:
tzname = timezone._get_timezone_name(self.query.tzinfo)
params = (tzname,) + params
return sql, params
def results_iter(self):
"""
Returns an iterator over the results from executing this query.
"""
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
from django.db.backends.util import typecast_timestamp
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
datetime = row[offset]
if resolve_columns:
datetime = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
datetime = typecast_timestamp(str(datetime))
# Datetimes are artifically returned in UTC on databases that
# don't support time zone. Restore the zone used in the query.
if settings.USE_TZ:
datetime = datetime.replace(tzinfo=None)
datetime = timezone.make_aware(datetime, self.query.tzinfo)
yield datetime
def order_modified_iter(cursor, trim, sentinel):
"""

View File

@ -11,7 +11,8 @@ import re
QUERY_TERMS = set([
'exact', 'iexact', 'contains', 'icontains', 'gt', 'gte', 'lt', 'lte', 'in',
'startswith', 'istartswith', 'endswith', 'iendswith', 'range', 'year',
'month', 'day', 'week_day', 'isnull', 'search', 'regex', 'iregex',
'month', 'day', 'week_day', 'hour', 'minute', 'second', 'isnull', 'search',
'regex', 'iregex',
])
# Size of each "chunk" for get_iterator calls.

View File

@ -26,6 +26,8 @@ class Date(object):
"""
Add a date selection column.
"""
trunc_func = 'date_trunc_sql'
def __init__(self, col, lookup_type):
self.col = col
self.lookup_type = lookup_type
@ -40,4 +42,10 @@ class Date(object):
col = '%s.%s' % tuple([qn(c) for c in self.col])
else:
col = self.col
return connection.ops.date_trunc_sql(self.lookup_type, col)
return getattr(connection.ops, self.trunc_func)(self.lookup_type, col)
class DateTime(Date):
"""
Add a datetime selection column.
"""
trunc_func = 'datetime_trunc_sql'

View File

@ -2,22 +2,22 @@
Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from django.conf import settings
from django.core.exceptions import FieldError
from django.db import connections
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import DateField, FieldDoesNotExist
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import Date
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, Constraint
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
from django.utils.encoding import force_text
from django.utils import six
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
'AggregateQuery']
'DateTimeQuery', 'AggregateQuery']
class DeleteQuery(Query):
"""
@ -222,10 +222,11 @@ class DateQuery(Query):
"""
compiler = 'SQLDateCompiler'
select_type = Date
def add_date_select(self, field_name, lookup_type, order='ASC'):
def add_select(self, field_name, lookup_type, order='ASC'):
"""
Converts the query into a date extraction query.
Converts the query into an extraction query.
"""
try:
result = self.setup_joins(
@ -238,10 +239,9 @@ class DateQuery(Query):
self.model._meta.object_name, field_name
))
field = result[0]
assert isinstance(field, DateField), "%r isn't a DateField." \
% field.name
self._check_field(field) # overridden in DateTimeQuery
alias = result[3][-1]
select = Date((alias, field.column), lookup_type)
select = self.select_type((alias, field.column), lookup_type)
self.clear_select_clause()
self.select = [SelectInfo(select, None)]
self.distinct = True
@ -250,6 +250,27 @@ class DateQuery(Query):
if field.null:
self.add_filter(("%s__isnull" % field_name, False))
def _check_field(self, field):
assert isinstance(field, DateField), \
"%r isn't a DateField." % field.name
if settings.USE_TZ:
assert not isinstance(field, DateTimeField), \
"%r is a DateTimeField, not a DateField." % field.name
class DateTimeQuery(DateQuery):
"""
A DateTimeQuery is like a DateQuery but for a datetime field. If time zone
support is active, the tzinfo attribute contains the time zone to use for
converting the values before truncating them. Otherwise it's set to None.
"""
compiler = 'SQLDateTimeCompiler'
select_type = DateTime
def _check_field(self, field):
assert isinstance(field, DateTimeField), \
"%r isn't a DateTimeField." % field.name
class AggregateQuery(Query):
"""
An AggregateQuery takes another query as a parameter to the FROM

View File

@ -8,11 +8,13 @@ import collections
import datetime
from itertools import repeat
from django.utils import tree
from django.db.models.fields import Field
from django.conf import settings
from django.db.models.fields import DateTimeField, Field
from django.db.models.sql.datastructures import EmptyResultSet, Empty
from django.db.models.sql.aggregates import Aggregate
from django.utils.six.moves import xrange
from django.utils import timezone
from django.utils import tree
# Connection types
AND = 'AND'
@ -60,7 +62,8 @@ class WhereNode(tree.Node):
# about the value(s) to the query construction. Specifically, datetime
# and empty values need special handling. Other types could be used
# here in the future (using Python types is suggested for consistency).
if isinstance(value, datetime.datetime):
if (isinstance(value, datetime.datetime)
or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
value_annotation = datetime.datetime
elif hasattr(value, 'value_annotation'):
value_annotation = value.value_annotation
@ -174,10 +177,8 @@ class WhereNode(tree.Node):
# A smart object with an as_sql() method.
field_sql = lvalue.as_sql(qn, connection)
if value_annotation is datetime.datetime:
cast_sql = connection.ops.datetime_cast_sql()
else:
cast_sql = '%s'
is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
if hasattr(params, 'as_sql'):
extra, params = params.as_sql(qn, connection)
@ -221,9 +222,15 @@ class WhereNode(tree.Node):
params)
elif lookup_type in ('range', 'year'):
return ('%s BETWEEN %%s and %%s' % field_sql, params)
elif is_datetime_field and lookup_type in ('month', 'day', 'week_day',
'hour', 'minute', 'second'):
if settings.USE_TZ:
params = [timezone.get_current_timezone_name()] + params
return ('%s = %%s'
% connection.ops.datetime_extract_sql(lookup_type, field_sql), params)
elif lookup_type in ('month', 'day', 'week_day'):
return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql),
params)
return ('%s = %%s'
% connection.ops.date_extract_sql(lookup_type, field_sql), params)
elif lookup_type == 'isnull':
assert value_annotation in (True, False), "Invalid value_annotation for isnull"
return ('%s IS %sNULL' % (field_sql, ('' if value_annotation else 'NOT ')), ())