Fixed #29048 -- Added **extra_context to database function as_vendor() methods.

This commit is contained in:
priyanshsaxena 2018-02-08 12:39:00 +05:30 committed by Tim Graham
parent 08f360355a
commit 83b04d4f88
11 changed files with 108 additions and 83 deletions

View File

@ -26,10 +26,10 @@ class GeoAggregate(Aggregate):
**extra_context **extra_context
) )
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05) tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05)
template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
return self.as_sql(compiler, connection, template=template, tolerance=tolerance) return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)

View File

@ -102,19 +102,23 @@ class SQLiteDecimalToFloatMixin:
By default, Decimal values are converted to str by the SQLite backend, which By default, Decimal values are converted to str by the SQLite backend, which
is not acceptable by the GIS functions expecting numeric values. is not acceptable by the GIS functions expecting numeric values.
""" """
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
for expr in self.get_source_expressions(): for expr in self.get_source_expressions():
if hasattr(expr, 'value') and isinstance(expr.value, Decimal): if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
expr.value = float(expr.value) expr.value = float(expr.value)
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection, **extra_context)
class OracleToleranceMixin: class OracleToleranceMixin:
tolerance = 0.05 tolerance = 0.05
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
tol = self.extra.get('tolerance', self.tolerance) tol = self.extra.get('tolerance', self.tolerance)
return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) return self.as_sql(
compiler, connection,
template="%%(function)s(%%(expressions)s, %s)" % tol,
**extra_context
)
class Area(OracleToleranceMixin, GeoFunc): class Area(OracleToleranceMixin, GeoFunc):
@ -181,11 +185,11 @@ class AsGML(GeoFunc):
class AsKML(AsGML): class AsKML(AsGML):
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
# No version parameter # No version parameter
clone = self.copy() clone = self.copy()
clone.set_source_expressions(self.get_source_expressions()[1:]) clone.set_source_expressions(self.get_source_expressions()[1:])
return clone.as_sql(compiler, connection) return clone.as_sql(compiler, connection, **extra_context)
class AsSVG(GeoFunc): class AsSVG(GeoFunc):
@ -205,10 +209,10 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc):
def __init__(self, expression, num_seg=48, **extra): def __init__(self, expression, num_seg=48, **extra):
super().__init__(expression, num_seg, **extra) super().__init__(expression, num_seg, **extra)
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
clone = self.copy() clone = self.copy()
clone.set_source_expressions([self.get_source_expressions()[0]]) clone.set_source_expressions([self.get_source_expressions()[0]])
return super(BoundingCircle, clone).as_oracle(compiler, connection) return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context)
class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
@ -239,7 +243,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
self.spheroid = self._handle_param(spheroid, 'spheroid', bool) self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection, **extra_context):
clone = self.copy() clone = self.copy()
function = None function = None
expr2 = clone.source_expressions[1] expr2 = clone.source_expressions[1]
@ -262,7 +266,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else: else:
function = connection.ops.spatial_function_name('DistanceSphere') function = connection.ops.spatial_function_name('DistanceSphere')
return super(Distance, clone).as_sql(compiler, connection, function=function) return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context): def as_sqlite(self, compiler, connection, **extra_context):
if self.geo_field.geodetic(connection): if self.geo_field.geodetic(connection):
@ -300,12 +304,12 @@ class GeoHash(GeoFunc):
expressions.append(self._handle_param(precision, 'precision', int)) expressions.append(self._handle_param(precision, 'precision', int))
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
clone = self.copy() clone = self.copy()
# If no precision is provided, set it to the maximum. # If no precision is provided, set it to the maximum.
if len(clone.source_expressions) < 2: if len(clone.source_expressions) < 2:
clone.source_expressions.append(Value(100)) clone.source_expressions.append(Value(100))
return clone.as_sql(compiler, connection) return clone.as_sql(compiler, connection, **extra_context)
class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
@ -333,7 +337,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
raise NotSupportedError("This backend doesn't support Length on geodetic fields") raise NotSupportedError("This backend doesn't support Length on geodetic fields")
return super().as_sql(compiler, connection, **extra_context) return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection, **extra_context):
clone = self.copy() clone = self.copy()
function = None function = None
if self.source_is_geography(): if self.source_is_geography():
@ -346,13 +350,13 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
dim = min(f.dim for f in self.get_source_fields() if f) dim = min(f.dim for f in self.get_source_fields() if f)
if dim > 2: if dim > 2:
function = connection.ops.length3d function = connection.ops.length3d
return super(Length, clone).as_sql(compiler, connection, function=function) return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context)
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
function = None function = None
if self.geo_field.geodetic(connection): if self.geo_field.geodetic(connection):
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
return super().as_sql(compiler, connection, function=function) return super().as_sql(compiler, connection, function=function, **extra_context)
class LineLocatePoint(GeoFunc): class LineLocatePoint(GeoFunc):
@ -383,19 +387,19 @@ class NumPoints(GeoFunc):
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
arity = 1 arity = 1
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection, **extra_context):
function = None function = None
if self.geo_field.geodetic(connection) and not self.source_is_geography(): if self.geo_field.geodetic(connection) and not self.source_is_geography():
raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.") raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
dim = min(f.dim for f in self.get_source_fields()) dim = min(f.dim for f in self.get_source_fields())
if dim > 2: if dim > 2:
function = connection.ops.perimeter3d function = connection.ops.perimeter3d
return super().as_sql(compiler, connection, function=function) return super().as_sql(compiler, connection, function=function, **extra_context)
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
if self.geo_field.geodetic(connection): if self.geo_field.geodetic(connection):
raise NotSupportedError("Perimeter cannot use a non-projected field.") raise NotSupportedError("Perimeter cannot use a non-projected field.")
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection, **extra_context)
class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
@ -454,12 +458,12 @@ class Transform(GeomOutputGeoFunc):
class Translate(Scale): class Translate(Scale):
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
clone = self.copy() clone = self.copy()
if len(self.source_expressions) < 4: if len(self.source_expressions) < 4:
# Always provide the z parameter for ST_Translate # Always provide the z parameter for ST_Translate
clone.source_expressions.append(Value(0)) clone.source_expressions.append(Value(0))
return super(Translate, clone).as_sqlite(compiler, connection) return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)
class Union(OracleToleranceMixin, GeomOutputGeoFunc): class Union(OracleToleranceMixin, GeomOutputGeoFunc):

View File

@ -64,7 +64,10 @@ class Aggregate(Func):
if connection.features.supports_aggregate_filter_clause: if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection) filter_sql, filter_params = self.filter.as_sql(compiler, connection)
template = self.filter_template % extra_context.get('template', self.template) template = self.filter_template % extra_context.get('template', self.template)
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql) sql, params = super().as_sql(
compiler, connection, template=template, filter=filter_sql,
**extra_context
)
return sql, params + filter_params return sql, params + filter_params
else: else:
copy = self.copy() copy = self.copy()
@ -92,20 +95,20 @@ class Avg(Aggregate):
return FloatField() return FloatField()
return super()._resolve_output_field() return super()._resolve_output_field()
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection) sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql sql = 'CAST(%s as SIGNED)' % sql
return sql, params return sql, params
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0] expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile( return compiler.compile(
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter)) SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
) )
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection, **extra_context)
class Count(Aggregate): class Count(Aggregate):
@ -157,20 +160,20 @@ class Sum(Aggregate):
function = 'SUM' function = 'SUM'
name = 'Sum' name = 'Sum'
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection) sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql sql = 'CAST(%s as SIGNED)' % sql
return sql, params return sql, params
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField': if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0] expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile( return compiler.compile(
SecondsToInterval(Sum(IntervalToSeconds(expression))) SecondsToInterval(Sum(IntervalToSeconds(expression)))
) )
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection, **extra_context)
class Variance(Aggregate): class Variance(Aggregate):

View File

@ -14,16 +14,16 @@ class Cast(Func):
extra_context['db_type'] = self.output_field.cast_db_type(connection) extra_context['db_type'] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context) return super().as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
# MySQL doesn't support explicit cast to float. # MySQL doesn't support explicit cast to float.
template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None
return self.as_sql(compiler, connection, template=template) return self.as_sql(compiler, connection, template=template, **extra_context)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection, **extra_context):
# CAST would be valid too, but the :: shortcut syntax is more readable. # CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex # 'expressions' is wrapped in parentheses in case it's a complex
# expression. # expression.
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s') return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)
class Coalesce(Func): class Coalesce(Func):
@ -35,7 +35,7 @@ class Coalesce(Func):
raise ValueError('Coalesce must take at least two expressions') raise ValueError('Coalesce must take at least two expressions')
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2), # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected. # so convert all fields to NCLOB when that type is expected.
if self.output_field.get_internal_type() == 'TextField': if self.output_field.get_internal_type() == 'TextField':
@ -47,8 +47,8 @@ class Coalesce(Func):
] ]
clone = self.copy() clone = self.copy()
clone.set_source_expressions(expressions) clone.set_source_expressions(expressions)
return super(Coalesce, clone).as_sql(compiler, connection) return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection) return self.as_sql(compiler, connection, **extra_context)
class Greatest(Func): class Greatest(Func):
@ -66,9 +66,9 @@ class Greatest(Func):
raise ValueError('Greatest must take at least two expressions') raise ValueError('Greatest must take at least two expressions')
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite.""" """Use the MAX function on SQLite."""
return super().as_sqlite(compiler, connection, function='MAX') return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
class Least(Func): class Least(Func):
@ -86,6 +86,6 @@ class Least(Func):
raise ValueError('Least must take at least two expressions') raise ValueError('Least must take at least two expressions')
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite.""" """Use the MIN function on SQLite."""
return super().as_sqlite(compiler, connection, function='MIN') return super().as_sqlite(compiler, connection, function='MIN', **extra_context)

View File

@ -159,11 +159,11 @@ class Now(Func):
template = 'CURRENT_TIMESTAMP' template = 'CURRENT_TIMESTAMP'
output_field = fields.DateTimeField() output_field = fields.DateTimeField()
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases. # other databases.
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()') return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
class TruncBase(TimezoneMixin, Transform): class TruncBase(TimezoneMixin, Transform):

View File

@ -9,7 +9,7 @@ from django.db.models.functions import Cast
class DecimalInputMixin: class DecimalInputMixin:
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the # Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures: # following function signatures:
# - LOG(double, double) # - LOG(double, double)
@ -20,7 +20,7 @@ class DecimalInputMixin:
Cast(expression, output_field) if isinstance(expression.output_field, FloatField) Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
else expression for expression in self.get_source_expressions() else expression for expression in self.get_source_expressions()
]) ])
return clone.as_sql(compiler, connection) return clone.as_sql(compiler, connection, **extra_context)
class OutputFieldMixin: class OutputFieldMixin:
@ -54,7 +54,7 @@ class ATan2(OutputFieldMixin, Func):
function = 'ATAN2' function = 'ATAN2'
arity = 2 arity = 2
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0): if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0):
return self.as_sql(compiler, connection) return self.as_sql(compiler, connection)
# This function is usually ATan2(y, x), returning the inverse tangent # This function is usually ATan2(y, x), returning the inverse tangent
@ -67,15 +67,15 @@ class ATan2(OutputFieldMixin, Func):
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField) Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
else expression for expression in self.get_source_expressions()[::-1] else expression for expression in self.get_source_expressions()[::-1]
]) ])
return clone.as_sql(compiler, connection) return clone.as_sql(compiler, connection, **extra_context)
class Ceil(Transform): class Ceil(Transform):
function = 'CEILING' function = 'CEILING'
lookup_name = 'ceil' lookup_name = 'ceil'
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CEIL') return super().as_sql(compiler, connection, function='CEIL', **extra_context)
class Cos(OutputFieldMixin, Transform): class Cos(OutputFieldMixin, Transform):
@ -87,16 +87,20 @@ class Cot(OutputFieldMixin, Transform):
function = 'COT' function = 'COT'
lookup_name = 'cot' lookup_name = 'cot'
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))') return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
class Degrees(OutputFieldMixin, Transform): class Degrees(OutputFieldMixin, Transform):
function = 'DEGREES' function = 'DEGREES'
lookup_name = 'degrees' lookup_name = 'degrees'
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template='((%%(expressions)s) * 180 / %s)' % math.pi) return super().as_sql(
compiler, connection,
template='((%%(expressions)s) * 180 / %s)' % math.pi,
**extra_context
)
class Exp(OutputFieldMixin, Transform): class Exp(OutputFieldMixin, Transform):
@ -118,14 +122,14 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func):
function = 'LOG' function = 'LOG'
arity = 2 arity = 2
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, 'spatialite', False): if not getattr(connection.ops, 'spatialite', False):
return self.as_sql(compiler, connection) return self.as_sql(compiler, connection)
# This function is usually Log(b, x) returning the logarithm of x to # This function is usually Log(b, x) returning the logarithm of x to
# the base b, but on SpatiaLite it's Log(x, b). # the base b, but on SpatiaLite it's Log(x, b).
clone = self.copy() clone = self.copy()
clone.set_source_expressions(self.get_source_expressions()[::-1]) clone.set_source_expressions(self.get_source_expressions()[::-1])
return clone.as_sql(compiler, connection) return clone.as_sql(compiler, connection, **extra_context)
class Mod(DecimalInputMixin, OutputFieldMixin, Func): class Mod(DecimalInputMixin, OutputFieldMixin, Func):
@ -137,8 +141,8 @@ class Pi(OutputFieldMixin, Func):
function = 'PI' function = 'PI'
arity = 0 arity = 0
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template=str(math.pi)) return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
class Power(OutputFieldMixin, Func): class Power(OutputFieldMixin, Func):
@ -150,8 +154,12 @@ class Radians(OutputFieldMixin, Transform):
function = 'RADIANS' function = 'RADIANS'
lookup_name = 'radians' lookup_name = 'radians'
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template='((%%(expressions)s) * %s / 180)' % math.pi) return super().as_sql(
compiler, connection,
template='((%%(expressions)s) * %s / 180)' % math.pi,
**extra_context
)
class Round(Transform): class Round(Transform):

View File

@ -22,13 +22,19 @@ class Chr(Transform):
function = 'CHR' function = 'CHR'
lookup_name = 'chr' lookup_name = 'chr'
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql( return super().as_sql(
compiler, connection, function='CHAR', template='%(function)s(%(expressions)s USING utf16)' compiler, connection, function='CHAR',
template='%(function)s(%(expressions)s USING utf16)',
**extra_context
) )
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, template='%(function)s(%(expressions)s USING NCHAR_CS)') return super().as_sql(
compiler, connection,
template='%(function)s(%(expressions)s USING NCHAR_CS)',
**extra_context
)
def as_sqlite(self, compiler, connection, **extra_context): def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CHAR', **extra_context) return super().as_sql(compiler, connection, function='CHAR', **extra_context)
@ -41,16 +47,19 @@ class ConcatPair(Func):
""" """
function = 'CONCAT' function = 'CONCAT'
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
coalesced = self.coalesce() coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql( return super(ConcatPair, coalesced).as_sql(
compiler, connection, template='%(expressions)s', arg_joiner=' || ' compiler, connection, template='%(expressions)s', arg_joiner=' || ',
**extra_context
) )
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored. # Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super().as_sql( return super().as_sql(
compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)" compiler, connection, function='CONCAT_WS',
template="%(function)s('', %(expressions)s)",
**extra_context
) )
def coalesce(self): def coalesce(self):
@ -117,8 +126,8 @@ class Length(Transform):
lookup_name = 'length' lookup_name = 'length'
output_field = fields.IntegerField() output_field = fields.IntegerField()
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='CHAR_LENGTH') return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
class Lower(Transform): class Lower(Transform):
@ -199,8 +208,8 @@ class StrIndex(Func):
arity = 2 arity = 2
output_field = fields.IntegerField() output_field = fields.IntegerField()
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='STRPOS') return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
class Substr(Func): class Substr(Func):
@ -220,11 +229,11 @@ class Substr(Func):
expressions.append(length) expressions.append(length)
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='SUBSTR') return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='SUBSTR') return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
class Trim(Transform): class Trim(Transform):

View File

@ -275,7 +275,7 @@ We can change the behavior on a specific backend by creating a subclass of
``NotEqual`` with an ``as_mysql`` method:: ``NotEqual`` with an ``as_mysql`` method::
class MySQLNotEqual(NotEqual): class MySQLNotEqual(NotEqual):
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
lhs, lhs_params = self.process_lhs(compiler, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params

View File

@ -322,11 +322,12 @@ The ``Func`` API is as follows:
function = 'CONCAT' function = 'CONCAT'
... ...
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql( return super().as_sql(
compiler, connection, compiler, connection,
function='CONCAT_WS', function='CONCAT_WS',
template="%(function)s('', %(expressions)s)", template="%(function)s('', %(expressions)s)",
**extra_context
) )
To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must

View File

@ -1083,8 +1083,8 @@ class AggregateTestCase(TestCase):
class Greatest(Func): class Greatest(Func):
function = 'GREATEST' function = 'GREATEST'
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='MAX') return super().as_sql(compiler, connection, function='MAX', **extra_context)
qs = Publisher.objects.annotate( qs = Publisher.objects.annotate(
price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))

View File

@ -34,7 +34,7 @@ class Div3Transform(models.Transform):
lhs, lhs_params = compiler.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
return '(%s) %%%% 3' % lhs, lhs_params return '(%s) %%%% 3' % lhs, lhs_params
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection, **extra_context):
lhs, lhs_params = compiler.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
return 'mod(%s, 3)' % lhs, lhs_params return 'mod(%s, 3)' % lhs, lhs_params