Refs #14030 -- Added repr methods to all expressions

This commit is contained in:
Josh Smeaton 2015-01-27 13:40:32 +11:00
parent f218a2ff45
commit 7171bf755b
3 changed files with 111 additions and 6 deletions

View File

@ -94,6 +94,13 @@ class Count(Aggregate):
super(Count, self).__init__(
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
def __repr__(self):
return "{}({}, distinct={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.extra['distinct'] == '' else 'True',
)
def convert_value(self, value, connection, context):
if value is None:
return 0
@ -117,6 +124,13 @@ class StdDev(Aggregate):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'STDDEV_POP' else 'True',
)
def convert_value(self, value, connection, context):
if value is None:
return value
@ -135,6 +149,13 @@ class Variance(Aggregate):
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'VAR_POP' else 'True',
)
def convert_value(self, value, connection, context):
if value is None:
return value

View File

@ -340,6 +340,12 @@ class Expression(ExpressionNode):
self.lhs = lhs
self.rhs = rhs
def __repr__(self):
return "<{}: {}>".format(self.__class__.__name__, self)
def __str__(self):
return "{} {} {}".format(self.lhs, self.connector, self.rhs)
def get_source_expressions(self):
return [self.lhs, self.rhs]
@ -408,7 +414,7 @@ class DurationExpression(Expression):
return expression_wrapper % sql, expression_params
class F(CombinableMixin):
class F(Combinable):
"""
An object capable of resolving references to existing query objects.
"""
@ -419,6 +425,9 @@ class F(CombinableMixin):
"""
self.name = name
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.name)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
@ -446,6 +455,13 @@ class Func(ExpressionNode):
self.source_expressions = self._parse_expressions(*expressions)
self.extra = extra
def __repr__(self):
args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
extra = ', '.join(str(key) + '=' + str(val) for key, val in self.extra.items())
if extra:
return "{}({}, {})".format(self.__class__.__name__, args, extra)
return "{}({})".format(self.__class__.__name__, args)
def get_source_expressions(self):
return self.source_expressions
@ -504,6 +520,9 @@ class Value(ExpressionNode):
super(Value, self).__init__(output_field=output_field)
self.value = value
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.value)
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
val = self.value
@ -545,6 +564,9 @@ class RawSQL(ExpressionNode):
self.sql, self.params = sql, params
super(RawSQL, self).__init__(output_field=output_field)
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
def as_sql(self, compiler, connection):
return '(%s)' % self.sql, self.params
@ -556,6 +578,9 @@ class Random(ExpressionNode):
def __init__(self):
super(Random, self).__init__(output_field=fields.FloatField())
def __repr__(self):
return "Random()"
def as_sql(self, compiler, connection):
return connection.ops.random_function_sql(), []
@ -567,6 +592,10 @@ class Col(ExpressionNode):
super(Col, self).__init__(output_field=source)
self.alias, self.target = alias, target
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.target)
def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
@ -588,8 +617,10 @@ class Ref(ExpressionNode):
"""
def __init__(self, refs, source):
super(Ref, self).__init__()
self.source = source
self.refs = refs
self.refs, self.source = refs, source
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
def get_source_expressions(self):
return [self.source]
@ -743,6 +774,9 @@ class Date(ExpressionNode):
self.col = None
self.lookup_type = lookup_type
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)
def get_source_expressions(self):
return [self.col]
@ -792,6 +826,10 @@ class DateTime(ExpressionNode):
self.tzname = timezone._get_timezone_name(tzinfo)
self.tzinfo = tzinfo
def __repr__(self):
return "{}({}, {}, {})".format(
self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)
def get_source_expressions(self):
return [self.col]
@ -833,8 +871,6 @@ class DateTime(ExpressionNode):
class OrderBy(BaseExpression):
template = '%(expression)s %(ordering)s'
descending_template = 'DESC'
ascending_template = 'ASC'
def __init__(self, expression, descending=False):
self.descending = descending
@ -842,6 +878,10 @@ class OrderBy(BaseExpression):
raise ValueError('expression must be an expression type')
self.expression = expression
def __repr__(self):
return "{}({}, descending={})".format(
self.__class__.__name__, self.expression, self.descending)
def set_source_expressions(self, exprs):
self.expression = exprs[0]

View File

@ -6,10 +6,17 @@ import uuid
from django.core.exceptions import FieldError
from django.db import connection, transaction, DatabaseError
from django.db.models import F, Value, TimeField, UUIDField
from django.db.models import TimeField, UUIDField
from django.db.models.aggregates import Avg, Count, Max, Min, StdDev, Sum, Variance
from django.db.models.expressions import (
Case, Col, Date, DateTime, F, Func, OrderBy,
Random, RawSQL, Ref, Value, When
)
from django.db.models.functions import Coalesce, Concat, Length, Lower, Substr, Upper
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate
from django.utils import six
from django.utils.timezone import utc
from .models import Company, Employee, Number, Experiment, Time, UUID
@ -812,3 +819,40 @@ class ValueTests(TestCase):
UUID.objects.create()
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
class ReprTests(TestCase):
def test_expressions(self):
self.assertEqual(
repr(Case(When(a=1))),
"<Case: CASE WHEN <Q: (AND: ('a', 1))> THEN Value(None), ELSE Value(None)>"
)
self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)")
self.assertEqual(repr(Date('published', 'exact')), "Date(published, exact)")
self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, UTC)")
self.assertEqual(repr(F('published')), "F(published)")
self.assertEqual(repr(F('cost') + F('tax')), "<Expression: F(cost) + F(tax)>")
self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
self.assertEqual(repr(Random()), "Random()")
self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])")
self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))")
self.assertEqual(repr(Value(1)), "Value(1)")
def test_functions(self):
self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))")
self.assertEqual(repr(Concat('a', 'b')), "Concat(ConcatPair(F(a), F(b)))")
self.assertEqual(repr(Length('a')), "Length(F(a))")
self.assertEqual(repr(Lower('a')), "Lower(F(a))")
self.assertEqual(repr(Substr('a', 1, 3)), "Substr(F(a), Value(1), Value(3))")
self.assertEqual(repr(Upper('a')), "Upper(F(a))")
def test_aggregates(self):
self.assertEqual(repr(Avg('a')), "Avg(F(a))")
self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)")
self.assertEqual(repr(Max('a')), "Max(F(a))")
self.assertEqual(repr(Min('a')), "Min(F(a))")
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")