Fixed #21160 -- Fixed QuerySet.in_bulk() crash on SQLite when requesting more than 999 ids.

Thanks Andrei Picus and Anssi Kääriäinen for the initial patch
and Tim Graham for the review.
This commit is contained in:
Mariusz Felisiak 2017-03-27 18:43:40 +02:00 committed by GitHub
parent 899c42cc8e
commit 1b6f05e91f
2 changed files with 22 additions and 1 deletions

View File

@ -565,7 +565,17 @@ class QuerySet:
if id_list is not None: if id_list is not None:
if not id_list: if not id_list:
return {} return {}
qs = self.filter(pk__in=id_list).order_by() batch_size = connections[self.db].features.max_query_params
id_list = tuple(id_list)
# If the database has a limit on the number of query parameters
# (e.g. SQLite), retrieve objects in batches if necessary.
if batch_size and batch_size < len(id_list):
qs = ()
for offset in range(0, len(id_list), batch_size):
batch = id_list[offset:offset + batch_size]
qs += tuple(self.filter(pk__in=batch).order_by())
else:
qs = self.filter(pk__in=id_list).order_by()
else: else:
qs = self._clone() qs = self._clone()
return {obj._get_pk_val(): obj for obj in qs} return {obj._get_pk_val(): obj for obj in qs}

View File

@ -1,8 +1,10 @@
import collections import collections
from datetime import datetime from datetime import datetime
from math import ceil
from operator import attrgetter from operator import attrgetter
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from .models import Article, Author, Game, Player, Season, Tag from .models import Article, Author, Game, Player, Season, Tag
@ -127,6 +129,15 @@ class LookupTests(TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
Article.objects.in_bulk(headline__startswith='Blah') Article.objects.in_bulk(headline__startswith='Blah')
def test_in_bulk_lots_of_ids(self):
test_range = 2000
max_query_params = connection.features.max_query_params
expected_num_queries = ceil(test_range / max_query_params) if max_query_params else 1
Author.objects.bulk_create([Author() for i in range(test_range - Author.objects.count())])
authors = {author.pk: author for author in Author.objects.all()}
with self.assertNumQueries(expected_num_queries):
self.assertEqual(Author.objects.in_bulk(authors.keys()), authors)
def test_values(self): def test_values(self):
# values() returns a list of dictionaries instead of object instances -- # values() returns a list of dictionaries instead of object instances --
# and you can specify which fields you want to retrieve. # and you can specify which fields you want to retrieve.