Fixed #29745 -- Based Expression equality on detailed initialization signature.

The old implementation considered objects initialized with an equivalent
signature different if some arguments were provided positionally instead of
as keyword arguments.

Refs #11964, #26167.
This commit is contained in:
Simon Charette 2018-10-02 19:15:20 -04:00 committed by Tim Graham
parent e4df8e6dc0
commit bc7e288ca9
2 changed files with 57 additions and 24 deletions

View File

@ -1,5 +1,6 @@
import copy import copy
import datetime import datetime
import inspect
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
@ -137,6 +138,16 @@ class Combinable:
) )
def make_hashable(value):
if isinstance(value, list):
return tuple(map(make_hashable, value))
if isinstance(value, dict):
return tuple([
(key, make_hashable(nested_value)) for key, nested_value in value.items()
])
return value
@deconstructible @deconstructible
class BaseExpression: class BaseExpression:
"""Base class for all query expressions.""" """Base class for all query expressions."""
@ -360,28 +371,27 @@ class BaseExpression:
if expr: if expr:
yield from expr.flatten() yield from expr.flatten()
@cached_property
def identity(self):
constructor_signature = inspect.signature(self.__init__)
args, kwargs = self._constructor_args
signature = constructor_signature.bind_partial(*args, **kwargs)
signature.apply_defaults()
arguments = signature.arguments.items()
identity = [self.__class__]
for arg, value in arguments:
if isinstance(value, fields.Field):
value = type(value)
else:
value = make_hashable(value)
identity.append((arg, value))
return tuple(identity)
def __eq__(self, other): def __eq__(self, other):
if self.__class__ != other.__class__: return isinstance(other, BaseExpression) and other.identity == self.identity
return False
path, args, kwargs = self.deconstruct()
other_path, other_args, other_kwargs = other.deconstruct()
if (path, args) == (other_path, other_args):
kwargs = kwargs.copy()
other_kwargs = other_kwargs.copy()
output_field = type(kwargs.pop('output_field', None))
other_output_field = type(other_kwargs.pop('output_field', None))
if output_field == other_output_field:
return kwargs == other_kwargs
return False
def __hash__(self): def __hash__(self):
path, args, kwargs = self.deconstruct() return hash(self.identity)
kwargs = kwargs.copy()
output_field = type(kwargs.pop('output_field', None))
return hash((path, output_field) + args + tuple([
(key, tuple(value)) if isinstance(value, list) else (key, value)
for key, value in kwargs.items()
]))
class Expression(BaseExpression, Combinable): class Expression(BaseExpression, Combinable):
@ -695,9 +705,6 @@ class RawSQL(Expression):
def get_group_by_cols(self): def get_group_by_cols(self):
return [self] return [self]
def __hash__(self):
return hash((self.sql, self.output_field) + tuple(self.params))
class Star(Expression): class Star(Expression):
def __repr__(self): def __repr__(self):

View File

@ -11,8 +11,9 @@ from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance, Avg, Count, Max, Min, StdDev, Sum, Variance,
) )
from django.db.models.expressions import ( from django.db.models.expressions import (
Case, Col, Combinable, Exists, ExpressionList, ExpressionWrapper, F, Func, Case, Col, Combinable, Exists, Expression, ExpressionList,
OrderBy, OuterRef, Random, RawSQL, Ref, Subquery, Value, When, ExpressionWrapper, F, Func, OrderBy, OuterRef, Random, RawSQL, Ref,
Subquery, Value, When,
) )
from django.db.models.functions import ( from django.db.models.functions import (
Coalesce, Concat, Length, Lower, Substr, Upper, Coalesce, Concat, Length, Lower, Substr, Upper,
@ -822,6 +823,31 @@ class ExpressionsTests(TestCase):
) )
class SimpleExpressionTests(SimpleTestCase):
def test_equal(self):
self.assertEqual(Expression(), Expression())
self.assertEqual(
Expression(models.IntegerField()),
Expression(output_field=models.IntegerField())
)
self.assertNotEqual(
Expression(models.IntegerField()),
Expression(models.CharField())
)
def test_hash(self):
self.assertEqual(hash(Expression()), hash(Expression()))
self.assertEqual(
hash(Expression(models.IntegerField())),
hash(Expression(output_field=models.IntegerField()))
)
self.assertNotEqual(
hash(Expression(models.IntegerField())),
hash(Expression(models.CharField())),
)
class ExpressionsNumericTests(TestCase): class ExpressionsNumericTests(TestCase):
def setUp(self): def setUp(self):