Added support for parameters in SELECT clauses.
This commit is contained in:
parent
b4351d2890
commit
924a144ef8
|
@ -56,12 +56,13 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
|
||||
lookup_info = self.geometry_functions.get(lookup_type, False)
|
||||
if lookup_info:
|
||||
return "%s(%s, %s)" % (lookup_info, geo_col,
|
||||
sql = "%s(%s, %s)" % (lookup_info, geo_col,
|
||||
self.get_geom_placeholder(value, field.srid))
|
||||
return sql, []
|
||||
|
||||
# TODO: Is this really necessary? MySQL can't handle NULL geometries
|
||||
# in its spatial indexes anyways.
|
||||
if lookup_type == 'isnull':
|
||||
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
|
||||
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
|
||||
|
||||
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
|
||||
|
|
|
@ -262,7 +262,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value))
|
||||
elif lookup_type == 'isnull':
|
||||
# Handling 'isnull' lookup type
|
||||
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
|
||||
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
|
||||
|
||||
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
|
||||
|
||||
|
|
|
@ -560,7 +560,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
|
||||
elif lookup_type == 'isnull':
|
||||
# Handling 'isnull' lookup type
|
||||
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
|
||||
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
|
||||
|
||||
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
|
||||
|
||||
|
|
|
@ -358,7 +358,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
return op.as_sql(geo_col, self.get_geom_placeholder(field, geom))
|
||||
elif lookup_type == 'isnull':
|
||||
# Handling 'isnull' lookup type
|
||||
return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
|
||||
return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
|
||||
|
||||
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ class SpatialOperation(object):
|
|||
self.extra = kwargs
|
||||
|
||||
def as_sql(self, geo_col, geometry='%s'):
|
||||
return self.sql_template % self.params(geo_col, geometry)
|
||||
return self.sql_template % self.params(geo_col, geometry), []
|
||||
|
||||
def params(self, geo_col, geometry):
|
||||
params = {'function' : self.function,
|
||||
|
|
|
@ -22,13 +22,15 @@ class GeoAggregate(Aggregate):
|
|||
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
"Return the aggregate, rendered as SQL."
|
||||
"Return the aggregate, rendered as SQL with parameters."
|
||||
|
||||
if connection.ops.oracle:
|
||||
self.extra['tolerance'] = self.tolerance
|
||||
|
||||
params = []
|
||||
|
||||
if hasattr(self.col, 'as_sql'):
|
||||
field_name = self.col.as_sql(qn, connection)
|
||||
field_name, params = self.col.as_sql(qn, connection)
|
||||
elif isinstance(self.col, (list, tuple)):
|
||||
field_name = '.'.join([qn(c) for c in self.col])
|
||||
else:
|
||||
|
@ -36,13 +38,13 @@ class GeoAggregate(Aggregate):
|
|||
|
||||
sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
|
||||
|
||||
params = {
|
||||
substitutions = {
|
||||
'function': sql_function,
|
||||
'field': field_name
|
||||
}
|
||||
params.update(self.extra)
|
||||
substitutions.update(self.extra)
|
||||
|
||||
return sql_template % params
|
||||
return sql_template % substitutions, params
|
||||
|
||||
class Collect(GeoAggregate):
|
||||
pass
|
||||
|
|
|
@ -33,6 +33,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
|
|||
qn2 = self.connection.ops.quote_name
|
||||
result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
|
||||
for alias, col in six.iteritems(self.query.extra_select)]
|
||||
params = []
|
||||
aliases = set(self.query.extra_select.keys())
|
||||
if with_aliases:
|
||||
col_aliases = aliases.copy()
|
||||
|
@ -63,7 +64,9 @@ class GeoSQLCompiler(compiler.SQLCompiler):
|
|||
aliases.add(r)
|
||||
col_aliases.add(col[1])
|
||||
else:
|
||||
result.append(col.as_sql(qn, self.connection))
|
||||
col_sql, col_params = col.as_sql(qn, self.connection)
|
||||
result.append(col_sql)
|
||||
params.extend(col_params)
|
||||
|
||||
if hasattr(col, 'alias'):
|
||||
aliases.add(col.alias)
|
||||
|
@ -76,15 +79,13 @@ class GeoSQLCompiler(compiler.SQLCompiler):
|
|||
aliases.update(new_aliases)
|
||||
|
||||
max_name_length = self.connection.ops.max_name_length()
|
||||
result.extend([
|
||||
'%s%s' % (
|
||||
self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection),
|
||||
alias is not None
|
||||
and ' AS %s' % qn(truncate_name(alias, max_name_length))
|
||||
or ''
|
||||
)
|
||||
for alias, aggregate in self.query.aggregate_select.items()
|
||||
])
|
||||
for alias, aggregate in self.query.aggregate_select.items():
|
||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
||||
if alias is None:
|
||||
result.append(agg_sql)
|
||||
else:
|
||||
result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
|
||||
params.extend(agg_params)
|
||||
|
||||
# This loop customized for GeoQuery.
|
||||
for (table, col), field in self.query.related_select_cols:
|
||||
|
@ -100,7 +101,7 @@ class GeoSQLCompiler(compiler.SQLCompiler):
|
|||
col_aliases.add(col)
|
||||
|
||||
self._select_aliases = aliases
|
||||
return result
|
||||
return result, params
|
||||
|
||||
def get_default_columns(self, with_aliases=False, col_aliases=None,
|
||||
start_alias=None, opts=None, as_pairs=False, from_parent=None):
|
||||
|
|
|
@ -44,8 +44,9 @@ class GeoWhereNode(WhereNode):
|
|||
lvalue, lookup_type, value_annot, params_or_value = child
|
||||
if isinstance(lvalue, GeoConstraint):
|
||||
data, params = lvalue.process(lookup_type, params_or_value, connection)
|
||||
spatial_sql = connection.ops.spatial_lookup_sql(data, lookup_type, params_or_value, lvalue.field, qn)
|
||||
return spatial_sql, params
|
||||
spatial_sql, spatial_params = connection.ops.spatial_lookup_sql(
|
||||
data, lookup_type, params_or_value, lvalue.field, qn)
|
||||
return spatial_sql, spatial_params + params
|
||||
else:
|
||||
return super(GeoWhereNode, self).make_atom(child, qn, connection)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class QueryWrapper(object):
|
|||
parameters. Can be used to pass opaque data to a where-clause, for example.
|
||||
"""
|
||||
def __init__(self, sql, params):
|
||||
self.data = sql, params
|
||||
self.data = sql, list(params)
|
||||
|
||||
def as_sql(self, qn=None, connection=None):
|
||||
return self.data
|
||||
|
|
|
@ -73,22 +73,23 @@ class Aggregate(object):
|
|||
self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
"Return the aggregate, rendered as SQL."
|
||||
"Return the aggregate, rendered as SQL with parameters."
|
||||
params = []
|
||||
|
||||
if hasattr(self.col, 'as_sql'):
|
||||
field_name = self.col.as_sql(qn, connection)
|
||||
field_name, params = self.col.as_sql(qn, connection)
|
||||
elif isinstance(self.col, (list, tuple)):
|
||||
field_name = '.'.join([qn(c) for c in self.col])
|
||||
else:
|
||||
field_name = self.col
|
||||
|
||||
params = {
|
||||
substitutions = {
|
||||
'function': self.sql_function,
|
||||
'field': field_name
|
||||
}
|
||||
params.update(self.extra)
|
||||
substitutions.update(self.extra)
|
||||
|
||||
return self.sql_template % params
|
||||
return self.sql_template % substitutions, params
|
||||
|
||||
|
||||
class Avg(Aggregate):
|
||||
|
|
|
@ -74,7 +74,7 @@ class SQLCompiler(object):
|
|||
# as the pre_sql_setup will modify query state in a way that forbids
|
||||
# another run of it.
|
||||
self.refcounts_before = self.query.alias_refcount.copy()
|
||||
out_cols = self.get_columns(with_col_aliases)
|
||||
out_cols, s_params = self.get_columns(with_col_aliases)
|
||||
ordering, ordering_group_by = self.get_ordering()
|
||||
|
||||
distinct_fields = self.get_distinct()
|
||||
|
@ -97,6 +97,7 @@ class SQLCompiler(object):
|
|||
result.append(self.connection.ops.distinct_sql(distinct_fields))
|
||||
|
||||
result.append(', '.join(out_cols + self.query.ordering_aliases))
|
||||
params.extend(s_params)
|
||||
|
||||
result.append('FROM')
|
||||
result.extend(from_)
|
||||
|
@ -164,9 +165,10 @@ class SQLCompiler(object):
|
|||
|
||||
def get_columns(self, with_aliases=False):
|
||||
"""
|
||||
Returns the list of columns to use in the select statement. If no
|
||||
columns have been specified, returns all columns relating to fields in
|
||||
the model.
|
||||
Returns the list of columns to use in the select statement, as well as
|
||||
a list any extra parameters that need to be included. If no columns
|
||||
have been specified, returns all columns relating to fields in the
|
||||
model.
|
||||
|
||||
If 'with_aliases' is true, any column names that are duplicated
|
||||
(without the table names) are given unique aliases. This is needed in
|
||||
|
@ -175,6 +177,7 @@ class SQLCompiler(object):
|
|||
qn = self.quote_name_unless_alias
|
||||
qn2 = self.connection.ops.quote_name
|
||||
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
|
||||
params = []
|
||||
aliases = set(self.query.extra_select.keys())
|
||||
if with_aliases:
|
||||
col_aliases = aliases.copy()
|
||||
|
@ -204,7 +207,9 @@ class SQLCompiler(object):
|
|||
aliases.add(r)
|
||||
col_aliases.add(col[1])
|
||||
else:
|
||||
result.append(col.as_sql(qn, self.connection))
|
||||
col_sql, col_params = col.as_sql(qn, self.connection)
|
||||
result.append(col_sql)
|
||||
params.extend(col_params)
|
||||
|
||||
if hasattr(col, 'alias'):
|
||||
aliases.add(col.alias)
|
||||
|
@ -217,15 +222,13 @@ class SQLCompiler(object):
|
|||
aliases.update(new_aliases)
|
||||
|
||||
max_name_length = self.connection.ops.max_name_length()
|
||||
result.extend([
|
||||
'%s%s' % (
|
||||
aggregate.as_sql(qn, self.connection),
|
||||
alias is not None
|
||||
and ' AS %s' % qn(truncate_name(alias, max_name_length))
|
||||
or ''
|
||||
)
|
||||
for alias, aggregate in self.query.aggregate_select.items()
|
||||
])
|
||||
for alias, aggregate in self.query.aggregate_select.items():
|
||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
||||
if alias is None:
|
||||
result.append(agg_sql)
|
||||
else:
|
||||
result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
|
||||
params.extend(agg_params)
|
||||
|
||||
for (table, col), _ in self.query.related_select_cols:
|
||||
r = '%s.%s' % (qn(table), qn(col))
|
||||
|
@ -240,7 +243,7 @@ class SQLCompiler(object):
|
|||
col_aliases.add(col)
|
||||
|
||||
self._select_aliases = aliases
|
||||
return result
|
||||
return result, params
|
||||
|
||||
def get_default_columns(self, with_aliases=False, col_aliases=None,
|
||||
start_alias=None, opts=None, as_pairs=False, from_parent=None):
|
||||
|
@ -545,14 +548,16 @@ class SQLCompiler(object):
|
|||
seen = set()
|
||||
cols = self.query.group_by + select_cols
|
||||
for col in cols:
|
||||
col_params = ()
|
||||
if isinstance(col, (list, tuple)):
|
||||
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
|
||||
elif hasattr(col, 'as_sql'):
|
||||
sql = col.as_sql(qn, self.connection)
|
||||
sql, col_params = col.as_sql(qn, self.connection)
|
||||
else:
|
||||
sql = '(%s)' % str(col)
|
||||
if sql not in seen:
|
||||
result.append(sql)
|
||||
params.extend(col_params)
|
||||
seen.add(sql)
|
||||
|
||||
# Still, we need to add all stuff in ordering (except if the backend can
|
||||
|
@ -991,15 +996,17 @@ class SQLAggregateCompiler(SQLCompiler):
|
|||
if qn is None:
|
||||
qn = self.quote_name_unless_alias
|
||||
|
||||
sql = ('SELECT %s FROM (%s) subquery' % (
|
||||
', '.join([
|
||||
aggregate.as_sql(qn, self.connection)
|
||||
for aggregate in self.query.aggregate_select.values()
|
||||
]),
|
||||
self.query.subquery)
|
||||
)
|
||||
params = self.query.sub_params
|
||||
return (sql, params)
|
||||
sql, params = [], []
|
||||
for aggregate in self.query.aggregate_select.values():
|
||||
agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
|
||||
sql.append(agg_sql)
|
||||
params.extend(agg_params)
|
||||
sql = ', '.join(sql)
|
||||
params = tuple(params)
|
||||
|
||||
sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
|
||||
params = params + self.query.sub_params
|
||||
return sql, params
|
||||
|
||||
class SQLDateCompiler(SQLCompiler):
|
||||
def results_iter(self):
|
||||
|
|
|
@ -42,7 +42,7 @@ class Date(object):
|
|||
col = '%s.%s' % tuple([qn(c) for c in self.col])
|
||||
else:
|
||||
col = self.col
|
||||
return getattr(connection.ops, self.trunc_func)(self.lookup_type, col)
|
||||
return getattr(connection.ops, self.trunc_func)(self.lookup_type, col), []
|
||||
|
||||
class DateTime(Date):
|
||||
"""
|
||||
|
|
|
@ -94,9 +94,9 @@ class SQLEvaluator(object):
|
|||
if col is None:
|
||||
raise ValueError("Given node not found")
|
||||
if hasattr(col, 'as_sql'):
|
||||
return col.as_sql(qn, connection), ()
|
||||
return col.as_sql(qn, connection)
|
||||
else:
|
||||
return '%s.%s' % (qn(col[0]), qn(col[1])), ()
|
||||
return '%s.%s' % (qn(col[0]), qn(col[1])), []
|
||||
|
||||
def evaluate_date_modifier_node(self, node, qn, connection):
|
||||
timedelta = node.children.pop()
|
||||
|
|
|
@ -172,10 +172,10 @@ class WhereNode(tree.Node):
|
|||
|
||||
if isinstance(lvalue, tuple):
|
||||
# A direct database column lookup.
|
||||
field_sql = self.sql_for_columns(lvalue, qn, connection)
|
||||
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), []
|
||||
else:
|
||||
# A smart object with an as_sql() method.
|
||||
field_sql = lvalue.as_sql(qn, connection)
|
||||
field_sql, field_params = lvalue.as_sql(qn, connection)
|
||||
|
||||
is_datetime_field = value_annotation is datetime.datetime
|
||||
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
|
||||
|
@ -186,6 +186,8 @@ class WhereNode(tree.Node):
|
|||
else:
|
||||
extra = ''
|
||||
|
||||
params = field_params + params
|
||||
|
||||
if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
|
||||
and connection.features.interprets_empty_strings_as_nulls):
|
||||
lookup_type = 'isnull'
|
||||
|
@ -245,7 +247,7 @@ class WhereNode(tree.Node):
|
|||
"""
|
||||
Returns the SQL fragment used for the left-hand side of a column
|
||||
constraint (for example, the "T1.foo" portion in the clause
|
||||
"WHERE ... T1.foo = 6").
|
||||
"WHERE ... T1.foo = 6") and a list of parameters.
|
||||
"""
|
||||
table_alias, name, db_type = data
|
||||
if table_alias:
|
||||
|
@ -338,7 +340,7 @@ class ExtraWhere(object):
|
|||
|
||||
def as_sql(self, qn=None, connection=None):
|
||||
sqls = ["(%s)" % sql for sql in self.sqls]
|
||||
return " AND ".join(sqls), tuple(self.params or ())
|
||||
return " AND ".join(sqls), list(self.params or ())
|
||||
|
||||
def clone(self):
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue