From bc7e288ca9554ac1a0a19941302dea19df1acd21 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Tue, 2 Oct 2018 19:15:20 -0400 Subject: [PATCH] Fixed #29745 -- Based Expression equality on detailed initialization signature. The old implementation considered objects initialized with an equivalent signature different if some arguments were provided positionally instead of as keyword arguments. Refs #11964, #26167. --- django/db/models/expressions.py | 51 +++++++++++++++++++-------------- tests/expressions/tests.py | 30 +++++++++++++++++-- 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 2532431821..2bf6316c2e 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1,5 +1,6 @@ import copy import datetime +import inspect from decimal import Decimal from django.core.exceptions import EmptyResultSet, FieldError @@ -137,6 +138,16 @@ class Combinable: ) +def make_hashable(value): + if isinstance(value, list): + return tuple(map(make_hashable, value)) + if isinstance(value, dict): + return tuple([ + (key, make_hashable(nested_value)) for key, nested_value in value.items() + ]) + return value + + @deconstructible class BaseExpression: """Base class for all query expressions.""" @@ -360,28 +371,27 @@ class BaseExpression: if expr: yield from expr.flatten() + @cached_property + def identity(self): + constructor_signature = inspect.signature(self.__init__) + args, kwargs = self._constructor_args + signature = constructor_signature.bind_partial(*args, **kwargs) + signature.apply_defaults() + arguments = signature.arguments.items() + identity = [self.__class__] + for arg, value in arguments: + if isinstance(value, fields.Field): + value = type(value) + else: + value = make_hashable(value) + identity.append((arg, value)) + return tuple(identity) + def __eq__(self, other): - if self.__class__ != other.__class__: - return False - path, args, kwargs = self.deconstruct() - other_path, other_args, other_kwargs = other.deconstruct() - if (path, args) == (other_path, other_args): - kwargs = kwargs.copy() - other_kwargs = other_kwargs.copy() - output_field = type(kwargs.pop('output_field', None)) - other_output_field = type(other_kwargs.pop('output_field', None)) - if output_field == other_output_field: - return kwargs == other_kwargs - return False + return isinstance(other, BaseExpression) and other.identity == self.identity def __hash__(self): - path, args, kwargs = self.deconstruct() - kwargs = kwargs.copy() - output_field = type(kwargs.pop('output_field', None)) - return hash((path, output_field) + args + tuple([ - (key, tuple(value)) if isinstance(value, list) else (key, value) - for key, value in kwargs.items() - ])) + return hash(self.identity) class Expression(BaseExpression, Combinable): @@ -695,9 +705,6 @@ class RawSQL(Expression): def get_group_by_cols(self): return [self] - def __hash__(self): - return hash((self.sql, self.output_field) + tuple(self.params)) - class Star(Expression): def __repr__(self): diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index c631151448..d3f86fcd92 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -11,8 +11,9 @@ from django.db.models.aggregates import ( Avg, Count, Max, Min, StdDev, Sum, Variance, ) from django.db.models.expressions import ( - Case, Col, Combinable, Exists, ExpressionList, ExpressionWrapper, F, Func, - OrderBy, OuterRef, Random, RawSQL, Ref, Subquery, Value, When, + Case, Col, Combinable, Exists, Expression, ExpressionList, + ExpressionWrapper, F, Func, OrderBy, OuterRef, Random, RawSQL, Ref, + Subquery, Value, When, ) from django.db.models.functions import ( Coalesce, Concat, Length, Lower, Substr, Upper, @@ -822,6 +823,31 @@ class ExpressionsTests(TestCase): ) +class SimpleExpressionTests(SimpleTestCase): + + def test_equal(self): + self.assertEqual(Expression(), Expression()) + self.assertEqual( + Expression(models.IntegerField()), + Expression(output_field=models.IntegerField()) + ) + self.assertNotEqual( + Expression(models.IntegerField()), + Expression(models.CharField()) + ) + + def test_hash(self): + self.assertEqual(hash(Expression()), hash(Expression())) + self.assertEqual( + hash(Expression(models.IntegerField())), + hash(Expression(output_field=models.IntegerField())) + ) + self.assertNotEqual( + hash(Expression(models.IntegerField())), + hash(Expression(models.CharField())), + ) + + class ExpressionsNumericTests(TestCase): def setUp(self):