mirror of https://github.com/django/django.git
Fixed #10414 -- Made select_related() fail on invalid field names.
This commit is contained in:
parent
b27db97b23
commit
3daa9d60be
|
@ -1,3 +1,4 @@
|
|||
from itertools import chain
|
||||
import warnings
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
|
@ -599,6 +600,14 @@ class SQLCompiler(object):
|
|||
(for example, cur_depth=1 means we are looking at models with direct
|
||||
connections to the root model).
|
||||
"""
|
||||
def _get_field_choices():
|
||||
direct_choices = (f.name for (f, _) in opts.get_fields_with_model() if f.rel)
|
||||
reverse_choices = (
|
||||
f.field.related_query_name()
|
||||
for f in opts.get_all_related_objects() if f.field.unique
|
||||
)
|
||||
return chain(direct_choices, reverse_choices)
|
||||
|
||||
if not restricted and self.query.max_depth and cur_depth > self.query.max_depth:
|
||||
# We've recursed far enough; bail out.
|
||||
return
|
||||
|
@ -611,6 +620,7 @@ class SQLCompiler(object):
|
|||
|
||||
# Setup for the case when only particular related fields should be
|
||||
# included in the related selection.
|
||||
fields_found = set()
|
||||
if requested is None:
|
||||
if isinstance(self.query.select_related, dict):
|
||||
requested = self.query.select_related
|
||||
|
@ -619,6 +629,24 @@ class SQLCompiler(object):
|
|||
restricted = False
|
||||
|
||||
for f, model in opts.get_fields_with_model():
|
||||
fields_found.add(f.name)
|
||||
|
||||
if restricted:
|
||||
next = requested.get(f.name, {})
|
||||
if not f.rel:
|
||||
# If a non-related field is used like a relation,
|
||||
# or if a single non-relational field is given.
|
||||
if next or (cur_depth == 1 and f.name in requested):
|
||||
raise FieldError(
|
||||
"Non-relational field given in select_related: '%s'. "
|
||||
"Choices are: %s" % (
|
||||
f.name,
|
||||
", ".join(_get_field_choices()) or '(none)',
|
||||
)
|
||||
)
|
||||
else:
|
||||
next = False
|
||||
|
||||
# The get_fields_with_model() returns None for fields that live
|
||||
# in the field's local model. So, for those fields we want to use
|
||||
# the f.model - that is the field's local model.
|
||||
|
@ -632,13 +660,9 @@ class SQLCompiler(object):
|
|||
columns, _ = self.get_default_columns(start_alias=alias,
|
||||
opts=f.rel.to._meta, as_pairs=True)
|
||||
self.query.related_select_cols.extend(
|
||||
SelectInfo((col[0], col[1].column), col[1]) for col in columns)
|
||||
if restricted:
|
||||
next = requested.get(f.name, {})
|
||||
else:
|
||||
next = False
|
||||
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
|
||||
next, restricted)
|
||||
SelectInfo((col[0], col[1].column), col[1]) for col in columns
|
||||
)
|
||||
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted)
|
||||
|
||||
if restricted:
|
||||
related_fields = [
|
||||
|
@ -651,8 +675,10 @@ class SQLCompiler(object):
|
|||
only_load.get(model), reverse=True):
|
||||
continue
|
||||
|
||||
_, _, _, joins, _ = self.query.setup_joins(
|
||||
[f.related_query_name()], opts, root_alias)
|
||||
related_field_name = f.related_query_name()
|
||||
fields_found.add(related_field_name)
|
||||
|
||||
_, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias)
|
||||
alias = joins[-1]
|
||||
from_parent = (opts.model if issubclass(model, opts.model)
|
||||
else None)
|
||||
|
@ -664,6 +690,17 @@ class SQLCompiler(object):
|
|||
self.fill_related_selections(model._meta, alias, cur_depth + 1,
|
||||
next, restricted)
|
||||
|
||||
fields_not_found = set(requested.keys()).difference(fields_found)
|
||||
if fields_not_found:
|
||||
invalid_fields = ("'%s'" % s for s in fields_not_found)
|
||||
raise FieldError(
|
||||
'Invalid field name(s) given in select_related: %s. '
|
||||
'Choices are: %s' % (
|
||||
', '.join(invalid_fields),
|
||||
', '.join(_get_field_choices()) or '(none)',
|
||||
)
|
||||
)
|
||||
|
||||
def deferred_to_columns(self):
|
||||
"""
|
||||
Converts the self.deferred_loading data structure to mapping of table
|
||||
|
|
|
@ -681,6 +681,24 @@ lookups::
|
|||
...
|
||||
ValueError: Cannot query "<Book: Django>": Must be "Author" instance.
|
||||
|
||||
``select_related()`` now checks given fields
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
``select_related()`` now validates that the given fields actually exist.
|
||||
Previously, nonexistent fields were silently ignored. Now, an error is raised::
|
||||
|
||||
>>> book = Book.objects.select_related('nonexistent_field')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
FieldError: Invalid field name(s) given in select_related: 'nonexistent_field'
|
||||
|
||||
The validation also makes sure that the given field is relational::
|
||||
|
||||
>>> book = Book.objects.select_related('name')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
FieldError: Non-relational field given in select_related: 'name'
|
||||
|
||||
Default ``EmailField.max_length`` increased to 254
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ class NonAggregateAnnotationTestCase(TestCase):
|
|||
other_chain=F('chain'),
|
||||
is_open=Value(True, BooleanField()),
|
||||
book_isbn=F('books__isbn')
|
||||
).select_related('store').order_by('book_isbn').filter(chain='Westfield')
|
||||
).order_by('book_isbn').filter(chain='Westfield')
|
||||
|
||||
self.assertQuerysetEqual(
|
||||
qs, [
|
||||
|
|
|
@ -10,6 +10,9 @@ the select-related behavior will traverse.
|
|||
from django.db import models
|
||||
from django.utils.encoding import python_2_unicode_compatible
|
||||
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
# Who remembers high school biology?
|
||||
|
||||
|
||||
|
@ -94,3 +97,41 @@ class HybridSpecies(models.Model):
|
|||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Topping(models.Model):
|
||||
name = models.CharField(max_length=30)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Pizza(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
toppings = models.ManyToManyField(Topping)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class TaggedItem(models.Model):
|
||||
tag = models.CharField(max_length=30)
|
||||
|
||||
content_type = models.ForeignKey(ContentType, related_name='select_related_tagged_items')
|
||||
object_id = models.PositiveIntegerField()
|
||||
content_object = GenericForeignKey('content_type', 'object_id')
|
||||
|
||||
def __str__(self):
|
||||
return self.tag
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Bookmark(models.Model):
|
||||
url = models.URLField()
|
||||
tags = GenericRelation(TaggedItem)
|
||||
|
||||
def __str__(self):
|
||||
return self.url
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.test import TestCase
|
||||
from django.core.exceptions import FieldError
|
||||
|
||||
from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies
|
||||
from .models import (
|
||||
Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies,
|
||||
Pizza, TaggedItem, Bookmark,
|
||||
)
|
||||
|
||||
|
||||
class SelectRelatedTests(TestCase):
|
||||
|
@ -126,6 +130,12 @@ class SelectRelatedTests(TestCase):
|
|||
orders = [o.genus.family.order.name for o in world]
|
||||
self.assertEqual(orders, ['Agaricales'])
|
||||
|
||||
def test_single_related_field(self):
|
||||
with self.assertNumQueries(1):
|
||||
species = Species.objects.select_related('genus__name')
|
||||
names = [s.genus.name for s in species]
|
||||
self.assertEqual(sorted(names), ['Amanita', 'Drosophila', 'Homo', 'Pisum'])
|
||||
|
||||
def test_field_traversal(self):
|
||||
with self.assertNumQueries(1):
|
||||
s = (Species.objects.all()
|
||||
|
@ -152,3 +162,47 @@ class SelectRelatedTests(TestCase):
|
|||
obj = queryset[0]
|
||||
self.assertEqual(obj.parent_1, parent_1)
|
||||
self.assertEqual(obj.parent_2, parent_2)
|
||||
|
||||
|
||||
class SelectRelatedValidationTests(TestCase):
|
||||
"""
|
||||
select_related() should thrown an error on fields that do not exist and
|
||||
non-relational fields.
|
||||
"""
|
||||
non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s"
|
||||
invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s"
|
||||
|
||||
def test_non_relational_field(self):
|
||||
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')):
|
||||
list(Species.objects.select_related('name__some_field'))
|
||||
|
||||
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')):
|
||||
list(Species.objects.select_related('name'))
|
||||
|
||||
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', '(none)')):
|
||||
list(Domain.objects.select_related('name'))
|
||||
|
||||
def test_many_to_many_field(self):
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('toppings', '(none)')):
|
||||
list(Pizza.objects.select_related('toppings'))
|
||||
|
||||
def test_reverse_relational_field(self):
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('child_1', 'genus')):
|
||||
list(Species.objects.select_related('child_1'))
|
||||
|
||||
def test_invalid_field(self):
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', 'genus')):
|
||||
list(Species.objects.select_related('invalid_field'))
|
||||
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('related_invalid_field', 'family')):
|
||||
list(Species.objects.select_related('genus__related_invalid_field'))
|
||||
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', '(none)')):
|
||||
list(Domain.objects.select_related('invalid_field'))
|
||||
|
||||
def test_generic_relations(self):
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('tags', '')):
|
||||
list(Bookmark.objects.select_related('tags'))
|
||||
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('content_object', 'content_type')):
|
||||
list(TaggedItem.objects.select_related('content_object'))
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import unittest
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.test import TestCase
|
||||
|
||||
from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
|
||||
|
@ -208,3 +209,21 @@ class ReverseSelectRelatedTestCase(TestCase):
|
|||
self.assertEqual(p.child1.name1, 'n1')
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEqual(p.child1.child4.name1, 'n1')
|
||||
|
||||
|
||||
class ReverseSelectRelatedValidationTests(TestCase):
|
||||
"""
|
||||
Rverse related fields should be listed in the validation message when an
|
||||
invalid field is given in select_related().
|
||||
"""
|
||||
non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s"
|
||||
invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s"
|
||||
|
||||
def test_reverse_related_validation(self):
|
||||
fields = 'userprofile, userstat'
|
||||
|
||||
with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)):
|
||||
list(User.objects.select_related('foobar'))
|
||||
|
||||
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)):
|
||||
list(User.objects.select_related('username'))
|
||||
|
|
Loading…
Reference in New Issue