Fixed #34362 -- Fixed FilteredRelation() crash on conditional expressions.

Thanks zhu for the report and Simon Charette for reviews.
This commit is contained in:
Francesco Panico 2023-03-04 17:59:53 +00:00 committed by Mariusz Felisiak
parent 1506f498fe
commit 59f4754704
2 changed files with 123 additions and 7 deletions

View File

@ -65,21 +65,53 @@ def get_field_names_from_opts(opts):
) )
def get_paths_from_expression(expr):
if isinstance(expr, F):
yield expr.name
elif hasattr(expr, "flatten"):
for child in expr.flatten():
if isinstance(child, F):
yield child.name
elif isinstance(child, Q):
yield from get_children_from_q(child)
def get_children_from_q(q): def get_children_from_q(q):
for child in q.children: for child in q.children:
if isinstance(child, Node): if isinstance(child, Node):
yield from get_children_from_q(child) yield from get_children_from_q(child)
else: elif isinstance(child, tuple):
yield child lhs, rhs = child
yield lhs
if hasattr(rhs, "resolve_expression"):
yield from get_paths_from_expression(rhs)
elif hasattr(child, "resolve_expression"):
yield from get_paths_from_expression(child)
def get_child_with_renamed_prefix(prefix, replacement, child): def get_child_with_renamed_prefix(prefix, replacement, child):
if isinstance(child, Node): if isinstance(child, Node):
return rename_prefix_from_q(prefix, replacement, child) return rename_prefix_from_q(prefix, replacement, child)
if isinstance(child, tuple):
lhs, rhs = child lhs, rhs = child
lhs = lhs.replace(prefix, replacement, 1) lhs = lhs.replace(prefix, replacement, 1)
if not isinstance(rhs, F) and hasattr(rhs, "resolve_expression"):
rhs = get_child_with_renamed_prefix(prefix, replacement, rhs)
return lhs, rhs return lhs, rhs
if isinstance(child, F):
child = child.copy()
child.name = child.name.replace(prefix, replacement, 1)
elif hasattr(child, "resolve_expression"):
child = child.copy()
child.set_source_expressions(
[
get_child_with_renamed_prefix(prefix, replacement, grand_child)
for grand_child in child.get_source_expressions()
]
)
return child
def rename_prefix_from_q(prefix, replacement, q): def rename_prefix_from_q(prefix, replacement, q):
return Q.create( return Q.create(
@ -1618,7 +1650,6 @@ class Query(BaseExpression):
def add_filtered_relation(self, filtered_relation, alias): def add_filtered_relation(self, filtered_relation, alias):
filtered_relation.alias = alias filtered_relation.alias = alias
lookups = dict(get_children_from_q(filtered_relation.condition))
relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type( relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
filtered_relation.relation_name filtered_relation.relation_name
) )
@ -1627,7 +1658,7 @@ class Query(BaseExpression):
"FilteredRelation's relation_name cannot contain lookups " "FilteredRelation's relation_name cannot contain lookups "
"(got %r)." % filtered_relation.relation_name "(got %r)." % filtered_relation.relation_name
) )
for lookup in chain(lookups): for lookup in get_children_from_q(filtered_relation.condition):
lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup) lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
shift = 2 if not lookup_parts else 1 shift = 2 if not lookup_parts else 1
lookup_field_path = lookup_field_parts[:-shift] lookup_field_path = lookup_field_parts[:-shift]

View File

@ -4,9 +4,11 @@ from unittest import mock
from django.db import connection, transaction from django.db import connection, transaction
from django.db.models import ( from django.db.models import (
BooleanField,
Case, Case,
Count, Count,
DecimalField, DecimalField,
ExpressionWrapper,
F, F,
FilteredRelation, FilteredRelation,
Q, Q,
@ -15,6 +17,7 @@ from django.db.models import (
When, When,
) )
from django.db.models.functions import Concat from django.db.models.functions import Concat
from django.db.models.lookups import Exact, IStartsWith
from django.test import TestCase from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature from django.test.testcases import skipUnlessDBFeature
@ -707,6 +710,88 @@ class FilteredRelationTests(TestCase):
FilteredRelation("book", condition=Q(book__title="b")), mock.ANY FilteredRelation("book", condition=Q(book__title="b")), mock.ANY
) )
def test_conditional_expression(self):
qs = Author.objects.annotate(
the_book=FilteredRelation("book", condition=Q(Value(False))),
).filter(the_book__isnull=False)
self.assertSequenceEqual(qs, [])
def test_expression_outside_relation_name(self):
qs = Author.objects.annotate(
book_editor=FilteredRelation(
"book__editor",
condition=Q(
Exact(F("book__author__name"), "Alice"),
Value(True),
book__title__startswith="Poem",
),
),
).filter(book_editor__isnull=False)
self.assertSequenceEqual(qs, [self.author1])
def test_conditional_expression_with_case(self):
qs = Book.objects.annotate(
alice_author=FilteredRelation(
"author",
condition=Q(
Case(When(author__name="Alice", then=True), default=False),
),
),
).filter(alice_author__isnull=False)
self.assertCountEqual(qs, [self.book1, self.book4])
def test_conditional_expression_outside_relation_name(self):
tests = [
Q(Case(When(book__author__name="Alice", then=True), default=False)),
Q(
ExpressionWrapper(
Q(Value(True), Exact(F("book__author__name"), "Alice")),
output_field=BooleanField(),
),
),
]
for condition in tests:
with self.subTest(condition=condition):
qs = Author.objects.annotate(
book_editor=FilteredRelation("book__editor", condition=condition),
).filter(book_editor__isnull=True)
self.assertSequenceEqual(qs, [self.author2, self.author2])
def test_conditional_expression_with_lookup(self):
lookups = [
Q(book__title__istartswith="poem"),
Q(IStartsWith(F("book__title"), "poem")),
]
for condition in lookups:
with self.subTest(condition=condition):
qs = Author.objects.annotate(
poem_book=FilteredRelation("book", condition=condition)
).filter(poem_book__isnull=False)
self.assertSequenceEqual(qs, [self.author1])
def test_conditional_expression_with_expressionwrapper(self):
qs = Author.objects.annotate(
poem_book=FilteredRelation(
"book",
condition=Q(
ExpressionWrapper(
Q(Exact(F("book__title"), "Poem by Alice")),
output_field=BooleanField(),
),
),
),
).filter(poem_book__isnull=False)
self.assertSequenceEqual(qs, [self.author1])
def test_conditional_expression_with_multiple_fields(self):
qs = Author.objects.annotate(
my_books=FilteredRelation(
"book__author",
condition=Q(Exact(F("book__author__name"), F("book__author__name"))),
),
).filter(my_books__isnull=True)
self.assertSequenceEqual(qs, [])
class FilteredRelationAggregationTests(TestCase): class FilteredRelationAggregationTests(TestCase):
@classmethod @classmethod