Refs #14030 -- Added repr methods to all expressions
This commit is contained in:
parent
f218a2ff45
commit
7171bf755b
|
@ -94,6 +94,13 @@ class Count(Aggregate):
|
||||||
super(Count, self).__init__(
|
super(Count, self).__init__(
|
||||||
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
|
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, distinct={})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
|
||||||
|
'False' if self.extra['distinct'] == '' else 'True',
|
||||||
|
)
|
||||||
|
|
||||||
def convert_value(self, value, connection, context):
|
def convert_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return 0
|
return 0
|
||||||
|
@ -117,6 +124,13 @@ class StdDev(Aggregate):
|
||||||
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
||||||
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
|
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, sample={})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
|
||||||
|
'False' if self.function == 'STDDEV_POP' else 'True',
|
||||||
|
)
|
||||||
|
|
||||||
def convert_value(self, value, connection, context):
|
def convert_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
|
@ -135,6 +149,13 @@ class Variance(Aggregate):
|
||||||
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
||||||
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
|
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, sample={})".format(
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
|
||||||
|
'False' if self.function == 'VAR_POP' else 'True',
|
||||||
|
)
|
||||||
|
|
||||||
def convert_value(self, value, connection, context):
|
def convert_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -340,6 +340,12 @@ class Expression(ExpressionNode):
|
||||||
self.lhs = lhs
|
self.lhs = lhs
|
||||||
self.rhs = rhs
|
self.rhs = rhs
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<{}: {}>".format(self.__class__.__name__, self)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "{} {} {}".format(self.lhs, self.connector, self.rhs)
|
||||||
|
|
||||||
def get_source_expressions(self):
|
def get_source_expressions(self):
|
||||||
return [self.lhs, self.rhs]
|
return [self.lhs, self.rhs]
|
||||||
|
|
||||||
|
@ -408,7 +414,7 @@ class DurationExpression(Expression):
|
||||||
return expression_wrapper % sql, expression_params
|
return expression_wrapper % sql, expression_params
|
||||||
|
|
||||||
|
|
||||||
class F(CombinableMixin):
|
class F(Combinable):
|
||||||
"""
|
"""
|
||||||
An object capable of resolving references to existing query objects.
|
An object capable of resolving references to existing query objects.
|
||||||
"""
|
"""
|
||||||
|
@ -419,6 +425,9 @@ class F(CombinableMixin):
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({})".format(self.__class__.__name__, self.name)
|
||||||
|
|
||||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||||
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
|
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
|
||||||
|
|
||||||
|
@ -446,6 +455,13 @@ class Func(ExpressionNode):
|
||||||
self.source_expressions = self._parse_expressions(*expressions)
|
self.source_expressions = self._parse_expressions(*expressions)
|
||||||
self.extra = extra
|
self.extra = extra
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
||||||
|
extra = ', '.join(str(key) + '=' + str(val) for key, val in self.extra.items())
|
||||||
|
if extra:
|
||||||
|
return "{}({}, {})".format(self.__class__.__name__, args, extra)
|
||||||
|
return "{}({})".format(self.__class__.__name__, args)
|
||||||
|
|
||||||
def get_source_expressions(self):
|
def get_source_expressions(self):
|
||||||
return self.source_expressions
|
return self.source_expressions
|
||||||
|
|
||||||
|
@ -504,6 +520,9 @@ class Value(ExpressionNode):
|
||||||
super(Value, self).__init__(output_field=output_field)
|
super(Value, self).__init__(output_field=output_field)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({})".format(self.__class__.__name__, self.value)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
val = self.value
|
val = self.value
|
||||||
|
@ -545,6 +564,9 @@ class RawSQL(ExpressionNode):
|
||||||
self.sql, self.params = sql, params
|
self.sql, self.params = sql, params
|
||||||
super(RawSQL, self).__init__(output_field=output_field)
|
super(RawSQL, self).__init__(output_field=output_field)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
return '(%s)' % self.sql, self.params
|
return '(%s)' % self.sql, self.params
|
||||||
|
|
||||||
|
@ -556,6 +578,9 @@ class Random(ExpressionNode):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Random, self).__init__(output_field=fields.FloatField())
|
super(Random, self).__init__(output_field=fields.FloatField())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "Random()"
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
return connection.ops.random_function_sql(), []
|
return connection.ops.random_function_sql(), []
|
||||||
|
|
||||||
|
@ -567,6 +592,10 @@ class Col(ExpressionNode):
|
||||||
super(Col, self).__init__(output_field=source)
|
super(Col, self).__init__(output_field=source)
|
||||||
self.alias, self.target = alias, target
|
self.alias, self.target = alias, target
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, {})".format(
|
||||||
|
self.__class__.__name__, self.alias, self.target)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
qn = compiler.quote_name_unless_alias
|
qn = compiler.quote_name_unless_alias
|
||||||
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
||||||
|
@ -588,8 +617,10 @@ class Ref(ExpressionNode):
|
||||||
"""
|
"""
|
||||||
def __init__(self, refs, source):
|
def __init__(self, refs, source):
|
||||||
super(Ref, self).__init__()
|
super(Ref, self).__init__()
|
||||||
self.source = source
|
self.refs, self.source = refs, source
|
||||||
self.refs = refs
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
|
||||||
|
|
||||||
def get_source_expressions(self):
|
def get_source_expressions(self):
|
||||||
return [self.source]
|
return [self.source]
|
||||||
|
@ -743,6 +774,9 @@ class Date(ExpressionNode):
|
||||||
self.col = None
|
self.col = None
|
||||||
self.lookup_type = lookup_type
|
self.lookup_type = lookup_type
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)
|
||||||
|
|
||||||
def get_source_expressions(self):
|
def get_source_expressions(self):
|
||||||
return [self.col]
|
return [self.col]
|
||||||
|
|
||||||
|
@ -792,6 +826,10 @@ class DateTime(ExpressionNode):
|
||||||
self.tzname = timezone._get_timezone_name(tzinfo)
|
self.tzname = timezone._get_timezone_name(tzinfo)
|
||||||
self.tzinfo = tzinfo
|
self.tzinfo = tzinfo
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, {}, {})".format(
|
||||||
|
self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)
|
||||||
|
|
||||||
def get_source_expressions(self):
|
def get_source_expressions(self):
|
||||||
return [self.col]
|
return [self.col]
|
||||||
|
|
||||||
|
@ -833,8 +871,6 @@ class DateTime(ExpressionNode):
|
||||||
|
|
||||||
class OrderBy(BaseExpression):
|
class OrderBy(BaseExpression):
|
||||||
template = '%(expression)s %(ordering)s'
|
template = '%(expression)s %(ordering)s'
|
||||||
descending_template = 'DESC'
|
|
||||||
ascending_template = 'ASC'
|
|
||||||
|
|
||||||
def __init__(self, expression, descending=False):
|
def __init__(self, expression, descending=False):
|
||||||
self.descending = descending
|
self.descending = descending
|
||||||
|
@ -842,6 +878,10 @@ class OrderBy(BaseExpression):
|
||||||
raise ValueError('expression must be an expression type')
|
raise ValueError('expression must be an expression type')
|
||||||
self.expression = expression
|
self.expression = expression
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}({}, descending={})".format(
|
||||||
|
self.__class__.__name__, self.expression, self.descending)
|
||||||
|
|
||||||
def set_source_expressions(self, exprs):
|
def set_source_expressions(self, exprs):
|
||||||
self.expression = exprs[0]
|
self.expression = exprs[0]
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,17 @@ import uuid
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import connection, transaction, DatabaseError
|
from django.db import connection, transaction, DatabaseError
|
||||||
from django.db.models import F, Value, TimeField, UUIDField
|
from django.db.models import TimeField, UUIDField
|
||||||
|
from django.db.models.aggregates import Avg, Count, Max, Min, StdDev, Sum, Variance
|
||||||
|
from django.db.models.expressions import (
|
||||||
|
Case, Col, Date, DateTime, F, Func, OrderBy,
|
||||||
|
Random, RawSQL, Ref, Value, When
|
||||||
|
)
|
||||||
|
from django.db.models.functions import Coalesce, Concat, Length, Lower, Substr, Upper
|
||||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||||
from django.test.utils import Approximate
|
from django.test.utils import Approximate
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
from django.utils.timezone import utc
|
||||||
|
|
||||||
from .models import Company, Employee, Number, Experiment, Time, UUID
|
from .models import Company, Employee, Number, Experiment, Time, UUID
|
||||||
|
|
||||||
|
@ -812,3 +819,40 @@ class ValueTests(TestCase):
|
||||||
UUID.objects.create()
|
UUID.objects.create()
|
||||||
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
|
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
|
||||||
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
|
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
|
||||||
|
|
||||||
|
|
||||||
|
class ReprTests(TestCase):
|
||||||
|
|
||||||
|
def test_expressions(self):
|
||||||
|
self.assertEqual(
|
||||||
|
repr(Case(When(a=1))),
|
||||||
|
"<Case: CASE WHEN <Q: (AND: ('a', 1))> THEN Value(None), ELSE Value(None)>"
|
||||||
|
)
|
||||||
|
self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)")
|
||||||
|
self.assertEqual(repr(Date('published', 'exact')), "Date(published, exact)")
|
||||||
|
self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, UTC)")
|
||||||
|
self.assertEqual(repr(F('published')), "F(published)")
|
||||||
|
self.assertEqual(repr(F('cost') + F('tax')), "<Expression: F(cost) + F(tax)>")
|
||||||
|
self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
|
||||||
|
self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
|
||||||
|
self.assertEqual(repr(Random()), "Random()")
|
||||||
|
self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])")
|
||||||
|
self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))")
|
||||||
|
self.assertEqual(repr(Value(1)), "Value(1)")
|
||||||
|
|
||||||
|
def test_functions(self):
|
||||||
|
self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))")
|
||||||
|
self.assertEqual(repr(Concat('a', 'b')), "Concat(ConcatPair(F(a), F(b)))")
|
||||||
|
self.assertEqual(repr(Length('a')), "Length(F(a))")
|
||||||
|
self.assertEqual(repr(Lower('a')), "Lower(F(a))")
|
||||||
|
self.assertEqual(repr(Substr('a', 1, 3)), "Substr(F(a), Value(1), Value(3))")
|
||||||
|
self.assertEqual(repr(Upper('a')), "Upper(F(a))")
|
||||||
|
|
||||||
|
def test_aggregates(self):
|
||||||
|
self.assertEqual(repr(Avg('a')), "Avg(F(a))")
|
||||||
|
self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)")
|
||||||
|
self.assertEqual(repr(Max('a')), "Max(F(a))")
|
||||||
|
self.assertEqual(repr(Min('a')), "Min(F(a))")
|
||||||
|
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
|
||||||
|
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
|
||||||
|
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")
|
||||||
|
|
Loading…
Reference in New Issue