Fixed #28492 -- Defined default output_field of expressions at the class level.
This wasn't possible when settings were accessed during Field initialization time as our test suite setup script was triggering imports of expressions before settings were configured.
This commit is contained in:
parent
13be453080
commit
08654a99bb
|
@ -16,12 +16,9 @@ NUMERIC_TYPES = (int, float, Decimal)
|
|||
|
||||
class GeoFuncMixin:
|
||||
function = None
|
||||
output_field_class = None
|
||||
geom_param_pos = (0,)
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if 'output_field' not in extra and self.output_field_class:
|
||||
extra['output_field'] = self.output_field_class()
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
# Ensure that value expressions are geometric.
|
||||
|
@ -137,13 +134,13 @@ class Area(OracleToleranceMixin, GeoFunc):
|
|||
|
||||
|
||||
class Azimuth(GeoFunc):
|
||||
output_field_class = FloatField
|
||||
output_field = FloatField()
|
||||
arity = 2
|
||||
geom_param_pos = (0, 1)
|
||||
|
||||
|
||||
class AsGeoJSON(GeoFunc):
|
||||
output_field_class = TextField
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
|
||||
expressions = [expression]
|
||||
|
@ -163,7 +160,7 @@ class AsGeoJSON(GeoFunc):
|
|||
|
||||
class AsGML(GeoFunc):
|
||||
geom_param_pos = (1,)
|
||||
output_field_class = TextField
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, expression, version=2, precision=8, **extra):
|
||||
expressions = [version, expression]
|
||||
|
@ -189,7 +186,7 @@ class AsKML(AsGML):
|
|||
|
||||
|
||||
class AsSVG(GeoFunc):
|
||||
output_field_class = TextField
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, expression, relative=False, precision=8, **extra):
|
||||
relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
|
||||
|
@ -281,7 +278,7 @@ class ForceRHR(GeomOutputGeoFunc):
|
|||
|
||||
|
||||
class GeoHash(GeoFunc):
|
||||
output_field_class = TextField
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, expression, precision=None, **extra):
|
||||
expressions = [expression]
|
||||
|
@ -345,7 +342,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
|
||||
|
||||
class LineLocatePoint(GeoFunc):
|
||||
output_field_class = FloatField
|
||||
output_field = FloatField()
|
||||
arity = 2
|
||||
geom_param_pos = (0, 1)
|
||||
|
||||
|
@ -355,17 +352,17 @@ class MakeValid(GeoFunc):
|
|||
|
||||
|
||||
class MemSize(GeoFunc):
|
||||
output_field_class = IntegerField
|
||||
output_field = IntegerField()
|
||||
arity = 1
|
||||
|
||||
|
||||
class NumGeometries(GeoFunc):
|
||||
output_field_class = IntegerField
|
||||
output_field = IntegerField()
|
||||
arity = 1
|
||||
|
||||
|
||||
class NumPoints(GeoFunc):
|
||||
output_field_class = IntegerField
|
||||
output_field = IntegerField()
|
||||
arity = 1
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,9 @@ __all__ = [
|
|||
|
||||
|
||||
class StatAggregate(Aggregate):
|
||||
def __init__(self, y, x, output_field=FloatField(), filter=None):
|
||||
output_field = FloatField()
|
||||
|
||||
def __init__(self, y, x, output_field=None, filter=None):
|
||||
if not x or not y:
|
||||
raise ValueError('Both y and x must be provided.')
|
||||
super().__init__(y, x, output_field=output_field, filter=filter)
|
||||
|
@ -37,9 +39,7 @@ class RegrAvgY(StatAggregate):
|
|||
|
||||
class RegrCount(StatAggregate):
|
||||
function = 'REGR_COUNT'
|
||||
|
||||
def __init__(self, y, x, filter=None):
|
||||
super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
|
||||
output_field = IntegerField()
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
|
|
|
@ -3,17 +3,9 @@ from django.db.models import DateTimeField, Func, UUIDField
|
|||
|
||||
class RandomUUID(Func):
|
||||
template = 'GEN_RANDOM_UUID()'
|
||||
|
||||
def __init__(self, output_field=None, **extra):
|
||||
if output_field is None:
|
||||
output_field = UUIDField()
|
||||
super().__init__(output_field=output_field, **extra)
|
||||
output_field = UUIDField()
|
||||
|
||||
|
||||
class TransactionNow(Func):
|
||||
template = 'CURRENT_TIMESTAMP'
|
||||
|
||||
def __init__(self, output_field=None, **extra):
|
||||
if output_field is None:
|
||||
output_field = DateTimeField()
|
||||
super().__init__(output_field=output_field, **extra)
|
||||
output_field = DateTimeField()
|
||||
|
|
|
@ -202,10 +202,12 @@ SearchVectorField.register_lookup(SearchVectorExact)
|
|||
|
||||
|
||||
class TrigramBase(Func):
|
||||
output_field = FloatField()
|
||||
|
||||
def __init__(self, expression, string, **extra):
|
||||
if not hasattr(string, 'resolve_expression'):
|
||||
string = Value(string)
|
||||
super().__init__(expression, string, output_field=FloatField(), **extra)
|
||||
super().__init__(expression, string, **extra)
|
||||
|
||||
|
||||
class TrigramSimilarity(TrigramBase):
|
||||
|
|
|
@ -467,8 +467,10 @@ class DurationExpression(CombinedExpression):
|
|||
|
||||
|
||||
class TemporalSubtraction(CombinedExpression):
|
||||
output_field = fields.DurationField()
|
||||
|
||||
def __init__(self, lhs, rhs):
|
||||
super().__init__(lhs, self.SUB, rhs, output_field=fields.DurationField())
|
||||
super().__init__(lhs, self.SUB, rhs)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
connection.ops.check_expression_support(self)
|
||||
|
@ -692,8 +694,7 @@ class Star(Expression):
|
|||
|
||||
|
||||
class Random(Expression):
|
||||
def __init__(self):
|
||||
super().__init__(output_field=fields.FloatField())
|
||||
output_field = fields.FloatField()
|
||||
|
||||
def __repr__(self):
|
||||
return "Random()"
|
||||
|
@ -1017,6 +1018,7 @@ class Subquery(Expression):
|
|||
|
||||
class Exists(Subquery):
|
||||
template = 'EXISTS(%(subquery)s)'
|
||||
output_field = fields.BooleanField()
|
||||
|
||||
def __init__(self, *args, negated=False, **kwargs):
|
||||
self.negated = negated
|
||||
|
@ -1025,10 +1027,6 @@ class Exists(Subquery):
|
|||
def __invert__(self):
|
||||
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return fields.BooleanField()
|
||||
|
||||
def resolve_expression(self, query=None, **kwargs):
|
||||
# As a performance optimization, remove ordering since EXISTS doesn't
|
||||
# care about it, just whether or not a row matches.
|
||||
|
|
|
@ -142,9 +142,7 @@ class Length(Transform):
|
|||
"""Return the number of characters in the expression."""
|
||||
function = 'LENGTH'
|
||||
lookup_name = 'length'
|
||||
|
||||
def __init__(self, expression, *, output_field=None, **extra):
|
||||
super().__init__(expression, output_field=output_field or fields.IntegerField(), **extra)
|
||||
output_field = fields.IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return super().as_sql(compiler, connection, function='CHAR_LENGTH')
|
||||
|
@ -157,11 +155,7 @@ class Lower(Transform):
|
|||
|
||||
class Now(Func):
|
||||
template = 'CURRENT_TIMESTAMP'
|
||||
|
||||
def __init__(self, output_field=None, **extra):
|
||||
if output_field is None:
|
||||
output_field = fields.DateTimeField()
|
||||
super().__init__(output_field=output_field, **extra)
|
||||
output_field = fields.DateTimeField()
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
# Postgres' CURRENT_TIMESTAMP means "the time at the start of the
|
||||
|
@ -178,13 +172,7 @@ class StrIndex(Func):
|
|||
"""
|
||||
function = 'INSTR'
|
||||
arity = 2
|
||||
|
||||
def __init__(self, string, substring, **extra):
|
||||
"""
|
||||
string: the name of a field, or an expression returning a string
|
||||
substring: the name of a field, or an expression returning a string
|
||||
"""
|
||||
super().__init__(string, substring, output_field=fields.IntegerField(), **extra)
|
||||
output_field = fields.IntegerField()
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
return super().as_sql(compiler, connection, function='STRPOS')
|
||||
|
|
|
@ -2,14 +2,13 @@ from datetime import datetime
|
|||
|
||||
from django.conf import settings
|
||||
from django.db.models import (
|
||||
DateField, DateTimeField, DurationField, IntegerField, TimeField,
|
||||
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
|
||||
Transform,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
YearExact, YearGt, YearGte, YearLt, YearLte,
|
||||
)
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class TimezoneMixin:
|
||||
|
@ -31,6 +30,7 @@ class TimezoneMixin:
|
|||
|
||||
class Extract(TimezoneMixin, Transform):
|
||||
lookup_name = None
|
||||
output_field = IntegerField()
|
||||
|
||||
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
||||
if self.lookup_name is None:
|
||||
|
@ -75,10 +75,6 @@ class Extract(TimezoneMixin, Transform):
|
|||
)
|
||||
return copy
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return IntegerField()
|
||||
|
||||
|
||||
class ExtractYear(Extract):
|
||||
lookup_name = 'year'
|
||||
|
@ -183,17 +179,18 @@ class TruncBase(TimezoneMixin, Transform):
|
|||
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
|
||||
# Passing dates or times to functions expecting datetimes is most
|
||||
# likely a mistake.
|
||||
output_field = copy.output_field
|
||||
explicit_output_field = field.__class__ != copy.output_field.__class__
|
||||
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
|
||||
output_field = class_output_field or copy.output_field
|
||||
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
if type(field) == DateField and (
|
||||
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
|
||||
raise ValueError("Cannot truncate DateField '%s' to %s. " % (
|
||||
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField'
|
||||
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||
))
|
||||
elif isinstance(field, TimeField) and (
|
||||
isinstance(output_field, DateTimeField) or copy.kind in ('year', 'quarter', 'month', 'day', 'date')):
|
||||
raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
|
||||
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField'
|
||||
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||
))
|
||||
return copy
|
||||
|
||||
|
@ -241,9 +238,7 @@ class TruncDay(TruncBase):
|
|||
class TruncDate(TruncBase):
|
||||
kind = 'date'
|
||||
lookup_name = 'date'
|
||||
|
||||
def __init__(self, *args, output_field=None, **kwargs):
|
||||
super().__init__(*args, output_field=DateField(), **kwargs)
|
||||
output_field = DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
|
@ -256,9 +251,7 @@ class TruncDate(TruncBase):
|
|||
class TruncTime(TruncBase):
|
||||
kind = 'time'
|
||||
lookup_name = 'time'
|
||||
|
||||
def __init__(self, *args, output_field=None, **kwargs):
|
||||
super().__init__(*args, output_field=TimeField(), **kwargs)
|
||||
output_field = TimeField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
|
|
Loading…
Reference in New Issue