Refs #25367 -- Moved select_format hook to BaseExpression.
This will expose an intermediary hook for expressions that need special formatting when used in a SELECT clause.
This commit is contained in:
parent
4f7328ce8a
commit
fff5186d32
|
@ -272,7 +272,9 @@ class GeometryField(BaseSpatialField):
|
|||
of the spatial backend. For example, Oracle and MySQL require custom
|
||||
selection formats in order to retrieve geometries in OGC WKB.
|
||||
"""
|
||||
return compiler.connection.ops.select % sql, params
|
||||
if not compiler.query.subquery:
|
||||
return compiler.connection.ops.select % sql, params
|
||||
return sql, params
|
||||
|
||||
|
||||
# The OpenGIS Geometry Type Fields
|
||||
|
|
|
@ -366,6 +366,13 @@ class BaseExpression:
|
|||
if expr:
|
||||
yield from expr.flatten()
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
"""
|
||||
Custom format for select clauses. For example, EXISTS expressions need
|
||||
to be wrapped in CASE WHEN on Oracle.
|
||||
"""
|
||||
return self.output_field.select_format(compiler, sql, params)
|
||||
|
||||
@cached_property
|
||||
def identity(self):
|
||||
constructor_signature = inspect.signature(self.__init__)
|
||||
|
|
|
@ -17,8 +17,6 @@ from django.db.utils import DatabaseError, NotSupportedError
|
|||
from django.utils.deprecation import RemovedInDjango31Warning
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
FORCE = object()
|
||||
|
||||
|
||||
class SQLCompiler:
|
||||
def __init__(self, query, connection, using):
|
||||
|
@ -244,10 +242,12 @@ class SQLCompiler:
|
|||
ret = []
|
||||
for col, alias in select:
|
||||
try:
|
||||
sql, params = self.compile(col, select_format=True)
|
||||
sql, params = self.compile(col)
|
||||
except EmptyResultSet:
|
||||
# Select a predicate that's always False.
|
||||
sql, params = '0', ()
|
||||
else:
|
||||
sql, params = col.select_format(self, sql, params)
|
||||
ret.append((col, (sql, params), alias))
|
||||
return ret, klass_info, annotations
|
||||
|
||||
|
@ -402,14 +402,12 @@ class SQLCompiler:
|
|||
self.quote_cache[name] = r
|
||||
return r
|
||||
|
||||
def compile(self, node, select_format=False):
|
||||
def compile(self, node):
|
||||
vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
|
||||
if vendor_impl:
|
||||
sql, params = vendor_impl(self, self.connection)
|
||||
else:
|
||||
sql, params = node.as_sql(self, self.connection)
|
||||
if select_format is FORCE or (select_format and not self.query.subquery):
|
||||
return node.output_field.select_format(self, sql, params)
|
||||
return sql, params
|
||||
|
||||
def get_combinator_sql(self, combinator, all):
|
||||
|
@ -1503,7 +1501,8 @@ class SQLAggregateCompiler(SQLCompiler):
|
|||
"""
|
||||
sql, params = [], []
|
||||
for annotation in self.query.annotation_select.values():
|
||||
ann_sql, ann_params = self.compile(annotation, select_format=FORCE)
|
||||
ann_sql, ann_params = self.compile(annotation)
|
||||
ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
|
||||
sql.append(ann_sql)
|
||||
params.extend(ann_params)
|
||||
self.col_count = len(self.query.annotation_select)
|
||||
|
|
Loading…
Reference in New Issue