Fixed #30651 -- Made __eq__() methods return NotImplemented for not implemented comparisons.

Changed __eq__ to return NotImplemented instead of False if compared to
an object of the same type, as is recommended by the Python data model
reference. Now these models can be compared to ANY (or other objects
with __eq__ overwritten) without returning False automatically.
This commit is contained in:
ElizabethU 2019-09-02 19:09:31 -07:00 committed by Mariusz Felisiak
parent 6475e6318c
commit 54ea290e5b
20 changed files with 71 additions and 33 deletions

View File

@ -25,8 +25,9 @@ class Message:
self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, Message) and self.level == other.level and \ if not isinstance(other, Message):
self.message == other.message return NotImplemented
return self.level == other.level and self.message == other.message
def __str__(self): def __str__(self):
return str(self.message) return str(self.message)

View File

@ -89,13 +89,14 @@ class ExclusionConstraint(BaseConstraint):
return path, args, kwargs return path, args, kwargs
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__):
return ( return (
isinstance(other, self.__class__) and
self.name == other.name and self.name == other.name and
self.index_type == other.index_type and self.index_type == other.index_type and
self.expressions == other.expressions and self.expressions == other.expressions and
self.condition == other.condition self.condition == other.condition
) )
return super().__eq__(other)
def __repr__(self): def __repr__(self):
return '<%s: index_type=%s, expressions=%s%s>' % ( return '<%s: index_type=%s, expressions=%s%s>' % (

View File

@ -324,8 +324,9 @@ class BaseValidator:
raise ValidationError(self.message, code=self.code, params=params) raise ValidationError(self.message, code=self.code, params=params)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return ( return (
isinstance(other, self.__class__) and
self.limit_value == other.limit_value and self.limit_value == other.limit_value and
self.message == other.message and self.message == other.message and
self.code == other.code self.code == other.code

View File

@ -522,7 +522,7 @@ class Model(metaclass=ModelBase):
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Model): if not isinstance(other, Model):
return False return NotImplemented
if self._meta.concrete_model != other._meta.concrete_model: if self._meta.concrete_model != other._meta.concrete_model:
return False return False
my_pk = self.pk my_pk = self.pk

View File

@ -54,11 +54,9 @@ class CheckConstraint(BaseConstraint):
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
def __eq__(self, other): def __eq__(self, other):
return ( if isinstance(other, CheckConstraint):
isinstance(other, CheckConstraint) and return self.name == other.name and self.check == other.check
self.name == other.name and return super().__eq__(other)
self.check == other.check
)
def deconstruct(self): def deconstruct(self):
path, args, kwargs = super().deconstruct() path, args, kwargs = super().deconstruct()
@ -106,12 +104,13 @@ class UniqueConstraint(BaseConstraint):
) )
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, UniqueConstraint):
return ( return (
isinstance(other, UniqueConstraint) and
self.name == other.name and self.name == other.name and
self.fields == other.fields and self.fields == other.fields and
self.condition == other.condition self.condition == other.condition
) )
return super().__eq__(other)
def deconstruct(self): def deconstruct(self):
path, args, kwargs = super().deconstruct() path, args, kwargs = super().deconstruct()

View File

@ -401,7 +401,9 @@ class BaseExpression:
return tuple(identity) return tuple(identity)
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, BaseExpression) and other.identity == self.identity if not isinstance(other, BaseExpression):
return NotImplemented
return other.identity == self.identity
def __hash__(self): def __hash__(self):
return hash(self.identity) return hash(self.identity)

View File

@ -112,4 +112,6 @@ class Index:
) )
def __eq__(self, other): def __eq__(self, other):
return (self.__class__ == other.__class__) and (self.deconstruct() == other.deconstruct()) if self.__class__ == other.__class__:
return self.deconstruct() == other.deconstruct()
return NotImplemented

View File

@ -1543,7 +1543,9 @@ class Prefetch:
return None return None
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, Prefetch) and self.prefetch_to == other.prefetch_to if not isinstance(other, Prefetch):
return NotImplemented
return self.prefetch_to == other.prefetch_to
def __hash__(self): def __hash__(self):
return hash((self.__class__, self.prefetch_to)) return hash((self.__class__, self.prefetch_to))

View File

@ -309,8 +309,9 @@ class FilteredRelation:
self.path = [] self.path = []
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return ( return (
isinstance(other, self.__class__) and
self.relation_name == other.relation_name and self.relation_name == other.relation_name and
self.alias == other.alias and self.alias == other.alias and
self.condition == other.condition self.condition == other.condition

View File

@ -124,12 +124,10 @@ class BaseContext:
""" """
Compare two contexts by comparing theirs 'dicts' attributes. Compare two contexts by comparing theirs 'dicts' attributes.
""" """
return ( if not isinstance(other, BaseContext):
isinstance(other, BaseContext) and return NotImplemented
# because dictionaries can be put in different order # flatten dictionaries because they can be put in a different order.
# we have to flatten them like in templates return self.flatten() == other.flatten()
self.flatten() == other.flatten()
)
class Context(BaseContext): class Context(BaseContext):

View File

@ -1,5 +1,6 @@
import threading import threading
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest import mock
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
@ -354,6 +355,7 @@ class ModelTest(TestCase):
self.assertNotEqual(object(), Article(id=1)) self.assertNotEqual(object(), Article(id=1))
a = Article() a = Article()
self.assertEqual(a, a) self.assertEqual(a, a)
self.assertEqual(a, mock.ANY)
self.assertNotEqual(Article(), a) self.assertNotEqual(Article(), a)
def test_hash(self): def test_hash(self):

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models from django.db import IntegrityError, connection, models
from django.db.models.constraints import BaseConstraint from django.db.models.constraints import BaseConstraint
@ -39,6 +41,7 @@ class CheckConstraintTests(TestCase):
models.CheckConstraint(check=check1, name='price'), models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price'), models.CheckConstraint(check=check1, name='price'),
) )
self.assertEqual(models.CheckConstraint(check=check1, name='price'), mock.ANY)
self.assertNotEqual( self.assertNotEqual(
models.CheckConstraint(check=check1, name='price'), models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price2'), models.CheckConstraint(check=check1, name='price2'),
@ -102,6 +105,10 @@ class UniqueConstraintTests(TestCase):
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
) )
self.assertEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
mock.ANY,
)
self.assertNotEqual( self.assertNotEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'), models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'),

View File

@ -3,6 +3,7 @@ import pickle
import unittest import unittest
import uuid import uuid
from copy import deepcopy from copy import deepcopy
from unittest import mock
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models from django.db import DatabaseError, connection, models
@ -965,6 +966,7 @@ class SimpleExpressionTests(SimpleTestCase):
Expression(models.IntegerField()), Expression(models.IntegerField()),
Expression(output_field=models.IntegerField()) Expression(output_field=models.IntegerField())
) )
self.assertEqual(Expression(models.IntegerField()), mock.ANY)
self.assertNotEqual( self.assertNotEqual(
Expression(models.IntegerField()), Expression(models.IntegerField()),
Expression(models.CharField()) Expression(models.CharField())

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.db import connection, transaction from django.db import connection, transaction
from django.db.models import Case, Count, F, FilteredRelation, Q, When from django.db.models import Case, Count, F, FilteredRelation, Q, When
from django.test import TestCase from django.test import TestCase
@ -323,6 +325,9 @@ class FilteredRelationTests(TestCase):
[self.book1] [self.book1]
) )
def test_eq(self):
self.assertEqual(FilteredRelation('book', condition=Q(book__title='b')), mock.ANY)
class FilteredRelationAggregationTests(TestCase): class FilteredRelationAggregationTests(TestCase):

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.contrib.messages import constants from django.contrib.messages import constants
from django.contrib.messages.storage.base import Message from django.contrib.messages.storage.base import Message
from django.test import SimpleTestCase from django.test import SimpleTestCase
@ -9,6 +11,7 @@ class MessageTests(SimpleTestCase):
msg_2 = Message(constants.INFO, 'Test message 2') msg_2 = Message(constants.INFO, 'Test message 2')
msg_3 = Message(constants.WARNING, 'Test message 1') msg_3 = Message(constants.WARNING, 'Test message 1')
self.assertEqual(msg_1, msg_1) self.assertEqual(msg_1, msg_1)
self.assertEqual(msg_1, mock.ANY)
self.assertNotEqual(msg_1, msg_2) self.assertNotEqual(msg_1, msg_2)
self.assertNotEqual(msg_1, msg_3) self.assertNotEqual(msg_1, msg_3)
self.assertNotEqual(msg_2, msg_3) self.assertNotEqual(msg_2, msg_3)

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.conf import settings from django.conf import settings
from django.db import connection, models from django.db import connection, models
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
@ -28,6 +30,7 @@ class SimpleIndexesTests(SimpleTestCase):
same_index.model = Book same_index.model = Book
another_index.model = Book another_index.model = Book
self.assertEqual(index, same_index) self.assertEqual(index, same_index)
self.assertEqual(index, mock.ANY)
self.assertNotEqual(index, another_index) self.assertNotEqual(index, another_index)
def test_index_fields_type(self): def test_index_fields_type(self):

View File

@ -1,4 +1,5 @@
import datetime import datetime
from unittest import mock
from django.db import connection, transaction from django.db import connection, transaction
from django.db.models import F, Func, Q from django.db.models import F, Func, Q
@ -175,6 +176,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
condition=Q(cancelled=False), condition=Q(cancelled=False),
) )
self.assertEqual(constraint_1, constraint_1) self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2) self.assertNotEqual(constraint_1, constraint_2)
self.assertNotEqual(constraint_1, constraint_3) self.assertNotEqual(constraint_1, constraint_3)
self.assertNotEqual(constraint_2, constraint_3) self.assertNotEqual(constraint_2, constraint_3)

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db import connection from django.db import connection
@ -243,6 +245,7 @@ class PrefetchRelatedTests(TestDataMixin, TestCase):
prefetch_1 = Prefetch('authors', queryset=Author.objects.all()) prefetch_1 = Prefetch('authors', queryset=Author.objects.all())
prefetch_2 = Prefetch('books', queryset=Book.objects.all()) prefetch_2 = Prefetch('books', queryset=Book.objects.all())
self.assertEqual(prefetch_1, prefetch_1) self.assertEqual(prefetch_1, prefetch_1)
self.assertEqual(prefetch_1, mock.ANY)
self.assertNotEqual(prefetch_1, prefetch_2) self.assertNotEqual(prefetch_1, prefetch_2)
def test_forward_m2m_to_attr_conflict(self): def test_forward_m2m_to_attr_conflict(self):

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.http import HttpRequest from django.http import HttpRequest
from django.template import ( from django.template import (
Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist, Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist,
@ -18,6 +20,7 @@ class ContextTests(SimpleTestCase):
self.assertEqual(c.pop(), {"a": 2}) self.assertEqual(c.pop(), {"a": 2})
self.assertEqual(c["a"], 1) self.assertEqual(c["a"], 1)
self.assertEqual(c.get("foo", 42), 42) self.assertEqual(c.get("foo", 42), 42)
self.assertEqual(c, mock.ANY)
def test_push_context_manager(self): def test_push_context_manager(self):
c = Context({"a": 1}) c = Context({"a": 1})

View File

@ -3,7 +3,7 @@ import re
import types import types
from datetime import datetime, timedelta from datetime import datetime, timedelta
from decimal import Decimal from decimal import Decimal
from unittest import TestCase from unittest import TestCase, mock
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.core.files.base import ContentFile from django.core.files.base import ContentFile
@ -424,6 +424,7 @@ class TestValidatorEquality(TestCase):
MaxValueValidator(44), MaxValueValidator(44),
MaxValueValidator(44), MaxValueValidator(44),
) )
self.assertEqual(MaxValueValidator(44), mock.ANY)
self.assertNotEqual( self.assertNotEqual(
MaxValueValidator(44), MaxValueValidator(44),
MinValueValidator(44), MinValueValidator(44),