From fff5186d3215e0ba06e47090226169f2230786b0 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Fri, 9 Aug 2019 14:30:23 -0400 Subject: [PATCH] Refs #25367 -- Moved select_format hook to BaseExpression. This will expose an intermediary hook for expressions that need special formatting when used in a SELECT clause. --- django/contrib/gis/db/models/fields.py | 4 +++- django/db/models/expressions.py | 7 +++++++ django/db/models/sql/compiler.py | 13 ++++++------- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index f73e26be5e..08186d8933 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -272,7 +272,9 @@ class GeometryField(BaseSpatialField): of the spatial backend. For example, Oracle and MySQL require custom selection formats in order to retrieve geometries in OGC WKB. """ - return compiler.connection.ops.select % sql, params + if not compiler.query.subquery: + return compiler.connection.ops.select % sql, params + return sql, params # The OpenGIS Geometry Type Fields diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 9d14caf4cb..16924be9f6 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -366,6 +366,13 @@ class BaseExpression: if expr: yield from expr.flatten() + def select_format(self, compiler, sql, params): + """ + Custom format for select clauses. For example, EXISTS expressions need + to be wrapped in CASE WHEN on Oracle. + """ + return self.output_field.select_format(compiler, sql, params) + @cached_property def identity(self): constructor_signature = inspect.signature(self.__init__) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 52ea717ca6..77e023b92f 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -17,8 +17,6 @@ from django.db.utils import DatabaseError, NotSupportedError from django.utils.deprecation import RemovedInDjango31Warning from django.utils.hashable import make_hashable -FORCE = object() - class SQLCompiler: def __init__(self, query, connection, using): @@ -244,10 +242,12 @@ class SQLCompiler: ret = [] for col, alias in select: try: - sql, params = self.compile(col, select_format=True) + sql, params = self.compile(col) except EmptyResultSet: # Select a predicate that's always False. sql, params = '0', () + else: + sql, params = col.select_format(self, sql, params) ret.append((col, (sql, params), alias)) return ret, klass_info, annotations @@ -402,14 +402,12 @@ class SQLCompiler: self.quote_cache[name] = r return r - def compile(self, node, select_format=False): + def compile(self, node): vendor_impl = getattr(node, 'as_' + self.connection.vendor, None) if vendor_impl: sql, params = vendor_impl(self, self.connection) else: sql, params = node.as_sql(self, self.connection) - if select_format is FORCE or (select_format and not self.query.subquery): - return node.output_field.select_format(self, sql, params) return sql, params def get_combinator_sql(self, combinator, all): @@ -1503,7 +1501,8 @@ class SQLAggregateCompiler(SQLCompiler): """ sql, params = [], [] for annotation in self.query.annotation_select.values(): - ann_sql, ann_params = self.compile(annotation, select_format=FORCE) + ann_sql, ann_params = self.compile(annotation) + ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params) sql.append(ann_sql) params.extend(ann_params) self.col_count = len(self.query.annotation_select)