Fixed #24031 -- Added CASE expressions to the ORM.

This commit is contained in:
Michał Modzelewski 2015-01-02 02:39:31 +01:00 committed by Tim Graham
parent aa8ee6a573
commit 65246de7b1
20 changed files with 1659 additions and 33 deletions

View File

@ -475,6 +475,7 @@ answer newbie questions, and generally made Django that much better:
Michael Thornhill <michael.thornhill@gmail.com> Michael Thornhill <michael.thornhill@gmail.com>
Michal Chruszcz <troll@pld-linux.org> Michal Chruszcz <troll@pld-linux.org>
michal@plovarna.cz michal@plovarna.cz
Michał Modzelewski <michal.modzelewski@gmail.com>
Mihai Damian <yang_damian@yahoo.com> Mihai Damian <yang_damian@yahoo.com>
Mihai Preda <mihai_preda@yahoo.com> Mihai Preda <mihai_preda@yahoo.com>
Mikaël Barbero <mikael.barbero nospam at nospam free.fr> Mikaël Barbero <mikael.barbero nospam at nospam free.fr>

View File

@ -821,6 +821,14 @@ class BaseDatabaseOperations(object):
""" """
return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s" return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s"
def unification_cast_sql(self, output_field):
"""
Given a field instance, returns the SQL necessary to cast the result of
a union to that type. Note that the resulting string should contain a
'%s' placeholder for the expression being cast.
"""
return '%s'
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, field_name):
""" """
Given a lookup_type of 'year', 'month' or 'day', returns the SQL that Given a lookup_type of 'year', 'month' or 'day', returns the SQL that

View File

@ -5,6 +5,18 @@ from django.db.backends import BaseDatabaseOperations
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# http://www.postgresql.org/docs/9.4/static/typeconv-union-case.html
# These fields cannot be implicitly cast back in the default
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
return '%s'
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, field_name):
# http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT # http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT

View File

@ -4,7 +4,7 @@ import warnings
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA
from django.db.models.query import Q, QuerySet, Prefetch # NOQA from django.db.models.query import Q, QuerySet, Prefetch # NOQA
from django.db.models.expressions import ExpressionNode, F, Value, Func # NOQA from django.db.models.expressions import ExpressionNode, F, Value, Func, Case, When # NOQA
from django.db.models.manager import Manager # NOQA from django.db.models.manager import Manager # NOQA
from django.db.models.base import Model # NOQA from django.db.models.base import Model # NOQA
from django.db.models.aggregates import * # NOQA from django.db.models.aggregates import * # NOQA

View File

@ -14,8 +14,9 @@ class Aggregate(Func):
contains_aggregate = True contains_aggregate = True
name = None name = None
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
assert len(self.source_expressions) == 1 assert len(self.source_expressions) == 1
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize) c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
if c.source_expressions[0].contains_aggregate and not summarize: if c.source_expressions[0].contains_aggregate and not summarize:
name = self.source_expressions[0].name name = self.source_expressions[0].name
@ -101,7 +102,6 @@ class Count(Aggregate):
def __init__(self, expression, distinct=False, **extra): def __init__(self, expression, distinct=False, **extra):
if expression == '*': if expression == '*':
expression = Value(expression) expression = Value(expression)
expression._output_field = IntegerField()
super(Count, self).__init__( super(Count, self).__init__(
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)

View File

@ -6,7 +6,7 @@ from django.core.exceptions import FieldError
from django.db.backends import utils as backend_utils from django.db.backends import utils as backend_utils
from django.db.models import fields from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import refs_aggregate from django.db.models.query_utils import refs_aggregate, Q
from django.utils import timezone from django.utils import timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -173,7 +173,7 @@ class BaseExpression(object):
return True return True
return False return False
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
""" """
Provides the chance to do any preprocessing or validation before being Provides the chance to do any preprocessing or validation before being
added to the query. added to the query.
@ -380,11 +380,11 @@ class Expression(ExpressionNode):
sql = connection.ops.combine_expression(self.connector, expressions) sql = connection.ops.combine_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params return expression_wrapper % sql, expression_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy() c = self.copy()
c.is_summary = summarize c.is_summary = summarize
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize) c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize) c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c return c
@ -426,7 +426,7 @@ class F(CombinableMixin):
""" """
self.name = name self.name = name
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize) return query.resolve_ref(self.name, allow_joins, reuse, summarize)
def refs_aggregate(self, existing_aggregates): def refs_aggregate(self, existing_aggregates):
@ -465,11 +465,11 @@ class Func(ExpressionNode):
for arg in expressions for arg in expressions
] ]
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy() c = self.copy()
c.is_summary = summarize c.is_summary = summarize
for pos, arg in enumerate(c.source_expressions): for pos, arg in enumerate(c.source_expressions):
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize) c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c return c
def as_sql(self, compiler, connection, function=None, template=None): def as_sql(self, compiler, connection, function=None, template=None):
@ -511,12 +511,24 @@ class Value(ExpressionNode):
self.value = value self.value = value
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if self.value is None: val = self.value
# check _output_field to avoid triggering an exception
if self._output_field is not None:
if self.for_save:
val = self.output_field.get_db_prep_save(val, connection=connection)
else:
val = self.output_field.get_db_prep_value(val, connection=connection)
if val is None:
# cx_Oracle does not always convert None to the appropriate # cx_Oracle does not always convert None to the appropriate
# NULL type (like in case expressions using numbers), so we # NULL type (like in case expressions using numbers), so we
# use a literal SQL NULL # use a literal SQL NULL
return 'NULL', [] return 'NULL', []
return '%s', [self.value] return '%s', [val]
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = super(Value, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
c.for_save = for_save
return c
def get_group_by_cols(self): def get_group_by_cols(self):
return [] return []
@ -599,6 +611,130 @@ class Ref(ExpressionNode):
return [self] return [self]
class When(ExpressionNode):
template = 'WHEN %(condition)s THEN %(result)s'
def __init__(self, condition=None, then=Value(None), **lookups):
if lookups and condition is None:
condition, lookups = Q(**lookups), None
if condition is None or not isinstance(condition, Q) or lookups:
raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
super(When, self).__init__(output_field=None)
self.condition = condition
self.result = self._parse_expression(then)
def __str__(self):
return "WHEN %r THEN %r" % (self.condition, self.result)
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self)
def get_source_expressions(self):
return [self.condition, self.result]
def set_source_expressions(self, exprs):
self.condition, self.result = exprs
def get_source_fields(self):
# We're only interested in the fields of the result expressions.
return [self.result._output_field_or_none]
def _parse_expression(self, expression):
return expression if hasattr(expression, 'resolve_expression') else F(expression)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy()
c.is_summary = summarize
c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False)
c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c
def as_sql(self, compiler, connection, template=None):
template_params = {}
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
template_params['condition'] = condition_sql
sql_params.extend(condition_params)
result_sql, result_params = compiler.compile(self.result)
template_params['result'] = result_sql
sql_params.extend(result_params)
template = template or self.template
return template % template_params, sql_params
def get_group_by_cols(self):
# This is not a complete expression and cannot be used in GROUP BY.
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
return cols
class Case(ExpressionNode):
"""
An SQL searched CASE expression:
CASE
WHEN n > 0
THEN 'positive'
WHEN n < 0
THEN 'negative'
ELSE 'zero'
END
"""
template = 'CASE %(cases)s ELSE %(default)s END'
case_joiner = ' '
def __init__(self, *cases, **extra):
if not all(isinstance(case, When) for case in cases):
raise TypeError("Positional arguments must all be When objects.")
default = extra.pop('default', Value(None))
output_field = extra.pop('output_field', None)
super(Case, self).__init__(output_field)
self.cases = list(cases)
self.default = default if hasattr(default, 'resolve_expression') else F(default)
def __str__(self):
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self)
def get_source_expressions(self):
return self.cases + [self.default]
def set_source_expressions(self, exprs):
self.cases = exprs[:-1]
self.default = exprs[-1]
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy()
c.is_summary = summarize
for pos, case in enumerate(c.cases):
c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save)
c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c
def as_sql(self, compiler, connection, template=None, extra=None):
if not self.cases:
return compiler.compile(self.default)
template_params = dict(extra) if extra else {}
case_parts = []
sql_params = []
for case in self.cases:
case_sql, case_params = compiler.compile(case)
case_parts.append(case_sql)
sql_params.extend(case_params)
template_params['cases'] = self.case_joiner.join(case_parts)
default_sql, default_params = compiler.compile(self.default)
template_params['default'] = default_sql
sql_params.extend(default_params)
template = template or self.template
sql = template % template_params
if self._output_field_or_none is not None:
sql = connection.ops.unification_cast_sql(self.output_field) % sql
return sql, sql_params
class Date(ExpressionNode): class Date(ExpressionNode):
""" """
Add a date selection column. Add a date selection column.
@ -615,7 +751,7 @@ class Date(ExpressionNode):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.col, = exprs self.col, = exprs
def resolve_expression(self, query, allow_joins, reuse, summarize): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = self.copy() copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize) copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field field = copy.col.output_field
@ -664,7 +800,7 @@ class DateTime(ExpressionNode):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.col, = exprs self.col, = exprs
def resolve_expression(self, query, allow_joins, reuse, summarize): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = self.copy() copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize) copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field field = copy.col.output_field

View File

@ -86,6 +86,27 @@ class Q(tree.Node):
clone.children.append(child) clone.children.append(child)
return clone return clone
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
clause, _ = query._add_q(self, reuse, allow_joins=allow_joins)
return clause
def refs_aggregate(self, existing_aggregates):
def _refs_aggregate(obj, existing_aggregates):
if not isinstance(obj, tree.Node):
aggregate, aggregate_lookups = refs_aggregate(obj[0].split(LOOKUP_SEP), existing_aggregates)
if not aggregate and hasattr(obj[1], 'refs_aggregate'):
return obj[1].refs_aggregate(existing_aggregates)
return aggregate, aggregate_lookups
for c in obj.children:
aggregate, aggregate_lookups = _refs_aggregate(c, existing_aggregates)
if aggregate:
return aggregate, aggregate_lookups
return False, ()
if not existing_aggregates:
return False
return _refs_aggregate(self, existing_aggregates)
class DeferredAttribute(object): class DeferredAttribute(object):
""" """

View File

@ -998,7 +998,9 @@ class SQLUpdateCompiler(SQLCompiler):
values, update_params = [], [] values, update_params = [], []
for field, model, val in self.query.values: for field, model, val in self.query.values:
if hasattr(val, 'resolve_expression'): if hasattr(val, 'resolve_expression'):
val = val.resolve_expression(self.query, allow_joins=False) val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
if val.contains_aggregate:
raise FieldError("Aggregate functions are not allowed in this query")
elif hasattr(val, 'prepare_database_save'): elif hasattr(val, 'prepare_database_save'):
if field.rel: if field.rel:
val = val.prepare_database_save(field) val = val.prepare_database_save(field)

View File

@ -961,7 +961,7 @@ class Query(object):
self.append_annotation_mask([alias]) self.append_annotation_mask([alias])
self.annotations[alias] = annotation self.annotations[alias] = annotation
def prepare_lookup_value(self, value, lookups, can_reuse): def prepare_lookup_value(self, value, lookups, can_reuse, allow_joins=True):
# Default lookup if none given is exact. # Default lookup if none given is exact.
used_joins = [] used_joins = []
if len(lookups) == 0: if len(lookups) == 0:
@ -980,7 +980,7 @@ class Query(object):
value = value() value = value()
elif hasattr(value, 'resolve_expression'): elif hasattr(value, 'resolve_expression'):
pre_joins = self.alias_refcount.copy() pre_joins = self.alias_refcount.copy()
value = value.resolve_expression(self, reuse=can_reuse) value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)] used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
# Subqueries need to use a different set of aliases than the # Subqueries need to use a different set of aliases than the
# outer query. Call bump_prefix to change aliases of the inner # outer query. Call bump_prefix to change aliases of the inner
@ -1095,7 +1095,7 @@ class Query(object):
(name, lhs.output_field.__class__.__name__)) (name, lhs.output_field.__class__.__name__))
def build_filter(self, filter_expr, branch_negated=False, current_negated=False, def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None, connector=AND): can_reuse=None, connector=AND, allow_joins=True):
""" """
Builds a WhereNode for a single filter clause, but doesn't add it Builds a WhereNode for a single filter clause, but doesn't add it
to this Query. Query.add_q() will then add this filter to the where to this Query. Query.add_q() will then add this filter to the where
@ -1125,10 +1125,12 @@ class Query(object):
if not arg: if not arg:
raise FieldError("Cannot parse keyword query %r" % arg) raise FieldError("Cannot parse keyword query %r" % arg)
lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) lookups, parts, reffed_aggregate = self.solve_lookup_type(arg)
if not allow_joins and len(parts) > 1:
raise FieldError("Joined field references are not permitted in this query")
# Work out the lookup type and remove it from the end of 'parts', # Work out the lookup type and remove it from the end of 'parts',
# if necessary. # if necessary.
value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse) value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse, allow_joins)
clause = self.where_class() clause = self.where_class()
if reffed_aggregate: if reffed_aggregate:
@ -1225,11 +1227,11 @@ class Query(object):
""" """
if not self._annotations: if not self._annotations:
return False return False
if not isinstance(obj, Node): if hasattr(obj, 'refs_aggregate'):
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0] return obj.refs_aggregate(self.annotations)[0]
or (hasattr(obj[1], 'refs_aggregate') return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0]
and obj[1].refs_aggregate(self.annotations)[0])) or (hasattr(obj[1], 'refs_aggregate')
return any(self.need_having(c) for c in obj.children) and obj[1].refs_aggregate(self.annotations)[0]))
def split_having_parts(self, q_object, negated=False): def split_having_parts(self, q_object, negated=False):
""" """
@ -1287,7 +1289,7 @@ class Query(object):
self.demote_joins(existing_inner) self.demote_joins(existing_inner)
def _add_q(self, q_object, used_aliases, branch_negated=False, def _add_q(self, q_object, used_aliases, branch_negated=False,
current_negated=False): current_negated=False, allow_joins=True):
""" """
Adds a Q-object to the current filter. Adds a Q-object to the current filter.
""" """
@ -1301,12 +1303,12 @@ class Query(object):
if isinstance(child, Node): if isinstance(child, Node):
child_clause, needed_inner = self._add_q( child_clause, needed_inner = self._add_q(
child, used_aliases, branch_negated, child, used_aliases, branch_negated,
current_negated) current_negated, allow_joins)
joinpromoter.add_votes(needed_inner) joinpromoter.add_votes(needed_inner)
else: else:
child_clause, needed_inner = self.build_filter( child_clause, needed_inner = self.build_filter(
child, can_reuse=used_aliases, branch_negated=branch_negated, child, can_reuse=used_aliases, branch_negated=branch_negated,
current_negated=current_negated, connector=connector) current_negated=current_negated, connector=connector, allow_joins=allow_joins)
joinpromoter.add_votes(needed_inner) joinpromoter.add_votes(needed_inner)
target_clause.add(child_clause, connector) target_clause.add(child_clause, connector)
needed_inner = joinpromoter.update_join_types(self) needed_inner = joinpromoter.update_join_types(self)

View File

@ -11,6 +11,7 @@ from django.conf import settings
from django.db.models.fields import DateTimeField, Field from django.db.models.fields import DateTimeField, Field
from django.db.models.sql.datastructures import EmptyResultSet, Empty from django.db.models.sql.datastructures import EmptyResultSet, Empty
from django.utils.deprecation import RemovedInDjango19Warning from django.utils.deprecation import RemovedInDjango19Warning
from django.utils.functional import cached_property
from django.utils.six.moves import range from django.utils.six.moves import range
from django.utils import timezone from django.utils import timezone
from django.utils import tree from django.utils import tree
@ -309,6 +310,30 @@ class WhereNode(tree.Node):
clone.children.append(child) clone.children.append(child)
return clone return clone
def relabeled_clone(self, change_map):
clone = self.clone()
clone.relabel_aliases(change_map)
return clone
@cached_property
def contains_aggregate(self):
def _contains_aggregate(obj):
if not isinstance(obj, tree.Node):
return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False)
return any(_contains_aggregate(c) for c in obj.children)
return _contains_aggregate(self)
def refs_field(self, aggregate_types, field_types):
def _refs_field(obj, aggregate_types, field_types):
if not isinstance(obj, tree.Node):
if hasattr(obj.rhs, 'refs_field'):
return obj.rhs.refs_field(aggregate_types, field_types)
return False
return any(_refs_field(c, aggregate_types, field_types) for c in obj.children)
return _refs_field(self, aggregate_types, field_types)
class EmptyWhere(WhereNode): class EmptyWhere(WhereNode):
def add(self, data, connector): def add(self, data, connector):

View File

@ -87,6 +87,7 @@ manipulating the data of your Web application. Learn more about it below:
:doc:`Multiple databases <topics/db/multi-db>` | :doc:`Multiple databases <topics/db/multi-db>` |
:doc:`Custom lookups <howto/custom-lookups>` | :doc:`Custom lookups <howto/custom-lookups>` |
:doc:`Query Expressions <ref/models/expressions>` | :doc:`Query Expressions <ref/models/expressions>` |
:doc:`Conditional Expressions <ref/models/conditional-expressions>` |
:doc:`Database Functions <ref/models/database-functions>` :doc:`Database Functions <ref/models/database-functions>`
* **Other:** * **Other:**

View File

@ -0,0 +1,212 @@
=======================
Conditional Expressions
=======================
.. currentmodule:: django.db.models.expressions
.. versionadded:: 1.8
Conditional expressions let you use :keyword:`if` ... :keyword:`elif` ...
:keyword:`else` logic within filters, annotations, aggregations, and updates. A
conditional expression evaluates a series of conditions for each row of a
table and returns the matching result expression. Conditional expressions can
also be combined and nested like other :doc:`expressions <expressions>`.
The conditional expression classes
==================================
We'll be using the following model in the subsequent examples::
from django.db import models
class Client(models.Model):
REGULAR = 'R'
GOLD = 'G'
PLATINUM = 'P'
ACCOUNT_TYPE_CHOICES = (
(REGULAR, 'Regular'),
(GOLD, 'Gold'),
(PLATINUM, 'Platinum'),
)
name = models.CharField(max_length=50)
registered_on = models.DateField()
account_type = models.CharField(
max_length=1,
choices=ACCOUNT_TYPE_CHOICES,
default=REGULAR,
)
When
----
.. class:: When(condition=None, then=Value(None), **lookups)
A ``When()`` object is used to encapsulate a condition and its result for use
in the conditional expression. Using a ``When()`` object is similar to using
the :meth:`~django.db.models.query.QuerySet.filter` method. The condition can
be specified using :ref:`field lookups <field-lookups>` or
:class:`~django.db.models.Q` objects. The result is provided using the ``then``
keyword.
Some examples::
>>> from django.db.models import When, F, Q
>>> # String arguments refer to fields; the following two examples are equivalent:
>>> When(account_type=Client.GOLD, then='name')
>>> When(account_type=Client.GOLD, then=F('name'))
>>> # You can use field lookups in the condition
>>> from datetime import date
>>> When(registered_on__gt=date(2014, 1, 1),
... registered_on__lt=date(2015, 1, 1),
... then='account_type')
>>> # Complex conditions can be created using Q objects
>>> When(Q(name__startswith="John") | Q(name__startswith="Paul"),
... then='name')
Keep in mind that each of these values can be an expression.
.. note::
Since the ``then`` keyword argument is reserved for the result of the
``When()``, there is a potential conflict if a
:class:`~django.db.models.Model` has a field named ``then``. This can be
resolved in two ways::
>>> from django.db.models import Value
>>> When(then__exact=0, then=Value(1))
>>> When(Q(then=0), then=Value(1))
Case
----
.. class:: Case(*cases, **extra)
A ``Case()`` expression is like the :keyword:`if` ... :keyword:`elif` ...
:keyword:`else` statement in ``Python``. Each ``condition`` in the provided
``When()`` objects is evaluated in order, until one evaluates to a
truthful value. The ``result`` expression from the matching ``When()`` object
is returned.
A simple example::
>>>
>>> from datetime import date, timedelta
>>> from django.db.models import CharField, Case, Value, When
>>> Client.objects.create(
... name='Jane Doe',
... account_type=Client.REGULAR,
... registered_on=date.today() - timedelta(days=36))
>>> Client.objects.create(
... name='James Smith',
... account_type=Client.GOLD,
... registered_on=date.today() - timedelta(days=5))
>>> Client.objects.create(
... name='Jack Black',
... account_type=Client.PLATINUM,
... registered_on=date.today() - timedelta(days=10 * 365))
>>> # Get the discount for each Client based on the account type
>>> Client.objects.annotate(
... discount=Case(
... When(account_type=Client.GOLD, then=Value('5%')),
... When(account_type=Client.PLATINUM, then=Value('10%')),
... default=Value('0%'),
... output_field=CharField(),
... ),
... ).values_list('name', 'discount')
[('Jane Doe', '0%'), ('James Smith', '5%'), ('Jack Black', '10%')]
``Case()`` accepts any number of ``When()`` objects as individual arguments.
Other options are provided using keyword arguments. If none of the conditions
evaluate to ``TRUE``, then the expression given with the ``default`` keyword
argument is returned. If no ``default`` argument is provided, ``Value(None)``
is used.
If we wanted to change our previous query to get the discount based on how long
the ``Client`` has been with us, we could do so using lookups::
>>> a_month_ago = date.today() - timedelta(days=30)
>>> a_year_ago = date.today() - timedelta(days=365)
>>> # Get the discount for each Client based on the registration date
>>> Client.objects.annotate(
... discount=Case(
... When(registered_on__lte=a_year_ago, then=Value('10%')),
... When(registered_on__lte=a_month_ago, then=Value('5%')),
... default=Value('0%'),
... output_field=CharField(),
... )
... ).values_list('name', 'discount')
[('Jane Doe', '5%'), ('James Smith', '0%'), ('Jack Black', '10%')]
.. note::
Remember that the conditions are evaluated in order, so in the above
example we get the correct result even though the second condition matches
both Jane Doe and Jack Black. This works just like an :keyword:`if` ...
:keyword:`elif` ... :keyword:`else` statement in ``Python``.
Advanced queries
================
Conditional expressions can be used in annotations, aggregations, lookups, and
updates. They can also be combined and nested with other expressions. This
allows you to make powerful conditional queries.
Conditional update
------------------
Let's say we want to change the ``account_type`` for our clients to match
their registration dates. We can do this using a conditional expression and the
:meth:`~django.db.models.query.QuerySet.update` method::
>>> a_month_ago = date.today() - timedelta(days=30)
>>> a_year_ago = date.today() - timedelta(days=365)
>>> # Update the account_type for each Client from the registration date
>>> Client.objects.update(
... account_type=Case(
... When(registered_on__lte=a_year_ago,
... then=Value(Client.PLATINUM)),
... When(registered_on__lte=a_month_ago,
... then=Value(Client.GOLD)),
... default=Value(Client.REGULAR)
... ),
... )
>>> Client.objects.values_list('name', 'account_type')
[('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]
Conditional aggregation
-----------------------
What if we want to find out how many clients there are for each
``account_type``? We can nest conditional expression within
:ref:`aggregate functions <aggregation-functions>` to achieve this::
>>> # Create some more Clients first so we can have something to count
>>> Client.objects.create(
... name='Jean Grey',
... account_type=Client.REGULAR,
... registered_on=date.today())
>>> Client.objects.create(
... name='James Bond',
... account_type=Client.PLATINUM,
... registered_on=date.today())
>>> Client.objects.create(
... name='Jane Porter',
... account_type=Client.PLATINUM,
... registered_on=date.today())
>>> # Get counts for each value of account_type
>>> from django.db.models import IntegerField, Sum
>>> Client.objects.aggregate(
... regular=Sum(
... Case(When(account_type=Client.REGULAR, then=Value(1)),
... output_field=IntegerField())
... ),
... gold=Sum(
... Case(When(account_type=Client.GOLD, then=Value(1)),
... output_field=IntegerField())
... ),
... platinum=Sum(
... Case(When(account_type=Client.PLATINUM, then=Value(1)),
... output_field=IntegerField())
... )
... )
{'regular': 2, 'gold': 1, 'platinum': 3}

View File

@ -332,6 +332,15 @@ instantiating the model field as any arguments relating to data validation
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
output value. output value.
Conditional expressions
-----------------------
.. versionadded:: 1.8
Conditional expressions allow you to use :keyword:`if` ... :keyword:`elif` ...
:keyword:`else` logic in queries. Django natively supports SQL ``CASE``
expressions. For more details see :doc:`conditional-expressions`.
Technical Information Technical Information
===================== =====================

View File

@ -16,4 +16,5 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`.
querysets querysets
lookups lookups
expressions expressions
conditional-expressions
database-functions database-functions

View File

@ -93,16 +93,20 @@ New data types
backends. There is a corresponding :class:`form field backends. There is a corresponding :class:`form field
<django.forms.DurationField>`. <django.forms.DurationField>`.
Query Expressions and Database Functions Query Expressions, Conditional Expressions, and Database Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
:doc:`Query Expressions </ref/models/expressions>` allow users to create, :doc:`Query Expressions </ref/models/expressions>` allow you to create,
customize, and compose complex SQL expressions. This has enabled annotate customize, and compose complex SQL expressions. This has enabled annotate
to accept expressions other than aggregates. Aggregates are now able to to accept expressions other than aggregates. Aggregates are now able to
reference multiple fields, as well as perform arithmetic, similar to ``F()`` reference multiple fields, as well as perform arithmetic, similar to ``F()``
objects. :meth:`~django.db.models.query.QuerySet.order_by` has also gained the objects. :meth:`~django.db.models.query.QuerySet.order_by` has also gained the
ability to accept expressions. ability to accept expressions.
:doc:`Conditional Expressions </ref/models/conditional-expressions>` allow
you to use :keyword:`if` ... :keyword:`elif` ... :keyword:`else` logic within
queries.
A collection of :doc:`database functions </ref/models/database-functions>` is A collection of :doc:`database functions </ref/models/database-functions>` is
also included with functionality such as also included with functionality such as
:class:`~django.db.models.functions.Coalesce`, :class:`~django.db.models.functions.Coalesce`,

View File

@ -56,3 +56,19 @@ class Experiment(models.Model):
def duration(self): def duration(self):
return self.end - self.start return self.end - self.start
@python_2_unicode_compatible
class Time(models.Model):
time = models.TimeField(null=True)
def __str__(self):
return "%s" % self.time
@python_2_unicode_compatible
class UUID(models.Model):
uuid = models.UUIDField(null=True)
def __str__(self):
return "%s" % self.uuid

View File

@ -2,15 +2,16 @@ from __future__ import unicode_literals
from copy import deepcopy from copy import deepcopy
import datetime import datetime
import uuid
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection, transaction, DatabaseError from django.db import connection, transaction, DatabaseError
from django.db.models import F, Value from django.db.models import F, Value, TimeField, UUIDField
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import Approximate from django.test.utils import Approximate
from django.utils import six from django.utils import six
from .models import Company, Employee, Number, Experiment from .models import Company, Employee, Number, Experiment, Time, UUID
class BasicExpressionsTests(TestCase): class BasicExpressionsTests(TestCase):
@ -799,3 +800,15 @@ class FTimeDeltaTests(TestCase):
over_estimate = [e.name for e in over_estimate = [e.name for e in
Experiment.objects.filter(estimated_time__lt=F('end') - F('start'))] Experiment.objects.filter(estimated_time__lt=F('end') - F('start'))]
self.assertEqual(over_estimate, ['e4']) self.assertEqual(over_estimate, ['e4'])
class ValueTests(TestCase):
def test_update_TimeField_using_Value(self):
Time.objects.create()
Time.objects.update(time=Value(datetime.time(1), output_field=TimeField()))
self.assertEqual(Time.objects.get().time, datetime.time(1))
def test_update_UUIDField_using_Value(self):
UUID.objects.create()
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))

View File

View File

@ -0,0 +1,80 @@
from __future__ import unicode_literals
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
@python_2_unicode_compatible
class CaseTestModel(models.Model):
integer = models.IntegerField()
integer2 = models.IntegerField(null=True)
string = models.CharField(max_length=100, default='')
big_integer = models.BigIntegerField(null=True)
binary = models.BinaryField(default=b'')
boolean = models.BooleanField(default=False)
comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, default='')
date = models.DateField(null=True, db_column='date_field')
date_time = models.DateTimeField(null=True)
decimal = models.DecimalField(max_digits=2, decimal_places=1, null=True, db_column='decimal_field')
duration = models.DurationField(null=True)
email = models.EmailField(default='')
file = models.FileField(null=True, db_column='file_field')
file_path = models.FilePathField(null=True)
float = models.FloatField(null=True, db_column='float_field')
image = models.ImageField(null=True)
ip_address = models.IPAddressField(null=True)
generic_ip_address = models.GenericIPAddressField(null=True)
null_boolean = models.NullBooleanField()
positive_integer = models.PositiveIntegerField(null=True)
positive_small_integer = models.PositiveSmallIntegerField(null=True)
slug = models.SlugField(default='')
small_integer = models.SmallIntegerField(null=True)
text = models.TextField(default='')
time = models.TimeField(null=True, db_column='time_field')
url = models.URLField(default='')
uuid = models.UUIDField(null=True)
fk = models.ForeignKey('self', null=True)
def __str__(self):
return "%i, %s" % (self.integer, self.string)
@python_2_unicode_compatible
class O2OCaseTestModel(models.Model):
o2o = models.OneToOneField(CaseTestModel, related_name='o2o_rel')
integer = models.IntegerField()
def __str__(self):
return "%i, %s" % (self.id, self.o2o)
@python_2_unicode_compatible
class FKCaseTestModel(models.Model):
fk = models.ForeignKey(CaseTestModel, related_name='fk_rel')
integer = models.IntegerField()
def __str__(self):
return "%i, %s" % (self.id, self.fk)
@python_2_unicode_compatible
class Client(models.Model):
REGULAR = 'R'
GOLD = 'G'
PLATINUM = 'P'
ACCOUNT_TYPE_CHOICES = (
(REGULAR, 'Regular'),
(GOLD, 'Gold'),
(PLATINUM, 'Platinum'),
)
name = models.CharField(max_length=50)
registered_on = models.DateField()
account_type = models.CharField(
max_length=1,
choices=ACCOUNT_TYPE_CHOICES,
default=REGULAR,
)
def __str__(self):
return self.name

File diff suppressed because it is too large Load Diff