Fixed #29048 -- Added **extra_context to database function as_vendor() methods.
This commit is contained in:
parent
08f360355a
commit
83b04d4f88
|
@ -26,10 +26,10 @@ class GeoAggregate(Aggregate):
|
||||||
**extra_context
|
**extra_context
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05)
|
tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05)
|
||||||
template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
|
template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
|
||||||
return self.as_sql(compiler, connection, template=template, tolerance=tolerance)
|
return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context)
|
||||||
|
|
||||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||||
c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||||
|
|
|
@ -102,19 +102,23 @@ class SQLiteDecimalToFloatMixin:
|
||||||
By default, Decimal values are converted to str by the SQLite backend, which
|
By default, Decimal values are converted to str by the SQLite backend, which
|
||||||
is not acceptable by the GIS functions expecting numeric values.
|
is not acceptable by the GIS functions expecting numeric values.
|
||||||
"""
|
"""
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
for expr in self.get_source_expressions():
|
for expr in self.get_source_expressions():
|
||||||
if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
|
if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
|
||||||
expr.value = float(expr.value)
|
expr.value = float(expr.value)
|
||||||
return super().as_sql(compiler, connection)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class OracleToleranceMixin:
|
class OracleToleranceMixin:
|
||||||
tolerance = 0.05
|
tolerance = 0.05
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
tol = self.extra.get('tolerance', self.tolerance)
|
tol = self.extra.get('tolerance', self.tolerance)
|
||||||
return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
|
return self.as_sql(
|
||||||
|
compiler, connection,
|
||||||
|
template="%%(function)s(%%(expressions)s, %s)" % tol,
|
||||||
|
**extra_context
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Area(OracleToleranceMixin, GeoFunc):
|
class Area(OracleToleranceMixin, GeoFunc):
|
||||||
|
@ -181,11 +185,11 @@ class AsGML(GeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class AsKML(AsGML):
|
class AsKML(AsGML):
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
# No version parameter
|
# No version parameter
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
clone.set_source_expressions(self.get_source_expressions()[1:])
|
clone.set_source_expressions(self.get_source_expressions()[1:])
|
||||||
return clone.as_sql(compiler, connection)
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class AsSVG(GeoFunc):
|
class AsSVG(GeoFunc):
|
||||||
|
@ -205,10 +209,10 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc):
|
||||||
def __init__(self, expression, num_seg=48, **extra):
|
def __init__(self, expression, num_seg=48, **extra):
|
||||||
super().__init__(expression, num_seg, **extra)
|
super().__init__(expression, num_seg, **extra)
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
clone.set_source_expressions([self.get_source_expressions()[0]])
|
clone.set_source_expressions([self.get_source_expressions()[0]])
|
||||||
return super(BoundingCircle, clone).as_oracle(compiler, connection)
|
return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
|
class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||||
|
@ -239,7 +243,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
|
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
function = None
|
function = None
|
||||||
expr2 = clone.source_expressions[1]
|
expr2 = clone.source_expressions[1]
|
||||||
|
@ -262,7 +266,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||||
else:
|
else:
|
||||||
function = connection.ops.spatial_function_name('DistanceSphere')
|
function = connection.ops.spatial_function_name('DistanceSphere')
|
||||||
return super(Distance, clone).as_sql(compiler, connection, function=function)
|
return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection, **extra_context):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
if self.geo_field.geodetic(connection):
|
if self.geo_field.geodetic(connection):
|
||||||
|
@ -300,12 +304,12 @@ class GeoHash(GeoFunc):
|
||||||
expressions.append(self._handle_param(precision, 'precision', int))
|
expressions.append(self._handle_param(precision, 'precision', int))
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
# If no precision is provided, set it to the maximum.
|
# If no precision is provided, set it to the maximum.
|
||||||
if len(clone.source_expressions) < 2:
|
if len(clone.source_expressions) < 2:
|
||||||
clone.source_expressions.append(Value(100))
|
clone.source_expressions.append(Value(100))
|
||||||
return clone.as_sql(compiler, connection)
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
|
class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||||
|
@ -333,7 +337,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
raise NotSupportedError("This backend doesn't support Length on geodetic fields")
|
raise NotSupportedError("This backend doesn't support Length on geodetic fields")
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
function = None
|
function = None
|
||||||
if self.source_is_geography():
|
if self.source_is_geography():
|
||||||
|
@ -346,13 +350,13 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
dim = min(f.dim for f in self.get_source_fields() if f)
|
dim = min(f.dim for f in self.get_source_fields() if f)
|
||||||
if dim > 2:
|
if dim > 2:
|
||||||
function = connection.ops.length3d
|
function = connection.ops.length3d
|
||||||
return super(Length, clone).as_sql(compiler, connection, function=function)
|
return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
function = None
|
function = None
|
||||||
if self.geo_field.geodetic(connection):
|
if self.geo_field.geodetic(connection):
|
||||||
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
|
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
|
||||||
return super().as_sql(compiler, connection, function=function)
|
return super().as_sql(compiler, connection, function=function, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class LineLocatePoint(GeoFunc):
|
class LineLocatePoint(GeoFunc):
|
||||||
|
@ -383,19 +387,19 @@ class NumPoints(GeoFunc):
|
||||||
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
arity = 1
|
arity = 1
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
function = None
|
function = None
|
||||||
if self.geo_field.geodetic(connection) and not self.source_is_geography():
|
if self.geo_field.geodetic(connection) and not self.source_is_geography():
|
||||||
raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
|
raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
|
||||||
dim = min(f.dim for f in self.get_source_fields())
|
dim = min(f.dim for f in self.get_source_fields())
|
||||||
if dim > 2:
|
if dim > 2:
|
||||||
function = connection.ops.perimeter3d
|
function = connection.ops.perimeter3d
|
||||||
return super().as_sql(compiler, connection, function=function)
|
return super().as_sql(compiler, connection, function=function, **extra_context)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
if self.geo_field.geodetic(connection):
|
if self.geo_field.geodetic(connection):
|
||||||
raise NotSupportedError("Perimeter cannot use a non-projected field.")
|
raise NotSupportedError("Perimeter cannot use a non-projected field.")
|
||||||
return super().as_sql(compiler, connection)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
|
class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||||
|
@ -454,12 +458,12 @@ class Transform(GeomOutputGeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class Translate(Scale):
|
class Translate(Scale):
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
if len(self.source_expressions) < 4:
|
if len(self.source_expressions) < 4:
|
||||||
# Always provide the z parameter for ST_Translate
|
# Always provide the z parameter for ST_Translate
|
||||||
clone.source_expressions.append(Value(0))
|
clone.source_expressions.append(Value(0))
|
||||||
return super(Translate, clone).as_sqlite(compiler, connection)
|
return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Union(OracleToleranceMixin, GeomOutputGeoFunc):
|
class Union(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||||
|
|
|
@ -64,7 +64,10 @@ class Aggregate(Func):
|
||||||
if connection.features.supports_aggregate_filter_clause:
|
if connection.features.supports_aggregate_filter_clause:
|
||||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||||
template = self.filter_template % extra_context.get('template', self.template)
|
template = self.filter_template % extra_context.get('template', self.template)
|
||||||
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
|
sql, params = super().as_sql(
|
||||||
|
compiler, connection, template=template, filter=filter_sql,
|
||||||
|
**extra_context
|
||||||
|
)
|
||||||
return sql, params + filter_params
|
return sql, params + filter_params
|
||||||
else:
|
else:
|
||||||
copy = self.copy()
|
copy = self.copy()
|
||||||
|
@ -92,20 +95,20 @@ class Avg(Aggregate):
|
||||||
return FloatField()
|
return FloatField()
|
||||||
return super()._resolve_output_field()
|
return super()._resolve_output_field()
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
sql, params = super().as_sql(compiler, connection)
|
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||||
if self.output_field.get_internal_type() == 'DurationField':
|
if self.output_field.get_internal_type() == 'DurationField':
|
||||||
sql = 'CAST(%s as SIGNED)' % sql
|
sql = 'CAST(%s as SIGNED)' % sql
|
||||||
return sql, params
|
return sql, params
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
if self.output_field.get_internal_type() == 'DurationField':
|
if self.output_field.get_internal_type() == 'DurationField':
|
||||||
expression = self.get_source_expressions()[0]
|
expression = self.get_source_expressions()[0]
|
||||||
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
||||||
return compiler.compile(
|
return compiler.compile(
|
||||||
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
|
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
|
||||||
)
|
)
|
||||||
return super().as_sql(compiler, connection)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Count(Aggregate):
|
class Count(Aggregate):
|
||||||
|
@ -157,20 +160,20 @@ class Sum(Aggregate):
|
||||||
function = 'SUM'
|
function = 'SUM'
|
||||||
name = 'Sum'
|
name = 'Sum'
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
sql, params = super().as_sql(compiler, connection)
|
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||||
if self.output_field.get_internal_type() == 'DurationField':
|
if self.output_field.get_internal_type() == 'DurationField':
|
||||||
sql = 'CAST(%s as SIGNED)' % sql
|
sql = 'CAST(%s as SIGNED)' % sql
|
||||||
return sql, params
|
return sql, params
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
if self.output_field.get_internal_type() == 'DurationField':
|
if self.output_field.get_internal_type() == 'DurationField':
|
||||||
expression = self.get_source_expressions()[0]
|
expression = self.get_source_expressions()[0]
|
||||||
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
|
||||||
return compiler.compile(
|
return compiler.compile(
|
||||||
SecondsToInterval(Sum(IntervalToSeconds(expression)))
|
SecondsToInterval(Sum(IntervalToSeconds(expression)))
|
||||||
)
|
)
|
||||||
return super().as_sql(compiler, connection)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Variance(Aggregate):
|
class Variance(Aggregate):
|
||||||
|
|
|
@ -14,16 +14,16 @@ class Cast(Func):
|
||||||
extra_context['db_type'] = self.output_field.cast_db_type(connection)
|
extra_context['db_type'] = self.output_field.cast_db_type(connection)
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
# MySQL doesn't support explicit cast to float.
|
# MySQL doesn't support explicit cast to float.
|
||||||
template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None
|
template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None
|
||||||
return self.as_sql(compiler, connection, template=template)
|
return self.as_sql(compiler, connection, template=template, **extra_context)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
||||||
# 'expressions' is wrapped in parentheses in case it's a complex
|
# 'expressions' is wrapped in parentheses in case it's a complex
|
||||||
# expression.
|
# expression.
|
||||||
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s')
|
return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Coalesce(Func):
|
class Coalesce(Func):
|
||||||
|
@ -35,7 +35,7 @@ class Coalesce(Func):
|
||||||
raise ValueError('Coalesce must take at least two expressions')
|
raise ValueError('Coalesce must take at least two expressions')
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
||||||
# so convert all fields to NCLOB when that type is expected.
|
# so convert all fields to NCLOB when that type is expected.
|
||||||
if self.output_field.get_internal_type() == 'TextField':
|
if self.output_field.get_internal_type() == 'TextField':
|
||||||
|
@ -47,8 +47,8 @@ class Coalesce(Func):
|
||||||
]
|
]
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
clone.set_source_expressions(expressions)
|
clone.set_source_expressions(expressions)
|
||||||
return super(Coalesce, clone).as_sql(compiler, connection)
|
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
|
||||||
return self.as_sql(compiler, connection)
|
return self.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Greatest(Func):
|
class Greatest(Func):
|
||||||
|
@ -66,9 +66,9 @@ class Greatest(Func):
|
||||||
raise ValueError('Greatest must take at least two expressions')
|
raise ValueError('Greatest must take at least two expressions')
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
"""Use the MAX function on SQLite."""
|
"""Use the MAX function on SQLite."""
|
||||||
return super().as_sqlite(compiler, connection, function='MAX')
|
return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Least(Func):
|
class Least(Func):
|
||||||
|
@ -86,6 +86,6 @@ class Least(Func):
|
||||||
raise ValueError('Least must take at least two expressions')
|
raise ValueError('Least must take at least two expressions')
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
"""Use the MIN function on SQLite."""
|
"""Use the MIN function on SQLite."""
|
||||||
return super().as_sqlite(compiler, connection, function='MIN')
|
return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
|
||||||
|
|
|
@ -159,11 +159,11 @@ class Now(Func):
|
||||||
template = 'CURRENT_TIMESTAMP'
|
template = 'CURRENT_TIMESTAMP'
|
||||||
output_field = fields.DateTimeField()
|
output_field = fields.DateTimeField()
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
||||||
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
||||||
# other databases.
|
# other databases.
|
||||||
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()')
|
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class TruncBase(TimezoneMixin, Transform):
|
class TruncBase(TimezoneMixin, Transform):
|
||||||
|
|
|
@ -9,7 +9,7 @@ from django.db.models.functions import Cast
|
||||||
|
|
||||||
class DecimalInputMixin:
|
class DecimalInputMixin:
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
||||||
# following function signatures:
|
# following function signatures:
|
||||||
# - LOG(double, double)
|
# - LOG(double, double)
|
||||||
|
@ -20,7 +20,7 @@ class DecimalInputMixin:
|
||||||
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
|
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
|
||||||
else expression for expression in self.get_source_expressions()
|
else expression for expression in self.get_source_expressions()
|
||||||
])
|
])
|
||||||
return clone.as_sql(compiler, connection)
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class OutputFieldMixin:
|
class OutputFieldMixin:
|
||||||
|
@ -54,7 +54,7 @@ class ATan2(OutputFieldMixin, Func):
|
||||||
function = 'ATAN2'
|
function = 'ATAN2'
|
||||||
arity = 2
|
arity = 2
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0):
|
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0):
|
||||||
return self.as_sql(compiler, connection)
|
return self.as_sql(compiler, connection)
|
||||||
# This function is usually ATan2(y, x), returning the inverse tangent
|
# This function is usually ATan2(y, x), returning the inverse tangent
|
||||||
|
@ -67,15 +67,15 @@ class ATan2(OutputFieldMixin, Func):
|
||||||
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
|
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
|
||||||
else expression for expression in self.get_source_expressions()[::-1]
|
else expression for expression in self.get_source_expressions()[::-1]
|
||||||
])
|
])
|
||||||
return clone.as_sql(compiler, connection)
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Ceil(Transform):
|
class Ceil(Transform):
|
||||||
function = 'CEILING'
|
function = 'CEILING'
|
||||||
lookup_name = 'ceil'
|
lookup_name = 'ceil'
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, function='CEIL')
|
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Cos(OutputFieldMixin, Transform):
|
class Cos(OutputFieldMixin, Transform):
|
||||||
|
@ -87,16 +87,20 @@ class Cot(OutputFieldMixin, Transform):
|
||||||
function = 'COT'
|
function = 'COT'
|
||||||
lookup_name = 'cot'
|
lookup_name = 'cot'
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))')
|
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Degrees(OutputFieldMixin, Transform):
|
class Degrees(OutputFieldMixin, Transform):
|
||||||
function = 'DEGREES'
|
function = 'DEGREES'
|
||||||
lookup_name = 'degrees'
|
lookup_name = 'degrees'
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, template='((%%(expressions)s) * 180 / %s)' % math.pi)
|
return super().as_sql(
|
||||||
|
compiler, connection,
|
||||||
|
template='((%%(expressions)s) * 180 / %s)' % math.pi,
|
||||||
|
**extra_context
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Exp(OutputFieldMixin, Transform):
|
class Exp(OutputFieldMixin, Transform):
|
||||||
|
@ -118,14 +122,14 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func):
|
||||||
function = 'LOG'
|
function = 'LOG'
|
||||||
arity = 2
|
arity = 2
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
if not getattr(connection.ops, 'spatialite', False):
|
if not getattr(connection.ops, 'spatialite', False):
|
||||||
return self.as_sql(compiler, connection)
|
return self.as_sql(compiler, connection)
|
||||||
# This function is usually Log(b, x) returning the logarithm of x to
|
# This function is usually Log(b, x) returning the logarithm of x to
|
||||||
# the base b, but on SpatiaLite it's Log(x, b).
|
# the base b, but on SpatiaLite it's Log(x, b).
|
||||||
clone = self.copy()
|
clone = self.copy()
|
||||||
clone.set_source_expressions(self.get_source_expressions()[::-1])
|
clone.set_source_expressions(self.get_source_expressions()[::-1])
|
||||||
return clone.as_sql(compiler, connection)
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Mod(DecimalInputMixin, OutputFieldMixin, Func):
|
class Mod(DecimalInputMixin, OutputFieldMixin, Func):
|
||||||
|
@ -137,8 +141,8 @@ class Pi(OutputFieldMixin, Func):
|
||||||
function = 'PI'
|
function = 'PI'
|
||||||
arity = 0
|
arity = 0
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, template=str(math.pi))
|
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Power(OutputFieldMixin, Func):
|
class Power(OutputFieldMixin, Func):
|
||||||
|
@ -150,8 +154,12 @@ class Radians(OutputFieldMixin, Transform):
|
||||||
function = 'RADIANS'
|
function = 'RADIANS'
|
||||||
lookup_name = 'radians'
|
lookup_name = 'radians'
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, template='((%%(expressions)s) * %s / 180)' % math.pi)
|
return super().as_sql(
|
||||||
|
compiler, connection,
|
||||||
|
template='((%%(expressions)s) * %s / 180)' % math.pi,
|
||||||
|
**extra_context
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Round(Transform):
|
class Round(Transform):
|
||||||
|
|
|
@ -22,13 +22,19 @@ class Chr(Transform):
|
||||||
function = 'CHR'
|
function = 'CHR'
|
||||||
lookup_name = 'chr'
|
lookup_name = 'chr'
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(
|
return super().as_sql(
|
||||||
compiler, connection, function='CHAR', template='%(function)s(%(expressions)s USING utf16)'
|
compiler, connection, function='CHAR',
|
||||||
|
template='%(function)s(%(expressions)s USING utf16)',
|
||||||
|
**extra_context
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, template='%(function)s(%(expressions)s USING NCHAR_CS)')
|
return super().as_sql(
|
||||||
|
compiler, connection,
|
||||||
|
template='%(function)s(%(expressions)s USING NCHAR_CS)',
|
||||||
|
**extra_context
|
||||||
|
)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection, **extra_context):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, function='CHAR', **extra_context)
|
return super().as_sql(compiler, connection, function='CHAR', **extra_context)
|
||||||
|
@ -41,16 +47,19 @@ class ConcatPair(Func):
|
||||||
"""
|
"""
|
||||||
function = 'CONCAT'
|
function = 'CONCAT'
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
coalesced = self.coalesce()
|
coalesced = self.coalesce()
|
||||||
return super(ConcatPair, coalesced).as_sql(
|
return super(ConcatPair, coalesced).as_sql(
|
||||||
compiler, connection, template='%(expressions)s', arg_joiner=' || '
|
compiler, connection, template='%(expressions)s', arg_joiner=' || ',
|
||||||
|
**extra_context
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||||
return super().as_sql(
|
return super().as_sql(
|
||||||
compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)"
|
compiler, connection, function='CONCAT_WS',
|
||||||
|
template="%(function)s('', %(expressions)s)",
|
||||||
|
**extra_context
|
||||||
)
|
)
|
||||||
|
|
||||||
def coalesce(self):
|
def coalesce(self):
|
||||||
|
@ -117,8 +126,8 @@ class Length(Transform):
|
||||||
lookup_name = 'length'
|
lookup_name = 'length'
|
||||||
output_field = fields.IntegerField()
|
output_field = fields.IntegerField()
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, function='CHAR_LENGTH')
|
return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Lower(Transform):
|
class Lower(Transform):
|
||||||
|
@ -199,8 +208,8 @@ class StrIndex(Func):
|
||||||
arity = 2
|
arity = 2
|
||||||
output_field = fields.IntegerField()
|
output_field = fields.IntegerField()
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, function='STRPOS')
|
return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Substr(Func):
|
class Substr(Func):
|
||||||
|
@ -220,11 +229,11 @@ class Substr(Func):
|
||||||
expressions.append(length)
|
expressions.append(length)
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, function='SUBSTR')
|
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, function='SUBSTR')
|
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Trim(Transform):
|
class Trim(Transform):
|
||||||
|
|
|
@ -275,7 +275,7 @@ We can change the behavior on a specific backend by creating a subclass of
|
||||||
``NotEqual`` with an ``as_mysql`` method::
|
``NotEqual`` with an ``as_mysql`` method::
|
||||||
|
|
||||||
class MySQLNotEqual(NotEqual):
|
class MySQLNotEqual(NotEqual):
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||||
params = lhs_params + rhs_params
|
params = lhs_params + rhs_params
|
||||||
|
|
|
@ -322,11 +322,12 @@ The ``Func`` API is as follows:
|
||||||
function = 'CONCAT'
|
function = 'CONCAT'
|
||||||
...
|
...
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(
|
return super().as_sql(
|
||||||
compiler, connection,
|
compiler, connection,
|
||||||
function='CONCAT_WS',
|
function='CONCAT_WS',
|
||||||
template="%(function)s('', %(expressions)s)",
|
template="%(function)s('', %(expressions)s)",
|
||||||
|
**extra_context
|
||||||
)
|
)
|
||||||
|
|
||||||
To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must
|
To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must
|
||||||
|
|
|
@ -1083,8 +1083,8 @@ class AggregateTestCase(TestCase):
|
||||||
class Greatest(Func):
|
class Greatest(Func):
|
||||||
function = 'GREATEST'
|
function = 'GREATEST'
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
return super().as_sql(compiler, connection, function='MAX')
|
return super().as_sql(compiler, connection, function='MAX', **extra_context)
|
||||||
|
|
||||||
qs = Publisher.objects.annotate(
|
qs = Publisher.objects.annotate(
|
||||||
price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
|
price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
|
||||||
|
|
|
@ -34,7 +34,7 @@ class Div3Transform(models.Transform):
|
||||||
lhs, lhs_params = compiler.compile(self.lhs)
|
lhs, lhs_params = compiler.compile(self.lhs)
|
||||||
return '(%s) %%%% 3' % lhs, lhs_params
|
return '(%s) %%%% 3' % lhs, lhs_params
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection, **extra_context):
|
||||||
lhs, lhs_params = compiler.compile(self.lhs)
|
lhs, lhs_params = compiler.compile(self.lhs)
|
||||||
return 'mod(%s, 3)' % lhs, lhs_params
|
return 'mod(%s, 3)' % lhs, lhs_params
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue