From 45d4e43d2d25b902e3821b612209afa951a8bcb8 Mon Sep 17 00:00:00 2001 From: Niclas Olofsson Date: Thu, 4 Dec 2014 21:47:48 +0100 Subject: [PATCH] Fixed #10414 -- Make select_related fail on invalid field names. --- django/db/models/sql/compiler.py | 35 +++++++++++++++++------ docs/releases/1.8.txt | 18 ++++++++++++ tests/select_related/models.py | 41 +++++++++++++++++++++++++++ tests/select_related/tests.py | 48 +++++++++++++++++++++++++++++++- 4 files changed, 133 insertions(+), 9 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index b0e3e48c15..c31a543c61 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -610,6 +610,7 @@ class SQLCompiler(object): # Setup for the case when only particular related fields should be # included in the related selection. + fields_found = set() if requested is None: if isinstance(self.query.select_related, dict): requested = self.query.select_related @@ -618,6 +619,18 @@ class SQLCompiler(object): restricted = False for f, model in opts.get_fields_with_model(): + fields_found.add(f.name) + + if restricted: + next = requested.get(f.name, {}) + if not f.rel: + # If a non-related field is used like a relation, + # or if a single non-relational field is given. + if next or (cur_depth == 1 and f.name in requested): + raise ValueError("Non-relational field given in select_related: '%s'" % f.name) + else: + next = False + # The get_fields_with_model() returns None for fields that live # in the field's local model. So, for those fields we want to use # the f.model - that is the field's local model. @@ -625,6 +638,7 @@ class SQLCompiler(object): if not select_related_descend(f, restricted, requested, only_load.get(field_model)): continue + _, _, _, joins, _ = self.query.setup_joins( [f.name], opts, root_alias) alias = joins[-1] @@ -632,12 +646,8 @@ class SQLCompiler(object): opts=f.rel.to._meta, as_pairs=True) self.query.related_select_cols.extend( SelectInfo((col[0], col[1].column), col[1]) for col in columns) - if restricted: - next = requested.get(f.name, {}) - else: - next = False - self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, - next, restricted) + + self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted) if restricted: related_fields = [ @@ -645,13 +655,16 @@ class SQLCompiler(object): for o in opts.get_all_related_objects() if o.field.unique ] + for f, model in related_fields: if not select_related_descend(f, restricted, requested, only_load.get(model), reverse=True): continue - _, _, _, joins, _ = self.query.setup_joins( - [f.related_query_name()], opts, root_alias) + related_field_name = f.related_query_name() + fields_found.add(related_field_name) + + _, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias) alias = joins[-1] from_parent = (opts.model if issubclass(model, opts.model) else None) @@ -663,6 +676,12 @@ class SQLCompiler(object): self.fill_related_selections(model._meta, alias, cur_depth + 1, next, restricted) + fields_not_found = set(requested.keys()).difference(fields_found) + if fields_not_found: + field_descriptions = ["'%s'" % s for s in fields_not_found] + raise ValueError('Invalid field name(s) given in select_related: %s' % + (', '.join(field_descriptions))) + def deferred_to_columns(self): """ Converts the self.deferred_loading data structure to mapping of table diff --git a/docs/releases/1.8.txt b/docs/releases/1.8.txt index 375a39920f..d7bc5fddab 100644 --- a/docs/releases/1.8.txt +++ b/docs/releases/1.8.txt @@ -633,6 +633,24 @@ lookups:: ... ValueError: Cannot query "": Must be "Author" instance. +``select_related()`` now checks given fields +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``select_related()`` now validates that given fields actually exist. Previously, +nonexisting fields were silently ignored. Now, an error is raised:: + + >>> book = Book.objects.select_related('nonexisting_field') + Traceback (most recent call last): + ... + ValueError: Invalid field name(s) given in select_related: 'nonexisting_field' + +The validation also makes sure that the given field is relational:: + + >>> book = Book.objects.select_related('name') + Traceback (most recent call last): + ... + ValueError: Non-relational field given in select_related: 'name' + Default ``EmailField.max_length`` increased to 254 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/select_related/models.py b/tests/select_related/models.py index 13170e7e69..05c8f1c027 100644 --- a/tests/select_related/models.py +++ b/tests/select_related/models.py @@ -10,6 +10,9 @@ the select-related behavior will traverse. from django.db import models from django.utils.encoding import python_2_unicode_compatible +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.models import ContentType + # Who remembers high school biology? @@ -94,3 +97,41 @@ class HybridSpecies(models.Model): def __str__(self): return self.name + + +@python_2_unicode_compatible +class Topping(models.Model): + name = models.CharField(max_length=30) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Pizza(models.Model): + name = models.CharField(max_length=100) + toppings = models.ManyToManyField(Topping) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class TaggedItem(models.Model): + tag = models.CharField(max_length=30) + + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey('content_type', 'object_id') + + def __str__(self): + return self.tag + + +@python_2_unicode_compatible +class Bookmark(models.Model): + url = models.URLField() + tags = GenericRelation(TaggedItem) + + def __str__(self): + return self.url diff --git a/tests/select_related/tests.py b/tests/select_related/tests.py index 6b029fa64b..3dab493752 100644 --- a/tests/select_related/tests.py +++ b/tests/select_related/tests.py @@ -2,7 +2,8 @@ from __future__ import unicode_literals from django.test import TestCase -from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies +from .models import (Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species, HybridSpecies, + Pizza, TaggedItem, Bookmark) class SelectRelatedTests(TestCase): @@ -126,6 +127,13 @@ class SelectRelatedTests(TestCase): orders = [o.genus.family.order.name for o in world] self.assertEqual(orders, ['Agaricales']) + def test_single_related_field(self): + with self.assertNumQueries(1): + species = Species.objects.select_related('genus__name') + names = [s.genus.name for s in species] + + self.assertEqual(sorted(names), ['Amanita', 'Drosophila', 'Homo', 'Pisum']) + def test_field_traversal(self): with self.assertNumQueries(1): s = (Species.objects.all() @@ -152,3 +160,41 @@ class SelectRelatedTests(TestCase): obj = queryset[0] self.assertEqual(obj.parent_1, parent_1) self.assertEqual(obj.parent_2, parent_2) + + +class SelectRelatedValidationTests(TestCase): + """ + Test validation of fields that does not exist or cannot be used + with select_related (such as non-relational fields). + """ + non_relational_error = "Non-relational field given in select_related: '%s'" + invalid_error = "Invalid field name(s) given in select_related: '%s'" + + def test_non_relational_field(self): + with self.assertRaisesMessage(ValueError, self.non_relational_error % 'name'): + list(Species.objects.select_related('name__some_field')) + + with self.assertRaisesMessage(ValueError, self.non_relational_error % 'name'): + list(Species.objects.select_related('name')) + + def test_many_to_many_field(self): + with self.assertRaisesMessage(ValueError, self.invalid_error % 'toppings'): + list(Pizza.objects.select_related('toppings')) + + def test_reverse_relational_field(self): + with self.assertRaisesMessage(ValueError, self.invalid_error % 'child_1'): + list(Species.objects.select_related('child_1')) + + def test_invalid_field(self): + with self.assertRaisesMessage(ValueError, self.invalid_error % 'invalid_field'): + list(Species.objects.select_related('invalid_field')) + + with self.assertRaisesMessage(ValueError, self.invalid_error % 'related_invalid_field'): + list(Species.objects.select_related('genus__related_invalid_field')) + + def test_generic_relations(self): + with self.assertRaisesMessage(ValueError, self.invalid_error % 'tags'): + list(Bookmark.objects.select_related('tags')) + + with self.assertRaisesMessage(ValueError, self.invalid_error % 'content_object'): + list(TaggedItem.objects.select_related('content_object'))