Fixed #10414 -- Make select_related fail on invalid field names.

This commit is contained in:
Niclas Olofsson 2014-12-04 21:47:48 +01:00
parent ce2eff7e48
commit 45d4e43d2d
4 changed files with 133 additions and 9 deletions

View File

@ -610,6 +610,7 @@ class SQLCompiler(object):
# Setup for the case when only particular related fields should be
# included in the related selection.
fields_found = set()
if requested is None:
if isinstance(self.query.select_related, dict):
requested = self.query.select_related
@ -618,6 +619,18 @@ class SQLCompiler(object):
restricted = False
for f, model in opts.get_fields_with_model():
fields_found.add(f.name)
if restricted:
next = requested.get(f.name, {})
if not f.rel:
# If a non-related field is used like a relation,
# or if a single non-relational field is given.
if next or (cur_depth == 1 and f.name in requested):
raise ValueError("Non-relational field given in select_related: '%s'" % f.name)
else:
next = False
# The get_fields_with_model() returns None for fields that live
# in the field's local model. So, for those fields we want to use
# the f.model - that is the field's local model.
@ -625,6 +638,7 @@ class SQLCompiler(object):
if not select_related_descend(f, restricted, requested,
only_load.get(field_model)):
continue
_, _, _, joins, _ = self.query.setup_joins(
[f.name], opts, root_alias)
alias = joins[-1]
@ -632,12 +646,8 @@ class SQLCompiler(object):
opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend(
SelectInfo((col[0], col[1].column), col[1]) for col in columns)
if restricted:
next = requested.get(f.name, {})
else:
next = False
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
next, restricted)
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted)
if restricted:
related_fields = [
@ -645,13 +655,16 @@ class SQLCompiler(object):
for o in opts.get_all_related_objects()
if o.field.unique
]
for f, model in related_fields:
if not select_related_descend(f, restricted, requested,
only_load.get(model), reverse=True):
continue
_, _, _, joins, _ = self.query.setup_joins(
[f.related_query_name()], opts, root_alias)
related_field_name = f.related_query_name()
fields_found.add(related_field_name)
_, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias)
alias = joins[-1]
from_parent = (opts.model if issubclass(model, opts.model)
else None)
@ -663,6 +676,12 @@ class SQLCompiler(object):
self.fill_related_selections(model._meta, alias, cur_depth + 1,
next, restricted)
fields_not_found = set(requested.keys()).difference(fields_found)
if fields_not_found:
field_descriptions = ["'%s'" % s for s in fields_not_found]
raise ValueError('Invalid field name(s) given in select_related: %s' %
(', '.join(field_descriptions)))
def deferred_to_columns(self):
"""
Converts the self.deferred_loading data structure to mapping of table

View File

@ -633,6 +633,24 @@ lookups::
...
ValueError: Cannot query "<Book: Django>": Must be "Author" instance.
``select_related()`` now checks given fields
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``select_related()`` now validates that given fields actually exist. Previously,
nonexisting fields were silently ignored. Now, an error is raised::
>>> book = Book.objects.select_related('nonexisting_field')
Traceback (most recent call last):
...
ValueError: Invalid field name(s) given in select_related: 'nonexisting_field'
The validation also makes sure that the given field is relational::
>>> book = Book.objects.select_related('name')
Traceback (most recent call last):
...
ValueError: Non-relational field given in select_related: 'name'
Default ``EmailField.max_length`` increased to 254
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -10,6 +10,9 @@ the select-related behavior will traverse.
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
# Who remembers high school biology?
@ -94,3 +97,41 @@ class HybridSpecies(models.Model):
def __str__(self):
return self.name
@python_2_unicode_compatible
class Topping(models.Model):
name = models.CharField(max_length=30)
def __str__(self):
return self.name
@python_2_unicode_compatible
class Pizza(models.Model):
name = models.CharField(max_length=100)
toppings = models.ManyToManyField(Topping)
def __str__(self):
return self.name
@python_2_unicode_compatible
class TaggedItem(models.Model):
tag = models.CharField(max_length=30)
content_type = models.ForeignKey(ContentType)
object_id = models.PositiveIntegerField()
content_object = GenericForeignKey('content_type', 'object_id')
def __str__(self):
return self.tag
@python_2_unicode_compatible
class Bookmark(models.Model):
url = models.URLField()
tags = GenericRelation(TaggedItem)
def __str__(self):
return self.url

View File

@ -2,7 +2,8 @@ from __future__ import unicode_literals
from django.test import TestCase
from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies
from .models import (Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies,
Pizza, TaggedItem, Bookmark)
class SelectRelatedTests(TestCase):
@ -126,6 +127,13 @@ class SelectRelatedTests(TestCase):
orders = [o.genus.family.order.name for o in world]
self.assertEqual(orders, ['Agaricales'])
def test_single_related_field(self):
with self.assertNumQueries(1):
species = Species.objects.select_related('genus__name')
names = [s.genus.name for s in species]
self.assertEqual(sorted(names), ['Amanita', 'Drosophila', 'Homo', 'Pisum'])
def test_field_traversal(self):
with self.assertNumQueries(1):
s = (Species.objects.all()
@ -152,3 +160,41 @@ class SelectRelatedTests(TestCase):
obj = queryset[0]
self.assertEqual(obj.parent_1, parent_1)
self.assertEqual(obj.parent_2, parent_2)
class SelectRelatedValidationTests(TestCase):
"""
Test validation of fields that does not exist or cannot be used
with select_related (such as non-relational fields).
"""
non_relational_error = "Non-relational field given in select_related: '%s'"
invalid_error = "Invalid field name(s) given in select_related: '%s'"
def test_non_relational_field(self):
with self.assertRaisesMessage(ValueError, self.non_relational_error % 'name'):
list(Species.objects.select_related('name__some_field'))
with self.assertRaisesMessage(ValueError, self.non_relational_error % 'name'):
list(Species.objects.select_related('name'))
def test_many_to_many_field(self):
with self.assertRaisesMessage(ValueError, self.invalid_error % 'toppings'):
list(Pizza.objects.select_related('toppings'))
def test_reverse_relational_field(self):
with self.assertRaisesMessage(ValueError, self.invalid_error % 'child_1'):
list(Species.objects.select_related('child_1'))
def test_invalid_field(self):
with self.assertRaisesMessage(ValueError, self.invalid_error % 'invalid_field'):
list(Species.objects.select_related('invalid_field'))
with self.assertRaisesMessage(ValueError, self.invalid_error % 'related_invalid_field'):
list(Species.objects.select_related('genus__related_invalid_field'))
def test_generic_relations(self):
with self.assertRaisesMessage(ValueError, self.invalid_error % 'tags'):
list(Bookmark.objects.select_related('tags'))
with self.assertRaisesMessage(ValueError, self.invalid_error % 'content_object'):
list(TaggedItem.objects.select_related('content_object'))