Refs #11964, #26167 -- Made Expressions deconstructible.

This commit is contained in:
Ian Foote 2016-11-05 15:49:29 +00:00 committed by Tim Graham
parent 3dcc351691
commit 19b2dfd1bf
2 changed files with 107 additions and 8 deletions

View File

@ -5,6 +5,7 @@ from django.core.exceptions import EmptyResultSet, FieldError
from django.db.backends import utils as backend_utils from django.db.backends import utils as backend_utils
from django.db.models import fields from django.db.models import fields
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.utils.deconstruct import deconstructible
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -117,6 +118,7 @@ class Combinable:
) )
@deconstructible
class BaseExpression: class BaseExpression:
""" """
Base class for all query expressions. Base class for all query expressions.
@ -339,6 +341,27 @@ class BaseExpression:
if expr: if expr:
yield from expr.flatten() yield from expr.flatten()
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
path, args, kwargs = self.deconstruct()
other_path, other_args, other_kwargs = other.deconstruct()
if (path, args) == (other_path, other_args):
kwargs = kwargs.copy()
other_kwargs = other_kwargs.copy()
output_field = type(kwargs.pop('output_field', None))
other_output_field = type(other_kwargs.pop('output_field', None))
if output_field == other_output_field:
return kwargs == other_kwargs
return False
def __hash__(self):
path, args, kwargs = self.deconstruct()
h = hash(path) ^ hash(args)
for kwarg in kwargs.items():
h ^= hash(kwarg)
return h
class Expression(BaseExpression, Combinable): class Expression(BaseExpression, Combinable):
""" """
@ -445,6 +468,7 @@ class TemporalSubtraction(CombinedExpression):
return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs) return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs)
@deconstructible
class F(Combinable): class F(Combinable):
""" """
An object capable of resolving references to existing query objects. An object capable of resolving references to existing query objects.
@ -468,6 +492,12 @@ class F(Combinable):
def desc(self, **kwargs): def desc(self, **kwargs):
return OrderBy(self, descending=True, **kwargs) return OrderBy(self, descending=True, **kwargs)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self):
return hash(self.name)
class ResolvedOuterRef(F): class ResolvedOuterRef(F):
""" """
@ -647,6 +677,12 @@ class RawSQL(Expression):
def get_group_by_cols(self): def get_group_by_cols(self):
return [self] return [self]
def __hash__(self):
h = hash(self.sql) ^ hash(self._output_field)
for param in self.params:
h ^= hash(param)
return h
class Star(Expression): class Star(Expression):
def __repr__(self): def __repr__(self):

View File

@ -5,7 +5,7 @@ from copy import deepcopy
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import DatabaseError, connection, models, transaction from django.db import DatabaseError, connection, models, transaction
from django.db.models import TimeField, UUIDField from django.db.models import CharField, TimeField, UUIDField
from django.db.models.aggregates import ( from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance, Avg, Count, Max, Min, StdDev, Sum, Variance,
) )
@ -18,7 +18,9 @@ from django.db.models.functions import (
) )
from django.db.models.sql import constants from django.db.models.sql import constants
from django.db.models.sql.datastructures import Join from django.db.models.sql.datastructures import Join
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import (
SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature,
)
from django.test.utils import Approximate from django.test.utils import Approximate
from .models import ( from .models import (
@ -653,17 +655,42 @@ class IterableLookupInnerExpressionsTests(TestCase):
self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"]) self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
class ExpressionsTests(TestCase): class FTests(SimpleTestCase):
def test_F_object_deepcopy(self): def test_deepcopy(self):
"""
Make sure F objects can be deepcopied (#23492)
"""
f = F("foo") f = F("foo")
g = deepcopy(f) g = deepcopy(f)
self.assertEqual(f.name, g.name) self.assertEqual(f.name, g.name)
def test_f_reuse(self): def test_deconstruct(self):
f = F('name')
path, args, kwargs = f.deconstruct()
self.assertEqual(path, 'django.db.models.expressions.F')
self.assertEqual(args, (f.name,))
self.assertEqual(kwargs, {})
def test_equal(self):
f = F('name')
same_f = F('name')
other_f = F('username')
self.assertEqual(f, same_f)
self.assertNotEqual(f, other_f)
def test_hash(self):
d = {F('name'): 'Bob'}
self.assertIn(F('name'), d)
self.assertEqual(d[F('name')], 'Bob')
def test_not_equal_Value(self):
f = F('name')
value = Value('name')
self.assertNotEqual(f, value)
self.assertNotEqual(value, f)
class ExpressionsTests(TestCase):
def test_F_reuse(self):
f = F('id') f = F('id')
n = Number.objects.create(integer=-1) n = Number.objects.create(integer=-1)
c = Company.objects.create( c = Company.objects.create(
@ -1238,6 +1265,42 @@ class ValueTests(TestCase):
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField())) UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012')) self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
def test_deconstruct(self):
value = Value('name')
path, args, kwargs = value.deconstruct()
self.assertEqual(path, 'django.db.models.expressions.Value')
self.assertEqual(args, (value.value,))
self.assertEqual(kwargs, {})
def test_deconstruct_output_field(self):
value = Value('name', output_field=CharField())
path, args, kwargs = value.deconstruct()
self.assertEqual(path, 'django.db.models.expressions.Value')
self.assertEqual(args, (value.value,))
self.assertEqual(len(kwargs), 1)
self.assertEqual(kwargs['output_field'].deconstruct(), CharField().deconstruct())
def test_equal(self):
value = Value('name')
same_value = Value('name')
other_value = Value('username')
self.assertEqual(value, same_value)
self.assertNotEqual(value, other_value)
def test_hash(self):
d = {Value('name'): 'Bob'}
self.assertIn(Value('name'), d)
self.assertEqual(d[Value('name')], 'Bob')
def test_equal_output_field(self):
value = Value('name', output_field=CharField())
same_value = Value('name', output_field=CharField())
other_value = Value('name', output_field=TimeField())
no_output_field = Value('name')
self.assertEqual(value, same_value)
self.assertNotEqual(value, other_value)
self.assertNotEqual(value, no_output_field)
class ReprTests(TestCase): class ReprTests(TestCase):