Fixed #28394 -- Allowed setting BaseExpression.output_field (renamed from _output_field).
This commit is contained in:
parent
5debbdfcc8
commit
504ce3914f
|
@ -37,7 +37,7 @@ class BoolOr(Aggregate):
|
||||||
|
|
||||||
class JSONBAgg(Aggregate):
|
class JSONBAgg(Aggregate):
|
||||||
function = 'JSONB_AGG'
|
function = 'JSONB_AGG'
|
||||||
_output_field = JSONField()
|
output_field = JSONField()
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection, context):
|
def convert_value(self, value, expression, connection, context):
|
||||||
if not value:
|
if not value:
|
||||||
|
|
|
@ -115,7 +115,7 @@ class KeyTransform(Transform):
|
||||||
class KeyTextTransform(KeyTransform):
|
class KeyTextTransform(KeyTransform):
|
||||||
operator = '->>'
|
operator = '->>'
|
||||||
nested_operator = '#>>'
|
nested_operator = '#>>'
|
||||||
_output_field = TextField()
|
output_field = TextField()
|
||||||
|
|
||||||
|
|
||||||
class KeyTransformTextLookupMixin:
|
class KeyTransformTextLookupMixin:
|
||||||
|
|
|
@ -47,7 +47,7 @@ class SearchVectorCombinable:
|
||||||
class SearchVector(SearchVectorCombinable, Func):
|
class SearchVector(SearchVectorCombinable, Func):
|
||||||
function = 'to_tsvector'
|
function = 'to_tsvector'
|
||||||
arg_joiner = " || ' ' || "
|
arg_joiner = " || ' ' || "
|
||||||
_output_field = SearchVectorField()
|
output_field = SearchVectorField()
|
||||||
config = None
|
config = None
|
||||||
|
|
||||||
def __init__(self, *expressions, **extra):
|
def __init__(self, *expressions, **extra):
|
||||||
|
@ -125,7 +125,7 @@ class SearchQueryCombinable:
|
||||||
|
|
||||||
|
|
||||||
class SearchQuery(SearchQueryCombinable, Value):
|
class SearchQuery(SearchQueryCombinable, Value):
|
||||||
_output_field = SearchQueryField()
|
output_field = SearchQueryField()
|
||||||
|
|
||||||
def __init__(self, value, output_field=None, *, config=None, invert=False):
|
def __init__(self, value, output_field=None, *, config=None, invert=False):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -170,7 +170,7 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
|
||||||
|
|
||||||
class SearchRank(Func):
|
class SearchRank(Func):
|
||||||
function = 'ts_rank'
|
function = 'ts_rank'
|
||||||
_output_field = FloatField()
|
output_field = FloatField()
|
||||||
|
|
||||||
def __init__(self, vector, query, **extra):
|
def __init__(self, vector, query, **extra):
|
||||||
if not hasattr(vector, 'resolve_expression'):
|
if not hasattr(vector, 'resolve_expression'):
|
||||||
|
|
|
@ -125,11 +125,11 @@ class BaseExpression:
|
||||||
|
|
||||||
# aggregate specific fields
|
# aggregate specific fields
|
||||||
is_summary = False
|
is_summary = False
|
||||||
_output_field = None
|
_output_field_resolved_to_none = False
|
||||||
|
|
||||||
def __init__(self, output_field=None):
|
def __init__(self, output_field=None):
|
||||||
if output_field is not None:
|
if output_field is not None:
|
||||||
self._output_field = output_field
|
self.output_field = output_field
|
||||||
|
|
||||||
def get_db_converters(self, connection):
|
def get_db_converters(self, connection):
|
||||||
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
||||||
|
@ -223,21 +223,23 @@ class BaseExpression:
|
||||||
@cached_property
|
@cached_property
|
||||||
def output_field(self):
|
def output_field(self):
|
||||||
"""Return the output type of this expressions."""
|
"""Return the output type of this expressions."""
|
||||||
if self._output_field_or_none is None:
|
output_field = self._resolve_output_field()
|
||||||
raise FieldError("Cannot resolve expression type, unknown output_field")
|
if output_field is None:
|
||||||
return self._output_field_or_none
|
self._output_field_resolved_to_none = True
|
||||||
|
raise FieldError('Cannot resolve expression type, unknown output_field')
|
||||||
|
return output_field
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _output_field_or_none(self):
|
def _output_field_or_none(self):
|
||||||
"""
|
"""
|
||||||
Return the output field of this expression, or None if no output type
|
Return the output field of this expression, or None if
|
||||||
can be resolved. Note that the 'output_field' property will raise
|
_resolve_output_field() didn't return an output type.
|
||||||
FieldError if no type can be resolved, but this attribute allows for
|
|
||||||
None values.
|
|
||||||
"""
|
"""
|
||||||
if self._output_field is None:
|
try:
|
||||||
self._output_field = self._resolve_output_field()
|
return self.output_field
|
||||||
return self._output_field
|
except FieldError:
|
||||||
|
if not self._output_field_resolved_to_none:
|
||||||
|
raise
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
def _resolve_output_field(self):
|
||||||
"""
|
"""
|
||||||
|
@ -249,9 +251,9 @@ class BaseExpression:
|
||||||
the type here is a convenience for the common case. The user should
|
the type here is a convenience for the common case. The user should
|
||||||
supply their own output_field with more complex computations.
|
supply their own output_field with more complex computations.
|
||||||
|
|
||||||
If a source does not have an `_output_field` then we exclude it from
|
If a source's output field resolves to None, exclude it from this check.
|
||||||
this check. If all sources are `None`, then an error will be thrown
|
If all sources are None, then an error is raised higher up the stack in
|
||||||
higher up the stack in the `output_field` property.
|
the output_field property.
|
||||||
"""
|
"""
|
||||||
sources_iter = (source for source in self.get_source_fields() if source is not None)
|
sources_iter = (source for source in self.get_source_fields() if source is not None)
|
||||||
for output_field in sources_iter:
|
for output_field in sources_iter:
|
||||||
|
@ -603,14 +605,14 @@ class Value(Expression):
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
val = self.value
|
val = self.value
|
||||||
# check _output_field to avoid triggering an exception
|
output_field = self._output_field_or_none
|
||||||
if self._output_field is not None:
|
if output_field is not None:
|
||||||
if self.for_save:
|
if self.for_save:
|
||||||
val = self.output_field.get_db_prep_save(val, connection=connection)
|
val = output_field.get_db_prep_save(val, connection=connection)
|
||||||
else:
|
else:
|
||||||
val = self.output_field.get_db_prep_value(val, connection=connection)
|
val = output_field.get_db_prep_value(val, connection=connection)
|
||||||
if hasattr(self._output_field, 'get_placeholder'):
|
if hasattr(output_field, 'get_placeholder'):
|
||||||
return self._output_field.get_placeholder(val, compiler, connection), [val]
|
return output_field.get_placeholder(val, compiler, connection), [val]
|
||||||
if val is None:
|
if val is None:
|
||||||
# cx_Oracle does not always convert None to the appropriate
|
# cx_Oracle does not always convert None to the appropriate
|
||||||
# NULL type (like in case expressions using numbers), so we
|
# NULL type (like in case expressions using numbers), so we
|
||||||
|
@ -652,7 +654,7 @@ class RawSQL(Expression):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
h = hash(self.sql) ^ hash(self._output_field)
|
h = hash(self.sql) ^ hash(self.output_field)
|
||||||
for param in self.params:
|
for param in self.params:
|
||||||
h ^= hash(param)
|
h ^= hash(param)
|
||||||
return h
|
return h
|
||||||
|
@ -998,7 +1000,7 @@ class Exists(Subquery):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def __invert__(self):
|
def __invert__(self):
|
||||||
return type(self)(self.queryset, self.output_field, negated=(not self.negated), **self.extra)
|
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_field(self):
|
def output_field(self):
|
||||||
|
|
|
@ -24,12 +24,12 @@ class Cast(Func):
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
if 'db_type' not in extra_context:
|
if 'db_type' not in extra_context:
|
||||||
extra_context['db_type'] = self._output_field.db_type(connection)
|
extra_context['db_type'] = self.output_field.db_type(connection)
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection):
|
||||||
extra_context = {}
|
extra_context = {}
|
||||||
output_field_class = type(self._output_field)
|
output_field_class = type(self.output_field)
|
||||||
if output_field_class in self.mysql_types:
|
if output_field_class in self.mysql_types:
|
||||||
extra_context['db_type'] = self.mysql_types[output_field_class]
|
extra_context['db_type'] = self.mysql_types[output_field_class]
|
||||||
return self.as_sql(compiler, connection, **extra_context)
|
return self.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
|
@ -243,9 +243,8 @@ class TruncDate(TruncBase):
|
||||||
kind = 'date'
|
kind = 'date'
|
||||||
lookup_name = 'date'
|
lookup_name = 'date'
|
||||||
|
|
||||||
@cached_property
|
def __init__(self, *args, output_field=None, **kwargs):
|
||||||
def output_field(self):
|
super().__init__(*args, output_field=DateField(), **kwargs)
|
||||||
return DateField()
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
# Cast to date rather than truncate to date.
|
# Cast to date rather than truncate to date.
|
||||||
|
@ -259,9 +258,8 @@ class TruncTime(TruncBase):
|
||||||
kind = 'time'
|
kind = 'time'
|
||||||
lookup_name = 'time'
|
lookup_name = 'time'
|
||||||
|
|
||||||
@cached_property
|
def __init__(self, *args, output_field=None, **kwargs):
|
||||||
def output_field(self):
|
super().__init__(*args, output_field=TimeField(), **kwargs)
|
||||||
return TimeField()
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
# Cast to date rather than truncate to date.
|
# Cast to date rather than truncate to date.
|
||||||
|
|
|
@ -551,6 +551,9 @@ Miscellaneous
|
||||||
in the cache backend as an intermediate class in ``CacheKeyWarning``'s
|
in the cache backend as an intermediate class in ``CacheKeyWarning``'s
|
||||||
inheritance of ``RuntimeWarning``.
|
inheritance of ``RuntimeWarning``.
|
||||||
|
|
||||||
|
* Renamed ``BaseExpression._output_field`` to ``output_field``. You may need
|
||||||
|
to update custom expressions.
|
||||||
|
|
||||||
.. _deprecated-features-2.0:
|
.. _deprecated-features-2.0:
|
||||||
|
|
||||||
Features deprecated in 2.0
|
Features deprecated in 2.0
|
||||||
|
|
|
@ -532,6 +532,16 @@ class BasicExpressionsTests(TestCase):
|
||||||
outer = Company.objects.filter(pk__in=Subquery(inner.values('pk')))
|
outer = Company.objects.filter(pk__in=Subquery(inner.values('pk')))
|
||||||
self.assertFalse(outer.exists())
|
self.assertFalse(outer.exists())
|
||||||
|
|
||||||
|
def test_explicit_output_field(self):
|
||||||
|
class FuncA(Func):
|
||||||
|
output_field = models.CharField()
|
||||||
|
|
||||||
|
class FuncB(Func):
|
||||||
|
pass
|
||||||
|
|
||||||
|
expr = FuncB(FuncA())
|
||||||
|
self.assertEqual(expr.output_field, FuncA.output_field)
|
||||||
|
|
||||||
|
|
||||||
class IterableLookupInnerExpressionsTests(TestCase):
|
class IterableLookupInnerExpressionsTests(TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue