[3.1.x] Fixed #31773 -- Fixed preserving output_field in ExpressionWrapper for combined expressions.

Thanks Thodoris Sotiropoulos for the report and Simon Charette for the
implementation idea.

Regression in df32fd42b8.
Backport of 8a6df55f2d from master
This commit is contained in:
Mariusz Felisiak 2020-07-09 11:55:03 +02:00
parent b19372952c
commit e6285cac83
2 changed files with 13 additions and 6 deletions

View File

@ -857,9 +857,6 @@ class ExpressionWrapper(Expression):
def __init__(self, expression, output_field): def __init__(self, expression, output_field):
super().__init__(output_field=output_field) super().__init__(output_field=output_field)
if getattr(expression, '_output_field_or_none', True) is None:
expression = expression.copy()
expression.output_field = output_field
self.expression = expression self.expression = expression
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
@ -869,7 +866,9 @@ class ExpressionWrapper(Expression):
return [self.expression] return [self.expression]
def get_group_by_cols(self, alias=None): def get_group_by_cols(self, alias=None):
return self.expression.get_group_by_cols(alias=alias) expression = self.expression.copy()
expression.output_field = self.output_field
return expression.get_group_by_cols(alias=alias)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return self.expression.as_sql(compiler, connection) return self.expression.as_sql(compiler, connection)

View File

@ -6,8 +6,8 @@ from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db import connection from django.db import connection
from django.db.models import ( from django.db.models import (
BooleanField, Case, CharField, Count, DateTimeField, Exists, BooleanField, Case, CharField, Count, DateTimeField, Exists,
ExpressionWrapper, F, Func, IntegerField, Max, NullBooleanField, OuterRef, ExpressionWrapper, F, FloatField, Func, IntegerField, Max,
Q, Subquery, Sum, Value, When, NullBooleanField, OuterRef, Q, Subquery, Sum, Value, When,
) )
from django.db.models.expressions import RawSQL from django.db.models.expressions import RawSQL
from django.db.models.functions import Length, Lower from django.db.models.functions import Length, Lower
@ -178,6 +178,14 @@ class NonAggregateAnnotationTestCase(TestCase):
self.assertEqual(book.combined, 12) self.assertEqual(book.combined, 12)
self.assertEqual(book.rating_count, 1) self.assertEqual(book.rating_count, 1)
def test_combined_f_expression_annotation_with_aggregation(self):
book = Book.objects.filter(isbn='159059725').annotate(
combined=ExpressionWrapper(F('price') * F('pages'), output_field=FloatField()),
rating_count=Count('rating'),
).first()
self.assertEqual(book.combined, 13410.0)
self.assertEqual(book.rating_count, 1)
def test_aggregate_over_annotation(self): def test_aggregate_over_annotation(self):
agg = Author.objects.annotate(other_age=F('age')).aggregate(otherage_sum=Sum('other_age')) agg = Author.objects.annotate(other_age=F('age')).aggregate(otherage_sum=Sum('other_age'))
other_agg = Author.objects.aggregate(age_sum=Sum('age')) other_agg = Author.objects.aggregate(age_sum=Sum('age'))