Refs #11964, #26167 -- Made Expressions deconstructible.

This commit is contained in:
Ian Foote 2016-11-05 15:49:29 +00:00 committed by Tim Graham
parent 3dcc351691
commit 19b2dfd1bf
2 changed files with 107 additions and 8 deletions

View File

@ -5,6 +5,7 @@ from django.core.exceptions import EmptyResultSet, FieldError
from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.query_utils import Q
from django.utils.deconstruct import deconstructible
from django.utils.functional import cached_property
@ -117,6 +118,7 @@ class Combinable:
)
@deconstructible
class BaseExpression:
"""
Base class for all query expressions.
@ -339,6 +341,27 @@ class BaseExpression:
if expr:
yield from expr.flatten()
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
path, args, kwargs = self.deconstruct()
other_path, other_args, other_kwargs = other.deconstruct()
if (path, args) == (other_path, other_args):
kwargs = kwargs.copy()
other_kwargs = other_kwargs.copy()
output_field = type(kwargs.pop('output_field', None))
other_output_field = type(other_kwargs.pop('output_field', None))
if output_field == other_output_field:
return kwargs == other_kwargs
return False
def __hash__(self):
path, args, kwargs = self.deconstruct()
h = hash(path) ^ hash(args)
for kwarg in kwargs.items():
h ^= hash(kwarg)
return h
class Expression(BaseExpression, Combinable):
"""
@ -445,6 +468,7 @@ class TemporalSubtraction(CombinedExpression):
return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs)
@deconstructible
class F(Combinable):
"""
An object capable of resolving references to existing query objects.
@ -468,6 +492,12 @@ class F(Combinable):
def desc(self, **kwargs):
return OrderBy(self, descending=True, **kwargs)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self):
return hash(self.name)
class ResolvedOuterRef(F):
"""
@ -647,6 +677,12 @@ class RawSQL(Expression):
def get_group_by_cols(self):
return [self]
def __hash__(self):
h = hash(self.sql) ^ hash(self._output_field)
for param in self.params:
h ^= hash(param)
return h
class Star(Expression):
def __repr__(self):

View File

@ -5,7 +5,7 @@ from copy import deepcopy
from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models, transaction
from django.db.models import TimeField, UUIDField
from django.db.models import CharField, TimeField, UUIDField
from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance,
)
@ -18,7 +18,9 @@ from django.db.models.functions import (
)
from django.db.models.sql import constants
from django.db.models.sql.datastructures import Join
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test import (
SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature,
)
from django.test.utils import Approximate
from .models import (
@ -653,17 +655,42 @@ class IterableLookupInnerExpressionsTests(TestCase):
self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
class ExpressionsTests(TestCase):
class FTests(SimpleTestCase):
def test_F_object_deepcopy(self):
"""
Make sure F objects can be deepcopied (#23492)
"""
def test_deepcopy(self):
f = F("foo")
g = deepcopy(f)
self.assertEqual(f.name, g.name)
def test_f_reuse(self):
def test_deconstruct(self):
f = F('name')
path, args, kwargs = f.deconstruct()
self.assertEqual(path, 'django.db.models.expressions.F')
self.assertEqual(args, (f.name,))
self.assertEqual(kwargs, {})
def test_equal(self):
f = F('name')
same_f = F('name')
other_f = F('username')
self.assertEqual(f, same_f)
self.assertNotEqual(f, other_f)
def test_hash(self):
d = {F('name'): 'Bob'}
self.assertIn(F('name'), d)
self.assertEqual(d[F('name')], 'Bob')
def test_not_equal_Value(self):
f = F('name')
value = Value('name')
self.assertNotEqual(f, value)
self.assertNotEqual(value, f)
class ExpressionsTests(TestCase):
def test_F_reuse(self):
f = F('id')
n = Number.objects.create(integer=-1)
c = Company.objects.create(
@ -1238,6 +1265,42 @@ class ValueTests(TestCase):
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
def test_deconstruct(self):
value = Value('name')
path, args, kwargs = value.deconstruct()
self.assertEqual(path, 'django.db.models.expressions.Value')
self.assertEqual(args, (value.value,))
self.assertEqual(kwargs, {})
def test_deconstruct_output_field(self):
value = Value('name', output_field=CharField())
path, args, kwargs = value.deconstruct()
self.assertEqual(path, 'django.db.models.expressions.Value')
self.assertEqual(args, (value.value,))
self.assertEqual(len(kwargs), 1)
self.assertEqual(kwargs['output_field'].deconstruct(), CharField().deconstruct())
def test_equal(self):
value = Value('name')
same_value = Value('name')
other_value = Value('username')
self.assertEqual(value, same_value)
self.assertNotEqual(value, other_value)
def test_hash(self):
d = {Value('name'): 'Bob'}
self.assertIn(Value('name'), d)
self.assertEqual(d[Value('name')], 'Bob')
def test_equal_output_field(self):
value = Value('name', output_field=CharField())
same_value = Value('name', output_field=CharField())
other_value = Value('name', output_field=TimeField())
no_output_field = Value('name')
self.assertEqual(value, same_value)
self.assertNotEqual(value, other_value)
self.assertNotEqual(value, no_output_field)
class ReprTests(TestCase):