Refs #25367 -- Moved select_format hook to BaseExpression.
This will expose an intermediary hook for expressions that need special formatting when used in a SELECT clause.
This commit is contained in:
parent
4f7328ce8a
commit
fff5186d32
|
@ -272,7 +272,9 @@ class GeometryField(BaseSpatialField):
|
||||||
of the spatial backend. For example, Oracle and MySQL require custom
|
of the spatial backend. For example, Oracle and MySQL require custom
|
||||||
selection formats in order to retrieve geometries in OGC WKB.
|
selection formats in order to retrieve geometries in OGC WKB.
|
||||||
"""
|
"""
|
||||||
return compiler.connection.ops.select % sql, params
|
if not compiler.query.subquery:
|
||||||
|
return compiler.connection.ops.select % sql, params
|
||||||
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
# The OpenGIS Geometry Type Fields
|
# The OpenGIS Geometry Type Fields
|
||||||
|
|
|
@ -366,6 +366,13 @@ class BaseExpression:
|
||||||
if expr:
|
if expr:
|
||||||
yield from expr.flatten()
|
yield from expr.flatten()
|
||||||
|
|
||||||
|
def select_format(self, compiler, sql, params):
|
||||||
|
"""
|
||||||
|
Custom format for select clauses. For example, EXISTS expressions need
|
||||||
|
to be wrapped in CASE WHEN on Oracle.
|
||||||
|
"""
|
||||||
|
return self.output_field.select_format(compiler, sql, params)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def identity(self):
|
def identity(self):
|
||||||
constructor_signature = inspect.signature(self.__init__)
|
constructor_signature = inspect.signature(self.__init__)
|
||||||
|
|
|
@ -17,8 +17,6 @@ from django.db.utils import DatabaseError, NotSupportedError
|
||||||
from django.utils.deprecation import RemovedInDjango31Warning
|
from django.utils.deprecation import RemovedInDjango31Warning
|
||||||
from django.utils.hashable import make_hashable
|
from django.utils.hashable import make_hashable
|
||||||
|
|
||||||
FORCE = object()
|
|
||||||
|
|
||||||
|
|
||||||
class SQLCompiler:
|
class SQLCompiler:
|
||||||
def __init__(self, query, connection, using):
|
def __init__(self, query, connection, using):
|
||||||
|
@ -244,10 +242,12 @@ class SQLCompiler:
|
||||||
ret = []
|
ret = []
|
||||||
for col, alias in select:
|
for col, alias in select:
|
||||||
try:
|
try:
|
||||||
sql, params = self.compile(col, select_format=True)
|
sql, params = self.compile(col)
|
||||||
except EmptyResultSet:
|
except EmptyResultSet:
|
||||||
# Select a predicate that's always False.
|
# Select a predicate that's always False.
|
||||||
sql, params = '0', ()
|
sql, params = '0', ()
|
||||||
|
else:
|
||||||
|
sql, params = col.select_format(self, sql, params)
|
||||||
ret.append((col, (sql, params), alias))
|
ret.append((col, (sql, params), alias))
|
||||||
return ret, klass_info, annotations
|
return ret, klass_info, annotations
|
||||||
|
|
||||||
|
@ -402,14 +402,12 @@ class SQLCompiler:
|
||||||
self.quote_cache[name] = r
|
self.quote_cache[name] = r
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def compile(self, node, select_format=False):
|
def compile(self, node):
|
||||||
vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
|
vendor_impl = getattr(node, 'as_' + self.connection.vendor, None)
|
||||||
if vendor_impl:
|
if vendor_impl:
|
||||||
sql, params = vendor_impl(self, self.connection)
|
sql, params = vendor_impl(self, self.connection)
|
||||||
else:
|
else:
|
||||||
sql, params = node.as_sql(self, self.connection)
|
sql, params = node.as_sql(self, self.connection)
|
||||||
if select_format is FORCE or (select_format and not self.query.subquery):
|
|
||||||
return node.output_field.select_format(self, sql, params)
|
|
||||||
return sql, params
|
return sql, params
|
||||||
|
|
||||||
def get_combinator_sql(self, combinator, all):
|
def get_combinator_sql(self, combinator, all):
|
||||||
|
@ -1503,7 +1501,8 @@ class SQLAggregateCompiler(SQLCompiler):
|
||||||
"""
|
"""
|
||||||
sql, params = [], []
|
sql, params = [], []
|
||||||
for annotation in self.query.annotation_select.values():
|
for annotation in self.query.annotation_select.values():
|
||||||
ann_sql, ann_params = self.compile(annotation, select_format=FORCE)
|
ann_sql, ann_params = self.compile(annotation)
|
||||||
|
ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
|
||||||
sql.append(ann_sql)
|
sql.append(ann_sql)
|
||||||
params.extend(ann_params)
|
params.extend(ann_params)
|
||||||
self.col_count = len(self.query.annotation_select)
|
self.col_count = len(self.query.annotation_select)
|
||||||
|
|
Loading…
Reference in New Issue