Refs CVE-2022-34265 -- Properly escaped Extract() and Trunc() parameters.

Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Simon Charette 2022-06-19 23:46:22 -04:00 committed by Mariusz Felisiak
parent 73766c1187
commit 877c800f25
10 changed files with 263 additions and 220 deletions

View File

@ -9,7 +9,6 @@ from django.db import NotSupportedError, transaction
from django.db.backends import utils
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.regex_helper import _lazy_re_compile
class BaseDatabaseOperations:
@ -55,8 +54,6 @@ class BaseDatabaseOperations:
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
explain_prefix = None
extract_trunc_lookup_pattern = _lazy_re_compile(r"[\w\-_()]+")
def __init__(self, connection):
self.connection = connection
self._cache = None
@ -103,7 +100,7 @@ class BaseDatabaseOperations:
"""
return "%s"
def date_extract_sql(self, lookup_type, field_name):
def date_extract_sql(self, lookup_type, sql, params):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
extracts a value from the given date field field_name.
@ -113,7 +110,7 @@ class BaseDatabaseOperations:
"method"
)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
truncates the given date or datetime field field_name to a date object
@ -127,7 +124,7 @@ class BaseDatabaseOperations:
"method."
)
def datetime_cast_date_sql(self, field_name, tzname):
def datetime_cast_date_sql(self, sql, params, tzname):
"""
Return the SQL to cast a datetime value to date value.
"""
@ -136,7 +133,7 @@ class BaseDatabaseOperations:
"datetime_cast_date_sql() method."
)
def datetime_cast_time_sql(self, field_name, tzname):
def datetime_cast_time_sql(self, sql, params, tzname):
"""
Return the SQL to cast a datetime value to time value.
"""
@ -145,7 +142,7 @@ class BaseDatabaseOperations:
"datetime_cast_time_sql() method"
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that extracts a value from the given
@ -156,7 +153,7 @@ class BaseDatabaseOperations:
"method"
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
"""
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that truncates the given datetime field
@ -167,7 +164,7 @@ class BaseDatabaseOperations:
"method"
)
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
"""
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
that truncates the given time or datetime field field_name to a time
@ -180,12 +177,12 @@ class BaseDatabaseOperations:
"subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
)
def time_extract_sql(self, lookup_type, field_name):
def time_extract_sql(self, lookup_type, sql, params):
"""
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
that extracts a value from the given time field field_name.
"""
return self.date_extract_sql(lookup_type, field_name)
return self.date_extract_sql(lookup_type, sql, params)
def deferrable_sql(self):
"""

View File

@ -7,6 +7,7 @@ from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict
from django.utils import timezone
from django.utils.encoding import force_str
from django.utils.regex_helper import _lazy_re_compile
class DatabaseOperations(BaseDatabaseOperations):
@ -37,117 +38,115 @@ class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "char"
explain_prefix = "EXPLAIN"
def date_extract_sql(self, lookup_type, field_name):
# EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
if lookup_type == "week_day":
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
return "DAYOFWEEK(%s)" % field_name
return f"DAYOFWEEK({sql})", params
elif lookup_type == "iso_week_day":
# WEEKDAY() returns an integer, 0-6, Monday=0.
return "WEEKDAY(%s) + 1" % field_name
return f"WEEKDAY({sql}) + 1", params
elif lookup_type == "week":
# Override the value of default_week_format for consistency with
# other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year.
return "WEEK(%s, 3)" % field_name
return f"WEEK({sql}, 3)", params
elif lookup_type == "iso_year":
# Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week.
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
else:
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
return f"EXTRACT({lookup_type} FROM {sql})", params
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_field_to_tz(sql, params, tzname)
fields = {
"year": "%%Y-01-01",
"month": "%%Y-%%m-01",
} # Use double percents to escape.
"year": "%Y-01-01",
"month": "%Y-%m-01",
}
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
elif lookup_type == "quarter":
return (
"MAKEDATE(YEAR(%s), 1) + "
"INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER"
% (field_name, field_name)
f"MAKEDATE(YEAR({sql}), 1) + "
f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
(*params, *params),
)
elif lookup_type == "week":
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (field_name, field_name)
return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
else:
return "DATE(%s)" % (field_name)
return f"DATE({sql})", params
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname
def _convert_field_to_tz(self, field_name, tzname):
def _convert_field_to_tz(self, sql, params, tzname):
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
field_name = "CONVERT_TZ(%s, '%s', '%s')" % (
field_name,
return f"CONVERT_TZ({sql}, %s, %s)", (
*params,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
)
return field_name
return sql, params
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "DATE(%s)" % field_name
def datetime_cast_date_sql(self, sql, params, tzname):
sql, params = self._convert_field_to_tz(sql, params, tzname)
return f"DATE({sql})", params
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "TIME(%s)" % field_name
def datetime_cast_time_sql(self, sql, params, tzname):
sql, params = self._convert_field_to_tz(sql, params, tzname)
return f"TIME({sql})", params
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_field_to_tz(sql, params, tzname)
return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_field_to_tz(sql, params, tzname)
fields = ["year", "month", "day", "hour", "minute", "second"]
format = (
"%%Y-",
"%%m",
"-%%d",
" %%H:",
"%%i",
":%%s",
) # Use double percents to escape.
format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
if lookup_type == "quarter":
return (
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
"INTERVAL QUARTER({field_name}) QUARTER - "
+ "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
).format(field_name=field_name)
f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
f"INTERVAL QUARTER({sql}) QUARTER - "
f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
), (*params, *params, "%Y-%m-01 00:00:00")
if lookup_type == "week":
return (
"CAST(DATE_FORMAT(DATE_SUB({field_name}, "
"INTERVAL WEEKDAY({field_name}) DAY), "
"'%%Y-%%m-%%d 00:00:00') AS DATETIME)"
).format(field_name=field_name)
f"CAST(DATE_FORMAT("
f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
), (*params, *params, "%Y-%m-%d 00:00:00")
try:
i = fields.index(lookup_type) + 1
except ValueError:
sql = field_name
pass
else:
format_str = "".join(format[:i] + format_def[i:])
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql
return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
return sql, params
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_field_to_tz(sql, params, tzname)
fields = {
"hour": "%%H:00:00",
"minute": "%%H:%%i:00",
"second": "%%H:%%i:%%s",
} # Use double percents to escape.
"hour": "%H:00:00",
"minute": "%H:%i:00",
"second": "%H:%i:%s",
}
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str)
return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
else:
return "TIME(%s)" % (field_name)
return f"TIME({sql})", params
def fetch_returned_insert_rows(self, cursor):
"""

View File

@ -77,34 +77,46 @@ END;
f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY"
)
def date_extract_sql(self, lookup_type, field_name):
# EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
extract_sql = f"TO_CHAR({sql}, %s)"
extract_param = None
if lookup_type == "week_day":
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
return "TO_CHAR(%s, 'D')" % field_name
extract_param = "D"
elif lookup_type == "iso_week_day":
return "TO_CHAR(%s - 1, 'D')" % field_name
extract_sql = f"TO_CHAR({sql} - 1, %s)"
extract_param = "D"
elif lookup_type == "week":
# IW = ISO week number
return "TO_CHAR(%s, 'IW')" % field_name
extract_param = "IW"
elif lookup_type == "quarter":
return "TO_CHAR(%s, 'Q')" % field_name
extract_param = "Q"
elif lookup_type == "iso_year":
return "TO_CHAR(%s, 'IYYY')" % field_name
extract_param = "IYYY"
else:
lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
return f"EXTRACT({lookup_type} FROM {sql})", params
return extract_sql, (*params, extract_param)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_field_to_tz(sql, params, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
trunc_param = None
if lookup_type in ("year", "month"):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
trunc_param = lookup_type.upper()
elif lookup_type == "quarter":
return "TRUNC(%s, 'Q')" % field_name
trunc_param = "Q"
elif lookup_type == "week":
return "TRUNC(%s, 'IW')" % field_name
trunc_param = "IW"
else:
return "TRUNC(%s)" % field_name
return f"TRUNC({sql})", params
return f"TRUNC({sql}, %s)", (*params, trunc_param)
# Oracle crashes with "ORA-03113: end-of-file on communication channel"
# if the time zone name is passed in parameter. Use interpolation instead.
@ -116,77 +128,80 @@ END;
tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname
def _convert_field_to_tz(self, field_name, tzname):
def _convert_field_to_tz(self, sql, params, tzname):
if not (settings.USE_TZ and tzname):
return field_name
return sql, params
if not self._tzname_re.match(tzname):
raise ValueError("Invalid time zone name: %s" % tzname)
# Convert from connection timezone to the local time, returning
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
# TIME ZONE details.
if self.connection.timezone_name != tzname:
return "CAST((FROM_TZ(%s, '%s') AT TIME ZONE '%s') AS TIMESTAMP)" % (
field_name,
self.connection.timezone_name,
self._prepare_tzname_delta(tzname),
from_timezone_name = self.connection.timezone_name
to_timezone_name = self._prepare_tzname_delta(tzname)
return (
f"CAST((FROM_TZ({sql}, '{from_timezone_name}') AT TIME ZONE "
f"'{to_timezone_name}') AS TIMESTAMP)",
params,
)
return field_name
return sql, params
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "TRUNC(%s)" % field_name
def datetime_cast_date_sql(self, sql, params, tzname):
sql, params = self._convert_field_to_tz(sql, params, tzname)
return f"TRUNC({sql})", params
def datetime_cast_time_sql(self, field_name, tzname):
def datetime_cast_time_sql(self, sql, params, tzname):
# Since `TimeField` values are stored as TIMESTAMP change to the
# default date and convert the field to the specified timezone.
sql, params = self._convert_field_to_tz(sql, params, tzname)
convert_datetime_sql = (
"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR(%s, 'HH24:MI:SS.FF')), "
"'YYYY-MM-DD HH24:MI:SS.FF')"
) % self._convert_field_to_tz(field_name, tzname)
return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % (
field_name,
convert_datetime_sql,
f"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR({sql}, 'HH24:MI:SS.FF')), "
f"'YYYY-MM-DD HH24:MI:SS.FF')"
)
return (
f"CASE WHEN {sql} IS NOT NULL THEN {convert_datetime_sql} ELSE NULL END",
(*params, *params),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return self.date_extract_sql(lookup_type, field_name)
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_field_to_tz(sql, params, tzname)
return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_field_to_tz(sql, params, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
trunc_param = None
if lookup_type in ("year", "month"):
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
trunc_param = lookup_type.upper()
elif lookup_type == "quarter":
sql = "TRUNC(%s, 'Q')" % field_name
trunc_param = "Q"
elif lookup_type == "week":
sql = "TRUNC(%s, 'IW')" % field_name
elif lookup_type == "day":
sql = "TRUNC(%s)" % field_name
trunc_param = "IW"
elif lookup_type == "hour":
sql = "TRUNC(%s, 'HH24')" % field_name
trunc_param = "HH24"
elif lookup_type == "minute":
sql = "TRUNC(%s, 'MI')" % field_name
trunc_param = "MI"
elif lookup_type == "day":
return f"TRUNC({sql})", params
else:
sql = (
"CAST(%s AS DATE)" % field_name
) # Cast to DATE removes sub-second precision.
return sql
# Cast to DATE removes sub-second precision.
return f"CAST({sql} AS DATE)", params
return f"TRUNC({sql}, %s)", (*params, trunc_param)
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
# The implementation is similar to `datetime_trunc_sql` as both
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored.
field_name = self._convert_field_to_tz(field_name, tzname)
sql, params = self._convert_field_to_tz(sql, params, tzname)
trunc_param = None
if lookup_type == "hour":
sql = "TRUNC(%s, 'HH24')" % field_name
trunc_param = "HH24"
elif lookup_type == "minute":
sql = "TRUNC(%s, 'MI')" % field_name
trunc_param = "MI"
elif lookup_type == "second":
sql = (
"CAST(%s AS DATE)" % field_name
) # Cast to DATE removes sub-second precision.
return sql
# Cast to DATE removes sub-second precision.
return f"CAST({sql} AS DATE)", params
return f"TRUNC({sql}, %s)", (*params, trunc_param)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)

View File

@ -47,22 +47,24 @@ class DatabaseOperations(BaseDatabaseOperations):
)
return "%s"
def date_extract_sql(self, lookup_type, field_name):
def date_extract_sql(self, lookup_type, sql, params):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
extract_sql = f"EXTRACT(%s FROM {sql})"
extract_param = lookup_type
if lookup_type == "week_day":
# For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name
extract_sql = f"EXTRACT(%s FROM {sql}) + 1"
extract_param = "dow"
elif lookup_type == "iso_week_day":
return "EXTRACT('isodow' FROM %s)" % field_name
extract_param = "isodow"
elif lookup_type == "iso_year":
return "EXTRACT('isoyear' FROM %s)" % field_name
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
extract_param = "isoyear"
return extract_sql, (extract_param, *params)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
@ -71,43 +73,47 @@ class DatabaseOperations(BaseDatabaseOperations):
return f"{tzname}{sign}{offset}"
return tzname
def _convert_field_to_tz(self, field_name, tzname):
def _convert_sql_to_tz(self, sql, params, tzname):
if tzname and settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % (
field_name,
self._prepare_tzname_delta(tzname),
tzname_param = self._prepare_tzname_delta(tzname)
return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
return sql, params
def datetime_cast_date_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"({sql})::date", params
def datetime_cast_time_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"({sql})::time", params
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
if lookup_type == "second":
# Truncate fractional seconds.
return (
f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
("second", "second", *params),
)
return field_name
return self.date_extract_sql(lookup_type, sql, params)
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "(%s)::date" % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "(%s)::time" % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
if lookup_type == "second":
# Truncate fractional seconds.
return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))"
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def time_extract_sql(self, lookup_type, field_name):
def time_extract_sql(self, lookup_type, sql, params):
if lookup_type == "second":
# Truncate fractional seconds.
return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))"
return self.date_extract_sql(lookup_type, field_name)
return (
f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
("second", "second", *params),
)
return self.date_extract_sql(lookup_type, sql, params)
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"

View File

@ -69,13 +69,13 @@ class DatabaseOperations(BaseDatabaseOperations):
"accepting multiple arguments."
)
def date_extract_sql(self, lookup_type, field_name):
def date_extract_sql(self, lookup_type, sql, params):
"""
Support EXTRACT with a user-defined function django_date_extract()
that's registered in connect(). Use single quotes because this is a
string and could otherwise cause a collision with a field name.
"""
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
def fetch_returned_insert_rows(self, cursor):
"""
@ -88,53 +88,53 @@ class DatabaseOperations(BaseDatabaseOperations):
"""Do nothing since formatting is handled in the custom function."""
return sql
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_date_trunc('%s', %s, %s, %s)" % (
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
return f"django_date_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(),
field_name,
*params,
*self._convert_tznames_to_sql(tzname),
)
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
return "django_time_trunc('%s', %s, %s, %s)" % (
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
return f"django_time_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(),
field_name,
*params,
*self._convert_tznames_to_sql(tzname),
)
def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
return "NULL", "NULL"
return tzname, self.connection.timezone_name
return None, None
def datetime_cast_date_sql(self, field_name, tzname):
return "django_datetime_cast_date(%s, %s, %s)" % (
field_name,
def datetime_cast_date_sql(self, sql, params, tzname):
return f"django_datetime_cast_date({sql}, %s, %s)", (
*params,
*self._convert_tznames_to_sql(tzname),
)
def datetime_cast_time_sql(self, field_name, tzname):
return "django_datetime_cast_time(%s, %s, %s)" % (
field_name,
def datetime_cast_time_sql(self, sql, params, tzname):
return f"django_datetime_cast_time({sql}, %s, %s)", (
*params,
*self._convert_tznames_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
return "django_datetime_extract('%s', %s, %s, %s)" % (
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
return f"django_datetime_extract(%s, {sql}, %s, %s)", (
lookup_type.lower(),
field_name,
*params,
*self._convert_tznames_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
return "django_datetime_trunc('%s', %s, %s, %s)" % (
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(),
field_name,
*params,
*self._convert_tznames_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, field_name):
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
def time_extract_sql(self, lookup_type, sql, params):
return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
def pk_default_value(self):
return "NULL"

View File

@ -51,25 +51,31 @@ class Extract(TimezoneMixin, Transform):
super().__init__(expression, **extra)
def as_sql(self, compiler, connection):
if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.lookup_name):
raise ValueError("Invalid lookup_name: %s" % self.lookup_name)
sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname()
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
sql, params = connection.ops.datetime_extract_sql(
self.lookup_name, sql, tuple(params), tzname
)
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
sql, params = connection.ops.date_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field:
raise ValueError(
"Extract requires native DurationField database support."
)
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
else:
# resolve_expression has already validated the output_field so this
# assert should never be hit.
@ -237,25 +243,29 @@ class TruncBase(TimezoneMixin, Transform):
super().__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection):
if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.kind):
raise ValueError("Invalid kind: %s" % self.kind)
inner_sql, inner_params = compiler.compile(self.lhs)
sql, params = compiler.compile(self.lhs)
tzname = None
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
sql, params = connection.ops.datetime_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
sql, params = connection.ops.date_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
sql, params = connection.ops.time_trunc_sql(
self.kind, sql, tuple(params), tzname
)
else:
raise ValueError(
"Trunc only valid on DateField, TimeField, or DateTimeField."
)
return sql, inner_params
return sql, params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
@ -384,10 +394,9 @@ class TruncDate(TruncBase):
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
lhs, lhs_params = compiler.compile(self.lhs)
sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname()
sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
return sql, lhs_params
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
class TruncTime(TruncBase):
@ -397,10 +406,9 @@ class TruncTime(TruncBase):
def as_sql(self, compiler, connection):
# Cast to time rather than truncate to time.
lhs, lhs_params = compiler.compile(self.lhs)
sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname()
sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
return sql, lhs_params
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
class TruncHour(TruncBase):

View File

@ -459,6 +459,20 @@ backends.
``DatabaseOperations.insert_statement()`` method is replaced by
``on_conflict`` that accepts ``django.db.models.constants.OnConflict``.
* Several date and time methods on ``DatabaseOperations`` now take ``sql`` and
``params`` arguments instead of ``field_name`` and return 2-tuple containing
some SQL and the parameters to be interpolated into that SQL. The changed
methods have these new signatures:
* ``DatabaseOperations.date_extract_sql(lookup_type, sql, params)``
* ``DatabaseOperations.datetime_extract_sql(lookup_type, sql, params, tzname)``
* ``DatabaseOperations.time_extract_sql(lookup_type, sql, params)``
* ``DatabaseOperations.date_trunc_sql(lookup_type, sql, params, tzname=None)``
* ``DatabaseOperations.datetime_trunc_sql(self, lookup_type, sql, params, tzname)``
* ``DatabaseOperations.time_trunc_sql(lookup_type, sql, params, tzname=None)``
* ``DatabaseOperations.datetime_cast_date_sql(sql, params, tzname)``
* ``DatabaseOperations.datetime_cast_time_sql(sql, params, tzname)``
:mod:`django.contrib.gis`
-------------------------

View File

@ -115,49 +115,49 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "date_extract_sql"
):
self.ops.date_extract_sql(None, None)
self.ops.date_extract_sql(None, None, None)
def test_time_extract_sql(self):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "date_extract_sql"
):
self.ops.time_extract_sql(None, None)
self.ops.time_extract_sql(None, None, None)
def test_date_trunc_sql(self):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "date_trunc_sql"
):
self.ops.date_trunc_sql(None, None)
self.ops.date_trunc_sql(None, None, None)
def test_time_trunc_sql(self):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "time_trunc_sql"
):
self.ops.time_trunc_sql(None, None)
self.ops.time_trunc_sql(None, None, None)
def test_datetime_trunc_sql(self):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_trunc_sql"
):
self.ops.datetime_trunc_sql(None, None, None)
self.ops.datetime_trunc_sql(None, None, None, None)
def test_datetime_cast_date_sql(self):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_cast_date_sql"
):
self.ops.datetime_cast_date_sql(None, None)
self.ops.datetime_cast_date_sql(None, None, None)
def test_datetime_cast_time_sql(self):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_cast_time_sql"
):
self.ops.datetime_cast_time_sql(None, None)
self.ops.datetime_cast_time_sql(None, None, None)
def test_datetime_extract_sql(self):
with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_extract_sql"
):
self.ops.datetime_extract_sql(None, None, None)
self.ops.datetime_extract_sql(None, None, None, None)
class DatabaseOperationTests(TestCase):

View File

@ -75,7 +75,7 @@ class YearTransform(models.Transform):
def as_sql(self, compiler, connection):
lhs_sql, params = compiler.compile(self.lhs)
return connection.ops.date_extract_sql("year", lhs_sql), params
return connection.ops.date_extract_sql("year", lhs_sql, params)
@property
def output_field(self):

View File

@ -13,6 +13,7 @@ except ImportError:
pytz = None
from django.conf import settings
from django.db import DataError, OperationalError
from django.db.models import (
DateField,
DateTimeField,
@ -244,8 +245,7 @@ class DateFunctionTests(TestCase):
self.create_model(start_datetime, end_datetime)
self.create_model(end_datetime, start_datetime)
msg = "Invalid lookup_name: "
with self.assertRaisesMessage(ValueError, msg):
with self.assertRaises((DataError, OperationalError, ValueError)):
DTModel.objects.filter(
start_datetime__year=Extract(
"start_datetime", "day' FROM start_datetime)) OR 1=1;--"
@ -940,14 +940,18 @@ class DateFunctionTests(TestCase):
end_datetime = timezone.make_aware(end_datetime)
self.create_model(start_datetime, end_datetime)
self.create_model(end_datetime, start_datetime)
msg = "Invalid kind: "
with self.assertRaisesMessage(ValueError, msg):
DTModel.objects.filter(
# Database backends raise an exception or don't return any results.
try:
exists = DTModel.objects.filter(
start_datetime__date=Trunc(
"start_datetime",
"year', start_datetime)) OR 1=1;--",
)
).exists()
except (DataError, OperationalError):
pass
else:
self.assertIs(exists, False)
def test_trunc_func(self):
start_datetime = datetime(999, 6, 15, 14, 30, 50, 321)