mirror of https://github.com/django/django.git
Fixed #24031 -- Added CASE expressions to the ORM.
This commit is contained in:
parent
aa8ee6a573
commit
65246de7b1
1
AUTHORS
1
AUTHORS
|
@ -475,6 +475,7 @@ answer newbie questions, and generally made Django that much better:
|
|||
Michael Thornhill <michael.thornhill@gmail.com>
|
||||
Michal Chruszcz <troll@pld-linux.org>
|
||||
michal@plovarna.cz
|
||||
Michał Modzelewski <michal.modzelewski@gmail.com>
|
||||
Mihai Damian <yang_damian@yahoo.com>
|
||||
Mihai Preda <mihai_preda@yahoo.com>
|
||||
Mikaël Barbero <mikael.barbero nospam at nospam free.fr>
|
||||
|
|
|
@ -821,6 +821,14 @@ class BaseDatabaseOperations(object):
|
|||
"""
|
||||
return "SELECT cache_key FROM %s ORDER BY cache_key LIMIT 1 OFFSET %%s"
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
"""
|
||||
Given a field instance, returns the SQL necessary to cast the result of
|
||||
a union to that type. Note that the resulting string should contain a
|
||||
'%s' placeholder for the expression being cast.
|
||||
"""
|
||||
return '%s'
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month' or 'day', returns the SQL that
|
||||
|
|
|
@ -5,6 +5,18 @@ from django.db.backends import BaseDatabaseOperations
|
|||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
def unification_cast_sql(self, output_field):
|
||||
internal_type = output_field.get_internal_type()
|
||||
if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
|
||||
# PostgreSQL will resolve a union as type 'text' if input types are
|
||||
# 'unknown'.
|
||||
# http://www.postgresql.org/docs/9.4/static/typeconv-union-case.html
|
||||
# These fields cannot be implicitly cast back in the default
|
||||
# PostgreSQL configuration so we need to explicitly cast them.
|
||||
# We must also remove components of the type within brackets:
|
||||
# varchar(255) -> varchar.
|
||||
return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
|
||||
return '%s'
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
# http://www.postgresql.org/docs/current/static/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
|
||||
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA
|
||||
from django.db.models.query import Q, QuerySet, Prefetch # NOQA
|
||||
from django.db.models.expressions import ExpressionNode, F, Value, Func # NOQA
|
||||
from django.db.models.expressions import ExpressionNode, F, Value, Func, Case, When # NOQA
|
||||
from django.db.models.manager import Manager # NOQA
|
||||
from django.db.models.base import Model # NOQA
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
|
|
|
@ -14,8 +14,9 @@ class Aggregate(Func):
|
|||
contains_aggregate = True
|
||||
name = None
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
assert len(self.source_expressions) == 1
|
||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||
c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if c.source_expressions[0].contains_aggregate and not summarize:
|
||||
name = self.source_expressions[0].name
|
||||
|
@ -101,7 +102,6 @@ class Count(Aggregate):
|
|||
def __init__(self, expression, distinct=False, **extra):
|
||||
if expression == '*':
|
||||
expression = Value(expression)
|
||||
expression._output_field = IntegerField()
|
||||
super(Count, self).__init__(
|
||||
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from django.core.exceptions import FieldError
|
|||
from django.db.backends import utils as backend_utils
|
||||
from django.db.models import fields
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.query_utils import refs_aggregate
|
||||
from django.db.models.query_utils import refs_aggregate, Q
|
||||
from django.utils import timezone
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
@ -173,7 +173,7 @@ class BaseExpression(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
"""
|
||||
Provides the chance to do any preprocessing or validation before being
|
||||
added to the query.
|
||||
|
@ -380,11 +380,11 @@ class Expression(ExpressionNode):
|
|||
sql = connection.ops.combine_expression(self.connector, expressions)
|
||||
return expression_wrapper % sql, expression_params
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return c
|
||||
|
||||
|
||||
|
@ -426,7 +426,7 @@ class F(CombinableMixin):
|
|||
"""
|
||||
self.name = name
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
|
||||
|
||||
def refs_aggregate(self, existing_aggregates):
|
||||
|
@ -465,11 +465,11 @@ class Func(ExpressionNode):
|
|||
for arg in expressions
|
||||
]
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
for pos, arg in enumerate(c.source_expressions):
|
||||
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
|
@ -511,12 +511,24 @@ class Value(ExpressionNode):
|
|||
self.value = value
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.value is None:
|
||||
val = self.value
|
||||
# check _output_field to avoid triggering an exception
|
||||
if self._output_field is not None:
|
||||
if self.for_save:
|
||||
val = self.output_field.get_db_prep_save(val, connection=connection)
|
||||
else:
|
||||
val = self.output_field.get_db_prep_value(val, connection=connection)
|
||||
if val is None:
|
||||
# cx_Oracle does not always convert None to the appropriate
|
||||
# NULL type (like in case expressions using numbers), so we
|
||||
# use a literal SQL NULL
|
||||
return 'NULL', []
|
||||
return '%s', [self.value]
|
||||
return '%s', [val]
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
c = super(Value, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
c.for_save = for_save
|
||||
return c
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
@ -599,6 +611,130 @@ class Ref(ExpressionNode):
|
|||
return [self]
|
||||
|
||||
|
||||
class When(ExpressionNode):
|
||||
template = 'WHEN %(condition)s THEN %(result)s'
|
||||
|
||||
def __init__(self, condition=None, then=Value(None), **lookups):
|
||||
if lookups and condition is None:
|
||||
condition, lookups = Q(**lookups), None
|
||||
if condition is None or not isinstance(condition, Q) or lookups:
|
||||
raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
|
||||
super(When, self).__init__(output_field=None)
|
||||
self.condition = condition
|
||||
self.result = self._parse_expression(then)
|
||||
|
||||
def __str__(self):
|
||||
return "WHEN %r THEN %r" % (self.condition, self.result)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s>" % (self.__class__.__name__, self)
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.condition, self.result]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.condition, self.result = exprs
|
||||
|
||||
def get_source_fields(self):
|
||||
# We're only interested in the fields of the result expressions.
|
||||
return [self.result._output_field_or_none]
|
||||
|
||||
def _parse_expression(self, expression):
|
||||
return expression if hasattr(expression, 'resolve_expression') else F(expression)
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False)
|
||||
c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
template_params = {}
|
||||
sql_params = []
|
||||
condition_sql, condition_params = compiler.compile(self.condition)
|
||||
template_params['condition'] = condition_sql
|
||||
sql_params.extend(condition_params)
|
||||
result_sql, result_params = compiler.compile(self.result)
|
||||
template_params['result'] = result_sql
|
||||
sql_params.extend(result_params)
|
||||
template = template or self.template
|
||||
return template % template_params, sql_params
|
||||
|
||||
def get_group_by_cols(self):
|
||||
# This is not a complete expression and cannot be used in GROUP BY.
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
|
||||
class Case(ExpressionNode):
|
||||
"""
|
||||
An SQL searched CASE expression:
|
||||
|
||||
CASE
|
||||
WHEN n > 0
|
||||
THEN 'positive'
|
||||
WHEN n < 0
|
||||
THEN 'negative'
|
||||
ELSE 'zero'
|
||||
END
|
||||
"""
|
||||
template = 'CASE %(cases)s ELSE %(default)s END'
|
||||
case_joiner = ' '
|
||||
|
||||
def __init__(self, *cases, **extra):
|
||||
if not all(isinstance(case, When) for case in cases):
|
||||
raise TypeError("Positional arguments must all be When objects.")
|
||||
default = extra.pop('default', Value(None))
|
||||
output_field = extra.pop('output_field', None)
|
||||
super(Case, self).__init__(output_field)
|
||||
self.cases = list(cases)
|
||||
self.default = default if hasattr(default, 'resolve_expression') else F(default)
|
||||
|
||||
def __str__(self):
|
||||
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s>" % (self.__class__.__name__, self)
|
||||
|
||||
def get_source_expressions(self):
|
||||
return self.cases + [self.default]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.cases = exprs[:-1]
|
||||
self.default = exprs[-1]
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
for pos, case in enumerate(c.cases):
|
||||
c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, template=None, extra=None):
|
||||
if not self.cases:
|
||||
return compiler.compile(self.default)
|
||||
template_params = dict(extra) if extra else {}
|
||||
case_parts = []
|
||||
sql_params = []
|
||||
for case in self.cases:
|
||||
case_sql, case_params = compiler.compile(case)
|
||||
case_parts.append(case_sql)
|
||||
sql_params.extend(case_params)
|
||||
template_params['cases'] = self.case_joiner.join(case_parts)
|
||||
default_sql, default_params = compiler.compile(self.default)
|
||||
template_params['default'] = default_sql
|
||||
sql_params.extend(default_params)
|
||||
template = template or self.template
|
||||
sql = template % template_params
|
||||
if self._output_field_or_none is not None:
|
||||
sql = connection.ops.unification_cast_sql(self.output_field) % sql
|
||||
return sql, sql_params
|
||||
|
||||
|
||||
class Date(ExpressionNode):
|
||||
"""
|
||||
Add a date selection column.
|
||||
|
@ -615,7 +751,7 @@ class Date(ExpressionNode):
|
|||
def set_source_expressions(self, exprs):
|
||||
self.col, = exprs
|
||||
|
||||
def resolve_expression(self, query, allow_joins, reuse, summarize):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
copy = self.copy()
|
||||
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
|
||||
field = copy.col.output_field
|
||||
|
@ -664,7 +800,7 @@ class DateTime(ExpressionNode):
|
|||
def set_source_expressions(self, exprs):
|
||||
self.col, = exprs
|
||||
|
||||
def resolve_expression(self, query, allow_joins, reuse, summarize):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
copy = self.copy()
|
||||
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
|
||||
field = copy.col.output_field
|
||||
|
|
|
@ -86,6 +86,27 @@ class Q(tree.Node):
|
|||
clone.children.append(child)
|
||||
return clone
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
clause, _ = query._add_q(self, reuse, allow_joins=allow_joins)
|
||||
return clause
|
||||
|
||||
def refs_aggregate(self, existing_aggregates):
|
||||
def _refs_aggregate(obj, existing_aggregates):
|
||||
if not isinstance(obj, tree.Node):
|
||||
aggregate, aggregate_lookups = refs_aggregate(obj[0].split(LOOKUP_SEP), existing_aggregates)
|
||||
if not aggregate and hasattr(obj[1], 'refs_aggregate'):
|
||||
return obj[1].refs_aggregate(existing_aggregates)
|
||||
return aggregate, aggregate_lookups
|
||||
for c in obj.children:
|
||||
aggregate, aggregate_lookups = _refs_aggregate(c, existing_aggregates)
|
||||
if aggregate:
|
||||
return aggregate, aggregate_lookups
|
||||
return False, ()
|
||||
|
||||
if not existing_aggregates:
|
||||
return False
|
||||
return _refs_aggregate(self, existing_aggregates)
|
||||
|
||||
|
||||
class DeferredAttribute(object):
|
||||
"""
|
||||
|
|
|
@ -998,7 +998,9 @@ class SQLUpdateCompiler(SQLCompiler):
|
|||
values, update_params = [], []
|
||||
for field, model, val in self.query.values:
|
||||
if hasattr(val, 'resolve_expression'):
|
||||
val = val.resolve_expression(self.query, allow_joins=False)
|
||||
val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
|
||||
if val.contains_aggregate:
|
||||
raise FieldError("Aggregate functions are not allowed in this query")
|
||||
elif hasattr(val, 'prepare_database_save'):
|
||||
if field.rel:
|
||||
val = val.prepare_database_save(field)
|
||||
|
|
|
@ -961,7 +961,7 @@ class Query(object):
|
|||
self.append_annotation_mask([alias])
|
||||
self.annotations[alias] = annotation
|
||||
|
||||
def prepare_lookup_value(self, value, lookups, can_reuse):
|
||||
def prepare_lookup_value(self, value, lookups, can_reuse, allow_joins=True):
|
||||
# Default lookup if none given is exact.
|
||||
used_joins = []
|
||||
if len(lookups) == 0:
|
||||
|
@ -980,7 +980,7 @@ class Query(object):
|
|||
value = value()
|
||||
elif hasattr(value, 'resolve_expression'):
|
||||
pre_joins = self.alias_refcount.copy()
|
||||
value = value.resolve_expression(self, reuse=can_reuse)
|
||||
value = value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins)
|
||||
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
|
||||
# Subqueries need to use a different set of aliases than the
|
||||
# outer query. Call bump_prefix to change aliases of the inner
|
||||
|
@ -1095,7 +1095,7 @@ class Query(object):
|
|||
(name, lhs.output_field.__class__.__name__))
|
||||
|
||||
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
|
||||
can_reuse=None, connector=AND):
|
||||
can_reuse=None, connector=AND, allow_joins=True):
|
||||
"""
|
||||
Builds a WhereNode for a single filter clause, but doesn't add it
|
||||
to this Query. Query.add_q() will then add this filter to the where
|
||||
|
@ -1125,10 +1125,12 @@ class Query(object):
|
|||
if not arg:
|
||||
raise FieldError("Cannot parse keyword query %r" % arg)
|
||||
lookups, parts, reffed_aggregate = self.solve_lookup_type(arg)
|
||||
if not allow_joins and len(parts) > 1:
|
||||
raise FieldError("Joined field references are not permitted in this query")
|
||||
|
||||
# Work out the lookup type and remove it from the end of 'parts',
|
||||
# if necessary.
|
||||
value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse)
|
||||
value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse, allow_joins)
|
||||
|
||||
clause = self.where_class()
|
||||
if reffed_aggregate:
|
||||
|
@ -1225,11 +1227,11 @@ class Query(object):
|
|||
"""
|
||||
if not self._annotations:
|
||||
return False
|
||||
if not isinstance(obj, Node):
|
||||
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0]
|
||||
or (hasattr(obj[1], 'refs_aggregate')
|
||||
and obj[1].refs_aggregate(self.annotations)[0]))
|
||||
return any(self.need_having(c) for c in obj.children)
|
||||
if hasattr(obj, 'refs_aggregate'):
|
||||
return obj.refs_aggregate(self.annotations)[0]
|
||||
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0]
|
||||
or (hasattr(obj[1], 'refs_aggregate')
|
||||
and obj[1].refs_aggregate(self.annotations)[0]))
|
||||
|
||||
def split_having_parts(self, q_object, negated=False):
|
||||
"""
|
||||
|
@ -1287,7 +1289,7 @@ class Query(object):
|
|||
self.demote_joins(existing_inner)
|
||||
|
||||
def _add_q(self, q_object, used_aliases, branch_negated=False,
|
||||
current_negated=False):
|
||||
current_negated=False, allow_joins=True):
|
||||
"""
|
||||
Adds a Q-object to the current filter.
|
||||
"""
|
||||
|
@ -1301,12 +1303,12 @@ class Query(object):
|
|||
if isinstance(child, Node):
|
||||
child_clause, needed_inner = self._add_q(
|
||||
child, used_aliases, branch_negated,
|
||||
current_negated)
|
||||
current_negated, allow_joins)
|
||||
joinpromoter.add_votes(needed_inner)
|
||||
else:
|
||||
child_clause, needed_inner = self.build_filter(
|
||||
child, can_reuse=used_aliases, branch_negated=branch_negated,
|
||||
current_negated=current_negated, connector=connector)
|
||||
current_negated=current_negated, connector=connector, allow_joins=allow_joins)
|
||||
joinpromoter.add_votes(needed_inner)
|
||||
target_clause.add(child_clause, connector)
|
||||
needed_inner = joinpromoter.update_join_types(self)
|
||||
|
|
|
@ -11,6 +11,7 @@ from django.conf import settings
|
|||
from django.db.models.fields import DateTimeField, Field
|
||||
from django.db.models.sql.datastructures import EmptyResultSet, Empty
|
||||
from django.utils.deprecation import RemovedInDjango19Warning
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.six.moves import range
|
||||
from django.utils import timezone
|
||||
from django.utils import tree
|
||||
|
@ -309,6 +310,30 @@ class WhereNode(tree.Node):
|
|||
clone.children.append(child)
|
||||
return clone
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = self.clone()
|
||||
clone.relabel_aliases(change_map)
|
||||
return clone
|
||||
|
||||
@cached_property
|
||||
def contains_aggregate(self):
|
||||
def _contains_aggregate(obj):
|
||||
if not isinstance(obj, tree.Node):
|
||||
return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False)
|
||||
return any(_contains_aggregate(c) for c in obj.children)
|
||||
|
||||
return _contains_aggregate(self)
|
||||
|
||||
def refs_field(self, aggregate_types, field_types):
|
||||
def _refs_field(obj, aggregate_types, field_types):
|
||||
if not isinstance(obj, tree.Node):
|
||||
if hasattr(obj.rhs, 'refs_field'):
|
||||
return obj.rhs.refs_field(aggregate_types, field_types)
|
||||
return False
|
||||
return any(_refs_field(c, aggregate_types, field_types) for c in obj.children)
|
||||
|
||||
return _refs_field(self, aggregate_types, field_types)
|
||||
|
||||
|
||||
class EmptyWhere(WhereNode):
|
||||
def add(self, data, connector):
|
||||
|
|
|
@ -87,6 +87,7 @@ manipulating the data of your Web application. Learn more about it below:
|
|||
:doc:`Multiple databases <topics/db/multi-db>` |
|
||||
:doc:`Custom lookups <howto/custom-lookups>` |
|
||||
:doc:`Query Expressions <ref/models/expressions>` |
|
||||
:doc:`Conditional Expressions <ref/models/conditional-expressions>` |
|
||||
:doc:`Database Functions <ref/models/database-functions>`
|
||||
|
||||
* **Other:**
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
=======================
|
||||
Conditional Expressions
|
||||
=======================
|
||||
|
||||
.. currentmodule:: django.db.models.expressions
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
Conditional expressions let you use :keyword:`if` ... :keyword:`elif` ...
|
||||
:keyword:`else` logic within filters, annotations, aggregations, and updates. A
|
||||
conditional expression evaluates a series of conditions for each row of a
|
||||
table and returns the matching result expression. Conditional expressions can
|
||||
also be combined and nested like other :doc:`expressions <expressions>`.
|
||||
|
||||
The conditional expression classes
|
||||
==================================
|
||||
|
||||
We'll be using the following model in the subsequent examples::
|
||||
|
||||
from django.db import models
|
||||
|
||||
class Client(models.Model):
|
||||
REGULAR = 'R'
|
||||
GOLD = 'G'
|
||||
PLATINUM = 'P'
|
||||
ACCOUNT_TYPE_CHOICES = (
|
||||
(REGULAR, 'Regular'),
|
||||
(GOLD, 'Gold'),
|
||||
(PLATINUM, 'Platinum'),
|
||||
)
|
||||
name = models.CharField(max_length=50)
|
||||
registered_on = models.DateField()
|
||||
account_type = models.CharField(
|
||||
max_length=1,
|
||||
choices=ACCOUNT_TYPE_CHOICES,
|
||||
default=REGULAR,
|
||||
)
|
||||
|
||||
When
|
||||
----
|
||||
|
||||
.. class:: When(condition=None, then=Value(None), **lookups)
|
||||
|
||||
A ``When()`` object is used to encapsulate a condition and its result for use
|
||||
in the conditional expression. Using a ``When()`` object is similar to using
|
||||
the :meth:`~django.db.models.query.QuerySet.filter` method. The condition can
|
||||
be specified using :ref:`field lookups <field-lookups>` or
|
||||
:class:`~django.db.models.Q` objects. The result is provided using the ``then``
|
||||
keyword.
|
||||
|
||||
Some examples::
|
||||
|
||||
>>> from django.db.models import When, F, Q
|
||||
>>> # String arguments refer to fields; the following two examples are equivalent:
|
||||
>>> When(account_type=Client.GOLD, then='name')
|
||||
>>> When(account_type=Client.GOLD, then=F('name'))
|
||||
>>> # You can use field lookups in the condition
|
||||
>>> from datetime import date
|
||||
>>> When(registered_on__gt=date(2014, 1, 1),
|
||||
... registered_on__lt=date(2015, 1, 1),
|
||||
... then='account_type')
|
||||
>>> # Complex conditions can be created using Q objects
|
||||
>>> When(Q(name__startswith="John") | Q(name__startswith="Paul"),
|
||||
... then='name')
|
||||
|
||||
Keep in mind that each of these values can be an expression.
|
||||
|
||||
.. note::
|
||||
|
||||
Since the ``then`` keyword argument is reserved for the result of the
|
||||
``When()``, there is a potential conflict if a
|
||||
:class:`~django.db.models.Model` has a field named ``then``. This can be
|
||||
resolved in two ways::
|
||||
|
||||
>>> from django.db.models import Value
|
||||
>>> When(then__exact=0, then=Value(1))
|
||||
>>> When(Q(then=0), then=Value(1))
|
||||
|
||||
Case
|
||||
----
|
||||
|
||||
.. class:: Case(*cases, **extra)
|
||||
|
||||
A ``Case()`` expression is like the :keyword:`if` ... :keyword:`elif` ...
|
||||
:keyword:`else` statement in ``Python``. Each ``condition`` in the provided
|
||||
``When()`` objects is evaluated in order, until one evaluates to a
|
||||
truthful value. The ``result`` expression from the matching ``When()`` object
|
||||
is returned.
|
||||
|
||||
A simple example::
|
||||
|
||||
>>>
|
||||
>>> from datetime import date, timedelta
|
||||
>>> from django.db.models import CharField, Case, Value, When
|
||||
>>> Client.objects.create(
|
||||
... name='Jane Doe',
|
||||
... account_type=Client.REGULAR,
|
||||
... registered_on=date.today() - timedelta(days=36))
|
||||
>>> Client.objects.create(
|
||||
... name='James Smith',
|
||||
... account_type=Client.GOLD,
|
||||
... registered_on=date.today() - timedelta(days=5))
|
||||
>>> Client.objects.create(
|
||||
... name='Jack Black',
|
||||
... account_type=Client.PLATINUM,
|
||||
... registered_on=date.today() - timedelta(days=10 * 365))
|
||||
>>> # Get the discount for each Client based on the account type
|
||||
>>> Client.objects.annotate(
|
||||
... discount=Case(
|
||||
... When(account_type=Client.GOLD, then=Value('5%')),
|
||||
... When(account_type=Client.PLATINUM, then=Value('10%')),
|
||||
... default=Value('0%'),
|
||||
... output_field=CharField(),
|
||||
... ),
|
||||
... ).values_list('name', 'discount')
|
||||
[('Jane Doe', '0%'), ('James Smith', '5%'), ('Jack Black', '10%')]
|
||||
|
||||
``Case()`` accepts any number of ``When()`` objects as individual arguments.
|
||||
Other options are provided using keyword arguments. If none of the conditions
|
||||
evaluate to ``TRUE``, then the expression given with the ``default`` keyword
|
||||
argument is returned. If no ``default`` argument is provided, ``Value(None)``
|
||||
is used.
|
||||
|
||||
If we wanted to change our previous query to get the discount based on how long
|
||||
the ``Client`` has been with us, we could do so using lookups::
|
||||
|
||||
>>> a_month_ago = date.today() - timedelta(days=30)
|
||||
>>> a_year_ago = date.today() - timedelta(days=365)
|
||||
>>> # Get the discount for each Client based on the registration date
|
||||
>>> Client.objects.annotate(
|
||||
... discount=Case(
|
||||
... When(registered_on__lte=a_year_ago, then=Value('10%')),
|
||||
... When(registered_on__lte=a_month_ago, then=Value('5%')),
|
||||
... default=Value('0%'),
|
||||
... output_field=CharField(),
|
||||
... )
|
||||
... ).values_list('name', 'discount')
|
||||
[('Jane Doe', '5%'), ('James Smith', '0%'), ('Jack Black', '10%')]
|
||||
|
||||
.. note::
|
||||
|
||||
Remember that the conditions are evaluated in order, so in the above
|
||||
example we get the correct result even though the second condition matches
|
||||
both Jane Doe and Jack Black. This works just like an :keyword:`if` ...
|
||||
:keyword:`elif` ... :keyword:`else` statement in ``Python``.
|
||||
|
||||
Advanced queries
|
||||
================
|
||||
|
||||
Conditional expressions can be used in annotations, aggregations, lookups, and
|
||||
updates. They can also be combined and nested with other expressions. This
|
||||
allows you to make powerful conditional queries.
|
||||
|
||||
Conditional update
|
||||
------------------
|
||||
|
||||
Let's say we want to change the ``account_type`` for our clients to match
|
||||
their registration dates. We can do this using a conditional expression and the
|
||||
:meth:`~django.db.models.query.QuerySet.update` method::
|
||||
|
||||
>>> a_month_ago = date.today() - timedelta(days=30)
|
||||
>>> a_year_ago = date.today() - timedelta(days=365)
|
||||
>>> # Update the account_type for each Client from the registration date
|
||||
>>> Client.objects.update(
|
||||
... account_type=Case(
|
||||
... When(registered_on__lte=a_year_ago,
|
||||
... then=Value(Client.PLATINUM)),
|
||||
... When(registered_on__lte=a_month_ago,
|
||||
... then=Value(Client.GOLD)),
|
||||
... default=Value(Client.REGULAR)
|
||||
... ),
|
||||
... )
|
||||
>>> Client.objects.values_list('name', 'account_type')
|
||||
[('Jane Doe', 'G'), ('James Smith', 'R'), ('Jack Black', 'P')]
|
||||
|
||||
Conditional aggregation
|
||||
-----------------------
|
||||
|
||||
What if we want to find out how many clients there are for each
|
||||
``account_type``? We can nest conditional expression within
|
||||
:ref:`aggregate functions <aggregation-functions>` to achieve this::
|
||||
|
||||
>>> # Create some more Clients first so we can have something to count
|
||||
>>> Client.objects.create(
|
||||
... name='Jean Grey',
|
||||
... account_type=Client.REGULAR,
|
||||
... registered_on=date.today())
|
||||
>>> Client.objects.create(
|
||||
... name='James Bond',
|
||||
... account_type=Client.PLATINUM,
|
||||
... registered_on=date.today())
|
||||
>>> Client.objects.create(
|
||||
... name='Jane Porter',
|
||||
... account_type=Client.PLATINUM,
|
||||
... registered_on=date.today())
|
||||
>>> # Get counts for each value of account_type
|
||||
>>> from django.db.models import IntegerField, Sum
|
||||
>>> Client.objects.aggregate(
|
||||
... regular=Sum(
|
||||
... Case(When(account_type=Client.REGULAR, then=Value(1)),
|
||||
... output_field=IntegerField())
|
||||
... ),
|
||||
... gold=Sum(
|
||||
... Case(When(account_type=Client.GOLD, then=Value(1)),
|
||||
... output_field=IntegerField())
|
||||
... ),
|
||||
... platinum=Sum(
|
||||
... Case(When(account_type=Client.PLATINUM, then=Value(1)),
|
||||
... output_field=IntegerField())
|
||||
... )
|
||||
... )
|
||||
{'regular': 2, 'gold': 1, 'platinum': 3}
|
|
@ -332,6 +332,15 @@ instantiating the model field as any arguments relating to data validation
|
|||
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
|
||||
output value.
|
||||
|
||||
Conditional expressions
|
||||
-----------------------
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
Conditional expressions allow you to use :keyword:`if` ... :keyword:`elif` ...
|
||||
:keyword:`else` logic in queries. Django natively supports SQL ``CASE``
|
||||
expressions. For more details see :doc:`conditional-expressions`.
|
||||
|
||||
Technical Information
|
||||
=====================
|
||||
|
||||
|
|
|
@ -16,4 +16,5 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`.
|
|||
querysets
|
||||
lookups
|
||||
expressions
|
||||
conditional-expressions
|
||||
database-functions
|
||||
|
|
|
@ -93,16 +93,20 @@ New data types
|
|||
backends. There is a corresponding :class:`form field
|
||||
<django.forms.DurationField>`.
|
||||
|
||||
Query Expressions and Database Functions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Query Expressions, Conditional Expressions, and Database Functions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
:doc:`Query Expressions </ref/models/expressions>` allow users to create,
|
||||
:doc:`Query Expressions </ref/models/expressions>` allow you to create,
|
||||
customize, and compose complex SQL expressions. This has enabled annotate
|
||||
to accept expressions other than aggregates. Aggregates are now able to
|
||||
reference multiple fields, as well as perform arithmetic, similar to ``F()``
|
||||
objects. :meth:`~django.db.models.query.QuerySet.order_by` has also gained the
|
||||
ability to accept expressions.
|
||||
|
||||
:doc:`Conditional Expressions </ref/models/conditional-expressions>` allow
|
||||
you to use :keyword:`if` ... :keyword:`elif` ... :keyword:`else` logic within
|
||||
queries.
|
||||
|
||||
A collection of :doc:`database functions </ref/models/database-functions>` is
|
||||
also included with functionality such as
|
||||
:class:`~django.db.models.functions.Coalesce`,
|
||||
|
|
|
@ -56,3 +56,19 @@ class Experiment(models.Model):
|
|||
|
||||
def duration(self):
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Time(models.Model):
|
||||
time = models.TimeField(null=True)
|
||||
|
||||
def __str__(self):
|
||||
return "%s" % self.time
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class UUID(models.Model):
|
||||
uuid = models.UUIDField(null=True)
|
||||
|
||||
def __str__(self):
|
||||
return "%s" % self.uuid
|
||||
|
|
|
@ -2,15 +2,16 @@ from __future__ import unicode_literals
|
|||
|
||||
from copy import deepcopy
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import connection, transaction, DatabaseError
|
||||
from django.db.models import F, Value
|
||||
from django.db.models import F, Value, TimeField, UUIDField
|
||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
from django.test.utils import Approximate
|
||||
from django.utils import six
|
||||
|
||||
from .models import Company, Employee, Number, Experiment
|
||||
from .models import Company, Employee, Number, Experiment, Time, UUID
|
||||
|
||||
|
||||
class BasicExpressionsTests(TestCase):
|
||||
|
@ -799,3 +800,15 @@ class FTimeDeltaTests(TestCase):
|
|||
over_estimate = [e.name for e in
|
||||
Experiment.objects.filter(estimated_time__lt=F('end') - F('start'))]
|
||||
self.assertEqual(over_estimate, ['e4'])
|
||||
|
||||
|
||||
class ValueTests(TestCase):
|
||||
def test_update_TimeField_using_Value(self):
|
||||
Time.objects.create()
|
||||
Time.objects.update(time=Value(datetime.time(1), output_field=TimeField()))
|
||||
self.assertEqual(Time.objects.get().time, datetime.time(1))
|
||||
|
||||
def test_update_UUIDField_using_Value(self):
|
||||
UUID.objects.create()
|
||||
UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField()))
|
||||
self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012'))
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import models
|
||||
from django.utils.encoding import python_2_unicode_compatible
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class CaseTestModel(models.Model):
|
||||
integer = models.IntegerField()
|
||||
integer2 = models.IntegerField(null=True)
|
||||
string = models.CharField(max_length=100, default='')
|
||||
|
||||
big_integer = models.BigIntegerField(null=True)
|
||||
binary = models.BinaryField(default=b'')
|
||||
boolean = models.BooleanField(default=False)
|
||||
comma_separated_integer = models.CommaSeparatedIntegerField(max_length=100, default='')
|
||||
date = models.DateField(null=True, db_column='date_field')
|
||||
date_time = models.DateTimeField(null=True)
|
||||
decimal = models.DecimalField(max_digits=2, decimal_places=1, null=True, db_column='decimal_field')
|
||||
duration = models.DurationField(null=True)
|
||||
email = models.EmailField(default='')
|
||||
file = models.FileField(null=True, db_column='file_field')
|
||||
file_path = models.FilePathField(null=True)
|
||||
float = models.FloatField(null=True, db_column='float_field')
|
||||
image = models.ImageField(null=True)
|
||||
ip_address = models.IPAddressField(null=True)
|
||||
generic_ip_address = models.GenericIPAddressField(null=True)
|
||||
null_boolean = models.NullBooleanField()
|
||||
positive_integer = models.PositiveIntegerField(null=True)
|
||||
positive_small_integer = models.PositiveSmallIntegerField(null=True)
|
||||
slug = models.SlugField(default='')
|
||||
small_integer = models.SmallIntegerField(null=True)
|
||||
text = models.TextField(default='')
|
||||
time = models.TimeField(null=True, db_column='time_field')
|
||||
url = models.URLField(default='')
|
||||
uuid = models.UUIDField(null=True)
|
||||
fk = models.ForeignKey('self', null=True)
|
||||
|
||||
def __str__(self):
|
||||
return "%i, %s" % (self.integer, self.string)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class O2OCaseTestModel(models.Model):
|
||||
o2o = models.OneToOneField(CaseTestModel, related_name='o2o_rel')
|
||||
integer = models.IntegerField()
|
||||
|
||||
def __str__(self):
|
||||
return "%i, %s" % (self.id, self.o2o)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class FKCaseTestModel(models.Model):
|
||||
fk = models.ForeignKey(CaseTestModel, related_name='fk_rel')
|
||||
integer = models.IntegerField()
|
||||
|
||||
def __str__(self):
|
||||
return "%i, %s" % (self.id, self.fk)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Client(models.Model):
|
||||
REGULAR = 'R'
|
||||
GOLD = 'G'
|
||||
PLATINUM = 'P'
|
||||
ACCOUNT_TYPE_CHOICES = (
|
||||
(REGULAR, 'Regular'),
|
||||
(GOLD, 'Gold'),
|
||||
(PLATINUM, 'Platinum'),
|
||||
)
|
||||
name = models.CharField(max_length=50)
|
||||
registered_on = models.DateField()
|
||||
account_type = models.CharField(
|
||||
max_length=1,
|
||||
choices=ACCOUNT_TYPE_CHOICES,
|
||||
default=REGULAR,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue