Fixed #9596 -- Added date transform for DateTimeField.

This commit is contained in:
Jon Dufresne 2015-03-07 13:20:29 -08:00 committed by Tim Graham
parent 076a63e672
commit 44f3ee7716
10 changed files with 155 additions and 50 deletions

View File

@ -99,6 +99,12 @@ class BaseDatabaseOperations(object):
""" """
return "%s" return "%s"
def datetime_cast_date_sql(self, field_name, tzname):
"""
Returns the SQL necessary to cast a datetime value to date value.
"""
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_date() method')
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, field_name, tzname):
""" """
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute' or Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute' or

View File

@ -39,27 +39,26 @@ class DatabaseOperations(BaseDatabaseOperations):
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql return sql
def datetime_extract_sql(self, lookup_type, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ: if settings.USE_TZ:
field_name = "CONVERT_TZ(%s, 'UTC', %%s)" % field_name field_name = "CONVERT_TZ(%s, 'UTC', %%s)" % field_name
params = [tzname] params = [tzname]
else: else:
params = [] params = []
# http://dev.mysql.com/doc/mysql/en/date-and-time-functions.html return field_name, params
if lookup_type == 'week_day':
# DAYOFWEEK() returns an integer, 1-7, Sunday=1. def datetime_cast_date_sql(self, field_name, tzname):
# Note: WEEKDAY() returns 0-6, Monday=0. field_name, params = self._convert_field_to_tz(field_name, tzname)
sql = "DAYOFWEEK(%s)" % field_name sql = "DATE(%s)" % field_name
else: return sql, params
sql = "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname)
sql = self.date_extract_sql(lookup_type, field_name)
return sql, params return sql, params
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
if settings.USE_TZ: field_name, params = self._convert_field_to_tz(field_name, tzname)
field_name = "CONVERT_TZ(%s, 'UTC', %%s)" % field_name
params = [tzname]
else:
params = []
fields = ['year', 'month', 'day', 'hour', 'minute', 'second'] fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape. format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
format_def = ('0000-', '01', '-01', ' 00:', '00', ':00') format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')

View File

@ -114,6 +114,8 @@ WHEN (new.%(col_name)s IS NULL)
_tzname_re = re.compile(r'^[\w/:+-]+$') _tzname_re = re.compile(r'^[\w/:+-]+$')
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if not settings.USE_TZ:
return field_name
if not self._tzname_re.match(tzname): if not self._tzname_re.match(tzname):
raise ValueError("Invalid time zone name: %s" % tzname) raise ValueError("Invalid time zone name: %s" % tzname)
# Convert from UTC to local time, returning TIMESTAMP WITH TIME ZONE. # Convert from UTC to local time, returning TIMESTAMP WITH TIME ZONE.
@ -127,19 +129,17 @@ WHEN (new.%(col_name)s IS NULL)
# on DATE values, even though they actually store the time part. # on DATE values, even though they actually store the time part.
return "CAST(%s AS TIMESTAMP)" % result return "CAST(%s AS TIMESTAMP)" % result
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_cast_date_sql(self, field_name, tzname):
if settings.USE_TZ:
field_name = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
if lookup_type == 'week_day': sql = 'TRUNC(%s)' % field_name
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday. return sql, []
sql = "TO_CHAR(%s, 'D')" % field_name
else: def datetime_extract_sql(self, lookup_type, field_name, tzname):
# http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions050.htm field_name = self._convert_field_to_tz(field_name, tzname)
sql = "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) sql = self.date_extract_sql(lookup_type, field_name)
return sql, [] return sql, []
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
if settings.USE_TZ:
field_name = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
# http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions230.htm#i1002084 # http://docs.oracle.com/cd/B19306_01/server.102/b14200/functions230.htm#i1002084
if lookup_type in ('year', 'month'): if lookup_type in ('year', 'month'):

View File

@ -32,26 +32,26 @@ class DatabaseOperations(BaseDatabaseOperations):
# http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC # http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def datetime_extract_sql(self, lookup_type, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ: if settings.USE_TZ:
field_name = "%s AT TIME ZONE %%s" % field_name field_name = "%s AT TIME ZONE %%s" % field_name
params = [tzname] params = [tzname]
else: else:
params = [] params = []
# http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT return field_name, params
if lookup_type == 'week_day':
# For consistency across backends, we return Sunday=1, Saturday=7. def datetime_cast_date_sql(self, field_name, tzname):
sql = "EXTRACT('dow' FROM %s) + 1" % field_name field_name, params = self._convert_field_to_tz(field_name, tzname)
else: sql = '(%s)::date' % field_name
sql = "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) return sql, params
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname)
sql = self.date_extract_sql(lookup_type, field_name)
return sql, params return sql, params
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
if settings.USE_TZ: field_name, params = self._convert_field_to_tz(field_name, tzname)
field_name = "%s AT TIME ZONE %%s" % field_name
params = [tzname]
else:
params = []
# http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC # http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
sql = "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) sql = "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
return sql, params return sql, params

View File

@ -207,6 +207,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn = Database.connect(**conn_params) conn = Database.connect(**conn_params)
conn.create_function("django_date_extract", 2, _sqlite_date_extract) conn.create_function("django_date_extract", 2, _sqlite_date_extract)
conn.create_function("django_date_trunc", 2, _sqlite_date_trunc) conn.create_function("django_date_trunc", 2, _sqlite_date_trunc)
conn.create_function("django_datetime_cast_date", 2, _sqlite_datetime_cast_date)
conn.create_function("django_datetime_extract", 3, _sqlite_datetime_extract) conn.create_function("django_datetime_extract", 3, _sqlite_datetime_extract)
conn.create_function("django_datetime_trunc", 3, _sqlite_datetime_trunc) conn.create_function("django_datetime_trunc", 3, _sqlite_datetime_trunc)
conn.create_function("regexp", 2, _sqlite_regexp) conn.create_function("regexp", 2, _sqlite_regexp)
@ -354,7 +355,7 @@ def _sqlite_date_trunc(lookup_type, dt):
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day) return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
def _sqlite_datetime_extract(lookup_type, dt, tzname): def _sqlite_datetime_parse(dt, tzname):
if dt is None: if dt is None:
return None return None
try: try:
@ -363,6 +364,20 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname):
return None return None
if tzname is not None: if tzname is not None:
dt = timezone.localtime(dt, pytz.timezone(tzname)) dt = timezone.localtime(dt, pytz.timezone(tzname))
return dt
def _sqlite_datetime_cast_date(dt, tzname):
dt = _sqlite_datetime_parse(dt, tzname)
if dt is None:
return None
return dt.date().isoformat()
def _sqlite_datetime_extract(lookup_type, dt, tzname):
dt = _sqlite_datetime_parse(dt, tzname)
if dt is None:
return None
if lookup_type == 'week_day': if lookup_type == 'week_day':
return (dt.isoweekday() % 7) + 1 return (dt.isoweekday() % 7) + 1
else: else:
@ -370,12 +385,9 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname):
def _sqlite_datetime_trunc(lookup_type, dt, tzname): def _sqlite_datetime_trunc(lookup_type, dt, tzname):
try: dt = _sqlite_datetime_parse(dt, tzname)
dt = backend_utils.typecast_timestamp(dt) if dt is None:
except (ValueError, TypeError):
return None return None
if tzname is not None:
dt = timezone.localtime(dt, pytz.timezone(tzname))
if lookup_type == 'year': if lookup_type == 'year':
return "%i-01-01 00:00:00" % dt.year return "%i-01-01 00:00:00" % dt.year
elif lookup_type == 'month': elif lookup_type == 'month':

View File

@ -68,21 +68,23 @@ class DatabaseOperations(BaseDatabaseOperations):
# cause a collision with a field name). # cause a collision with a field name).
return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name) return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)
def _require_pytz(self):
if settings.USE_TZ and pytz is None:
raise ImproperlyConfigured("This query requires pytz, but it isn't installed.")
def datetime_cast_date_sql(self, field_name, tzname):
self._require_pytz()
return "django_datetime_cast_date(%s, %%s)" % field_name, [tzname]
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, field_name, tzname):
# Same comment as in date_extract_sql. # Same comment as in date_extract_sql.
if settings.USE_TZ: self._require_pytz()
if pytz is None:
raise ImproperlyConfigured("This query requires pytz, "
"but it isn't installed.")
return "django_datetime_extract('%s', %s, %%s)" % ( return "django_datetime_extract('%s', %s, %%s)" % (
lookup_type.lower(), field_name), [tzname] lookup_type.lower(), field_name), [tzname]
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
# Same comment as in date_trunc_sql. # Same comment as in date_trunc_sql.
if settings.USE_TZ: self._require_pytz()
if pytz is None:
raise ImproperlyConfigured("This query requires pytz, "
"but it isn't installed.")
return "django_datetime_trunc('%s', %s, %%s)" % ( return "django_datetime_trunc('%s', %s, %%s)" % (
lookup_type.lower(), field_name), [tzname] lookup_type.lower(), field_name), [tzname]

View File

@ -1463,6 +1463,22 @@ class DateTimeField(DateField):
return super(DateTimeField, self).formfield(**defaults) return super(DateTimeField, self).formfield(**defaults)
@DateTimeField.register_lookup
class DateTimeDateTransform(Transform):
lookup_name = 'date'
@cached_property
def output_field(self):
return DateField()
def as_sql(self, compiler, connection):
lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params
class DecimalField(Field): class DecimalField(Field):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {

View File

@ -2463,6 +2463,27 @@ numbers and even characters.
Generally speaking, you can't mix dates and datetimes. Generally speaking, you can't mix dates and datetimes.
.. fieldlookup:: date
date
~~~~
.. versionadded:: 1.9
For datetime fields, casts the value as date. Allows chaining additional field
lookups. Takes a date value.
Example::
Entry.objects.filter(pub_date__date=datetime.date(2005, 1, 1))
Entry.objects.filter(pub_date__date__gt=datetime.date(2005, 1, 1))
(No equivalent SQL code fragment is included for this lookup because
implementation of the relevant query varies among different database engines.)
When :setting:`USE_TZ` is ``True``, fields are converted to the current time
zone before filtering.
.. fieldlookup:: year .. fieldlookup:: year
year year

View File

@ -233,6 +233,9 @@ Models
:class:`~django.db.models.Avg` aggregate in order to aggregate over :class:`~django.db.models.Avg` aggregate in order to aggregate over
non-numeric columns, such as ``DurationField``. non-numeric columns, such as ``DurationField``.
* Added the :lookup:`date` lookup to :class:`~django.db.models.DateTimeField`
to allow querying the field by only the date portion.
CSRF CSRF
^^^^ ^^^^
@ -346,6 +349,9 @@ Database backend API
``adapt_<type>field_value()`` to mirror the ``convert_<type>field_value()`` ``adapt_<type>field_value()`` to mirror the ``convert_<type>field_value()``
methods. methods.
* To use the new ``date`` lookup, third-party database backends may need to
implement the ``DatabaseOperations.datetime_cast_date_sql()`` method.
Default settings that were tuples are now lists Default settings that were tuples are now lists
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -18,7 +18,8 @@ from django.db.models.fields import (
TimeField, URLField, TimeField, URLField,
) )
from django.db.models.fields.files import FileField, ImageField from django.db.models.fields.files import FileField, ImageField
from django.utils import six from django.test.utils import requires_tz_support
from django.utils import six, timezone
from django.utils.functional import lazy from django.utils.functional import lazy
from .models import ( from .models import (
@ -274,6 +275,48 @@ class DateTimeFieldTests(test.TestCase):
self.assertEqual(obj.dt, datetim) self.assertEqual(obj.dt, datetim)
self.assertEqual(obj.t, tim) self.assertEqual(obj.t, tim)
@test.override_settings(USE_TZ=False)
def test_lookup_date_without_use_tz(self):
d = datetime.date(2014, 3, 12)
dt1 = datetime.datetime(2014, 3, 12, 21, 22, 23, 240000)
dt2 = datetime.datetime(2014, 3, 11, 21, 22, 23, 240000)
t = datetime.time(21, 22, 23, 240000)
m = DateTimeModel.objects.create(d=d, dt=dt1, t=t)
# Other model with different datetime.
DateTimeModel.objects.create(d=d, dt=dt2, t=t)
self.assertEqual(m, DateTimeModel.objects.get(dt__date=d))
@requires_tz_support
@test.skipUnlessDBFeature('has_zoneinfo_database')
@test.override_settings(USE_TZ=True, TIME_ZONE='America/Vancouver')
def test_lookup_date_with_use_tz(self):
d = datetime.date(2014, 3, 12)
# The following is equivalent to UTC 2014-03-12 18:34:23.24000.
dt1 = datetime.datetime(
2014, 3, 12, 10, 22, 23, 240000,
tzinfo=timezone.get_current_timezone()
)
# The following is equivalent to UTC 2014-03-13 05:34:23.24000.
dt2 = datetime.datetime(
2014, 3, 12, 21, 22, 23, 240000,
tzinfo=timezone.get_current_timezone()
)
t = datetime.time(21, 22, 23, 240000)
m1 = DateTimeModel.objects.create(d=d, dt=dt1, t=t)
m2 = DateTimeModel.objects.create(d=d, dt=dt2, t=t)
# In Vancouver, we expect both results.
self.assertQuerysetEqual(
DateTimeModel.objects.filter(dt__date=d),
[repr(m1), repr(m2)],
ordered=False
)
with self.settings(TIME_ZONE='UTC'):
# But in UTC, the __date only matches one of them.
self.assertQuerysetEqual(
DateTimeModel.objects.filter(dt__date=d),
[repr(m1)]
)
class BooleanFieldTests(test.TestCase): class BooleanFieldTests(test.TestCase):
def _test_get_db_prep_lookup(self, f): def _test_get_db_prep_lookup(self, f):