Refs #14030 -- Added repr methods to all expressions

This commit is contained in:
Josh Smeaton 2015-01-27 13:40:32 +11:00
parent f218a2ff45
commit 7171bf755b
3 changed files with 111 additions and 6 deletions

View File

@ -94,6 +94,13 @@ class Count(Aggregate):
super(Count, self).__init__( super(Count, self).__init__(
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
def __repr__(self):
return "{}({}, distinct={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.extra['distinct'] == '' else 'True',
)
def convert_value(self, value, connection, context): def convert_value(self, value, connection, context):
if value is None: if value is None:
return 0 return 0
@ -117,6 +124,13 @@ class StdDev(Aggregate):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra) super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'STDDEV_POP' else 'True',
)
def convert_value(self, value, connection, context): def convert_value(self, value, connection, context):
if value is None: if value is None:
return value return value
@ -135,6 +149,13 @@ class Variance(Aggregate):
self.function = 'VAR_SAMP' if sample else 'VAR_POP' self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super(Variance, self).__init__(expression, output_field=FloatField(), **extra) super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
def __repr__(self):
return "{}({}, sample={})".format(
self.__class__.__name__,
self.arg_joiner.join(str(arg) for arg in self.source_expressions),
'False' if self.function == 'VAR_POP' else 'True',
)
def convert_value(self, value, connection, context): def convert_value(self, value, connection, context):
if value is None: if value is None:
return value return value

View File

@ -340,6 +340,12 @@ class Expression(ExpressionNode):
self.lhs = lhs self.lhs = lhs
self.rhs = rhs self.rhs = rhs
def __repr__(self):
return "<{}: {}>".format(self.__class__.__name__, self)
def __str__(self):
return "{} {} {}".format(self.lhs, self.connector, self.rhs)
def get_source_expressions(self): def get_source_expressions(self):
return [self.lhs, self.rhs] return [self.lhs, self.rhs]
@ -408,7 +414,7 @@ class DurationExpression(Expression):
return expression_wrapper % sql, expression_params return expression_wrapper % sql, expression_params
class F(CombinableMixin): class F(Combinable):
""" """
An object capable of resolving references to existing query objects. An object capable of resolving references to existing query objects.
""" """
@ -419,6 +425,9 @@ class F(CombinableMixin):
""" """
self.name = name self.name = name
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.name)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize) return query.resolve_ref(self.name, allow_joins, reuse, summarize)
@ -446,6 +455,13 @@ class Func(ExpressionNode):
self.source_expressions = self._parse_expressions(*expressions) self.source_expressions = self._parse_expressions(*expressions)
self.extra = extra self.extra = extra
def __repr__(self):
args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
extra = ', '.join(str(key) + '=' + str(val) for key, val in self.extra.items())
if extra:
return "{}({}, {})".format(self.__class__.__name__, args, extra)
return "{}({})".format(self.__class__.__name__, args)
def get_source_expressions(self): def get_source_expressions(self):
return self.source_expressions return self.source_expressions
@ -504,6 +520,9 @@ class Value(ExpressionNode):
super(Value, self).__init__(output_field=output_field) super(Value, self).__init__(output_field=output_field)
self.value = value self.value = value
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.value)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
val = self.value val = self.value
@ -545,6 +564,9 @@ class RawSQL(ExpressionNode):
self.sql, self.params = sql, params self.sql, self.params = sql, params
super(RawSQL, self).__init__(output_field=output_field) super(RawSQL, self).__init__(output_field=output_field)
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return '(%s)' % self.sql, self.params return '(%s)' % self.sql, self.params
@ -556,6 +578,9 @@ class Random(ExpressionNode):
def __init__(self): def __init__(self):
super(Random, self).__init__(output_field=fields.FloatField()) super(Random, self).__init__(output_field=fields.FloatField())
def __repr__(self):
return "Random()"
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return connection.ops.random_function_sql(), [] return connection.ops.random_function_sql(), []
@ -567,6 +592,10 @@ class Col(ExpressionNode):
super(Col, self).__init__(output_field=source) super(Col, self).__init__(output_field=source)
self.alias, self.target = alias, target self.alias, self.target = alias, target
def __repr__(self):
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.target)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias qn = compiler.quote_name_unless_alias
return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
@ -588,8 +617,10 @@ class Ref(ExpressionNode):
""" """
def __init__(self, refs, source): def __init__(self, refs, source):
super(Ref, self).__init__() super(Ref, self).__init__()
self.source = source self.refs, self.source = refs, source
self.refs = refs
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
def get_source_expressions(self): def get_source_expressions(self):
return [self.source] return [self.source]
@ -743,6 +774,9 @@ class Date(ExpressionNode):
self.col = None self.col = None
self.lookup_type = lookup_type self.lookup_type = lookup_type
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type)
def get_source_expressions(self): def get_source_expressions(self):
return [self.col] return [self.col]
@ -792,6 +826,10 @@ class DateTime(ExpressionNode):
self.tzname = timezone._get_timezone_name(tzinfo) self.tzname = timezone._get_timezone_name(tzinfo)
self.tzinfo = tzinfo self.tzinfo = tzinfo
def __repr__(self):
return "{}({}, {}, {})".format(
self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo)
def get_source_expressions(self): def get_source_expressions(self):
return [self.col] return [self.col]
@ -833,8 +871,6 @@ class DateTime(ExpressionNode):
class OrderBy(BaseExpression): class OrderBy(BaseExpression):
template = '%(expression)s %(ordering)s' template = '%(expression)s %(ordering)s'
descending_template = 'DESC'
ascending_template = 'ASC'
def __init__(self, expression, descending=False): def __init__(self, expression, descending=False):
self.descending = descending self.descending = descending
@ -842,6 +878,10 @@ class OrderBy(BaseExpression):
raise ValueError('expression must be an expression type') raise ValueError('expression must be an expression type')
self.expression = expression self.expression = expression
def __repr__(self):
return "{}({}, descending={})".format(
self.__class__.__name__, self.expression, self.descending)
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.expression = exprs[0] self.expression = exprs[0]

View File

@ -6,10 +6,17 @@ import uuid
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection, transaction, DatabaseError from django.db import connection, transaction, DatabaseError
from django.db.models import F, Value, TimeField, UUIDField from django.db.models import TimeField, UUIDField
from django.db.models.aggregates import Avg, Count, Max, Min, StdDev, Sum, Variance
from django.db.models.expressions import (
Case, Col, Date, DateTime, F, Func, OrderBy,
Random, RawSQL, Ref, Value, When
)
from django.db.models.functions import Coalesce, Concat, Length, Lower, Substr, Upper
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate
from django.utils import six from django.utils import six
from django.utils.timezone import utc
from .models import Company, Employee, Number, Experiment, Time, UUID from .models import Company, Employee, Number, Experiment, Time, UUID
@ -812,3 +819,40 @@ class ValueTests(TestCase):
UUID.objects.create() UUID.objects.create()
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField())) UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012')) self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
class ReprTests(TestCase):
def test_expressions(self):
self.assertEqual(
repr(Case(When(a=1))),
"<Case: CASE WHEN <Q: (AND: ('a', 1))> THEN Value(None), ELSE Value(None)>"
)
self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)")
self.assertEqual(repr(Date('published', 'exact')), "Date(published, exact)")
self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, UTC)")
self.assertEqual(repr(F('published')), "F(published)")
self.assertEqual(repr(F('cost') + F('tax')), "<Expression: F(cost) + F(tax)>")
self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
self.assertEqual(repr(Random()), "Random()")
self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])")
self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))")
self.assertEqual(repr(Value(1)), "Value(1)")
def test_functions(self):
self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))")
self.assertEqual(repr(Concat('a', 'b')), "Concat(ConcatPair(F(a), F(b)))")
self.assertEqual(repr(Length('a')), "Length(F(a))")
self.assertEqual(repr(Lower('a')), "Lower(F(a))")
self.assertEqual(repr(Substr('a', 1, 3)), "Substr(F(a), Value(1), Value(3))")
self.assertEqual(repr(Upper('a')), "Upper(F(a))")
def test_aggregates(self):
self.assertEqual(repr(Avg('a')), "Avg(F(a))")
self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)")
self.assertEqual(repr(Max('a')), "Max(F(a))")
self.assertEqual(repr(Min('a')), "Min(F(a))")
self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)")
self.assertEqual(repr(Sum('a')), "Sum(F(a))")
self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")