From 924a144ef8a80ba4daeeafbe9efaa826566e9d02 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Feb 2013 14:47:44 +0100 Subject: [PATCH] Added support for parameters in SELECT clauses. --- .../gis/db/backends/mysql/operations.py | 7 ++- .../gis/db/backends/oracle/operations.py | 4 +- .../gis/db/backends/postgis/operations.py | 2 +- .../gis/db/backends/spatialite/operations.py | 2 +- django/contrib/gis/db/backends/util.py | 2 +- .../contrib/gis/db/models/sql/aggregates.py | 12 ++-- django/contrib/gis/db/models/sql/compiler.py | 23 ++++---- django/contrib/gis/db/models/sql/where.py | 5 +- django/db/models/query_utils.py | 2 +- django/db/models/sql/aggregates.py | 11 ++-- django/db/models/sql/compiler.py | 57 +++++++++++-------- django/db/models/sql/datastructures.py | 2 +- django/db/models/sql/expressions.py | 4 +- django/db/models/sql/where.py | 10 ++-- 14 files changed, 79 insertions(+), 64 deletions(-) diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index fa20ca07f4..14402ec0a3 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -56,12 +56,13 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations): lookup_info = self.geometry_functions.get(lookup_type, False) if lookup_info: - return "%s(%s, %s)" % (lookup_info, geo_col, - self.get_geom_placeholder(value, field.srid)) + sql = "%s(%s, %s)" % (lookup_info, geo_col, + self.get_geom_placeholder(value, field.srid)) + return sql, [] # TODO: Is this really necessary? MySQL can't handle NULL geometries # in its spatial indexes anyways. if lookup_type == 'isnull': - return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) + return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 4e42b4cf00..18697ac8c0 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -262,7 +262,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value)) elif lookup_type == 'isnull': # Handling 'isnull' lookup type - return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) + return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) @@ -288,7 +288,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): def spatial_ref_sys(self): from django.contrib.gis.db.backends.oracle.models import SpatialRefSys return SpatialRefSys - + def modify_insert_params(self, placeholders, params): """Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial backend due to #10888 diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index aa23b974db..fe90343411 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -560,7 +560,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): elif lookup_type == 'isnull': # Handling 'isnull' lookup type - return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) + return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 773ac0b57d..d2d75c1fff 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -358,7 +358,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): return op.as_sql(geo_col, self.get_geom_placeholder(field, geom)) elif lookup_type == 'isnull': # Handling 'isnull' lookup type - return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or '')) + return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), [] raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type)) diff --git a/django/contrib/gis/db/backends/util.py b/django/contrib/gis/db/backends/util.py index 2fc9123d26..2612810659 100644 --- a/django/contrib/gis/db/backends/util.py +++ b/django/contrib/gis/db/backends/util.py @@ -16,7 +16,7 @@ class SpatialOperation(object): self.extra = kwargs def as_sql(self, geo_col, geometry='%s'): - return self.sql_template % self.params(geo_col, geometry) + return self.sql_template % self.params(geo_col, geometry), [] def params(self, geo_col, geometry): params = {'function' : self.function, diff --git a/django/contrib/gis/db/models/sql/aggregates.py b/django/contrib/gis/db/models/sql/aggregates.py index 9fcbb516d6..ae848c0894 100644 --- a/django/contrib/gis/db/models/sql/aggregates.py +++ b/django/contrib/gis/db/models/sql/aggregates.py @@ -22,13 +22,15 @@ class GeoAggregate(Aggregate): raise ValueError('Geospatial aggregates only allowed on geometry fields.') def as_sql(self, qn, connection): - "Return the aggregate, rendered as SQL." + "Return the aggregate, rendered as SQL with parameters." if connection.ops.oracle: self.extra['tolerance'] = self.tolerance + params = [] + if hasattr(self.col, 'as_sql'): - field_name = self.col.as_sql(qn, connection) + field_name, params = self.col.as_sql(qn, connection) elif isinstance(self.col, (list, tuple)): field_name = '.'.join([qn(c) for c in self.col]) else: @@ -36,13 +38,13 @@ class GeoAggregate(Aggregate): sql_template, sql_function = connection.ops.spatial_aggregate_sql(self) - params = { + substitutions = { 'function': sql_function, 'field': field_name } - params.update(self.extra) + substitutions.update(self.extra) - return sql_template % params + return sql_template % substitutions, params class Collect(GeoAggregate): pass diff --git a/django/contrib/gis/db/models/sql/compiler.py b/django/contrib/gis/db/models/sql/compiler.py index 5e4a504c72..b488f59362 100644 --- a/django/contrib/gis/db/models/sql/compiler.py +++ b/django/contrib/gis/db/models/sql/compiler.py @@ -33,6 +33,7 @@ class GeoSQLCompiler(compiler.SQLCompiler): qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] + params = [] aliases = set(self.query.extra_select.keys()) if with_aliases: col_aliases = aliases.copy() @@ -63,7 +64,9 @@ class GeoSQLCompiler(compiler.SQLCompiler): aliases.add(r) col_aliases.add(col[1]) else: - result.append(col.as_sql(qn, self.connection)) + col_sql, col_params = col.as_sql(qn, self.connection) + result.append(col_sql) + params.extend(col_params) if hasattr(col, 'alias'): aliases.add(col.alias) @@ -76,15 +79,13 @@ class GeoSQLCompiler(compiler.SQLCompiler): aliases.update(new_aliases) max_name_length = self.connection.ops.max_name_length() - result.extend([ - '%s%s' % ( - self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection), - alias is not None - and ' AS %s' % qn(truncate_name(alias, max_name_length)) - or '' - ) - for alias, aggregate in self.query.aggregate_select.items() - ]) + for alias, aggregate in self.query.aggregate_select.items(): + agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + if alias is None: + result.append(agg_sql) + else: + result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length)))) + params.extend(agg_params) # This loop customized for GeoQuery. for (table, col), field in self.query.related_select_cols: @@ -100,7 +101,7 @@ class GeoSQLCompiler(compiler.SQLCompiler): col_aliases.add(col) self._select_aliases = aliases - return result + return result, params def get_default_columns(self, with_aliases=False, col_aliases=None, start_alias=None, opts=None, as_pairs=False, from_parent=None): diff --git a/django/contrib/gis/db/models/sql/where.py b/django/contrib/gis/db/models/sql/where.py index ec078aebed..6ef34db0a3 100644 --- a/django/contrib/gis/db/models/sql/where.py +++ b/django/contrib/gis/db/models/sql/where.py @@ -44,8 +44,9 @@ class GeoWhereNode(WhereNode): lvalue, lookup_type, value_annot, params_or_value = child if isinstance(lvalue, GeoConstraint): data, params = lvalue.process(lookup_type, params_or_value, connection) - spatial_sql = connection.ops.spatial_lookup_sql(data, lookup_type, params_or_value, lvalue.field, qn) - return spatial_sql, params + spatial_sql, spatial_params = connection.ops.spatial_lookup_sql( + data, lookup_type, params_or_value, lvalue.field, qn) + return spatial_sql, spatial_params + params else: return super(GeoWhereNode, self).make_atom(child, qn, connection) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index c1a690a524..c82cc45617 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -25,7 +25,7 @@ class QueryWrapper(object): parameters. Can be used to pass opaque data to a where-clause, for example. """ def __init__(self, sql, params): - self.data = sql, params + self.data = sql, list(params) def as_sql(self, qn=None, connection=None): return self.data diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 75a330f22a..3c8720210b 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -73,22 +73,23 @@ class Aggregate(object): self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) def as_sql(self, qn, connection): - "Return the aggregate, rendered as SQL." + "Return the aggregate, rendered as SQL with parameters." + params = [] if hasattr(self.col, 'as_sql'): - field_name = self.col.as_sql(qn, connection) + field_name, params = self.col.as_sql(qn, connection) elif isinstance(self.col, (list, tuple)): field_name = '.'.join([qn(c) for c in self.col]) else: field_name = self.col - params = { + substitutions = { 'function': self.sql_function, 'field': field_name } - params.update(self.extra) + substitutions.update(self.extra) - return self.sql_template % params + return self.sql_template % substitutions, params class Avg(Aggregate): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 306b40e801..1f53810755 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -74,7 +74,7 @@ class SQLCompiler(object): # as the pre_sql_setup will modify query state in a way that forbids # another run of it. self.refcounts_before = self.query.alias_refcount.copy() - out_cols = self.get_columns(with_col_aliases) + out_cols, s_params = self.get_columns(with_col_aliases) ordering, ordering_group_by = self.get_ordering() distinct_fields = self.get_distinct() @@ -97,6 +97,7 @@ class SQLCompiler(object): result.append(self.connection.ops.distinct_sql(distinct_fields)) result.append(', '.join(out_cols + self.query.ordering_aliases)) + params.extend(s_params) result.append('FROM') result.extend(from_) @@ -164,9 +165,10 @@ class SQLCompiler(object): def get_columns(self, with_aliases=False): """ - Returns the list of columns to use in the select statement. If no - columns have been specified, returns all columns relating to fields in - the model. + Returns the list of columns to use in the select statement, as well as + a list any extra parameters that need to be included. If no columns + have been specified, returns all columns relating to fields in the + model. If 'with_aliases' is true, any column names that are duplicated (without the table names) are given unique aliases. This is needed in @@ -175,6 +177,7 @@ class SQLCompiler(object): qn = self.quote_name_unless_alias qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] + params = [] aliases = set(self.query.extra_select.keys()) if with_aliases: col_aliases = aliases.copy() @@ -204,7 +207,9 @@ class SQLCompiler(object): aliases.add(r) col_aliases.add(col[1]) else: - result.append(col.as_sql(qn, self.connection)) + col_sql, col_params = col.as_sql(qn, self.connection) + result.append(col_sql) + params.extend(col_params) if hasattr(col, 'alias'): aliases.add(col.alias) @@ -217,15 +222,13 @@ class SQLCompiler(object): aliases.update(new_aliases) max_name_length = self.connection.ops.max_name_length() - result.extend([ - '%s%s' % ( - aggregate.as_sql(qn, self.connection), - alias is not None - and ' AS %s' % qn(truncate_name(alias, max_name_length)) - or '' - ) - for alias, aggregate in self.query.aggregate_select.items() - ]) + for alias, aggregate in self.query.aggregate_select.items(): + agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + if alias is None: + result.append(agg_sql) + else: + result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length)))) + params.extend(agg_params) for (table, col), _ in self.query.related_select_cols: r = '%s.%s' % (qn(table), qn(col)) @@ -240,7 +243,7 @@ class SQLCompiler(object): col_aliases.add(col) self._select_aliases = aliases - return result + return result, params def get_default_columns(self, with_aliases=False, col_aliases=None, start_alias=None, opts=None, as_pairs=False, from_parent=None): @@ -545,14 +548,16 @@ class SQLCompiler(object): seen = set() cols = self.query.group_by + select_cols for col in cols: + col_params = () if isinstance(col, (list, tuple)): sql = '%s.%s' % (qn(col[0]), qn(col[1])) elif hasattr(col, 'as_sql'): - sql = col.as_sql(qn, self.connection) + sql, col_params = col.as_sql(qn, self.connection) else: sql = '(%s)' % str(col) if sql not in seen: result.append(sql) + params.extend(col_params) seen.add(sql) # Still, we need to add all stuff in ordering (except if the backend can @@ -991,15 +996,17 @@ class SQLAggregateCompiler(SQLCompiler): if qn is None: qn = self.quote_name_unless_alias - sql = ('SELECT %s FROM (%s) subquery' % ( - ', '.join([ - aggregate.as_sql(qn, self.connection) - for aggregate in self.query.aggregate_select.values() - ]), - self.query.subquery) - ) - params = self.query.sub_params - return (sql, params) + sql, params = [], [] + for aggregate in self.query.aggregate_select.values(): + agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + sql.append(agg_sql) + params.extend(agg_params) + sql = ', '.join(sql) + params = tuple(params) + + sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery) + params = params + self.query.sub_params + return sql, params class SQLDateCompiler(SQLCompiler): def results_iter(self): diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 96e1ddcdfb..51d4410f59 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -42,7 +42,7 @@ class Date(object): col = '%s.%s' % tuple([qn(c) for c in self.col]) else: col = self.col - return getattr(connection.ops, self.trunc_func)(self.lookup_type, col) + return getattr(connection.ops, self.trunc_func)(self.lookup_type, col), [] class DateTime(Date): """ diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index a4c1d85c65..2a5008f067 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -94,9 +94,9 @@ class SQLEvaluator(object): if col is None: raise ValueError("Given node not found") if hasattr(col, 'as_sql'): - return col.as_sql(qn, connection), () + return col.as_sql(qn, connection) else: - return '%s.%s' % (qn(col[0]), qn(col[1])), () + return '%s.%s' % (qn(col[0]), qn(col[1])), [] def evaluate_date_modifier_node(self, node, qn, connection): timedelta = node.children.pop() diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 3735ff0e8c..a61bc0b929 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -172,10 +172,10 @@ class WhereNode(tree.Node): if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn, connection) + field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), [] else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(qn, connection) + field_sql, field_params = lvalue.as_sql(qn, connection) is_datetime_field = value_annotation is datetime.datetime cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' @@ -186,6 +186,8 @@ class WhereNode(tree.Node): else: extra = '' + params = field_params + params + if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' and connection.features.interprets_empty_strings_as_nulls): lookup_type = 'isnull' @@ -245,7 +247,7 @@ class WhereNode(tree.Node): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause - "WHERE ... T1.foo = 6"). + "WHERE ... T1.foo = 6") and a list of parameters. """ table_alias, name, db_type = data if table_alias: @@ -338,7 +340,7 @@ class ExtraWhere(object): def as_sql(self, qn=None, connection=None): sqls = ["(%s)" % sql for sql in self.sqls] - return " AND ".join(sqls), tuple(self.params or ()) + return " AND ".join(sqls), list(self.params or ()) def clone(self): return self