Fixed #28394 -- Allowed setting BaseExpression.output_field (renamed from _output_field).

This commit is contained in:
Sergey Fedoseev 2017-07-15 06:56:01 +05:00 committed by Tim Graham
parent 5debbdfcc8
commit 504ce3914f
8 changed files with 49 additions and 36 deletions

View File

@ -37,7 +37,7 @@ class BoolOr(Aggregate):
class JSONBAgg(Aggregate): class JSONBAgg(Aggregate):
function = 'JSONB_AGG' function = 'JSONB_AGG'
_output_field = JSONField() output_field = JSONField()
def convert_value(self, value, expression, connection, context): def convert_value(self, value, expression, connection, context):
if not value: if not value:

View File

@ -115,7 +115,7 @@ class KeyTransform(Transform):
class KeyTextTransform(KeyTransform): class KeyTextTransform(KeyTransform):
operator = '->>' operator = '->>'
nested_operator = '#>>' nested_operator = '#>>'
_output_field = TextField() output_field = TextField()
class KeyTransformTextLookupMixin: class KeyTransformTextLookupMixin:

View File

@ -47,7 +47,7 @@ class SearchVectorCombinable:
class SearchVector(SearchVectorCombinable, Func): class SearchVector(SearchVectorCombinable, Func):
function = 'to_tsvector' function = 'to_tsvector'
arg_joiner = " || ' ' || " arg_joiner = " || ' ' || "
_output_field = SearchVectorField() output_field = SearchVectorField()
config = None config = None
def __init__(self, *expressions, **extra): def __init__(self, *expressions, **extra):
@ -125,7 +125,7 @@ class SearchQueryCombinable:
class SearchQuery(SearchQueryCombinable, Value): class SearchQuery(SearchQueryCombinable, Value):
_output_field = SearchQueryField() output_field = SearchQueryField()
def __init__(self, value, output_field=None, *, config=None, invert=False): def __init__(self, value, output_field=None, *, config=None, invert=False):
self.config = config self.config = config
@ -170,7 +170,7 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
class SearchRank(Func): class SearchRank(Func):
function = 'ts_rank' function = 'ts_rank'
_output_field = FloatField() output_field = FloatField()
def __init__(self, vector, query, **extra): def __init__(self, vector, query, **extra):
if not hasattr(vector, 'resolve_expression'): if not hasattr(vector, 'resolve_expression'):

View File

@ -125,11 +125,11 @@ class BaseExpression:
# aggregate specific fields # aggregate specific fields
is_summary = False is_summary = False
_output_field = None _output_field_resolved_to_none = False
def __init__(self, output_field=None): def __init__(self, output_field=None):
if output_field is not None: if output_field is not None:
self._output_field = output_field self.output_field = output_field
def get_db_converters(self, connection): def get_db_converters(self, connection):
return [self.convert_value] + self.output_field.get_db_converters(connection) return [self.convert_value] + self.output_field.get_db_converters(connection)
@ -223,21 +223,23 @@ class BaseExpression:
@cached_property @cached_property
def output_field(self): def output_field(self):
"""Return the output type of this expressions.""" """Return the output type of this expressions."""
if self._output_field_or_none is None: output_field = self._resolve_output_field()
raise FieldError("Cannot resolve expression type, unknown output_field") if output_field is None:
return self._output_field_or_none self._output_field_resolved_to_none = True
raise FieldError('Cannot resolve expression type, unknown output_field')
return output_field
@cached_property @cached_property
def _output_field_or_none(self): def _output_field_or_none(self):
""" """
Return the output field of this expression, or None if no output type Return the output field of this expression, or None if
can be resolved. Note that the 'output_field' property will raise _resolve_output_field() didn't return an output type.
FieldError if no type can be resolved, but this attribute allows for
None values.
""" """
if self._output_field is None: try:
self._output_field = self._resolve_output_field() return self.output_field
return self._output_field except FieldError:
if not self._output_field_resolved_to_none:
raise
def _resolve_output_field(self): def _resolve_output_field(self):
""" """
@ -249,9 +251,9 @@ class BaseExpression:
the type here is a convenience for the common case. The user should the type here is a convenience for the common case. The user should
supply their own output_field with more complex computations. supply their own output_field with more complex computations.
If a source does not have an `_output_field` then we exclude it from If a source's output field resolves to None, exclude it from this check.
this check. If all sources are `None`, then an error will be thrown If all sources are None, then an error is raised higher up the stack in
higher up the stack in the `output_field` property. the output_field property.
""" """
sources_iter = (source for source in self.get_source_fields() if source is not None) sources_iter = (source for source in self.get_source_fields() if source is not None)
for output_field in sources_iter: for output_field in sources_iter:
@ -603,14 +605,14 @@ class Value(Expression):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
val = self.value val = self.value
# check _output_field to avoid triggering an exception output_field = self._output_field_or_none
if self._output_field is not None: if output_field is not None:
if self.for_save: if self.for_save:
val = self.output_field.get_db_prep_save(val, connection=connection) val = output_field.get_db_prep_save(val, connection=connection)
else: else:
val = self.output_field.get_db_prep_value(val, connection=connection) val = output_field.get_db_prep_value(val, connection=connection)
if hasattr(self._output_field, 'get_placeholder'): if hasattr(output_field, 'get_placeholder'):
return self._output_field.get_placeholder(val, compiler, connection), [val] return output_field.get_placeholder(val, compiler, connection), [val]
if val is None: if val is None:
# cx_Oracle does not always convert None to the appropriate # cx_Oracle does not always convert None to the appropriate
# NULL type (like in case expressions using numbers), so we # NULL type (like in case expressions using numbers), so we
@ -652,7 +654,7 @@ class RawSQL(Expression):
return [self] return [self]
def __hash__(self): def __hash__(self):
h = hash(self.sql) ^ hash(self._output_field) h = hash(self.sql) ^ hash(self.output_field)
for param in self.params: for param in self.params:
h ^= hash(param) h ^= hash(param)
return h return h
@ -998,7 +1000,7 @@ class Exists(Subquery):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __invert__(self): def __invert__(self):
return type(self)(self.queryset, self.output_field, negated=(not self.negated), **self.extra) return type(self)(self.queryset, negated=(not self.negated), **self.extra)
@property @property
def output_field(self): def output_field(self):

View File

@ -24,12 +24,12 @@ class Cast(Func):
def as_sql(self, compiler, connection, **extra_context): def as_sql(self, compiler, connection, **extra_context):
if 'db_type' not in extra_context: if 'db_type' not in extra_context:
extra_context['db_type'] = self._output_field.db_type(connection) extra_context['db_type'] = self.output_field.db_type(connection)
return super().as_sql(compiler, connection, **extra_context) return super().as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection):
extra_context = {} extra_context = {}
output_field_class = type(self._output_field) output_field_class = type(self.output_field)
if output_field_class in self.mysql_types: if output_field_class in self.mysql_types:
extra_context['db_type'] = self.mysql_types[output_field_class] extra_context['db_type'] = self.mysql_types[output_field_class]
return self.as_sql(compiler, connection, **extra_context) return self.as_sql(compiler, connection, **extra_context)

View File

@ -243,9 +243,8 @@ class TruncDate(TruncBase):
kind = 'date' kind = 'date'
lookup_name = 'date' lookup_name = 'date'
@cached_property def __init__(self, *args, output_field=None, **kwargs):
def output_field(self): super().__init__(*args, output_field=DateField(), **kwargs)
return DateField()
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date. # Cast to date rather than truncate to date.
@ -259,9 +258,8 @@ class TruncTime(TruncBase):
kind = 'time' kind = 'time'
lookup_name = 'time' lookup_name = 'time'
@cached_property def __init__(self, *args, output_field=None, **kwargs):
def output_field(self): super().__init__(*args, output_field=TimeField(), **kwargs)
return TimeField()
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# Cast to date rather than truncate to date. # Cast to date rather than truncate to date.

View File

@ -551,6 +551,9 @@ Miscellaneous
in the cache backend as an intermediate class in ``CacheKeyWarning``'s in the cache backend as an intermediate class in ``CacheKeyWarning``'s
inheritance of ``RuntimeWarning``. inheritance of ``RuntimeWarning``.
* Renamed ``BaseExpression._output_field`` to ``output_field``. You may need
to update custom expressions.
.. _deprecated-features-2.0: .. _deprecated-features-2.0:
Features deprecated in 2.0 Features deprecated in 2.0

View File

@ -532,6 +532,16 @@ class BasicExpressionsTests(TestCase):
outer = Company.objects.filter(pk__in=Subquery(inner.values('pk'))) outer = Company.objects.filter(pk__in=Subquery(inner.values('pk')))
self.assertFalse(outer.exists()) self.assertFalse(outer.exists())
def test_explicit_output_field(self):
class FuncA(Func):
output_field = models.CharField()
class FuncB(Func):
pass
expr = FuncB(FuncA())
self.assertEqual(expr.output_field, FuncA.output_field)
class IterableLookupInnerExpressionsTests(TestCase): class IterableLookupInnerExpressionsTests(TestCase):
@classmethod @classmethod