Made model fields comparable to other objects
Fixed #17851 -- Added __lt__ and @total_ordering to models.Field, made sure these work correctly on other objects than Field, too.
This commit is contained in:
parent
1aae1cba99
commit
5cbfb48b92
|
@ -12,7 +12,7 @@ from django import forms
|
||||||
from django.core import exceptions, validators
|
from django.core import exceptions, validators
|
||||||
from django.utils.datastructures import DictWrapper
|
from django.utils.datastructures import DictWrapper
|
||||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||||
from django.utils.functional import curry
|
from django.utils.functional import curry, total_ordering
|
||||||
from django.utils.text import capfirst
|
from django.utils.text import capfirst
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
@ -45,6 +45,7 @@ class FieldDoesNotExist(Exception):
|
||||||
#
|
#
|
||||||
# getattr(obj, opts.pk.attname)
|
# getattr(obj, opts.pk.attname)
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
class Field(object):
|
class Field(object):
|
||||||
"""Base class for all field types"""
|
"""Base class for all field types"""
|
||||||
|
|
||||||
|
@ -118,9 +119,17 @@ class Field(object):
|
||||||
messages.update(error_messages or {})
|
messages.update(error_messages or {})
|
||||||
self.error_messages = messages
|
self.error_messages = messages
|
||||||
|
|
||||||
def __cmp__(self, other):
|
def __eq__(self, other):
|
||||||
|
# Needed for @total_ordering
|
||||||
|
if isinstance(other, Field):
|
||||||
|
return self.creation_counter == other.creation_counter
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
# This is needed because bisect does not take a comparison function.
|
# This is needed because bisect does not take a comparison function.
|
||||||
return cmp(self.creation_counter, other.creation_counter)
|
if isinstance(other, Field):
|
||||||
|
return self.creation_counter < other.creation_counter
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def __deepcopy__(self, memodict):
|
def __deepcopy__(self, memodict):
|
||||||
# We don't have to deepcopy very much here, since most things are not
|
# We don't have to deepcopy very much here, since most things are not
|
||||||
|
|
|
@ -310,3 +310,35 @@ def partition(predicate, values):
|
||||||
for item in values:
|
for item in values:
|
||||||
results[predicate(item)].append(item)
|
results[predicate(item)].append(item)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
from functools import total_ordering
|
||||||
|
except ImportError:
|
||||||
|
# For Python < 2.7
|
||||||
|
# Code borrowed from python 2.7.3 stdlib
|
||||||
|
def total_ordering(cls):
|
||||||
|
"""Class decorator that fills in missing ordering methods"""
|
||||||
|
convert = {
|
||||||
|
'__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
|
||||||
|
('__le__', lambda self, other: self < other or self == other),
|
||||||
|
('__ge__', lambda self, other: not self < other)],
|
||||||
|
'__le__': [('__ge__', lambda self, other: not self <= other or self == other),
|
||||||
|
('__lt__', lambda self, other: self <= other and not self == other),
|
||||||
|
('__gt__', lambda self, other: not self <= other)],
|
||||||
|
'__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
|
||||||
|
('__ge__', lambda self, other: self > other or self == other),
|
||||||
|
('__le__', lambda self, other: not self > other)],
|
||||||
|
'__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
|
||||||
|
('__gt__', lambda self, other: self >= other and not self == other),
|
||||||
|
('__lt__', lambda self, other: not self >= other)]
|
||||||
|
}
|
||||||
|
roots = set(dir(cls)) & set(convert)
|
||||||
|
if not roots:
|
||||||
|
raise ValueError('must define at least one ordering operation: < > <= >=')
|
||||||
|
root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__
|
||||||
|
for opname, opfunc in convert[root]:
|
||||||
|
if opname not in roots:
|
||||||
|
opfunc.__name__ = opname
|
||||||
|
opfunc.__doc__ = getattr(int, opname).__doc__
|
||||||
|
setattr(cls, opname, opfunc)
|
||||||
|
return cls
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import absolute_import
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
from django.db.models.fields import FieldDoesNotExist
|
from django.db.models.fields import Field, FieldDoesNotExist
|
||||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
from django.utils.translation import ugettext_lazy
|
from django.utils.translation import ugettext_lazy
|
||||||
|
|
||||||
|
@ -520,6 +520,22 @@ class ModelTest(TestCase):
|
||||||
s = set([a10, a11, a12])
|
s = set([a10, a11, a12])
|
||||||
self.assertTrue(Article.objects.get(headline='Article 11') in s)
|
self.assertTrue(Article.objects.get(headline='Article 11') in s)
|
||||||
|
|
||||||
|
def test_field_ordering(self):
|
||||||
|
"""
|
||||||
|
Field instances have a `__lt__` comparison function to define an
|
||||||
|
ordering based on their creation. Prior to #17851 this ordering
|
||||||
|
comparison relied on the now unsupported `__cmp__` and was assuming
|
||||||
|
compared objects were both Field instances raising `AttributeError`
|
||||||
|
when it should have returned `NotImplemented`.
|
||||||
|
"""
|
||||||
|
f1 = Field()
|
||||||
|
f2 = Field(auto_created=True)
|
||||||
|
f3 = Field()
|
||||||
|
self.assertTrue(f2 < f1)
|
||||||
|
self.assertTrue(f3 > f1)
|
||||||
|
self.assertFalse(f1 == None)
|
||||||
|
self.assertFalse(f2 in (None, 1, ''))
|
||||||
|
|
||||||
def test_extra_method_select_argument_with_dashes_and_values(self):
|
def test_extra_method_select_argument_with_dashes_and_values(self):
|
||||||
# The 'select' argument to extra() supports names with dashes in
|
# The 'select' argument to extra() supports names with dashes in
|
||||||
# them, as long as you use values().
|
# them, as long as you use values().
|
||||||
|
|
Loading…
Reference in New Issue