Fixed #25759 -- Added keyword arguments to customize Expressions' as_sql().

This commit is contained in:
Kai Feldhoff 2016-02-15 21:42:24 +01:00 committed by Tim Graham
parent f1db8c36e9
commit 5336158990
4 changed files with 47 additions and 24 deletions

View File

@ -534,7 +534,7 @@ class Func(Expression):
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save) c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c return c
def as_sql(self, compiler, connection, function=None, template=None): def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
sql_parts = [] sql_parts = []
params = [] params = []
@ -542,13 +542,19 @@ class Func(Expression):
arg_sql, arg_params = compiler.compile(arg) arg_sql, arg_params = compiler.compile(arg)
sql_parts.append(arg_sql) sql_parts.append(arg_sql)
params.extend(arg_params) params.extend(arg_params)
if function is None: data = self.extra.copy()
self.extra['function'] = self.extra.get('function', self.function) data.update(**extra_context)
# Use the first supplied value in this order: the parameter to this
# method, a value supplied in __init__()'s **extra (the value in
# `data`), or the value defined on the class.
if function is not None:
data['function'] = function
else: else:
self.extra['function'] = function data.setdefault('function', self.function)
self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts) template = template or data.get('template', self.template)
template = template or self.extra.get('template', self.template) arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner)
return template % self.extra, params data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
return template % data, params
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection):
sql, params = self.as_sql(compiler, connection) sql, params = self.as_sql(compiler, connection)
@ -778,9 +784,9 @@ class When(Expression):
c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save) c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return c return c
def as_sql(self, compiler, connection, template=None): def as_sql(self, compiler, connection, template=None, **extra_context):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
template_params = {} template_params = extra_context
sql_params = [] sql_params = []
condition_sql, condition_params = compiler.compile(self.condition) condition_sql, condition_params = compiler.compile(self.condition)
template_params['condition'] = condition_sql template_params['condition'] = condition_sql
@ -822,6 +828,7 @@ class Case(Expression):
super(Case, self).__init__(output_field) super(Case, self).__init__(output_field)
self.cases = list(cases) self.cases = list(cases)
self.default = self._parse_expressions(default)[0] self.default = self._parse_expressions(default)[0]
self.extra = extra
def __str__(self): def __str__(self):
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default) return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
@ -849,22 +856,24 @@ class Case(Expression):
c.cases = c.cases[:] c.cases = c.cases[:]
return c return c
def as_sql(self, compiler, connection, template=None, extra=None): def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
if not self.cases: if not self.cases:
return compiler.compile(self.default) return compiler.compile(self.default)
template_params = dict(extra) if extra else {} template_params = self.extra.copy()
template_params.update(extra_context)
case_parts = [] case_parts = []
sql_params = [] sql_params = []
for case in self.cases: for case in self.cases:
case_sql, case_params = compiler.compile(case) case_sql, case_params = compiler.compile(case)
case_parts.append(case_sql) case_parts.append(case_sql)
sql_params.extend(case_params) sql_params.extend(case_params)
template_params['cases'] = self.case_joiner.join(case_parts) case_joiner = case_joiner or self.case_joiner
template_params['cases'] = case_joiner.join(case_parts)
default_sql, default_params = compiler.compile(self.default) default_sql, default_params = compiler.compile(self.default)
template_params['default'] = default_sql template_params['default'] = default_sql
sql_params.extend(default_params) sql_params.extend(default_params)
template = template or self.template template = template or template_params.get('template', self.template)
sql = template % template_params sql = template % template_params
if self._output_field_or_none is not None: if self._output_field_or_none is not None:
sql = connection.ops.unification_cast_sql(self.output_field) % sql sql = connection.ops.unification_cast_sql(self.output_field) % sql
@ -995,14 +1004,16 @@ class OrderBy(BaseExpression):
def get_source_expressions(self): def get_source_expressions(self):
return [self.expression] return [self.expression]
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection, template=None, **extra_context):
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
expression_sql, params = compiler.compile(self.expression) expression_sql, params = compiler.compile(self.expression)
placeholders = { placeholders = {
'expression': expression_sql, 'expression': expression_sql,
'ordering': 'DESC' if self.descending else 'ASC', 'ordering': 'DESC' if self.descending else 'ASC',
} }
return (self.template % placeholders).rstrip(), params placeholders.update(extra_context)
template = template or self.template
return (template % placeholders).rstrip(), params
def get_group_by_cols(self): def get_group_by_cols(self):
cols = [] cols = []

View File

@ -43,9 +43,8 @@ class ConcatPair(Func):
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection):
coalesced = self.coalesce() coalesced = self.coalesce()
coalesced.arg_joiner = ' || '
return super(ConcatPair, coalesced).as_sql( return super(ConcatPair, coalesced).as_sql(
compiler, connection, template='%(expressions)s', compiler, connection, template='%(expressions)s', arg_joiner=' || '
) )
def as_mysql(self, compiler, connection): def as_mysql(self, compiler, connection):

View File

@ -261,12 +261,13 @@ The ``Func`` API is as follows:
different number of expressions, ``TypeError`` will be raised. Defaults different number of expressions, ``TypeError`` will be raised. Defaults
to ``None``. to ``None``.
.. method:: as_sql(compiler, connection, function=None, template=None) .. method:: as_sql(compiler, connection, function=None, template=None, arg_joiner=None, **extra_context)
Generates the SQL for the database function. Generates the SQL for the database function.
The ``as_vendor()`` methods should use the ``function`` and The ``as_vendor()`` methods should use the ``function``, ``template``,
``template`` parameters to customize the SQL as needed. For example: ``arg_joiner``, and any other ``**extra_context`` parameters to
customize the SQL as needed. For example:
.. snippet:: .. snippet::
:filename: django/db/models/functions.py :filename: django/db/models/functions.py
@ -283,6 +284,11 @@ The ``Func`` API is as follows:
template="%(function)s('', %(expressions)s)", template="%(function)s('', %(expressions)s)",
) )
.. versionchanged:: 1.10
Support for the ``arg_joiner`` and ``**extra_context`` parameters
was added.
The ``*expressions`` argument is a list of positional expressions that the The ``*expressions`` argument is a list of positional expressions that the
function will be applied to. The expressions will be converted to strings, function will be applied to. The expressions will be converted to strings,
joined together with ``arg_joiner``, and then interpolated into the ``template`` joined together with ``arg_joiner``, and then interpolated into the ``template``
@ -293,10 +299,10 @@ assumed to be column references and will be wrapped in ``F()`` expressions
while other values will be wrapped in ``Value()`` expressions. while other values will be wrapped in ``Value()`` expressions.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. Note that the keywords ``function`` and into the ``template`` attribute. The ``function``, ``template``, and
``template`` can be used to replace the ``function`` and ``template`` ``arg_joiner`` keywords can be used to replace the attributes of the same name
attributes respectively, without having to define your own class. without having to define your own class. ``output_field`` can be used to define
``output_field`` can be used to define the expected return type. the expected return type.
``Aggregate()`` expressions ``Aggregate()`` expressions
--------------------------- ---------------------------

View File

@ -212,6 +212,13 @@ Database backends
``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys ``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys
on objects created using ``QuerySet.bulk_create()``. on objects created using ``QuerySet.bulk_create()``.
* Added keyword arguments to the ``as_sql()`` methods of various expressions
(``Func``, ``When``, ``Case``, and ``OrderBy``) to allow database backends to
customize them without mutating ``self``, which isn't safe when using
different database backends. See the ``arg_joiner`` and ``**extra_context``
parameters of :meth:`Func.as_sql() <django.db.models.Func.as_sql>` for an
example.
Email Email
~~~~~ ~~~~~