Refs CVE-2022-34265 -- Properly escaped Extract() and Trunc() parameters.

Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Simon Charette 2022-06-19 23:46:22 -04:00 committed by Mariusz Felisiak
parent 73766c1187
commit 877c800f25
10 changed files with 263 additions and 220 deletions

View File

@ -9,7 +9,6 @@ from django.db import NotSupportedError, transaction
from django.db.backends import utils from django.db.backends import utils
from django.utils import timezone from django.utils import timezone
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.regex_helper import _lazy_re_compile
class BaseDatabaseOperations: class BaseDatabaseOperations:
@ -55,8 +54,6 @@ class BaseDatabaseOperations:
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported. # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
explain_prefix = None explain_prefix = None
extract_trunc_lookup_pattern = _lazy_re_compile(r"[\w\-_()]+")
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
self._cache = None self._cache = None
@ -103,7 +100,7 @@ class BaseDatabaseOperations:
""" """
return "%s" return "%s"
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, sql, params):
""" """
Given a lookup_type of 'year', 'month', or 'day', return the SQL that Given a lookup_type of 'year', 'month', or 'day', return the SQL that
extracts a value from the given date field field_name. extracts a value from the given date field field_name.
@ -113,7 +110,7 @@ class BaseDatabaseOperations:
"method" "method"
) )
def date_trunc_sql(self, lookup_type, field_name, tzname=None): def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
""" """
Given a lookup_type of 'year', 'month', or 'day', return the SQL that Given a lookup_type of 'year', 'month', or 'day', return the SQL that
truncates the given date or datetime field field_name to a date object truncates the given date or datetime field field_name to a date object
@ -127,7 +124,7 @@ class BaseDatabaseOperations:
"method." "method."
) )
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, sql, params, tzname):
""" """
Return the SQL to cast a datetime value to date value. Return the SQL to cast a datetime value to date value.
""" """
@ -136,7 +133,7 @@ class BaseDatabaseOperations:
"datetime_cast_date_sql() method." "datetime_cast_date_sql() method."
) )
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, sql, params, tzname):
""" """
Return the SQL to cast a datetime value to time value. Return the SQL to cast a datetime value to time value.
""" """
@ -145,7 +142,7 @@ class BaseDatabaseOperations:
"datetime_cast_time_sql() method" "datetime_cast_time_sql() method"
) )
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, sql, params, tzname):
""" """
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that extracts a value from the given 'second', return the SQL that extracts a value from the given
@ -156,7 +153,7 @@ class BaseDatabaseOperations:
"method" "method"
) )
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
""" """
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
'second', return the SQL that truncates the given datetime field 'second', return the SQL that truncates the given datetime field
@ -167,7 +164,7 @@ class BaseDatabaseOperations:
"method" "method"
) )
def time_trunc_sql(self, lookup_type, field_name, tzname=None): def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
""" """
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
that truncates the given time or datetime field field_name to a time that truncates the given time or datetime field field_name to a time
@ -180,12 +177,12 @@ class BaseDatabaseOperations:
"subclasses of BaseDatabaseOperations may require a time_trunc_sql() method" "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
) )
def time_extract_sql(self, lookup_type, field_name): def time_extract_sql(self, lookup_type, sql, params):
""" """
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
that extracts a value from the given time field field_name. that extracts a value from the given time field field_name.
""" """
return self.date_extract_sql(lookup_type, field_name) return self.date_extract_sql(lookup_type, sql, params)
def deferrable_sql(self): def deferrable_sql(self):
""" """

View File

@ -7,6 +7,7 @@ from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict from django.db.models.constants import OnConflict
from django.utils import timezone from django.utils import timezone
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.regex_helper import _lazy_re_compile
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
@ -37,117 +38,115 @@ class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "char" cast_char_field_without_max_length = "char"
explain_prefix = "EXPLAIN" explain_prefix = "EXPLAIN"
def date_extract_sql(self, lookup_type, field_name): # EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
if lookup_type == "week_day": if lookup_type == "week_day":
# DAYOFWEEK() returns an integer, 1-7, Sunday=1. # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
return "DAYOFWEEK(%s)" % field_name return f"DAYOFWEEK({sql})", params
elif lookup_type == "iso_week_day": elif lookup_type == "iso_week_day":
# WEEKDAY() returns an integer, 0-6, Monday=0. # WEEKDAY() returns an integer, 0-6, Monday=0.
return "WEEKDAY(%s) + 1" % field_name return f"WEEKDAY({sql}) + 1", params
elif lookup_type == "week": elif lookup_type == "week":
# Override the value of default_week_format for consistency with # Override the value of default_week_format for consistency with
# other database backends. # other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year. # Mode 3: Monday, 1-53, with 4 or more days this year.
return "WEEK(%s, 3)" % field_name return f"WEEK({sql}, 3)", params
elif lookup_type == "iso_year": elif lookup_type == "iso_year":
# Get the year part from the YEARWEEK function, which returns a # Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week. # number as year * 100 + week.
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
else: else:
# EXTRACT returns 1-53 based on ISO-8601 for the week number. # EXTRACT returns 1-53 based on ISO-8601 for the week number.
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
return f"EXTRACT({lookup_type} FROM {sql})", params
def date_trunc_sql(self, lookup_type, field_name, tzname=None): def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
fields = { fields = {
"year": "%%Y-01-01", "year": "%Y-01-01",
"month": "%%Y-%%m-01", "month": "%Y-%m-01",
} # Use double percents to escape. }
if lookup_type in fields: if lookup_type in fields:
format_str = fields[lookup_type] format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str) return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
elif lookup_type == "quarter": elif lookup_type == "quarter":
return ( return (
"MAKEDATE(YEAR(%s), 1) + " f"MAKEDATE(YEAR({sql}), 1) + "
"INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
% (field_name, field_name) (*params, *params),
) )
elif lookup_type == "week": elif lookup_type == "week":
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (field_name, field_name) return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
else: else:
return "DATE(%s)" % (field_name) return f"DATE({sql})", params
def _prepare_tzname_delta(self, tzname): def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname) tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname return f"{sign}{offset}" if offset else tzname
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, sql, params, tzname):
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname: if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
field_name = "CONVERT_TZ(%s, '%s', '%s')" % ( return f"CONVERT_TZ({sql}, %s, %s)", (
field_name, *params,
self.connection.timezone_name, self.connection.timezone_name,
self._prepare_tzname_delta(tzname), self._prepare_tzname_delta(tzname),
) )
return field_name return sql, params
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
return "DATE(%s)" % field_name return f"DATE({sql})", params
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
return "TIME(%s)" % field_name return f"TIME({sql})", params
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
return self.date_extract_sql(lookup_type, field_name) return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
fields = ["year", "month", "day", "hour", "minute", "second"] fields = ["year", "month", "day", "hour", "minute", "second"]
format = ( format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
"%%Y-",
"%%m",
"-%%d",
" %%H:",
"%%i",
":%%s",
) # Use double percents to escape.
format_def = ("0000-", "01", "-01", " 00:", "00", ":00") format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
if lookup_type == "quarter": if lookup_type == "quarter":
return ( return (
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + " f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
"INTERVAL QUARTER({field_name}) QUARTER - " f"INTERVAL QUARTER({sql}) QUARTER - "
+ "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)" f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
).format(field_name=field_name) ), (*params, *params, "%Y-%m-01 00:00:00")
if lookup_type == "week": if lookup_type == "week":
return ( return (
"CAST(DATE_FORMAT(DATE_SUB({field_name}, " f"CAST(DATE_FORMAT("
"INTERVAL WEEKDAY({field_name}) DAY), " f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
"'%%Y-%%m-%%d 00:00:00') AS DATETIME)" ), (*params, *params, "%Y-%m-%d 00:00:00")
).format(field_name=field_name)
try: try:
i = fields.index(lookup_type) + 1 i = fields.index(lookup_type) + 1
except ValueError: except ValueError:
sql = field_name pass
else: else:
format_str = "".join(format[:i] + format_def[i:]) format_str = "".join(format[:i] + format_def[i:])
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str) return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
return sql return sql, params
def time_trunc_sql(self, lookup_type, field_name, tzname=None): def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
fields = { fields = {
"hour": "%%H:00:00", "hour": "%H:00:00",
"minute": "%%H:%%i:00", "minute": "%H:%i:00",
"second": "%%H:%%i:%%s", "second": "%H:%i:%s",
} # Use double percents to escape. }
if lookup_type in fields: if lookup_type in fields:
format_str = fields[lookup_type] format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str) return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
else: else:
return "TIME(%s)" % (field_name) return f"TIME({sql})", params
def fetch_returned_insert_rows(self, cursor): def fetch_returned_insert_rows(self, cursor):
""" """

View File

@ -77,34 +77,46 @@ END;
f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY" f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY"
) )
def date_extract_sql(self, lookup_type, field_name): # EXTRACT format cannot be passed in parameters.
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
def date_extract_sql(self, lookup_type, sql, params):
extract_sql = f"TO_CHAR({sql}, %s)"
extract_param = None
if lookup_type == "week_day": if lookup_type == "week_day":
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday. # TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
return "TO_CHAR(%s, 'D')" % field_name extract_param = "D"
elif lookup_type == "iso_week_day": elif lookup_type == "iso_week_day":
return "TO_CHAR(%s - 1, 'D')" % field_name extract_sql = f"TO_CHAR({sql} - 1, %s)"
extract_param = "D"
elif lookup_type == "week": elif lookup_type == "week":
# IW = ISO week number # IW = ISO week number
return "TO_CHAR(%s, 'IW')" % field_name extract_param = "IW"
elif lookup_type == "quarter": elif lookup_type == "quarter":
return "TO_CHAR(%s, 'Q')" % field_name extract_param = "Q"
elif lookup_type == "iso_year": elif lookup_type == "iso_year":
return "TO_CHAR(%s, 'IYYY')" % field_name extract_param = "IYYY"
else: else:
lookup_type = lookup_type.upper()
if not self._extract_format_re.fullmatch(lookup_type):
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name) return f"EXTRACT({lookup_type} FROM {sql})", params
return extract_sql, (*params, extract_param)
def date_trunc_sql(self, lookup_type, field_name, tzname=None): def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
trunc_param = None
if lookup_type in ("year", "month"): if lookup_type in ("year", "month"):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) trunc_param = lookup_type.upper()
elif lookup_type == "quarter": elif lookup_type == "quarter":
return "TRUNC(%s, 'Q')" % field_name trunc_param = "Q"
elif lookup_type == "week": elif lookup_type == "week":
return "TRUNC(%s, 'IW')" % field_name trunc_param = "IW"
else: else:
return "TRUNC(%s)" % field_name return f"TRUNC({sql})", params
return f"TRUNC({sql}, %s)", (*params, trunc_param)
# Oracle crashes with "ORA-03113: end-of-file on communication channel" # Oracle crashes with "ORA-03113: end-of-file on communication channel"
# if the time zone name is passed in parameter. Use interpolation instead. # if the time zone name is passed in parameter. Use interpolation instead.
@ -116,77 +128,80 @@ END;
tzname, sign, offset = split_tzname_delta(tzname) tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname return f"{sign}{offset}" if offset else tzname
def _convert_field_to_tz(self, field_name, tzname): def _convert_field_to_tz(self, sql, params, tzname):
if not (settings.USE_TZ and tzname): if not (settings.USE_TZ and tzname):
return field_name return sql, params
if not self._tzname_re.match(tzname): if not self._tzname_re.match(tzname):
raise ValueError("Invalid time zone name: %s" % tzname) raise ValueError("Invalid time zone name: %s" % tzname)
# Convert from connection timezone to the local time, returning # Convert from connection timezone to the local time, returning
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the # TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
# TIME ZONE details. # TIME ZONE details.
if self.connection.timezone_name != tzname: if self.connection.timezone_name != tzname:
return "CAST((FROM_TZ(%s, '%s') AT TIME ZONE '%s') AS TIMESTAMP)" % ( from_timezone_name = self.connection.timezone_name
field_name, to_timezone_name = self._prepare_tzname_delta(tzname)
self.connection.timezone_name, return (
self._prepare_tzname_delta(tzname), f"CAST((FROM_TZ({sql}, '{from_timezone_name}') AT TIME ZONE "
f"'{to_timezone_name}') AS TIMESTAMP)",
params,
) )
return field_name return sql, params
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
return "TRUNC(%s)" % field_name return f"TRUNC({sql})", params
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, sql, params, tzname):
# Since `TimeField` values are stored as TIMESTAMP change to the # Since `TimeField` values are stored as TIMESTAMP change to the
# default date and convert the field to the specified timezone. # default date and convert the field to the specified timezone.
sql, params = self._convert_field_to_tz(sql, params, tzname)
convert_datetime_sql = ( convert_datetime_sql = (
"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR(%s, 'HH24:MI:SS.FF')), " f"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR({sql}, 'HH24:MI:SS.FF')), "
"'YYYY-MM-DD HH24:MI:SS.FF')" f"'YYYY-MM-DD HH24:MI:SS.FF')"
) % self._convert_field_to_tz(field_name, tzname) )
return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % ( return (
field_name, f"CASE WHEN {sql} IS NOT NULL THEN {convert_datetime_sql} ELSE NULL END",
convert_datetime_sql, (*params, *params),
) )
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
return self.date_extract_sql(lookup_type, field_name) return self.date_extract_sql(lookup_type, sql, params)
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
trunc_param = None
if lookup_type in ("year", "month"): if lookup_type in ("year", "month"):
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper()) trunc_param = lookup_type.upper()
elif lookup_type == "quarter": elif lookup_type == "quarter":
sql = "TRUNC(%s, 'Q')" % field_name trunc_param = "Q"
elif lookup_type == "week": elif lookup_type == "week":
sql = "TRUNC(%s, 'IW')" % field_name trunc_param = "IW"
elif lookup_type == "day":
sql = "TRUNC(%s)" % field_name
elif lookup_type == "hour": elif lookup_type == "hour":
sql = "TRUNC(%s, 'HH24')" % field_name trunc_param = "HH24"
elif lookup_type == "minute": elif lookup_type == "minute":
sql = "TRUNC(%s, 'MI')" % field_name trunc_param = "MI"
elif lookup_type == "day":
return f"TRUNC({sql})", params
else: else:
sql = ( # Cast to DATE removes sub-second precision.
"CAST(%s AS DATE)" % field_name return f"CAST({sql} AS DATE)", params
) # Cast to DATE removes sub-second precision. return f"TRUNC({sql}, %s)", (*params, trunc_param)
return sql
def time_trunc_sql(self, lookup_type, field_name, tzname=None): def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
# The implementation is similar to `datetime_trunc_sql` as both # The implementation is similar to `datetime_trunc_sql` as both
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where # `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored. # the date part of the later is ignored.
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_field_to_tz(sql, params, tzname)
trunc_param = None
if lookup_type == "hour": if lookup_type == "hour":
sql = "TRUNC(%s, 'HH24')" % field_name trunc_param = "HH24"
elif lookup_type == "minute": elif lookup_type == "minute":
sql = "TRUNC(%s, 'MI')" % field_name trunc_param = "MI"
elif lookup_type == "second": elif lookup_type == "second":
sql = ( # Cast to DATE removes sub-second precision.
"CAST(%s AS DATE)" % field_name return f"CAST({sql} AS DATE)", params
) # Cast to DATE removes sub-second precision. return f"TRUNC({sql}, %s)", (*params, trunc_param)
return sql
def get_db_converters(self, expression): def get_db_converters(self, expression):
converters = super().get_db_converters(expression) converters = super().get_db_converters(expression)

View File

@ -47,22 +47,24 @@ class DatabaseOperations(BaseDatabaseOperations):
) )
return "%s" return "%s"
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, sql, params):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
extract_sql = f"EXTRACT(%s FROM {sql})"
extract_param = lookup_type
if lookup_type == "week_day": if lookup_type == "week_day":
# For consistency across backends, we return Sunday=1, Saturday=7. # For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name extract_sql = f"EXTRACT(%s FROM {sql}) + 1"
extract_param = "dow"
elif lookup_type == "iso_week_day": elif lookup_type == "iso_week_day":
return "EXTRACT('isodow' FROM %s)" % field_name extract_param = "isodow"
elif lookup_type == "iso_year": elif lookup_type == "iso_year":
return "EXTRACT('isoyear' FROM %s)" % field_name extract_param = "isoyear"
else: return extract_sql, (extract_param, *params)
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
def date_trunc_sql(self, lookup_type, field_name, tzname=None): def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_sql_to_tz(sql, params, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def _prepare_tzname_delta(self, tzname): def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname) tzname, sign, offset = split_tzname_delta(tzname)
@ -71,43 +73,47 @@ class DatabaseOperations(BaseDatabaseOperations):
return f"{tzname}{sign}{offset}" return f"{tzname}{sign}{offset}"
return tzname return tzname
def _convert_field_to_tz(self, field_name, tzname): def _convert_sql_to_tz(self, sql, params, tzname):
if tzname and settings.USE_TZ: if tzname and settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % ( tzname_param = self._prepare_tzname_delta(tzname)
field_name, return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
self._prepare_tzname_delta(tzname), return sql, params
def datetime_cast_date_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"({sql})::date", params
def datetime_cast_time_sql(self, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
return f"({sql})::time", params
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
sql, params = self._convert_sql_to_tz(sql, params, tzname)
if lookup_type == "second":
# Truncate fractional seconds.
return (
f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
("second", "second", *params),
) )
return field_name return self.date_extract_sql(lookup_type, sql, params)
def datetime_cast_date_sql(self, field_name, tzname): def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_sql_to_tz(sql, params, tzname)
return "(%s)::date" % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "(%s)::time" % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
if lookup_type == "second":
# Truncate fractional seconds.
return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))"
return self.date_extract_sql(lookup_type, field_name)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name) return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
def time_extract_sql(self, lookup_type, field_name): def time_extract_sql(self, lookup_type, sql, params):
if lookup_type == "second": if lookup_type == "second":
# Truncate fractional seconds. # Truncate fractional seconds.
return f"EXTRACT('second' FROM DATE_TRUNC('second', {field_name}))" return (
return self.date_extract_sql(lookup_type, field_name) f"EXTRACT(%s FROM DATE_TRUNC(%s, {sql}))",
("second", "second", *params),
)
return self.date_extract_sql(lookup_type, sql, params)
def time_trunc_sql(self, lookup_type, field_name, tzname=None): def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname) sql, params = self._convert_sql_to_tz(sql, params, tzname)
return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name) return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
def deferrable_sql(self): def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED" return " DEFERRABLE INITIALLY DEFERRED"

View File

@ -69,13 +69,13 @@ class DatabaseOperations(BaseDatabaseOperations):
"accepting multiple arguments." "accepting multiple arguments."
) )
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, sql, params):
""" """
Support EXTRACT with a user-defined function django_date_extract() Support EXTRACT with a user-defined function django_date_extract()
that's registered in connect(). Use single quotes because this is a that's registered in connect(). Use single quotes because this is a
string and could otherwise cause a collision with a field name. string and could otherwise cause a collision with a field name.
""" """
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name) return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
def fetch_returned_insert_rows(self, cursor): def fetch_returned_insert_rows(self, cursor):
""" """
@ -88,53 +88,53 @@ class DatabaseOperations(BaseDatabaseOperations):
"""Do nothing since formatting is handled in the custom function.""" """Do nothing since formatting is handled in the custom function."""
return sql return sql
def date_trunc_sql(self, lookup_type, field_name, tzname=None): def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
return "django_date_trunc('%s', %s, %s, %s)" % ( return f"django_date_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(), lookup_type.lower(),
field_name, *params,
*self._convert_tznames_to_sql(tzname), *self._convert_tznames_to_sql(tzname),
) )
def time_trunc_sql(self, lookup_type, field_name, tzname=None): def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
return "django_time_trunc('%s', %s, %s, %s)" % ( return f"django_time_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(), lookup_type.lower(),
field_name, *params,
*self._convert_tznames_to_sql(tzname), *self._convert_tznames_to_sql(tzname),
) )
def _convert_tznames_to_sql(self, tzname): def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ: if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name return tzname, self.connection.timezone_name
return "NULL", "NULL" return None, None
def datetime_cast_date_sql(self, field_name, tzname): def datetime_cast_date_sql(self, sql, params, tzname):
return "django_datetime_cast_date(%s, %s, %s)" % ( return f"django_datetime_cast_date({sql}, %s, %s)", (
field_name, *params,
*self._convert_tznames_to_sql(tzname), *self._convert_tznames_to_sql(tzname),
) )
def datetime_cast_time_sql(self, field_name, tzname): def datetime_cast_time_sql(self, sql, params, tzname):
return "django_datetime_cast_time(%s, %s, %s)" % ( return f"django_datetime_cast_time({sql}, %s, %s)", (
field_name, *params,
*self._convert_tznames_to_sql(tzname), *self._convert_tznames_to_sql(tzname),
) )
def datetime_extract_sql(self, lookup_type, field_name, tzname): def datetime_extract_sql(self, lookup_type, sql, params, tzname):
return "django_datetime_extract('%s', %s, %s, %s)" % ( return f"django_datetime_extract(%s, {sql}, %s, %s)", (
lookup_type.lower(), lookup_type.lower(),
field_name, *params,
*self._convert_tznames_to_sql(tzname), *self._convert_tznames_to_sql(tzname),
) )
def datetime_trunc_sql(self, lookup_type, field_name, tzname): def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
return "django_datetime_trunc('%s', %s, %s, %s)" % ( return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
lookup_type.lower(), lookup_type.lower(),
field_name, *params,
*self._convert_tznames_to_sql(tzname), *self._convert_tznames_to_sql(tzname),
) )
def time_extract_sql(self, lookup_type, field_name): def time_extract_sql(self, lookup_type, sql, params):
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name) return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
def pk_default_value(self): def pk_default_value(self):
return "NULL" return "NULL"

View File

@ -51,25 +51,31 @@ class Extract(TimezoneMixin, Transform):
super().__init__(expression, **extra) super().__init__(expression, **extra)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.lookup_name):
raise ValueError("Invalid lookup_name: %s" % self.lookup_name)
sql, params = compiler.compile(self.lhs) sql, params = compiler.compile(self.lhs)
lhs_output_field = self.lhs.output_field lhs_output_field = self.lhs.output_field
if isinstance(lhs_output_field, DateTimeField): if isinstance(lhs_output_field, DateTimeField):
tzname = self.get_tzname() tzname = self.get_tzname()
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname) sql, params = connection.ops.datetime_extract_sql(
self.lookup_name, sql, tuple(params), tzname
)
elif self.tzinfo is not None: elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.") raise ValueError("tzinfo can only be used with DateTimeField.")
elif isinstance(lhs_output_field, DateField): elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql) sql, params = connection.ops.date_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, TimeField): elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql) sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
elif isinstance(lhs_output_field, DurationField): elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field: if not connection.features.has_native_duration_field:
raise ValueError( raise ValueError(
"Extract requires native DurationField database support." "Extract requires native DurationField database support."
) )
sql = connection.ops.time_extract_sql(self.lookup_name, sql) sql, params = connection.ops.time_extract_sql(
self.lookup_name, sql, tuple(params)
)
else: else:
# resolve_expression has already validated the output_field so this # resolve_expression has already validated the output_field so this
# assert should never be hit. # assert should never be hit.
@ -237,25 +243,29 @@ class TruncBase(TimezoneMixin, Transform):
super().__init__(expression, output_field=output_field, **extra) super().__init__(expression, output_field=output_field, **extra)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if not connection.ops.extract_trunc_lookup_pattern.fullmatch(self.kind): sql, params = compiler.compile(self.lhs)
raise ValueError("Invalid kind: %s" % self.kind)
inner_sql, inner_params = compiler.compile(self.lhs)
tzname = None tzname = None
if isinstance(self.lhs.output_field, DateTimeField): if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname() tzname = self.get_tzname()
elif self.tzinfo is not None: elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.") raise ValueError("tzinfo can only be used with DateTimeField.")
if isinstance(self.output_field, DateTimeField): if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname) sql, params = connection.ops.datetime_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, DateField): elif isinstance(self.output_field, DateField):
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname) sql, params = connection.ops.date_trunc_sql(
self.kind, sql, tuple(params), tzname
)
elif isinstance(self.output_field, TimeField): elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname) sql, params = connection.ops.time_trunc_sql(
self.kind, sql, tuple(params), tzname
)
else: else:
raise ValueError( raise ValueError(
"Trunc only valid on DateField, TimeField, or DateTimeField." "Trunc only valid on DateField, TimeField, or DateTimeField."
) )
return sql, inner_params return sql, params
def resolve_expression( def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
@ -384,10 +394,9 @@ class TruncDate(TruncBase):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date. # Cast to date rather than truncate to date.
lhs, lhs_params = compiler.compile(self.lhs) sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname() tzname = self.get_tzname()
sql = connection.ops.datetime_cast_date_sql(lhs, tzname) return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
return sql, lhs_params
class TruncTime(TruncBase): class TruncTime(TruncBase):
@ -397,10 +406,9 @@ class TruncTime(TruncBase):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# Cast to time rather than truncate to time. # Cast to time rather than truncate to time.
lhs, lhs_params = compiler.compile(self.lhs) sql, params = compiler.compile(self.lhs)
tzname = self.get_tzname() tzname = self.get_tzname()
sql = connection.ops.datetime_cast_time_sql(lhs, tzname) return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
return sql, lhs_params
class TruncHour(TruncBase): class TruncHour(TruncBase):

View File

@ -459,6 +459,20 @@ backends.
``DatabaseOperations.insert_statement()`` method is replaced by ``DatabaseOperations.insert_statement()`` method is replaced by
``on_conflict`` that accepts ``django.db.models.constants.OnConflict``. ``on_conflict`` that accepts ``django.db.models.constants.OnConflict``.
* Several date and time methods on ``DatabaseOperations`` now take ``sql`` and
``params`` arguments instead of ``field_name`` and return 2-tuple containing
some SQL and the parameters to be interpolated into that SQL. The changed
methods have these new signatures:
* ``DatabaseOperations.date_extract_sql(lookup_type, sql, params)``
* ``DatabaseOperations.datetime_extract_sql(lookup_type, sql, params, tzname)``
* ``DatabaseOperations.time_extract_sql(lookup_type, sql, params)``
* ``DatabaseOperations.date_trunc_sql(lookup_type, sql, params, tzname=None)``
* ``DatabaseOperations.datetime_trunc_sql(self, lookup_type, sql, params, tzname)``
* ``DatabaseOperations.time_trunc_sql(lookup_type, sql, params, tzname=None)``
* ``DatabaseOperations.datetime_cast_date_sql(sql, params, tzname)``
* ``DatabaseOperations.datetime_cast_time_sql(sql, params, tzname)``
:mod:`django.contrib.gis` :mod:`django.contrib.gis`
------------------------- -------------------------

View File

@ -115,49 +115,49 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "date_extract_sql" NotImplementedError, self.may_require_msg % "date_extract_sql"
): ):
self.ops.date_extract_sql(None, None) self.ops.date_extract_sql(None, None, None)
def test_time_extract_sql(self): def test_time_extract_sql(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "date_extract_sql" NotImplementedError, self.may_require_msg % "date_extract_sql"
): ):
self.ops.time_extract_sql(None, None) self.ops.time_extract_sql(None, None, None)
def test_date_trunc_sql(self): def test_date_trunc_sql(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "date_trunc_sql" NotImplementedError, self.may_require_msg % "date_trunc_sql"
): ):
self.ops.date_trunc_sql(None, None) self.ops.date_trunc_sql(None, None, None)
def test_time_trunc_sql(self): def test_time_trunc_sql(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "time_trunc_sql" NotImplementedError, self.may_require_msg % "time_trunc_sql"
): ):
self.ops.time_trunc_sql(None, None) self.ops.time_trunc_sql(None, None, None)
def test_datetime_trunc_sql(self): def test_datetime_trunc_sql(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_trunc_sql" NotImplementedError, self.may_require_msg % "datetime_trunc_sql"
): ):
self.ops.datetime_trunc_sql(None, None, None) self.ops.datetime_trunc_sql(None, None, None, None)
def test_datetime_cast_date_sql(self): def test_datetime_cast_date_sql(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_cast_date_sql" NotImplementedError, self.may_require_msg % "datetime_cast_date_sql"
): ):
self.ops.datetime_cast_date_sql(None, None) self.ops.datetime_cast_date_sql(None, None, None)
def test_datetime_cast_time_sql(self): def test_datetime_cast_time_sql(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_cast_time_sql" NotImplementedError, self.may_require_msg % "datetime_cast_time_sql"
): ):
self.ops.datetime_cast_time_sql(None, None) self.ops.datetime_cast_time_sql(None, None, None)
def test_datetime_extract_sql(self): def test_datetime_extract_sql(self):
with self.assertRaisesMessage( with self.assertRaisesMessage(
NotImplementedError, self.may_require_msg % "datetime_extract_sql" NotImplementedError, self.may_require_msg % "datetime_extract_sql"
): ):
self.ops.datetime_extract_sql(None, None, None) self.ops.datetime_extract_sql(None, None, None, None)
class DatabaseOperationTests(TestCase): class DatabaseOperationTests(TestCase):

View File

@ -75,7 +75,7 @@ class YearTransform(models.Transform):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
lhs_sql, params = compiler.compile(self.lhs) lhs_sql, params = compiler.compile(self.lhs)
return connection.ops.date_extract_sql("year", lhs_sql), params return connection.ops.date_extract_sql("year", lhs_sql, params)
@property @property
def output_field(self): def output_field(self):

View File

@ -13,6 +13,7 @@ except ImportError:
pytz = None pytz = None
from django.conf import settings from django.conf import settings
from django.db import DataError, OperationalError
from django.db.models import ( from django.db.models import (
DateField, DateField,
DateTimeField, DateTimeField,
@ -244,8 +245,7 @@ class DateFunctionTests(TestCase):
self.create_model(start_datetime, end_datetime) self.create_model(start_datetime, end_datetime)
self.create_model(end_datetime, start_datetime) self.create_model(end_datetime, start_datetime)
msg = "Invalid lookup_name: " with self.assertRaises((DataError, OperationalError, ValueError)):
with self.assertRaisesMessage(ValueError, msg):
DTModel.objects.filter( DTModel.objects.filter(
start_datetime__year=Extract( start_datetime__year=Extract(
"start_datetime", "day' FROM start_datetime)) OR 1=1;--" "start_datetime", "day' FROM start_datetime)) OR 1=1;--"
@ -940,14 +940,18 @@ class DateFunctionTests(TestCase):
end_datetime = timezone.make_aware(end_datetime) end_datetime = timezone.make_aware(end_datetime)
self.create_model(start_datetime, end_datetime) self.create_model(start_datetime, end_datetime)
self.create_model(end_datetime, start_datetime) self.create_model(end_datetime, start_datetime)
msg = "Invalid kind: " # Database backends raise an exception or don't return any results.
with self.assertRaisesMessage(ValueError, msg): try:
DTModel.objects.filter( exists = DTModel.objects.filter(
start_datetime__date=Trunc( start_datetime__date=Trunc(
"start_datetime", "start_datetime",
"year', start_datetime)) OR 1=1;--", "year', start_datetime)) OR 1=1;--",
) )
).exists() ).exists()
except (DataError, OperationalError):
pass
else:
self.assertIs(exists, False)
def test_trunc_func(self): def test_trunc_func(self):
start_datetime = datetime(999, 6, 15, 14, 30, 50, 321) start_datetime = datetime(999, 6, 15, 14, 30, 50, 321)