Fixed #10414 -- Make select_related fail on invalid field names.
This commit is contained in:
parent
ce2eff7e48
commit
45d4e43d2d
|
@ -610,6 +610,7 @@ class SQLCompiler(object):
|
||||||
|
|
||||||
# Setup for the case when only particular related fields should be
|
# Setup for the case when only particular related fields should be
|
||||||
# included in the related selection.
|
# included in the related selection.
|
||||||
|
fields_found = set()
|
||||||
if requested is None:
|
if requested is None:
|
||||||
if isinstance(self.query.select_related, dict):
|
if isinstance(self.query.select_related, dict):
|
||||||
requested = self.query.select_related
|
requested = self.query.select_related
|
||||||
|
@ -618,6 +619,18 @@ class SQLCompiler(object):
|
||||||
restricted = False
|
restricted = False
|
||||||
|
|
||||||
for f, model in opts.get_fields_with_model():
|
for f, model in opts.get_fields_with_model():
|
||||||
|
fields_found.add(f.name)
|
||||||
|
|
||||||
|
if restricted:
|
||||||
|
next = requested.get(f.name, {})
|
||||||
|
if not f.rel:
|
||||||
|
# If a non-related field is used like a relation,
|
||||||
|
# or if a single non-relational field is given.
|
||||||
|
if next or (cur_depth == 1 and f.name in requested):
|
||||||
|
raise ValueError("Non-relational field given in select_related: '%s'" % f.name)
|
||||||
|
else:
|
||||||
|
next = False
|
||||||
|
|
||||||
# The get_fields_with_model() returns None for fields that live
|
# The get_fields_with_model() returns None for fields that live
|
||||||
# in the field's local model. So, for those fields we want to use
|
# in the field's local model. So, for those fields we want to use
|
||||||
# the f.model - that is the field's local model.
|
# the f.model - that is the field's local model.
|
||||||
|
@ -625,6 +638,7 @@ class SQLCompiler(object):
|
||||||
if not select_related_descend(f, restricted, requested,
|
if not select_related_descend(f, restricted, requested,
|
||||||
only_load.get(field_model)):
|
only_load.get(field_model)):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_, _, _, joins, _ = self.query.setup_joins(
|
_, _, _, joins, _ = self.query.setup_joins(
|
||||||
[f.name], opts, root_alias)
|
[f.name], opts, root_alias)
|
||||||
alias = joins[-1]
|
alias = joins[-1]
|
||||||
|
@ -632,12 +646,8 @@ class SQLCompiler(object):
|
||||||
opts=f.rel.to._meta, as_pairs=True)
|
opts=f.rel.to._meta, as_pairs=True)
|
||||||
self.query.related_select_cols.extend(
|
self.query.related_select_cols.extend(
|
||||||
SelectInfo((col[0], col[1].column), col[1]) for col in columns)
|
SelectInfo((col[0], col[1].column), col[1]) for col in columns)
|
||||||
if restricted:
|
|
||||||
next = requested.get(f.name, {})
|
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted)
|
||||||
else:
|
|
||||||
next = False
|
|
||||||
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
|
|
||||||
next, restricted)
|
|
||||||
|
|
||||||
if restricted:
|
if restricted:
|
||||||
related_fields = [
|
related_fields = [
|
||||||
|
@ -645,13 +655,16 @@ class SQLCompiler(object):
|
||||||
for o in opts.get_all_related_objects()
|
for o in opts.get_all_related_objects()
|
||||||
if o.field.unique
|
if o.field.unique
|
||||||
]
|
]
|
||||||
|
|
||||||
for f, model in related_fields:
|
for f, model in related_fields:
|
||||||
if not select_related_descend(f, restricted, requested,
|
if not select_related_descend(f, restricted, requested,
|
||||||
only_load.get(model), reverse=True):
|
only_load.get(model), reverse=True):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_, _, _, joins, _ = self.query.setup_joins(
|
related_field_name = f.related_query_name()
|
||||||
[f.related_query_name()], opts, root_alias)
|
fields_found.add(related_field_name)
|
||||||
|
|
||||||
|
_, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias)
|
||||||
alias = joins[-1]
|
alias = joins[-1]
|
||||||
from_parent = (opts.model if issubclass(model, opts.model)
|
from_parent = (opts.model if issubclass(model, opts.model)
|
||||||
else None)
|
else None)
|
||||||
|
@ -663,6 +676,12 @@ class SQLCompiler(object):
|
||||||
self.fill_related_selections(model._meta, alias, cur_depth + 1,
|
self.fill_related_selections(model._meta, alias, cur_depth + 1,
|
||||||
next, restricted)
|
next, restricted)
|
||||||
|
|
||||||
|
fields_not_found = set(requested.keys()).difference(fields_found)
|
||||||
|
if fields_not_found:
|
||||||
|
field_descriptions = ["'%s'" % s for s in fields_not_found]
|
||||||
|
raise ValueError('Invalid field name(s) given in select_related: %s' %
|
||||||
|
(', '.join(field_descriptions)))
|
||||||
|
|
||||||
def deferred_to_columns(self):
|
def deferred_to_columns(self):
|
||||||
"""
|
"""
|
||||||
Converts the self.deferred_loading data structure to mapping of table
|
Converts the self.deferred_loading data structure to mapping of table
|
||||||
|
|
|
@ -633,6 +633,24 @@ lookups::
|
||||||
...
|
...
|
||||||
ValueError: Cannot query "<Book: Django>": Must be "Author" instance.
|
ValueError: Cannot query "<Book: Django>": Must be "Author" instance.
|
||||||
|
|
||||||
|
``select_related()`` now checks given fields
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
``select_related()`` now validates that given fields actually exist. Previously,
|
||||||
|
nonexisting fields were silently ignored. Now, an error is raised::
|
||||||
|
|
||||||
|
>>> book = Book.objects.select_related('nonexisting_field')
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
ValueError: Invalid field name(s) given in select_related: 'nonexisting_field'
|
||||||
|
|
||||||
|
The validation also makes sure that the given field is relational::
|
||||||
|
|
||||||
|
>>> book = Book.objects.select_related('name')
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
ValueError: Non-relational field given in select_related: 'name'
|
||||||
|
|
||||||
Default ``EmailField.max_length`` increased to 254
|
Default ``EmailField.max_length`` increased to 254
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,9 @@ the select-related behavior will traverse.
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.encoding import python_2_unicode_compatible
|
from django.utils.encoding import python_2_unicode_compatible
|
||||||
|
|
||||||
|
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
|
||||||
|
from django.contrib.contenttypes.models import ContentType
|
||||||
|
|
||||||
# Who remembers high school biology?
|
# Who remembers high school biology?
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,3 +97,41 @@ class HybridSpecies(models.Model):
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@python_2_unicode_compatible
|
||||||
|
class Topping(models.Model):
|
||||||
|
name = models.CharField(max_length=30)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@python_2_unicode_compatible
|
||||||
|
class Pizza(models.Model):
|
||||||
|
name = models.CharField(max_length=100)
|
||||||
|
toppings = models.ManyToManyField(Topping)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@python_2_unicode_compatible
|
||||||
|
class TaggedItem(models.Model):
|
||||||
|
tag = models.CharField(max_length=30)
|
||||||
|
|
||||||
|
content_type = models.ForeignKey(ContentType)
|
||||||
|
object_id = models.PositiveIntegerField()
|
||||||
|
content_object = GenericForeignKey('content_type', 'object_id')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.tag
|
||||||
|
|
||||||
|
|
||||||
|
@python_2_unicode_compatible
|
||||||
|
class Bookmark(models.Model):
|
||||||
|
url = models.URLField()
|
||||||
|
tags = GenericRelation(TaggedItem)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.url
|
||||||
|
|
|
@ -2,7 +2,8 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies
|
from .models import (Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies,
|
||||||
|
Pizza, TaggedItem, Bookmark)
|
||||||
|
|
||||||
|
|
||||||
class SelectRelatedTests(TestCase):
|
class SelectRelatedTests(TestCase):
|
||||||
|
@ -126,6 +127,13 @@ class SelectRelatedTests(TestCase):
|
||||||
orders = [o.genus.family.order.name for o in world]
|
orders = [o.genus.family.order.name for o in world]
|
||||||
self.assertEqual(orders, ['Agaricales'])
|
self.assertEqual(orders, ['Agaricales'])
|
||||||
|
|
||||||
|
def test_single_related_field(self):
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
species = Species.objects.select_related('genus__name')
|
||||||
|
names = [s.genus.name for s in species]
|
||||||
|
|
||||||
|
self.assertEqual(sorted(names), ['Amanita', 'Drosophila', 'Homo', 'Pisum'])
|
||||||
|
|
||||||
def test_field_traversal(self):
|
def test_field_traversal(self):
|
||||||
with self.assertNumQueries(1):
|
with self.assertNumQueries(1):
|
||||||
s = (Species.objects.all()
|
s = (Species.objects.all()
|
||||||
|
@ -152,3 +160,41 @@ class SelectRelatedTests(TestCase):
|
||||||
obj = queryset[0]
|
obj = queryset[0]
|
||||||
self.assertEqual(obj.parent_1, parent_1)
|
self.assertEqual(obj.parent_1, parent_1)
|
||||||
self.assertEqual(obj.parent_2, parent_2)
|
self.assertEqual(obj.parent_2, parent_2)
|
||||||
|
|
||||||
|
|
||||||
|
class SelectRelatedValidationTests(TestCase):
|
||||||
|
"""
|
||||||
|
Test validation of fields that does not exist or cannot be used
|
||||||
|
with select_related (such as non-relational fields).
|
||||||
|
"""
|
||||||
|
non_relational_error = "Non-relational field given in select_related: '%s'"
|
||||||
|
invalid_error = "Invalid field name(s) given in select_related: '%s'"
|
||||||
|
|
||||||
|
def test_non_relational_field(self):
|
||||||
|
with self.assertRaisesMessage(ValueError, self.non_relational_error % 'name'):
|
||||||
|
list(Species.objects.select_related('name__some_field'))
|
||||||
|
|
||||||
|
with self.assertRaisesMessage(ValueError, self.non_relational_error % 'name'):
|
||||||
|
list(Species.objects.select_related('name'))
|
||||||
|
|
||||||
|
def test_many_to_many_field(self):
|
||||||
|
with self.assertRaisesMessage(ValueError, self.invalid_error % 'toppings'):
|
||||||
|
list(Pizza.objects.select_related('toppings'))
|
||||||
|
|
||||||
|
def test_reverse_relational_field(self):
|
||||||
|
with self.assertRaisesMessage(ValueError, self.invalid_error % 'child_1'):
|
||||||
|
list(Species.objects.select_related('child_1'))
|
||||||
|
|
||||||
|
def test_invalid_field(self):
|
||||||
|
with self.assertRaisesMessage(ValueError, self.invalid_error % 'invalid_field'):
|
||||||
|
list(Species.objects.select_related('invalid_field'))
|
||||||
|
|
||||||
|
with self.assertRaisesMessage(ValueError, self.invalid_error % 'related_invalid_field'):
|
||||||
|
list(Species.objects.select_related('genus__related_invalid_field'))
|
||||||
|
|
||||||
|
def test_generic_relations(self):
|
||||||
|
with self.assertRaisesMessage(ValueError, self.invalid_error % 'tags'):
|
||||||
|
list(Bookmark.objects.select_related('tags'))
|
||||||
|
|
||||||
|
with self.assertRaisesMessage(ValueError, self.invalid_error % 'content_object'):
|
||||||
|
list(TaggedItem.objects.select_related('content_object'))
|
||||||
|
|
Loading…
Reference in New Issue