diff --git a/django/db/models/fields/related_lookups.py b/django/db/models/fields/related_lookups.py index 1588128ed1..d3e4f8fa8a 100644 --- a/django/db/models/fields/related_lookups.py +++ b/django/db/models/fields/related_lookups.py @@ -23,7 +23,10 @@ def get_normalized_value(value, lhs): from django.db.models import Model if isinstance(value, Model): value_list = [] - # Account for one-to-one relations when sent a different model + # A case like Restaurant.objects.filter(place=restaurant_instance), + # where place is a OneToOneField and the primary key of Restaurant. + if getattr(lhs.output_field, 'primary_key', False): + return (value.pk,) sources = lhs.output_field.get_path_info()[-1].target_fields for source in sources: while not isinstance(value, source.model) and source.remote_field: diff --git a/django/db/models/query.py b/django/db/models/query.py index 55a94cebe5..5813ec6688 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -19,7 +19,7 @@ from django.db.models.deletion import Collector from django.db.models.expressions import F, Date, DateTime from django.db.models.fields import AutoField from django.db.models.query_utils import ( - Q, InvalidQuery, deferred_class_factory, + Q, InvalidQuery, check_rel_lookup_compatibility, deferred_class_factory, ) from django.db.models.sql.constants import CURSOR from django.utils import six, timezone @@ -1141,16 +1141,19 @@ class QuerySet(object): """ return self.query.has_filters() - def is_compatible_query_object_type(self, opts): - model = self.model - return ( - # We trust that users of values() know what they are doing. - self._fields is not None or - # Otherwise check that models are compatible. - model == opts.concrete_model or - opts.concrete_model in model._meta.get_parent_list() or - model in opts.get_parent_list() - ) + def is_compatible_query_object_type(self, opts, field): + """ + Check that using this queryset as the rhs value for a lookup is + allowed. The opts are the options of the relation's target we are + querying against. For example in .filter(author__in=Author.objects.all()) + the opts would be Author's (from the author field) and self.model would + be Author.objects.all() queryset's .model (Author also). The field is + the related field on the lhs side. + """ + # We trust that users of values() know what they are doing. + if self._fields is not None: + return True + return check_rel_lookup_compatibility(self.model, opts, field) is_compatible_query_object_type.queryset_only = True diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 5d0b4e619d..4c2a4fe80b 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -277,3 +277,31 @@ def refs_expression(lookup_parts, annotations): if level_n_lookup in annotations and annotations[level_n_lookup]: return annotations[level_n_lookup], lookup_parts[n:] return False, () + + +def check_rel_lookup_compatibility(model, target_opts, field): + """ + Check that self.model is compatible with target_opts. Compatibility + is OK if: + 1) model and opts match (where proxy inheritance is removed) + 2) model is parent of opts' model or the other way around + """ + def check(opts): + return ( + model._meta.concrete_model == opts.concrete_model or + opts.concrete_model in model._meta.get_parent_list() or + model in opts.get_parent_list() + ) + # If the field is a primary key, then doing a query against the field's + # model is ok, too. Consider the case: + # class Restaurant(models.Model): + # place = OnetoOneField(Place, primary_key=True): + # Restaurant.objects.filter(pk__in=Restaurant.objects.all()). + # If we didn't have the primary key check, then pk__in (== place__in) would + # give Place's opts as the target opts, but Restaurant isn't compatible + # with that. This logic applies only to primary keys, as when doing __in=qs, + # we are going to turn this into __in=qs.values('pk') later on. + return ( + check(target_opts) or + (getattr(field, 'primary_key', False) and check(field.model._meta)) + ) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index e88dad536d..df654052fb 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -18,7 +18,9 @@ from django.db.models.aggregates import Count from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref from django.db.models.fields.related_lookups import MultiColSource -from django.db.models.query_utils import Q, PathInfo, refs_expression +from django.db.models.query_utils import ( + Q, PathInfo, check_rel_lookup_compatibility, refs_expression, +) from django.db.models.sql.constants import ( INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE, ) @@ -1040,15 +1042,13 @@ class Query(object): (lookup, self.get_meta().model.__name__)) return lookup_parts, field_parts, False - def check_query_object_type(self, value, opts): + def check_query_object_type(self, value, opts, field): """ Checks whether the object passed while querying is of the correct type. If not, it raises a ValueError specifying the wrong object. """ if hasattr(value, '_meta'): - if not (value._meta.concrete_model == opts.concrete_model - or opts.concrete_model in value._meta.get_parent_list() - or value._meta.concrete_model in opts.get_parent_list()): + if not check_rel_lookup_compatibility(value._meta.model, opts, field): raise ValueError( 'Cannot query "%s": Must be "%s" instance.' % (value, opts.object_name)) @@ -1061,16 +1061,16 @@ class Query(object): # QuerySets implement is_compatible_query_object_type() to # determine compatibility with the given field. if hasattr(value, 'is_compatible_query_object_type'): - if not value.is_compatible_query_object_type(opts): + if not value.is_compatible_query_object_type(opts, field): raise ValueError( 'Cannot use QuerySet for "%s": Use a QuerySet for "%s".' % (value.model._meta.model_name, opts.object_name) ) elif hasattr(value, '_meta'): - self.check_query_object_type(value, opts) + self.check_query_object_type(value, opts, field) elif hasattr(value, '__iter__'): for v in value: - self.check_query_object_type(v, opts) + self.check_query_object_type(v, opts, field) def build_lookup(self, lookups, lhs, rhs): """ diff --git a/tests/one_to_one/tests.py b/tests/one_to_one/tests.py index 6604299076..041fe6857c 100644 --- a/tests/one_to_one/tests.py +++ b/tests/one_to_one/tests.py @@ -479,3 +479,21 @@ class OneToOneTests(TestCase): Waiter.objects.update(restaurant=r2) w.refresh_from_db() self.assertEqual(w.restaurant, r2) + + def test_rel_pk_subquery(self): + r = Restaurant.objects.first() + q1 = Restaurant.objects.filter(place_id=r.pk) + # Test that subquery using primary key and a query against the + # same model works correctly. + q2 = Restaurant.objects.filter(place_id__in=q1) + self.assertQuerysetEqual(q2, [r], lambda x: x) + # Test that subquery using 'pk__in' instead of 'place_id__in' work, too. + q2 = Restaurant.objects.filter( + pk__in=Restaurant.objects.filter(place__id=r.place.pk) + ) + self.assertQuerysetEqual(q2, [r], lambda x: x) + + def test_rel_pk_exact(self): + r = Restaurant.objects.first() + r2 = Restaurant.objects.filter(pk__exact=r).first() + self.assertEqual(r, r2)