Made model fields comparable to other objects

Fixed #17851 -- Added __lt__ and @total_ordering to models.Field,
made sure these work correctly on other objects than Field, too.
This commit is contained in:
Simon Charette 2012-05-07 20:08:20 +03:00 committed by Anssi Kääriäinen
parent 1aae1cba99
commit 5cbfb48b92
3 changed files with 61 additions and 4 deletions

View File

@ -12,7 +12,7 @@ from django import forms
from django.core import exceptions, validators
from django.utils.datastructures import DictWrapper
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.functional import curry
from django.utils.functional import curry, total_ordering
from django.utils.text import capfirst
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _
@ -45,6 +45,7 @@ class FieldDoesNotExist(Exception):
#
# getattr(obj, opts.pk.attname)
@total_ordering
class Field(object):
"""Base class for all field types"""
@ -118,9 +119,17 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
def __cmp__(self, other):
def __eq__(self, other):
# Needed for @total_ordering
if isinstance(other, Field):
return self.creation_counter == other.creation_counter
return NotImplemented
def __lt__(self, other):
# This is needed because bisect does not take a comparison function.
return cmp(self.creation_counter, other.creation_counter)
if isinstance(other, Field):
return self.creation_counter < other.creation_counter
return NotImplemented
def __deepcopy__(self, memodict):
# We don't have to deepcopy very much here, since most things are not

View File

@ -310,3 +310,35 @@ def partition(predicate, values):
for item in values:
results[predicate(item)].append(item)
return results
try:
from functools import total_ordering
except ImportError:
# For Python < 2.7
# Code borrowed from python 2.7.3 stdlib
def total_ordering(cls):
"""Class decorator that fills in missing ordering methods"""
convert = {
'__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
('__le__', lambda self, other: self < other or self == other),
('__ge__', lambda self, other: not self < other)],
'__le__': [('__ge__', lambda self, other: not self <= other or self == other),
('__lt__', lambda self, other: self <= other and not self == other),
('__gt__', lambda self, other: not self <= other)],
'__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
('__ge__', lambda self, other: self > other or self == other),
('__le__', lambda self, other: not self > other)],
'__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
('__gt__', lambda self, other: self >= other and not self == other),
('__lt__', lambda self, other: not self >= other)]
}
roots = set(dir(cls)) & set(convert)
if not roots:
raise ValueError('must define at least one ordering operation: < > <= >=')
root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__
for opname, opfunc in convert[root]:
if opname not in roots:
opfunc.__name__ = opname
opfunc.__doc__ = getattr(int, opname).__doc__
setattr(cls, opname, opfunc)
return cls

View File

@ -3,7 +3,7 @@ from __future__ import absolute_import
from datetime import datetime
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.fields import FieldDoesNotExist
from django.db.models.fields import Field, FieldDoesNotExist
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.utils.translation import ugettext_lazy
@ -520,6 +520,22 @@ class ModelTest(TestCase):
s = set([a10, a11, a12])
self.assertTrue(Article.objects.get(headline='Article 11') in s)
def test_field_ordering(self):
"""
Field instances have a `__lt__` comparison function to define an
ordering based on their creation. Prior to #17851 this ordering
comparison relied on the now unsupported `__cmp__` and was assuming
compared objects were both Field instances raising `AttributeError`
when it should have returned `NotImplemented`.
"""
f1 = Field()
f2 = Field(auto_created=True)
f3 = Field()
self.assertTrue(f2 < f1)
self.assertTrue(f3 > f1)
self.assertFalse(f1 == None)
self.assertFalse(f2 in (None, 1, ''))
def test_extra_method_select_argument_with_dashes_and_values(self):
# The 'select' argument to extra() supports names with dashes in
# them, as long as you use values().