Fixed #31640 -- Made Trunc() truncate datetimes to Date/TimeField in a specific timezone.

This commit is contained in:
David-Wobrock 2020-10-04 19:28:21 +02:00 committed by Mariusz Felisiak
parent 8d018231ac
commit ee005328c8
9 changed files with 145 additions and 32 deletions

View File

@ -99,11 +99,14 @@ class BaseDatabaseOperations:
""" """
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method') raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')
def date_trunc_sql(self, lookup_type, field_name): def date_trunc_sql(self, lookup_type, field_name, tzname=None):
""" """
Given a lookup_type of 'year', 'month', or 'day', return the SQL that Given a lookup_type of 'year', 'month', or 'day', return the SQL that
truncates the given date field field_name to a date object with only truncates the given date or datetime field field_name to a date object
the given specificity. with only the given specificity.
If `tzname` is provided, the given value is truncated in a specific
timezone.
""" """
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.') raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')
@ -138,11 +141,14 @@ class BaseDatabaseOperations:
""" """
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method') raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name, tzname=None):
""" """
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
that truncates the given time field field_name to a time object with that truncates the given time or datetime field field_name to a time
only the given specificity. object with only the given specificity.
If `tzname` is provided, the given value is truncated in a specific
timezone.
""" """
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method') raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')

View File

@ -55,7 +55,8 @@ class DatabaseOperations(BaseDatabaseOperations):
# EXTRACT returns 1-53 based on ISO-8601 for the week number. # EXTRACT returns 1-53 based on ISO-8601 for the week number.
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def date_trunc_sql(self, lookup_type, field_name): def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = { fields = {
'year': '%%Y-01-01', 'year': '%%Y-01-01',
'month': '%%Y-%%m-01', 'month': '%%Y-%%m-01',
@ -82,7 +83,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return tzname return tzname
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ and self.connection.timezone_name != tzname: if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
field_name = "CONVERT_TZ(%s, '%s', '%s')" % ( field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
field_name, field_name,
self.connection.timezone_name, self.connection.timezone_name,
@ -128,7 +129,8 @@ class DatabaseOperations(BaseDatabaseOperations):
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql return sql
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = { fields = {
'hour': '%%H:00:00', 'hour': '%%H:00:00',
'minute': '%%H:%%i:00', 'minute': '%%H:%%i:00',

View File

@ -89,7 +89,8 @@ END;
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def date_trunc_sql(self, lookup_type, field_name): def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html # https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'): if lookup_type in ('year', 'month'):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
@ -114,7 +115,7 @@ END;
return tzname return tzname
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if not settings.USE_TZ: if not (settings.USE_TZ and tzname):
return field_name return field_name
if not self._tzname_re.match(tzname): if not self._tzname_re.match(tzname):
raise ValueError("Invalid time zone name: %s" % tzname) raise ValueError("Invalid time zone name: %s" % tzname)
@ -161,10 +162,11 @@ END;
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision. sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql return sql
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name, tzname=None):
# The implementation is similar to `datetime_trunc_sql` as both # The implementation is similar to `datetime_trunc_sql` as both
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where # `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored. # the date part of the later is ignored.
field_name = self._convert_field_to_tz(field_name, tzname)
if lookup_type == 'hour': if lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == 'minute': elif lookup_type == 'minute':

View File

@ -38,7 +38,8 @@ class DatabaseOperations(BaseDatabaseOperations):
else: else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name) return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
def date_trunc_sql(self, lookup_type, field_name): def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
@ -50,7 +51,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return tzname return tzname
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ: if tzname and settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname)) field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
return field_name return field_name
@ -71,7 +72,8 @@ class DatabaseOperations(BaseDatabaseOperations):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
def deferrable_sql(self): def deferrable_sql(self):

View File

@ -213,13 +213,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else: else:
create_deterministic_function = conn.create_function create_deterministic_function = conn.create_function
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract) create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
create_deterministic_function('django_date_trunc', 2, _sqlite_date_trunc) create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date) create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time) create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract) create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc) create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract) create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
create_deterministic_function('django_time_trunc', 2, _sqlite_time_trunc) create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff) create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff) create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta) create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
@ -445,8 +445,8 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
return dt return dt
def _sqlite_date_trunc(lookup_type, dt): def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt) dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None: if dt is None:
return None return None
if lookup_type == 'year': if lookup_type == 'year':
@ -463,13 +463,17 @@ def _sqlite_date_trunc(lookup_type, dt):
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day) return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
def _sqlite_time_trunc(lookup_type, dt): def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
if dt is None: if dt is None:
return None return None
try: dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
dt = backend_utils.typecast_time(dt) if dt_parsed is None:
except (ValueError, TypeError): try:
return None dt = backend_utils.typecast_time(dt)
except (ValueError, TypeError):
return None
else:
dt = dt_parsed
if lookup_type == 'hour': if lookup_type == 'hour':
return "%02i:00:00" % dt.hour return "%02i:00:00" % dt.hour
elif lookup_type == 'minute': elif lookup_type == 'minute':

View File

@ -77,14 +77,22 @@ class DatabaseOperations(BaseDatabaseOperations):
"""Do nothing since formatting is handled in the custom function.""" """Do nothing since formatting is handled in the custom function."""
return sql return sql
def date_trunc_sql(self, lookup_type, field_name): def date_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name) return "django_date_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
)
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name) return "django_time_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
)
def _convert_tznames_to_sql(self, tzname): def _convert_tznames_to_sql(self, tzname):
if settings.USE_TZ: if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
return 'NULL', 'NULL' return 'NULL', 'NULL'

View File

@ -193,13 +193,17 @@ class TruncBase(TimezoneMixin, Transform):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
inner_sql, inner_params = compiler.compile(self.lhs) inner_sql, inner_params = compiler.compile(self.lhs)
if isinstance(self.output_field, DateTimeField): tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname() tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError('tzinfo can only be used with DateTimeField.')
if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname) sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField): elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql) sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, TimeField): elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql) sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
else: else:
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.') raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
return sql, inner_params return sql, inner_params

View File

@ -458,6 +458,10 @@ backends.
* ``DatabaseOperations.random_function_sql()`` is removed in favor of the new * ``DatabaseOperations.random_function_sql()`` is removed in favor of the new
:class:`~django.db.models.functions.Random` database function. :class:`~django.db.models.functions.Random` database function.
* ``DatabaseOperations.date_trunc_sql()`` and
``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname``
argument in order to truncate in a specific timezone.
:mod:`django.contrib.admin` :mod:`django.contrib.admin`
--------------------------- ---------------------------

View File

@ -672,6 +672,18 @@ class DateFunctionTests(TestCase):
lambda m: (m.start_datetime, m.truncated) lambda m: (m.start_datetime, m.truncated)
) )
def test_datetime_to_time_kind(kind):
self.assertQuerysetEqual(
DTModel.objects.annotate(
truncated=Trunc('start_datetime', kind, output_field=TimeField()),
).order_by('start_datetime'),
[
(start_datetime, truncate_to(start_datetime.time(), kind)),
(end_datetime, truncate_to(end_datetime.time(), kind)),
],
lambda m: (m.start_datetime, m.truncated),
)
test_date_kind('year') test_date_kind('year')
test_date_kind('quarter') test_date_kind('quarter')
test_date_kind('month') test_date_kind('month')
@ -688,6 +700,9 @@ class DateFunctionTests(TestCase):
test_datetime_kind('hour') test_datetime_kind('hour')
test_datetime_kind('minute') test_datetime_kind('minute')
test_datetime_kind('second') test_datetime_kind('second')
test_datetime_to_time_kind('hour')
test_datetime_to_time_kind('minute')
test_datetime_to_time_kind('second')
qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField())) qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField()))
self.assertEqual(qs.count(), 2) self.assertEqual(qs.count(), 2)
@ -1205,6 +1220,60 @@ class DateFunctionWithTimeZoneTests(DateFunctionTests):
lambda m: (m.start_datetime, m.truncated) lambda m: (m.start_datetime, m.truncated)
) )
def test_datetime_to_date_kind(kind):
self.assertQuerysetEqual(
DTModel.objects.annotate(
truncated=Trunc(
'start_datetime',
kind,
output_field=DateField(),
tzinfo=melb,
),
).order_by('start_datetime'),
[
(
start_datetime,
truncate_to(start_datetime.astimezone(melb).date(), kind),
),
(
end_datetime,
truncate_to(end_datetime.astimezone(melb).date(), kind),
),
],
lambda m: (m.start_datetime, m.truncated),
)
def test_datetime_to_time_kind(kind):
self.assertQuerysetEqual(
DTModel.objects.annotate(
truncated=Trunc(
'start_datetime',
kind,
output_field=TimeField(),
tzinfo=melb,
)
).order_by('start_datetime'),
[
(
start_datetime,
truncate_to(start_datetime.astimezone(melb).time(), kind),
),
(
end_datetime,
truncate_to(end_datetime.astimezone(melb).time(), kind),
),
],
lambda m: (m.start_datetime, m.truncated),
)
test_datetime_to_date_kind('year')
test_datetime_to_date_kind('quarter')
test_datetime_to_date_kind('month')
test_datetime_to_date_kind('week')
test_datetime_to_date_kind('day')
test_datetime_to_time_kind('hour')
test_datetime_to_time_kind('minute')
test_datetime_to_time_kind('second')
test_datetime_kind('year') test_datetime_kind('year')
test_datetime_kind('quarter') test_datetime_kind('quarter')
test_datetime_kind('month') test_datetime_kind('month')
@ -1216,3 +1285,15 @@ class DateFunctionWithTimeZoneTests(DateFunctionTests):
qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField())) qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField()))
self.assertEqual(qs.count(), 2) self.assertEqual(qs.count(), 2)
def test_trunc_invalid_field_with_timezone(self):
melb = pytz.timezone('Australia/Melbourne')
msg = 'tzinfo can only be used with DateTimeField.'
with self.assertRaisesMessage(ValueError, msg):
DTModel.objects.annotate(
day_melb=Trunc('start_date', 'day', tzinfo=melb),
).get()
with self.assertRaisesMessage(ValueError, msg):
DTModel.objects.annotate(
hour_melb=Trunc('start_time', 'hour', tzinfo=melb),
).get()