Fixed #28394 -- Allowed setting BaseExpression.output_field (renamed from _output_field).

This commit is contained in:
Sergey Fedoseev 2017-07-15 06:56:01 +05:00 committed by Tim Graham
parent 5debbdfcc8
commit 504ce3914f
8 changed files with 49 additions and 36 deletions

View File

@ -37,7 +37,7 @@ class BoolOr(Aggregate):
class JSONBAgg(Aggregate):
function = 'JSONB_AGG'
_output_field = JSONField()
output_field = JSONField()
def convert_value(self, value, expression, connection, context):
if not value:

View File

@ -115,7 +115,7 @@ class KeyTransform(Transform):
class KeyTextTransform(KeyTransform):
operator = '->>'
nested_operator = '#>>'
_output_field = TextField()
output_field = TextField()
class KeyTransformTextLookupMixin:

View File

@ -47,7 +47,7 @@ class SearchVectorCombinable:
class SearchVector(SearchVectorCombinable, Func):
function = 'to_tsvector'
arg_joiner = " || ' ' || "
_output_field = SearchVectorField()
output_field = SearchVectorField()
config = None
def __init__(self, *expressions, **extra):
@ -125,7 +125,7 @@ class SearchQueryCombinable:
class SearchQuery(SearchQueryCombinable, Value):
_output_field = SearchQueryField()
output_field = SearchQueryField()
def __init__(self, value, output_field=None, *, config=None, invert=False):
self.config = config
@ -170,7 +170,7 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
class SearchRank(Func):
function = 'ts_rank'
_output_field = FloatField()
output_field = FloatField()
def __init__(self, vector, query, **extra):
if not hasattr(vector, 'resolve_expression'):

View File

@ -125,11 +125,11 @@ class BaseExpression:
# aggregate specific fields
is_summary = False
_output_field = None
_output_field_resolved_to_none = False
def __init__(self, output_field=None):
if output_field is not None:
self._output_field = output_field
self.output_field = output_field
def get_db_converters(self, connection):
return [self.convert_value] + self.output_field.get_db_converters(connection)
@ -223,21 +223,23 @@ class BaseExpression:
@cached_property
def output_field(self):
"""Return the output type of this expressions."""
if self._output_field_or_none is None:
raise FieldError("Cannot resolve expression type, unknown output_field")
return self._output_field_or_none
output_field = self._resolve_output_field()
if output_field is None:
self._output_field_resolved_to_none = True
raise FieldError('Cannot resolve expression type, unknown output_field')
return output_field
@cached_property
def _output_field_or_none(self):
"""
Return the output field of this expression, or None if no output type
can be resolved. Note that the 'output_field' property will raise
FieldError if no type can be resolved, but this attribute allows for
None values.
Return the output field of this expression, or None if
_resolve_output_field() didn't return an output type.
"""
if self._output_field is None:
self._output_field = self._resolve_output_field()
return self._output_field
try:
return self.output_field
except FieldError:
if not self._output_field_resolved_to_none:
raise
def _resolve_output_field(self):
"""
@ -249,9 +251,9 @@ class BaseExpression:
the type here is a convenience for the common case. The user should
supply their own output_field with more complex computations.
If a source does not have an `_output_field` then we exclude it from
this check. If all sources are `None`, then an error will be thrown
higher up the stack in the `output_field` property.
If a source's output field resolves to None, exclude it from this check.
If all sources are None, then an error is raised higher up the stack in
the output_field property.
"""
sources_iter = (source for source in self.get_source_fields() if source is not None)
for output_field in sources_iter:
@ -603,14 +605,14 @@ class Value(Expression):
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
val = self.value
# check _output_field to avoid triggering an exception
if self._output_field is not None:
output_field = self._output_field_or_none
if output_field is not None:
if self.for_save:
val = self.output_field.get_db_prep_save(val, connection=connection)
val = output_field.get_db_prep_save(val, connection=connection)
else:
val = self.output_field.get_db_prep_value(val, connection=connection)
if hasattr(self._output_field, 'get_placeholder'):
return self._output_field.get_placeholder(val, compiler, connection), [val]
val = output_field.get_db_prep_value(val, connection=connection)
if hasattr(output_field, 'get_placeholder'):
return output_field.get_placeholder(val, compiler, connection), [val]
if val is None:
# cx_Oracle does not always convert None to the appropriate
# NULL type (like in case expressions using numbers), so we
@ -652,7 +654,7 @@ class RawSQL(Expression):
return [self]
def __hash__(self):
h = hash(self.sql) ^ hash(self._output_field)
h = hash(self.sql) ^ hash(self.output_field)
for param in self.params:
h ^= hash(param)
return h
@ -998,7 +1000,7 @@ class Exists(Subquery):
super().__init__(*args, **kwargs)
def __invert__(self):
return type(self)(self.queryset, self.output_field, negated=(not self.negated), **self.extra)
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
@property
def output_field(self):

View File

@ -24,12 +24,12 @@ class Cast(Func):
def as_sql(self, compiler, connection, **extra_context):
if 'db_type' not in extra_context:
extra_context['db_type'] = self._output_field.db_type(connection)
extra_context['db_type'] = self.output_field.db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection):
extra_context = {}
output_field_class = type(self._output_field)
output_field_class = type(self.output_field)
if output_field_class in self.mysql_types:
extra_context['db_type'] = self.mysql_types[output_field_class]
return self.as_sql(compiler, connection, **extra_context)

View File

@ -243,9 +243,8 @@ class TruncDate(TruncBase):
kind = 'date'
lookup_name = 'date'
@cached_property
def output_field(self):
return DateField()
def __init__(self, *args, output_field=None, **kwargs):
super().__init__(*args, output_field=DateField(), **kwargs)
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.
@ -259,9 +258,8 @@ class TruncTime(TruncBase):
kind = 'time'
lookup_name = 'time'
@cached_property
def output_field(self):
return TimeField()
def __init__(self, *args, output_field=None, **kwargs):
super().__init__(*args, output_field=TimeField(), **kwargs)
def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date.

View File

@ -551,6 +551,9 @@ Miscellaneous
in the cache backend as an intermediate class in ``CacheKeyWarning``'s
inheritance of ``RuntimeWarning``.
* Renamed ``BaseExpression._output_field`` to ``output_field``. You may need
to update custom expressions.
.. _deprecated-features-2.0:
Features deprecated in 2.0

View File

@ -532,6 +532,16 @@ class BasicExpressionsTests(TestCase):
outer = Company.objects.filter(pk__in=Subquery(inner.values('pk')))
self.assertFalse(outer.exists())
def test_explicit_output_field(self):
class FuncA(Func):
output_field = models.CharField()
class FuncB(Func):
pass
expr = FuncB(FuncA())
self.assertEqual(expr.output_field, FuncA.output_field)
class IterableLookupInnerExpressionsTests(TestCase):
@classmethod