diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 65f28672aa..9a1e3c286b 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -5,7 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions from django.db.models import Field, IntegerField, Transform -from django.db.models.lookups import Exact +from django.db.models.lookups import Exact, In from django.utils import six from django.utils.translation import string_concat, ugettext_lazy as _ @@ -217,6 +217,15 @@ class ArrayLenTransform(Transform): ) % {'lhs': lhs}, params +@ArrayField.register_lookup +class ArrayInLookup(In): + def get_prep_lookup(self): + values = super(ArrayInLookup, self).get_prep_lookup() + # In.process_rhs() expects values to be hashable, so convert lists + # to tuples. + return [tuple(value) for value in values] + + class IndexTransform(Transform): def __init__(self, index, base_field, *args, **kwargs): diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 7b4862a46c..045f7d194f 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -204,12 +204,17 @@ class In(BuiltinLookup): def process_rhs(self, compiler, connection): if self.rhs_is_direct_value(): - # rhs should be an iterable, we use batch_process_rhs - # to prepare/transform those values - rhs = list(self.rhs) + try: + rhs = set(self.rhs) + except TypeError: # Unhashable items in self.rhs + rhs = self.rhs + if not rhs: from django.db.models.sql.datastructures import EmptyResultSet raise EmptyResultSet + + # rhs should be an iterable; use batch_process_rhs() to + # prepare/transform those values. sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) placeholder = '(' + ', '.join(sqls) + ')' return (placeholder, sqls_params) diff --git a/tests/postgres_tests/test_hstore.py b/tests/postgres_tests/test_hstore.py index 68c54918a1..80e46e5021 100644 --- a/tests/postgres_tests/test_hstore.py +++ b/tests/postgres_tests/test_hstore.py @@ -67,6 +67,14 @@ class TestQuerying(PostgreSQLTestCase): self.objs[:2] ) + def test_in_generator(self): + def search(): + yield {'a': 'b'} + self.assertSequenceEqual( + HStoreModel.objects.filter(field__in=search()), + self.objs[:1] + ) + def test_has_key(self): self.assertSequenceEqual( HStoreModel.objects.filter(field__has_key='c'), diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index 56bd64b5b9..01c64b0959 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -6,6 +6,7 @@ from django.db import connection from django.db.models import Prefetch from django.db.models.query import get_prefetcher from django.test import TestCase, override_settings +from django.test.utils import CaptureQueriesContext from django.utils import six from django.utils.encoding import force_text @@ -245,6 +246,27 @@ class PrefetchRelatedTests(TestCase): # save of reverse relation assignment. self.assertEqual(self.author1.books.count(), 2) + def test_m2m_then_reverse_fk_object_ids(self): + with CaptureQueriesContext(connection) as queries: + list(Book.objects.prefetch_related('authors__addresses')) + + sql = queries[-1]['sql'] + self.assertEqual(sql.count(self.author1.name), 1) + + def test_m2m_then_m2m_object_ids(self): + with CaptureQueriesContext(connection) as queries: + list(Book.objects.prefetch_related('authors__favorite_authors')) + + sql = queries[-1]['sql'] + self.assertEqual(sql.count(self.author1.name), 1) + + def test_m2m_then_reverse_one_to_one_object_ids(self): + with CaptureQueriesContext(connection) as queries: + list(Book.objects.prefetch_related('authors__authorwithage')) + + sql = queries[-1]['sql'] + self.assertEqual(sql.count(str(self.author1.id)), 1, sql) + class CustomPrefetchTests(TestCase): @classmethod