Fixed #30651 -- Made __eq__() methods return NotImplemented for not implemented comparisons.
Changed __eq__ to return NotImplemented instead of False if compared to an object of the same type, as is recommended by the Python data model reference. Now these models can be compared to ANY (or other objects with __eq__ overwritten) without returning False automatically.
This commit is contained in:
parent
6475e6318c
commit
54ea290e5b
|
@ -25,8 +25,9 @@ class Message:
|
||||||
self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None
|
self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, Message) and self.level == other.level and \
|
if not isinstance(other, Message):
|
||||||
self.message == other.message
|
return NotImplemented
|
||||||
|
return self.level == other.level and self.message == other.message
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.message)
|
return str(self.message)
|
||||||
|
|
|
@ -89,13 +89,14 @@ class ExclusionConstraint(BaseConstraint):
|
||||||
return path, args, kwargs
|
return path, args, kwargs
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (
|
if isinstance(other, self.__class__):
|
||||||
isinstance(other, self.__class__) and
|
return (
|
||||||
self.name == other.name and
|
self.name == other.name and
|
||||||
self.index_type == other.index_type and
|
self.index_type == other.index_type and
|
||||||
self.expressions == other.expressions and
|
self.expressions == other.expressions and
|
||||||
self.condition == other.condition
|
self.condition == other.condition
|
||||||
)
|
)
|
||||||
|
return super().__eq__(other)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<%s: index_type=%s, expressions=%s%s>' % (
|
return '<%s: index_type=%s, expressions=%s%s>' % (
|
||||||
|
|
|
@ -324,8 +324,9 @@ class BaseValidator:
|
||||||
raise ValidationError(self.message, code=self.code, params=params)
|
raise ValidationError(self.message, code=self.code, params=params)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, self.__class__):
|
||||||
|
return NotImplemented
|
||||||
return (
|
return (
|
||||||
isinstance(other, self.__class__) and
|
|
||||||
self.limit_value == other.limit_value and
|
self.limit_value == other.limit_value and
|
||||||
self.message == other.message and
|
self.message == other.message and
|
||||||
self.code == other.code
|
self.code == other.code
|
||||||
|
|
|
@ -522,7 +522,7 @@ class Model(metaclass=ModelBase):
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, Model):
|
if not isinstance(other, Model):
|
||||||
return False
|
return NotImplemented
|
||||||
if self._meta.concrete_model != other._meta.concrete_model:
|
if self._meta.concrete_model != other._meta.concrete_model:
|
||||||
return False
|
return False
|
||||||
my_pk = self.pk
|
my_pk = self.pk
|
||||||
|
|
|
@ -54,11 +54,9 @@ class CheckConstraint(BaseConstraint):
|
||||||
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
|
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (
|
if isinstance(other, CheckConstraint):
|
||||||
isinstance(other, CheckConstraint) and
|
return self.name == other.name and self.check == other.check
|
||||||
self.name == other.name and
|
return super().__eq__(other)
|
||||||
self.check == other.check
|
|
||||||
)
|
|
||||||
|
|
||||||
def deconstruct(self):
|
def deconstruct(self):
|
||||||
path, args, kwargs = super().deconstruct()
|
path, args, kwargs = super().deconstruct()
|
||||||
|
@ -106,12 +104,13 @@ class UniqueConstraint(BaseConstraint):
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (
|
if isinstance(other, UniqueConstraint):
|
||||||
isinstance(other, UniqueConstraint) and
|
return (
|
||||||
self.name == other.name and
|
self.name == other.name and
|
||||||
self.fields == other.fields and
|
self.fields == other.fields and
|
||||||
self.condition == other.condition
|
self.condition == other.condition
|
||||||
)
|
)
|
||||||
|
return super().__eq__(other)
|
||||||
|
|
||||||
def deconstruct(self):
|
def deconstruct(self):
|
||||||
path, args, kwargs = super().deconstruct()
|
path, args, kwargs = super().deconstruct()
|
||||||
|
|
|
@ -401,7 +401,9 @@ class BaseExpression:
|
||||||
return tuple(identity)
|
return tuple(identity)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, BaseExpression) and other.identity == self.identity
|
if not isinstance(other, BaseExpression):
|
||||||
|
return NotImplemented
|
||||||
|
return other.identity == self.identity
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.identity)
|
return hash(self.identity)
|
||||||
|
|
|
@ -112,4 +112,6 @@ class Index:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (self.__class__ == other.__class__) and (self.deconstruct() == other.deconstruct())
|
if self.__class__ == other.__class__:
|
||||||
|
return self.deconstruct() == other.deconstruct()
|
||||||
|
return NotImplemented
|
||||||
|
|
|
@ -1543,7 +1543,9 @@ class Prefetch:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return isinstance(other, Prefetch) and self.prefetch_to == other.prefetch_to
|
if not isinstance(other, Prefetch):
|
||||||
|
return NotImplemented
|
||||||
|
return self.prefetch_to == other.prefetch_to
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.__class__, self.prefetch_to))
|
return hash((self.__class__, self.prefetch_to))
|
||||||
|
|
|
@ -309,8 +309,9 @@ class FilteredRelation:
|
||||||
self.path = []
|
self.path = []
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, self.__class__):
|
||||||
|
return NotImplemented
|
||||||
return (
|
return (
|
||||||
isinstance(other, self.__class__) and
|
|
||||||
self.relation_name == other.relation_name and
|
self.relation_name == other.relation_name and
|
||||||
self.alias == other.alias and
|
self.alias == other.alias and
|
||||||
self.condition == other.condition
|
self.condition == other.condition
|
||||||
|
|
|
@ -124,12 +124,10 @@ class BaseContext:
|
||||||
"""
|
"""
|
||||||
Compare two contexts by comparing theirs 'dicts' attributes.
|
Compare two contexts by comparing theirs 'dicts' attributes.
|
||||||
"""
|
"""
|
||||||
return (
|
if not isinstance(other, BaseContext):
|
||||||
isinstance(other, BaseContext) and
|
return NotImplemented
|
||||||
# because dictionaries can be put in different order
|
# flatten dictionaries because they can be put in a different order.
|
||||||
# we have to flatten them like in templates
|
return self.flatten() == other.flatten()
|
||||||
self.flatten() == other.flatten()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Context(BaseContext):
|
class Context(BaseContext):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
|
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
|
||||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
|
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
|
||||||
|
@ -354,6 +355,7 @@ class ModelTest(TestCase):
|
||||||
self.assertNotEqual(object(), Article(id=1))
|
self.assertNotEqual(object(), Article(id=1))
|
||||||
a = Article()
|
a = Article()
|
||||||
self.assertEqual(a, a)
|
self.assertEqual(a, a)
|
||||||
|
self.assertEqual(a, mock.ANY)
|
||||||
self.assertNotEqual(Article(), a)
|
self.assertNotEqual(Article(), a)
|
||||||
|
|
||||||
def test_hash(self):
|
def test_hash(self):
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.db import IntegrityError, connection, models
|
from django.db import IntegrityError, connection, models
|
||||||
from django.db.models.constraints import BaseConstraint
|
from django.db.models.constraints import BaseConstraint
|
||||||
|
@ -39,6 +41,7 @@ class CheckConstraintTests(TestCase):
|
||||||
models.CheckConstraint(check=check1, name='price'),
|
models.CheckConstraint(check=check1, name='price'),
|
||||||
models.CheckConstraint(check=check1, name='price'),
|
models.CheckConstraint(check=check1, name='price'),
|
||||||
)
|
)
|
||||||
|
self.assertEqual(models.CheckConstraint(check=check1, name='price'), mock.ANY)
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
models.CheckConstraint(check=check1, name='price'),
|
models.CheckConstraint(check=check1, name='price'),
|
||||||
models.CheckConstraint(check=check1, name='price2'),
|
models.CheckConstraint(check=check1, name='price2'),
|
||||||
|
@ -102,6 +105,10 @@ class UniqueConstraintTests(TestCase):
|
||||||
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
|
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
|
||||||
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
|
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
|
||||||
)
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
|
||||||
|
mock.ANY,
|
||||||
|
)
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
|
models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
|
||||||
models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'),
|
models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'),
|
||||||
|
|
|
@ -3,6 +3,7 @@ import pickle
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import DatabaseError, connection, models
|
from django.db import DatabaseError, connection, models
|
||||||
|
@ -965,6 +966,7 @@ class SimpleExpressionTests(SimpleTestCase):
|
||||||
Expression(models.IntegerField()),
|
Expression(models.IntegerField()),
|
||||||
Expression(output_field=models.IntegerField())
|
Expression(output_field=models.IntegerField())
|
||||||
)
|
)
|
||||||
|
self.assertEqual(Expression(models.IntegerField()), mock.ANY)
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
Expression(models.IntegerField()),
|
Expression(models.IntegerField()),
|
||||||
Expression(models.CharField())
|
Expression(models.CharField())
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.db import connection, transaction
|
from django.db import connection, transaction
|
||||||
from django.db.models import Case, Count, F, FilteredRelation, Q, When
|
from django.db.models import Case, Count, F, FilteredRelation, Q, When
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
@ -323,6 +325,9 @@ class FilteredRelationTests(TestCase):
|
||||||
[self.book1]
|
[self.book1]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_eq(self):
|
||||||
|
self.assertEqual(FilteredRelation('book', condition=Q(book__title='b')), mock.ANY)
|
||||||
|
|
||||||
|
|
||||||
class FilteredRelationAggregationTests(TestCase):
|
class FilteredRelationAggregationTests(TestCase):
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.contrib.messages import constants
|
from django.contrib.messages import constants
|
||||||
from django.contrib.messages.storage.base import Message
|
from django.contrib.messages.storage.base import Message
|
||||||
from django.test import SimpleTestCase
|
from django.test import SimpleTestCase
|
||||||
|
@ -9,6 +11,7 @@ class MessageTests(SimpleTestCase):
|
||||||
msg_2 = Message(constants.INFO, 'Test message 2')
|
msg_2 = Message(constants.INFO, 'Test message 2')
|
||||||
msg_3 = Message(constants.WARNING, 'Test message 1')
|
msg_3 = Message(constants.WARNING, 'Test message 1')
|
||||||
self.assertEqual(msg_1, msg_1)
|
self.assertEqual(msg_1, msg_1)
|
||||||
|
self.assertEqual(msg_1, mock.ANY)
|
||||||
self.assertNotEqual(msg_1, msg_2)
|
self.assertNotEqual(msg_1, msg_2)
|
||||||
self.assertNotEqual(msg_1, msg_3)
|
self.assertNotEqual(msg_1, msg_3)
|
||||||
self.assertNotEqual(msg_2, msg_3)
|
self.assertNotEqual(msg_2, msg_3)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import connection, models
|
from django.db import connection, models
|
||||||
from django.db.models.query_utils import Q
|
from django.db.models.query_utils import Q
|
||||||
|
@ -28,6 +30,7 @@ class SimpleIndexesTests(SimpleTestCase):
|
||||||
same_index.model = Book
|
same_index.model = Book
|
||||||
another_index.model = Book
|
another_index.model = Book
|
||||||
self.assertEqual(index, same_index)
|
self.assertEqual(index, same_index)
|
||||||
|
self.assertEqual(index, mock.ANY)
|
||||||
self.assertNotEqual(index, another_index)
|
self.assertNotEqual(index, another_index)
|
||||||
|
|
||||||
def test_index_fields_type(self):
|
def test_index_fields_type(self):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.db import connection, transaction
|
from django.db import connection, transaction
|
||||||
from django.db.models import F, Func, Q
|
from django.db.models import F, Func, Q
|
||||||
|
@ -175,6 +176,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
|
||||||
condition=Q(cancelled=False),
|
condition=Q(cancelled=False),
|
||||||
)
|
)
|
||||||
self.assertEqual(constraint_1, constraint_1)
|
self.assertEqual(constraint_1, constraint_1)
|
||||||
|
self.assertEqual(constraint_1, mock.ANY)
|
||||||
self.assertNotEqual(constraint_1, constraint_2)
|
self.assertNotEqual(constraint_1, constraint_2)
|
||||||
self.assertNotEqual(constraint_1, constraint_3)
|
self.assertNotEqual(constraint_1, constraint_3)
|
||||||
self.assertNotEqual(constraint_2, constraint_3)
|
self.assertNotEqual(constraint_2, constraint_3)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
|
@ -243,6 +245,7 @@ class PrefetchRelatedTests(TestDataMixin, TestCase):
|
||||||
prefetch_1 = Prefetch('authors', queryset=Author.objects.all())
|
prefetch_1 = Prefetch('authors', queryset=Author.objects.all())
|
||||||
prefetch_2 = Prefetch('books', queryset=Book.objects.all())
|
prefetch_2 = Prefetch('books', queryset=Book.objects.all())
|
||||||
self.assertEqual(prefetch_1, prefetch_1)
|
self.assertEqual(prefetch_1, prefetch_1)
|
||||||
|
self.assertEqual(prefetch_1, mock.ANY)
|
||||||
self.assertNotEqual(prefetch_1, prefetch_2)
|
self.assertNotEqual(prefetch_1, prefetch_2)
|
||||||
|
|
||||||
def test_forward_m2m_to_attr_conflict(self):
|
def test_forward_m2m_to_attr_conflict(self):
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.template import (
|
from django.template import (
|
||||||
Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist,
|
Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist,
|
||||||
|
@ -18,6 +20,7 @@ class ContextTests(SimpleTestCase):
|
||||||
self.assertEqual(c.pop(), {"a": 2})
|
self.assertEqual(c.pop(), {"a": 2})
|
||||||
self.assertEqual(c["a"], 1)
|
self.assertEqual(c["a"], 1)
|
||||||
self.assertEqual(c.get("foo", 42), 42)
|
self.assertEqual(c.get("foo", 42), 42)
|
||||||
|
self.assertEqual(c, mock.ANY)
|
||||||
|
|
||||||
def test_push_context_manager(self):
|
def test_push_context_manager(self):
|
||||||
c = Context({"a": 1})
|
c = Context({"a": 1})
|
||||||
|
|
|
@ -3,7 +3,7 @@ import re
|
||||||
import types
|
import types
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from unittest import TestCase
|
from unittest import TestCase, mock
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.core.files.base import ContentFile
|
from django.core.files.base import ContentFile
|
||||||
|
@ -424,6 +424,7 @@ class TestValidatorEquality(TestCase):
|
||||||
MaxValueValidator(44),
|
MaxValueValidator(44),
|
||||||
MaxValueValidator(44),
|
MaxValueValidator(44),
|
||||||
)
|
)
|
||||||
|
self.assertEqual(MaxValueValidator(44), mock.ANY)
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
MaxValueValidator(44),
|
MaxValueValidator(44),
|
||||||
MinValueValidator(44),
|
MinValueValidator(44),
|
||||||
|
|
Loading…
Reference in New Issue