Fixed #25544 -- Removed duplicate ids in prefetch_related() queries.

This commit is contained in:
Ian Foote 2015-11-07 16:06:06 +01:00 committed by Tim Graham
parent ed1bcf0515
commit 86eccdc8b6
4 changed files with 48 additions and 4 deletions

View File

@ -5,7 +5,7 @@ from django.contrib.postgres.forms import SimpleArrayField
from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.contrib.postgres.validators import ArrayMaxLengthValidator
from django.core import checks, exceptions from django.core import checks, exceptions
from django.db.models import Field, IntegerField, Transform from django.db.models import Field, IntegerField, Transform
from django.db.models.lookups import Exact from django.db.models.lookups import Exact, In
from django.utils import six from django.utils import six
from django.utils.translation import string_concat, ugettext_lazy as _ from django.utils.translation import string_concat, ugettext_lazy as _
@ -217,6 +217,15 @@ class ArrayLenTransform(Transform):
) % {'lhs': lhs}, params ) % {'lhs': lhs}, params
@ArrayField.register_lookup
class ArrayInLookup(In):
def get_prep_lookup(self):
values = super(ArrayInLookup, self).get_prep_lookup()
# In.process_rhs() expects values to be hashable, so convert lists
# to tuples.
return [tuple(value) for value in values]
class IndexTransform(Transform): class IndexTransform(Transform):
def __init__(self, index, base_field, *args, **kwargs): def __init__(self, index, base_field, *args, **kwargs):

View File

@ -204,12 +204,17 @@ class In(BuiltinLookup):
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value(): if self.rhs_is_direct_value():
# rhs should be an iterable, we use batch_process_rhs try:
# to prepare/transform those values rhs = set(self.rhs)
rhs = list(self.rhs) except TypeError: # Unhashable items in self.rhs
rhs = self.rhs
if not rhs: if not rhs:
from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.datastructures import EmptyResultSet
raise EmptyResultSet raise EmptyResultSet
# rhs should be an iterable; use batch_process_rhs() to
# prepare/transform those values.
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
placeholder = '(' + ', '.join(sqls) + ')' placeholder = '(' + ', '.join(sqls) + ')'
return (placeholder, sqls_params) return (placeholder, sqls_params)

View File

@ -67,6 +67,14 @@ class TestQuerying(PostgreSQLTestCase):
self.objs[:2] self.objs[:2]
) )
def test_in_generator(self):
def search():
yield {'a': 'b'}
self.assertSequenceEqual(
HStoreModel.objects.filter(field__in=search()),
self.objs[:1]
)
def test_has_key(self): def test_has_key(self):
self.assertSequenceEqual( self.assertSequenceEqual(
HStoreModel.objects.filter(field__has_key='c'), HStoreModel.objects.filter(field__has_key='c'),

View File

@ -6,6 +6,7 @@ from django.db import connection
from django.db.models import Prefetch from django.db.models import Prefetch
from django.db.models.query import get_prefetcher from django.db.models.query import get_prefetcher
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.test.utils import CaptureQueriesContext
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
@ -245,6 +246,27 @@ class PrefetchRelatedTests(TestCase):
# save of reverse relation assignment. # save of reverse relation assignment.
self.assertEqual(self.author1.books.count(), 2) self.assertEqual(self.author1.books.count(), 2)
def test_m2m_then_reverse_fk_object_ids(self):
with CaptureQueriesContext(connection) as queries:
list(Book.objects.prefetch_related('authors__addresses'))
sql = queries[-1]['sql']
self.assertEqual(sql.count(self.author1.name), 1)
def test_m2m_then_m2m_object_ids(self):
with CaptureQueriesContext(connection) as queries:
list(Book.objects.prefetch_related('authors__favorite_authors'))
sql = queries[-1]['sql']
self.assertEqual(sql.count(self.author1.name), 1)
def test_m2m_then_reverse_one_to_one_object_ids(self):
with CaptureQueriesContext(connection) as queries:
list(Book.objects.prefetch_related('authors__authorwithage'))
sql = queries[-1]['sql']
self.assertEqual(sql.count(str(self.author1.id)), 1, sql)
class CustomPrefetchTests(TestCase): class CustomPrefetchTests(TestCase):
@classmethod @classmethod