Refs #23919 -- Replaced usage of django.utils.functional.curry() with functools.partial()/partialmethod().

This commit is contained in:
Sergey Fedoseev 2017-09-06 22:11:18 +05:00 committed by Tim Graham
parent 34f27f910b
commit 5b1c389603
8 changed files with 32 additions and 30 deletions

View File

@ -1,6 +1,7 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from functools import partialmethod
from itertools import chain from itertools import chain
from django.apps import apps from django.apps import apps
@ -27,7 +28,6 @@ from django.db.models.signals import (
) )
from django.db.models.utils import make_model_tuple from django.db.models.utils import make_model_tuple
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import curry
from django.utils.text import capfirst, get_text_list from django.utils.text import capfirst, get_text_list
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.utils.version import get_version from django.utils.version import get_version
@ -328,8 +328,8 @@ class ModelBase(type):
opts._prepare(cls) opts._prepare(cls)
if opts.order_with_respect_to: if opts.order_with_respect_to:
cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True) cls.get_next_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=True)
cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False) cls.get_previous_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=False)
# Defer creating accessors on the foreign class until it has been # Defer creating accessors on the foreign class until it has been
# created and registered. If remote_field is None, we're ordering # created and registered. If remote_field is None, we're ordering
@ -1670,7 +1670,7 @@ class Model(metaclass=ModelBase):
# ORDERING METHODS ######################### # ORDERING METHODS #########################
def method_set_order(ordered_obj, self, id_list, using=None): def method_set_order(self, ordered_obj, id_list, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
order_wrt = ordered_obj._meta.order_with_respect_to order_wrt = ordered_obj._meta.order_with_respect_to
@ -1682,7 +1682,7 @@ def method_set_order(ordered_obj, self, id_list, using=None):
ordered_obj.objects.filter(pk=j, **filter_args).update(_order=i) ordered_obj.objects.filter(pk=j, **filter_args).update(_order=i)
def method_get_order(ordered_obj, self): def method_get_order(self, ordered_obj):
order_wrt = ordered_obj._meta.order_with_respect_to order_wrt = ordered_obj._meta.order_with_respect_to
filter_args = order_wrt.get_forward_related_filter(self) filter_args = order_wrt.get_forward_related_filter(self)
pk_name = ordered_obj._meta.pk.name pk_name = ordered_obj._meta.pk.name
@ -1693,12 +1693,12 @@ def make_foreign_order_accessors(model, related_model):
setattr( setattr(
related_model, related_model,
'get_%s_order' % model.__name__.lower(), 'get_%s_order' % model.__name__.lower(),
curry(method_get_order, model) partialmethod(method_get_order, model)
) )
setattr( setattr(
related_model, related_model,
'set_%s_order' % model.__name__.lower(), 'set_%s_order' % model.__name__.lower(),
curry(method_set_order, model) partialmethod(method_set_order, model)
) )
######## ########

View File

@ -6,7 +6,7 @@ import itertools
import uuid import uuid
import warnings import warnings
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from functools import total_ordering from functools import partialmethod, total_ordering
from django import forms from django import forms
from django.apps import apps from django.apps import apps
@ -26,7 +26,7 @@ from django.utils.dateparse import (
) )
from django.utils.duration import duration_string from django.utils.duration import duration_string
from django.utils.encoding import force_bytes, smart_text from django.utils.encoding import force_bytes, smart_text
from django.utils.functional import Promise, cached_property, curry from django.utils.functional import Promise, cached_property
from django.utils.ipv6 import clean_ipv6_address from django.utils.ipv6 import clean_ipv6_address
from django.utils.itercompat import is_iterable from django.utils.itercompat import is_iterable
from django.utils.text import capfirst from django.utils.text import capfirst
@ -717,7 +717,7 @@ class Field(RegisterLookupMixin):
setattr(cls, self.attname, DeferredAttribute(self.attname, cls)) setattr(cls, self.attname, DeferredAttribute(self.attname, cls))
if self.choices: if self.choices:
setattr(cls, 'get_%s_display' % self.name, setattr(cls, 'get_%s_display' % self.name,
curry(cls._get_FIELD_display, field=self)) partialmethod(cls._get_FIELD_display, field=self))
def get_filter_kwargs_for_object(self, obj): def get_filter_kwargs_for_object(self, obj):
""" """
@ -1254,11 +1254,11 @@ class DateField(DateTimeCheckMixin, Field):
if not self.null: if not self.null:
setattr( setattr(
cls, 'get_next_by_%s' % self.name, cls, 'get_next_by_%s' % self.name,
curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=True) partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=True)
) )
setattr( setattr(
cls, 'get_previous_by_%s' % self.name, cls, 'get_previous_by_%s' % self.name,
curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False) partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)
) )
def get_prep_value(self, value): def get_prep_value(self, value):

View File

@ -12,7 +12,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.deletion import CASCADE, SET_DEFAULT, SET_NULL from django.db.models.deletion import CASCADE, SET_DEFAULT, SET_NULL
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.db.models.utils import make_model_tuple from django.db.models.utils import make_model_tuple
from django.utils.functional import cached_property, curry from django.utils.functional import cached_property
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from . import Field from . import Field
@ -1567,7 +1567,7 @@ class ManyToManyField(RelatedField):
setattr(cls, self.name, ManyToManyDescriptor(self.remote_field, reverse=False)) setattr(cls, self.name, ManyToManyDescriptor(self.remote_field, reverse=False))
# Set up the accessor for the m2m table name for the relation. # Set up the accessor for the m2m table name for the relation.
self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) self.m2m_db_table = partial(self._get_m2m_db_table, cls._meta)
def contribute_to_related_class(self, cls, related): def contribute_to_related_class(self, cls, related):
# Internal M2Ms (i.e., those with a related name ending with '+') # Internal M2Ms (i.e., those with a related name ending with '+')
@ -1576,15 +1576,15 @@ class ManyToManyField(RelatedField):
setattr(cls, related.get_accessor_name(), ManyToManyDescriptor(self.remote_field, reverse=True)) setattr(cls, related.get_accessor_name(), ManyToManyDescriptor(self.remote_field, reverse=True))
# Set up the accessors for the column names on the m2m table. # Set up the accessors for the column names on the m2m table.
self.m2m_column_name = curry(self._get_m2m_attr, related, 'column') self.m2m_column_name = partial(self._get_m2m_attr, related, 'column')
self.m2m_reverse_name = curry(self._get_m2m_reverse_attr, related, 'column') self.m2m_reverse_name = partial(self._get_m2m_reverse_attr, related, 'column')
self.m2m_field_name = curry(self._get_m2m_attr, related, 'name') self.m2m_field_name = partial(self._get_m2m_attr, related, 'name')
self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name') self.m2m_reverse_field_name = partial(self._get_m2m_reverse_attr, related, 'name')
get_m2m_rel = curry(self._get_m2m_attr, related, 'remote_field') get_m2m_rel = partial(self._get_m2m_attr, related, 'remote_field')
self.m2m_target_field_name = lambda: get_m2m_rel().field_name self.m2m_target_field_name = lambda: get_m2m_rel().field_name
get_m2m_reverse_rel = curry(self._get_m2m_reverse_attr, related, 'remote_field') get_m2m_reverse_rel = partial(self._get_m2m_reverse_attr, related, 'remote_field')
self.m2m_reverse_target_field_name = lambda: get_m2m_reverse_rel().field_name self.m2m_reverse_target_field_name = lambda: get_m2m_reverse_rel().field_name
def set_attributes_from_rel(self): def set_attributes_from_rel(self):

View File

@ -4,6 +4,7 @@ import os
import re import re
import sys import sys
from copy import copy from copy import copy
from functools import partial
from importlib import import_module from importlib import import_module
from io import BytesIO from io import BytesIO
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
@ -21,7 +22,7 @@ from django.test import signals
from django.test.utils import ContextList from django.test.utils import ContextList
from django.urls import resolve from django.urls import resolve
from django.utils.encoding import force_bytes from django.utils.encoding import force_bytes
from django.utils.functional import SimpleLazyObject, curry from django.utils.functional import SimpleLazyObject
from django.utils.http import urlencode from django.utils.http import urlencode
from django.utils.itercompat import is_iterable from django.utils.itercompat import is_iterable
@ -455,7 +456,7 @@ class Client(RequestFactory):
# Curry a data dictionary into an instance of the template renderer # Curry a data dictionary into an instance of the template renderer
# callback function. # callback function.
data = {} data = {}
on_template_render = curry(store_rendered_templates, data) on_template_render = partial(store_rendered_templates, data)
signal_uid = "template-render-%s" % id(request) signal_uid = "template-render-%s" % id(request)
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid) signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
# Capture exceptions created by the handler. # Capture exceptions created by the handler.
@ -491,7 +492,7 @@ class Client(RequestFactory):
response.templates = data.get("templates", []) response.templates = data.get("templates", [])
response.context = data.get("context") response.context = data.get("context")
response.json = curry(self._parse_json, response) response.json = partial(self._parse_json, response)
# Attach the ResolverMatch instance to the response # Attach the ResolverMatch instance to the response
response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO'])) response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO']))

View File

@ -1,9 +1,10 @@
from functools import partial
from django.db import models from django.db import models
from django.db.models.fields.related import ( from django.db.models.fields.related import (
RECURSIVE_RELATIONSHIP_CONSTANT, ManyToManyDescriptor, ManyToManyField, RECURSIVE_RELATIONSHIP_CONSTANT, ManyToManyDescriptor, ManyToManyField,
ManyToManyRel, RelatedField, create_many_to_many_intermediary_model, ManyToManyRel, RelatedField, create_many_to_many_intermediary_model,
) )
from django.utils.functional import curry
class CustomManyToManyField(RelatedField): class CustomManyToManyField(RelatedField):
@ -43,7 +44,7 @@ class CustomManyToManyField(RelatedField):
if not self.remote_field.through and not cls._meta.abstract and not cls._meta.swapped: if not self.remote_field.through and not cls._meta.abstract and not cls._meta.swapped:
self.remote_field.through = create_many_to_many_intermediary_model(self, cls) self.remote_field.through = create_many_to_many_intermediary_model(self, cls)
setattr(cls, self.name, ManyToManyDescriptor(self.remote_field)) setattr(cls, self.name, ManyToManyDescriptor(self.remote_field))
self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) self.m2m_db_table = partial(self._get_m2m_db_table, cls._meta)
def get_internal_type(self): def get_internal_type(self):
return 'ManyToManyField' return 'ManyToManyField'

View File

@ -390,7 +390,7 @@ class SerializerDataTests(TestCase):
pass pass
def serializerTest(format, self): def serializerTest(self, format):
# Create all the objects defined in the test data # Create all the objects defined in the test data
objects = [] objects = []

View File

@ -10,7 +10,7 @@ class NaturalKeySerializerTests(TestCase):
pass pass
def natural_key_serializer_test(format, self): def natural_key_serializer_test(self, format):
# Create all the objects defined in the test data # Create all the objects defined in the test data
with connection.constraint_checks_disabled(): with connection.constraint_checks_disabled():
objects = [ objects = [
@ -36,7 +36,7 @@ def natural_key_serializer_test(format, self):
) )
def natural_key_test(format, self): def natural_key_test(self, format):
book1 = { book1 = {
'data': '978-1590597255', 'data': '978-1590597255',
'title': 'The Definitive Guide to Django: Web Development Done Right', 'title': 'The Definitive Guide to Django: Web Development Done Right',

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from functools import partialmethod
from io import StringIO from io import StringIO
from unittest import mock from unittest import mock
@ -9,7 +10,6 @@ from django.db import connection, transaction
from django.http import HttpResponse from django.http import HttpResponse
from django.test import SimpleTestCase, override_settings, skipUnlessDBFeature from django.test import SimpleTestCase, override_settings, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate
from django.utils.functional import curry
from .models import ( from .models import (
Actor, Article, Author, AuthorProfile, BaseModel, Category, ComplexModel, Actor, Article, Author, AuthorProfile, BaseModel, Category, ComplexModel,
@ -405,4 +405,4 @@ def register_tests(test_class, method_name, test_func, exclude=None):
(exclude is None or f not in exclude)) (exclude is None or f not in exclude))
] ]
for format_ in formats: for format_ in formats:
setattr(test_class, method_name % format_, curry(test_func, format_)) setattr(test_class, method_name % format_, partialmethod(test_func, format_))