Fixed #28492 -- Defined default output_field of expressions at the class level.

This wasn't possible when settings were accessed during Field initialization
time as our test suite setup script was triggering imports of expressions
before settings were configured.
This commit is contained in:
Simon Charette 2017-08-08 13:31:59 -04:00 committed by Tim Graham
parent 13be453080
commit 08654a99bb
7 changed files with 35 additions and 65 deletions

View File

@ -16,12 +16,9 @@ NUMERIC_TYPES = (int, float, Decimal)
class GeoFuncMixin:
function = None
output_field_class = None
geom_param_pos = (0,)
def __init__(self, *expressions, **extra):
if 'output_field' not in extra and self.output_field_class:
extra['output_field'] = self.output_field_class()
super().__init__(*expressions, **extra)
# Ensure that value expressions are geometric.
@ -137,13 +134,13 @@ class Area(OracleToleranceMixin, GeoFunc):
class Azimuth(GeoFunc):
output_field_class = FloatField
output_field = FloatField()
arity = 2
geom_param_pos = (0, 1)
class AsGeoJSON(GeoFunc):
output_field_class = TextField
output_field = TextField()
def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
expressions = [expression]
@ -163,7 +160,7 @@ class AsGeoJSON(GeoFunc):
class AsGML(GeoFunc):
geom_param_pos = (1,)
output_field_class = TextField
output_field = TextField()
def __init__(self, expression, version=2, precision=8, **extra):
expressions = [version, expression]
@ -189,7 +186,7 @@ class AsKML(AsGML):
class AsSVG(GeoFunc):
output_field_class = TextField
output_field = TextField()
def __init__(self, expression, relative=False, precision=8, **extra):
relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
@ -281,7 +278,7 @@ class ForceRHR(GeomOutputGeoFunc):
class GeoHash(GeoFunc):
output_field_class = TextField
output_field = TextField()
def __init__(self, expression, precision=None, **extra):
expressions = [expression]
@ -345,7 +342,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
class LineLocatePoint(GeoFunc):
output_field_class = FloatField
output_field = FloatField()
arity = 2
geom_param_pos = (0, 1)
@ -355,17 +352,17 @@ class MakeValid(GeoFunc):
class MemSize(GeoFunc):
output_field_class = IntegerField
output_field = IntegerField()
arity = 1
class NumGeometries(GeoFunc):
output_field_class = IntegerField
output_field = IntegerField()
arity = 1
class NumPoints(GeoFunc):
output_field_class = IntegerField
output_field = IntegerField()
arity = 1

View File

@ -8,7 +8,9 @@ __all__ = [
class StatAggregate(Aggregate):
def __init__(self, y, x, output_field=FloatField(), filter=None):
output_field = FloatField()
def __init__(self, y, x, output_field=None, filter=None):
if not x or not y:
raise ValueError('Both y and x must be provided.')
super().__init__(y, x, output_field=output_field, filter=filter)
@ -37,9 +39,7 @@ class RegrAvgY(StatAggregate):
class RegrCount(StatAggregate):
function = 'REGR_COUNT'
def __init__(self, y, x, filter=None):
super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
output_field = IntegerField()
def convert_value(self, value, expression, connection):
if value is None:

View File

@ -3,17 +3,9 @@ from django.db.models import DateTimeField, Func, UUIDField
class RandomUUID(Func):
template = 'GEN_RANDOM_UUID()'
def __init__(self, output_field=None, **extra):
if output_field is None:
output_field = UUIDField()
super().__init__(output_field=output_field, **extra)
output_field = UUIDField()
class TransactionNow(Func):
template = 'CURRENT_TIMESTAMP'
def __init__(self, output_field=None, **extra):
if output_field is None:
output_field = DateTimeField()
super().__init__(output_field=output_field, **extra)
output_field = DateTimeField()

View File

@ -202,10 +202,12 @@ SearchVectorField.register_lookup(SearchVectorExact)
class TrigramBase(Func):
output_field = FloatField()
def __init__(self, expression, string, **extra):
if not hasattr(string, 'resolve_expression'):
string = Value(string)
super().__init__(expression, string, output_field=FloatField(), **extra)
super().__init__(expression, string, **extra)
class TrigramSimilarity(TrigramBase):

View File

@ -467,8 +467,10 @@ class DurationExpression(CombinedExpression):
class TemporalSubtraction(CombinedExpression):
output_field = fields.DurationField()
def __init__(self, lhs, rhs):
super().__init__(lhs, self.SUB, rhs, output_field=fields.DurationField())
super().__init__(lhs, self.SUB, rhs)
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
@ -692,8 +694,7 @@ class Star(Expression):
class Random(Expression):
def __init__(self):
super().__init__(output_field=fields.FloatField())
output_field = fields.FloatField()
def __repr__(self):
return "Random()"
@ -1017,6 +1018,7 @@ class Subquery(Expression):
class Exists(Subquery):
template = 'EXISTS(%(subquery)s)'
output_field = fields.BooleanField()
def __init__(self, *args, negated=False, **kwargs):
self.negated = negated
@ -1025,10 +1027,6 @@ class Exists(Subquery):
def __invert__(self):
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
@property
def output_field(self):
return fields.BooleanField()
def resolve_expression(self, query=None, **kwargs):
# As a performance optimization, remove ordering since EXISTS doesn't
# care about it, just whether or not a row matches.

View File

@ -142,9 +142,7 @@ class Length(Transform):
"""Return the number of characters in the expression."""
function = 'LENGTH'
lookup_name = 'length'
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(expression, output_field=output_field or fields.IntegerField(), **extra)
output_field = fields.IntegerField()
def as_mysql(self, compiler, connection):
return super().as_sql(compiler, connection, function='CHAR_LENGTH')
@ -157,11 +155,7 @@ class Lower(Transform):
class Now(Func):
template = 'CURRENT_TIMESTAMP'
def __init__(self, output_field=None, **extra):
if output_field is None:
output_field = fields.DateTimeField()
super().__init__(output_field=output_field, **extra)
output_field = fields.DateTimeField()
def as_postgresql(self, compiler, connection):
# Postgres' CURRENT_TIMESTAMP means "the time at the start of the
@ -178,13 +172,7 @@ class StrIndex(Func):
"""
function = 'INSTR'
arity = 2
def __init__(self, string, substring, **extra):
"""
string: the name of a field, or an expression returning a string
substring: the name of a field, or an expression returning a string
"""
super().__init__(string, substring, output_field=fields.IntegerField(), **extra)
output_field = fields.IntegerField()
def as_postgresql(self, compiler, connection):
return super().as_sql(compiler, connection, function='STRPOS')

View File

@ -2,14 +2,13 @@ from datetime import datetime
from django.conf import settings
from django.db.models import (
DateField, DateTimeField, DurationField, IntegerField, TimeField,
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
Transform,
)
from django.db.models.lookups import (
YearExact, YearGt, YearGte, YearLt, YearLte,
)
from django.utils import timezone
from django.utils.functional import cached_property
class TimezoneMixin:
@ -31,6 +30,7 @@ class TimezoneMixin:
class Extract(TimezoneMixin, Transform):
lookup_name = None
output_field = IntegerField()
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
if self.lookup_name is None:
@ -75,10 +75,6 @@ class Extract(TimezoneMixin, Transform):
)
return copy
@cached_property
def output_field(self):
return IntegerField()
class ExtractYear(Extract):
lookup_name = 'year'
@ -183,17 +179,18 @@ class TruncBase(TimezoneMixin, Transform):
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
# Passing dates or times to functions expecting datetimes is most
# likely a mistake.
output_field = copy.output_field
explicit_output_field = field.__class__ != copy.output_field.__class__
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
output_field = class_output_field or copy.output_field
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
if type(field) == DateField and (
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
raise ValueError("Cannot truncate DateField '%s' to %s. " % (
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField'
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
elif isinstance(field, TimeField) and (
isinstance(output_field, DateTimeField) or copy.kind in ('year', 'quarter', 'month', 'day', 'date')):
raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField'
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
return copy
@ -241,9 +238,7 @@ class TruncDay(TruncBase):
class TruncDate(TruncBase):
kind = 'date'
lookup_name = 'date'
def __init__(self, *args, output_field=None, **kwargs):
super().__init__(*args, output_field=DateField(), **kwargs)
output_field = DateField()
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
@ -256,9 +251,7 @@ class TruncDate(TruncBase):
class TruncTime(TruncBase):
kind = 'time'
lookup_name = 'time'
def __init__(self, *args, output_field=None, **kwargs):
super().__init__(*args, output_field=TimeField(), **kwargs)
output_field = TimeField()
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.