Fixed #31640 -- Made Trunc() truncate datetimes to Date/TimeField in a specific timezone.
This commit is contained in:
parent
8d018231ac
commit
ee005328c8
|
@ -99,11 +99,14 @@ class BaseDatabaseOperations:
|
|||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name):
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
truncates the given date field field_name to a date object with only
|
||||
the given specificity.
|
||||
truncates the given date or datetime field field_name to a date object
|
||||
with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')
|
||||
|
||||
|
@ -138,11 +141,14 @@ class BaseDatabaseOperations:
|
|||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name):
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
|
||||
that truncates the given time field field_name to a time object with
|
||||
only the given specificity.
|
||||
that truncates the given time or datetime field field_name to a time
|
||||
object with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')
|
||||
|
||||
|
|
|
@ -55,7 +55,8 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
|
||||
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name):
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
fields = {
|
||||
'year': '%%Y-01-01',
|
||||
'month': '%%Y-%%m-01',
|
||||
|
@ -82,7 +83,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
return tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
if settings.USE_TZ and self.connection.timezone_name != tzname:
|
||||
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
|
||||
field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
|
||||
field_name,
|
||||
self.connection.timezone_name,
|
||||
|
@ -128,7 +129,8 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
|
||||
return sql
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name):
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
fields = {
|
||||
'hour': '%%H:00:00',
|
||||
'minute': '%%H:%%i:00',
|
||||
|
|
|
@ -89,7 +89,8 @@ END;
|
|||
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
|
||||
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name):
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
if lookup_type in ('year', 'month'):
|
||||
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
|
||||
|
@ -114,7 +115,7 @@ END;
|
|||
return tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
if not settings.USE_TZ:
|
||||
if not (settings.USE_TZ and tzname):
|
||||
return field_name
|
||||
if not self._tzname_re.match(tzname):
|
||||
raise ValueError("Invalid time zone name: %s" % tzname)
|
||||
|
@ -161,10 +162,11 @@ END;
|
|||
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
|
||||
return sql
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name):
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
# The implementation is similar to `datetime_trunc_sql` as both
|
||||
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
|
||||
# the date part of the later is ignored.
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
if lookup_type == 'hour':
|
||||
sql = "TRUNC(%s, 'HH24')" % field_name
|
||||
elif lookup_type == 'minute':
|
||||
|
|
|
@ -38,7 +38,8 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
else:
|
||||
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name):
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
|
||||
|
||||
|
@ -50,7 +51,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
return tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
if settings.USE_TZ:
|
||||
if tzname and settings.USE_TZ:
|
||||
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
|
||||
return field_name
|
||||
|
||||
|
@ -71,7 +72,8 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name):
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
|
||||
|
||||
def deferrable_sql(self):
|
||||
|
|
|
@ -213,13 +213,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|||
else:
|
||||
create_deterministic_function = conn.create_function
|
||||
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
|
||||
create_deterministic_function('django_date_trunc', 2, _sqlite_date_trunc)
|
||||
create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
|
||||
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
|
||||
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
|
||||
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
|
||||
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
|
||||
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
|
||||
create_deterministic_function('django_time_trunc', 2, _sqlite_time_trunc)
|
||||
create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
|
||||
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
|
||||
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
|
||||
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
|
||||
|
@ -445,8 +445,8 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
|
|||
return dt
|
||||
|
||||
|
||||
def _sqlite_date_trunc(lookup_type, dt):
|
||||
dt = _sqlite_datetime_parse(dt)
|
||||
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == 'year':
|
||||
|
@ -463,13 +463,17 @@ def _sqlite_date_trunc(lookup_type, dt):
|
|||
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
|
||||
|
||||
|
||||
def _sqlite_time_trunc(lookup_type, dt):
|
||||
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
if dt is None:
|
||||
return None
|
||||
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt_parsed is None:
|
||||
try:
|
||||
dt = backend_utils.typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
else:
|
||||
dt = dt_parsed
|
||||
if lookup_type == 'hour':
|
||||
return "%02i:00:00" % dt.hour
|
||||
elif lookup_type == 'minute':
|
||||
|
|
|
@ -77,14 +77,22 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
"""Do nothing since formatting is handled in the custom function."""
|
||||
return sql
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name):
|
||||
return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_date_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name):
|
||||
return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_time_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def _convert_tznames_to_sql(self, tzname):
|
||||
if settings.USE_TZ:
|
||||
if tzname and settings.USE_TZ:
|
||||
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
|
||||
return 'NULL', 'NULL'
|
||||
|
||||
|
|
|
@ -193,13 +193,17 @@ class TruncBase(TimezoneMixin, Transform):
|
|||
|
||||
def as_sql(self, compiler, connection):
|
||||
inner_sql, inner_params = compiler.compile(self.lhs)
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
tzname = None
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError('tzinfo can only be used with DateTimeField.')
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
sql = connection.ops.date_trunc_sql(self.kind, inner_sql)
|
||||
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql = connection.ops.time_trunc_sql(self.kind, inner_sql)
|
||||
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
|
||||
else:
|
||||
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
|
||||
return sql, inner_params
|
||||
|
|
|
@ -458,6 +458,10 @@ backends.
|
|||
* ``DatabaseOperations.random_function_sql()`` is removed in favor of the new
|
||||
:class:`~django.db.models.functions.Random` database function.
|
||||
|
||||
* ``DatabaseOperations.date_trunc_sql()`` and
|
||||
``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname``
|
||||
argument in order to truncate in a specific timezone.
|
||||
|
||||
:mod:`django.contrib.admin`
|
||||
---------------------------
|
||||
|
||||
|
|
|
@ -672,6 +672,18 @@ class DateFunctionTests(TestCase):
|
|||
lambda m: (m.start_datetime, m.truncated)
|
||||
)
|
||||
|
||||
def test_datetime_to_time_kind(kind):
|
||||
self.assertQuerysetEqual(
|
||||
DTModel.objects.annotate(
|
||||
truncated=Trunc('start_datetime', kind, output_field=TimeField()),
|
||||
).order_by('start_datetime'),
|
||||
[
|
||||
(start_datetime, truncate_to(start_datetime.time(), kind)),
|
||||
(end_datetime, truncate_to(end_datetime.time(), kind)),
|
||||
],
|
||||
lambda m: (m.start_datetime, m.truncated),
|
||||
)
|
||||
|
||||
test_date_kind('year')
|
||||
test_date_kind('quarter')
|
||||
test_date_kind('month')
|
||||
|
@ -688,6 +700,9 @@ class DateFunctionTests(TestCase):
|
|||
test_datetime_kind('hour')
|
||||
test_datetime_kind('minute')
|
||||
test_datetime_kind('second')
|
||||
test_datetime_to_time_kind('hour')
|
||||
test_datetime_to_time_kind('minute')
|
||||
test_datetime_to_time_kind('second')
|
||||
|
||||
qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField()))
|
||||
self.assertEqual(qs.count(), 2)
|
||||
|
@ -1205,6 +1220,60 @@ class DateFunctionWithTimeZoneTests(DateFunctionTests):
|
|||
lambda m: (m.start_datetime, m.truncated)
|
||||
)
|
||||
|
||||
def test_datetime_to_date_kind(kind):
|
||||
self.assertQuerysetEqual(
|
||||
DTModel.objects.annotate(
|
||||
truncated=Trunc(
|
||||
'start_datetime',
|
||||
kind,
|
||||
output_field=DateField(),
|
||||
tzinfo=melb,
|
||||
),
|
||||
).order_by('start_datetime'),
|
||||
[
|
||||
(
|
||||
start_datetime,
|
||||
truncate_to(start_datetime.astimezone(melb).date(), kind),
|
||||
),
|
||||
(
|
||||
end_datetime,
|
||||
truncate_to(end_datetime.astimezone(melb).date(), kind),
|
||||
),
|
||||
],
|
||||
lambda m: (m.start_datetime, m.truncated),
|
||||
)
|
||||
|
||||
def test_datetime_to_time_kind(kind):
|
||||
self.assertQuerysetEqual(
|
||||
DTModel.objects.annotate(
|
||||
truncated=Trunc(
|
||||
'start_datetime',
|
||||
kind,
|
||||
output_field=TimeField(),
|
||||
tzinfo=melb,
|
||||
)
|
||||
).order_by('start_datetime'),
|
||||
[
|
||||
(
|
||||
start_datetime,
|
||||
truncate_to(start_datetime.astimezone(melb).time(), kind),
|
||||
),
|
||||
(
|
||||
end_datetime,
|
||||
truncate_to(end_datetime.astimezone(melb).time(), kind),
|
||||
),
|
||||
],
|
||||
lambda m: (m.start_datetime, m.truncated),
|
||||
)
|
||||
|
||||
test_datetime_to_date_kind('year')
|
||||
test_datetime_to_date_kind('quarter')
|
||||
test_datetime_to_date_kind('month')
|
||||
test_datetime_to_date_kind('week')
|
||||
test_datetime_to_date_kind('day')
|
||||
test_datetime_to_time_kind('hour')
|
||||
test_datetime_to_time_kind('minute')
|
||||
test_datetime_to_time_kind('second')
|
||||
test_datetime_kind('year')
|
||||
test_datetime_kind('quarter')
|
||||
test_datetime_kind('month')
|
||||
|
@ -1216,3 +1285,15 @@ class DateFunctionWithTimeZoneTests(DateFunctionTests):
|
|||
|
||||
qs = DTModel.objects.filter(start_datetime__date=Trunc('start_datetime', 'day', output_field=DateField()))
|
||||
self.assertEqual(qs.count(), 2)
|
||||
|
||||
def test_trunc_invalid_field_with_timezone(self):
|
||||
melb = pytz.timezone('Australia/Melbourne')
|
||||
msg = 'tzinfo can only be used with DateTimeField.'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
DTModel.objects.annotate(
|
||||
day_melb=Trunc('start_date', 'day', tzinfo=melb),
|
||||
).get()
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
DTModel.objects.annotate(
|
||||
hour_melb=Trunc('start_time', 'hour', tzinfo=melb),
|
||||
).get()
|
||||
|
|
Loading…
Reference in New Issue