Refs CVE-2022-34265 -- Properly escaped Extract() and Trunc() parameters.
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
parent
73766c1187
commit
877c800f25
|
@ -9,7 +9,6 @@ from django.db import NotSupportedError, transaction
|
|||
from django.db.backends import utils
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class BaseDatabaseOperations:
|
||||
|
@ -55,8 +54,6 @@ class BaseDatabaseOperations:
|
|||
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
|
||||
explain_prefix = None
|
||||
|
||||
extract_trunc_lookup_pattern = _lazy_re_compile(r"[\w\-_()]+")
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self._cache = None
|
||||
|
@ -103,7 +100,7 @@ class BaseDatabaseOperations:
|
|||
"""
|
||||
return "%s"
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
extracts a value from the given date field field_name.
|
||||
|
@ -113,7 +110,7 @@ class BaseDatabaseOperations:
|
|||
"method"
|
||||
)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
truncates the given date or datetime field field_name to a date object
|
||||
|
@ -127,7 +124,7 @@ class BaseDatabaseOperations:
|
|||
"method."
|
||||
)
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to date value.
|
||||
"""
|
||||
|
@ -136,7 +133,7 @@ class BaseDatabaseOperations:
|
|||
"datetime_cast_date_sql() method."
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to time value.
|
||||
"""
|
||||
|
@ -145,7 +142,7 @@ class BaseDatabaseOperations:
|
|||
"datetime_cast_time_sql() method"
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that extracts a value from the given
|
||||
|
@ -156,7 +153,7 @@ class BaseDatabaseOperations:
|
|||
"method"
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that truncates the given datetime field
|
||||
|
@ -167,7 +164,7 @@ class BaseDatabaseOperations:
|
|||
"method"
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
|
||||
that truncates the given time or datetime field field_name to a time
|
||||
|
@ -180,12 +177,12 @@ class BaseDatabaseOperations:
|
|||
"subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, field_name):
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
|
||||
that extracts a value from the given time field field_name.
|
||||
"""
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def deferrable_sql(self):
|
||||
"""
|
||||
|
|
|
@ -7,6 +7,7 @@ from django.db.models import Exists, ExpressionWrapper, Lookup
|
|||
from django.db.models.constants import OnConflict
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
|
@ -37,117 +38,115 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
cast_char_field_without_max_length = "char"
|
||||
explain_prefix = "EXPLAIN"
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
|
||||
if lookup_type == "week_day":
|
||||
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
|
||||
return "DAYOFWEEK(%s)" % field_name
|
||||
return f"DAYOFWEEK({sql})", params
|
||||
elif lookup_type == "iso_week_day":
|
||||
# WEEKDAY() returns an integer, 0-6, Monday=0.
|
||||
return "WEEKDAY(%s) + 1" % field_name
|
||||
return f"WEEKDAY({sql}) + 1", params
|
||||
elif lookup_type == "week":
|
||||
# Override the value of default_week_format for consistency with
|
||||
# other database backends.
|
||||
# Mode 3: Monday, 1-53, with 4 or more days this year.
|
||||
return "WEEK(%s, 3)" % field_name
|
||||
return f"WEEK({sql}, 3)", params
|
||||
elif lookup_type == "iso_year":
|
||||
# Get the year part from the YEARWEEK function, which returns a
|
||||
# number as year * 100 + week.
|
||||
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
|
||||
return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
|
||||
else:
|
||||
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
|
||||
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
fields = {
|
||||
"year": "%%Y-01-01",
|
||||
"month": "%%Y-%%m-01",
|
||||
} # Use double percents to escape.
|
||||
"year": "%Y-01-01",
|
||||
"month": "%Y-%m-01",
|
||||
}
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
|
||||
elif lookup_type == "quarter":
|
||||
return (
|
||||
"MAKEDATE(YEAR(%s), 1) + "
|
||||
"INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER"
|
||||
% (field_name, field_name)
|
||||
f"MAKEDATE(YEAR({sql}), 1) + "
|
||||
f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
|
||||
(*params, *params),
|
||||
)
|
||||
elif lookup_type == "week":
|
||||
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (field_name, field_name)
|
||||
return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
|
||||
else:
|
||||
return "DATE(%s)" % (field_name)
|
||||
return f"DATE({sql})", params
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f"{sign}{offset}" if offset else tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
def _convert_field_to_tz(self, sql, params, tzname):
|
||||
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
|
||||
field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
|
||||
field_name,
|
||||
return f"CONVERT_TZ({sql}, %s, %s)", (
|
||||
*params,
|
||||
self.connection.timezone_name,
|
||||
self._prepare_tzname_delta(tzname),
|
||||
)
|
||||
return field_name
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "DATE(%s)" % field_name
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
return f"DATE({sql})", params
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "TIME(%s)" % field_name
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
return f"TIME({sql})", params
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
fields = ["year", "month", "day", "hour", "minute", "second"]
|
||||
format = (
|
||||
"%%Y-",
|
||||
"%%m",
|
||||
"-%%d",
|
||||
" %%H:",
|
||||
"%%i",
|
||||
":%%s",
|
||||
) # Use double percents to escape.
|
||||
format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
|
||||
format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
|
||||
if lookup_type == "quarter":
|
||||
return (
|
||||
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
|
||||
"INTERVAL QUARTER({field_name}) QUARTER - "
|
||||
+ "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
|
||||
).format(field_name=field_name)
|
||||
f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
|
||||
f"INTERVAL QUARTER({sql}) QUARTER - "
|
||||
f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
|
||||
), (*params, *params, "%Y-%m-01 00:00:00")
|
||||
if lookup_type == "week":
|
||||
return (
|
||||
"CAST(DATE_FORMAT(DATE_SUB({field_name}, "
|
||||
"INTERVAL WEEKDAY({field_name}) DAY), "
|
||||
"'%%Y-%%m-%%d 00:00:00') AS DATETIME)"
|
||||
).format(field_name=field_name)
|
||||
f"CAST(DATE_FORMAT("
|
||||
f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
|
||||
), (*params, *params, "%Y-%m-%d 00:00:00")
|
||||
try:
|
||||
i = fields.index(lookup_type) + 1
|
||||
except ValueError:
|
||||
sql = field_name
|
||||
pass
|
||||
else:
|
||||
format_str = "".join(format[:i] + format_def[i:])
|
||||
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
|
||||
return sql
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
|
||||
return sql, params
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
fields = {
|
||||
"hour": "%%H:00:00",
|
||||
"minute": "%%H:%%i:00",
|
||||
"second": "%%H:%%i:%%s",
|
||||
} # Use double percents to escape.
|
||||
"hour": "%H:00:00",
|
||||
"minute": "%H:%i:00",
|
||||
"second": "%H:%i:%s",
|
||||
}
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str)
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
|
||||
else:
|
||||
return "TIME(%s)" % (field_name)
|
||||
return f"TIME({sql})", params
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
|
|
|
@ -77,34 +77,46 @@ END;
|
|||
f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY"
|
||||
)
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
extract_sql = f"TO_CHAR({sql}, %s)"
|
||||
extract_param = None
|
||||
if lookup_type == "week_day":
|
||||
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
|
||||
return "TO_CHAR(%s, 'D')" % field_name
|
||||
extract_param = "D"
|
||||
elif lookup_type == "iso_week_day":
|
||||
return "TO_CHAR(%s - 1, 'D')" % field_name
|
||||
extract_sql = f"TO_CHAR({sql} - 1, %s)"
|
||||
extract_param = "D"
|
||||
elif lookup_type == "week":
|
||||
# IW = ISO week number
|
||||
return "TO_CHAR(%s, 'IW')" % field_name
|
||||
extract_param = "IW"
|
||||
elif lookup_type == "quarter":
|
||||
return "TO_CHAR(%s, 'Q')" % field_name
|
||||
extract_param = "Q"
|
||||
elif lookup_type == "iso_year":
|
||||
return "TO_CHAR(%s, 'IYYY')" % field_name
|
||||
extract_param = "IYYY"
|
||||
else:
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html
|
||||
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
return extract_sql, (*params, extract_param)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
trunc_param = None
|
||||
if lookup_type in ("year", "month"):
|
||||
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
|
||||
trunc_param = lookup_type.upper()
|
||||
elif lookup_type == "quarter":
|
||||
return "TRUNC(%s, 'Q')" % field_name
|
||||
trunc_param = "Q"
|
||||
elif lookup_type == "week":
|
||||
return "TRUNC(%s, 'IW')" % field_name
|
||||
trunc_param = "IW"
|
||||
else:
|
||||
return "TRUNC(%s)" % field_name
|
||||
return f"TRUNC({sql})", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
# Oracle crashes with "ORA-03113: end-of-file on communication channel"
|
||||
# if the time zone name is passed in parameter. Use interpolation instead.
|
||||
|
@ -116,77 +128,80 @@ END;
|
|||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f"{sign}{offset}" if offset else tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
def _convert_field_to_tz(self, sql, params, tzname):
|
||||
if not (settings.USE_TZ and tzname):
|
||||
return field_name
|
||||
return sql, params
|
||||
if not self._tzname_re.match(tzname):
|
||||
raise ValueError("Invalid time zone name: %s" % tzname)
|
||||
# Convert from connection timezone to the local time, returning
|
||||
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
|
||||
# TIME ZONE details.
|
||||
if self.connection.timezone_name != tzname:
|
||||
return "CAST((FROM_TZ(%s, '%s') AT TIME ZONE '%s') AS TIMESTAMP)" % (
|
||||
field_name,
|
||||
self.connection.timezone_name,
|
||||
self._prepare_tzname_delta(tzname),
|
||||
from_timezone_name = self.connection.timezone_name
|
||||
to_timezone_name = self._prepare_tzname_delta(tzname)
|
||||
return (
|
||||
f"CAST((FROM_TZ({sql}, '{from_timezone_name}') AT TIME ZONE "
|
||||
f"'{to_timezone_name}') AS TIMESTAMP)",
|
||||
params,
|
||||
)
|
||||
return field_name
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "TRUNC(%s)" % field_name
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
return f"TRUNC({sql})", params
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
# Since `TimeField` values are stored as TIMESTAMP change to the
|
||||
# default date and convert the field to the specified timezone.
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
convert_datetime_sql = (
|
||||
"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR(%s, 'HH24:MI:SS.FF')), "
|
||||
"'YYYY-MM-DD HH24:MI:SS.FF')"
|
||||
) % self._convert_field_to_tz(field_name, tzname)
|
||||
return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % (
|
||||
field_name,
|
||||
convert_datetime_sql,
|
||||
f"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR({sql}, 'HH24:MI:SS.FF')), "
|
||||
f"'YYYY-MM-DD HH24:MI:SS.FF')"
|
||||
)
|
||||
return (
|
||||
f"CASE WHEN {sql} IS NOT NULL THEN {convert_datetime_sql} ELSE NULL END",
|
||||
(*params, *params),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
trunc_param = None
|
||||
if lookup_type in ("year", "month"):
|
||||
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
|
||||
trunc_param = lookup_type.upper()
|
||||
elif lookup_type == "quarter":
|
||||
sql = "TRUNC(%s, 'Q')" % field_name
|
||||
trunc_param = "Q"
|
||||
elif lookup_type == "week":
|
||||
sql = "TRUNC(%s, 'IW')" % field_name
|
||||
elif lookup_type == "day":
|
||||
sql = "TRUNC(%s)" % field_name
|
||||
trunc_param = "IW"
|
||||
elif lookup_type == "hour":
|
||||
sql = "TRUNC(%s, 'HH24')" % field_name
|
||||
trunc_param = "HH24"
|
||||
elif lookup_type == "minute":
|
||||
sql = "TRUNC(%s, 'MI')" % field_name
|
||||
trunc_param = "MI"
|
||||
elif lookup_type == "day":
|
||||
return f"TRUNC({sql})", params
|
||||
else:
|
||||
sql = (
|
||||
"CAST(%s AS DATE)" % field_name
|
||||
) # Cast to DATE removes sub-second precision.
|
||||
return sql
|
||||
# Cast to DATE removes sub-second precision.
|
||||
return f"CAST({sql} AS DATE)", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
# The implementation is similar to `datetime_trunc_sql` as both
|
||||
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
|
||||
# the date part of the later is ignored.
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
sql, params = self._convert_field_to_tz(sql, params, tzname)
|
||||
trunc_param = None
|
||||
if lookup_type == "hour":
|
||||
sql = "TRUNC(%s, 'HH24')" % field_name
|
||||
trunc_param = "HH24"
|
||||
elif lookup_type == "minute":
|
||||
sql = "TRUNC(%s, 'MI')" % field_name
|
||||
trunc_param = "MI"
|
||||
elif lookup_type == "second":
|
||||
sql = (
|
||||
"CAST(%s AS DATE)" % field_name
|
||||
) # Cast to DATE removes sub-second precision.
|
||||
return sql
|
||||
# Cast to DATE removes sub-second precision.
|
||||
return f"CAST({sql} AS DATE)", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
|
|
|
@ -47,22 +47,24 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
)
|
||||
return "%s"
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
|
||||
extract_sql = f"EXTRACT(%s FROM {sql})"
|
||||
extract_param = lookup_type
|
||||
if lookup_type == "week_day":
|
||||
# For consistency across backends, we return Sunday=1, Saturday=7.
|
||||
return "EXTRACT('dow' FROM %s) + 1" % field_name
|
||||
extract_sql = f"EXTRACT(%s FROM {sql}) + 1"
|
||||
extract_param = "dow"
|
||||
elif lookup_type == "iso_week_day":
|
||||
return "EXTRACT('isodow' FROM %s)" % field_name
|
||||
extract_param = "isodow"
|
||||
elif lookup_type == "iso_year":
|
||||
return "EXTRACT('isoyear' FROM %s)" % field_name
|
||||
else:
|
||||
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
|
||||
extract_param = "isoyear"
|
||||
return extract_sql, (extract_param, *params)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
|
||||
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
|
@ -71,43 +73,47 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
return f"{tzname}{sign}{offset}"
|
||||
return tzname
|
||||
|
||||
def _convert_field_to_tz(self, field_name, tzname):
|
||||
def _convert_sql_to_tz(self, sql, params, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
field_name = "%s AT TIME ZONE '%s'" % (
|
||||
field_name,
|
||||
self._prepare_tzname_delta(tzname),
|
||||
tzname_param = self._prepare_tzname_delta(tzname)
|
||||
return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"({sql})::date", params
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"({sql})::time", params
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return (
|
||||
f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
|
||||
("second", "second", *params),
|
||||
)
|
||||
return field_name
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "(%s)::date" % field_name
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "(%s)::time" % field_name
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))"
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
|
||||
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
|
||||
|
||||
def time_extract_sql(self, lookup_type, field_name):
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))"
|
||||
return self.date_extract_sql(lookup_type, field_name)
|
||||
return (
|
||||
f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
|
||||
("second", "second", *params),
|
||||
)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
field_name = self._convert_field_to_tz(field_name, tzname)
|
||||
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
|
||||
|
||||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
|
|
@ -69,13 +69,13 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
"accepting multiple arguments."
|
||||
)
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Support EXTRACT with a user-defined function django_date_extract()
|
||||
that's registered in connect(). Use single quotes because this is a
|
||||
string and could otherwise cause a collision with a field name.
|
||||
"""
|
||||
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
|
@ -88,53 +88,53 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
"""Do nothing since formatting is handled in the custom function."""
|
||||
return sql
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_date_trunc('%s', %s, %s, %s)" % (
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
return f"django_date_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_time_trunc('%s', %s, %s, %s)" % (
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
return f"django_time_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def _convert_tznames_to_sql(self, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
|
||||
return "NULL", "NULL"
|
||||
return tzname, self.connection.timezone_name
|
||||
return None, None
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
return "django_datetime_cast_date(%s, %s, %s)" % (
|
||||
field_name,
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
return f"django_datetime_cast_date({sql}, %s, %s)", (
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
return "django_datetime_cast_time(%s, %s, %s)" % (
|
||||
field_name,
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
return f"django_datetime_cast_time({sql}, %s, %s)", (
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_extract('%s', %s, %s, %s)" % (
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
return f"django_datetime_extract(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_trunc('%s', %s, %s, %s)" % (
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, field_name):
|
||||
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
|
|
@ -51,25 +51,31 @@ class Extract(TimezoneMixin, Transform):
|
|||
super().__init__(expression, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.lookup_name):
|
||||
raise ValueError("Invalid lookup_name: %s" % self.lookup_name)
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
|
||||
sql, params = connection.ops.datetime_extract_sql(
|
||||
self.lookup_name, sql, tuple(params), tzname
|
||||
)
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
|
||||
sql, params = connection.ops.date_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, DurationField):
|
||||
if not connection.features.has_native_duration_field:
|
||||
raise ValueError(
|
||||
"Extract requires native DurationField database support."
|
||||
)
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
else:
|
||||
# resolve_expression has already validated the output_field so this
|
||||
# assert should never be hit.
|
||||
|
@ -237,25 +243,29 @@ class TruncBase(TimezoneMixin, Transform):
|
|||
super().__init__(expression, output_field=output_field, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.kind):
|
||||
raise ValueError("Invalid kind: %s" % self.kind)
|
||||
inner_sql, inner_params = compiler.compile(self.lhs)
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = None
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
|
||||
sql, params = connection.ops.datetime_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
|
||||
sql, params = connection.ops.date_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
|
||||
sql, params = connection.ops.time_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
||||
)
|
||||
return sql, inner_params
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
|
@ -384,10 +394,9 @@ class TruncDate(TruncBase):
|
|||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
lhs, lhs_params = compiler.compile(self.lhs)
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
|
||||
return sql, lhs_params
|
||||
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncTime(TruncBase):
|
||||
|
@ -397,10 +406,9 @@ class TruncTime(TruncBase):
|
|||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to time rather than truncate to time.
|
||||
lhs, lhs_params = compiler.compile(self.lhs)
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
|
||||
return sql, lhs_params
|
||||
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncHour(TruncBase):
|
||||
|
|
|
@ -459,6 +459,20 @@ backends.
|
|||
``DatabaseOperations.insert_statement()`` method is replaced by
|
||||
``on_conflict`` that accepts ``django.db.models.constants.OnConflict``.
|
||||
|
||||
* Several date and time methods on ``DatabaseOperations`` now take ``sql`` and
|
||||
``params`` arguments instead of ``field_name`` and return 2-tuple containing
|
||||
some SQL and the parameters to be interpolated into that SQL. The changed
|
||||
methods have these new signatures:
|
||||
|
||||
* ``DatabaseOperations.date_extract_sql(lookup_type, sql, params)``
|
||||
* ``DatabaseOperations.datetime_extract_sql(lookup_type, sql, params, tzname)``
|
||||
* ``DatabaseOperations.time_extract_sql(lookup_type, sql, params)``
|
||||
* ``DatabaseOperations.date_trunc_sql(lookup_type, sql, params, tzname=None)``
|
||||
* ``DatabaseOperations.datetime_trunc_sql(self, lookup_type, sql, params, tzname)``
|
||||
* ``DatabaseOperations.time_trunc_sql(lookup_type, sql, params, tzname=None)``
|
||||
* ``DatabaseOperations.datetime_cast_date_sql(sql, params, tzname)``
|
||||
* ``DatabaseOperations.datetime_cast_time_sql(sql, params, tzname)``
|
||||
|
||||
:mod:`django.contrib.gis`
|
||||
-------------------------
|
||||
|
||||
|
|
|
@ -115,49 +115,49 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
|
|||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "date_extract_sql"
|
||||
):
|
||||
self.ops.date_extract_sql(None, None)
|
||||
self.ops.date_extract_sql(None, None, None)
|
||||
|
||||
def test_time_extract_sql(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "date_extract_sql"
|
||||
):
|
||||
self.ops.time_extract_sql(None, None)
|
||||
self.ops.time_extract_sql(None, None, None)
|
||||
|
||||
def test_date_trunc_sql(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "date_trunc_sql"
|
||||
):
|
||||
self.ops.date_trunc_sql(None, None)
|
||||
self.ops.date_trunc_sql(None, None, None)
|
||||
|
||||
def test_time_trunc_sql(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "time_trunc_sql"
|
||||
):
|
||||
self.ops.time_trunc_sql(None, None)
|
||||
self.ops.time_trunc_sql(None, None, None)
|
||||
|
||||
def test_datetime_trunc_sql(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "datetime_trunc_sql"
|
||||
):
|
||||
self.ops.datetime_trunc_sql(None, None, None)
|
||||
self.ops.datetime_trunc_sql(None, None, None, None)
|
||||
|
||||
def test_datetime_cast_date_sql(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "datetime_cast_date_sql"
|
||||
):
|
||||
self.ops.datetime_cast_date_sql(None, None)
|
||||
self.ops.datetime_cast_date_sql(None, None, None)
|
||||
|
||||
def test_datetime_cast_time_sql(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "datetime_cast_time_sql"
|
||||
):
|
||||
self.ops.datetime_cast_time_sql(None, None)
|
||||
self.ops.datetime_cast_time_sql(None, None, None)
|
||||
|
||||
def test_datetime_extract_sql(self):
|
||||
with self.assertRaisesMessage(
|
||||
NotImplementedError, self.may_require_msg % "datetime_extract_sql"
|
||||
):
|
||||
self.ops.datetime_extract_sql(None, None, None)
|
||||
self.ops.datetime_extract_sql(None, None, None, None)
|
||||
|
||||
|
||||
class DatabaseOperationTests(TestCase):
|
||||
|
|
|
@ -75,7 +75,7 @@ class YearTransform(models.Transform):
|
|||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs_sql, params = compiler.compile(self.lhs)
|
||||
return connection.ops.date_extract_sql("year", lhs_sql), params
|
||||
return connection.ops.date_extract_sql("year", lhs_sql, params)
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
|
|
|
@ -13,6 +13,7 @@ except ImportError:
|
|||
pytz = None
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DataError, OperationalError
|
||||
from django.db.models import (
|
||||
DateField,
|
||||
DateTimeField,
|
||||
|
@ -244,8 +245,7 @@ class DateFunctionTests(TestCase):
|
|||
self.create_model(start_datetime, end_datetime)
|
||||
self.create_model(end_datetime, start_datetime)
|
||||
|
||||
msg = "Invalid lookup_name: "
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
with self.assertRaises((DataError, OperationalError, ValueError)):
|
||||
DTModel.objects.filter(
|
||||
start_datetime__year=Extract(
|
||||
"start_datetime", "day' FROM start_datetime)) OR 1=1;--"
|
||||
|
@ -940,14 +940,18 @@ class DateFunctionTests(TestCase):
|
|||
end_datetime = timezone.make_aware(end_datetime)
|
||||
self.create_model(start_datetime, end_datetime)
|
||||
self.create_model(end_datetime, start_datetime)
|
||||
msg = "Invalid kind: "
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
DTModel.objects.filter(
|
||||
# Database backends raise an exception or don't return any results.
|
||||
try:
|
||||
exists = DTModel.objects.filter(
|
||||
start_datetime__date=Trunc(
|
||||
"start_datetime",
|
||||
"year', start_datetime)) OR 1=1;--",
|
||||
)
|
||||
).exists()
|
||||
except (DataError, OperationalError):
|
||||
pass
|
||||
else:
|
||||
self.assertIs(exists, False)
|
||||
|
||||
def test_trunc_func(self):
|
||||
start_datetime = datetime(999, 6, 15, 14, 30, 50, 321)
|
||||
|
|
Loading…
Reference in New Issue