Refs #25367 -- Moved select_format hook to BaseExpression.

This will expose an intermediary hook for expressions that need special
formatting when used in a SELECT clause.
This commit is contained in:
Simon Charette 2019-08-09 14:30:23 -04:00 committed by Mariusz Felisiak
parent 4f7328ce8a
commit fff5186d32
3 changed files with 16 additions and 8 deletions

View File

@ -272,7 +272,9 @@ class GeometryField(BaseSpatialField):
of the spatial backend. For example, Oracle and MySQL require custom of the spatial backend. For example, Oracle and MySQL require custom
selection formats in order to retrieve geometries in OGC WKB. selection formats in order to retrieve geometries in OGC WKB.
""" """
return compiler.connection.ops.select % sql, params if not compiler.query.subquery:
return compiler.connection.ops.select % sql, params
return sql, params
# The OpenGIS Geometry Type Fields # The OpenGIS Geometry Type Fields

View File

@ -366,6 +366,13 @@ class BaseExpression:
if expr: if expr:
yield from expr.flatten() yield from expr.flatten()
def select_format(self, compiler, sql, params):
"""
Custom format for select clauses. For example, EXISTS expressions need
to be wrapped in CASE WHEN on Oracle.
"""
return self.output_field.select_format(compiler, sql, params)
@cached_property @cached_property
def identity(self): def identity(self):
constructor_signature = inspect.signature(self.__init__) constructor_signature = inspect.signature(self.__init__)

View File

@ -17,8 +17,6 @@ from django.db.utils import DatabaseError, NotSupportedError
from django.utils.deprecation import RemovedInDjango31Warning from django.utils.deprecation import RemovedInDjango31Warning
from django.utils.hashable import make_hashable from django.utils.hashable import make_hashable
FORCE = object()
class SQLCompiler: class SQLCompiler:
def __init__(self, query, connection, using): def __init__(self, query, connection, using):
@ -244,10 +242,12 @@ class SQLCompiler:
ret = [] ret = []
for col, alias in select: for col, alias in select:
try: try:
sql, params = self.compile(col, select_format=True) sql, params = self.compile(col)
except EmptyResultSet: except EmptyResultSet:
# Select a predicate that's always False. # Select a predicate that's always False.
sql, params = '0', () sql, params = '0', ()
else:
sql, params = col.select_format(self, sql, params)
ret.append((col, (sql, params), alias)) ret.append((col, (sql, params), alias))
return ret, klass_info, annotations return ret, klass_info, annotations
@ -402,14 +402,12 @@ class SQLCompiler:
self.quote_cache[name] = r self.quote_cache[name] = r
return r return r
def compile(self, node, select_format=False): def compile(self, node):
vendor_impl = getattr(node, 'as_' + self.connection.vendor, None) vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
if vendor_impl: if vendor_impl:
sql, params = vendor_impl(self, self.connection) sql, params = vendor_impl(self, self.connection)
else: else:
sql, params = node.as_sql(self, self.connection) sql, params = node.as_sql(self, self.connection)
if select_format is FORCE or (select_format and not self.query.subquery):
return node.output_field.select_format(self, sql, params)
return sql, params return sql, params
def get_combinator_sql(self, combinator, all): def get_combinator_sql(self, combinator, all):
@ -1503,7 +1501,8 @@ class SQLAggregateCompiler(SQLCompiler):
""" """
sql, params = [], [] sql, params = [], []
for annotation in self.query.annotation_select.values(): for annotation in self.query.annotation_select.values():
ann_sql, ann_params = self.compile(annotation, select_format=FORCE) ann_sql, ann_params = self.compile(annotation)
ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
sql.append(ann_sql) sql.append(ann_sql)
params.extend(ann_params) params.extend(ann_params)
self.col_count = len(self.query.annotation_select) self.col_count = len(self.query.annotation_select)