From 19b2dfd1bfe7fd716dd3d8bfa5f972070d83b42f Mon Sep 17 00:00:00 2001 From: Ian Foote Date: Sat, 5 Nov 2016 15:49:29 +0000 Subject: [PATCH] Refs #11964, #26167 -- Made Expressions deconstructible. --- django/db/models/expressions.py | 36 +++++++++++++++ tests/expressions/tests.py | 79 +++++++++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index cfa23ccd2db..2528da22498 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -5,6 +5,7 @@ from django.core.exceptions import EmptyResultSet, FieldError from django.db.backends import utils as backend_utils from django.db.models import fields from django.db.models.query_utils import Q +from django.utils.deconstruct import deconstructible from django.utils.functional import cached_property @@ -117,6 +118,7 @@ class Combinable: ) +@deconstructible class BaseExpression: """ Base class for all query expressions. @@ -339,6 +341,27 @@ class BaseExpression: if expr: yield from expr.flatten() + def __eq__(self, other): + if self.__class__ != other.__class__: + return False + path, args, kwargs = self.deconstruct() + other_path, other_args, other_kwargs = other.deconstruct() + if (path, args) == (other_path, other_args): + kwargs = kwargs.copy() + other_kwargs = other_kwargs.copy() + output_field = type(kwargs.pop('output_field', None)) + other_output_field = type(other_kwargs.pop('output_field', None)) + if output_field == other_output_field: + return kwargs == other_kwargs + return False + + def __hash__(self): + path, args, kwargs = self.deconstruct() + h = hash(path) ^ hash(args) + for kwarg in kwargs.items(): + h ^= hash(kwarg) + return h + class Expression(BaseExpression, Combinable): """ @@ -445,6 +468,7 @@ class TemporalSubtraction(CombinedExpression): return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs) +@deconstructible class F(Combinable): """ An object capable of resolving references to existing query objects. @@ -468,6 +492,12 @@ class F(Combinable): def desc(self, **kwargs): return OrderBy(self, descending=True, **kwargs) + def __eq__(self, other): + return self.__class__ == other.__class__ and self.name == other.name + + def __hash__(self): + return hash(self.name) + class ResolvedOuterRef(F): """ @@ -647,6 +677,12 @@ class RawSQL(Expression): def get_group_by_cols(self): return [self] + def __hash__(self): + h = hash(self.sql) ^ hash(self._output_field) + for param in self.params: + h ^= hash(param) + return h + class Star(Expression): def __repr__(self): diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index fc05b0a2808..38258e0dd06 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -5,7 +5,7 @@ from copy import deepcopy from django.core.exceptions import FieldError from django.db import DatabaseError, connection, models, transaction -from django.db.models import TimeField, UUIDField +from django.db.models import CharField, TimeField, UUIDField from django.db.models.aggregates import ( Avg, Count, Max, Min, StdDev, Sum, Variance, ) @@ -18,7 +18,9 @@ from django.db.models.functions import ( ) from django.db.models.sql import constants from django.db.models.sql.datastructures import Join -from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature +from django.test import ( + SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature, +) from django.test.utils import Approximate from .models import ( @@ -653,17 +655,42 @@ class IterableLookupInnerExpressionsTests(TestCase): self.assertQuerysetEqual(queryset, [""]) -class ExpressionsTests(TestCase): +class FTests(SimpleTestCase): - def test_F_object_deepcopy(self): - """ - Make sure F objects can be deepcopied (#23492) - """ + def test_deepcopy(self): f = F("foo") g = deepcopy(f) self.assertEqual(f.name, g.name) - def test_f_reuse(self): + def test_deconstruct(self): + f = F('name') + path, args, kwargs = f.deconstruct() + self.assertEqual(path, 'django.db.models.expressions.F') + self.assertEqual(args, (f.name,)) + self.assertEqual(kwargs, {}) + + def test_equal(self): + f = F('name') + same_f = F('name') + other_f = F('username') + self.assertEqual(f, same_f) + self.assertNotEqual(f, other_f) + + def test_hash(self): + d = {F('name'): 'Bob'} + self.assertIn(F('name'), d) + self.assertEqual(d[F('name')], 'Bob') + + def test_not_equal_Value(self): + f = F('name') + value = Value('name') + self.assertNotEqual(f, value) + self.assertNotEqual(value, f) + + +class ExpressionsTests(TestCase): + + def test_F_reuse(self): f = F('id') n = Number.objects.create(integer=-1) c = Company.objects.create( @@ -1238,6 +1265,42 @@ class ValueTests(TestCase): UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField())) self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012')) + def test_deconstruct(self): + value = Value('name') + path, args, kwargs = value.deconstruct() + self.assertEqual(path, 'django.db.models.expressions.Value') + self.assertEqual(args, (value.value,)) + self.assertEqual(kwargs, {}) + + def test_deconstruct_output_field(self): + value = Value('name', output_field=CharField()) + path, args, kwargs = value.deconstruct() + self.assertEqual(path, 'django.db.models.expressions.Value') + self.assertEqual(args, (value.value,)) + self.assertEqual(len(kwargs), 1) + self.assertEqual(kwargs['output_field'].deconstruct(), CharField().deconstruct()) + + def test_equal(self): + value = Value('name') + same_value = Value('name') + other_value = Value('username') + self.assertEqual(value, same_value) + self.assertNotEqual(value, other_value) + + def test_hash(self): + d = {Value('name'): 'Bob'} + self.assertIn(Value('name'), d) + self.assertEqual(d[Value('name')], 'Bob') + + def test_equal_output_field(self): + value = Value('name', output_field=CharField()) + same_value = Value('name', output_field=CharField()) + other_value = Value('name', output_field=TimeField()) + no_output_field = Value('name') + self.assertEqual(value, same_value) + self.assertNotEqual(value, other_value) + self.assertNotEqual(value, no_output_field) + class ReprTests(TestCase):