diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 576ff400c46..43f40c24ec1 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -12,6 +12,7 @@ from django.db.models.query_utils import RegisterLookupMixin from django.utils.datastructures import OrderedSet from django.utils.deprecation import RemovedInDjango40Warning from django.utils.functional import cached_property +from django.utils.hashable import make_hashable class Lookup: @@ -143,6 +144,18 @@ class Lookup: def is_summary(self): return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False) + @property + def identity(self): + return self.__class__, self.lhs, self.rhs + + def __eq__(self, other): + if not isinstance(other, Lookup): + return NotImplemented + return self.identity == other.identity + + def __hash__(self): + return hash(make_hashable(self.identity)) + class Transform(RegisterLookupMixin, Func): """ diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index c2c347b3cf1..c9598b3bd44 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -114,17 +114,28 @@ class Join: self.join_field, self.nullable, filtered_relation=filtered_relation, ) - def equals(self, other, with_filtered_relation): + @property + def identity(self): return ( - isinstance(other, self.__class__) and - self.table_name == other.table_name and - self.parent_alias == other.parent_alias and - self.join_field == other.join_field and - (not with_filtered_relation or self.filtered_relation == other.filtered_relation) + self.__class__, + self.table_name, + self.parent_alias, + self.join_field, + self.filtered_relation, ) def __eq__(self, other): - return self.equals(other, with_filtered_relation=True) + if not isinstance(other, Join): + return NotImplemented + return self.identity == other.identity + + def __hash__(self): + return hash(self.identity) + + def equals(self, other, with_filtered_relation): + if with_filtered_relation: + return self == other + return self.identity[:-1] == other.identity[:-1] def demote(self): new = self.relabeled_clone({}) @@ -160,9 +171,17 @@ class BaseTable: def relabeled_clone(self, change_map): return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias)) + @property + def identity(self): + return self.__class__, self.table_name, self.table_alias + + def __eq__(self, other): + if not isinstance(other, BaseTable): + return NotImplemented + return self.identity == other.identity + + def __hash__(self): + return hash(self.identity) + def equals(self, other, with_filtered_relation): - return ( - isinstance(self, other.__class__) and - self.table_name == other.table_name and - self.table_alias == other.table_alias - ) + return self.identity == other.identity diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ee98984826e..8d76b436ee6 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -39,6 +39,7 @@ from django.db.models.sql.where import ( ) from django.utils.deprecation import RemovedInDjango40Warning from django.utils.functional import cached_property +from django.utils.hashable import make_hashable from django.utils.tree import Node __all__ = ['Query', 'RawQuery'] @@ -246,6 +247,14 @@ class Query(BaseExpression): for alias in self.alias_map: return alias + @property + def identity(self): + identity = ( + (arg, make_hashable(value)) + for arg, value in self.__dict__.items() + ) + return (self.__class__, *identity) + def __str__(self): """ Return the query as a string of SQL with the parameter values diff --git a/tests/lookup/test_lookups.py b/tests/lookup/test_lookups.py index c3aa48ddad7..4d90612048a 100644 --- a/tests/lookup/test_lookups.py +++ b/tests/lookup/test_lookups.py @@ -1,10 +1,34 @@ from datetime import datetime +from unittest import mock from django.db.models import DateTimeField, Value -from django.db.models.lookups import YearLookup +from django.db.models.lookups import Lookup, YearLookup from django.test import SimpleTestCase +class CustomLookup(Lookup): + pass + + +class LookupTests(SimpleTestCase): + def test_equality(self): + lookup = Lookup(Value(1), Value(2)) + self.assertEqual(lookup, lookup) + self.assertEqual(lookup, Lookup(lookup.lhs, lookup.rhs)) + self.assertEqual(lookup, mock.ANY) + self.assertNotEqual(lookup, Lookup(lookup.lhs, Value(3))) + self.assertNotEqual(lookup, Lookup(Value(3), lookup.rhs)) + self.assertNotEqual(lookup, CustomLookup(lookup.lhs, lookup.rhs)) + + def test_hash(self): + lookup = Lookup(Value(1), Value(2)) + self.assertEqual(hash(lookup), hash(lookup)) + self.assertEqual(hash(lookup), hash(Lookup(lookup.lhs, lookup.rhs))) + self.assertNotEqual(hash(lookup), hash(Lookup(lookup.lhs, Value(3)))) + self.assertNotEqual(hash(lookup), hash(Lookup(Value(3), lookup.rhs))) + self.assertNotEqual(hash(lookup), hash(CustomLookup(lookup.lhs, lookup.rhs))) + + class YearLookupTests(SimpleTestCase): def test_get_bound_params(self): look_up = YearLookup( diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py index 523fa607f07..5db9d961630 100644 --- a/tests/queries/test_query.py +++ b/tests/queries/test_query.py @@ -150,3 +150,31 @@ class TestQuery(SimpleTestCase): msg = 'Cannot filter against a non-conditional expression.' with self.assertRaisesMessage(TypeError, msg): query.build_where(Func(output_field=CharField())) + + def test_equality(self): + self.assertNotEqual( + Author.objects.all().query, + Author.objects.filter(item__name='foo').query, + ) + self.assertEqual( + Author.objects.filter(item__name='foo').query, + Author.objects.filter(item__name='foo').query, + ) + self.assertEqual( + Author.objects.filter(item__name='foo').query, + Author.objects.filter(Q(item__name='foo')).query, + ) + + def test_hash(self): + self.assertNotEqual( + hash(Author.objects.all().query), + hash(Author.objects.filter(item__name='foo').query) + ) + self.assertEqual( + hash(Author.objects.filter(item__name='foo').query), + hash(Author.objects.filter(item__name='foo').query), + ) + self.assertEqual( + hash(Author.objects.filter(item__name='foo').query), + hash(Author.objects.filter(Q(item__name='foo')).query), + )