Renamed qn to compiler

This commit is contained in:
Josh Smeaton 2014-11-16 12:56:42 +11:00 committed by Simon Charette
parent 05e0e4674c
commit f61256da3a
23 changed files with 240 additions and 237 deletions

View File

@ -186,7 +186,7 @@ class BaseSpatialOperations(object):
""" """
raise NotImplementedError('Distance operations not available on this spatial backend.') raise NotImplementedError('Distance operations not available on this spatial backend.')
def get_geom_placeholder(self, f, value, qn): def get_geom_placeholder(self, f, value, compiler):
""" """
Returns the placeholder for the given geometry field with the given Returns the placeholder for the given geometry field with the given
value. Depending on the spatial backend, the placeholder may contain a value. Depending on the spatial backend, the placeholder may contain a

View File

@ -35,14 +35,14 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
def geo_db_type(self, f): def geo_db_type(self, f):
return f.geom_type return f.geom_type
def get_geom_placeholder(self, f, value, qn): def get_geom_placeholder(self, f, value, compiler):
""" """
The placeholder here has to include MySQL's WKT constructor. Because The placeholder here has to include MySQL's WKT constructor. Because
MySQL does not support spatial transformations, there is no need to MySQL does not support spatial transformations, there is no need to
modify the placeholder based on the contents of the given value. modify the placeholder based on the contents of the given value.
""" """
if hasattr(value, 'as_sql'): if hasattr(value, 'as_sql'):
placeholder, _ = qn.compile(value) placeholder, _ = compiler.compile(value)
else: else:
placeholder = '%s(%%s)' % self.from_text placeholder = '%s(%%s)' % self.from_text
return placeholder return placeholder

View File

@ -186,7 +186,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value, qn): def get_geom_placeholder(self, f, value, compiler):
""" """
Provides a proper substitution value for Geometries that are not in the Provides a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the SRID of the field. Specifically, this routine will substitute in the
@ -205,7 +205,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
placeholder = '%s' placeholder = '%s'
# No geometry value used for F expression, substitute in # No geometry value used for F expression, substitute in
# the column name instead. # the column name instead.
sql, _ = qn.compile(value) sql, _ = compiler.compile(value)
return placeholder % sql return placeholder % sql
else: else:
if transform_value(value, f.srid): if transform_value(value, f.srid):

View File

@ -284,7 +284,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
else: else:
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value, qn): def get_geom_placeholder(self, f, value, compiler):
""" """
Provides a proper substitution value for Geometries that are not in the Provides a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the SRID of the field. Specifically, this routine will substitute in the
@ -300,7 +300,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
# If this is an F expression, then we don't really want # If this is an F expression, then we don't really want
# a placeholder and instead substitute in the column # a placeholder and instead substitute in the column
# of the expression. # of the expression.
sql, _ = qn.compile(value) sql, _ = compiler.compile(value)
placeholder = placeholder % sql placeholder = placeholder % sql
return placeholder return placeholder

View File

@ -178,7 +178,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
dist_param = value dist_param = value
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value, qn): def get_geom_placeholder(self, f, value, compiler):
""" """
Provides a proper substitution value for Geometries that are not in the Provides a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the SRID of the field. Specifically, this routine will substitute in the
@ -193,7 +193,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
placeholder = '%s' placeholder = '%s'
# No geometry value used for F expression, substitute in # No geometry value used for F expression, substitute in
# the column name instead. # the column name instead.
sql, _ = qn.compile(value) sql, _ = compiler.compile(value)
return placeholder % sql return placeholder % sql
else: else:
if transform_value(value, f.srid): if transform_value(value, f.srid):

View File

@ -282,12 +282,12 @@ class GeometryField(Field):
else: else:
return connection.ops.Adapter(self.get_prep_value(value)) return connection.ops.Adapter(self.get_prep_value(value))
def get_placeholder(self, value, qn, connection): def get_placeholder(self, value, compiler, connection):
""" """
Returns the placeholder for the geometry column for the Returns the placeholder for the geometry column for the
given value. given value.
""" """
return connection.ops.get_geom_placeholder(self, value, qn) return connection.ops.get_geom_placeholder(self, value, compiler)
for klass in gis_lookups.values(): for klass in gis_lookups.values():

View File

@ -64,8 +64,8 @@ class GISLookup(Lookup):
params = [connection.ops.Adapter(value)] params = [connection.ops.Adapter(value)]
return ('%s', params) return ('%s', params)
def process_rhs(self, qn, connection): def process_rhs(self, compiler, connection):
rhs, rhs_params = super(GISLookup, self).process_rhs(qn, connection) rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection)
geom = self.rhs geom = self.rhs
if isinstance(self.rhs, Col): if isinstance(self.rhs, Col):
@ -80,12 +80,12 @@ class GISLookup(Lookup):
raise ValueError('Complex expressions not supported for GeometryField') raise ValueError('Complex expressions not supported for GeometryField')
elif isinstance(self.rhs, (list, tuple)): elif isinstance(self.rhs, (list, tuple)):
geom = self.rhs[0] geom = self.rhs[0]
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, qn) rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler)
return rhs, rhs_params return rhs, rhs_params
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs_sql, sql_params = self.process_lhs(qn, connection) lhs_sql, sql_params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(qn, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
sql_params.extend(rhs_params) sql_params.extend(rhs_params)
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql} template_params = {'lhs': lhs_sql, 'rhs': rhs_sql}

View File

@ -29,7 +29,7 @@ class GeoAggregate(Aggregate):
if not isinstance(self.source, GeometryField): if not isinstance(self.source, GeometryField):
raise ValueError('Geospatial aggregates only allowed on geometry fields.') raise ValueError('Geospatial aggregates only allowed on geometry fields.')
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
"Return the aggregate, rendered as SQL with parameters." "Return the aggregate, rendered as SQL with parameters."
if connection.ops.oracle: if connection.ops.oracle:
@ -38,9 +38,9 @@ class GeoAggregate(Aggregate):
params = [] params = []
if hasattr(self.col, 'as_sql'): if hasattr(self.col, 'as_sql'):
field_name, params = self.col.as_sql(qn, connection) field_name, params = self.col.as_sql(compiler, connection)
elif isinstance(self.col, (list, tuple)): elif isinstance(self.col, (list, tuple)):
field_name = '.'.join(qn(c) for c in self.col) field_name = '.'.join(compiler.quote_name_unless_alias(c) for c in self.col)
else: else:
field_name = self.col field_name = self.col

View File

@ -164,9 +164,9 @@ class ArrayField(Field):
class ArrayContainsLookup(Lookup): class ArrayContainsLookup(Lookup):
lookup_name = 'contains' lookup_name = 'contains'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
type_cast = self.lhs.output_field.db_type(connection) type_cast = self.lhs.output_field.db_type(connection)
return '%s @> %s::%s' % (lhs, rhs, type_cast), params return '%s @> %s::%s' % (lhs, rhs, type_cast), params
@ -176,9 +176,9 @@ class ArrayContainsLookup(Lookup):
class ArrayContainedByLookup(Lookup): class ArrayContainedByLookup(Lookup):
lookup_name = 'contained_by' lookup_name = 'contained_by'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params return '%s <@ %s' % (lhs, rhs), params
@ -187,9 +187,9 @@ class ArrayContainedByLookup(Lookup):
class ArrayOverlapLookup(Lookup): class ArrayOverlapLookup(Lookup):
lookup_name = 'overlap' lookup_name = 'overlap'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s && %s' % (lhs, rhs), params return '%s && %s' % (lhs, rhs), params
@ -202,8 +202,8 @@ class ArrayLenTransform(Transform):
def output_field(self): def output_field(self):
return IntegerField() return IntegerField()
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return 'array_length(%s, 1)' % lhs, params return 'array_length(%s, 1)' % lhs, params
@ -214,8 +214,8 @@ class IndexTransform(Transform):
self.index = index self.index = index
self.base_field = base_field self.base_field = base_field
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return '%s[%s]' % (lhs, self.index), params return '%s[%s]' % (lhs, self.index), params
@property @property
@ -240,8 +240,8 @@ class SliceTransform(Transform):
self.start = start self.start = start
self.end = end self.end = end
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return '%s[%s:%s]' % (lhs, self.start, self.end), params return '%s[%s:%s]' % (lhs, self.start, self.end), params

View File

@ -64,9 +64,9 @@ class HStoreField(Field):
class HStoreContainsLookup(Lookup): class HStoreContainsLookup(Lookup):
lookup_name = 'contains' lookup_name = 'contains'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s @> %s' % (lhs, rhs), params return '%s @> %s' % (lhs, rhs), params
@ -75,9 +75,9 @@ class HStoreContainsLookup(Lookup):
class HStoreContainedByLookup(Lookup): class HStoreContainedByLookup(Lookup):
lookup_name = 'contained_by' lookup_name = 'contained_by'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s <@ %s' % (lhs, rhs), params return '%s <@ %s' % (lhs, rhs), params
@ -86,9 +86,9 @@ class HStoreContainedByLookup(Lookup):
class HasKeyLookup(Lookup): class HasKeyLookup(Lookup):
lookup_name = 'has_key' lookup_name = 'has_key'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s ? %s' % (lhs, rhs), params return '%s ? %s' % (lhs, rhs), params
@ -97,9 +97,9 @@ class HasKeyLookup(Lookup):
class HasKeysLookup(Lookup): class HasKeysLookup(Lookup):
lookup_name = 'has_keys' lookup_name = 'has_keys'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s ?& %s' % (lhs, rhs), params return '%s ?& %s' % (lhs, rhs), params
@ -111,8 +111,8 @@ class KeyTransform(Transform):
super(KeyTransform, self).__init__(*args, **kwargs) super(KeyTransform, self).__init__(*args, **kwargs)
self.key_name = key_name self.key_name = key_name
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return "%s -> '%s'" % (lhs, self.key_name), params return "%s -> '%s'" % (lhs, self.key_name), params
@ -130,8 +130,8 @@ class KeysTransform(Transform):
lookup_name = 'keys' lookup_name = 'keys'
output_field = ArrayField(TextField()) output_field = ArrayField(TextField())
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return 'akeys(%s)' % lhs, params return 'akeys(%s)' % lhs, params
@ -140,6 +140,6 @@ class ValuesTransform(Transform):
lookup_name = 'values' lookup_name = 'values'
output_field = ArrayField(TextField()) output_field = ArrayField(TextField())
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return 'avals(%s)' % lhs, params return 'avals(%s)' % lhs, params

View File

@ -490,7 +490,8 @@ class Col(ExpressionNode):
super(Col, self).__init__(output_field=source) super(Col, self).__init__(output_field=source)
self.alias, self.target = alias, target self.alias, self.target = alias, target
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias
return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
def relabeled_clone(self, relabels): def relabeled_clone(self, relabels):
@ -541,8 +542,8 @@ class Date(ExpressionNode):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.col, = self.exprs self.col, = self.exprs
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(qn, connection) sql, params = self.col.as_sql(compiler, connection)
assert not(params) assert not(params)
return connection.ops.date_trunc_sql(self.lookup_type, sql), [] return connection.ops.date_trunc_sql(self.lookup_type, sql), []
@ -563,7 +564,7 @@ class DateTime(ExpressionNode):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.col, = exprs self.col, = exprs
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
sql, params = self.col.as_sql(qn, connection) sql, params = self.col.as_sql(compiler, connection)
assert not(params) assert not(params)
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname) return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)

View File

@ -1553,7 +1553,8 @@ class ForeignObject(RelatedField):
def get_extra_restriction(self, where_class, alias, related_alias): def get_extra_restriction(self, where_class, alias, related_alias):
""" """
Returns a pair condition used for joining and subquery pushdown. The Returns a pair condition used for joining and subquery pushdown. The
condition is something that responds to as_sql(qn, connection) method. condition is something that responds to as_sql(compiler, connection)
method.
Note that currently referring both the 'alias' and 'related_alias' Note that currently referring both the 'alias' and 'related_alias'
will not work in some conditions, like subquery pushdown. will not work in some conditions, like subquery pushdown.

View File

@ -65,7 +65,7 @@ class Transform(RegisterLookupMixin):
self.lhs = lhs self.lhs = lhs
self.init_lookups = lookups[:] self.init_lookups = lookups[:]
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
raise NotImplementedError raise NotImplementedError
@cached_property @cached_property
@ -101,7 +101,7 @@ class Lookup(RegisterLookupMixin):
value = transform(value, lookups) value = transform(value, lookups)
return value return value
def batch_process_rhs(self, qn, connection, rhs=None): def batch_process_rhs(self, compiler, connection, rhs=None):
if rhs is None: if rhs is None:
rhs = self.rhs rhs = self.rhs
if self.bilateral_transforms: if self.bilateral_transforms:
@ -110,7 +110,7 @@ class Lookup(RegisterLookupMixin):
value = QueryWrapper('%s', value = QueryWrapper('%s',
[self.lhs.output_field.get_db_prep_value(p, connection)]) [self.lhs.output_field.get_db_prep_value(p, connection)])
value = self.apply_bilateral_transforms(value) value = self.apply_bilateral_transforms(value)
sql, sql_params = qn.compile(value) sql, sql_params = compiler.compile(value)
sqls.append(sql) sqls.append(sql)
sqls_params.extend(sql_params) sqls_params.extend(sql_params)
else: else:
@ -127,11 +127,11 @@ class Lookup(RegisterLookupMixin):
'%s', self.lhs.output_field.get_db_prep_lookup( '%s', self.lhs.output_field.get_db_prep_lookup(
self.lookup_name, value, connection, prepared=True)) self.lookup_name, value, connection, prepared=True))
def process_lhs(self, qn, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs lhs = lhs or self.lhs
return qn.compile(lhs) return compiler.compile(lhs)
def process_rhs(self, qn, connection): def process_rhs(self, compiler, connection):
value = self.rhs value = self.rhs
if self.bilateral_transforms: if self.bilateral_transforms:
if self.rhs_is_direct_value(): if self.rhs_is_direct_value():
@ -148,7 +148,7 @@ class Lookup(RegisterLookupMixin):
if hasattr(value, 'get_compiler'): if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection) value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql'): if hasattr(value, 'as_sql'):
sql, params = qn.compile(value) sql, params = compiler.compile(value)
return '(' + sql + ')', params return '(' + sql + ')', params
if hasattr(value, '_as_sql'): if hasattr(value, '_as_sql'):
sql, params = value._as_sql(connection=connection) sql, params = value._as_sql(connection=connection)
@ -175,14 +175,14 @@ class Lookup(RegisterLookupMixin):
cols.extend(self.rhs.get_group_by_cols()) cols.extend(self.rhs.get_group_by_cols())
return cols return cols
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
raise NotImplementedError raise NotImplementedError
class BuiltinLookup(Lookup): class BuiltinLookup(Lookup):
def process_lhs(self, qn, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):
lhs_sql, params = super(BuiltinLookup, self).process_lhs( lhs_sql, params = super(BuiltinLookup, self).process_lhs(
qn, connection, lhs) compiler, connection, lhs)
field_internal_type = self.lhs.output_field.get_internal_type() field_internal_type = self.lhs.output_field.get_internal_type()
db_type = self.lhs.output_field.db_type(connection=connection) db_type = self.lhs.output_field.db_type(connection=connection)
lhs_sql = connection.ops.field_cast_sql( lhs_sql = connection.ops.field_cast_sql(
@ -190,9 +190,9 @@ class BuiltinLookup(Lookup):
lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql
return lhs_sql, params return lhs_sql, params
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs_sql, params = self.process_lhs(qn, connection) lhs_sql, params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(qn, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params) params.extend(rhs_params)
rhs_sql = self.get_rhs_op(connection, rhs_sql) rhs_sql = self.get_rhs_op(connection, rhs_sql)
return '%s %s' % (lhs_sql, rhs_sql), params return '%s %s' % (lhs_sql, rhs_sql), params
@ -247,7 +247,7 @@ default_lookups['lte'] = LessThanOrEqual
class In(BuiltinLookup): class In(BuiltinLookup):
lookup_name = 'in' lookup_name = 'in'
def process_rhs(self, qn, connection): def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value(): if self.rhs_is_direct_value():
# rhs should be an iterable, we use batch_process_rhs # rhs should be an iterable, we use batch_process_rhs
# to prepare/transform those values # to prepare/transform those values
@ -255,23 +255,23 @@ class In(BuiltinLookup):
if not rhs: if not rhs:
from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.datastructures import EmptyResultSet
raise EmptyResultSet raise EmptyResultSet
sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs) sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
placeholder = '(' + ', '.join(sqls) + ')' placeholder = '(' + ', '.join(sqls) + ')'
return (placeholder, sqls_params) return (placeholder, sqls_params)
else: else:
return super(In, self).process_rhs(qn, connection) return super(In, self).process_rhs(compiler, connection)
def get_rhs_op(self, connection, rhs): def get_rhs_op(self, connection, rhs):
return 'IN %s' % rhs return 'IN %s' % rhs
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
max_in_list_size = connection.ops.max_in_list_size() max_in_list_size = connection.ops.max_in_list_size()
if self.rhs_is_direct_value() and (max_in_list_size and if self.rhs_is_direct_value() and (max_in_list_size and
len(self.rhs) > max_in_list_size): len(self.rhs) > max_in_list_size):
# This is a special case for Oracle which limits the number of elements # This is a special case for Oracle which limits the number of elements
# which can appear in an 'IN' clause. # which can appear in an 'IN' clause.
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.batch_process_rhs(qn, connection) rhs, rhs_params = self.batch_process_rhs(compiler, connection)
in_clause_elements = ['('] in_clause_elements = ['(']
params = [] params = []
for offset in xrange(0, len(rhs_params), max_in_list_size): for offset in xrange(0, len(rhs_params), max_in_list_size):
@ -288,7 +288,7 @@ class In(BuiltinLookup):
in_clause_elements.append(')') in_clause_elements.append(')')
return ''.join(in_clause_elements), params return ''.join(in_clause_elements), params
else: else:
return super(In, self).as_sql(qn, connection) return super(In, self).as_sql(compiler, connection)
default_lookups['in'] = In default_lookups['in'] = In
@ -348,21 +348,21 @@ class Range(BuiltinLookup):
def get_rhs_op(self, connection, rhs): def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
def process_rhs(self, qn, connection): def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value(): if self.rhs_is_direct_value():
# rhs should be an iterable of 2 values, we use batch_process_rhs # rhs should be an iterable of 2 values, we use batch_process_rhs
# to prepare/transform those values # to prepare/transform those values
return self.batch_process_rhs(qn, connection) return self.batch_process_rhs(compiler, connection)
else: else:
return super(Range, self).process_rhs(qn, connection) return super(Range, self).process_rhs(compiler, connection)
default_lookups['range'] = Range default_lookups['range'] = Range
class DateLookup(BuiltinLookup): class DateLookup(BuiltinLookup):
def process_lhs(self, qn, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):
from django.db.models import DateTimeField from django.db.models import DateTimeField
lhs, params = super(DateLookup, self).process_lhs(qn, connection, lhs) lhs, params = super(DateLookup, self).process_lhs(compiler, connection, lhs)
if isinstance(self.lhs.output_field, DateTimeField): if isinstance(self.lhs.output_field, DateTimeField):
tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname) sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname)
@ -413,8 +413,8 @@ default_lookups['second'] = Second
class IsNull(BuiltinLookup): class IsNull(BuiltinLookup):
lookup_name = 'isnull' lookup_name = 'isnull'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
sql, params = qn.compile(self.lhs) sql, params = compiler.compile(self.lhs)
if self.rhs: if self.rhs:
return "%s IS NULL" % sql, params return "%s IS NULL" % sql, params
else: else:
@ -425,9 +425,9 @@ default_lookups['isnull'] = IsNull
class Search(BuiltinLookup): class Search(BuiltinLookup):
lookup_name = 'search' lookup_name = 'search'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
sql_template = connection.ops.fulltext_search_sql(field_name=lhs) sql_template = connection.ops.fulltext_search_sql(field_name=lhs)
return sql_template, lhs_params + rhs_params return sql_template, lhs_params + rhs_params
@ -437,12 +437,12 @@ default_lookups['search'] = Search
class Regex(BuiltinLookup): class Regex(BuiltinLookup):
lookup_name = 'regex' lookup_name = 'regex'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
if self.lookup_name in connection.operators: if self.lookup_name in connection.operators:
return super(Regex, self).as_sql(qn, connection) return super(Regex, self).as_sql(compiler, connection)
else: else:
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
sql_template = connection.ops.regex_lookup(self.lookup_name) sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params return sql_template % (lhs, rhs), lhs_params + rhs_params
default_lookups['regex'] = Regex default_lookups['regex'] = Regex

View File

@ -29,7 +29,7 @@ class QueryWrapper(object):
def __init__(self, sql, params): def __init__(self, sql, params):
self.data = sql, list(params) self.data = sql, list(params)
def as_sql(self, qn=None, connection=None): def as_sql(self, compiler=None, connection=None):
return self.data return self.data

View File

@ -88,16 +88,16 @@ class Aggregate(RegisterLookupMixin):
clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
return clone return clone
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
"Return the aggregate, rendered as SQL with parameters." "Return the aggregate, rendered as SQL with parameters."
params = [] params = []
if hasattr(self.col, 'as_sql'): if hasattr(self.col, 'as_sql'):
field_name, params = self.col.as_sql(qn, connection) field_name, params = self.col.as_sql(compiler, connection)
elif isinstance(self.col, (list, tuple)): elif isinstance(self.col, (list, tuple)):
field_name = '.'.join(qn(c) for c in self.col) field_name = '.'.join(compiler(c) for c in self.col)
else: else:
field_name = qn(self.col) field_name = compiler(self.col)
substitutions = { substitutions = {
'function': self.sql_function, 'function': self.sql_function,

View File

@ -1098,13 +1098,13 @@ class SQLUpdateCompiler(SQLCompiler):
class SQLAggregateCompiler(SQLCompiler): class SQLAggregateCompiler(SQLCompiler):
def as_sql(self, qn=None): def as_sql(self, compiler=None):
""" """
Creates the SQL for this query. Returns the SQL string and list of Creates the SQL for this query. Returns the SQL string and list of
parameters. parameters.
""" """
if qn is None: if compiler is None:
qn = self compiler = self
sql, params = [], [] sql, params = [], []
for annotation in self.query.annotation_select.values(): for annotation in self.query.annotation_select.values():

View File

@ -81,7 +81,7 @@ class WhereNode(tree.Node):
value = obj.prepare(lookup_type, value) value = obj.prepare(lookup_type, value)
return (obj, lookup_type, value_annotation, value) return (obj, lookup_type, value_annotation, value)
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
""" """
Returns the SQL version of the where clause and the value to be Returns the SQL version of the where clause and the value to be
substituted in. Returns '', [] if this node matches everything, substituted in. Returns '', [] if this node matches everything,
@ -102,10 +102,10 @@ class WhereNode(tree.Node):
for child in self.children: for child in self.children:
try: try:
if hasattr(child, 'as_sql'): if hasattr(child, 'as_sql'):
sql, params = qn.compile(child) sql, params = compiler.compile(child)
else: else:
# A leaf node in the tree. # A leaf node in the tree.
sql, params = self.make_atom(child, qn, connection) sql, params = self.make_atom(child, compiler, connection)
except EmptyResultSet: except EmptyResultSet:
nothing_childs += 1 nothing_childs += 1
else: else:
@ -165,7 +165,7 @@ class WhereNode(tree.Node):
cols.extend(child[3].get_group_by_cols()) cols.extend(child[3].get_group_by_cols())
return cols return cols
def make_atom(self, child, qn, connection): def make_atom(self, child, compiler, connection):
""" """
Turn a tuple (Constraint(table_alias, column_name, db_type), Turn a tuple (Constraint(table_alias, column_name, db_type),
lookup_type, value_annotation, params) into valid SQL. lookup_type, value_annotation, params) into valid SQL.
@ -192,16 +192,16 @@ class WhereNode(tree.Node):
if isinstance(lvalue, tuple): if isinstance(lvalue, tuple):
# A direct database column lookup. # A direct database column lookup.
field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), [] field_sql, field_params = self.sql_for_columns(lvalue, compiler, connection, field_internal_type), []
else: else:
# A smart object with an as_sql() method. # A smart object with an as_sql() method.
field_sql, field_params = qn.compile(lvalue) field_sql, field_params = compiler.compile(lvalue)
is_datetime_field = value_annotation is datetime.datetime is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
if hasattr(params, 'as_sql'): if hasattr(params, 'as_sql'):
extra, params = qn.compile(params) extra, params = compiler.compile(params)
cast_sql = '' cast_sql = ''
else: else:
extra = '' extra = ''
@ -314,7 +314,7 @@ class EmptyWhere(WhereNode):
def add(self, data, connector): def add(self, data, connector):
return return
def as_sql(self, qn=None, connection=None): def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet raise EmptyResultSet
@ -323,7 +323,7 @@ class EverythingNode(object):
A node that matches everything. A node that matches everything.
""" """
def as_sql(self, qn=None, connection=None): def as_sql(self, compiler=None, connection=None):
return '', [] return '', []
@ -331,7 +331,7 @@ class NothingNode(object):
""" """
A node that matches nothing. A node that matches nothing.
""" """
def as_sql(self, qn=None, connection=None): def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet raise EmptyResultSet
@ -340,7 +340,7 @@ class ExtraWhere(object):
self.sqls = sqls self.sqls = sqls
self.params = params self.params = params
def as_sql(self, qn=None, connection=None): def as_sql(self, compiler=None, connection=None):
sqls = ["(%s)" % sql for sql in self.sqls] sqls = ["(%s)" % sql for sql in self.sqls]
return " AND ".join(sqls), list(self.params or ()) return " AND ".join(sqls), list(self.params or ())
@ -402,7 +402,7 @@ class SubqueryConstraint(object):
self.targets = targets self.targets = targets
self.query_object = query_object self.query_object = query_object
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
query = self.query_object query = self.query_object
# QuerySet was sent # QuerySet was sent
@ -420,7 +420,7 @@ class SubqueryConstraint(object):
query.clear_ordering(True) query.clear_ordering(True)
query_compiler = query.get_compiler(connection=connection) query_compiler = query.get_compiler(connection=connection)
return query_compiler.as_subquery_condition(self.alias, self.columns, qn) return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)
def relabel_aliases(self, change_map): def relabel_aliases(self, change_map):
self.alias = change_map.get(self.alias, self.alias) self.alias = change_map.get(self.alias, self.alias)

View File

@ -32,9 +32,9 @@ straightforward::
class NotEqual(Lookup): class NotEqual(Lookup):
lookup_name = 'ne' lookup_name = 'ne'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s <> %s' % (lhs, rhs), params return '%s <> %s' % (lhs, rhs), params
@ -70,12 +70,12 @@ lowercase strings containing only letters, but the only hard requirement is
that it must not contain the string ``__``. that it must not contain the string ``__``.
We then need to define the ``as_sql`` method. This takes a ``SQLCompiler`` We then need to define the ``as_sql`` method. This takes a ``SQLCompiler``
object, called ``qn``, and the active database connection. ``SQLCompiler`` object, called ``compiler``, and the active database connection.
objects are not documented, but the only thing we need to know about them is ``SQLCompiler`` objects are not documented, but the only thing we need to know
that they have a ``compile()`` method which returns a tuple containing a SQL about them is that they have a ``compile()`` method which returns a tuple
string, and the parameters to be interpolated into that string. In most cases, containing a SQL string, and the parameters to be interpolated into that
you don't need to use it directly and can pass it on to ``process_lhs()`` and string. In most cases, you don't need to use it directly and can pass it on to
``process_rhs()``. ``process_lhs()`` and ``process_rhs()``.
A ``Lookup`` works against two values, ``lhs`` and ``rhs``, standing for A ``Lookup`` works against two values, ``lhs`` and ``rhs``, standing for
left-hand side and right-hand side. The left-hand side is usually a field left-hand side and right-hand side. The left-hand side is usually a field
@ -86,13 +86,13 @@ reference to the ``name`` field of the ``Author`` model, and ``'Jack'`` is the
right-hand side. right-hand side.
We call ``process_lhs`` and ``process_rhs`` to convert them into the values we We call ``process_lhs`` and ``process_rhs`` to convert them into the values we
need for SQL using the ``qn`` object described before. These methods return need for SQL using the ``compiler`` object described before. These methods
tuples containing some SQL and the parameters to be interpolated into that SQL, return tuples containing some SQL and the parameters to be interpolated into
just as we need to return from our ``as_sql`` method. In the above example, that SQL, just as we need to return from our ``as_sql`` method. In the above
``process_lhs`` returns ``('"author"."name"', [])`` and ``process_rhs`` returns example, ``process_lhs`` returns ``('"author"."name"', [])`` and
``('"%s"', ['Jack'])``. In this example there were no parameters for the left ``process_rhs`` returns ``('"%s"', ['Jack'])``. In this example there were no
hand side, but this would depend on the object we have, so we still need to parameters for the left hand side, but this would depend on the object we have,
include them in the parameters we return. so we still need to include them in the parameters we return.
Finally we combine the parts into a SQL expression with ``<>``, and supply all Finally we combine the parts into a SQL expression with ``<>``, and supply all
the parameters for the query. We then return a tuple containing the generated the parameters for the query. We then return a tuple containing the generated
@ -123,8 +123,8 @@ function ``ABS()`` to transform the value before comparison::
class AbsoluteValue(Transform): class AbsoluteValue(Transform):
lookup_name = 'abs' lookup_name = 'abs'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return "ABS(%s)" % lhs, params return "ABS(%s)" % lhs, params
Next, let's register it for ``IntegerField``:: Next, let's register it for ``IntegerField``::
@ -160,8 +160,8 @@ be done by adding an ``output_field`` attribute to the transform::
class AbsoluteValue(Transform): class AbsoluteValue(Transform):
lookup_name = 'abs' lookup_name = 'abs'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return "ABS(%s)" % lhs, params return "ABS(%s)" % lhs, params
@property @property
@ -191,9 +191,9 @@ The implementation is::
class AbsoluteValueLessThan(Lookup): class AbsoluteValueLessThan(Lookup):
lookup_name = 'lt' lookup_name = 'lt'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = qn.compile(self.lhs.lhs) lhs, lhs_params = compiler.compile(self.lhs.lhs)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params + lhs_params + rhs_params params = lhs_params + rhs_params + lhs_params + rhs_params
return '%s < %s AND %s > -%s' % (lhs, rhs, lhs, rhs), params return '%s < %s AND %s > -%s' % (lhs, rhs, lhs, rhs), params
@ -247,8 +247,8 @@ this transformation should apply to both ``lhs`` and ``rhs``::
lookup_name = 'upper' lookup_name = 'upper'
bilateral = True bilateral = True
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return "UPPER(%s)" % lhs, params return "UPPER(%s)" % lhs, params
Next, let's register it:: Next, let's register it::
@ -275,9 +275,9 @@ We can change the behavior on a specific backend by creating a subclass of
``NotEqual`` with a ``as_mysql`` method:: ``NotEqual`` with a ``as_mysql`` method::
class MySQLNotEqual(NotEqual): class MySQLNotEqual(NotEqual):
def as_mysql(self, qn, connection): def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params params = lhs_params + rhs_params
return '%s != %s' % (lhs, rhs), params return '%s != %s' % (lhs, rhs), params
Field.register_lookup(MySQLNotExact) Field.register_lookup(MySQLNotExact)

View File

@ -80,24 +80,24 @@ field references, aggregates, and ``Transform`` are examples that follow this
API. A class is said to follow the query expression API when it implements the API. A class is said to follow the query expression API when it implements the
following methods: following methods:
.. method:: as_sql(self, qn, connection) .. method:: as_sql(self, compiler, connection)
Responsible for producing the query string and parameters for the expression. Responsible for producing the query string and parameters for the expression.
The ``qn`` is an ``SQLCompiler`` object, which has a ``compile()`` method The ``compiler`` is an ``SQLCompiler`` object, which has a ``compile()``
that can be used to compile other expressions. The ``connection`` is the method that can be used to compile other expressions. The ``connection`` is
connection used to execute the query. the connection used to execute the query.
Calling ``expression.as_sql()`` is usually incorrect - instead Calling ``expression.as_sql()`` is usually incorrect - instead
``qn.compile(expression)`` should be used. The ``qn.compile()`` method will ``compiler.compile(expression)`` should be used. The ``compiler.compile()``
take care of calling vendor-specific methods of the expression. method will take care of calling vendor-specific methods of the expression.
.. method:: as_vendorname(self, qn, connection) .. method:: as_vendorname(self, compiler, connection)
Works like ``as_sql()`` method. When an expression is compiled by Works like ``as_sql()`` method. When an expression is compiled by
``qn.compile()``, Django will first try to call ``as_vendorname()``, where ``compiler.compile()``, Django will first try to call ``as_vendorname()``,
``vendorname`` is the vendor name of the backend used for executing the where ``vendorname`` is the vendor name of the backend used for executing
query. The ``vendorname`` is one of ``postgresql``, ``oracle``, ``sqlite``, the query. The ``vendorname`` is one of ``postgresql``, ``oracle``,
or ``mysql`` for Django's built-in backends. ``sqlite``, or ``mysql`` for Django's built-in backends.
.. method:: get_lookup(lookup_name) .. method:: get_lookup(lookup_name)
@ -200,17 +200,17 @@ Lookup reference
The name of this lookup, used to identify it on parsing query The name of this lookup, used to identify it on parsing query
expressions. It cannot contain the string ``"__"``. expressions. It cannot contain the string ``"__"``.
.. method:: process_lhs(qn, connection[, lhs=None]) .. method:: process_lhs(compiler, connection[, lhs=None])
Returns a tuple ``(lhs_string, lhs_params)``, as returned by Returns a tuple ``(lhs_string, lhs_params)``, as returned by
``qn.compile(lhs)``. This method can be overridden to tune how the ``compiler.compile(lhs)``. This method can be overridden to tune how
``lhs`` is processed. the ``lhs`` is processed.
``qn`` is an ``SQLCompiler`` object, to be used like ``qn.compile(lhs)`` ``compiler`` is an ``SQLCompiler`` object, to be used like
for compiling ``lhs``. The ``connection`` can be used for compiling ``compiler.compile(lhs)`` for compiling ``lhs``. The ``connection``
vendor specific SQL. If ``lhs`` is not ``None``, use it as the can be used for compiling vendor specific SQL. If ``lhs`` is not
processed ``lhs`` instead of ``self.lhs``. ``None``, use it as the processed ``lhs`` instead of ``self.lhs``.
.. method:: process_rhs(qn, connection) .. method:: process_rhs(compiler, connection)
Behaves the same way as :meth:`process_lhs`, for the right-hand side. Behaves the same way as :meth:`process_lhs`, for the right-hand side.

View File

@ -863,8 +863,8 @@ class ComplexAggregateTestCase(TestCase):
def test_add_implementation(self): def test_add_implementation(self):
try: try:
# test completely changing how the output is rendered # test completely changing how the output is rendered
def lower_case_function_override(self, qn, connection): def lower_case_function_override(self, compiler, connection):
sql, params = qn.compile(self.source_expressions[0]) sql, params = compiler.compile(self.source_expressions[0])
substitutions = dict(function=self.function.lower(), expressions=sql) substitutions = dict(function=self.function.lower(), expressions=sql)
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, params return self.template % substitutions, params
@ -877,9 +877,9 @@ class ComplexAggregateTestCase(TestCase):
self.assertEqual(b1.sums, 383) self.assertEqual(b1.sums, 383)
# test changing the dict and delegating # test changing the dict and delegating
def lower_case_function_super(self, qn, connection): def lower_case_function_super(self, compiler, connection):
self.extra['function'] = self.function.lower() self.extra['function'] = self.function.lower()
return super(Sum, self).as_sql(qn, connection) return super(Sum, self).as_sql(compiler, connection)
setattr(Sum, 'as_' + connection.vendor, lower_case_function_super) setattr(Sum, 'as_' + connection.vendor, lower_case_function_super)
qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
@ -889,7 +889,7 @@ class ComplexAggregateTestCase(TestCase):
self.assertEqual(b1.sums, 383) self.assertEqual(b1.sums, 383)
# test overriding all parts of the template # test overriding all parts of the template
def be_evil(self, qn, connection): def be_evil(self, compiler, connection):
substitutions = dict(function='MAX', expressions='2') substitutions = dict(function='MAX', expressions='2')
substitutions.update(self.extra) substitutions.update(self.extra)
return self.template % substitutions, () return self.template % substitutions, ()
@ -921,8 +921,8 @@ class ComplexAggregateTestCase(TestCase):
class Greatest(Func): class Greatest(Func):
function = 'GREATEST' function = 'GREATEST'
def as_sqlite(self, qn, connection): def as_sqlite(self, compiler, connection):
return super(Greatest, self).as_sql(qn, connection, function='MAX') return super(Greatest, self).as_sql(compiler, connection, function='MAX')
qs = Publisher.objects.annotate( qs = Publisher.objects.annotate(
price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))

View File

@ -14,15 +14,15 @@ from .models import Author, MySQLUnixTimestamp
class Div3Lookup(models.Lookup): class Div3Lookup(models.Lookup):
lookup_name = 'div3' lookup_name = 'div3'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = self.process_lhs(qn, connection) lhs, params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params) params.extend(rhs_params)
return '(%s) %%%% 3 = %s' % (lhs, rhs), params return '(%s) %%%% 3 = %s' % (lhs, rhs), params
def as_oracle(self, qn, connection): def as_oracle(self, compiler, connection):
lhs, params = self.process_lhs(qn, connection) lhs, params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params) params.extend(rhs_params)
return 'mod(%s, 3) = %s' % (lhs, rhs), params return 'mod(%s, 3) = %s' % (lhs, rhs), params
@ -30,12 +30,12 @@ class Div3Lookup(models.Lookup):
class Div3Transform(models.Transform): class Div3Transform(models.Transform):
lookup_name = 'div3' lookup_name = 'div3'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = qn.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
return '(%s) %%%% 3' % lhs, lhs_params return '(%s) %%%% 3' % lhs, lhs_params
def as_oracle(self, qn, connection): def as_oracle(self, compiler, connection):
lhs, lhs_params = qn.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
return 'mod(%s, 3)' % lhs, lhs_params return 'mod(%s, 3)' % lhs, lhs_params
@ -47,8 +47,8 @@ class Mult3BilateralTransform(models.Transform):
bilateral = True bilateral = True
lookup_name = 'mult3' lookup_name = 'mult3'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = qn.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
return '3 * (%s)' % lhs, lhs_params return '3 * (%s)' % lhs, lhs_params
@ -56,16 +56,16 @@ class UpperBilateralTransform(models.Transform):
bilateral = True bilateral = True
lookup_name = 'upper' lookup_name = 'upper'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = qn.compile(self.lhs) lhs, lhs_params = compiler.compile(self.lhs)
return 'UPPER(%s)' % lhs, lhs_params return 'UPPER(%s)' % lhs, lhs_params
class YearTransform(models.Transform): class YearTransform(models.Transform):
lookup_name = 'year' lookup_name = 'year'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs_sql, params = qn.compile(self.lhs) lhs_sql, params = compiler.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params return connection.ops.date_extract_sql('year', lhs_sql), params
@property @property
@ -77,11 +77,11 @@ class YearTransform(models.Transform):
class YearExact(models.lookups.Lookup): class YearExact(models.lookups.Lookup):
lookup_name = 'exact' lookup_name = 'exact'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
# We will need to skip the extract part, and instead go # We will need to skip the extract part, and instead go
# directly with the originating field, that is self.lhs.lhs # directly with the originating field, that is self.lhs.lhs
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
# Note that we must be careful so that we have params in the # Note that we must be careful so that we have params in the
# same order as we have the parts in the SQL. # same order as we have the parts in the SQL.
params = lhs_params + rhs_params + lhs_params + rhs_params params = lhs_params + rhs_params + lhs_params + rhs_params
@ -98,12 +98,12 @@ class YearLte(models.lookups.LessThanOrEqual):
The purpose of this lookup is to efficiently compare the year of the field. The purpose of this lookup is to efficiently compare the year of the field.
""" """
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
# Skip the YearTransform above us (no possibility for efficient # Skip the YearTransform above us (no possibility for efficient
# lookup otherwise). # lookup otherwise).
real_lhs = self.lhs.lhs real_lhs = self.lhs.lhs
lhs_sql, params = self.process_lhs(qn, connection, real_lhs) lhs_sql, params = self.process_lhs(compiler, connection, real_lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params) params.extend(rhs_params)
# Build SQL where the integer year is concatenated with last month # Build SQL where the integer year is concatenated with last month
# and day, then convert that to date. (We try to have SQL like: # and day, then convert that to date. (We try to have SQL like:
@ -117,7 +117,7 @@ class SQLFunc(models.Lookup):
super(SQLFunc, self).__init__(*args, **kwargs) super(SQLFunc, self).__init__(*args, **kwargs)
self.name = name self.name = name
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
return '%s()', [self.name] return '%s()', [self.name]
@property @property
@ -162,9 +162,9 @@ class InMonth(models.lookups.Lookup):
""" """
lookup_name = 'inmonth' lookup_name = 'inmonth'
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(qn, connection) lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(qn, connection) rhs, rhs_params = self.process_rhs(compiler, connection)
# We need to be careful so that we get the params in right # We need to be careful so that we get the params in right
# places. # places.
params = lhs_params + rhs_params + lhs_params + rhs_params params = lhs_params + rhs_params + lhs_params + rhs_params
@ -180,8 +180,8 @@ class DateTimeTransform(models.Transform):
def output_field(self): def output_field(self):
return models.DateTimeField() return models.DateTimeField()
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs, params = qn.compile(self.lhs) lhs, params = compiler.compile(self.lhs)
return 'from_unixtime({})'.format(lhs), params return 'from_unixtime({})'.format(lhs), params
@ -448,9 +448,9 @@ class YearLteTests(TestCase):
try: try:
# Two ways to add a customized implementation for different backends: # Two ways to add a customized implementation for different backends:
# First is MonkeyPatch of the class. # First is MonkeyPatch of the class.
def as_custom_sql(self, qn, connection): def as_custom_sql(self, compiler, connection):
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params + lhs_params + rhs_params params = lhs_params + rhs_params + lhs_params + rhs_params
return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
@ -468,9 +468,9 @@ class YearLteTests(TestCase):
# This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres
# and so on, but as we don't know which DB we are running on, we need to use # and so on, but as we don't know which DB we are running on, we need to use
# setattr. # setattr.
def as_custom_sql(self, qn, connection): def as_custom_sql(self, compiler, connection):
lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, rhs_params = self.process_rhs(qn, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params + lhs_params + rhs_params params = lhs_params + rhs_params + lhs_params + rhs_params
return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
"AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" %
@ -489,8 +489,8 @@ class TrackCallsYearTransform(YearTransform):
lookup_name = 'year' lookup_name = 'year'
call_order = [] call_order = []
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
lhs_sql, params = qn.compile(self.lhs) lhs_sql, params = compiler.compile(self.lhs)
return connection.ops.date_extract_sql('year', lhs_sql), params return connection.ops.date_extract_sql('year', lhs_sql), params
@property @property

View File

@ -117,7 +117,8 @@ class ColConstraint(object):
def __init__(self, alias, col, value): def __init__(self, alias, col, value):
self.alias, self.col, self.value = alias, col, value self.alias, self.col, self.value = alias, col, value
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
qn = compiler.quote_name_unless_alias
return '%s.%s = %%s' % (qn(self.alias), qn(self.col)), [self.value] return '%s.%s = %%s' % (qn(self.alias), qn(self.col)), [self.value]

View File

@ -2817,7 +2817,7 @@ class ProxyQueryCleanupTest(TestCase):
class WhereNodeTest(TestCase): class WhereNodeTest(TestCase):
class DummyNode(object): class DummyNode(object):
def as_sql(self, qn, connection): def as_sql(self, compiler, connection):
return 'dummy', [] return 'dummy', []
class MockCompiler(object): class MockCompiler(object):
@ -2828,70 +2828,70 @@ class WhereNodeTest(TestCase):
return connection.ops.quote_name(name) return connection.ops.quote_name(name)
def test_empty_full_handling_conjunction(self): def test_empty_full_handling_conjunction(self):
qn = WhereNodeTest.MockCompiler() compiler = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()]) w = WhereNode(children=[EverythingNode()])
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
w.negate() w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w = WhereNode(children=[NothingNode()]) w = WhereNode(children=[NothingNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()]) w = WhereNode(children=[EverythingNode(), EverythingNode()])
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
w.negate() w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()]) w = WhereNode(children=[EverythingNode(), self.DummyNode()])
self.assertEqual(w.as_sql(qn, connection), ('dummy', [])) self.assertEqual(w.as_sql(compiler, connection), ('dummy', []))
w = WhereNode(children=[self.DummyNode(), self.DummyNode()]) w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
self.assertEqual(w.as_sql(qn, connection), ('(dummy AND dummy)', [])) self.assertEqual(w.as_sql(compiler, connection), ('(dummy AND dummy)', []))
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy AND dummy)', [])) self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy AND dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()]) w = WhereNode(children=[NothingNode(), self.DummyNode()])
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
def test_empty_full_handling_disjunction(self): def test_empty_full_handling_disjunction(self):
qn = WhereNodeTest.MockCompiler() compiler = WhereNodeTest.MockCompiler()
w = WhereNode(children=[EverythingNode()], connector='OR') w = WhereNode(children=[EverythingNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
w.negate() w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w = WhereNode(children=[NothingNode()], connector='OR') w = WhereNode(children=[NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR') w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
w.negate() w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR') w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('', [])) self.assertEqual(w.as_sql(compiler, connection), ('', []))
w.negate() w.negate()
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR') w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('(dummy OR dummy)', [])) self.assertEqual(w.as_sql(compiler, connection), ('(dummy OR dummy)', []))
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy OR dummy)', [])) self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy OR dummy)', []))
w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR') w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR')
self.assertEqual(w.as_sql(qn, connection), ('dummy', [])) self.assertEqual(w.as_sql(compiler, connection), ('dummy', []))
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', [])) self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy)', []))
def test_empty_nodes(self): def test_empty_nodes(self):
qn = WhereNodeTest.MockCompiler() compiler = WhereNodeTest.MockCompiler()
empty_w = WhereNode() empty_w = WhereNode()
w = WhereNode(children=[empty_w, empty_w]) w = WhereNode(children=[empty_w, empty_w])
self.assertEqual(w.as_sql(qn, connection), (None, [])) self.assertEqual(w.as_sql(compiler, connection), (None, []))
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), (None, [])) self.assertEqual(w.as_sql(compiler, connection), (None, []))
w.connector = 'OR' w.connector = 'OR'
self.assertEqual(w.as_sql(qn, connection), (None, [])) self.assertEqual(w.as_sql(compiler, connection), (None, []))
w.negate() w.negate()
self.assertEqual(w.as_sql(qn, connection), (None, [])) self.assertEqual(w.as_sql(compiler, connection), (None, []))
w = WhereNode(children=[empty_w, NothingNode()], connector='OR') w = WhereNode(children=[empty_w, NothingNode()], connector='OR')
self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection)
class IteratorExceptionsTest(TestCase): class IteratorExceptionsTest(TestCase):