Fixed #30651 -- Made __eq__() methods return NotImplemented for not implemented comparisons.

Changed __eq__ to return NotImplemented instead of False if compared to
an object of the same type, as is recommended by the Python data model
reference. Now these models can be compared to ANY (or other objects
with __eq__ overwritten) without returning False automatically.
This commit is contained in:
ElizabethU 2019-09-02 19:09:31 -07:00 committed by Mariusz Felisiak
parent 6475e6318c
commit 54ea290e5b
20 changed files with 71 additions and 33 deletions

View File

@ -25,8 +25,9 @@ class Message:
self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None
def __eq__(self, other):
return isinstance(other, Message) and self.level == other.level and \
self.message == other.message
if not isinstance(other, Message):
return NotImplemented
return self.level == other.level and self.message == other.message
def __str__(self):
return str(self.message)

View File

@ -89,13 +89,14 @@ class ExclusionConstraint(BaseConstraint):
return path, args, kwargs
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
self.name == other.name and
self.index_type == other.index_type and
self.expressions == other.expressions and
self.condition == other.condition
)
if isinstance(other, self.__class__):
return (
self.name == other.name and
self.index_type == other.index_type and
self.expressions == other.expressions and
self.condition == other.condition
)
return super().__eq__(other)
def __repr__(self):
return '<%s: index_type=%s, expressions=%s%s>' % (

View File

@ -324,8 +324,9 @@ class BaseValidator:
raise ValidationError(self.message, code=self.code, params=params)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
isinstance(other, self.__class__) and
self.limit_value == other.limit_value and
self.message == other.message and
self.code == other.code

View File

@ -522,7 +522,7 @@ class Model(metaclass=ModelBase):
def __eq__(self, other):
if not isinstance(other, Model):
return False
return NotImplemented
if self._meta.concrete_model != other._meta.concrete_model:
return False
my_pk = self.pk

View File

@ -54,11 +54,9 @@ class CheckConstraint(BaseConstraint):
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
def __eq__(self, other):
return (
isinstance(other, CheckConstraint) and
self.name == other.name and
self.check == other.check
)
if isinstance(other, CheckConstraint):
return self.name == other.name and self.check == other.check
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
@ -106,12 +104,13 @@ class UniqueConstraint(BaseConstraint):
)
def __eq__(self, other):
return (
isinstance(other, UniqueConstraint) and
self.name == other.name and
self.fields == other.fields and
self.condition == other.condition
)
if isinstance(other, UniqueConstraint):
return (
self.name == other.name and
self.fields == other.fields and
self.condition == other.condition
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()

View File

@ -401,7 +401,9 @@ class BaseExpression:
return tuple(identity)
def __eq__(self, other):
return isinstance(other, BaseExpression) and other.identity == self.identity
if not isinstance(other, BaseExpression):
return NotImplemented
return other.identity == self.identity
def __hash__(self):
return hash(self.identity)

View File

@ -112,4 +112,6 @@ class Index:
)
def __eq__(self, other):
return (self.__class__ == other.__class__) and (self.deconstruct() == other.deconstruct())
if self.__class__ == other.__class__:
return self.deconstruct() == other.deconstruct()
return NotImplemented

View File

@ -1543,7 +1543,9 @@ class Prefetch:
return None
def __eq__(self, other):
return isinstance(other, Prefetch) and self.prefetch_to == other.prefetch_to
if not isinstance(other, Prefetch):
return NotImplemented
return self.prefetch_to == other.prefetch_to
def __hash__(self):
return hash((self.__class__, self.prefetch_to))

View File

@ -309,8 +309,9 @@ class FilteredRelation:
self.path = []
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
isinstance(other, self.__class__) and
self.relation_name == other.relation_name and
self.alias == other.alias and
self.condition == other.condition

View File

@ -124,12 +124,10 @@ class BaseContext:
"""
Compare two contexts by comparing theirs 'dicts' attributes.
"""
return (
isinstance(other, BaseContext) and
# because dictionaries can be put in different order
# we have to flatten them like in templates
self.flatten() == other.flatten()
)
if not isinstance(other, BaseContext):
return NotImplemented
# flatten dictionaries because they can be put in a different order.
return self.flatten() == other.flatten()
class Context(BaseContext):

View File

@ -1,5 +1,6 @@
import threading
from datetime import datetime, timedelta
from unittest import mock
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
@ -354,6 +355,7 @@ class ModelTest(TestCase):
self.assertNotEqual(object(), Article(id=1))
a = Article()
self.assertEqual(a, a)
self.assertEqual(a, mock.ANY)
self.assertNotEqual(Article(), a)
def test_hash(self):

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.core.exceptions import ValidationError
from django.db import IntegrityError, connection, models
from django.db.models.constraints import BaseConstraint
@ -39,6 +41,7 @@ class CheckConstraintTests(TestCase):
models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price'),
)
self.assertEqual(models.CheckConstraint(check=check1, name='price'), mock.ANY)
self.assertNotEqual(
models.CheckConstraint(check=check1, name='price'),
models.CheckConstraint(check=check1, name='price2'),
@ -102,6 +105,10 @@ class UniqueConstraintTests(TestCase):
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
)
self.assertEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
mock.ANY,
)
self.assertNotEqual(
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'),

View File

@ -3,6 +3,7 @@ import pickle
import unittest
import uuid
from copy import deepcopy
from unittest import mock
from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models
@ -965,6 +966,7 @@ class SimpleExpressionTests(SimpleTestCase):
Expression(models.IntegerField()),
Expression(output_field=models.IntegerField())
)
self.assertEqual(Expression(models.IntegerField()), mock.ANY)
self.assertNotEqual(
Expression(models.IntegerField()),
Expression(models.CharField())

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.db import connection, transaction
from django.db.models import Case, Count, F, FilteredRelation, Q, When
from django.test import TestCase
@ -323,6 +325,9 @@ class FilteredRelationTests(TestCase):
[self.book1]
)
def test_eq(self):
self.assertEqual(FilteredRelation('book', condition=Q(book__title='b')), mock.ANY)
class FilteredRelationAggregationTests(TestCase):

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.contrib.messages import constants
from django.contrib.messages.storage.base import Message
from django.test import SimpleTestCase
@ -9,6 +11,7 @@ class MessageTests(SimpleTestCase):
msg_2 = Message(constants.INFO, 'Test message 2')
msg_3 = Message(constants.WARNING, 'Test message 1')
self.assertEqual(msg_1, msg_1)
self.assertEqual(msg_1, mock.ANY)
self.assertNotEqual(msg_1, msg_2)
self.assertNotEqual(msg_1, msg_3)
self.assertNotEqual(msg_2, msg_3)

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.conf import settings
from django.db import connection, models
from django.db.models.query_utils import Q
@ -28,6 +30,7 @@ class SimpleIndexesTests(SimpleTestCase):
same_index.model = Book
another_index.model = Book
self.assertEqual(index, same_index)
self.assertEqual(index, mock.ANY)
self.assertNotEqual(index, another_index)
def test_index_fields_type(self):

View File

@ -1,4 +1,5 @@
import datetime
from unittest import mock
from django.db import connection, transaction
from django.db.models import F, Func, Q
@ -175,6 +176,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
condition=Q(cancelled=False),
)
self.assertEqual(constraint_1, constraint_1)
self.assertEqual(constraint_1, mock.ANY)
self.assertNotEqual(constraint_1, constraint_2)
self.assertNotEqual(constraint_1, constraint_3)
self.assertNotEqual(constraint_2, constraint_3)

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
@ -243,6 +245,7 @@ class PrefetchRelatedTests(TestDataMixin, TestCase):
prefetch_1 = Prefetch('authors', queryset=Author.objects.all())
prefetch_2 = Prefetch('books', queryset=Book.objects.all())
self.assertEqual(prefetch_1, prefetch_1)
self.assertEqual(prefetch_1, mock.ANY)
self.assertNotEqual(prefetch_1, prefetch_2)
def test_forward_m2m_to_attr_conflict(self):

View File

@ -1,3 +1,5 @@
from unittest import mock
from django.http import HttpRequest
from django.template import (
Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist,
@ -18,6 +20,7 @@ class ContextTests(SimpleTestCase):
self.assertEqual(c.pop(), {"a": 2})
self.assertEqual(c["a"], 1)
self.assertEqual(c.get("foo", 42), 42)
self.assertEqual(c, mock.ANY)
def test_push_context_manager(self):
c = Context({"a": 1})

View File

@ -3,7 +3,7 @@ import re
import types
from datetime import datetime, timedelta
from decimal import Decimal
from unittest import TestCase
from unittest import TestCase, mock
from django.core.exceptions import ValidationError
from django.core.files.base import ContentFile
@ -424,6 +424,7 @@ class TestValidatorEquality(TestCase):
MaxValueValidator(44),
MaxValueValidator(44),
)
self.assertEqual(MaxValueValidator(44), mock.ANY)
self.assertNotEqual(
MaxValueValidator(44),
MinValueValidator(44),