1122 lines
39 KiB
Python
1122 lines
39 KiB
Python
import copy
|
|
import datetime
|
|
from contextlib import suppress
|
|
from decimal import Decimal
|
|
|
|
from django.core.exceptions import EmptyResultSet, FieldError
|
|
from django.db.models import fields
|
|
from django.db.models.query_utils import Q
|
|
from django.utils.deconstruct import deconstructible
|
|
from django.utils.functional import cached_property
|
|
|
|
|
|
class SQLiteNumericMixin:
|
|
"""
|
|
Some expressions with output_field=DecimalField() must be cast to
|
|
numeric to be properly filtered.
|
|
"""
|
|
def as_sqlite(self, compiler, connection, **extra_context):
|
|
sql, params = self.as_sql(compiler, connection, **extra_context)
|
|
with suppress(FieldError):
|
|
if self.output_field.get_internal_type() == 'DecimalField':
|
|
sql = 'CAST(%s AS NUMERIC)' % sql
|
|
return sql, params
|
|
|
|
|
|
class Combinable:
|
|
"""
|
|
Provide the ability to combine one or two objects with
|
|
some connector. For example F('foo') + F('bar').
|
|
"""
|
|
|
|
# Arithmetic connectors
|
|
ADD = '+'
|
|
SUB = '-'
|
|
MUL = '*'
|
|
DIV = '/'
|
|
POW = '^'
|
|
# The following is a quoted % operator - it is quoted because it can be
|
|
# used in strings that also have parameter substitution.
|
|
MOD = '%%'
|
|
|
|
# Bitwise operators - note that these are generated by .bitand()
|
|
# and .bitor(), the '&' and '|' are reserved for boolean operator
|
|
# usage.
|
|
BITAND = '&'
|
|
BITOR = '|'
|
|
BITLEFTSHIFT = '<<'
|
|
BITRIGHTSHIFT = '>>'
|
|
|
|
def _combine(self, other, connector, reversed):
|
|
if not hasattr(other, 'resolve_expression'):
|
|
# everything must be resolvable to an expression
|
|
if isinstance(other, datetime.timedelta):
|
|
other = DurationValue(other, output_field=fields.DurationField())
|
|
else:
|
|
other = Value(other)
|
|
|
|
if reversed:
|
|
return CombinedExpression(other, connector, self)
|
|
return CombinedExpression(self, connector, other)
|
|
|
|
#############
|
|
# OPERATORS #
|
|
#############
|
|
|
|
def __add__(self, other):
|
|
return self._combine(other, self.ADD, False)
|
|
|
|
def __sub__(self, other):
|
|
return self._combine(other, self.SUB, False)
|
|
|
|
def __mul__(self, other):
|
|
return self._combine(other, self.MUL, False)
|
|
|
|
def __truediv__(self, other):
|
|
return self._combine(other, self.DIV, False)
|
|
|
|
def __mod__(self, other):
|
|
return self._combine(other, self.MOD, False)
|
|
|
|
def __pow__(self, other):
|
|
return self._combine(other, self.POW, False)
|
|
|
|
def __and__(self, other):
|
|
raise NotImplementedError(
|
|
"Use .bitand() and .bitor() for bitwise logical operations."
|
|
)
|
|
|
|
def bitand(self, other):
|
|
return self._combine(other, self.BITAND, False)
|
|
|
|
def bitleftshift(self, other):
|
|
return self._combine(other, self.BITLEFTSHIFT, False)
|
|
|
|
def bitrightshift(self, other):
|
|
return self._combine(other, self.BITRIGHTSHIFT, False)
|
|
|
|
def __or__(self, other):
|
|
raise NotImplementedError(
|
|
"Use .bitand() and .bitor() for bitwise logical operations."
|
|
)
|
|
|
|
def bitor(self, other):
|
|
return self._combine(other, self.BITOR, False)
|
|
|
|
def __radd__(self, other):
|
|
return self._combine(other, self.ADD, True)
|
|
|
|
def __rsub__(self, other):
|
|
return self._combine(other, self.SUB, True)
|
|
|
|
def __rmul__(self, other):
|
|
return self._combine(other, self.MUL, True)
|
|
|
|
def __rtruediv__(self, other):
|
|
return self._combine(other, self.DIV, True)
|
|
|
|
def __rmod__(self, other):
|
|
return self._combine(other, self.MOD, True)
|
|
|
|
def __rpow__(self, other):
|
|
return self._combine(other, self.POW, True)
|
|
|
|
def __rand__(self, other):
|
|
raise NotImplementedError(
|
|
"Use .bitand() and .bitor() for bitwise logical operations."
|
|
)
|
|
|
|
def __ror__(self, other):
|
|
raise NotImplementedError(
|
|
"Use .bitand() and .bitor() for bitwise logical operations."
|
|
)
|
|
|
|
|
|
@deconstructible
|
|
class BaseExpression:
|
|
"""Base class for all query expressions."""
|
|
|
|
# aggregate specific fields
|
|
is_summary = False
|
|
_output_field_resolved_to_none = False
|
|
|
|
def __init__(self, output_field=None):
|
|
if output_field is not None:
|
|
self.output_field = output_field
|
|
|
|
def __getstate__(self):
|
|
# This method required only for Python 3.4.
|
|
state = self.__dict__.copy()
|
|
state.pop('convert_value', None)
|
|
return state
|
|
|
|
def get_db_converters(self, connection):
|
|
return (
|
|
[]
|
|
if self.convert_value is self._convert_value_noop else
|
|
[self.convert_value]
|
|
) + self.output_field.get_db_converters(connection)
|
|
|
|
def get_source_expressions(self):
|
|
return []
|
|
|
|
def set_source_expressions(self, exprs):
|
|
assert len(exprs) == 0
|
|
|
|
def _parse_expressions(self, *expressions):
|
|
return [
|
|
arg if hasattr(arg, 'resolve_expression') else (
|
|
F(arg) if isinstance(arg, str) else Value(arg)
|
|
) for arg in expressions
|
|
]
|
|
|
|
def as_sql(self, compiler, connection):
|
|
"""
|
|
Responsible for returning a (sql, [params]) tuple to be included
|
|
in the current query.
|
|
|
|
Different backends can provide their own implementation, by
|
|
providing an `as_{vendor}` method and patching the Expression:
|
|
|
|
```
|
|
def override_as_sql(self, compiler, connection):
|
|
# custom logic
|
|
return super().as_sql(compiler, connection)
|
|
setattr(Expression, 'as_' + connection.vendor, override_as_sql)
|
|
```
|
|
|
|
Arguments:
|
|
* compiler: the query compiler responsible for generating the query.
|
|
Must have a compile method, returning a (sql, [params]) tuple.
|
|
Calling compiler(value) will return a quoted `value`.
|
|
|
|
* connection: the database connection used for the current query.
|
|
|
|
Return: (sql, params)
|
|
Where `sql` is a string containing ordered sql parameters to be
|
|
replaced with the elements of the list `params`.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement as_sql()")
|
|
|
|
@cached_property
|
|
def contains_aggregate(self):
|
|
for expr in self.get_source_expressions():
|
|
if expr and expr.contains_aggregate:
|
|
return True
|
|
return False
|
|
|
|
@cached_property
|
|
def contains_column_references(self):
|
|
for expr in self.get_source_expressions():
|
|
if expr and expr.contains_column_references:
|
|
return True
|
|
return False
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
"""
|
|
Provide the chance to do any preprocessing or validation before being
|
|
added to the query.
|
|
|
|
Arguments:
|
|
* query: the backend query implementation
|
|
* allow_joins: boolean allowing or denying use of joins
|
|
in this query
|
|
* reuse: a set of reusable joins for multijoins
|
|
* summarize: a terminal aggregate clause
|
|
* for_save: whether this expression about to be used in a save or update
|
|
|
|
Return: an Expression to be added to the query.
|
|
"""
|
|
c = self.copy()
|
|
c.is_summary = summarize
|
|
c.set_source_expressions([
|
|
expr.resolve_expression(query, allow_joins, reuse, summarize)
|
|
for expr in c.get_source_expressions()
|
|
])
|
|
return c
|
|
|
|
def _prepare(self, field):
|
|
"""Hook used by Lookup.get_prep_lookup() to do custom preparation."""
|
|
return self
|
|
|
|
@property
|
|
def field(self):
|
|
return self.output_field
|
|
|
|
@cached_property
|
|
def output_field(self):
|
|
"""Return the output type of this expressions."""
|
|
output_field = self._resolve_output_field()
|
|
if output_field is None:
|
|
self._output_field_resolved_to_none = True
|
|
raise FieldError('Cannot resolve expression type, unknown output_field')
|
|
return output_field
|
|
|
|
@cached_property
|
|
def _output_field_or_none(self):
|
|
"""
|
|
Return the output field of this expression, or None if
|
|
_resolve_output_field() didn't return an output type.
|
|
"""
|
|
try:
|
|
return self.output_field
|
|
except FieldError:
|
|
if not self._output_field_resolved_to_none:
|
|
raise
|
|
|
|
def _resolve_output_field(self):
|
|
"""
|
|
Attempt to infer the output type of the expression. If the output
|
|
fields of all source fields match then, simply infer the same type
|
|
here. This isn't always correct, but it makes sense most of the time.
|
|
|
|
Consider the difference between `2 + 2` and `2 / 3`. Inferring
|
|
the type here is a convenience for the common case. The user should
|
|
supply their own output_field with more complex computations.
|
|
|
|
If a source's output field resolves to None, exclude it from this check.
|
|
If all sources are None, then an error is raised higher up the stack in
|
|
the output_field property.
|
|
"""
|
|
sources_iter = (source for source in self.get_source_fields() if source is not None)
|
|
for output_field in sources_iter:
|
|
if any(not isinstance(output_field, source.__class__) for source in sources_iter):
|
|
raise FieldError('Expression contains mixed types. You must set output_field.')
|
|
return output_field
|
|
|
|
@staticmethod
|
|
def _convert_value_noop(value, expression, connection):
|
|
return value
|
|
|
|
@cached_property
|
|
def convert_value(self):
|
|
"""
|
|
Expressions provide their own converters because users have the option
|
|
of manually specifying the output_field which may be a different type
|
|
from the one the database returns.
|
|
"""
|
|
field = self.output_field
|
|
internal_type = field.get_internal_type()
|
|
if internal_type == 'FloatField':
|
|
return lambda value, expression, connection: None if value is None else float(value)
|
|
elif internal_type.endswith('IntegerField'):
|
|
return lambda value, expression, connection: None if value is None else int(value)
|
|
elif internal_type == 'DecimalField':
|
|
return lambda value, expression, connection: None if value is None else Decimal(value)
|
|
return self._convert_value_noop
|
|
|
|
def get_lookup(self, lookup):
|
|
return self.output_field.get_lookup(lookup)
|
|
|
|
def get_transform(self, name):
|
|
return self.output_field.get_transform(name)
|
|
|
|
def relabeled_clone(self, change_map):
|
|
clone = self.copy()
|
|
clone.set_source_expressions(
|
|
[e.relabeled_clone(change_map) for e in self.get_source_expressions()])
|
|
return clone
|
|
|
|
def copy(self):
|
|
return copy.copy(self)
|
|
|
|
def get_group_by_cols(self):
|
|
if not self.contains_aggregate:
|
|
return [self]
|
|
cols = []
|
|
for source in self.get_source_expressions():
|
|
cols.extend(source.get_group_by_cols())
|
|
return cols
|
|
|
|
def get_source_fields(self):
|
|
"""Return the underlying field types used by this aggregate."""
|
|
return [e._output_field_or_none for e in self.get_source_expressions()]
|
|
|
|
def asc(self, **kwargs):
|
|
return OrderBy(self, **kwargs)
|
|
|
|
def desc(self, **kwargs):
|
|
return OrderBy(self, descending=True, **kwargs)
|
|
|
|
def reverse_ordering(self):
|
|
return self
|
|
|
|
def flatten(self):
|
|
"""
|
|
Recursively yield this expression and all subexpressions, in
|
|
depth-first order.
|
|
"""
|
|
yield self
|
|
for expr in self.get_source_expressions():
|
|
if expr:
|
|
yield from expr.flatten()
|
|
|
|
def __eq__(self, other):
|
|
if self.__class__ != other.__class__:
|
|
return False
|
|
path, args, kwargs = self.deconstruct()
|
|
other_path, other_args, other_kwargs = other.deconstruct()
|
|
if (path, args) == (other_path, other_args):
|
|
kwargs = kwargs.copy()
|
|
other_kwargs = other_kwargs.copy()
|
|
output_field = type(kwargs.pop('output_field', None))
|
|
other_output_field = type(other_kwargs.pop('output_field', None))
|
|
if output_field == other_output_field:
|
|
return kwargs == other_kwargs
|
|
return False
|
|
|
|
def __hash__(self):
|
|
path, args, kwargs = self.deconstruct()
|
|
h = hash(path) ^ hash(args)
|
|
for kwarg in kwargs.items():
|
|
h ^= hash(kwarg)
|
|
return h
|
|
|
|
|
|
class Expression(BaseExpression, Combinable):
|
|
"""An expression that can be combined with other expressions."""
|
|
pass
|
|
|
|
|
|
class CombinedExpression(SQLiteNumericMixin, Expression):
|
|
|
|
def __init__(self, lhs, connector, rhs, output_field=None):
|
|
super().__init__(output_field=output_field)
|
|
self.connector = connector
|
|
self.lhs = lhs
|
|
self.rhs = rhs
|
|
|
|
def __repr__(self):
|
|
return "<{}: {}>".format(self.__class__.__name__, self)
|
|
|
|
def __str__(self):
|
|
return "{} {} {}".format(self.lhs, self.connector, self.rhs)
|
|
|
|
def get_source_expressions(self):
|
|
return [self.lhs, self.rhs]
|
|
|
|
def set_source_expressions(self, exprs):
|
|
self.lhs, self.rhs = exprs
|
|
|
|
def as_sql(self, compiler, connection):
|
|
try:
|
|
lhs_output = self.lhs.output_field
|
|
except FieldError:
|
|
lhs_output = None
|
|
try:
|
|
rhs_output = self.rhs.output_field
|
|
except FieldError:
|
|
rhs_output = None
|
|
if (not connection.features.has_native_duration_field and
|
|
((lhs_output and lhs_output.get_internal_type() == 'DurationField') or
|
|
(rhs_output and rhs_output.get_internal_type() == 'DurationField'))):
|
|
return DurationExpression(self.lhs, self.connector, self.rhs).as_sql(compiler, connection)
|
|
if (lhs_output and rhs_output and self.connector == self.SUB and
|
|
lhs_output.get_internal_type() in {'DateField', 'DateTimeField', 'TimeField'} and
|
|
lhs_output.get_internal_type() == rhs_output.get_internal_type()):
|
|
return TemporalSubtraction(self.lhs, self.rhs).as_sql(compiler, connection)
|
|
expressions = []
|
|
expression_params = []
|
|
sql, params = compiler.compile(self.lhs)
|
|
expressions.append(sql)
|
|
expression_params.extend(params)
|
|
sql, params = compiler.compile(self.rhs)
|
|
expressions.append(sql)
|
|
expression_params.extend(params)
|
|
# order of precedence
|
|
expression_wrapper = '(%s)'
|
|
sql = connection.ops.combine_expression(self.connector, expressions)
|
|
return expression_wrapper % sql, expression_params
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
c = self.copy()
|
|
c.is_summary = summarize
|
|
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
return c
|
|
|
|
|
|
class DurationExpression(CombinedExpression):
|
|
def compile(self, side, compiler, connection):
|
|
if not isinstance(side, DurationValue):
|
|
try:
|
|
output = side.output_field
|
|
except FieldError:
|
|
pass
|
|
else:
|
|
if output.get_internal_type() == 'DurationField':
|
|
sql, params = compiler.compile(side)
|
|
return connection.ops.format_for_duration_arithmetic(sql), params
|
|
return compiler.compile(side)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
connection.ops.check_expression_support(self)
|
|
expressions = []
|
|
expression_params = []
|
|
sql, params = self.compile(self.lhs, compiler, connection)
|
|
expressions.append(sql)
|
|
expression_params.extend(params)
|
|
sql, params = self.compile(self.rhs, compiler, connection)
|
|
expressions.append(sql)
|
|
expression_params.extend(params)
|
|
# order of precedence
|
|
expression_wrapper = '(%s)'
|
|
sql = connection.ops.combine_duration_expression(self.connector, expressions)
|
|
return expression_wrapper % sql, expression_params
|
|
|
|
|
|
class TemporalSubtraction(CombinedExpression):
|
|
def __init__(self, lhs, rhs):
|
|
super().__init__(lhs, self.SUB, rhs, output_field=fields.DurationField())
|
|
|
|
def as_sql(self, compiler, connection):
|
|
connection.ops.check_expression_support(self)
|
|
lhs = compiler.compile(self.lhs, connection)
|
|
rhs = compiler.compile(self.rhs, connection)
|
|
return connection.ops.subtract_temporals(self.lhs.output_field.get_internal_type(), lhs, rhs)
|
|
|
|
|
|
@deconstructible
|
|
class F(Combinable):
|
|
"""An object capable of resolving references to existing query objects."""
|
|
def __init__(self, name):
|
|
"""
|
|
Arguments:
|
|
* name: the name of the field this expression references
|
|
"""
|
|
self.name = name
|
|
|
|
def __repr__(self):
|
|
return "{}({})".format(self.__class__.__name__, self.name)
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
|
|
|
|
def asc(self, **kwargs):
|
|
return OrderBy(self, **kwargs)
|
|
|
|
def desc(self, **kwargs):
|
|
return OrderBy(self, descending=True, **kwargs)
|
|
|
|
def __eq__(self, other):
|
|
return self.__class__ == other.__class__ and self.name == other.name
|
|
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
|
|
|
|
class ResolvedOuterRef(F):
|
|
"""
|
|
An object that contains a reference to an outer query.
|
|
|
|
In this case, the reference to the outer query has been resolved because
|
|
the inner query has been used as a subquery.
|
|
"""
|
|
def as_sql(self, *args, **kwargs):
|
|
raise ValueError(
|
|
'This queryset contains a reference to an outer query and may '
|
|
'only be used in a subquery.'
|
|
)
|
|
|
|
def _prepare(self, output_field=None):
|
|
return self
|
|
|
|
|
|
class OuterRef(F):
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
if isinstance(self.name, self.__class__):
|
|
return self.name
|
|
return ResolvedOuterRef(self.name)
|
|
|
|
def _prepare(self, output_field=None):
|
|
return self
|
|
|
|
|
|
class Func(SQLiteNumericMixin, Expression):
|
|
"""An SQL function call."""
|
|
function = None
|
|
template = '%(function)s(%(expressions)s)'
|
|
arg_joiner = ', '
|
|
arity = None # The number of arguments the function accepts.
|
|
|
|
def __init__(self, *expressions, output_field=None, **extra):
|
|
if self.arity is not None and len(expressions) != self.arity:
|
|
raise TypeError(
|
|
"'%s' takes exactly %s %s (%s given)" % (
|
|
self.__class__.__name__,
|
|
self.arity,
|
|
"argument" if self.arity == 1 else "arguments",
|
|
len(expressions),
|
|
)
|
|
)
|
|
super().__init__(output_field=output_field)
|
|
self.source_expressions = self._parse_expressions(*expressions)
|
|
self.extra = extra
|
|
|
|
def __repr__(self):
|
|
args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
|
|
extra = dict(self.extra, **self._get_repr_options())
|
|
if extra:
|
|
extra = ', '.join(str(key) + '=' + str(val) for key, val in sorted(extra.items()))
|
|
return "{}({}, {})".format(self.__class__.__name__, args, extra)
|
|
return "{}({})".format(self.__class__.__name__, args)
|
|
|
|
def _get_repr_options(self):
|
|
"""Return a dict of extra __init__() options to include in the repr."""
|
|
return {}
|
|
|
|
def get_source_expressions(self):
|
|
return self.source_expressions
|
|
|
|
def set_source_expressions(self, exprs):
|
|
self.source_expressions = exprs
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
c = self.copy()
|
|
c.is_summary = summarize
|
|
for pos, arg in enumerate(c.source_expressions):
|
|
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
return c
|
|
|
|
def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context):
|
|
connection.ops.check_expression_support(self)
|
|
sql_parts = []
|
|
params = []
|
|
for arg in self.source_expressions:
|
|
arg_sql, arg_params = compiler.compile(arg)
|
|
sql_parts.append(arg_sql)
|
|
params.extend(arg_params)
|
|
data = self.extra.copy()
|
|
data.update(**extra_context)
|
|
# Use the first supplied value in this order: the parameter to this
|
|
# method, a value supplied in __init__()'s **extra (the value in
|
|
# `data`), or the value defined on the class.
|
|
if function is not None:
|
|
data['function'] = function
|
|
else:
|
|
data.setdefault('function', self.function)
|
|
template = template or data.get('template', self.template)
|
|
arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner)
|
|
data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
|
|
return template % data, params
|
|
|
|
def copy(self):
|
|
copy = super().copy()
|
|
copy.source_expressions = self.source_expressions[:]
|
|
copy.extra = self.extra.copy()
|
|
return copy
|
|
|
|
|
|
class Value(Expression):
|
|
"""Represent a wrapped value as a node within an expression."""
|
|
def __init__(self, value, output_field=None):
|
|
"""
|
|
Arguments:
|
|
* value: the value this expression represents. The value will be
|
|
added into the sql parameter list and properly quoted.
|
|
|
|
* output_field: an instance of the model field type that this
|
|
expression will return, such as IntegerField() or CharField().
|
|
"""
|
|
super().__init__(output_field=output_field)
|
|
self.value = value
|
|
|
|
def __repr__(self):
|
|
return "{}({})".format(self.__class__.__name__, self.value)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
connection.ops.check_expression_support(self)
|
|
val = self.value
|
|
output_field = self._output_field_or_none
|
|
if output_field is not None:
|
|
if self.for_save:
|
|
val = output_field.get_db_prep_save(val, connection=connection)
|
|
else:
|
|
val = output_field.get_db_prep_value(val, connection=connection)
|
|
if hasattr(output_field, 'get_placeholder'):
|
|
return output_field.get_placeholder(val, compiler, connection), [val]
|
|
if val is None:
|
|
# cx_Oracle does not always convert None to the appropriate
|
|
# NULL type (like in case expressions using numbers), so we
|
|
# use a literal SQL NULL
|
|
return 'NULL', []
|
|
return '%s', [val]
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
c.for_save = for_save
|
|
return c
|
|
|
|
def get_group_by_cols(self):
|
|
return []
|
|
|
|
|
|
class DurationValue(Value):
|
|
def as_sql(self, compiler, connection):
|
|
connection.ops.check_expression_support(self)
|
|
if connection.features.has_native_duration_field:
|
|
return super().as_sql(compiler, connection)
|
|
return connection.ops.date_interval_sql(self.value), []
|
|
|
|
|
|
class RawSQL(Expression):
|
|
def __init__(self, sql, params, output_field=None):
|
|
if output_field is None:
|
|
output_field = fields.Field()
|
|
self.sql, self.params = sql, params
|
|
super().__init__(output_field=output_field)
|
|
|
|
def __repr__(self):
|
|
return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
return '(%s)' % self.sql, self.params
|
|
|
|
def get_group_by_cols(self):
|
|
return [self]
|
|
|
|
def __hash__(self):
|
|
h = hash(self.sql) ^ hash(self.output_field)
|
|
for param in self.params:
|
|
h ^= hash(param)
|
|
return h
|
|
|
|
|
|
class Star(Expression):
|
|
def __repr__(self):
|
|
return "'*'"
|
|
|
|
def as_sql(self, compiler, connection):
|
|
return '*', []
|
|
|
|
|
|
class Random(Expression):
|
|
def __init__(self):
|
|
super().__init__(output_field=fields.FloatField())
|
|
|
|
def __repr__(self):
|
|
return "Random()"
|
|
|
|
def as_sql(self, compiler, connection):
|
|
return connection.ops.random_function_sql(), []
|
|
|
|
|
|
class Col(Expression):
|
|
|
|
contains_column_references = True
|
|
|
|
def __init__(self, alias, target, output_field=None):
|
|
if output_field is None:
|
|
output_field = target
|
|
super().__init__(output_field=output_field)
|
|
self.alias, self.target = alias, target
|
|
|
|
def __repr__(self):
|
|
return "{}({}, {})".format(
|
|
self.__class__.__name__, self.alias, self.target)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
qn = compiler.quote_name_unless_alias
|
|
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
|
|
|
|
def relabeled_clone(self, relabels):
|
|
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
|
|
|
|
def get_group_by_cols(self):
|
|
return [self]
|
|
|
|
def get_db_converters(self, connection):
|
|
if self.target == self.output_field:
|
|
return self.output_field.get_db_converters(connection)
|
|
return (self.output_field.get_db_converters(connection) +
|
|
self.target.get_db_converters(connection))
|
|
|
|
|
|
class Ref(Expression):
|
|
"""
|
|
Reference to column alias of the query. For example, Ref('sum_cost') in
|
|
qs.annotate(sum_cost=Sum('cost')) query.
|
|
"""
|
|
def __init__(self, refs, source):
|
|
super().__init__()
|
|
self.refs, self.source = refs, source
|
|
|
|
def __repr__(self):
|
|
return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
|
|
|
|
def get_source_expressions(self):
|
|
return [self.source]
|
|
|
|
def set_source_expressions(self, exprs):
|
|
self.source, = exprs
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
# The sub-expression `source` has already been resolved, as this is
|
|
# just a reference to the name of `source`.
|
|
return self
|
|
|
|
def relabeled_clone(self, relabels):
|
|
return self
|
|
|
|
def as_sql(self, compiler, connection):
|
|
return "%s" % connection.ops.quote_name(self.refs), []
|
|
|
|
def get_group_by_cols(self):
|
|
return [self]
|
|
|
|
|
|
class ExpressionWrapper(Expression):
|
|
"""
|
|
An expression that can wrap another expression so that it can provide
|
|
extra context to the inner expression, such as the output_field.
|
|
"""
|
|
|
|
def __init__(self, expression, output_field):
|
|
super().__init__(output_field=output_field)
|
|
self.expression = expression
|
|
|
|
def set_source_expressions(self, exprs):
|
|
self.expression = exprs[0]
|
|
|
|
def get_source_expressions(self):
|
|
return [self.expression]
|
|
|
|
def as_sql(self, compiler, connection):
|
|
return self.expression.as_sql(compiler, connection)
|
|
|
|
def __repr__(self):
|
|
return "{}({})".format(self.__class__.__name__, self.expression)
|
|
|
|
|
|
class When(Expression):
|
|
template = 'WHEN %(condition)s THEN %(result)s'
|
|
|
|
def __init__(self, condition=None, then=None, **lookups):
|
|
if lookups and condition is None:
|
|
condition, lookups = Q(**lookups), None
|
|
if condition is None or not isinstance(condition, Q) or lookups:
|
|
raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
|
|
super().__init__(output_field=None)
|
|
self.condition = condition
|
|
self.result = self._parse_expressions(then)[0]
|
|
|
|
def __str__(self):
|
|
return "WHEN %r THEN %r" % (self.condition, self.result)
|
|
|
|
def __repr__(self):
|
|
return "<%s: %s>" % (self.__class__.__name__, self)
|
|
|
|
def get_source_expressions(self):
|
|
return [self.condition, self.result]
|
|
|
|
def set_source_expressions(self, exprs):
|
|
self.condition, self.result = exprs
|
|
|
|
def get_source_fields(self):
|
|
# We're only interested in the fields of the result expressions.
|
|
return [self.result._output_field_or_none]
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
c = self.copy()
|
|
c.is_summary = summarize
|
|
if hasattr(c.condition, 'resolve_expression'):
|
|
c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False)
|
|
c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
return c
|
|
|
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
|
connection.ops.check_expression_support(self)
|
|
template_params = extra_context
|
|
sql_params = []
|
|
condition_sql, condition_params = compiler.compile(self.condition)
|
|
template_params['condition'] = condition_sql
|
|
sql_params.extend(condition_params)
|
|
result_sql, result_params = compiler.compile(self.result)
|
|
template_params['result'] = result_sql
|
|
sql_params.extend(result_params)
|
|
template = template or self.template
|
|
return template % template_params, sql_params
|
|
|
|
def get_group_by_cols(self):
|
|
# This is not a complete expression and cannot be used in GROUP BY.
|
|
cols = []
|
|
for source in self.get_source_expressions():
|
|
cols.extend(source.get_group_by_cols())
|
|
return cols
|
|
|
|
|
|
class Case(Expression):
|
|
"""
|
|
An SQL searched CASE expression:
|
|
|
|
CASE
|
|
WHEN n > 0
|
|
THEN 'positive'
|
|
WHEN n < 0
|
|
THEN 'negative'
|
|
ELSE 'zero'
|
|
END
|
|
"""
|
|
template = 'CASE %(cases)s ELSE %(default)s END'
|
|
case_joiner = ' '
|
|
|
|
def __init__(self, *cases, default=None, output_field=None, **extra):
|
|
if not all(isinstance(case, When) for case in cases):
|
|
raise TypeError("Positional arguments must all be When objects.")
|
|
super().__init__(output_field)
|
|
self.cases = list(cases)
|
|
self.default = self._parse_expressions(default)[0]
|
|
self.extra = extra
|
|
|
|
def __str__(self):
|
|
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
|
|
|
|
def __repr__(self):
|
|
return "<%s: %s>" % (self.__class__.__name__, self)
|
|
|
|
def get_source_expressions(self):
|
|
return self.cases + [self.default]
|
|
|
|
def set_source_expressions(self, exprs):
|
|
self.cases = exprs[:-1]
|
|
self.default = exprs[-1]
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
c = self.copy()
|
|
c.is_summary = summarize
|
|
for pos, case in enumerate(c.cases):
|
|
c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
return c
|
|
|
|
def copy(self):
|
|
c = super().copy()
|
|
c.cases = c.cases[:]
|
|
return c
|
|
|
|
def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context):
|
|
connection.ops.check_expression_support(self)
|
|
if not self.cases:
|
|
return compiler.compile(self.default)
|
|
template_params = self.extra.copy()
|
|
template_params.update(extra_context)
|
|
case_parts = []
|
|
sql_params = []
|
|
for case in self.cases:
|
|
try:
|
|
case_sql, case_params = compiler.compile(case)
|
|
except EmptyResultSet:
|
|
continue
|
|
case_parts.append(case_sql)
|
|
sql_params.extend(case_params)
|
|
default_sql, default_params = compiler.compile(self.default)
|
|
if not case_parts:
|
|
return default_sql, default_params
|
|
case_joiner = case_joiner or self.case_joiner
|
|
template_params['cases'] = case_joiner.join(case_parts)
|
|
template_params['default'] = default_sql
|
|
sql_params.extend(default_params)
|
|
template = template or template_params.get('template', self.template)
|
|
sql = template % template_params
|
|
if self._output_field_or_none is not None:
|
|
sql = connection.ops.unification_cast_sql(self.output_field) % sql
|
|
return sql, sql_params
|
|
|
|
|
|
class Subquery(Expression):
|
|
"""
|
|
An explicit subquery. It may contain OuterRef() references to the outer
|
|
query which will be resolved when it is applied to that query.
|
|
"""
|
|
template = '(%(subquery)s)'
|
|
|
|
def __init__(self, queryset, output_field=None, **extra):
|
|
self.queryset = queryset
|
|
self.extra = extra
|
|
if output_field is None and len(self.queryset.query.select) == 1:
|
|
output_field = self.queryset.query.select[0].field
|
|
super().__init__(output_field)
|
|
|
|
def copy(self):
|
|
clone = super().copy()
|
|
clone.queryset = clone.queryset.all()
|
|
return clone
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
clone = self.copy()
|
|
clone.is_summary = summarize
|
|
clone.queryset.query.bump_prefix(query)
|
|
|
|
# Need to recursively resolve these.
|
|
def resolve_all(child):
|
|
if hasattr(child, 'children'):
|
|
[resolve_all(_child) for _child in child.children]
|
|
if hasattr(child, 'rhs'):
|
|
child.rhs = resolve(child.rhs)
|
|
|
|
def resolve(child):
|
|
if hasattr(child, 'resolve_expression'):
|
|
resolved = child.resolve_expression(
|
|
query=query, allow_joins=allow_joins, reuse=reuse,
|
|
summarize=summarize, for_save=for_save,
|
|
)
|
|
# Add table alias to the parent query's aliases to prevent
|
|
# quoting.
|
|
if hasattr(resolved, 'alias'):
|
|
clone.queryset.query.external_aliases.add(resolved.alias)
|
|
return resolved
|
|
return child
|
|
|
|
resolve_all(clone.queryset.query.where)
|
|
|
|
for key, value in clone.queryset.query.annotations.items():
|
|
if isinstance(value, Subquery):
|
|
clone.queryset.query.annotations[key] = resolve(value)
|
|
|
|
return clone
|
|
|
|
def get_source_expressions(self):
|
|
return [
|
|
x for x in [
|
|
getattr(expr, 'lhs', None)
|
|
for expr in self.queryset.query.where.children
|
|
] if x
|
|
]
|
|
|
|
def relabeled_clone(self, change_map):
|
|
clone = self.copy()
|
|
clone.queryset.query = clone.queryset.query.relabeled_clone(change_map)
|
|
clone.queryset.query.external_aliases.update(
|
|
alias for alias in change_map.values()
|
|
if alias not in clone.queryset.query.alias_map
|
|
)
|
|
return clone
|
|
|
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
|
connection.ops.check_expression_support(self)
|
|
template_params = self.extra.copy()
|
|
template_params.update(extra_context)
|
|
template_params['subquery'], sql_params = self.queryset.query.get_compiler(connection=connection).as_sql()
|
|
|
|
template = template or template_params.get('template', self.template)
|
|
sql = template % template_params
|
|
return sql, sql_params
|
|
|
|
def _prepare(self, output_field):
|
|
# This method will only be called if this instance is the "rhs" in an
|
|
# expression: the wrapping () must be removed (as the expression that
|
|
# contains this will provide them). SQLite evaluates ((subquery))
|
|
# differently than the other databases.
|
|
if self.template == '(%(subquery)s)':
|
|
clone = self.copy()
|
|
clone.template = '%(subquery)s'
|
|
return clone
|
|
return self
|
|
|
|
|
|
class Exists(Subquery):
|
|
template = 'EXISTS(%(subquery)s)'
|
|
|
|
def __init__(self, *args, negated=False, **kwargs):
|
|
self.negated = negated
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def __invert__(self):
|
|
return type(self)(self.queryset, negated=(not self.negated), **self.extra)
|
|
|
|
@property
|
|
def output_field(self):
|
|
return fields.BooleanField()
|
|
|
|
def resolve_expression(self, query=None, **kwargs):
|
|
# As a performance optimization, remove ordering since EXISTS doesn't
|
|
# care about it, just whether or not a row matches.
|
|
self.queryset = self.queryset.order_by()
|
|
return super().resolve_expression(query, **kwargs)
|
|
|
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
|
sql, params = super().as_sql(compiler, connection, template, **extra_context)
|
|
if self.negated:
|
|
sql = 'NOT {}'.format(sql)
|
|
return sql, params
|
|
|
|
def as_oracle(self, compiler, connection, template=None, **extra_context):
|
|
# Oracle doesn't allow EXISTS() in the SELECT list, so wrap it with a
|
|
# CASE WHEN expression. Change the template since the When expression
|
|
# requires a left hand side (column) to compare against.
|
|
sql, params = self.as_sql(compiler, connection, template, **extra_context)
|
|
sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql)
|
|
return sql, params
|
|
|
|
|
|
class OrderBy(BaseExpression):
|
|
template = '%(expression)s %(ordering)s'
|
|
|
|
def __init__(self, expression, descending=False, nulls_first=False, nulls_last=False):
|
|
if nulls_first and nulls_last:
|
|
raise ValueError('nulls_first and nulls_last are mutually exclusive')
|
|
self.nulls_first = nulls_first
|
|
self.nulls_last = nulls_last
|
|
self.descending = descending
|
|
if not hasattr(expression, 'resolve_expression'):
|
|
raise ValueError('expression must be an expression type')
|
|
self.expression = expression
|
|
|
|
def __repr__(self):
|
|
return "{}({}, descending={})".format(
|
|
self.__class__.__name__, self.expression, self.descending)
|
|
|
|
def set_source_expressions(self, exprs):
|
|
self.expression = exprs[0]
|
|
|
|
def get_source_expressions(self):
|
|
return [self.expression]
|
|
|
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
|
if not template:
|
|
if self.nulls_last:
|
|
template = '%s NULLS LAST' % self.template
|
|
elif self.nulls_first:
|
|
template = '%s NULLS FIRST' % self.template
|
|
connection.ops.check_expression_support(self)
|
|
expression_sql, params = compiler.compile(self.expression)
|
|
placeholders = {
|
|
'expression': expression_sql,
|
|
'ordering': 'DESC' if self.descending else 'ASC',
|
|
}
|
|
placeholders.update(extra_context)
|
|
template = template or self.template
|
|
return (template % placeholders).rstrip(), params
|
|
|
|
def as_sqlite(self, compiler, connection):
|
|
template = None
|
|
if self.nulls_last:
|
|
template = '%(expression)s IS NULL, %(expression)s %(ordering)s'
|
|
elif self.nulls_first:
|
|
template = '%(expression)s IS NOT NULL, %(expression)s %(ordering)s'
|
|
return self.as_sql(compiler, connection, template=template)
|
|
|
|
def as_mysql(self, compiler, connection):
|
|
template = None
|
|
if self.nulls_last:
|
|
template = 'IF(ISNULL(%(expression)s),1,0), %(expression)s %(ordering)s '
|
|
elif self.nulls_first:
|
|
template = 'IF(ISNULL(%(expression)s),0,1), %(expression)s %(ordering)s '
|
|
return self.as_sql(compiler, connection, template=template)
|
|
|
|
def get_group_by_cols(self):
|
|
cols = []
|
|
for source in self.get_source_expressions():
|
|
cols.extend(source.get_group_by_cols())
|
|
return cols
|
|
|
|
def reverse_ordering(self):
|
|
self.descending = not self.descending
|
|
return self
|
|
|
|
def asc(self):
|
|
self.descending = False
|
|
|
|
def desc(self):
|
|
self.descending = True
|