mirror of https://github.com/django/django.git
Fixed #25759 -- Added keyword arguments to customize Expressions' as_sql().
This commit is contained in:
parent
f1db8c36e9
commit
5336158990
|
@ -534,7 +534,7 @@ class Func(Expression):
|
||||||
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, function=None, template=None):
|
def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
sql_parts = []
|
sql_parts = []
|
||||||
params = []
|
params = []
|
||||||
|
@ -542,13 +542,19 @@ class Func(Expression):
|
||||||
arg_sql, arg_params = compiler.compile(arg)
|
arg_sql, arg_params = compiler.compile(arg)
|
||||||
sql_parts.append(arg_sql)
|
sql_parts.append(arg_sql)
|
||||||
params.extend(arg_params)
|
params.extend(arg_params)
|
||||||
if function is None:
|
data = self.extra.copy()
|
||||||
self.extra['function'] = self.extra.get('function', self.function)
|
data.update(**extra_context)
|
||||||
|
# Use the first supplied value in this order: the parameter to this
|
||||||
|
# method, a value supplied in __init__()'s **extra (the value in
|
||||||
|
# `data`), or the value defined on the class.
|
||||||
|
if function is not None:
|
||||||
|
data['function'] = function
|
||||||
else:
|
else:
|
||||||
self.extra['function'] = function
|
data.setdefault('function', self.function)
|
||||||
self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts)
|
template = template or data.get('template', self.template)
|
||||||
template = template or self.extra.get('template', self.template)
|
arg_joiner = arg_joiner or data.get('arg_joiner', self.arg_joiner)
|
||||||
return template % self.extra, params
|
data['expressions'] = data['field'] = arg_joiner.join(sql_parts)
|
||||||
|
return template % data, params
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
sql, params = self.as_sql(compiler, connection)
|
sql, params = self.as_sql(compiler, connection)
|
||||||
|
@ -778,9 +784,9 @@ class When(Expression):
|
||||||
c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, template=None):
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
template_params = {}
|
template_params = extra_context
|
||||||
sql_params = []
|
sql_params = []
|
||||||
condition_sql, condition_params = compiler.compile(self.condition)
|
condition_sql, condition_params = compiler.compile(self.condition)
|
||||||
template_params['condition'] = condition_sql
|
template_params['condition'] = condition_sql
|
||||||
|
@ -822,6 +828,7 @@ class Case(Expression):
|
||||||
super(Case, self).__init__(output_field)
|
super(Case, self).__init__(output_field)
|
||||||
self.cases = list(cases)
|
self.cases = list(cases)
|
||||||
self.default = self._parse_expressions(default)[0]
|
self.default = self._parse_expressions(default)[0]
|
||||||
|
self.extra = extra
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
|
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)
|
||||||
|
@ -849,22 +856,24 @@ class Case(Expression):
|
||||||
c.cases = c.cases[:]
|
c.cases = c.cases[:]
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, template=None, extra=None):
|
def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
if not self.cases:
|
if not self.cases:
|
||||||
return compiler.compile(self.default)
|
return compiler.compile(self.default)
|
||||||
template_params = dict(extra) if extra else {}
|
template_params = self.extra.copy()
|
||||||
|
template_params.update(extra_context)
|
||||||
case_parts = []
|
case_parts = []
|
||||||
sql_params = []
|
sql_params = []
|
||||||
for case in self.cases:
|
for case in self.cases:
|
||||||
case_sql, case_params = compiler.compile(case)
|
case_sql, case_params = compiler.compile(case)
|
||||||
case_parts.append(case_sql)
|
case_parts.append(case_sql)
|
||||||
sql_params.extend(case_params)
|
sql_params.extend(case_params)
|
||||||
template_params['cases'] = self.case_joiner.join(case_parts)
|
case_joiner = case_joiner or self.case_joiner
|
||||||
|
template_params['cases'] = case_joiner.join(case_parts)
|
||||||
default_sql, default_params = compiler.compile(self.default)
|
default_sql, default_params = compiler.compile(self.default)
|
||||||
template_params['default'] = default_sql
|
template_params['default'] = default_sql
|
||||||
sql_params.extend(default_params)
|
sql_params.extend(default_params)
|
||||||
template = template or self.template
|
template = template or template_params.get('template', self.template)
|
||||||
sql = template % template_params
|
sql = template % template_params
|
||||||
if self._output_field_or_none is not None:
|
if self._output_field_or_none is not None:
|
||||||
sql = connection.ops.unification_cast_sql(self.output_field) % sql
|
sql = connection.ops.unification_cast_sql(self.output_field) % sql
|
||||||
|
@ -995,14 +1004,16 @@ class OrderBy(BaseExpression):
|
||||||
def get_source_expressions(self):
|
def get_source_expressions(self):
|
||||||
return [self.expression]
|
return [self.expression]
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection, template=None, **extra_context):
|
||||||
connection.ops.check_expression_support(self)
|
connection.ops.check_expression_support(self)
|
||||||
expression_sql, params = compiler.compile(self.expression)
|
expression_sql, params = compiler.compile(self.expression)
|
||||||
placeholders = {
|
placeholders = {
|
||||||
'expression': expression_sql,
|
'expression': expression_sql,
|
||||||
'ordering': 'DESC' if self.descending else 'ASC',
|
'ordering': 'DESC' if self.descending else 'ASC',
|
||||||
}
|
}
|
||||||
return (self.template % placeholders).rstrip(), params
|
placeholders.update(extra_context)
|
||||||
|
template = template or self.template
|
||||||
|
return (template % placeholders).rstrip(), params
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self):
|
||||||
cols = []
|
cols = []
|
||||||
|
|
|
@ -43,9 +43,8 @@ class ConcatPair(Func):
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
coalesced = self.coalesce()
|
coalesced = self.coalesce()
|
||||||
coalesced.arg_joiner = ' || '
|
|
||||||
return super(ConcatPair, coalesced).as_sql(
|
return super(ConcatPair, coalesced).as_sql(
|
||||||
compiler, connection, template='%(expressions)s',
|
compiler, connection, template='%(expressions)s', arg_joiner=' || '
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_mysql(self, compiler, connection):
|
def as_mysql(self, compiler, connection):
|
||||||
|
|
|
@ -261,12 +261,13 @@ The ``Func`` API is as follows:
|
||||||
different number of expressions, ``TypeError`` will be raised. Defaults
|
different number of expressions, ``TypeError`` will be raised. Defaults
|
||||||
to ``None``.
|
to ``None``.
|
||||||
|
|
||||||
.. method:: as_sql(compiler, connection, function=None, template=None)
|
.. method:: as_sql(compiler, connection, function=None, template=None, arg_joiner=None, **extra_context)
|
||||||
|
|
||||||
Generates the SQL for the database function.
|
Generates the SQL for the database function.
|
||||||
|
|
||||||
The ``as_vendor()`` methods should use the ``function`` and
|
The ``as_vendor()`` methods should use the ``function``, ``template``,
|
||||||
``template`` parameters to customize the SQL as needed. For example:
|
``arg_joiner``, and any other ``**extra_context`` parameters to
|
||||||
|
customize the SQL as needed. For example:
|
||||||
|
|
||||||
.. snippet::
|
.. snippet::
|
||||||
:filename: django/db/models/functions.py
|
:filename: django/db/models/functions.py
|
||||||
|
@ -283,6 +284,11 @@ The ``Func`` API is as follows:
|
||||||
template="%(function)s('', %(expressions)s)",
|
template="%(function)s('', %(expressions)s)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
.. versionchanged:: 1.10
|
||||||
|
|
||||||
|
Support for the ``arg_joiner`` and ``**extra_context`` parameters
|
||||||
|
was added.
|
||||||
|
|
||||||
The ``*expressions`` argument is a list of positional expressions that the
|
The ``*expressions`` argument is a list of positional expressions that the
|
||||||
function will be applied to. The expressions will be converted to strings,
|
function will be applied to. The expressions will be converted to strings,
|
||||||
joined together with ``arg_joiner``, and then interpolated into the ``template``
|
joined together with ``arg_joiner``, and then interpolated into the ``template``
|
||||||
|
@ -293,10 +299,10 @@ assumed to be column references and will be wrapped in ``F()`` expressions
|
||||||
while other values will be wrapped in ``Value()`` expressions.
|
while other values will be wrapped in ``Value()`` expressions.
|
||||||
|
|
||||||
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
|
||||||
into the ``template`` attribute. Note that the keywords ``function`` and
|
into the ``template`` attribute. The ``function``, ``template``, and
|
||||||
``template`` can be used to replace the ``function`` and ``template``
|
``arg_joiner`` keywords can be used to replace the attributes of the same name
|
||||||
attributes respectively, without having to define your own class.
|
without having to define your own class. ``output_field`` can be used to define
|
||||||
``output_field`` can be used to define the expected return type.
|
the expected return type.
|
||||||
|
|
||||||
``Aggregate()`` expressions
|
``Aggregate()`` expressions
|
||||||
---------------------------
|
---------------------------
|
||||||
|
|
|
@ -212,6 +212,13 @@ Database backends
|
||||||
``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys
|
``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys
|
||||||
on objects created using ``QuerySet.bulk_create()``.
|
on objects created using ``QuerySet.bulk_create()``.
|
||||||
|
|
||||||
|
* Added keyword arguments to the ``as_sql()`` methods of various expressions
|
||||||
|
(``Func``, ``When``, ``Case``, and ``OrderBy``) to allow database backends to
|
||||||
|
customize them without mutating ``self``, which isn't safe when using
|
||||||
|
different database backends. See the ``arg_joiner`` and ``**extra_context``
|
||||||
|
parameters of :meth:`Func.as_sql() <django.db.models.Func.as_sql>` for an
|
||||||
|
example.
|
||||||
|
|
||||||
Email
|
Email
|
||||||
~~~~~
|
~~~~~
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue