Fixed #33651 -- Added support for prefetching GenericForeignKey.

Co-authored-by: revanthgss <revanthgss@almabase.com>
Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Clément Escolano 2023-08-01 23:31:40 +02:00 committed by Mariusz Felisiak
parent 190874eadd
commit cac94dd8aa
15 changed files with 473 additions and 42 deletions

View File

@ -1,5 +1,6 @@
import functools import functools
import itertools import itertools
import warnings
from collections import defaultdict from collections import defaultdict
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
@ -19,6 +20,7 @@ from django.db.models.query_utils import PathInfo
from django.db.models.sql import AND from django.db.models.sql import AND
from django.db.models.sql.where import WhereNode from django.db.models.sql.where import WhereNode
from django.db.models.utils import AltersData from django.db.models.utils import AltersData
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -163,20 +165,44 @@ class GenericForeignKey(FieldCacheMixin):
def get_cache_name(self): def get_cache_name(self):
return self.name return self.name
def get_content_type(self, obj=None, id=None, using=None): def get_content_type(self, obj=None, id=None, using=None, model=None):
if obj is not None: if obj is not None:
return ContentType.objects.db_manager(obj._state.db).get_for_model( return ContentType.objects.db_manager(obj._state.db).get_for_model(
obj, for_concrete_model=self.for_concrete_model obj, for_concrete_model=self.for_concrete_model
) )
elif id is not None: elif id is not None:
return ContentType.objects.db_manager(using).get_for_id(id) return ContentType.objects.db_manager(using).get_for_id(id)
elif model is not None:
return ContentType.objects.db_manager(using).get_for_model(
model, for_concrete_model=self.for_concrete_model
)
else: else:
# This should never happen. I love comments like this, don't you? # This should never happen. I love comments like this, don't you?
raise Exception("Impossible arguments to GFK.get_content_type!") raise Exception("Impossible arguments to GFK.get_content_type!")
def get_prefetch_queryset(self, instances, queryset=None): def get_prefetch_queryset(self, instances, queryset=None):
if queryset is not None: warnings.warn(
raise ValueError("Custom queryset can't be used for this lookup.") "get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
if queryset is None:
return self.get_prefetch_querysets(instances)
return self.get_prefetch_querysets(instances, [queryset])
def get_prefetch_querysets(self, instances, querysets=None):
custom_queryset_dict = {}
if querysets is not None:
for queryset in querysets:
ct_id = self.get_content_type(
model=queryset.query.model, using=queryset.db
).pk
if ct_id in custom_queryset_dict:
raise ValueError(
"Only one queryset is allowed for each content type."
)
custom_queryset_dict[ct_id] = queryset
# For efficiency, group the instances by content type and then do one # For efficiency, group the instances by content type and then do one
# query per model # query per model
@ -195,9 +221,13 @@ class GenericForeignKey(FieldCacheMixin):
ret_val = [] ret_val = []
for ct_id, fkeys in fk_dict.items(): for ct_id, fkeys in fk_dict.items():
instance = instance_dict[ct_id] if ct_id in custom_queryset_dict:
ct = self.get_content_type(id=ct_id, using=instance._state.db) # Return values from the custom queryset, if provided.
ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys)) ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
else:
instance = instance_dict[ct_id]
ct = self.get_content_type(id=ct_id, using=instance._state.db)
ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
# For doing the join in Python, we have to match both the FK val and the # For doing the join in Python, we have to match both the FK val and the
# content type, so we use a callable that returns a (fk, class) pair. # content type, so we use a callable that returns a (fk, class) pair.
@ -616,9 +646,23 @@ def create_generic_related_manager(superclass, rel):
return self._apply_rel_filters(queryset) return self._apply_rel_filters(queryset)
def get_prefetch_queryset(self, instances, queryset=None): def get_prefetch_queryset(self, instances, queryset=None):
warnings.warn(
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
if queryset is None: if queryset is None:
queryset = super().get_queryset() return self.get_prefetch_querysets(instances)
return self.get_prefetch_querysets(instances, [queryset])
def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0] if querysets else super().get_queryset()
queryset._add_hints(instance=instances[0]) queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db) queryset = queryset.using(queryset._db or self._db)
# Group instances by content types. # Group instances by content types.

View File

@ -0,0 +1,36 @@
from django.db.models import Prefetch
from django.db.models.query import ModelIterable, RawQuerySet
class GenericPrefetch(Prefetch):
def __init__(self, lookup, querysets=None, to_attr=None):
for queryset in querysets:
if queryset is not None and (
isinstance(queryset, RawQuerySet)
or (
hasattr(queryset, "_iterable_class")
and not issubclass(queryset._iterable_class, ModelIterable)
)
):
raise ValueError(
"Prefetch querysets cannot use raw(), values(), and values_list()."
)
self.querysets = querysets
super().__init__(lookup, to_attr=to_attr)
def __getstate__(self):
obj_dict = self.__dict__.copy()
obj_dict["querysets"] = []
for queryset in self.querysets:
if queryset is not None:
queryset = queryset._chain()
# Prevent the QuerySet from being evaluated
queryset._result_cache = []
queryset._prefetch_done = True
obj_dict["querysets"].append(queryset)
return obj_dict
def get_current_querysets(self, level):
if self.get_current_prefetch_to(level) == self.prefetch_to:
return self.querysets
return None

View File

@ -62,6 +62,7 @@ and two directions (forward and reverse) for a total of six combinations.
If you're looking for ``ForwardManyToManyDescriptor`` or If you're looking for ``ForwardManyToManyDescriptor`` or
``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead. ``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead.
""" """
import warnings
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
@ -79,6 +80,7 @@ from django.db.models.lookups import GreaterThan, LessThanOrEqual
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.db.models.query_utils import DeferredAttribute from django.db.models.query_utils import DeferredAttribute
from django.db.models.utils import AltersData, resolve_callables from django.db.models.utils import AltersData, resolve_callables
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -153,8 +155,23 @@ class ForwardManyToOneDescriptor:
return self.field.remote_field.model._base_manager.db_manager(hints=hints).all() return self.field.remote_field.model._base_manager.db_manager(hints=hints).all()
def get_prefetch_queryset(self, instances, queryset=None): def get_prefetch_queryset(self, instances, queryset=None):
warnings.warn(
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
if queryset is None: if queryset is None:
queryset = self.get_queryset() return self.get_prefetch_querysets(instances)
return self.get_prefetch_querysets(instances, [queryset])
def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a length "
"of 1."
)
queryset = querysets[0] if querysets else self.get_queryset()
queryset._add_hints(instance=instances[0]) queryset._add_hints(instance=instances[0])
rel_obj_attr = self.field.get_foreign_related_value rel_obj_attr = self.field.get_foreign_related_value
@ -427,8 +444,23 @@ class ReverseOneToOneDescriptor:
return self.related.related_model._base_manager.db_manager(hints=hints).all() return self.related.related_model._base_manager.db_manager(hints=hints).all()
def get_prefetch_queryset(self, instances, queryset=None): def get_prefetch_queryset(self, instances, queryset=None):
warnings.warn(
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
if queryset is None: if queryset is None:
queryset = self.get_queryset() return self.get_prefetch_querysets(instances)
return self.get_prefetch_querysets(instances, [queryset])
def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a length "
"of 1."
)
queryset = querysets[0] if querysets else self.get_queryset()
queryset._add_hints(instance=instances[0]) queryset._add_hints(instance=instances[0])
rel_obj_attr = self.related.field.get_local_related_value rel_obj_attr = self.related.field.get_local_related_value
@ -728,9 +760,23 @@ def create_reverse_many_to_one_manager(superclass, rel):
return self._apply_rel_filters(queryset) return self._apply_rel_filters(queryset)
def get_prefetch_queryset(self, instances, queryset=None): def get_prefetch_queryset(self, instances, queryset=None):
warnings.warn(
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
if queryset is None: if queryset is None:
queryset = super().get_queryset() return self.get_prefetch_querysets(instances)
return self.get_prefetch_querysets(instances, [queryset])
def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0] if querysets else super().get_queryset()
queryset._add_hints(instance=instances[0]) queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db) queryset = queryset.using(queryset._db or self._db)
@ -1087,9 +1133,23 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
return self._apply_rel_filters(queryset) return self._apply_rel_filters(queryset)
def get_prefetch_queryset(self, instances, queryset=None): def get_prefetch_queryset(self, instances, queryset=None):
warnings.warn(
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
if queryset is None: if queryset is None:
queryset = super().get_queryset() return self.get_prefetch_querysets(instances)
return self.get_prefetch_querysets(instances, [queryset])
def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a "
"length of 1."
)
queryset = querysets[0] if querysets else super().get_queryset()
queryset._add_hints(instance=instances[0]) queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db) queryset = queryset.using(queryset._db or self._db)
queryset = _filter_prefetch_queryset( queryset = _filter_prefetch_queryset(

View File

@ -33,6 +33,7 @@ from django.db.models.utils import (
resolve_callables, resolve_callables,
) )
from django.utils import timezone from django.utils import timezone
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.functional import cached_property, partition from django.utils.functional import cached_property, partition
# The maximum number of results to fetch in a get() query. # The maximum number of results to fetch in a get() query.
@ -2236,8 +2237,21 @@ class Prefetch:
return to_attr, as_attr return to_attr, as_attr
def get_current_queryset(self, level): def get_current_queryset(self, level):
if self.get_current_prefetch_to(level) == self.prefetch_to: warnings.warn(
return self.queryset "Prefetch.get_current_queryset() is deprecated. Use "
"get_current_querysets() instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
querysets = self.get_current_querysets(level)
return querysets[0] if querysets is not None else None
def get_current_querysets(self, level):
if (
self.get_current_prefetch_to(level) == self.prefetch_to
and self.queryset is not None
):
return [self.queryset]
return None return None
def __eq__(self, other): def __eq__(self, other):
@ -2425,9 +2439,9 @@ async def aprefetch_related_objects(model_instances, *related_lookups):
def get_prefetcher(instance, through_attr, to_attr): def get_prefetcher(instance, through_attr, to_attr):
""" """
For the attribute 'through_attr' on the given instance, find For the attribute 'through_attr' on the given instance, find
an object that has a get_prefetch_queryset(). an object that has a get_prefetch_querysets().
Return a 4 tuple containing: Return a 4 tuple containing:
(the object with get_prefetch_queryset (or None), (the object with get_prefetch_querysets (or None),
the descriptor object representing this relationship (or None), the descriptor object representing this relationship (or None),
a boolean that is False if the attribute was not found at all, a boolean that is False if the attribute was not found at all,
a function that takes an instance and returns a boolean that is True if a function that takes an instance and returns a boolean that is True if
@ -2462,8 +2476,12 @@ def get_prefetcher(instance, through_attr, to_attr):
attr_found = True attr_found = True
if rel_obj_descriptor: if rel_obj_descriptor:
# singly related object, descriptor object has the # singly related object, descriptor object has the
# get_prefetch_queryset() method. # get_prefetch_querysets() method.
if hasattr(rel_obj_descriptor, "get_prefetch_queryset"): if (
hasattr(rel_obj_descriptor, "get_prefetch_querysets")
# RemovedInDjango60Warning.
or hasattr(rel_obj_descriptor, "get_prefetch_queryset")
):
prefetcher = rel_obj_descriptor prefetcher = rel_obj_descriptor
# If to_attr is set, check if the value has already been set, # If to_attr is set, check if the value has already been set,
# which is done with has_to_attr_attribute(). Do not use the # which is done with has_to_attr_attribute(). Do not use the
@ -2476,7 +2494,11 @@ def get_prefetcher(instance, through_attr, to_attr):
# the attribute on the instance rather than the class to # the attribute on the instance rather than the class to
# support many related managers # support many related managers
rel_obj = getattr(instance, through_attr) rel_obj = getattr(instance, through_attr)
if hasattr(rel_obj, "get_prefetch_queryset"): if (
hasattr(rel_obj, "get_prefetch_querysets")
# RemovedInDjango60Warning.
or hasattr(rel_obj, "get_prefetch_queryset")
):
prefetcher = rel_obj prefetcher = rel_obj
if through_attr == to_attr: if through_attr == to_attr:
@ -2497,7 +2519,7 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
Return the prefetched objects along with any additional prefetches that Return the prefetched objects along with any additional prefetches that
must be done due to prefetch_related lookups found from default managers. must be done due to prefetch_related lookups found from default managers.
""" """
# prefetcher must have a method get_prefetch_queryset() which takes a list # prefetcher must have a method get_prefetch_querysets() which takes a list
# of instances, and returns a tuple: # of instances, and returns a tuple:
# (queryset of instances of self.model that are related to passed in instances, # (queryset of instances of self.model that are related to passed in instances,
@ -2510,14 +2532,34 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
# The 'values to be matched' must be hashable as they will be used # The 'values to be matched' must be hashable as they will be used
# in a dictionary. # in a dictionary.
( if hasattr(prefetcher, "get_prefetch_querysets"):
rel_qs, (
rel_obj_attr, rel_qs,
instance_attr, rel_obj_attr,
single, instance_attr,
cache_name, single,
is_descriptor, cache_name,
) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level)) is_descriptor,
) = prefetcher.get_prefetch_querysets(
instances, lookup.get_current_querysets(level)
)
else:
warnings.warn(
"The usage of get_prefetch_queryset() in prefetch_related_objects() is "
"deprecated. Implement get_prefetch_querysets() instead.",
RemovedInDjango60Warning,
stacklevel=2,
)
(
rel_qs,
rel_obj_attr,
instance_attr,
single,
cache_name,
is_descriptor,
) = prefetcher.get_prefetch_queryset(
instances, lookup.get_current_querysets(level)
)
# We have to handle the possibility that the QuerySet we just got back # We have to handle the possibility that the QuerySet we just got back
# contains some prefetch_related lookups. We don't want to trigger the # contains some prefetch_related lookups. We don't want to trigger the
# prefetch_related functionality by evaluating the query. Rather, we need # prefetch_related functionality by evaluating the query. Rather, we need

View File

@ -45,6 +45,14 @@ details on these changes.
* The ``ChoicesMeta`` alias to ``django.db.models.enums.ChoicesType`` will be * The ``ChoicesMeta`` alias to ``django.db.models.enums.ChoicesType`` will be
removed. removed.
* The ``Prefetch.get_current_queryset()`` method will be removed.
* The ``get_prefetch_queryset()`` method of related managers and descriptors
will be removed.
* ``get_prefetcher()`` and ``prefetch_related_objects()`` will no longer
fallback to ``get_prefetch_queryset()``.
.. _deprecation-removed-in-5.1: .. _deprecation-removed-in-5.1:
5.1 5.1

View File

@ -590,3 +590,29 @@ information.
Subclasses of :class:`GenericInlineModelAdmin` with stacked and tabular Subclasses of :class:`GenericInlineModelAdmin` with stacked and tabular
layouts, respectively. layouts, respectively.
.. module:: django.contrib.contenttypes.prefetch
``GenericPrefetch()``
---------------------
.. versionadded:: 5.0
.. class:: GenericPrefetch(lookup, querysets=None, to_attr=None)
This lookup is similar to ``Prefetch()`` and it should only be used on
``GenericForeignKey``. The ``querysets`` argument accepts a list of querysets,
each for a different ``ContentType``. This is useful for ``GenericForeignKey``
with non-homogeneous set of results.
.. code-block:: pycon
>>> bookmark = Bookmark.objects.create(url="https://www.djangoproject.com/")
>>> animal = Animal.objects.create(name="lion", weight=100)
>>> TaggedItem.objects.create(tag="great", content_object=bookmark)
>>> TaggedItem.objects.create(tag="awesome", content_object=animal)
>>> prefetch = GenericPrefetch(
... "content_object", [Bookmark.objects.all(), Animal.objects.only("name")]
... )
>>> TaggedItem.objects.prefetch_related(prefetch).all()
<QuerySet [<TaggedItem: Great>, <TaggedItem: Awesome>]>

View File

@ -1154,10 +1154,15 @@ many-to-many, many-to-one, and
cannot be done using ``select_related``, in addition to the foreign key and cannot be done using ``select_related``, in addition to the foreign key and
one-to-one relationships that are supported by ``select_related``. It also one-to-one relationships that are supported by ``select_related``. It also
supports prefetching of supports prefetching of
:class:`~django.contrib.contenttypes.fields.GenericForeignKey`, however, it :class:`~django.contrib.contenttypes.fields.GenericForeignKey`, however, the
must be restricted to a homogeneous set of results. For example, prefetching queryset for each ``ContentType`` must be provided in the ``querysets``
objects referenced by a ``GenericForeignKey`` is only supported if the query parameter of :class:`~django.contrib.contenttypes.prefetch.GenericPrefetch`.
is restricted to one ``ContentType``.
.. versionchanged:: 5.0
Support for prefetching
:class:`~django.contrib.contenttypes.fields.GenericForeignKey` with
non-homogeneous set of results was added.
For example, suppose you have these models:: For example, suppose you have these models::

View File

@ -243,7 +243,9 @@ Minor features
:mod:`django.contrib.contenttypes` :mod:`django.contrib.contenttypes`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* ... * :meth:`.QuerySet.prefetch_related` now supports prefetching
:class:`~django.contrib.contenttypes.fields.GenericForeignKey` with
non-homogeneous set of results.
:mod:`django.contrib.gis` :mod:`django.contrib.gis`
~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~
@ -711,6 +713,14 @@ Miscellaneous
* The ``django.db.models.enums.ChoicesMeta`` metaclass is renamed to * The ``django.db.models.enums.ChoicesMeta`` metaclass is renamed to
``ChoicesType``. ``ChoicesType``.
* The ``Prefetch.get_current_queryset()`` method is deprecated.
* The ``get_prefetch_queryset()`` method of related managers and descriptors
is deprecated. Starting with Django 6.0, ``get_prefetcher()`` and
``prefetch_related_objects()`` will no longer fallback to
``get_prefetch_queryset()``. Subclasses should implement
``get_prefetch_querysets()`` instead.
.. _`oracledb`: https://oracle.github.io/python-oracledb/ .. _`oracledb`: https://oracle.github.io/python-oracledb/
Features removed in 5.0 Features removed in 5.0

View File

@ -1,9 +1,11 @@
import json import json
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.prefetch import GenericPrefetch
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.test.utils import isolate_apps from django.test.utils import isolate_apps
from django.utils.deprecation import RemovedInDjango60Warning
from .models import Answer, Post, Question from .models import Answer, Post, Question
@ -22,14 +24,6 @@ class GenericForeignKeyTests(TestCase):
): ):
Answer.question.get_content_type() Answer.question.get_content_type()
def test_incorrect_get_prefetch_queryset_arguments(self):
with self.assertRaisesMessage(
ValueError, "Custom queryset can't be used for this lookup."
):
Answer.question.get_prefetch_queryset(
Answer.objects.all(), Answer.objects.all()
)
def test_get_object_cache_respects_deleted_objects(self): def test_get_object_cache_respects_deleted_objects(self):
question = Question.objects.create(text="Who?") question = Question.objects.create(text="Who?")
post = Post.objects.create(title="Answer", parent=question) post = Post.objects.create(title="Answer", parent=question)
@ -59,3 +53,55 @@ class GenericRelationTests(TestCase):
answer2 = Answer.objects.create(question=question) answer2 = Answer.objects.create(question=question)
result = json.loads(Question.answer_set.field.value_to_string(question)) result = json.loads(Question.answer_set.field.value_to_string(question))
self.assertCountEqual(result, [answer1.pk, answer2.pk]) self.assertCountEqual(result, [answer1.pk, answer2.pk])
class GetPrefetchQuerySetDeprecation(TestCase):
def test_generic_relation_warning(self):
Question.objects.create(text="test")
questions = Question.objects.all()
msg = (
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead."
)
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
questions[0].answer_set.get_prefetch_queryset(questions)
def test_generic_foreign_key_warning(self):
answers = Answer.objects.all()
msg = (
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead."
)
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
Answer.question.get_prefetch_queryset(answers)
class GetPrefetchQuerySetsTests(TestCase):
def test_duplicate_querysets(self):
question = Question.objects.create(text="What is your name?")
answer = Answer.objects.create(text="Joe", question=question)
answer = Answer.objects.get(pk=answer.pk)
msg = "Only one queryset is allowed for each content type."
with self.assertRaisesMessage(ValueError, msg):
models.prefetch_related_objects(
[answer],
GenericPrefetch(
"question",
[
Question.objects.all(),
Question.objects.filter(text__startswith="test"),
],
),
)
def test_generic_relation_invalid_length(self):
Question.objects.create(text="test")
questions = Question.objects.all()
msg = (
"querysets argument of get_prefetch_querysets() should have a length of 1."
)
with self.assertRaisesMessage(ValueError, msg):
questions[0].answer_set.get_prefetch_querysets(
instances=questions,
querysets=[Answer.objects.all(), Question.objects.all()],
)

View File

@ -1,5 +1,6 @@
from django.apps import apps from django.apps import apps
from django.contrib.contenttypes.models import ContentType, ContentTypeManager from django.contrib.contenttypes.models import ContentType, ContentTypeManager
from django.contrib.contenttypes.prefetch import GenericPrefetch
from django.db import models from django.db import models
from django.db.migrations.state import ProjectState from django.db.migrations.state import ProjectState
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
@ -328,3 +329,17 @@ class ContentTypesMultidbTests(TestCase):
1, using="other" 1, using="other"
): ):
ContentType.objects.get_for_model(Author) ContentType.objects.get_for_model(Author)
class GenericPrefetchTests(TestCase):
def test_values_queryset(self):
msg = "Prefetch querysets cannot use raw(), values(), and values_list()."
with self.assertRaisesMessage(ValueError, msg):
GenericPrefetch("question", [Author.objects.values("pk")])
with self.assertRaisesMessage(ValueError, msg):
GenericPrefetch("question", [Author.objects.values_list("pk")])
def test_raw_queryset(self):
msg = "Prefetch querysets cannot use raw(), values(), and values_list()."
with self.assertRaisesMessage(ValueError, msg):
GenericPrefetch("question", [Author.objects.raw("select pk from author")])

View File

@ -1,6 +1,7 @@
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.prefetch import GenericPrefetch
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models import Q from django.db.models import Q, prefetch_related_objects
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
from .models import ( from .models import (
@ -747,6 +748,38 @@ class GenericRelationsTests(TestCase):
comparison.first_obj.comparisons.all(), [comparison] comparison.first_obj.comparisons.all(), [comparison]
) )
def test_generic_prefetch(self):
tagged_vegetable = TaggedItem.objects.create(
tag="great", content_object=self.bacon
)
tagged_animal = TaggedItem.objects.create(
tag="awesome", content_object=self.platypus
)
# Getting the instances again so that content object is deferred.
tagged_vegetable = TaggedItem.objects.get(pk=tagged_vegetable.pk)
tagged_animal = TaggedItem.objects.get(pk=tagged_animal.pk)
with self.assertNumQueries(2):
prefetch_related_objects(
[tagged_vegetable, tagged_animal],
GenericPrefetch(
"content_object",
[Vegetable.objects.all(), Animal.objects.only("common_name")],
),
)
with self.assertNumQueries(0):
self.assertEqual(tagged_vegetable.content_object.name, self.bacon.name)
with self.assertNumQueries(0):
self.assertEqual(
tagged_animal.content_object.common_name,
self.platypus.common_name,
)
with self.assertNumQueries(1):
self.assertEqual(
tagged_animal.content_object.latin_name,
self.platypus.latin_name,
)
class ProxyRelatedModelTest(TestCase): class ProxyRelatedModelTest(TestCase):
def test_default_behavior(self): def test_default_behavior(self):

View File

@ -2,6 +2,7 @@ from unittest import mock
from django.db import transaction from django.db import transaction
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango60Warning
from .models import Article, InheritedArticleA, InheritedArticleB, Publication, User from .models import Article, InheritedArticleA, InheritedArticleB, Publication, User
@ -561,3 +562,23 @@ class ManyToManyTests(TestCase):
self.assertEqual( self.assertEqual(
self.p3.article_set.exists(), self.p3.article_set.all().exists() self.p3.article_set.exists(), self.p3.article_set.all().exists()
) )
def test_get_prefetch_queryset_warning(self):
articles = Article.objects.all()
msg = (
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead."
)
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
self.a1.publications.get_prefetch_queryset(articles)
def test_get_prefetch_querysets_invalid_querysets_length(self):
articles = Article.objects.all()
msg = (
"querysets argument of get_prefetch_querysets() should have a length of 1."
)
with self.assertRaisesMessage(ValueError, msg):
self.a1.publications.get_prefetch_querysets(
instances=articles,
querysets=[Publication.objects.all(), Publication.objects.all()],
)

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from django.core.exceptions import FieldError, MultipleObjectsReturned from django.core.exceptions import FieldError, MultipleObjectsReturned
from django.db import IntegrityError, models, transaction from django.db import IntegrityError, models, transaction
from django.test import TestCase from django.test import TestCase
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.translation import gettext_lazy from django.utils.translation import gettext_lazy
from .models import ( from .models import (
@ -877,3 +878,49 @@ class ManyToOneTests(TestCase):
usa.cities.remove(chicago.pk) usa.cities.remove(chicago.pk)
with self.assertRaisesMessage(TypeError, msg): with self.assertRaisesMessage(TypeError, msg):
usa.cities.set([chicago.pk]) usa.cities.set([chicago.pk])
def test_get_prefetch_queryset_warning(self):
City.objects.create(name="Chicago")
cities = City.objects.all()
msg = (
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead."
)
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
City.country.get_prefetch_queryset(cities)
def test_get_prefetch_queryset_reverse_warning(self):
usa = Country.objects.create(name="United States")
City.objects.create(name="Chicago")
countries = Country.objects.all()
msg = (
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead."
)
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
usa.cities.get_prefetch_queryset(countries)
def test_get_prefetch_querysets_invalid_querysets_length(self):
City.objects.create(name="Chicago")
cities = City.objects.all()
msg = (
"querysets argument of get_prefetch_querysets() should have a length of 1."
)
with self.assertRaisesMessage(ValueError, msg):
City.country.get_prefetch_querysets(
instances=cities,
querysets=[Country.objects.all(), Country.objects.all()],
)
def test_get_prefetch_querysets_reverse_invalid_querysets_length(self):
usa = Country.objects.create(name="United States")
City.objects.create(name="Chicago")
countries = Country.objects.all()
msg = (
"querysets argument of get_prefetch_querysets() should have a length of 1."
)
with self.assertRaisesMessage(ValueError, msg):
usa.cities.get_prefetch_querysets(
instances=countries,
querysets=[City.objects.all(), City.objects.all()],
)

View File

@ -1,5 +1,6 @@
from django.db import IntegrityError, connection, transaction from django.db import IntegrityError, connection, transaction
from django.test import TestCase from django.test import TestCase
from django.utils.deprecation import RemovedInDjango60Warning
from .models import ( from .models import (
Bar, Bar,
@ -606,3 +607,23 @@ class OneToOneTests(TestCase):
self.b1.place_id = self.p2.pk self.b1.place_id = self.p2.pk
self.b1.save() self.b1.save()
self.assertEqual(self.b1.place, self.p2) self.assertEqual(self.b1.place, self.p2)
def test_get_prefetch_queryset_warning(self):
places = Place.objects.all()
msg = (
"get_prefetch_queryset() is deprecated. Use get_prefetch_querysets() "
"instead."
)
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
Place.bar.get_prefetch_queryset(places)
def test_get_prefetch_querysets_invalid_querysets_length(self):
places = Place.objects.all()
msg = (
"querysets argument of get_prefetch_querysets() should have a length of 1."
)
with self.assertRaisesMessage(ValueError, msg):
Place.bar.get_prefetch_querysets(
instances=places,
querysets=[Bar.objects.all(), Bar.objects.all()],
)

View File

@ -13,6 +13,7 @@ from django.test import (
skipUnlessDBFeature, skipUnlessDBFeature,
) )
from django.test.utils import CaptureQueriesContext from django.test.utils import CaptureQueriesContext
from django.utils.deprecation import RemovedInDjango60Warning
from .models import ( from .models import (
Article, Article,
@ -1696,7 +1697,7 @@ class Ticket21760Tests(TestCase):
def test_bug(self): def test_bug(self):
prefetcher = get_prefetcher(self.rooms[0], "house", "house")[0] prefetcher = get_prefetcher(self.rooms[0], "house", "house")[0]
queryset = prefetcher.get_prefetch_queryset(list(Room.objects.all()))[0] queryset = prefetcher.get_prefetch_querysets(list(Room.objects.all()))[0]
self.assertNotIn(" JOIN ", str(queryset.query)) self.assertNotIn(" JOIN ", str(queryset.query))
@ -1994,3 +1995,19 @@ class PrefetchLimitTests(TestDataMixin, TestCase):
) )
with self.assertRaisesMessage(NotSupportedError, msg): with self.assertRaisesMessage(NotSupportedError, msg):
list(Book.objects.prefetch_related(Prefetch("authors", authors[1:]))) list(Book.objects.prefetch_related(Prefetch("authors", authors[1:])))
class GetCurrentQuerySetDeprecation(TestCase):
def test_get_current_queryset_warning(self):
msg = (
"Prefetch.get_current_queryset() is deprecated. Use "
"get_current_querysets() instead."
)
authors = Author.objects.all()
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
self.assertEqual(
Prefetch("authors", authors).get_current_queryset(1),
authors,
)
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
self.assertIsNone(Prefetch("authors").get_current_queryset(1))