Fixed #10414 -- Made select_related() fail on invalid field names.

This commit is contained in:
Niclas Olofsson 2014-12-04 21:47:48 +01:00 committed by Tim Graham
parent b27db97b23
commit 3daa9d60be
6 changed files with 180 additions and 11 deletions

View File

@ -1,3 +1,4 @@
from itertools import chain
import warnings import warnings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
@ -599,6 +600,14 @@ class SQLCompiler(object):
(for example, cur_depth=1 means we are looking at models with direct (for example, cur_depth=1 means we are looking at models with direct
connections to the root model). connections to the root model).
""" """
def _get_field_choices():
direct_choices = (f.name for (f, _) in opts.get_fields_with_model() if f.rel)
reverse_choices = (
f.field.related_query_name()
for f in opts.get_all_related_objects() if f.field.unique
)
return chain(direct_choices, reverse_choices)
if not restricted and self.query.max_depth and cur_depth > self.query.max_depth: if not restricted and self.query.max_depth and cur_depth > self.query.max_depth:
# We've recursed far enough; bail out. # We've recursed far enough; bail out.
return return
@ -611,6 +620,7 @@ class SQLCompiler(object):
# Setup for the case when only particular related fields should be # Setup for the case when only particular related fields should be
# included in the related selection. # included in the related selection.
fields_found = set()
if requested is None: if requested is None:
if isinstance(self.query.select_related, dict): if isinstance(self.query.select_related, dict):
requested = self.query.select_related requested = self.query.select_related
@ -619,6 +629,24 @@ class SQLCompiler(object):
restricted = False restricted = False
for f, model in opts.get_fields_with_model(): for f, model in opts.get_fields_with_model():
fields_found.add(f.name)
if restricted:
next = requested.get(f.name, {})
if not f.rel:
# If a non-related field is used like a relation,
# or if a single non-relational field is given.
if next or (cur_depth == 1 and f.name in requested):
raise FieldError(
"Non-relational field given in select_related: '%s'. "
"Choices are: %s" % (
f.name,
", ".join(_get_field_choices()) or '(none)',
)
)
else:
next = False
# The get_fields_with_model() returns None for fields that live # The get_fields_with_model() returns None for fields that live
# in the field's local model. So, for those fields we want to use # in the field's local model. So, for those fields we want to use
# the f.model - that is the field's local model. # the f.model - that is the field's local model.
@ -632,13 +660,9 @@ class SQLCompiler(object):
columns, _ = self.get_default_columns(start_alias=alias, columns, _ = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True) opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend( self.query.related_select_cols.extend(
SelectInfo((col[0], col[1].column), col[1]) for col in columns) SelectInfo((col[0], col[1].column), col[1]) for col in columns
if restricted: )
next = requested.get(f.name, {}) self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted)
else:
next = False
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
next, restricted)
if restricted: if restricted:
related_fields = [ related_fields = [
@ -651,8 +675,10 @@ class SQLCompiler(object):
only_load.get(model), reverse=True): only_load.get(model), reverse=True):
continue continue
_, _, _, joins, _ = self.query.setup_joins( related_field_name = f.related_query_name()
[f.related_query_name()], opts, root_alias) fields_found.add(related_field_name)
_, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias)
alias = joins[-1] alias = joins[-1]
from_parent = (opts.model if issubclass(model, opts.model) from_parent = (opts.model if issubclass(model, opts.model)
else None) else None)
@ -664,6 +690,17 @@ class SQLCompiler(object):
self.fill_related_selections(model._meta, alias, cur_depth + 1, self.fill_related_selections(model._meta, alias, cur_depth + 1,
next, restricted) next, restricted)
fields_not_found = set(requested.keys()).difference(fields_found)
if fields_not_found:
invalid_fields = ("'%s'" % s for s in fields_not_found)
raise FieldError(
'Invalid field name(s) given in select_related: %s. '
'Choices are: %s' % (
', '.join(invalid_fields),
', '.join(_get_field_choices()) or '(none)',
)
)
def deferred_to_columns(self): def deferred_to_columns(self):
""" """
Converts the self.deferred_loading data structure to mapping of table Converts the self.deferred_loading data structure to mapping of table

View File

@ -681,6 +681,24 @@ lookups::
... ...
ValueError: Cannot query "<Book: Django>": Must be "Author" instance. ValueError: Cannot query "<Book: Django>": Must be "Author" instance.
``select_related()`` now checks given fields
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``select_related()`` now validates that the given fields actually exist.
Previously, nonexistent fields were silently ignored. Now, an error is raised::
>>> book = Book.objects.select_related('nonexistent_field')
Traceback (most recent call last):
...
FieldError: Invalid field name(s) given in select_related: 'nonexistent_field'
The validation also makes sure that the given field is relational::
>>> book = Book.objects.select_related('name')
Traceback (most recent call last):
...
FieldError: Non-relational field given in select_related: 'name'
Default ``EmailField.max_length`` increased to 254 Default ``EmailField.max_length`` increased to 254
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -181,7 +181,7 @@ class NonAggregateAnnotationTestCase(TestCase):
other_chain=F('chain'), other_chain=F('chain'),
is_open=Value(True, BooleanField()), is_open=Value(True, BooleanField()),
book_isbn=F('books__isbn') book_isbn=F('books__isbn')
).select_related('store').order_by('book_isbn').filter(chain='Westfield') ).order_by('book_isbn').filter(chain='Westfield')
self.assertQuerysetEqual( self.assertQuerysetEqual(
qs, [ qs, [

View File

@ -10,6 +10,9 @@ the select-related behavior will traverse.
from django.db import models from django.db import models
from django.utils.encoding import python_2_unicode_compatible from django.utils.encoding import python_2_unicode_compatible
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
# Who remembers high school biology? # Who remembers high school biology?
@ -94,3 +97,41 @@ class HybridSpecies(models.Model):
def __str__(self): def __str__(self):
return self.name return self.name
@python_2_unicode_compatible
class Topping(models.Model):
name = models.CharField(max_length=30)
def __str__(self):
return self.name
@python_2_unicode_compatible
class Pizza(models.Model):
name = models.CharField(max_length=100)
toppings = models.ManyToManyField(Topping)
def __str__(self):
return self.name
@python_2_unicode_compatible
class TaggedItem(models.Model):
tag = models.CharField(max_length=30)
content_type = models.ForeignKey(ContentType, related_name='select_related_tagged_items')
object_id = models.PositiveIntegerField()
content_object = GenericForeignKey('content_type', 'object_id')
def __str__(self):
return self.tag
@python_2_unicode_compatible
class Bookmark(models.Model):
url = models.URLField()
tags = GenericRelation(TaggedItem)
def __str__(self):
return self.url

View File

@ -1,8 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.core.exceptions import FieldError
from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies from .models import (
Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies,
Pizza, TaggedItem, Bookmark,
)
class SelectRelatedTests(TestCase): class SelectRelatedTests(TestCase):
@ -126,6 +130,12 @@ class SelectRelatedTests(TestCase):
orders = [o.genus.family.order.name for o in world] orders = [o.genus.family.order.name for o in world]
self.assertEqual(orders, ['Agaricales']) self.assertEqual(orders, ['Agaricales'])
def test_single_related_field(self):
with self.assertNumQueries(1):
species = Species.objects.select_related('genus__name')
names = [s.genus.name for s in species]
self.assertEqual(sorted(names), ['Amanita', 'Drosophila', 'Homo', 'Pisum'])
def test_field_traversal(self): def test_field_traversal(self):
with self.assertNumQueries(1): with self.assertNumQueries(1):
s = (Species.objects.all() s = (Species.objects.all()
@ -152,3 +162,47 @@ class SelectRelatedTests(TestCase):
obj = queryset[0] obj = queryset[0]
self.assertEqual(obj.parent_1, parent_1) self.assertEqual(obj.parent_1, parent_1)
self.assertEqual(obj.parent_2, parent_2) self.assertEqual(obj.parent_2, parent_2)
class SelectRelatedValidationTests(TestCase):
"""
select_related() should thrown an error on fields that do not exist and
non-relational fields.
"""
non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s"
invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s"
def test_non_relational_field(self):
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')):
list(Species.objects.select_related('name__some_field'))
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')):
list(Species.objects.select_related('name'))
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', '(none)')):
list(Domain.objects.select_related('name'))
def test_many_to_many_field(self):
with self.assertRaisesMessage(FieldError, self.invalid_error % ('toppings', '(none)')):
list(Pizza.objects.select_related('toppings'))
def test_reverse_relational_field(self):
with self.assertRaisesMessage(FieldError, self.invalid_error % ('child_1', 'genus')):
list(Species.objects.select_related('child_1'))
def test_invalid_field(self):
with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', 'genus')):
list(Species.objects.select_related('invalid_field'))
with self.assertRaisesMessage(FieldError, self.invalid_error % ('related_invalid_field', 'family')):
list(Species.objects.select_related('genus__related_invalid_field'))
with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', '(none)')):
list(Domain.objects.select_related('invalid_field'))
def test_generic_relations(self):
with self.assertRaisesMessage(FieldError, self.invalid_error % ('tags', '')):
list(Bookmark.objects.select_related('tags'))
with self.assertRaisesMessage(FieldError, self.invalid_error % ('content_object', 'content_type')):
list(TaggedItem.objects.select_related('content_object'))

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import unittest import unittest
from django.core.exceptions import FieldError
from django.test import TestCase from django.test import TestCase
from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
@ -208,3 +209,21 @@ class ReverseSelectRelatedTestCase(TestCase):
self.assertEqual(p.child1.name1, 'n1') self.assertEqual(p.child1.name1, 'n1')
with self.assertNumQueries(1): with self.assertNumQueries(1):
self.assertEqual(p.child1.child4.name1, 'n1') self.assertEqual(p.child1.child4.name1, 'n1')
class ReverseSelectRelatedValidationTests(TestCase):
"""
Rverse related fields should be listed in the validation message when an
invalid field is given in select_related().
"""
non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s"
invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s"
def test_reverse_related_validation(self):
fields = 'userprofile, userstat'
with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)):
list(User.objects.select_related('foobar'))
with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)):
list(User.objects.select_related('username'))