mirror of https://github.com/django/django.git
Fixed #28394 -- Allowed setting BaseExpression.output_field (renamed from _output_field).
This commit is contained in:
parent
5debbdfcc8
commit
504ce3914f
|
@ -37,7 +37,7 @@ class BoolOr(Aggregate):
|
|||
|
||||
class JSONBAgg(Aggregate):
|
||||
function = 'JSONB_AGG'
|
||||
_output_field = JSONField()
|
||||
output_field = JSONField()
|
||||
|
||||
def convert_value(self, value, expression, connection, context):
|
||||
if not value:
|
||||
|
|
|
@ -115,7 +115,7 @@ class KeyTransform(Transform):
|
|||
class KeyTextTransform(KeyTransform):
|
||||
operator = '->>'
|
||||
nested_operator = '#>>'
|
||||
_output_field = TextField()
|
||||
output_field = TextField()
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
|
|
|
@ -47,7 +47,7 @@ class SearchVectorCombinable:
|
|||
class SearchVector(SearchVectorCombinable, Func):
|
||||
function = 'to_tsvector'
|
||||
arg_joiner = " || ' ' || "
|
||||
_output_field = SearchVectorField()
|
||||
output_field = SearchVectorField()
|
||||
config = None
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
|
@ -125,7 +125,7 @@ class SearchQueryCombinable:
|
|||
|
||||
|
||||
class SearchQuery(SearchQueryCombinable, Value):
|
||||
_output_field = SearchQueryField()
|
||||
output_field = SearchQueryField()
|
||||
|
||||
def __init__(self, value, output_field=None, *, config=None, invert=False):
|
||||
self.config = config
|
||||
|
@ -170,7 +170,7 @@ class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
|
|||
|
||||
class SearchRank(Func):
|
||||
function = 'ts_rank'
|
||||
_output_field = FloatField()
|
||||
output_field = FloatField()
|
||||
|
||||
def __init__(self, vector, query, **extra):
|
||||
if not hasattr(vector, 'resolve_expression'):
|
||||
|
|
|
@ -125,11 +125,11 @@ class BaseExpression:
|
|||
|
||||
# aggregate specific fields
|
||||
is_summary = False
|
||||
_output_field = None
|
||||
_output_field_resolved_to_none = False
|
||||
|
||||
def __init__(self, output_field=None):
|
||||
if output_field is not None:
|
||||
self._output_field = output_field
|
||||
self.output_field = output_field
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
||||
|
@ -223,21 +223,23 @@ class BaseExpression:
|
|||
@cached_property
|
||||
def output_field(self):
|
||||
"""Return the output type of this expressions."""
|
||||
if self._output_field_or_none is None:
|
||||
raise FieldError("Cannot resolve expression type, unknown output_field")
|
||||
return self._output_field_or_none
|
||||
output_field = self._resolve_output_field()
|
||||
if output_field is None:
|
||||
self._output_field_resolved_to_none = True
|
||||
raise FieldError('Cannot resolve expression type, unknown output_field')
|
||||
return output_field
|
||||
|
||||
@cached_property
|
||||
def _output_field_or_none(self):
|
||||
"""
|
||||
Return the output field of this expression, or None if no output type
|
||||
can be resolved. Note that the 'output_field' property will raise
|
||||
FieldError if no type can be resolved, but this attribute allows for
|
||||
None values.
|
||||
Return the output field of this expression, or None if
|
||||
_resolve_output_field() didn't return an output type.
|
||||
"""
|
||||
if self._output_field is None:
|
||||
self._output_field = self._resolve_output_field()
|
||||
return self._output_field
|
||||
try:
|
||||
return self.output_field
|
||||
except FieldError:
|
||||
if not self._output_field_resolved_to_none:
|
||||
raise
|
||||
|
||||
def _resolve_output_field(self):
|
||||
"""
|
||||
|
@ -249,9 +251,9 @@ class BaseExpression:
|
|||
the type here is a convenience for the common case. The user should
|
||||
supply their own output_field with more complex computations.
|
||||
|
||||
If a source does not have an `_output_field` then we exclude it from
|
||||
this check. If all sources are `None`, then an error will be thrown
|
||||
higher up the stack in the `output_field` property.
|
||||
If a source's output field resolves to None, exclude it from this check.
|
||||
If all sources are None, then an error is raised higher up the stack in
|
||||
the output_field property.
|
||||
"""
|
||||
sources_iter = (source for source in self.get_source_fields() if source is not None)
|
||||
for output_field in sources_iter:
|
||||
|
@ -603,14 +605,14 @@ class Value(Expression):
|
|||
def as_sql(self, compiler, connection):
|
||||
connection.ops.check_expression_support(self)
|
||||
val = self.value
|
||||
# check _output_field to avoid triggering an exception
|
||||
if self._output_field is not None:
|
||||
output_field = self._output_field_or_none
|
||||
if output_field is not None:
|
||||
if self.for_save:
|
||||
val = self.output_field.get_db_prep_save(val, connection=connection)
|
||||
val = output_field.get_db_prep_save(val, connection=connection)
|
||||
else:
|
||||
val = self.output_field.get_db_prep_value(val, connection=connection)
|
||||
if hasattr(self._output_field, 'get_placeholder'):
|
||||
return self._output_field.get_placeholder(val, compiler, connection), [val]
|
||||
val = output_field.get_db_prep_value(val, connection=connection)
|
||||
if hasattr(output_field, 'get_placeholder'):
|
||||
return output_field.get_placeholder(val, compiler, connection), [val]
|
||||
if val is None:
|
||||
# cx_Oracle does not always convert None to the appropriate
|
||||
# NULL type (like in case expressions using numbers), so we
|
||||
|
@ -652,7 +654,7 @@ class RawSQL(Expression):
|
|||
return [self]
|
||||
|
||||
def __hash__(self):
|
||||
h = hash(self.sql) ^ hash(self._output_field)
|
||||
h = hash(self.sql) ^ hash(self.output_field)
|
||||
for param in self.params:
|
||||
h ^= hash(param)
|
||||
return h
|
||||
|
@ -998,7 +1000,7 @@ class Exists(Subquery):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __invert__(self):
|
||||
return type(self)(self.queryset, self.output_field, negated=(not self.negated), **self.extra)
|
||||
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
|
|
|
@ -24,12 +24,12 @@ class Cast(Func):
|
|||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if 'db_type' not in extra_context:
|
||||
extra_context['db_type'] = self._output_field.db_type(connection)
|
||||
extra_context['db_type'] = self.output_field.db_type(connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
extra_context = {}
|
||||
output_field_class = type(self._output_field)
|
||||
output_field_class = type(self.output_field)
|
||||
if output_field_class in self.mysql_types:
|
||||
extra_context['db_type'] = self.mysql_types[output_field_class]
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
|
|
@ -243,9 +243,8 @@ class TruncDate(TruncBase):
|
|||
kind = 'date'
|
||||
lookup_name = 'date'
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return DateField()
|
||||
def __init__(self, *args, output_field=None, **kwargs):
|
||||
super().__init__(*args, output_field=DateField(), **kwargs)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
|
@ -259,9 +258,8 @@ class TruncTime(TruncBase):
|
|||
kind = 'time'
|
||||
lookup_name = 'time'
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return TimeField()
|
||||
def __init__(self, *args, output_field=None, **kwargs):
|
||||
super().__init__(*args, output_field=TimeField(), **kwargs)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
|
|
|
@ -551,6 +551,9 @@ Miscellaneous
|
|||
in the cache backend as an intermediate class in ``CacheKeyWarning``'s
|
||||
inheritance of ``RuntimeWarning``.
|
||||
|
||||
* Renamed ``BaseExpression._output_field`` to ``output_field``. You may need
|
||||
to update custom expressions.
|
||||
|
||||
.. _deprecated-features-2.0:
|
||||
|
||||
Features deprecated in 2.0
|
||||
|
|
|
@ -532,6 +532,16 @@ class BasicExpressionsTests(TestCase):
|
|||
outer = Company.objects.filter(pk__in=Subquery(inner.values('pk')))
|
||||
self.assertFalse(outer.exists())
|
||||
|
||||
def test_explicit_output_field(self):
|
||||
class FuncA(Func):
|
||||
output_field = models.CharField()
|
||||
|
||||
class FuncB(Func):
|
||||
pass
|
||||
|
||||
expr = FuncB(FuncA())
|
||||
self.assertEqual(expr.output_field, FuncA.output_field)
|
||||
|
||||
|
||||
class IterableLookupInnerExpressionsTests(TestCase):
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue