Fixed #28492 -- Defined default output_field of expressions at the class level.
This wasn't possible when settings were accessed during Field initialization time as our test suite setup script was triggering imports of expressions before settings were configured.
This commit is contained in:
parent
13be453080
commit
08654a99bb
|
@ -16,12 +16,9 @@ NUMERIC_TYPES = (int, float, Decimal)
|
||||||
|
|
||||||
class GeoFuncMixin:
|
class GeoFuncMixin:
|
||||||
function = None
|
function = None
|
||||||
output_field_class = None
|
|
||||||
geom_param_pos = (0,)
|
geom_param_pos = (0,)
|
||||||
|
|
||||||
def __init__(self, *expressions, **extra):
|
def __init__(self, *expressions, **extra):
|
||||||
if 'output_field' not in extra and self.output_field_class:
|
|
||||||
extra['output_field'] = self.output_field_class()
|
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
# Ensure that value expressions are geometric.
|
# Ensure that value expressions are geometric.
|
||||||
|
@ -137,13 +134,13 @@ class Area(OracleToleranceMixin, GeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class Azimuth(GeoFunc):
|
class Azimuth(GeoFunc):
|
||||||
output_field_class = FloatField
|
output_field = FloatField()
|
||||||
arity = 2
|
arity = 2
|
||||||
geom_param_pos = (0, 1)
|
geom_param_pos = (0, 1)
|
||||||
|
|
||||||
|
|
||||||
class AsGeoJSON(GeoFunc):
|
class AsGeoJSON(GeoFunc):
|
||||||
output_field_class = TextField
|
output_field = TextField()
|
||||||
|
|
||||||
def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
|
def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
|
||||||
expressions = [expression]
|
expressions = [expression]
|
||||||
|
@ -163,7 +160,7 @@ class AsGeoJSON(GeoFunc):
|
||||||
|
|
||||||
class AsGML(GeoFunc):
|
class AsGML(GeoFunc):
|
||||||
geom_param_pos = (1,)
|
geom_param_pos = (1,)
|
||||||
output_field_class = TextField
|
output_field = TextField()
|
||||||
|
|
||||||
def __init__(self, expression, version=2, precision=8, **extra):
|
def __init__(self, expression, version=2, precision=8, **extra):
|
||||||
expressions = [version, expression]
|
expressions = [version, expression]
|
||||||
|
@ -189,7 +186,7 @@ class AsKML(AsGML):
|
||||||
|
|
||||||
|
|
||||||
class AsSVG(GeoFunc):
|
class AsSVG(GeoFunc):
|
||||||
output_field_class = TextField
|
output_field = TextField()
|
||||||
|
|
||||||
def __init__(self, expression, relative=False, precision=8, **extra):
|
def __init__(self, expression, relative=False, precision=8, **extra):
|
||||||
relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
|
relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
|
||||||
|
@ -281,7 +278,7 @@ class ForceRHR(GeomOutputGeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class GeoHash(GeoFunc):
|
class GeoHash(GeoFunc):
|
||||||
output_field_class = TextField
|
output_field = TextField()
|
||||||
|
|
||||||
def __init__(self, expression, precision=None, **extra):
|
def __init__(self, expression, precision=None, **extra):
|
||||||
expressions = [expression]
|
expressions = [expression]
|
||||||
|
@ -345,7 +342,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class LineLocatePoint(GeoFunc):
|
class LineLocatePoint(GeoFunc):
|
||||||
output_field_class = FloatField
|
output_field = FloatField()
|
||||||
arity = 2
|
arity = 2
|
||||||
geom_param_pos = (0, 1)
|
geom_param_pos = (0, 1)
|
||||||
|
|
||||||
|
@ -355,17 +352,17 @@ class MakeValid(GeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class MemSize(GeoFunc):
|
class MemSize(GeoFunc):
|
||||||
output_field_class = IntegerField
|
output_field = IntegerField()
|
||||||
arity = 1
|
arity = 1
|
||||||
|
|
||||||
|
|
||||||
class NumGeometries(GeoFunc):
|
class NumGeometries(GeoFunc):
|
||||||
output_field_class = IntegerField
|
output_field = IntegerField()
|
||||||
arity = 1
|
arity = 1
|
||||||
|
|
||||||
|
|
||||||
class NumPoints(GeoFunc):
|
class NumPoints(GeoFunc):
|
||||||
output_field_class = IntegerField
|
output_field = IntegerField()
|
||||||
arity = 1
|
arity = 1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,9 @@ __all__ = [
|
||||||
|
|
||||||
|
|
||||||
class StatAggregate(Aggregate):
|
class StatAggregate(Aggregate):
|
||||||
def __init__(self, y, x, output_field=FloatField(), filter=None):
|
output_field = FloatField()
|
||||||
|
|
||||||
|
def __init__(self, y, x, output_field=None, filter=None):
|
||||||
if not x or not y:
|
if not x or not y:
|
||||||
raise ValueError('Both y and x must be provided.')
|
raise ValueError('Both y and x must be provided.')
|
||||||
super().__init__(y, x, output_field=output_field, filter=filter)
|
super().__init__(y, x, output_field=output_field, filter=filter)
|
||||||
|
@ -37,9 +39,7 @@ class RegrAvgY(StatAggregate):
|
||||||
|
|
||||||
class RegrCount(StatAggregate):
|
class RegrCount(StatAggregate):
|
||||||
function = 'REGR_COUNT'
|
function = 'REGR_COUNT'
|
||||||
|
output_field = IntegerField()
|
||||||
def __init__(self, y, x, filter=None):
|
|
||||||
super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
|
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection):
|
def convert_value(self, value, expression, connection):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
|
@ -3,17 +3,9 @@ from django.db.models import DateTimeField, Func, UUIDField
|
||||||
|
|
||||||
class RandomUUID(Func):
|
class RandomUUID(Func):
|
||||||
template = 'GEN_RANDOM_UUID()'
|
template = 'GEN_RANDOM_UUID()'
|
||||||
|
output_field = UUIDField()
|
||||||
def __init__(self, output_field=None, **extra):
|
|
||||||
if output_field is None:
|
|
||||||
output_field = UUIDField()
|
|
||||||
super().__init__(output_field=output_field, **extra)
|
|
||||||
|
|
||||||
|
|
||||||
class TransactionNow(Func):
|
class TransactionNow(Func):
|
||||||
template = 'CURRENT_TIMESTAMP'
|
template = 'CURRENT_TIMESTAMP'
|
||||||
|
output_field = DateTimeField()
|
||||||
def __init__(self, output_field=None, **extra):
|
|
||||||
if output_field is None:
|
|
||||||
output_field = DateTimeField()
|
|
||||||
super().__init__(output_field=output_field, **extra)
|
|
||||||
|
|
|
@ -202,10 +202,12 @@ SearchVectorField.register_lookup(SearchVectorExact)
|
||||||
|
|
||||||
|
|
||||||
class TrigramBase(Func):
|
class TrigramBase(Func):
|
||||||
|
output_field = FloatField()
|
||||||
|
|
||||||
def __init__(self, expression, string, **extra):
|
def __init__(self, expression, string, **extra):
|
||||||
if not hasattr(string, 'resolve_expression'):
|
if not hasattr(string, 'resolve_expression'):
|
||||||
string = Value(string)
|
string = Value(string)
|
||||||
super().__init__(expression, string, output_field=FloatField(), **extra)
|
super().__init__(expression, string, **extra)
|
||||||
|
|
||||||
|
|
||||||
class TrigramSimilarity(TrigramBase):
|
class TrigramSimilarity(TrigramBase):
|
||||||
|
|
|
@ -467,8 +467,10 @@ class DurationExpression(CombinedExpression):
|
||||||
|
|
||||||
|
|
||||||
class TemporalSubtraction(CombinedExpression):
|
class TemporalSubtraction(CombinedExpression):
|
||||||
|
output_field = fields.DurationField()
|
||||||
|
|
||||||
def __init__(self, lhs, rhs):
|
def __init__(self, lhs, rhs):
|
||||||
super().__init__(lhs, self.SUB, rhs, output_field=fields.DurationField())
|
super().__init__(lhs, self.SUB, rhs)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
|
@ -692,8 +694,7 @@ class Star(Expression):
|
||||||
|
|
||||||
|
|
||||||
class Random(Expression):
|
class Random(Expression):
|
||||||
def __init__(self):
|
output_field = fields.FloatField()
|
||||||
super().__init__(output_field=fields.FloatField())
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "Random()"
|
return "Random()"
|
||||||
|
@ -1017,6 +1018,7 @@ class Subquery(Expression):
|
||||||
|
|
||||||
class Exists(Subquery):
|
class Exists(Subquery):
|
||||||
template = 'EXISTS(%(subquery)s)'
|
template = 'EXISTS(%(subquery)s)'
|
||||||
|
output_field = fields.BooleanField()
|
||||||
|
|
||||||
def __init__(self, *args, negated=False, **kwargs):
|
def __init__(self, *args, negated=False, **kwargs):
|
||||||
self.negated = negated
|
self.negated = negated
|
||||||
|
@ -1025,10 +1027,6 @@ class Exists(Subquery):
|
||||||
def __invert__(self):
|
def __invert__(self):
|
||||||
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
|
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
|
||||||
|
|
||||||
@property
|
|
||||||
def output_field(self):
|
|
||||||
return fields.BooleanField()
|
|
||||||
|
|
||||||
def resolve_expression(self, query=None, **kwargs):
|
def resolve_expression(self, query=None, **kwargs):
|
||||||
# As a performance optimization, remove ordering since EXISTS doesn't
|
# As a performance optimization, remove ordering since EXISTS doesn't
|
||||||
# care about it, just whether or not a row matches.
|
# care about it, just whether or not a row matches.
|
||||||
|
|
|
@ -142,9 +142,7 @@ class Length(Transform):
|
||||||
"""Return the number of characters in the expression."""
|
"""Return the number of characters in the expression."""
|
||||||
function = 'LENGTH'
|
function = 'LENGTH'
|
||||||
lookup_name = 'length'
|
lookup_name = 'length'
|
||||||
|
output_field = fields.IntegerField()
|
||||||
def __init__(self, expression, *, output_field=None, **extra):
|
|
||||||
super().__init__(expression, output_field=output_field or fields.IntegerField(), **extra)
|
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection):
|
||||||
return super().as_sql(compiler, connection, function='CHAR_LENGTH')
|
return super().as_sql(compiler, connection, function='CHAR_LENGTH')
|
||||||
|
@ -157,11 +155,7 @@ class Lower(Transform):
|
||||||
|
|
||||||
class Now(Func):
|
class Now(Func):
|
||||||
template = 'CURRENT_TIMESTAMP'
|
template = 'CURRENT_TIMESTAMP'
|
||||||
|
output_field = fields.DateTimeField()
|
||||||
def __init__(self, output_field=None, **extra):
|
|
||||||
if output_field is None:
|
|
||||||
output_field = fields.DateTimeField()
|
|
||||||
super().__init__(output_field=output_field, **extra)
|
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
# Postgres' CURRENT_TIMESTAMP means "the time at the start of the
|
# Postgres' CURRENT_TIMESTAMP means "the time at the start of the
|
||||||
|
@ -178,13 +172,7 @@ class StrIndex(Func):
|
||||||
"""
|
"""
|
||||||
function = 'INSTR'
|
function = 'INSTR'
|
||||||
arity = 2
|
arity = 2
|
||||||
|
output_field = fields.IntegerField()
|
||||||
def __init__(self, string, substring, **extra):
|
|
||||||
"""
|
|
||||||
string: the name of a field, or an expression returning a string
|
|
||||||
substring: the name of a field, or an expression returning a string
|
|
||||||
"""
|
|
||||||
super().__init__(string, substring, output_field=fields.IntegerField(), **extra)
|
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
return super().as_sql(compiler, connection, function='STRPOS')
|
return super().as_sql(compiler, connection, function='STRPOS')
|
||||||
|
|
|
@ -2,14 +2,13 @@ from datetime import datetime
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
DateField, DateTimeField, DurationField, IntegerField, TimeField,
|
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
|
||||||
Transform,
|
Transform,
|
||||||
)
|
)
|
||||||
from django.db.models.lookups import (
|
from django.db.models.lookups import (
|
||||||
YearExact, YearGt, YearGte, YearLt, YearLte,
|
YearExact, YearGt, YearGte, YearLt, YearLte,
|
||||||
)
|
)
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.functional import cached_property
|
|
||||||
|
|
||||||
|
|
||||||
class TimezoneMixin:
|
class TimezoneMixin:
|
||||||
|
@ -31,6 +30,7 @@ class TimezoneMixin:
|
||||||
|
|
||||||
class Extract(TimezoneMixin, Transform):
|
class Extract(TimezoneMixin, Transform):
|
||||||
lookup_name = None
|
lookup_name = None
|
||||||
|
output_field = IntegerField()
|
||||||
|
|
||||||
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
||||||
if self.lookup_name is None:
|
if self.lookup_name is None:
|
||||||
|
@ -75,10 +75,6 @@ class Extract(TimezoneMixin, Transform):
|
||||||
)
|
)
|
||||||
return copy
|
return copy
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def output_field(self):
|
|
||||||
return IntegerField()
|
|
||||||
|
|
||||||
|
|
||||||
class ExtractYear(Extract):
|
class ExtractYear(Extract):
|
||||||
lookup_name = 'year'
|
lookup_name = 'year'
|
||||||
|
@ -183,17 +179,18 @@ class TruncBase(TimezoneMixin, Transform):
|
||||||
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
|
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
|
||||||
# Passing dates or times to functions expecting datetimes is most
|
# Passing dates or times to functions expecting datetimes is most
|
||||||
# likely a mistake.
|
# likely a mistake.
|
||||||
output_field = copy.output_field
|
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
|
||||||
explicit_output_field = field.__class__ != copy.output_field.__class__
|
output_field = class_output_field or copy.output_field
|
||||||
|
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
|
||||||
if type(field) == DateField and (
|
if type(field) == DateField and (
|
||||||
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
|
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
|
||||||
raise ValueError("Cannot truncate DateField '%s' to %s. " % (
|
raise ValueError("Cannot truncate DateField '%s' to %s. " % (
|
||||||
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField'
|
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||||
))
|
))
|
||||||
elif isinstance(field, TimeField) and (
|
elif isinstance(field, TimeField) and (
|
||||||
isinstance(output_field, DateTimeField) or copy.kind in ('year', 'quarter', 'month', 'day', 'date')):
|
isinstance(output_field, DateTimeField) or copy.kind in ('year', 'quarter', 'month', 'day', 'date')):
|
||||||
raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
|
raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
|
||||||
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField'
|
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||||
))
|
))
|
||||||
return copy
|
return copy
|
||||||
|
|
||||||
|
@ -241,9 +238,7 @@ class TruncDay(TruncBase):
|
||||||
class TruncDate(TruncBase):
|
class TruncDate(TruncBase):
|
||||||
kind = 'date'
|
kind = 'date'
|
||||||
lookup_name = 'date'
|
lookup_name = 'date'
|
||||||
|
output_field = DateField()
|
||||||
def __init__(self, *args, output_field=None, **kwargs):
|
|
||||||
super().__init__(*args, output_field=DateField(), **kwargs)
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
# Cast to date rather than truncate to date.
|
# Cast to date rather than truncate to date.
|
||||||
|
@ -256,9 +251,7 @@ class TruncDate(TruncBase):
|
||||||
class TruncTime(TruncBase):
|
class TruncTime(TruncBase):
|
||||||
kind = 'time'
|
kind = 'time'
|
||||||
lookup_name = 'time'
|
lookup_name = 'time'
|
||||||
|
output_field = TimeField()
|
||||||
def __init__(self, *args, output_field=None, **kwargs):
|
|
||||||
super().__init__(*args, output_field=TimeField(), **kwargs)
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
# Cast to date rather than truncate to date.
|
# Cast to date rather than truncate to date.
|
||||||
|
|
Loading…
Reference in New Issue