From 6c68e40e6e7f3ba36fa0e629d5724c7f4b279bb8 Mon Sep 17 00:00:00 2001 From: Josh Smeaton Date: Tue, 27 Jan 2015 13:40:32 +1100 Subject: [PATCH] [1.8.x] Refs #14030 -- Added repr methods to all expressions Backport of 7171bf755b0c4be85ddbcc164eaf87164c131021 from master --- django/db/models/aggregates.py | 21 ++++++++++++++ django/db/models/expressions.py | 50 +++++++++++++++++++++++++++++---- tests/expressions/tests.py | 46 +++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 6 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 668f79f622..c1ddc75d4c 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -94,6 +94,13 @@ class Count(Aggregate): super(Count, self).__init__( expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) + def __repr__(self): + return "{}({}, distinct={})".format( + self.__class__.__name__, + self.arg_joiner.join(str(arg) for arg in self.source_expressions), + 'False' if self.extra['distinct'] == '' else 'True', + ) + def convert_value(self, value, connection, context): if value is None: return 0 @@ -117,6 +124,13 @@ class StdDev(Aggregate): self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' super(StdDev, self).__init__(expression, output_field=FloatField(), **extra) + def __repr__(self): + return "{}({}, sample={})".format( + self.__class__.__name__, + self.arg_joiner.join(str(arg) for arg in self.source_expressions), + 'False' if self.function == 'STDDEV_POP' else 'True', + ) + def convert_value(self, value, connection, context): if value is None: return value @@ -135,6 +149,13 @@ class Variance(Aggregate): self.function = 'VAR_SAMP' if sample else 'VAR_POP' super(Variance, self).__init__(expression, output_field=FloatField(), **extra) + def __repr__(self): + return "{}({}, sample={})".format( + self.__class__.__name__, + self.arg_joiner.join(str(arg) for arg in self.source_expressions), + 'False' if self.function == 'VAR_POP' else 'True', + ) + def convert_value(self, value, connection, context): if value is None: return value diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index fb094fd4ef..9caa5dace4 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -340,6 +340,12 @@ class Expression(ExpressionNode): self.lhs = lhs self.rhs = rhs + def __repr__(self): + return "<{}: {}>".format(self.__class__.__name__, self) + + def __str__(self): + return "{} {} {}".format(self.lhs, self.connector, self.rhs) + def get_source_expressions(self): return [self.lhs, self.rhs] @@ -408,7 +414,7 @@ class DurationExpression(Expression): return expression_wrapper % sql, expression_params -class F(CombinableMixin): +class F(Combinable): """ An object capable of resolving references to existing query objects. """ @@ -419,6 +425,9 @@ class F(CombinableMixin): """ self.name = name + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self.name) + def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): return query.resolve_ref(self.name, allow_joins, reuse, summarize) @@ -446,6 +455,13 @@ class Func(ExpressionNode): self.source_expressions = self._parse_expressions(*expressions) self.extra = extra + def __repr__(self): + args = self.arg_joiner.join(str(arg) for arg in self.source_expressions) + extra = ', '.join(str(key) + '=' + str(val) for key, val in self.extra.items()) + if extra: + return "{}({}, {})".format(self.__class__.__name__, args, extra) + return "{}({})".format(self.__class__.__name__, args) + def get_source_expressions(self): return self.source_expressions @@ -504,6 +520,9 @@ class Value(ExpressionNode): super(Value, self).__init__(output_field=output_field) self.value = value + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self.value) + def as_sql(self, compiler, connection): connection.ops.check_expression_support(self) val = self.value @@ -545,6 +564,9 @@ class RawSQL(ExpressionNode): self.sql, self.params = sql, params super(RawSQL, self).__init__(output_field=output_field) + def __repr__(self): + return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params) + def as_sql(self, compiler, connection): return '(%s)' % self.sql, self.params @@ -556,6 +578,9 @@ class Random(ExpressionNode): def __init__(self): super(Random, self).__init__(output_field=fields.FloatField()) + def __repr__(self): + return "Random()" + def as_sql(self, compiler, connection): return connection.ops.random_function_sql(), [] @@ -567,6 +592,10 @@ class Col(ExpressionNode): super(Col, self).__init__(output_field=source) self.alias, self.target = alias, target + def __repr__(self): + return "{}({}, {})".format( + self.__class__.__name__, self.alias, self.target) + def as_sql(self, compiler, connection): qn = compiler.quote_name_unless_alias return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] @@ -588,8 +617,10 @@ class Ref(ExpressionNode): """ def __init__(self, refs, source): super(Ref, self).__init__() - self.source = source - self.refs = refs + self.refs, self.source = refs, source + + def __repr__(self): + return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source) def get_source_expressions(self): return [self.source] @@ -743,6 +774,9 @@ class Date(ExpressionNode): self.col = None self.lookup_type = lookup_type + def __repr__(self): + return "{}({}, {})".format(self.__class__.__name__, self.lookup, self.lookup_type) + def get_source_expressions(self): return [self.col] @@ -792,6 +826,10 @@ class DateTime(ExpressionNode): self.tzname = timezone._get_timezone_name(tzinfo) self.tzinfo = tzinfo + def __repr__(self): + return "{}({}, {}, {})".format( + self.__class__.__name__, self.lookup, self.lookup_type, self.tzinfo) + def get_source_expressions(self): return [self.col] @@ -833,8 +871,6 @@ class DateTime(ExpressionNode): class OrderBy(BaseExpression): template = '%(expression)s %(ordering)s' - descending_template = 'DESC' - ascending_template = 'ASC' def __init__(self, expression, descending=False): self.descending = descending @@ -842,6 +878,10 @@ class OrderBy(BaseExpression): raise ValueError('expression must be an expression type') self.expression = expression + def __repr__(self): + return "{}({}, descending={})".format( + self.__class__.__name__, self.expression, self.descending) + def set_source_expressions(self, exprs): self.expression = exprs[0] diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 3c508b8520..f7e8cae856 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -6,10 +6,17 @@ import uuid from django.core.exceptions import FieldError from django.db import connection, transaction, DatabaseError -from django.db.models import F, Value, TimeField, UUIDField +from django.db.models import TimeField, UUIDField +from django.db.models.aggregates import Avg, Count, Max, Min, StdDev, Sum, Variance +from django.db.models.expressions import ( + Case, Col, Date, DateTime, F, Func, OrderBy, + Random, RawSQL, Ref, Value, When +) +from django.db.models.functions import Coalesce, Concat, Length, Lower, Substr, Upper from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import Approximate from django.utils import six +from django.utils.timezone import utc from .models import Company, Employee, Number, Experiment, Time, UUID @@ -812,3 +819,40 @@ class ValueTests(TestCase): UUID.objects.create() UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField())) self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012')) + + +class ReprTests(TestCase): + + def test_expressions(self): + self.assertEqual( + repr(Case(When(a=1))), + " THEN Value(None), ELSE Value(None)>" + ) + self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)") + self.assertEqual(repr(Date('published', 'exact')), "Date(published, exact)") + self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, UTC)") + self.assertEqual(repr(F('published')), "F(published)") + self.assertEqual(repr(F('cost') + F('tax')), "") + self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)") + self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)') + self.assertEqual(repr(Random()), "Random()") + self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])") + self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))") + self.assertEqual(repr(Value(1)), "Value(1)") + + def test_functions(self): + self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))") + self.assertEqual(repr(Concat('a', 'b')), "Concat(ConcatPair(F(a), F(b)))") + self.assertEqual(repr(Length('a')), "Length(F(a))") + self.assertEqual(repr(Lower('a')), "Lower(F(a))") + self.assertEqual(repr(Substr('a', 1, 3)), "Substr(F(a), Value(1), Value(3))") + self.assertEqual(repr(Upper('a')), "Upper(F(a))") + + def test_aggregates(self): + self.assertEqual(repr(Avg('a')), "Avg(F(a))") + self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)") + self.assertEqual(repr(Max('a')), "Max(F(a))") + self.assertEqual(repr(Min('a')), "Min(F(a))") + self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)") + self.assertEqual(repr(Sum('a')), "Sum(F(a))") + self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)")