Fixed #27802 -- Unified return value of db backend datetime SQL methods.

This commit is contained in:
Mariusz Felisiak 2017-02-01 14:48:04 +01:00 committed by Tim Graham
parent 8b62e5df86
commit 15c14f6f16
6 changed files with 48 additions and 57 deletions

View File

@ -44,29 +44,23 @@ class DatabaseOperations(BaseDatabaseOperations):
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ: if settings.USE_TZ:
field_name = "CONVERT_TZ(%s, 'UTC', %%s)" % field_name field_name = "CONVERT_TZ(%s, 'UTC', '%s')" % (field_name, tzname)
params = [tzname] return field_name
else:
params = []
return field_name, params
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = "DATE(%s)" % field_name return "DATE(%s)" % field_name
return sql, params
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = "TIME(%s)" % field_name return "TIME(%s)" % field_name
return sql, params
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = self.date_extract_sql(lookup_type, field_name) return self.date_extract_sql(lookup_type, field_name)
return sql, params
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
fields = ['year', 'month', 'day', 'hour', 'minute', 'second'] fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape. format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
format_def = ('0000-', '01', '-01', ' 00:', '00', ':00') format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
@ -77,7 +71,7 @@ class DatabaseOperations(BaseDatabaseOperations):
else: else:
format_str = ''.join([f for f in format[:i]] + [f for f in format_def[i:]]) format_str = ''.join([f for f in format[:i]] + [f for f in format_def[i:]])
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql, params return sql
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name):
fields = { fields = {

View File

@ -122,19 +122,16 @@ WHEN (new.%(col_name)s IS NULL)
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = 'TRUNC(%s)' % field_name return 'TRUNC(%s)' % field_name
return sql, []
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, field_name, tzname):
# Since `TimeField` values are stored as TIMESTAMP where only the date # Since `TimeField` values are stored as TIMESTAMP where only the date
# part is ignored, convert the field to the specified timezone. # part is ignored, convert the field to the specified timezone.
field_name = self._convert_field_to_tz(field_name, tzname) return self._convert_field_to_tz(field_name, tzname)
return field_name, []
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = self.date_extract_sql(lookup_type, field_name) return self.date_extract_sql(lookup_type, field_name)
return sql, []
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
@ -149,7 +146,7 @@ WHEN (new.%(col_name)s IS NULL)
sql = "TRUNC(%s, 'MI')" % field_name sql = "TRUNC(%s, 'MI')" % field_name
else: else:
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision. sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql, [] return sql
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name):
# The implementation is similar to `datetime_trunc_sql` as both # The implementation is similar to `datetime_trunc_sql` as both

View File

@ -32,32 +32,25 @@ class DatabaseOperations(BaseDatabaseOperations):
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, field_name, tzname):
if settings.USE_TZ: if settings.USE_TZ:
field_name = "%s AT TIME ZONE %%s" % field_name field_name = "%s AT TIME ZONE '%s'" % (field_name, tzname)
params = [tzname] return field_name
else:
params = []
return field_name, params
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = '(%s)::date' % field_name return '(%s)::date' % field_name
return sql, params
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = '(%s)::time' % field_name return '(%s)::time' % field_name
return sql, params
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
sql = self.date_extract_sql(lookup_type, field_name) return self.date_extract_sql(lookup_type, field_name)
return sql, params
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name, params = self._convert_field_to_tz(field_name, tzname) field_name = self._convert_field_to_tz(field_name, tzname)
# https://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC # https://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
sql = "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
return sql, params
def time_trunc_sql(self, lookup_type, field_name): def time_trunc_sql(self, lookup_type, field_name):
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)

View File

@ -70,21 +70,30 @@ class DatabaseOperations(BaseDatabaseOperations):
# cause a collision with a field name). # cause a collision with a field name).
return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name) return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name)
def _convert_tzname_to_sql(self, tzname):
return "'%s'" % tzname if settings.USE_TZ else 'NULL'
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, field_name, tzname):
return "django_datetime_cast_date(%s, %%s)" % field_name, [tzname] return "django_datetime_cast_date(%s, %s)" % (
field_name, self._convert_tzname_to_sql(tzname),
)
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, field_name, tzname):
return "django_datetime_cast_time(%s, %%s)" % field_name, [tzname] return "django_datetime_cast_time(%s, %s)" % (
field_name, self._convert_tzname_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, field_name, tzname):
# Same comment as in date_extract_sql. # Same comment as in date_extract_sql.
return "django_datetime_extract('%s', %s, %%s)" % ( return "django_datetime_extract('%s', %s, %s)" % (
lookup_type.lower(), field_name), [tzname] lookup_type.lower(), field_name, self._convert_tzname_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, field_name, tzname):
# Same comment as in date_trunc_sql. # Same comment as in date_trunc_sql.
return "django_datetime_trunc('%s', %s, %%s)" % ( return "django_datetime_trunc('%s', %s, %s)" % (
lookup_type.lower(), field_name), [tzname] lookup_type.lower(), field_name, self._convert_tzname_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, field_name): def time_extract_sql(self, lookup_type, field_name):
# sqlite doesn't support extract, so we fake it with the user-defined # sqlite doesn't support extract, so we fake it with the user-defined

View File

@ -44,8 +44,7 @@ class Extract(TimezoneMixin, Transform):
lhs_output_field = self.lhs.output_field lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField): if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname() tzname = self.get_tzname()
sql, tz_params = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
params.extend(tz_params)
elif isinstance(lhs_output_field, DateField): elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql) sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField): elif isinstance(lhs_output_field, TimeField):
@ -150,16 +149,14 @@ class TruncBase(TimezoneMixin, Transform):
inner_sql = inner_sql.replace('%s', '%%s') inner_sql = inner_sql.replace('%s', '%%s')
if isinstance(self.output_field, DateTimeField): if isinstance(self.output_field, DateTimeField):
tzname = self.get_tzname() tzname = self.get_tzname()
sql, params = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname) sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField): elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql) sql = connection.ops.date_trunc_sql(self.kind, inner_sql)
params = []
elif isinstance(self.output_field, TimeField): elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql) sql = connection.ops.time_trunc_sql(self.kind, inner_sql)
params = []
else: else:
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.') raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
return sql, inner_params + params return sql, inner_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
@ -237,8 +234,7 @@ class TruncDate(TruncBase):
# Cast to date rather than truncate to date. # Cast to date rather than truncate to date.
lhs, lhs_params = compiler.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_date_sql(lhs, tzname) sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params return sql, lhs_params
@ -254,8 +250,7 @@ class TruncTime(TruncBase):
# Cast to date rather than truncate to date. # Cast to date rather than truncate to date.
lhs, lhs_params = compiler.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_cast_time_sql(lhs, tzname) sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
lhs_params.extend(tz_params)
return sql, lhs_params return sql, lhs_params

View File

@ -220,7 +220,10 @@ bytestrings in certain code paths.
Database backend API Database backend API
-------------------- --------------------
* ... * The ``DatabaseOperations.datetime_cast_date_sql()``,
``datetime_cast_time_sql()``, ``datetime_trunc_sql()``, and
``datetime_extract_sql()`` methods now return only the SQL to perform the
operation instead of SQL and a list of parameters.
Dropped support for Oracle 11.2 Dropped support for Oracle 11.2
------------------------------- -------------------------------