From d70c8907cd3b83cff2127b6a4af9cae438be8b24 Mon Sep 17 00:00:00 2001 From: Malcolm Tredinnick Date: Wed, 27 Aug 2008 05:22:33 +0000 Subject: [PATCH] Fixed #5937 -- When filtering on generic relations, restrict the target objects to those with the right content type. This isn't a complete solution to this class of problem, but it will do for 1.0, which only has generic relations as a multicolumn type. A more general multicolumn solution will be available after that release. git-svn-id: http://code.djangoproject.com/svn/django/trunk@8608 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/contrib/contenttypes/generic.py | 10 ++++++++ django/db/models/sql/query.py | 27 ++++++++++++++------ tests/modeltests/generic_relations/models.py | 23 ++++++++++++----- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index 9c85af2f8f..9b6c15a22a 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -166,6 +166,16 @@ class GenericRelation(RelatedField, Field): # same db_type as well. return None + def extra_filters(self, pieces, pos): + """ + Return an extra filter to the queryset so that the results are filtered + on the appropriate content type. + """ + ContentType = get_model("contenttypes", "contenttype") + content_type = ContentType.objects.get_for_model(self.model) + prefix = "__".join(pieces[:pos + 1]) + return "%s__%s" % (prefix, self.content_type_field_name), content_type + class ReverseGenericRelatedObjectsDescriptor(object): """ This class provides the functionality that makes the related-object diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index fa6c4da506..1628387096 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -647,8 +647,8 @@ class Query(object): pieces = name.split(LOOKUP_SEP) if not alias: alias = self.get_initial_alias() - field, target, opts, joins, last = self.setup_joins(pieces, opts, - alias, False) + field, target, opts, joins, last, extra = self.setup_joins(pieces, + opts, alias, False) alias = joins[-1] col = target.column if not field.rel: @@ -1006,7 +1006,7 @@ class Query(object): used, next, restricted, new_nullable, dupe_set, avoid) def add_filter(self, filter_expr, connector=AND, negate=False, trim=False, - can_reuse=None): + can_reuse=None, process_extras=True): """ Add a single filter to the query. The 'filter_expr' is a pair: (filter_string, value). E.g. ('name__contains', 'fred') @@ -1026,6 +1026,10 @@ class Query(object): will be a set of table aliases that can be reused in this filter, even if we would otherwise force the creation of new aliases for a join (needed for nested Q-filters). The set is updated by this method. + + If 'process_extras' is set, any extra filters returned from the table + joining process will be processed. This parameter is set to False + during the processing of extra filters to avoid infinite recursion. """ arg, value = filter_expr parts = arg.split(LOOKUP_SEP) @@ -1053,8 +1057,8 @@ class Query(object): allow_many = trim or not negate try: - field, target, opts, join_list, last = self.setup_joins(parts, opts, - alias, True, allow_many, can_reuse=can_reuse) + field, target, opts, join_list, last, extra_filters = self.setup_joins( + parts, opts, alias, True, allow_many, can_reuse=can_reuse) except MultiJoin, e: self.split_exclude(filter_expr, LOOKUP_SEP.join(parts[:e.level])) return @@ -1152,6 +1156,10 @@ class Query(object): if can_reuse is not None: can_reuse.update(join_list) + if process_extras: + for filter in extra_filters: + self.add_filter(filter, negate=negate, can_reuse=can_reuse, + process_extras=False) def add_q(self, q_object, used_aliases=None): """ @@ -1207,6 +1215,7 @@ class Query(object): last = [0] dupe_set = set() exclusions = set() + extra_filters = [] for pos, name in enumerate(names): try: exclusions.add(int_alias) @@ -1262,6 +1271,8 @@ class Query(object): exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col), ())) + if hasattr(field, 'extra_filters'): + extra_filters.append(field.extra_filters(names, pos)) if direct: if m2m: # Many-to-many field defined on the current model. @@ -1365,7 +1376,7 @@ class Query(object): if pos != len(names) - 1: raise FieldError("Join on field %r not permitted." % name) - return field, target, opts, joins, last + return field, target, opts, joins, last, extra_filters def update_dupe_avoidance(self, opts, col, alias): """ @@ -1437,7 +1448,7 @@ class Query(object): opts = self.get_meta() try: for name in field_names: - field, target, u2, joins, u3 = self.setup_joins( + field, target, u2, joins, u3, u4 = self.setup_joins( name.split(LOOKUP_SEP), opts, alias, False, allow_m2m, True) final_alias = joins[-1] @@ -1601,7 +1612,7 @@ class Query(object): """ opts = self.model._meta alias = self.get_initial_alias() - field, col, opts, joins, last = self.setup_joins( + field, col, opts, joins, last, extra = self.setup_joins( start.split(LOOKUP_SEP), opts, alias, False) alias = joins[last[-1]] self.select = [(alias, self.alias_map[alias][RHS_JOIN_COL])] diff --git a/tests/modeltests/generic_relations/models.py b/tests/modeltests/generic_relations/models.py index cbfb37dbad..a8e3577a2d 100644 --- a/tests/modeltests/generic_relations/models.py +++ b/tests/modeltests/generic_relations/models.py @@ -82,7 +82,7 @@ __test__ = {'API_TESTS':""" >>> eggplant = Vegetable(name="Eggplant", is_yucky=True) >>> bacon = Vegetable(name="Bacon", is_yucky=False) >>> quartz = Mineral(name="Quartz", hardness=7) ->>> for o in (lion, platypus, eggplant, bacon, quartz): +>>> for o in (platypus, lion, eggplant, bacon, quartz): ... o.save() # Objects with declared GenericRelations can be tagged directly -- the API @@ -95,6 +95,8 @@ __test__ = {'API_TESTS':""" >>> lion.tags.create(tag="hairy") +>>> platypus.tags.create(tag="fatty") + >>> lion.tags.all() [, ] @@ -124,25 +126,29 @@ __test__ = {'API_TESTS':""" >>> tag1.content_object = platypus >>> tag1.save() >>> platypus.tags.all() -[] +[, ] >>> TaggedItem.objects.filter(content_type__pk=ctype.id, object_id=quartz.id) [] +# Queries across generic relations respect the content types. Even though there are two TaggedItems with a tag of "fatty", this query only pulls out the one with the content type related to Animals. +>>> Animal.objects.filter(tags__tag='fatty') +[] + # If you delete an object with an explicit Generic relation, the related # objects are deleted when the source object is deleted. # Original list of tags: >>> [(t.tag, t.content_type, t.object_id) for t in TaggedItem.objects.all()] -[(u'clearish', , 1), (u'fatty', , 2), (u'hairy', , 1), (u'salty', , 2), (u'shiny', , 2), (u'yellow', , 1)] +[(u'clearish', , 1), (u'fatty', , 2), (u'fatty', , 1), (u'hairy', , 2), (u'salty', , 2), (u'shiny', , 1), (u'yellow', , 2)] >>> lion.delete() >>> [(t.tag, t.content_type, t.object_id) for t in TaggedItem.objects.all()] -[(u'clearish', , 1), (u'fatty', , 2), (u'salty', , 2), (u'shiny', , 2)] +[(u'clearish', , 1), (u'fatty', , 2), (u'fatty', , 1), (u'salty', , 2), (u'shiny', , 1)] # If Generic Relation is not explicitly defined, any related objects # remain after deletion of the source object. >>> quartz.delete() >>> [(t.tag, t.content_type, t.object_id) for t in TaggedItem.objects.all()] -[(u'clearish', , 1), (u'fatty', , 2), (u'salty', , 2), (u'shiny', , 2)] +[(u'clearish', , 1), (u'fatty', , 2), (u'fatty', , 1), (u'salty', , 2), (u'shiny', , 1)] # If you delete a tag, the objects using the tag are unaffected # (other than losing a tag) @@ -151,7 +157,9 @@ __test__ = {'API_TESTS':""" >>> bacon.tags.all() [] >>> [(t.tag, t.content_type, t.object_id) for t in TaggedItem.objects.all()] -[(u'clearish', , 1), (u'salty', , 2), (u'shiny', , 2)] +[(u'clearish', , 1), (u'fatty', , 1), (u'salty', , 2), (u'shiny', , 1)] + +>>> TaggedItem.objects.filter(tag='fatty').delete() >>> ctype = ContentType.objects.get_for_model(lion) >>> Animal.objects.filter(tags__content_type=ctype) @@ -192,6 +200,7 @@ __test__ = {'API_TESTS':""" >>> Comparison.objects.all() [] + # GenericInlineFormSet tests ################################################## >>> from django.contrib.contenttypes.generic import generic_inlineformset_factory @@ -207,7 +216,7 @@ __test__ = {'API_TESTS':""" >>> for form in formset.forms: ... print form.as_p()

-

+