Added support for parameters in SELECT clauses.

This commit is contained in:
Aymeric Augustin 2013-02-13 14:47:44 +01:00
parent b4351d2890
commit 924a144ef8
14 changed files with 79 additions and 64 deletions

View File

@ -56,12 +56,13 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
lookup_info = self.geometry_functions.get(lookup_type, False)
if lookup_info:
return "%s(%s, %s)" % (lookup_info, geo_col,
self.get_geom_placeholder(value, field.srid))
sql = "%s(%s, %s)" % (lookup_info, geo_col,
self.get_geom_placeholder(value, field.srid))
return sql, []
# TODO: Is this really necessary? MySQL can't handle NULL geometries
# in its spatial indexes anyways.
if lookup_type == 'isnull':
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))

View File

@ -262,7 +262,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value))
elif lookup_type == 'isnull':
# Handling 'isnull' lookup type
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
@ -288,7 +288,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
def spatial_ref_sys(self):
from django.contrib.gis.db.backends.oracle.models import SpatialRefSys
return SpatialRefSys
def modify_insert_params(self, placeholders, params):
"""Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial
backend due to #10888

View File

@ -560,7 +560,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
elif lookup_type == 'isnull':
# Handling 'isnull' lookup type
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))

View File

@ -358,7 +358,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
return op.as_sql(geo_col, self.get_geom_placeholder(field, geom))
elif lookup_type == 'isnull':
# Handling 'isnull' lookup type
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))

View File

@ -16,7 +16,7 @@ class SpatialOperation(object):
self.extra = kwargs
def as_sql(self, geo_col, geometry='%s'):
return self.sql_template % self.params(geo_col, geometry)
return self.sql_template % self.params(geo_col, geometry), []
def params(self, geo_col, geometry):
params = {'function' : self.function,

View File

@ -22,13 +22,15 @@ class GeoAggregate(Aggregate):
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
def as_sql(self, qn, connection):
"Return the aggregate, rendered as SQL."
"Return the aggregate, rendered as SQL with parameters."
if connection.ops.oracle:
self.extra['tolerance'] = self.tolerance
params = []
if hasattr(self.col, 'as_sql'):
field_name = self.col.as_sql(qn, connection)
field_name, params = self.col.as_sql(qn, connection)
elif isinstance(self.col, (list, tuple)):
field_name = '.'.join([qn(c) for c in self.col])
else:
@ -36,13 +38,13 @@ class GeoAggregate(Aggregate):
sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
params = {
substitutions = {
'function': sql_function,
'field': field_name
}
params.update(self.extra)
substitutions.update(self.extra)
return sql_template % params
return sql_template % substitutions, params
class Collect(GeoAggregate):
pass

View File

@ -33,6 +33,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
for alias, col in six.iteritems(self.query.extra_select)]
params = []
aliases = set(self.query.extra_select.keys())
if with_aliases:
col_aliases = aliases.copy()
@ -63,7 +64,9 @@ class GeoSQLCompiler(compiler.SQLCompiler):
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(col.as_sql(qn, self.connection))
col_sql, col_params = col.as_sql(qn, self.connection)
result.append(col_sql)
params.extend(col_params)
if hasattr(col, 'alias'):
aliases.add(col.alias)
@ -76,15 +79,13 @@ class GeoSQLCompiler(compiler.SQLCompiler):
aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length()
result.extend([
'%s%s' % (
self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection),
alias is not None
and ' AS %s' % qn(truncate_name(alias, max_name_length))
or ''
)
for alias, aggregate in self.query.aggregate_select.items()
])
for alias, aggregate in self.query.aggregate_select.items():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
if alias is None:
result.append(agg_sql)
else:
result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
params.extend(agg_params)
# This loop customized for GeoQuery.
for (table, col), field in self.query.related_select_cols:
@ -100,7 +101,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
col_aliases.add(col)
self._select_aliases = aliases
return result
return result, params
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, from_parent=None):

View File

@ -44,8 +44,9 @@ class GeoWhereNode(WhereNode):
lvalue, lookup_type, value_annot, params_or_value = child
if isinstance(lvalue, GeoConstraint):
data, params = lvalue.process(lookup_type, params_or_value, connection)
spatial_sql = connection.ops.spatial_lookup_sql(data, lookup_type, params_or_value, lvalue.field, qn)
return spatial_sql, params
spatial_sql, spatial_params = connection.ops.spatial_lookup_sql(
data, lookup_type, params_or_value, lvalue.field, qn)
return spatial_sql, spatial_params + params
else:
return super(GeoWhereNode, self).make_atom(child, qn, connection)

View File

@ -25,7 +25,7 @@ class QueryWrapper(object):
parameters. Can be used to pass opaque data to a where-clause, for example.
"""
def __init__(self, sql, params):
self.data = sql, params
self.data = sql, list(params)
def as_sql(self, qn=None, connection=None):
return self.data

View File

@ -73,22 +73,23 @@ class Aggregate(object):
self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
def as_sql(self, qn, connection):
"Return the aggregate, rendered as SQL."
"Return the aggregate, rendered as SQL with parameters."
params = []
if hasattr(self.col, 'as_sql'):
field_name = self.col.as_sql(qn, connection)
field_name, params = self.col.as_sql(qn, connection)
elif isinstance(self.col, (list, tuple)):
field_name = '.'.join([qn(c) for c in self.col])
else:
field_name = self.col
params = {
substitutions = {
'function': self.sql_function,
'field': field_name
}
params.update(self.extra)
substitutions.update(self.extra)
return self.sql_template % params
return self.sql_template % substitutions, params
class Avg(Aggregate):

View File

@ -74,7 +74,7 @@ class SQLCompiler(object):
# as the pre_sql_setup will modify query state in a way that forbids
# another run of it.
self.refcounts_before = self.query.alias_refcount.copy()
out_cols = self.get_columns(with_col_aliases)
out_cols, s_params = self.get_columns(with_col_aliases)
ordering, ordering_group_by = self.get_ordering()
distinct_fields = self.get_distinct()
@ -97,6 +97,7 @@ class SQLCompiler(object):
result.append(self.connection.ops.distinct_sql(distinct_fields))
result.append(', '.join(out_cols + self.query.ordering_aliases))
params.extend(s_params)
result.append('FROM')
result.extend(from_)
@ -164,9 +165,10 @@ class SQLCompiler(object):
def get_columns(self, with_aliases=False):
"""
Returns the list of columns to use in the select statement. If no
columns have been specified, returns all columns relating to fields in
the model.
Returns the list of columns to use in the select statement, as well as
a list any extra parameters that need to be included. If no columns
have been specified, returns all columns relating to fields in the
model.
If 'with_aliases' is true, any column names that are duplicated
(without the table names) are given unique aliases. This is needed in
@ -175,6 +177,7 @@ class SQLCompiler(object):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
params = []
aliases = set(self.query.extra_select.keys())
if with_aliases:
col_aliases = aliases.copy()
@ -204,7 +207,9 @@ class SQLCompiler(object):
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(col.as_sql(qn, self.connection))
col_sql, col_params = col.as_sql(qn, self.connection)
result.append(col_sql)
params.extend(col_params)
if hasattr(col, 'alias'):
aliases.add(col.alias)
@ -217,15 +222,13 @@ class SQLCompiler(object):
aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length()
result.extend([
'%s%s' % (
aggregate.as_sql(qn, self.connection),
alias is not None
and ' AS %s' % qn(truncate_name(alias, max_name_length))
or ''
)
for alias, aggregate in self.query.aggregate_select.items()
])
for alias, aggregate in self.query.aggregate_select.items():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
if alias is None:
result.append(agg_sql)
else:
result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
params.extend(agg_params)
for (table, col), _ in self.query.related_select_cols:
r = '%s.%s' % (qn(table), qn(col))
@ -240,7 +243,7 @@ class SQLCompiler(object):
col_aliases.add(col)
self._select_aliases = aliases
return result
return result, params
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, from_parent=None):
@ -545,14 +548,16 @@ class SQLCompiler(object):
seen = set()
cols = self.query.group_by + select_cols
for col in cols:
col_params = ()
if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'):
sql = col.as_sql(qn, self.connection)
sql, col_params = col.as_sql(qn, self.connection)
else:
sql = '(%s)' % str(col)
if sql not in seen:
result.append(sql)
params.extend(col_params)
seen.add(sql)
# Still, we need to add all stuff in ordering (except if the backend can
@ -991,15 +996,17 @@ class SQLAggregateCompiler(SQLCompiler):
if qn is None:
qn = self.quote_name_unless_alias
sql = ('SELECT %s FROM (%s) subquery' % (
', '.join([
aggregate.as_sql(qn, self.connection)
for aggregate in self.query.aggregate_select.values()
]),
self.query.subquery)
)
params = self.query.sub_params
return (sql, params)
sql, params = [], []
for aggregate in self.query.aggregate_select.values():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
sql.append(agg_sql)
params.extend(agg_params)
sql = ', '.join(sql)
params = tuple(params)
sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
params = params + self.query.sub_params
return sql, params
class SQLDateCompiler(SQLCompiler):
def results_iter(self):

View File

@ -42,7 +42,7 @@ class Date(object):
col = '%s.%s' % tuple([qn(c) for c in self.col])
else:
col = self.col
return getattr(connection.ops, self.trunc_func)(self.lookup_type, col)
return getattr(connection.ops, self.trunc_func)(self.lookup_type, col), []
class DateTime(Date):
"""

View File

@ -94,9 +94,9 @@ class SQLEvaluator(object):
if col is None:
raise ValueError("Given node not found")
if hasattr(col, 'as_sql'):
return col.as_sql(qn, connection), ()
return col.as_sql(qn, connection)
else:
return '%s.%s' % (qn(col[0]), qn(col[1])), ()
return '%s.%s' % (qn(col[0]), qn(col[1])), []
def evaluate_date_modifier_node(self, node, qn, connection):
timedelta = node.children.pop()

View File

@ -172,10 +172,10 @@ class WhereNode(tree.Node):
if isinstance(lvalue, tuple):
# A direct database column lookup.
field_sql = self.sql_for_columns(lvalue, qn, connection)
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), []
else:
# A smart object with an as_sql() method.
field_sql = lvalue.as_sql(qn, connection)
field_sql, field_params = lvalue.as_sql(qn, connection)
is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
@ -186,6 +186,8 @@ class WhereNode(tree.Node):
else:
extra = ''
params = field_params + params
if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
and connection.features.interprets_empty_strings_as_nulls):
lookup_type = 'isnull'
@ -245,7 +247,7 @@ class WhereNode(tree.Node):
"""
Returns the SQL fragment used for the left-hand side of a column
constraint (for example, the "T1.foo" portion in the clause
"WHERE ... T1.foo = 6").
"WHERE ... T1.foo = 6") and a list of parameters.
"""
table_alias, name, db_type = data
if table_alias:
@ -338,7 +340,7 @@ class ExtraWhere(object):
def as_sql(self, qn=None, connection=None):
sqls = ["(%s)" % sql for sql in self.sqls]
return " AND ".join(sqls), tuple(self.params or ())
return " AND ".join(sqls), list(self.params or ())
def clone(self):
return self