parent
3dcc351691
commit
19b2dfd1bf
|
@ -5,6 +5,7 @@ from django.core.exceptions import EmptyResultSet, FieldError
|
|||
from django.db.backends import utils as backend_utils
|
||||
from django.db.models import fields
|
||||
from django.db.models.query_utils import Q
|
||||
from django.utils.deconstruct import deconstructible
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
|
@ -117,6 +118,7 @@ class Combinable:
|
|||
)
|
||||
|
||||
|
||||
@deconstructible
|
||||
class BaseExpression:
|
||||
"""
|
||||
Base class for all query expressions.
|
||||
|
@ -339,6 +341,27 @@ class BaseExpression:
|
|||
if expr:
|
||||
yield from expr.flatten()
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.__class__ != other.__class__:
|
||||
return False
|
||||
path, args, kwargs = self.deconstruct()
|
||||
other_path, other_args, other_kwargs = other.deconstruct()
|
||||
if (path, args) == (other_path, other_args):
|
||||
kwargs = kwargs.copy()
|
||||
other_kwargs = other_kwargs.copy()
|
||||
output_field = type(kwargs.pop('output_field', None))
|
||||
other_output_field = type(other_kwargs.pop('output_field', None))
|
||||
if output_field == other_output_field:
|
||||
return kwargs == other_kwargs
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
path, args, kwargs = self.deconstruct()
|
||||
h = hash(path) ^ hash(args)
|
||||
for kwarg in kwargs.items():
|
||||
h ^= hash(kwarg)
|
||||
return h
|
||||
|
||||
|
||||
class Expression(BaseExpression, Combinable):
|
||||
"""
|
||||
|
@ -445,6 +468,7 @@ class TemporalSubtraction(CombinedExpression):
|
|||
return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs)
|
||||
|
||||
|
||||
@deconstructible
|
||||
class F(Combinable):
|
||||
"""
|
||||
An object capable of resolving references to existing query objects.
|
||||
|
@ -468,6 +492,12 @@ class F(Combinable):
|
|||
def desc(self, **kwargs):
|
||||
return OrderBy(self, descending=True, **kwargs)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__class__ == other.__class__ and self.name == other.name
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
|
||||
class ResolvedOuterRef(F):
|
||||
"""
|
||||
|
@ -647,6 +677,12 @@ class RawSQL(Expression):
|
|||
def get_group_by_cols(self):
|
||||
return [self]
|
||||
|
||||
def __hash__(self):
|
||||
h = hash(self.sql) ^ hash(self._output_field)
|
||||
for param in self.params:
|
||||
h ^= hash(param)
|
||||
return h
|
||||
|
||||
|
||||
class Star(Expression):
|
||||
def __repr__(self):
|
||||
|
|
|
@ -5,7 +5,7 @@ from copy import deepcopy
|
|||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DatabaseError, connection, models, transaction
|
||||
from django.db.models import TimeField, UUIDField
|
||||
from django.db.models import CharField, TimeField, UUIDField
|
||||
from django.db.models.aggregates import (
|
||||
Avg, Count, Max, Min, StdDev, Sum, Variance,
|
||||
)
|
||||
|
@ -18,7 +18,9 @@ from django.db.models.functions import (
|
|||
)
|
||||
from django.db.models.sql import constants
|
||||
from django.db.models.sql.datastructures import Join
|
||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
from django.test import (
|
||||
SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature,
|
||||
)
|
||||
from django.test.utils import Approximate
|
||||
|
||||
from .models import (
|
||||
|
@ -653,17 +655,42 @@ class IterableLookupInnerExpressionsTests(TestCase):
|
|||
self.assertQuerysetEqual(queryset, ["<Result: Result at 2016-02-04 15:00:00>"])
|
||||
|
||||
|
||||
class ExpressionsTests(TestCase):
|
||||
class FTests(SimpleTestCase):
|
||||
|
||||
def test_F_object_deepcopy(self):
|
||||
"""
|
||||
Make sure F objects can be deepcopied (#23492)
|
||||
"""
|
||||
def test_deepcopy(self):
|
||||
f = F("foo")
|
||||
g = deepcopy(f)
|
||||
self.assertEqual(f.name, g.name)
|
||||
|
||||
def test_f_reuse(self):
|
||||
def test_deconstruct(self):
|
||||
f = F('name')
|
||||
path, args, kwargs = f.deconstruct()
|
||||
self.assertEqual(path, 'django.db.models.expressions.F')
|
||||
self.assertEqual(args, (f.name,))
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
def test_equal(self):
|
||||
f = F('name')
|
||||
same_f = F('name')
|
||||
other_f = F('username')
|
||||
self.assertEqual(f, same_f)
|
||||
self.assertNotEqual(f, other_f)
|
||||
|
||||
def test_hash(self):
|
||||
d = {F('name'): 'Bob'}
|
||||
self.assertIn(F('name'), d)
|
||||
self.assertEqual(d[F('name')], 'Bob')
|
||||
|
||||
def test_not_equal_Value(self):
|
||||
f = F('name')
|
||||
value = Value('name')
|
||||
self.assertNotEqual(f, value)
|
||||
self.assertNotEqual(value, f)
|
||||
|
||||
|
||||
class ExpressionsTests(TestCase):
|
||||
|
||||
def test_F_reuse(self):
|
||||
f = F('id')
|
||||
n = Number.objects.create(integer=-1)
|
||||
c = Company.objects.create(
|
||||
|
@ -1238,6 +1265,42 @@ class ValueTests(TestCase):
|
|||
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
|
||||
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
|
||||
|
||||
def test_deconstruct(self):
|
||||
value = Value('name')
|
||||
path, args, kwargs = value.deconstruct()
|
||||
self.assertEqual(path, 'django.db.models.expressions.Value')
|
||||
self.assertEqual(args, (value.value,))
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
def test_deconstruct_output_field(self):
|
||||
value = Value('name', output_field=CharField())
|
||||
path, args, kwargs = value.deconstruct()
|
||||
self.assertEqual(path, 'django.db.models.expressions.Value')
|
||||
self.assertEqual(args, (value.value,))
|
||||
self.assertEqual(len(kwargs), 1)
|
||||
self.assertEqual(kwargs['output_field'].deconstruct(), CharField().deconstruct())
|
||||
|
||||
def test_equal(self):
|
||||
value = Value('name')
|
||||
same_value = Value('name')
|
||||
other_value = Value('username')
|
||||
self.assertEqual(value, same_value)
|
||||
self.assertNotEqual(value, other_value)
|
||||
|
||||
def test_hash(self):
|
||||
d = {Value('name'): 'Bob'}
|
||||
self.assertIn(Value('name'), d)
|
||||
self.assertEqual(d[Value('name')], 'Bob')
|
||||
|
||||
def test_equal_output_field(self):
|
||||
value = Value('name', output_field=CharField())
|
||||
same_value = Value('name', output_field=CharField())
|
||||
other_value = Value('name', output_field=TimeField())
|
||||
no_output_field = Value('name')
|
||||
self.assertEqual(value, same_value)
|
||||
self.assertNotEqual(value, other_value)
|
||||
self.assertNotEqual(value, no_output_field)
|
||||
|
||||
|
||||
class ReprTests(TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue