Fixed #34015 -- Allowed filtering by transforms on relation fields.

This commit is contained in:
Mariusz Felisiak 2022-09-22 00:17:04 +02:00 committed by GitHub
parent cfe3008123
commit ce6230aa97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 15 deletions

View File

@ -1278,10 +1278,6 @@ class Query(BaseExpression):
# supports both transform and lookup for the name.
lookup_class = lhs.get_lookup(lookup_name)
if not lookup_class:
if lhs.field.is_relation:
raise FieldError(
"Related Field got invalid lookup: {}".format(lookup_name)
)
# A lookup wasn't found. Try to interpret the name as a transform
# and do an Exact lookup against it.
lhs = self.try_transform(lhs, lookup_name)
@ -1450,12 +1446,6 @@ class Query(BaseExpression):
can_reuse.update(join_list)
if join_info.final_field.is_relation:
# No support for transforms for relational fields
num_lookups = len(lookups)
if num_lookups > 1:
raise FieldError(
"Related Field got invalid lookup: {}".format(lookups[0])
)
if len(targets) == 1:
col = self._get_col(targets[0], join_info.final_field, alias)
else:

View File

@ -786,10 +786,16 @@ class LookupTests(TestCase):
def test_relation_nested_lookup_error(self):
# An invalid nested lookup on a related field raises a useful error.
msg = "Related Field got invalid lookup: editor"
msg = (
"Unsupported lookup 'editor' for ForeignKey or join on the field not "
"permitted."
)
with self.assertRaisesMessage(FieldError, msg):
Article.objects.filter(author__editor__name="James")
msg = "Related Field got invalid lookup: foo"
msg = (
"Unsupported lookup 'foo' for ForeignKey or join on the field not "
"permitted."
)
with self.assertRaisesMessage(FieldError, msg):
Tag.objects.filter(articles__foo="bar")

View File

@ -1,6 +1,8 @@
"""
Various complex queries that have been problematic in the past.
"""
import datetime
from django.db import models
from django.db.models.functions import Now
@ -64,7 +66,7 @@ class Annotation(models.Model):
class DateTimePK(models.Model):
date = models.DateTimeField(primary_key=True, auto_now_add=True)
date = models.DateTimeField(primary_key=True, default=datetime.datetime.now)
class ExtraInfo(models.Model):

View File

@ -7,12 +7,13 @@ from threading import Lock
from django.core.exceptions import EmptyResultSet, FieldError
from django.db import DEFAULT_DB_ALIAS, connection
from django.db.models import Count, Exists, F, Max, OuterRef, Q
from django.db.models import CharField, Count, Exists, F, Max, OuterRef, Q
from django.db.models.expressions import RawSQL
from django.db.models.functions import ExtractYear, Length, LTrim
from django.db.models.sql.constants import LOUTER
from django.db.models.sql.where import AND, OR, NothingNode, WhereNode
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext, ignore_warnings
from django.test.utils import CaptureQueriesContext, ignore_warnings, register_lookup
from django.utils.deprecation import RemovedInDjango50Warning
from .models import (
@ -391,6 +392,33 @@ class Queries1Tests(TestCase):
qs = qs.order_by("id")
self.assertNotIn("OUTER JOIN", str(qs.query))
def test_filter_by_related_field_transform(self):
extra_old = ExtraInfo.objects.create(
info="extra 12",
date=DateTimePK.objects.create(date=datetime.datetime(2020, 12, 10)),
)
ExtraInfo.objects.create(info="extra 11", date=DateTimePK.objects.create())
a5 = Author.objects.create(name="a5", num=5005, extra=extra_old)
fk_field = ExtraInfo._meta.get_field("date")
with register_lookup(fk_field, ExtractYear):
self.assertSequenceEqual(
ExtraInfo.objects.filter(date__year=2020),
[extra_old],
)
self.assertSequenceEqual(
Author.objects.filter(extra__date__year=2020), [a5]
)
def test_filter_by_related_field_nested_transforms(self):
extra = ExtraInfo.objects.create(info=" extra")
a5 = Author.objects.create(name="a5", num=5005, extra=extra)
info_field = ExtraInfo._meta.get_field("info")
with register_lookup(info_field, Length), register_lookup(CharField, LTrim):
self.assertSequenceEqual(
Author.objects.filter(extra__info__ltrim__length=5), [a5]
)
def test_get_clears_ordering(self):
"""
get() should clear ordering for optimization purposes.