diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index e61bb7207d..993d9f91fc 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -26,10 +26,10 @@ class GeoAggregate(Aggregate): **extra_context ) - def as_oracle(self, compiler, connection): + def as_oracle(self, compiler, connection, **extra_context): tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05) template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' - return self.as_sql(compiler, connection, template=template, tolerance=tolerance) + return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context) def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 45ce8193ee..710e7c5d3d 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -102,19 +102,23 @@ class SQLiteDecimalToFloatMixin: By default, Decimal values are converted to str by the SQLite backend, which is not acceptable by the GIS functions expecting numeric values. """ - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): for expr in self.get_source_expressions(): if hasattr(expr, 'value') and isinstance(expr.value, Decimal): expr.value = float(expr.value) - return super().as_sql(compiler, connection) + return super().as_sql(compiler, connection, **extra_context) class OracleToleranceMixin: tolerance = 0.05 - def as_oracle(self, compiler, connection): + def as_oracle(self, compiler, connection, **extra_context): tol = self.extra.get('tolerance', self.tolerance) - return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) + return self.as_sql( + compiler, connection, + template="%%(function)s(%%(expressions)s, %s)" % tol, + **extra_context + ) class Area(OracleToleranceMixin, GeoFunc): @@ -181,11 +185,11 @@ class AsGML(GeoFunc): class AsKML(AsGML): - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): # No version parameter clone = self.copy() clone.set_source_expressions(self.get_source_expressions()[1:]) - return clone.as_sql(compiler, connection) + return clone.as_sql(compiler, connection, **extra_context) class AsSVG(GeoFunc): @@ -205,10 +209,10 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc): def __init__(self, expression, num_seg=48, **extra): super().__init__(expression, num_seg, **extra) - def as_oracle(self, compiler, connection): + def as_oracle(self, compiler, connection, **extra_context): clone = self.copy() clone.set_source_expressions([self.get_source_expressions()[0]]) - return super(BoundingCircle, clone).as_oracle(compiler, connection) + return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context) class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): @@ -239,7 +243,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): self.spheroid = self._handle_param(spheroid, 'spheroid', bool) super().__init__(*expressions, **extra) - def as_postgresql(self, compiler, connection): + def as_postgresql(self, compiler, connection, **extra_context): clone = self.copy() function = None expr2 = clone.source_expressions[1] @@ -262,7 +266,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) else: function = connection.ops.spatial_function_name('DistanceSphere') - return super(Distance, clone).as_sql(compiler, connection, function=function) + return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context) def as_sqlite(self, compiler, connection, **extra_context): if self.geo_field.geodetic(connection): @@ -300,12 +304,12 @@ class GeoHash(GeoFunc): expressions.append(self._handle_param(precision, 'precision', int)) super().__init__(*expressions, **extra) - def as_mysql(self, compiler, connection): + def as_mysql(self, compiler, connection, **extra_context): clone = self.copy() # If no precision is provided, set it to the maximum. if len(clone.source_expressions) < 2: clone.source_expressions.append(Value(100)) - return clone.as_sql(compiler, connection) + return clone.as_sql(compiler, connection, **extra_context) class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): @@ -333,7 +337,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): raise NotSupportedError("This backend doesn't support Length on geodetic fields") return super().as_sql(compiler, connection, **extra_context) - def as_postgresql(self, compiler, connection): + def as_postgresql(self, compiler, connection, **extra_context): clone = self.copy() function = None if self.source_is_geography(): @@ -346,13 +350,13 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): dim = min(f.dim for f in self.get_source_fields() if f) if dim > 2: function = connection.ops.length3d - return super(Length, clone).as_sql(compiler, connection, function=function) + return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context) - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): function = None if self.geo_field.geodetic(connection): function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' - return super().as_sql(compiler, connection, function=function) + return super().as_sql(compiler, connection, function=function, **extra_context) class LineLocatePoint(GeoFunc): @@ -383,19 +387,19 @@ class NumPoints(GeoFunc): class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): arity = 1 - def as_postgresql(self, compiler, connection): + def as_postgresql(self, compiler, connection, **extra_context): function = None if self.geo_field.geodetic(connection) and not self.source_is_geography(): raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.") dim = min(f.dim for f in self.get_source_fields()) if dim > 2: function = connection.ops.perimeter3d - return super().as_sql(compiler, connection, function=function) + return super().as_sql(compiler, connection, function=function, **extra_context) - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): if self.geo_field.geodetic(connection): raise NotSupportedError("Perimeter cannot use a non-projected field.") - return super().as_sql(compiler, connection) + return super().as_sql(compiler, connection, **extra_context) class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): @@ -454,12 +458,12 @@ class Transform(GeomOutputGeoFunc): class Translate(Scale): - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): clone = self.copy() if len(self.source_expressions) < 4: # Always provide the z parameter for ST_Translate clone.source_expressions.append(Value(0)) - return super(Translate, clone).as_sqlite(compiler, connection) + return super(Translate, clone).as_sqlite(compiler, connection, **extra_context) class Union(OracleToleranceMixin, GeomOutputGeoFunc): diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 8b67151663..cc4b0c52a6 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -64,7 +64,10 @@ class Aggregate(Func): if connection.features.supports_aggregate_filter_clause: filter_sql, filter_params = self.filter.as_sql(compiler, connection) template = self.filter_template % extra_context.get('template', self.template) - sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql) + sql, params = super().as_sql( + compiler, connection, template=template, filter=filter_sql, + **extra_context + ) return sql, params + filter_params else: copy = self.copy() @@ -92,20 +95,20 @@ class Avg(Aggregate): return FloatField() return super()._resolve_output_field() - def as_mysql(self, compiler, connection): - sql, params = super().as_sql(compiler, connection) + def as_mysql(self, compiler, connection, **extra_context): + sql, params = super().as_sql(compiler, connection, **extra_context) if self.output_field.get_internal_type() == 'DurationField': sql = 'CAST(%s as SIGNED)' % sql return sql, params - def as_oracle(self, compiler, connection): + def as_oracle(self, compiler, connection, **extra_context): if self.output_field.get_internal_type() == 'DurationField': expression = self.get_source_expressions()[0] from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval return compiler.compile( SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter)) ) - return super().as_sql(compiler, connection) + return super().as_sql(compiler, connection, **extra_context) class Count(Aggregate): @@ -157,20 +160,20 @@ class Sum(Aggregate): function = 'SUM' name = 'Sum' - def as_mysql(self, compiler, connection): - sql, params = super().as_sql(compiler, connection) + def as_mysql(self, compiler, connection, **extra_context): + sql, params = super().as_sql(compiler, connection, **extra_context) if self.output_field.get_internal_type() == 'DurationField': sql = 'CAST(%s as SIGNED)' % sql return sql, params - def as_oracle(self, compiler, connection): + def as_oracle(self, compiler, connection, **extra_context): if self.output_field.get_internal_type() == 'DurationField': expression = self.get_source_expressions()[0] from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval return compiler.compile( SecondsToInterval(Sum(IntervalToSeconds(expression))) ) - return super().as_sql(compiler, connection) + return super().as_sql(compiler, connection, **extra_context) class Variance(Aggregate): diff --git a/django/db/models/functions/comparison.py b/django/db/models/functions/comparison.py index 69e9d09abb..aa733816ec 100644 --- a/django/db/models/functions/comparison.py +++ b/django/db/models/functions/comparison.py @@ -14,16 +14,16 @@ class Cast(Func): extra_context['db_type'] = self.output_field.cast_db_type(connection) return super().as_sql(compiler, connection, **extra_context) - def as_mysql(self, compiler, connection): + def as_mysql(self, compiler, connection, **extra_context): # MySQL doesn't support explicit cast to float. template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None - return self.as_sql(compiler, connection, template=template) + return self.as_sql(compiler, connection, template=template, **extra_context) - def as_postgresql(self, compiler, connection): + def as_postgresql(self, compiler, connection, **extra_context): # CAST would be valid too, but the :: shortcut syntax is more readable. # 'expressions' is wrapped in parentheses in case it's a complex # expression. - return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s') + return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context) class Coalesce(Func): @@ -35,7 +35,7 @@ class Coalesce(Func): raise ValueError('Coalesce must take at least two expressions') super().__init__(*expressions, **extra) - def as_oracle(self, compiler, connection): + def as_oracle(self, compiler, connection, **extra_context): # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2), # so convert all fields to NCLOB when that type is expected. if self.output_field.get_internal_type() == 'TextField': @@ -47,8 +47,8 @@ class Coalesce(Func): ] clone = self.copy() clone.set_source_expressions(expressions) - return super(Coalesce, clone).as_sql(compiler, connection) - return self.as_sql(compiler, connection) + return super(Coalesce, clone).as_sql(compiler, connection, **extra_context) + return self.as_sql(compiler, connection, **extra_context) class Greatest(Func): @@ -66,9 +66,9 @@ class Greatest(Func): raise ValueError('Greatest must take at least two expressions') super().__init__(*expressions, **extra) - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): """Use the MAX function on SQLite.""" - return super().as_sqlite(compiler, connection, function='MAX') + return super().as_sqlite(compiler, connection, function='MAX', **extra_context) class Least(Func): @@ -86,6 +86,6 @@ class Least(Func): raise ValueError('Least must take at least two expressions') super().__init__(*expressions, **extra) - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): """Use the MIN function on SQLite.""" - return super().as_sqlite(compiler, connection, function='MIN') + return super().as_sqlite(compiler, connection, function='MIN', **extra_context) diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py index 4d24d2a694..0a68f075aa 100644 --- a/django/db/models/functions/datetime.py +++ b/django/db/models/functions/datetime.py @@ -159,11 +159,11 @@ class Now(Func): template = 'CURRENT_TIMESTAMP' output_field = fields.DateTimeField() - def as_postgresql(self, compiler, connection): + def as_postgresql(self, compiler, connection, **extra_context): # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with # other databases. - return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()') + return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context) class TruncBase(TimezoneMixin, Transform): diff --git a/django/db/models/functions/math.py b/django/db/models/functions/math.py index 6341e101a6..a08ac552aa 100644 --- a/django/db/models/functions/math.py +++ b/django/db/models/functions/math.py @@ -9,7 +9,7 @@ from django.db.models.functions import Cast class DecimalInputMixin: - def as_postgresql(self, compiler, connection): + def as_postgresql(self, compiler, connection, **extra_context): # Cast FloatField to DecimalField as PostgreSQL doesn't support the # following function signatures: # - LOG(double, double) @@ -20,7 +20,7 @@ class DecimalInputMixin: Cast(expression, output_field) if isinstance(expression.output_field, FloatField) else expression for expression in self.get_source_expressions() ]) - return clone.as_sql(compiler, connection) + return clone.as_sql(compiler, connection, **extra_context) class OutputFieldMixin: @@ -54,7 +54,7 @@ class ATan2(OutputFieldMixin, Func): function = 'ATAN2' arity = 2 - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0): return self.as_sql(compiler, connection) # This function is usually ATan2(y, x), returning the inverse tangent @@ -67,15 +67,15 @@ class ATan2(OutputFieldMixin, Func): Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField) else expression for expression in self.get_source_expressions()[::-1] ]) - return clone.as_sql(compiler, connection) + return clone.as_sql(compiler, connection, **extra_context) class Ceil(Transform): function = 'CEILING' lookup_name = 'ceil' - def as_oracle(self, compiler, connection): - return super().as_sql(compiler, connection, function='CEIL') + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, function='CEIL', **extra_context) class Cos(OutputFieldMixin, Transform): @@ -87,16 +87,20 @@ class Cot(OutputFieldMixin, Transform): function = 'COT' lookup_name = 'cot' - def as_oracle(self, compiler, connection): - return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))') + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context) class Degrees(OutputFieldMixin, Transform): function = 'DEGREES' lookup_name = 'degrees' - def as_oracle(self, compiler, connection): - return super().as_sql(compiler, connection, template='((%%(expressions)s) * 180 / %s)' % math.pi) + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql( + compiler, connection, + template='((%%(expressions)s) * 180 / %s)' % math.pi, + **extra_context + ) class Exp(OutputFieldMixin, Transform): @@ -118,14 +122,14 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func): function = 'LOG' arity = 2 - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): if not getattr(connection.ops, 'spatialite', False): return self.as_sql(compiler, connection) # This function is usually Log(b, x) returning the logarithm of x to # the base b, but on SpatiaLite it's Log(x, b). clone = self.copy() clone.set_source_expressions(self.get_source_expressions()[::-1]) - return clone.as_sql(compiler, connection) + return clone.as_sql(compiler, connection, **extra_context) class Mod(DecimalInputMixin, OutputFieldMixin, Func): @@ -137,8 +141,8 @@ class Pi(OutputFieldMixin, Func): function = 'PI' arity = 0 - def as_oracle(self, compiler, connection): - return super().as_sql(compiler, connection, template=str(math.pi)) + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, template=str(math.pi), **extra_context) class Power(OutputFieldMixin, Func): @@ -150,8 +154,12 @@ class Radians(OutputFieldMixin, Transform): function = 'RADIANS' lookup_name = 'radians' - def as_oracle(self, compiler, connection): - return super().as_sql(compiler, connection, template='((%%(expressions)s) * %s / 180)' % math.pi) + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql( + compiler, connection, + template='((%%(expressions)s) * %s / 180)' % math.pi, + **extra_context + ) class Round(Transform): diff --git a/django/db/models/functions/text.py b/django/db/models/functions/text.py index 1f04842153..8cf10bb76d 100644 --- a/django/db/models/functions/text.py +++ b/django/db/models/functions/text.py @@ -22,13 +22,19 @@ class Chr(Transform): function = 'CHR' lookup_name = 'chr' - def as_mysql(self, compiler, connection): + def as_mysql(self, compiler, connection, **extra_context): return super().as_sql( - compiler, connection, function='CHAR', template='%(function)s(%(expressions)s USING utf16)' + compiler, connection, function='CHAR', + template='%(function)s(%(expressions)s USING utf16)', + **extra_context ) - def as_oracle(self, compiler, connection): - return super().as_sql(compiler, connection, template='%(function)s(%(expressions)s USING NCHAR_CS)') + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql( + compiler, connection, + template='%(function)s(%(expressions)s USING NCHAR_CS)', + **extra_context + ) def as_sqlite(self, compiler, connection, **extra_context): return super().as_sql(compiler, connection, function='CHAR', **extra_context) @@ -41,16 +47,19 @@ class ConcatPair(Func): """ function = 'CONCAT' - def as_sqlite(self, compiler, connection): + def as_sqlite(self, compiler, connection, **extra_context): coalesced = self.coalesce() return super(ConcatPair, coalesced).as_sql( - compiler, connection, template='%(expressions)s', arg_joiner=' || ' + compiler, connection, template='%(expressions)s', arg_joiner=' || ', + **extra_context ) - def as_mysql(self, compiler, connection): + def as_mysql(self, compiler, connection, **extra_context): # Use CONCAT_WS with an empty separator so that NULLs are ignored. return super().as_sql( - compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)" + compiler, connection, function='CONCAT_WS', + template="%(function)s('', %(expressions)s)", + **extra_context ) def coalesce(self): @@ -117,8 +126,8 @@ class Length(Transform): lookup_name = 'length' output_field = fields.IntegerField() - def as_mysql(self, compiler, connection): - return super().as_sql(compiler, connection, function='CHAR_LENGTH') + def as_mysql(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context) class Lower(Transform): @@ -199,8 +208,8 @@ class StrIndex(Func): arity = 2 output_field = fields.IntegerField() - def as_postgresql(self, compiler, connection): - return super().as_sql(compiler, connection, function='STRPOS') + def as_postgresql(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, function='STRPOS', **extra_context) class Substr(Func): @@ -220,11 +229,11 @@ class Substr(Func): expressions.append(length) super().__init__(*expressions, **extra) - def as_sqlite(self, compiler, connection): - return super().as_sql(compiler, connection, function='SUBSTR') + def as_sqlite(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) - def as_oracle(self, compiler, connection): - return super().as_sql(compiler, connection, function='SUBSTR') + def as_oracle(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, function='SUBSTR', **extra_context) class Trim(Transform): diff --git a/docs/howto/custom-lookups.txt b/docs/howto/custom-lookups.txt index 55fdc42237..1e73d329e7 100644 --- a/docs/howto/custom-lookups.txt +++ b/docs/howto/custom-lookups.txt @@ -275,7 +275,7 @@ We can change the behavior on a specific backend by creating a subclass of ``NotEqual`` with an ``as_mysql`` method:: class MySQLNotEqual(NotEqual): - def as_mysql(self, compiler, connection): + def as_mysql(self, compiler, connection, **extra_context): lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 5969085ee4..281b3144ae 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -322,11 +322,12 @@ The ``Func`` API is as follows: function = 'CONCAT' ... - def as_mysql(self, compiler, connection): + def as_mysql(self, compiler, connection, **extra_context): return super().as_sql( compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)", + **extra_context ) To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index a9562784ff..a55ccfbfa2 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1083,8 +1083,8 @@ class AggregateTestCase(TestCase): class Greatest(Func): function = 'GREATEST' - def as_sqlite(self, compiler, connection): - return super().as_sql(compiler, connection, function='MAX') + def as_sqlite(self, compiler, connection, **extra_context): + return super().as_sql(compiler, connection, function='MAX', **extra_context) qs = Publisher.objects.annotate( price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 4bf85339ed..8c82b710c6 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -34,7 +34,7 @@ class Div3Transform(models.Transform): lhs, lhs_params = compiler.compile(self.lhs) return '(%s) %%%% 3' % lhs, lhs_params - def as_oracle(self, compiler, connection): + def as_oracle(self, compiler, connection, **extra_context): lhs, lhs_params = compiler.compile(self.lhs) return 'mod(%s, 3)' % lhs, lhs_params