Refs #25367 -- Moved select_format hook to BaseExpression.

This will expose an intermediary hook for expressions that need special
formatting when used in a SELECT clause.
This commit is contained in:
Simon Charette 2019-08-09 14:30:23 -04:00 committed by Mariusz Felisiak
parent 4f7328ce8a
commit fff5186d32
3 changed files with 16 additions and 8 deletions

View File

@ -272,7 +272,9 @@ class GeometryField(BaseSpatialField):
of the spatial backend. For example, Oracle and MySQL require custom
selection formats in order to retrieve geometries in OGC WKB.
"""
if not compiler.query.subquery:
return compiler.connection.ops.select % sql, params
return sql, params
# The OpenGIS Geometry Type Fields

View File

@ -366,6 +366,13 @@ class BaseExpression:
if expr:
yield from expr.flatten()
def select_format(self, compiler, sql, params):
"""
Custom format for select clauses. For example, EXISTS expressions need
to be wrapped in CASE WHEN on Oracle.
"""
return self.output_field.select_format(compiler, sql, params)
@cached_property
def identity(self):
constructor_signature = inspect.signature(self.__init__)

View File

@ -17,8 +17,6 @@ from django.db.utils import DatabaseError, NotSupportedError
from django.utils.deprecation import RemovedInDjango31Warning
from django.utils.hashable import make_hashable
FORCE = object()
class SQLCompiler:
def __init__(self, query, connection, using):
@ -244,10 +242,12 @@ class SQLCompiler:
ret = []
for col, alias in select:
try:
sql, params = self.compile(col, select_format=True)
sql, params = self.compile(col)
except EmptyResultSet:
# Select a predicate that's always False.
sql, params = '0', ()
else:
sql, params = col.select_format(self, sql, params)
ret.append((col, (sql, params), alias))
return ret, klass_info, annotations
@ -402,14 +402,12 @@ class SQLCompiler:
self.quote_cache[name] = r
return r
def compile(self, node, select_format=False):
def compile(self, node):
vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
if vendor_impl:
sql, params = vendor_impl(self, self.connection)
else:
sql, params = node.as_sql(self, self.connection)
if select_format is FORCE or (select_format and not self.query.subquery):
return node.output_field.select_format(self, sql, params)
return sql, params
def get_combinator_sql(self, combinator, all):
@ -1503,7 +1501,8 @@ class SQLAggregateCompiler(SQLCompiler):
"""
sql, params = [], []
for annotation in self.query.annotation_select.values():
ann_sql, ann_params = self.compile(annotation, select_format=FORCE)
ann_sql, ann_params = self.compile(annotation)
ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
sql.append(ann_sql)
params.extend(ann_params)
self.col_count = len(self.query.annotation_select)