diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 32085daf189..63ed2ff4c7f 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -858,7 +858,7 @@ class ForeignObject(RelatedField): @classmethod @functools.lru_cache(maxsize=None) - def get_lookups(cls): + def get_class_lookups(cls): bases = inspect.getmro(cls) bases = bases[: bases.index(ForeignObject) + 1] class_lookups = [parent.__dict__.get("class_lookups", {}) for parent in bases] diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index f4215ed48e2..5562303e006 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -188,19 +188,42 @@ class DeferredAttribute: return None -class RegisterLookupMixin: - @classmethod - def _get_lookup(cls, lookup_name): - return cls.get_lookups().get(lookup_name, None) +class class_or_instance_method: + """ + Hook used in RegisterLookupMixin to return partial functions depending on + the caller type (instance or class of models.Field). + """ + + def __init__(self, class_method, instance_method): + self.class_method = class_method + self.instance_method = instance_method + + def __get__(self, instance, owner): + if instance is None: + return functools.partial(self.class_method, owner) + return functools.partial(self.instance_method, instance) + + +class RegisterLookupMixin: + def _get_lookup(self, lookup_name): + return self.get_lookups().get(lookup_name, None) - @classmethod @functools.lru_cache(maxsize=None) - def get_lookups(cls): + def get_class_lookups(cls): class_lookups = [ parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls) ] return cls.merge_dicts(class_lookups) + def get_instance_lookups(self): + class_lookups = self.get_class_lookups() + if instance_lookups := getattr(self, "instance_lookups", None): + return {**class_lookups, **instance_lookups} + return class_lookups + + get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups) + get_class_lookups = classmethod(get_class_lookups) + def get_lookup(self, lookup_name): from django.db.models.lookups import Lookup @@ -233,22 +256,33 @@ class RegisterLookupMixin: return merged @classmethod - def _clear_cached_lookups(cls): + def _clear_cached_class_lookups(cls): for subclass in subclasses(cls): - subclass.get_lookups.cache_clear() + subclass.get_class_lookups.cache_clear() - @classmethod - def register_lookup(cls, lookup, lookup_name=None): + def register_class_lookup(cls, lookup, lookup_name=None): if lookup_name is None: lookup_name = lookup.lookup_name if "class_lookups" not in cls.__dict__: cls.class_lookups = {} cls.class_lookups[lookup_name] = lookup - cls._clear_cached_lookups() + cls._clear_cached_class_lookups() return lookup - @classmethod - def _unregister_lookup(cls, lookup, lookup_name=None): + def register_instance_lookup(self, lookup, lookup_name=None): + if lookup_name is None: + lookup_name = lookup.lookup_name + if "instance_lookups" not in self.__dict__: + self.instance_lookups = {} + self.instance_lookups[lookup_name] = lookup + return lookup + + register_lookup = class_or_instance_method( + register_class_lookup, register_instance_lookup + ) + register_class_lookup = classmethod(register_class_lookup) + + def _unregister_class_lookup(cls, lookup, lookup_name=None): """ Remove given lookup from cls lookups. For use in tests only as it's not thread-safe. @@ -256,7 +290,21 @@ class RegisterLookupMixin: if lookup_name is None: lookup_name = lookup.lookup_name del cls.class_lookups[lookup_name] - cls._clear_cached_lookups() + cls._clear_cached_class_lookups() + + def _unregister_instance_lookup(self, lookup, lookup_name=None): + """ + Remove given lookup from instance lookups. For use in tests only as + it's not thread-safe. + """ + if lookup_name is None: + lookup_name = lookup.lookup_name + del self.instance_lookups[lookup_name] + + _unregister_lookup = class_or_instance_method( + _unregister_class_lookup, _unregister_instance_lookup + ) + _unregister_class_lookup = classmethod(_unregister_class_lookup) def select_related_descend(field, restricted, requested, select_mask, reverse=False): diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index ea18d468de3..7f01ae3cfc4 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -520,8 +520,13 @@ Registering and fetching lookups ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ``Field`` implements the :ref:`lookup registration API `. -The API can be used to customize which lookups are available for a field class, and -how lookups are fetched from a field. +The API can be used to customize which lookups are available for a field class +and its instances, and how lookups are fetched from a field. + +.. versionchanged:: 4.2 + + Support for registering lookups on :class:`~django.db.models.Field` + instances was added. .. _model-field-types: diff --git a/docs/ref/models/lookups.txt b/docs/ref/models/lookups.txt index a03cede22fb..d71424c9061 100644 --- a/docs/ref/models/lookups.txt +++ b/docs/ref/models/lookups.txt @@ -34,7 +34,7 @@ Registration API ================ Django uses :class:`~lookups.RegisterLookupMixin` to give a class the interface to -register lookups on itself. The two prominent examples are +register lookups on itself or its instances. The two prominent examples are :class:`~django.db.models.Field`, the base class of all model fields, and :class:`Transform`, the base class of all Django transforms. @@ -44,35 +44,49 @@ register lookups on itself. The two prominent examples are .. classmethod:: register_lookup(lookup, lookup_name=None) - Registers a new lookup in the class. For example - ``DateField.register_lookup(YearExact)`` will register ``YearExact`` - lookup on ``DateField``. It overrides a lookup that already exists with - the same name. ``lookup_name`` will be used for this lookup if + Registers a new lookup in the class or class instance. For example:: + + DateField.register_lookup(YearExact) + User._meta.get_field('date_joined').register_lookup(MonthExact) + + will register ``YearExact`` lookup on ``DateField`` and ``MonthExact`` + lookup on the ``User.date_joined`` (you can use :ref:`Field Access API + ` to retrieve a single field instance). It + overrides a lookup that already exists with the same name. Lookups + registered on field instances take precedence over the lookups + registered on classes. ``lookup_name`` will be used for this lookup if provided, otherwise ``lookup.lookup_name`` will be used. .. method:: get_lookup(lookup_name) - Returns the :class:`Lookup` named ``lookup_name`` registered in the class. - The default implementation looks recursively on all parent classes - and checks if any has a registered lookup named ``lookup_name``, returning - the first match. + Returns the :class:`Lookup` named ``lookup_name`` registered in the + class or class instance depending on what calls it. The default + implementation looks recursively on all parent classes and checks if + any has a registered lookup named ``lookup_name``, returning the first + match. Instance lookups would override any class lookups with the same + ``lookup_name``. .. method:: get_lookups() - Returns a dictionary of each lookup name registered in the class mapped - to the :class:`Lookup` class. + Returns a dictionary of each lookup name registered in the class or + class instance mapped to the :class:`Lookup` class. .. method:: get_transform(transform_name) - Returns a :class:`Transform` named ``transform_name``. The default - implementation looks recursively on all parent classes to check if any - has the registered transform named ``transform_name``, returning the first - match. + Returns a :class:`Transform` named ``transform_name`` registered in the + class or class instance. The default implementation looks recursively + on all parent classes to check if any has the registered transform + named ``transform_name``, returning the first match. For a class to be a lookup, it must follow the :ref:`Query Expression API `. :class:`~Lookup` and :class:`~Transform` naturally follow this API. +.. versionchanged:: 4.2 + + Support for registering lookups on :class:`~django.db.models.Field` + instances was added. + .. _query-expression: The Query Expression API diff --git a/docs/releases/4.2.txt b/docs/releases/4.2.txt index 588d466a669..199abf4e198 100644 --- a/docs/releases/4.2.txt +++ b/docs/releases/4.2.txt @@ -204,6 +204,9 @@ Models * :meth:`~.QuerySet.prefetch_related` now supports :class:`~django.db.models.Prefetch` objects with sliced querysets. +* :ref:`Registering lookups ` on + :class:`~django.db.models.Field` instances is now supported. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/custom_lookups/models.py b/tests/custom_lookups/models.py index b1d809a5c91..17e0ce702b4 100644 --- a/tests/custom_lookups/models.py +++ b/tests/custom_lookups/models.py @@ -3,6 +3,7 @@ from django.db import models class Author(models.Model): name = models.CharField(max_length=20) + alias = models.CharField(max_length=20) age = models.IntegerField(null=True) birthdate = models.DateField(null=True) average_rating = models.FloatField(null=True) diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index f2610138a47..7e5802a9a9e 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -330,7 +330,7 @@ class LookupTests(TestCase): field = Article._meta.get_field("author") # clear and re-cache - field.get_lookups.cache_clear() + field.get_class_lookups.cache_clear() self.assertNotIn("exactly", field.get_lookups()) # registration should bust the cache @@ -670,6 +670,37 @@ class RegisterLookupTests(SimpleTestCase): self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith) self.assertIsNone(author_name.get_lookup("sw")) + def test_instance_lookup(self): + author_name = Author._meta.get_field("name") + author_alias = Author._meta.get_field("alias") + with register_lookup(author_name, CustomStartsWith): + self.assertEqual(author_name.instance_lookups, {"sw": CustomStartsWith}) + self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith) + self.assertIsNone(author_alias.get_lookup("sw")) + self.assertIsNone(author_name.get_lookup("sw")) + self.assertEqual(author_name.instance_lookups, {}) + self.assertIsNone(author_alias.get_lookup("sw")) + + def test_instance_lookup_override_class_lookups(self): + author_name = Author._meta.get_field("name") + author_alias = Author._meta.get_field("alias") + with register_lookup(models.CharField, CustomStartsWith, lookup_name="st_end"): + with register_lookup(author_alias, CustomEndsWith, lookup_name="st_end"): + self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith) + self.assertEqual(author_alias.get_lookup("st_end"), CustomEndsWith) + self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith) + self.assertEqual(author_alias.get_lookup("st_end"), CustomStartsWith) + self.assertIsNone(author_name.get_lookup("st_end")) + self.assertIsNone(author_alias.get_lookup("st_end")) + + def test_instance_lookup_override(self): + author_name = Author._meta.get_field("name") + with register_lookup(author_name, CustomStartsWith, lookup_name="st_end"): + self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith) + author_name.register_lookup(CustomEndsWith, lookup_name="st_end") + self.assertEqual(author_name.get_lookup("st_end"), CustomEndsWith) + self.assertIsNone(author_name.get_lookup("st_end")) + def test_lookup_on_transform(self): transform = Div3Transform with register_lookup(Div3Transform, CustomStartsWith): @@ -682,10 +713,16 @@ class RegisterLookupTests(SimpleTestCase): self.assertEqual(transform.get_lookups(), {}) def test_transform_on_field(self): - author_age = Author._meta.get_field("age") - with register_lookup(models.IntegerField, Div3Transform): - self.assertEqual(author_age.get_transform("div3"), Div3Transform) - self.assertIsNone(author_age.get_transform("div3")) + author_name = Author._meta.get_field("name") + author_alias = Author._meta.get_field("alias") + with register_lookup(models.CharField, Div3Transform): + self.assertEqual(author_alias.get_transform("div3"), Div3Transform) + self.assertEqual(author_name.get_transform("div3"), Div3Transform) + with register_lookup(author_alias, Div3Transform): + self.assertEqual(author_alias.get_transform("div3"), Div3Transform) + self.assertIsNone(author_name.get_transform("div3")) + self.assertIsNone(author_alias.get_transform("div3")) + self.assertIsNone(author_name.get_transform("div3")) def test_related_lookup(self): article_author = Article._meta.get_field("author") @@ -693,3 +730,9 @@ class RegisterLookupTests(SimpleTestCase): self.assertIsNone(article_author.get_lookup("sw")) with register_lookup(models.ForeignKey, RelatedMoreThan): self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan) + + def test_instance_related_lookup(self): + article_author = Article._meta.get_field("author") + with register_lookup(article_author, RelatedMoreThan): + self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan) + self.assertIsNone(article_author.get_lookup("rmt"))