Fixed #28492 -- Defined default output_field of expressions at the class level.

This wasn't possible when settings were accessed during Field initialization
time as our test suite setup script was triggering imports of expressions
before settings were configured.
This commit is contained in:
Simon Charette 2017-08-08 13:31:59 -04:00 committed by Tim Graham
parent 13be453080
commit 08654a99bb
7 changed files with 35 additions and 65 deletions

View File

@ -16,12 +16,9 @@ NUMERIC_TYPES = (int, float, Decimal)
class GeoFuncMixin: class GeoFuncMixin:
function = None function = None
output_field_class = None
geom_param_pos = (0,) geom_param_pos = (0,)
def __init__(self, *expressions, **extra): def __init__(self, *expressions, **extra):
if 'output_field' not in extra and self.output_field_class:
extra['output_field'] = self.output_field_class()
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
# Ensure that value expressions are geometric. # Ensure that value expressions are geometric.
@ -137,13 +134,13 @@ class Area(OracleToleranceMixin, GeoFunc):
class Azimuth(GeoFunc): class Azimuth(GeoFunc):
output_field_class = FloatField output_field = FloatField()
arity = 2 arity = 2
geom_param_pos = (0, 1) geom_param_pos = (0, 1)
class AsGeoJSON(GeoFunc): class AsGeoJSON(GeoFunc):
output_field_class = TextField output_field = TextField()
def __init__(self, expression, bbox=False, crs=False, precision=8, **extra): def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
expressions = [expression] expressions = [expression]
@ -163,7 +160,7 @@ class AsGeoJSON(GeoFunc):
class AsGML(GeoFunc): class AsGML(GeoFunc):
geom_param_pos = (1,) geom_param_pos = (1,)
output_field_class = TextField output_field = TextField()
def __init__(self, expression, version=2, precision=8, **extra): def __init__(self, expression, version=2, precision=8, **extra):
expressions = [version, expression] expressions = [version, expression]
@ -189,7 +186,7 @@ class AsKML(AsGML):
class AsSVG(GeoFunc): class AsSVG(GeoFunc):
output_field_class = TextField output_field = TextField()
def __init__(self, expression, relative=False, precision=8, **extra): def __init__(self, expression, relative=False, precision=8, **extra):
relative = relative if hasattr(relative, 'resolve_expression') else int(relative) relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
@ -281,7 +278,7 @@ class ForceRHR(GeomOutputGeoFunc):
class GeoHash(GeoFunc): class GeoHash(GeoFunc):
output_field_class = TextField output_field = TextField()
def __init__(self, expression, precision=None, **extra): def __init__(self, expression, precision=None, **extra):
expressions = [expression] expressions = [expression]
@ -345,7 +342,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
class LineLocatePoint(GeoFunc): class LineLocatePoint(GeoFunc):
output_field_class = FloatField output_field = FloatField()
arity = 2 arity = 2
geom_param_pos = (0, 1) geom_param_pos = (0, 1)
@ -355,17 +352,17 @@ class MakeValid(GeoFunc):
class MemSize(GeoFunc): class MemSize(GeoFunc):
output_field_class = IntegerField output_field = IntegerField()
arity = 1 arity = 1
class NumGeometries(GeoFunc): class NumGeometries(GeoFunc):
output_field_class = IntegerField output_field = IntegerField()
arity = 1 arity = 1
class NumPoints(GeoFunc): class NumPoints(GeoFunc):
output_field_class = IntegerField output_field = IntegerField()
arity = 1 arity = 1

View File

@ -8,7 +8,9 @@ __all__ = [
class StatAggregate(Aggregate): class StatAggregate(Aggregate):
def __init__(self, y, x, output_field=FloatField(), filter=None): output_field = FloatField()
def __init__(self, y, x, output_field=None, filter=None):
if not x or not y: if not x or not y:
raise ValueError('Both y and x must be provided.') raise ValueError('Both y and x must be provided.')
super().__init__(y, x, output_field=output_field, filter=filter) super().__init__(y, x, output_field=output_field, filter=filter)
@ -37,9 +39,7 @@ class RegrAvgY(StatAggregate):
class RegrCount(StatAggregate): class RegrCount(StatAggregate):
function = 'REGR_COUNT' function = 'REGR_COUNT'
output_field = IntegerField()
def __init__(self, y, x, filter=None):
super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if value is None: if value is None:

View File

@ -3,17 +3,9 @@ from django.db.models import DateTimeField, Func, UUIDField
class RandomUUID(Func): class RandomUUID(Func):
template = 'GEN_RANDOM_UUID()' template = 'GEN_RANDOM_UUID()'
output_field = UUIDField()
def __init__(self, output_field=None, **extra):
if output_field is None:
output_field = UUIDField()
super().__init__(output_field=output_field, **extra)
class TransactionNow(Func): class TransactionNow(Func):
template = 'CURRENT_TIMESTAMP' template = 'CURRENT_TIMESTAMP'
output_field = DateTimeField()
def __init__(self, output_field=None, **extra):
if output_field is None:
output_field = DateTimeField()
super().__init__(output_field=output_field, **extra)

View File

@ -202,10 +202,12 @@ SearchVectorField.register_lookup(SearchVectorExact)
class TrigramBase(Func): class TrigramBase(Func):
output_field = FloatField()
def __init__(self, expression, string, **extra): def __init__(self, expression, string, **extra):
if not hasattr(string, 'resolve_expression'): if not hasattr(string, 'resolve_expression'):
string = Value(string) string = Value(string)
super().__init__(expression, string, output_field=FloatField(), **extra) super().__init__(expression, string, **extra)
class TrigramSimilarity(TrigramBase): class TrigramSimilarity(TrigramBase):

View File

@ -467,8 +467,10 @@ class DurationExpression(CombinedExpression):
class TemporalSubtraction(CombinedExpression): class TemporalSubtraction(CombinedExpression):
output_field = fields.DurationField()
def __init__(self, lhs, rhs): def __init__(self, lhs, rhs):
super().__init__(lhs, self.SUB, rhs, output_field=fields.DurationField()) super().__init__(lhs, self.SUB, rhs)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
@ -692,8 +694,7 @@ class Star(Expression):
class Random(Expression): class Random(Expression):
def __init__(self): output_field = fields.FloatField()
super().__init__(output_field=fields.FloatField())
def __repr__(self): def __repr__(self):
return "Random()" return "Random()"
@ -1017,6 +1018,7 @@ class Subquery(Expression):
class Exists(Subquery): class Exists(Subquery):
template = 'EXISTS(%(subquery)s)' template = 'EXISTS(%(subquery)s)'
output_field = fields.BooleanField()
def __init__(self, *args, negated=False, **kwargs): def __init__(self, *args, negated=False, **kwargs):
self.negated = negated self.negated = negated
@ -1025,10 +1027,6 @@ class Exists(Subquery):
def __invert__(self): def __invert__(self):
return type(self)(self.queryset, negated=(not self.negated), **self.extra) return type(self)(self.queryset, negated=(not self.negated), **self.extra)
@property
def output_field(self):
return fields.BooleanField()
def resolve_expression(self, query=None, **kwargs): def resolve_expression(self, query=None, **kwargs):
# As a performance optimization, remove ordering since EXISTS doesn't # As a performance optimization, remove ordering since EXISTS doesn't
# care about it, just whether or not a row matches. # care about it, just whether or not a row matches.

View File

@ -142,9 +142,7 @@ class Length(Transform):
"""Return the number of characters in the expression.""" """Return the number of characters in the expression."""
function = 'LENGTH' function = 'LENGTH'
lookup_name = 'length' lookup_name = 'length'
output_field = fields.IntegerField()
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(expression, output_field=output_field or fields.IntegerField(), **extra)
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection):
return super().as_sql(compiler, connection, function='CHAR_LENGTH') return super().as_sql(compiler, connection, function='CHAR_LENGTH')
@ -157,11 +155,7 @@ class Lower(Transform):
class Now(Func): class Now(Func):
template = 'CURRENT_TIMESTAMP' template = 'CURRENT_TIMESTAMP'
output_field = fields.DateTimeField()
def __init__(self, output_field=None, **extra):
if output_field is None:
output_field = fields.DateTimeField()
super().__init__(output_field=output_field, **extra)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):
# Postgres' CURRENT_TIMESTAMP means "the time at the start of the # Postgres' CURRENT_TIMESTAMP means "the time at the start of the
@ -178,13 +172,7 @@ class StrIndex(Func):
""" """
function = 'INSTR' function = 'INSTR'
arity = 2 arity = 2
output_field = fields.IntegerField()
def __init__(self, string, substring, **extra):
"""
string: the name of a field, or an expression returning a string
substring: the name of a field, or an expression returning a string
"""
super().__init__(string, substring, output_field=fields.IntegerField(), **extra)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):
return super().as_sql(compiler, connection, function='STRPOS') return super().as_sql(compiler, connection, function='STRPOS')

View File

@ -2,14 +2,13 @@ from datetime import datetime
from django.conf import settings from django.conf import settings
from django.db.models import ( from django.db.models import (
DateField, DateTimeField, DurationField, IntegerField, TimeField, DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
Transform, Transform,
) )
from django.db.models.lookups import ( from django.db.models.lookups import (
YearExact, YearGt, YearGte, YearLt, YearLte, YearExact, YearGt, YearGte, YearLt, YearLte,
) )
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property
class TimezoneMixin: class TimezoneMixin:
@ -31,6 +30,7 @@ class TimezoneMixin:
class Extract(TimezoneMixin, Transform): class Extract(TimezoneMixin, Transform):
lookup_name = None lookup_name = None
output_field = IntegerField()
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra): def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
if self.lookup_name is None: if self.lookup_name is None:
@ -75,10 +75,6 @@ class Extract(TimezoneMixin, Transform):
) )
return copy return copy
@cached_property
def output_field(self):
return IntegerField()
class ExtractYear(Extract): class ExtractYear(Extract):
lookup_name = 'year' lookup_name = 'year'
@ -183,17 +179,18 @@ class TruncBase(TimezoneMixin, Transform):
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField') raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
# Passing dates or times to functions expecting datetimes is most # Passing dates or times to functions expecting datetimes is most
# likely a mistake. # likely a mistake.
output_field = copy.output_field class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
explicit_output_field = field.__class__ != copy.output_field.__class__ output_field = class_output_field or copy.output_field
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
if type(field) == DateField and ( if type(field) == DateField and (
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')): isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
raise ValueError("Cannot truncate DateField '%s' to %s. " % ( raise ValueError("Cannot truncate DateField '%s' to %s. " % (
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField' field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
)) ))
elif isinstance(field, TimeField) and ( elif isinstance(field, TimeField) and (
isinstance(output_field, DateTimeField) or copy.kind in ('year', 'quarter', 'month', 'day', 'date')): isinstance(output_field, DateTimeField) or copy.kind in ('year', 'quarter', 'month', 'day', 'date')):
raise ValueError("Cannot truncate TimeField '%s' to %s. " % ( raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
field.name, output_field.__class__.__name__ if explicit_output_field else 'DateTimeField' field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
)) ))
return copy return copy
@ -241,9 +238,7 @@ class TruncDay(TruncBase):
class TruncDate(TruncBase): class TruncDate(TruncBase):
kind = 'date' kind = 'date'
lookup_name = 'date' lookup_name = 'date'
output_field = DateField()
def __init__(self, *args, output_field=None, **kwargs):
super().__init__(*args, output_field=DateField(), **kwargs)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date. # Cast to date rather than truncate to date.
@ -256,9 +251,7 @@ class TruncDate(TruncBase):
class TruncTime(TruncBase): class TruncTime(TruncBase):
kind = 'time' kind = 'time'
lookup_name = 'time' lookup_name = 'time'
output_field = TimeField()
def __init__(self, *args, output_field=None, **kwargs):
super().__init__(*args, output_field=TimeField(), **kwargs)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date. # Cast to date rather than truncate to date.