django1/django/db/models/expressions.py

637 lines
21 KiB
Python

import copy
import datetime
from django.conf import settings
from django.core.exceptions import FieldError
from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import refs_aggregate
from django.utils import timezone
from django.utils.functional import cached_property
class CombinableMixin(object):
"""
Provides the ability to combine one or two objects with
some connector. For example F('foo') + F('bar').
"""
# Arithmetic connectors
ADD = '+'
SUB = '-'
MUL = '*'
DIV = '/'
POW = '^'
# The following is a quoted % operator - it is quoted because it can be
# used in strings that also have parameter substitution.
MOD = '%%'
# Bitwise operators - note that these are generated by .bitand()
# and .bitor(), the '&' and '|' are reserved for boolean operator
# usage.
BITAND = '&'
BITOR = '|'
def _combine(self, other, connector, reversed, node=None):
if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression
if isinstance(other, datetime.timedelta):
other = DurationValue(other, output_field=fields.DurationField())
else:
other = Value(other)
if reversed:
return Expression(other, connector, self)
return Expression(self, connector, other)
#############
# OPERATORS #
#############
def __add__(self, other):
return self._combine(other, self.ADD, False)
def __sub__(self, other):
return self._combine(other, self.SUB, False)
def __mul__(self, other):
return self._combine(other, self.MUL, False)
def __truediv__(self, other):
return self._combine(other, self.DIV, False)
def __div__(self, other): # Python 2 compatibility
return type(self).__truediv__(self, other)
def __mod__(self, other):
return self._combine(other, self.MOD, False)
def __pow__(self, other):
return self._combine(other, self.POW, False)
def __and__(self, other):
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
)
def bitand(self, other):
return self._combine(other, self.BITAND, False)
def __or__(self, other):
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
)
def bitor(self, other):
return self._combine(other, self.BITOR, False)
def __radd__(self, other):
return self._combine(other, self.ADD, True)
def __rsub__(self, other):
return self._combine(other, self.SUB, True)
def __rmul__(self, other):
return self._combine(other, self.MUL, True)
def __rtruediv__(self, other):
return self._combine(other, self.DIV, True)
def __rdiv__(self, other): # Python 2 compatibility
return type(self).__rtruediv__(self, other)
def __rmod__(self, other):
return self._combine(other, self.MOD, True)
def __rpow__(self, other):
return self._combine(other, self.POW, True)
def __rand__(self, other):
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
)
def __ror__(self, other):
raise NotImplementedError(
"Use .bitand() and .bitor() for bitwise logical operations."
)
class ExpressionNode(CombinableMixin):
"""
Base class for all query expressions.
"""
# aggregate specific fields
is_summary = False
def get_db_converters(self, connection):
return [self.convert_value]
def __init__(self, output_field=None):
self._output_field = output_field
def get_source_expressions(self):
return []
def set_source_expressions(self, exprs):
assert len(exprs) == 0
def as_sql(self, compiler, connection):
"""
Responsible for returning a (sql, [params]) tuple to be included
in the current query.
Different backends can provide their own implementation, by
providing an `as_{vendor}` method and patching the Expression:
```
def override_as_sql(self, compiler, connection):
# custom logic
return super(ExpressionNode, self).as_sql(compiler, connection)
setattr(ExpressionNode, 'as_' + connection.vendor, override_as_sql)
```
Arguments:
* compiler: the query compiler responsible for generating the query.
Must have a compile method, returning a (sql, [params]) tuple.
Calling compiler(value) will return a quoted `value`.
* connection: the database connection used for the current query.
Returns: (sql, params)
Where `sql` is a string containing ordered sql parameters to be
replaced with the elements of the list `params`.
"""
raise NotImplementedError("Subclasses must implement as_sql()")
@cached_property
def contains_aggregate(self):
for expr in self.get_source_expressions():
if expr and expr.contains_aggregate:
return True
return False
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
"""
Provides the chance to do any preprocessing or validation before being
added to the query.
Arguments:
* query: the backend query implementation
* allow_joins: boolean allowing or denying use of joins
in this query
* reuse: a set of reusable joins for multijoins
* summarize: a terminal aggregate clause
Returns: an ExpressionNode to be added to the query.
"""
c = self.copy()
c.is_summary = summarize
return c
def _prepare(self):
"""
Hook used by Field.get_prep_lookup() to do custom preparation.
"""
return self
@property
def field(self):
return self.output_field
@cached_property
def output_field(self):
"""
Returns the output type of this expressions.
"""
if self._output_field_or_none is None:
raise FieldError("Cannot resolve expression type, unknown output_field")
return self._output_field_or_none
@cached_property
def _output_field_or_none(self):
"""
Returns the output field of this expression, or None if no output type
can be resolved. Note that the 'output_field' property will raise
FieldError if no type can be resolved, but this attribute allows for
None values.
"""
if self._output_field is None:
self._resolve_output_field()
return self._output_field
def _resolve_output_field(self):
"""
Attempts to infer the output type of the expression. If the output
fields of all source fields match then we can simply infer the same
type here.
"""
if self._output_field is None:
sources = self.get_source_fields()
num_sources = len(sources)
if num_sources == 0:
self._output_field = None
else:
self._output_field = sources[0]
for source in sources:
if source is not None and not isinstance(self._output_field, source.__class__):
raise FieldError(
"Expression contains mixed types. You must set output_field")
def convert_value(self, value, connection):
"""
Expressions provide their own converters because users have the option
of manually specifying the output_field which may be a different type
from the one the database returns.
"""
field = self.output_field
internal_type = field.get_internal_type()
if value is None:
return value
elif internal_type == 'FloatField':
return float(value)
elif internal_type.endswith('IntegerField'):
return int(value)
elif internal_type == 'DecimalField':
return backend_utils.typecast_decimal(value)
return value
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def get_transform(self, name):
return self.output_field.get_transform(name)
def relabeled_clone(self, change_map):
clone = self.copy()
clone.set_source_expressions(
[e.relabeled_clone(change_map) for e in self.get_source_expressions()])
return clone
def copy(self):
c = copy.copy(self)
c.copied = True
return c
def refs_aggregate(self, existing_aggregates):
"""
Does this expression contain a reference to some of the
existing aggregates? If so, returns the aggregate and also
the lookup parts that *weren't* found. So, if
exsiting_aggregates = {'max_id': Max('id')}
self.name = 'max_id'
queryset.filter(max_id__range=[10,100])
then this method will return Max('id') and those parts of the
name that weren't found. In this case `max_id` is found and the range
portion is returned as ('range',).
"""
for node in self.get_source_expressions():
agg, lookup = node.refs_aggregate(existing_aggregates)
if agg:
return agg, lookup
return False, ()
def refs_field(self, aggregate_types, field_types):
"""
Helper method for check_aggregate_support on backends
"""
return any(
node.refs_field(aggregate_types, field_types)
for node in self.get_source_expressions())
def prepare_database_save(self, field):
return self
def get_group_by_cols(self):
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
return cols
def get_source_fields(self):
"""
Returns the underlying field types used by this
aggregate.
"""
return [e._output_field_or_none for e in self.get_source_expressions()]
class Expression(ExpressionNode):
def __init__(self, lhs, connector, rhs, output_field=None):
super(Expression, self).__init__(output_field=output_field)
self.connector = connector
self.lhs = lhs
self.rhs = rhs
def get_source_expressions(self):
return [self.lhs, self.rhs]
def set_source_expressions(self, exprs):
self.lhs, self.rhs = exprs
def as_sql(self, compiler, connection):
try:
lhs_output = self.lhs.output_field
except FieldError:
lhs_output = None
try:
rhs_output = self.rhs.output_field
except FieldError:
rhs_output = None
if (not connection.features.has_native_duration_field and
((lhs_output and lhs_output.get_internal_type() == 'DurationField')
or (rhs_output and rhs_output.get_internal_type() == 'DurationField'))):
return DurationExpression(self.lhs, self.connector, self.rhs).as_sql(compiler, connection)
expressions = []
expression_params = []
sql, params = compiler.compile(self.lhs)
expressions.append(sql)
expression_params.extend(params)
sql, params = compiler.compile(self.rhs)
expressions.append(sql)
expression_params.extend(params)
# order of precedence
expression_wrapper = '(%s)'
sql = connection.ops.combine_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = self.copy()
c.is_summary = summarize
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize)
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize)
return c
class DurationExpression(Expression):
def compile(self, side, compiler, connection):
if not isinstance(side, DurationValue):
try:
output = side.output_field
except FieldError:
pass
if output.get_internal_type() == 'DurationField':
sql, params = compiler.compile(side)
return connection.ops.format_for_duration_arithmetic(sql), params
return compiler.compile(side)
def as_sql(self, compiler, connection):
expressions = []
expression_params = []
sql, params = self.compile(self.lhs, compiler, connection)
expressions.append(sql)
expression_params.extend(params)
sql, params = self.compile(self.rhs, compiler, connection)
expressions.append(sql)
expression_params.extend(params)
# order of precedence
expression_wrapper = '(%s)'
sql = connection.ops.combine_duration_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
class F(CombinableMixin):
"""
An object capable of resolving references to existing query objects.
"""
def __init__(self, name):
"""
Arguments:
* name: the name of the field this expression references
"""
self.name = name
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
def refs_aggregate(self, existing_aggregates):
return refs_aggregate(self.name.split(LOOKUP_SEP), existing_aggregates)
class Func(ExpressionNode):
"""
A SQL function call.
"""
function = None
template = '%(function)s(%(expressions)s)'
arg_joiner = ', '
def __init__(self, *expressions, **extra):
output_field = extra.pop('output_field', None)
super(Func, self).__init__(output_field=output_field)
self.source_expressions = self._parse_expressions(*expressions)
self.extra = extra
def get_source_expressions(self):
return self.source_expressions
def set_source_expressions(self, exprs):
self.source_expressions = exprs
def _parse_expressions(self, *expressions):
return [
arg if hasattr(arg, 'resolve_expression') else F(arg)
for arg in expressions
]
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = self.copy()
c.is_summary = summarize
for pos, arg in enumerate(c.source_expressions):
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize)
return c
def as_sql(self, compiler, connection, function=None, template=None):
sql_parts = []
params = []
for arg in self.source_expressions:
arg_sql, arg_params = compiler.compile(arg)
sql_parts.append(arg_sql)
params.extend(arg_params)
if function is None:
self.extra['function'] = self.extra.get('function', self.function)
else:
self.extra['function'] = function
self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts)
template = template or self.extra.get('template', self.template)
return template % self.extra, params
def copy(self):
copy = super(Func, self).copy()
copy.source_expressions = self.source_expressions[:]
copy.extra = self.extra.copy()
return copy
class Value(ExpressionNode):
"""
Represents a wrapped value as a node within an expression
"""
def __init__(self, value, output_field=None):
"""
Arguments:
* value: the value this expression represents. The value will be
added into the sql parameter list and properly quoted.
* output_field: an instance of the model field type that this
expression will return, such as IntegerField() or CharField().
"""
super(Value, self).__init__(output_field=output_field)
self.value = value
def as_sql(self, compiler, connection):
return '%s', [self.value]
class DurationValue(Value):
def as_sql(self, compiler, connection):
if (connection.features.has_native_duration_field and
connection.features.driver_supports_timedelta_args):
return super(DurationValue, self).as_sql(compiler, connection)
return connection.ops.date_interval_sql(self.value)
class Col(ExpressionNode):
def __init__(self, alias, target, source=None):
if source is None:
source = target
super(Col, self).__init__(output_field=source)
self.alias, self.target = alias, target
def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
def get_group_by_cols(self):
return [self]
class Ref(ExpressionNode):
"""
Reference to column alias of the query. For example, Ref('sum_cost') in
qs.annotate(sum_cost=Sum('cost')) query.
"""
def __init__(self, refs, source):
super(Ref, self).__init__()
self.source = source
self.refs = refs
def get_source_expressions(self):
return [self.source]
def set_source_expressions(self, exprs):
self.source, = exprs
def relabeled_clone(self, relabels):
return self
def as_sql(self, compiler, connection):
return "%s" % compiler.quote_name_unless_alias(self.refs), []
def get_group_by_cols(self):
return [self]
class Date(ExpressionNode):
"""
Add a date selection column.
"""
def __init__(self, lookup, lookup_type):
super(Date, self).__init__(output_field=fields.DateField())
self.lookup = lookup
self.col = None
self.lookup_type = lookup_type
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = exprs
def resolve_expression(self, query, allow_joins, reuse, summarize):
copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field
assert isinstance(field, fields.DateField), "%r isn't a DateField." % field.name
if settings.USE_TZ:
assert not isinstance(field, fields.DateTimeField), (
"%r is a DateTimeField, not a DateField." % field.name
)
return copy
def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(compiler, connection)
assert not(params)
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
def copy(self):
copy = super(Date, self).copy()
copy.lookup = self.lookup
copy.lookup_type = self.lookup_type
return copy
def convert_value(self, value, connection):
if isinstance(value, datetime.datetime):
value = value.date()
return value
class DateTime(ExpressionNode):
"""
Add a datetime selection column.
"""
def __init__(self, lookup, lookup_type, tzinfo):
super(DateTime, self).__init__(output_field=fields.DateTimeField())
self.lookup = lookup
self.col = None
self.lookup_type = lookup_type
if tzinfo is None:
self.tzname = None
else:
self.tzname = timezone._get_timezone_name(tzinfo)
self.tzinfo = tzinfo
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = exprs
def resolve_expression(self, query, allow_joins, reuse, summarize):
copy = self.copy()
copy.col = query.resolve_ref(self.lookup, allow_joins, reuse, summarize)
field = copy.col.output_field
assert isinstance(field, fields.DateTimeField), (
"%r isn't a DateTimeField." % field.name
)
return copy
def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(compiler, connection)
assert not(params)
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)
def copy(self):
copy = super(DateTime, self).copy()
copy.lookup = self.lookup
copy.lookup_type = self.lookup_type
copy.tzname = self.tzname
return copy
def convert_value(self, value, connection):
if settings.USE_TZ:
if value is None:
raise ValueError(
"Database returned an invalid value in QuerySet.datetimes(). "
"Are time zone definitions for your database and pytz installed?"
)
value = value.replace(tzinfo=None)
value = timezone.make_aware(value, self.tzinfo)
return value