From 5b1c389603a353625ae1603ba345147356336afb Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Wed, 6 Sep 2017 22:11:18 +0500 Subject: [PATCH] Refs #23919 -- Replaced usage of django.utils.functional.curry() with functools.partial()/partialmethod(). --- django/db/models/base.py | 14 +++++++------- django/db/models/fields/__init__.py | 10 +++++----- django/db/models/fields/related.py | 16 ++++++++-------- django/test/client.py | 7 ++++--- tests/schema/fields.py | 5 +++-- tests/serializers/test_data.py | 2 +- tests/serializers/test_natural.py | 4 ++-- tests/serializers/tests.py | 4 ++-- 8 files changed, 32 insertions(+), 30 deletions(-) diff --git a/django/db/models/base.py b/django/db/models/base.py index dd2ac1de8c..34e0d65980 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -1,6 +1,7 @@ import copy import inspect import warnings +from functools import partialmethod from itertools import chain from django.apps import apps @@ -27,7 +28,6 @@ from django.db.models.signals import ( ) from django.db.models.utils import make_model_tuple from django.utils.encoding import force_text -from django.utils.functional import curry from django.utils.text import capfirst, get_text_list from django.utils.translation import gettext_lazy as _ from django.utils.version import get_version @@ -328,8 +328,8 @@ class ModelBase(type): opts._prepare(cls) if opts.order_with_respect_to: - cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True) - cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False) + cls.get_next_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=True) + cls.get_previous_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=False) # Defer creating accessors on the foreign class until it has been # created and registered. If remote_field is None, we're ordering @@ -1670,7 +1670,7 @@ class Model(metaclass=ModelBase): # ORDERING METHODS ######################### -def method_set_order(ordered_obj, self, id_list, using=None): +def method_set_order(self, ordered_obj, id_list, using=None): if using is None: using = DEFAULT_DB_ALIAS order_wrt = ordered_obj._meta.order_with_respect_to @@ -1682,7 +1682,7 @@ def method_set_order(ordered_obj, self, id_list, using=None): ordered_obj.objects.filter(pk=j, **filter_args).update(_order=i) -def method_get_order(ordered_obj, self): +def method_get_order(self, ordered_obj): order_wrt = ordered_obj._meta.order_with_respect_to filter_args = order_wrt.get_forward_related_filter(self) pk_name = ordered_obj._meta.pk.name @@ -1693,12 +1693,12 @@ def make_foreign_order_accessors(model, related_model): setattr( related_model, 'get_%s_order' % model.__name__.lower(), - curry(method_get_order, model) + partialmethod(method_get_order, model) ) setattr( related_model, 'set_%s_order' % model.__name__.lower(), - curry(method_set_order, model) + partialmethod(method_set_order, model) ) ######## diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index d4ccddc726..40801d0be2 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -6,7 +6,7 @@ import itertools import uuid import warnings from base64 import b64decode, b64encode -from functools import total_ordering +from functools import partialmethod, total_ordering from django import forms from django.apps import apps @@ -26,7 +26,7 @@ from django.utils.dateparse import ( ) from django.utils.duration import duration_string from django.utils.encoding import force_bytes, smart_text -from django.utils.functional import Promise, cached_property, curry +from django.utils.functional import Promise, cached_property from django.utils.ipv6 import clean_ipv6_address from django.utils.itercompat import is_iterable from django.utils.text import capfirst @@ -717,7 +717,7 @@ class Field(RegisterLookupMixin): setattr(cls, self.attname, DeferredAttribute(self.attname, cls)) if self.choices: setattr(cls, 'get_%s_display' % self.name, - curry(cls._get_FIELD_display, field=self)) + partialmethod(cls._get_FIELD_display, field=self)) def get_filter_kwargs_for_object(self, obj): """ @@ -1254,11 +1254,11 @@ class DateField(DateTimeCheckMixin, Field): if not self.null: setattr( cls, 'get_next_by_%s' % self.name, - curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=True) + partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=True) ) setattr( cls, 'get_previous_by_%s' % self.name, - curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False) + partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=False) ) def get_prep_value(self, value): diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 0e0910277f..5cf540d385 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -12,7 +12,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.deletion import CASCADE, SET_DEFAULT, SET_NULL from django.db.models.query_utils import PathInfo from django.db.models.utils import make_model_tuple -from django.utils.functional import cached_property, curry +from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ from . import Field @@ -1567,7 +1567,7 @@ class ManyToManyField(RelatedField): setattr(cls, self.name, ManyToManyDescriptor(self.remote_field, reverse=False)) # Set up the accessor for the m2m table name for the relation. - self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) + self.m2m_db_table = partial(self._get_m2m_db_table, cls._meta) def contribute_to_related_class(self, cls, related): # Internal M2Ms (i.e., those with a related name ending with '+') @@ -1576,15 +1576,15 @@ class ManyToManyField(RelatedField): setattr(cls, related.get_accessor_name(), ManyToManyDescriptor(self.remote_field, reverse=True)) # Set up the accessors for the column names on the m2m table. - self.m2m_column_name = curry(self._get_m2m_attr, related, 'column') - self.m2m_reverse_name = curry(self._get_m2m_reverse_attr, related, 'column') + self.m2m_column_name = partial(self._get_m2m_attr, related, 'column') + self.m2m_reverse_name = partial(self._get_m2m_reverse_attr, related, 'column') - self.m2m_field_name = curry(self._get_m2m_attr, related, 'name') - self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name') + self.m2m_field_name = partial(self._get_m2m_attr, related, 'name') + self.m2m_reverse_field_name = partial(self._get_m2m_reverse_attr, related, 'name') - get_m2m_rel = curry(self._get_m2m_attr, related, 'remote_field') + get_m2m_rel = partial(self._get_m2m_attr, related, 'remote_field') self.m2m_target_field_name = lambda: get_m2m_rel().field_name - get_m2m_reverse_rel = curry(self._get_m2m_reverse_attr, related, 'remote_field') + get_m2m_reverse_rel = partial(self._get_m2m_reverse_attr, related, 'remote_field') self.m2m_reverse_target_field_name = lambda: get_m2m_reverse_rel().field_name def set_attributes_from_rel(self): diff --git a/django/test/client.py b/django/test/client.py index f1bfa23eb6..d47cb087b8 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -4,6 +4,7 @@ import os import re import sys from copy import copy +from functools import partial from importlib import import_module from io import BytesIO from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit @@ -21,7 +22,7 @@ from django.test import signals from django.test.utils import ContextList from django.urls import resolve from django.utils.encoding import force_bytes -from django.utils.functional import SimpleLazyObject, curry +from django.utils.functional import SimpleLazyObject from django.utils.http import urlencode from django.utils.itercompat import is_iterable @@ -455,7 +456,7 @@ class Client(RequestFactory): # Curry a data dictionary into an instance of the template renderer # callback function. data = {} - on_template_render = curry(store_rendered_templates, data) + on_template_render = partial(store_rendered_templates, data) signal_uid = "template-render-%s" % id(request) signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid) # Capture exceptions created by the handler. @@ -491,7 +492,7 @@ class Client(RequestFactory): response.templates = data.get("templates", []) response.context = data.get("context") - response.json = curry(self._parse_json, response) + response.json = partial(self._parse_json, response) # Attach the ResolverMatch instance to the response response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO'])) diff --git a/tests/schema/fields.py b/tests/schema/fields.py index f03b9813b6..5f3244b767 100644 --- a/tests/schema/fields.py +++ b/tests/schema/fields.py @@ -1,9 +1,10 @@ +from functools import partial + from django.db import models from django.db.models.fields.related import ( RECURSIVE_RELATIONSHIP_CONSTANT, ManyToManyDescriptor, ManyToManyField, ManyToManyRel, RelatedField, create_many_to_many_intermediary_model, ) -from django.utils.functional import curry class CustomManyToManyField(RelatedField): @@ -43,7 +44,7 @@ class CustomManyToManyField(RelatedField): if not self.remote_field.through and not cls._meta.abstract and not cls._meta.swapped: self.remote_field.through = create_many_to_many_intermediary_model(self, cls) setattr(cls, self.name, ManyToManyDescriptor(self.remote_field)) - self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) + self.m2m_db_table = partial(self._get_m2m_db_table, cls._meta) def get_internal_type(self): return 'ManyToManyField' diff --git a/tests/serializers/test_data.py b/tests/serializers/test_data.py index b61dfe075f..62ce2bbfec 100644 --- a/tests/serializers/test_data.py +++ b/tests/serializers/test_data.py @@ -390,7 +390,7 @@ class SerializerDataTests(TestCase): pass -def serializerTest(format, self): +def serializerTest(self, format): # Create all the objects defined in the test data objects = [] diff --git a/tests/serializers/test_natural.py b/tests/serializers/test_natural.py index 99fc2bec9e..0c99e8e13f 100644 --- a/tests/serializers/test_natural.py +++ b/tests/serializers/test_natural.py @@ -10,7 +10,7 @@ class NaturalKeySerializerTests(TestCase): pass -def natural_key_serializer_test(format, self): +def natural_key_serializer_test(self, format): # Create all the objects defined in the test data with connection.constraint_checks_disabled(): objects = [ @@ -36,7 +36,7 @@ def natural_key_serializer_test(format, self): ) -def natural_key_test(format, self): +def natural_key_test(self, format): book1 = { 'data': '978-1590597255', 'title': 'The Definitive Guide to Django: Web Development Done Right', diff --git a/tests/serializers/tests.py b/tests/serializers/tests.py index 02184d735d..0ad95f7397 100644 --- a/tests/serializers/tests.py +++ b/tests/serializers/tests.py @@ -1,4 +1,5 @@ from datetime import datetime +from functools import partialmethod from io import StringIO from unittest import mock @@ -9,7 +10,6 @@ from django.db import connection, transaction from django.http import HttpResponse from django.test import SimpleTestCase, override_settings, skipUnlessDBFeature from django.test.utils import Approximate -from django.utils.functional import curry from .models import ( Actor, Article, Author, AuthorProfile, BaseModel, Category, ComplexModel, @@ -405,4 +405,4 @@ def register_tests(test_class, method_name, test_func, exclude=None): (exclude is None or f not in exclude)) ] for format_ in formats: - setattr(test_class, method_name % format_, curry(test_func, format_)) + setattr(test_class, method_name % format_, partialmethod(test_func, format_))