From f61256da3a266c75c2f75c35172832bf2d605939 Mon Sep 17 00:00:00 2001 From: Josh Smeaton Date: Sun, 16 Nov 2014 12:56:42 +1100 Subject: [PATCH] Renamed qn to compiler --- django/contrib/gis/db/backends/base.py | 2 +- .../gis/db/backends/mysql/operations.py | 4 +- .../gis/db/backends/oracle/operations.py | 4 +- .../gis/db/backends/postgis/operations.py | 4 +- .../gis/db/backends/spatialite/operations.py | 4 +- django/contrib/gis/db/models/fields.py | 4 +- django/contrib/gis/db/models/lookups.py | 12 ++-- .../contrib/gis/db/models/sql/aggregates.py | 6 +- django/contrib/postgres/fields/array.py | 30 ++++---- django/contrib/postgres/fields/hstore.py | 36 +++++----- django/db/models/expressions.py | 11 +-- django/db/models/fields/related.py | 3 +- django/db/models/lookups.py | 68 +++++++++--------- django/db/models/query_utils.py | 2 +- django/db/models/sql/aggregates.py | 8 +-- django/db/models/sql/compiler.py | 6 +- django/db/models/sql/where.py | 26 +++---- docs/howto/custom-lookups.txt | 56 +++++++-------- docs/ref/models/lookups.txt | 38 +++++----- tests/aggregation/tests.py | 14 ++-- tests/custom_lookups/tests.py | 72 +++++++++---------- tests/foreign_object/models.py | 3 +- tests/queries/tests.py | 64 ++++++++--------- 23 files changed, 240 insertions(+), 237 deletions(-) diff --git a/django/contrib/gis/db/backends/base.py b/django/contrib/gis/db/backends/base.py index 0d2a9db870..9e9612cc0f 100644 --- a/django/contrib/gis/db/backends/base.py +++ b/django/contrib/gis/db/backends/base.py @@ -186,7 +186,7 @@ class BaseSpatialOperations(object): """ raise NotImplementedError('Distance operations not available on this spatial backend.') - def get_geom_placeholder(self, f, value, qn): + def get_geom_placeholder(self, f, value, compiler): """ Returns the placeholder for the given geometry field with the given value. Depending on the spatial backend, the placeholder may contain a diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index 191e2c8956..ccbe0542f4 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -35,14 +35,14 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations): def geo_db_type(self, f): return f.geom_type - def get_geom_placeholder(self, f, value, qn): + def get_geom_placeholder(self, f, value, compiler): """ The placeholder here has to include MySQL's WKT constructor. Because MySQL does not support spatial transformations, there is no need to modify the placeholder based on the contents of the given value. """ if hasattr(value, 'as_sql'): - placeholder, _ = qn.compile(value) + placeholder, _ = compiler.compile(value) else: placeholder = '%s(%%s)' % self.from_text return placeholder diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index aa002d3b82..c9e3e8ee31 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -186,7 +186,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): return [dist_param] - def get_geom_placeholder(self, f, value, qn): + def get_geom_placeholder(self, f, value, compiler): """ Provides a proper substitution value for Geometries that are not in the SRID of the field. Specifically, this routine will substitute in the @@ -205,7 +205,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): placeholder = '%s' # No geometry value used for F expression, substitute in # the column name instead. - sql, _ = qn.compile(value) + sql, _ = compiler.compile(value) return placeholder % sql else: if transform_value(value, f.srid): diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index d78b081950..2cd088c59d 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -284,7 +284,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): else: return [dist_param] - def get_geom_placeholder(self, f, value, qn): + def get_geom_placeholder(self, f, value, compiler): """ Provides a proper substitution value for Geometries that are not in the SRID of the field. Specifically, this routine will substitute in the @@ -300,7 +300,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): # If this is an F expression, then we don't really want # a placeholder and instead substitute in the column # of the expression. - sql, _ = qn.compile(value) + sql, _ = compiler.compile(value) placeholder = placeholder % sql return placeholder diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 4ec98c402d..8e9bdd57d2 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -178,7 +178,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): dist_param = value return [dist_param] - def get_geom_placeholder(self, f, value, qn): + def get_geom_placeholder(self, f, value, compiler): """ Provides a proper substitution value for Geometries that are not in the SRID of the field. Specifically, this routine will substitute in the @@ -193,7 +193,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): placeholder = '%s' # No geometry value used for F expression, substitute in # the column name instead. - sql, _ = qn.compile(value) + sql, _ = compiler.compile(value) return placeholder % sql else: if transform_value(value, f.srid): diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index 1d64a06be2..f2c91aa8ab 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -282,12 +282,12 @@ class GeometryField(Field): else: return connection.ops.Adapter(self.get_prep_value(value)) - def get_placeholder(self, value, qn, connection): + def get_placeholder(self, value, compiler, connection): """ Returns the placeholder for the geometry column for the given value. """ - return connection.ops.get_geom_placeholder(self, value, qn) + return connection.ops.get_geom_placeholder(self, value, compiler) for klass in gis_lookups.values(): diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index 889237751a..434e6fe6ab 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -64,8 +64,8 @@ class GISLookup(Lookup): params = [connection.ops.Adapter(value)] return ('%s', params) - def process_rhs(self, qn, connection): - rhs, rhs_params = super(GISLookup, self).process_rhs(qn, connection) + def process_rhs(self, compiler, connection): + rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection) geom = self.rhs if isinstance(self.rhs, Col): @@ -80,12 +80,12 @@ class GISLookup(Lookup): raise ValueError('Complex expressions not supported for GeometryField') elif isinstance(self.rhs, (list, tuple)): geom = self.rhs[0] - rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, qn) + rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) return rhs, rhs_params - def as_sql(self, qn, connection): - lhs_sql, sql_params = self.process_lhs(qn, connection) - rhs_sql, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs_sql, sql_params = self.process_lhs(compiler, connection) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) sql_params.extend(rhs_params) template_params = {'lhs': lhs_sql, 'rhs': rhs_sql} diff --git a/django/contrib/gis/db/models/sql/aggregates.py b/django/contrib/gis/db/models/sql/aggregates.py index c3943eb9f6..65ccc960df 100644 --- a/django/contrib/gis/db/models/sql/aggregates.py +++ b/django/contrib/gis/db/models/sql/aggregates.py @@ -29,7 +29,7 @@ class GeoAggregate(Aggregate): if not isinstance(self.source, GeometryField): raise ValueError('Geospatial aggregates only allowed on geometry fields.') - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): "Return the aggregate, rendered as SQL with parameters." if connection.ops.oracle: @@ -38,9 +38,9 @@ class GeoAggregate(Aggregate): params = [] if hasattr(self.col, 'as_sql'): - field_name, params = self.col.as_sql(qn, connection) + field_name, params = self.col.as_sql(compiler, connection) elif isinstance(self.col, (list, tuple)): - field_name = '.'.join(qn(c) for c in self.col) + field_name = '.'.join(compiler.quote_name_unless_alias(c) for c in self.col) else: field_name = self.col diff --git a/django/contrib/postgres/fields/array.py b/django/contrib/postgres/fields/array.py index 87c2de43b4..bae64abfc1 100644 --- a/django/contrib/postgres/fields/array.py +++ b/django/contrib/postgres/fields/array.py @@ -164,9 +164,9 @@ class ArrayField(Field): class ArrayContainsLookup(Lookup): lookup_name = 'contains' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params type_cast = self.lhs.output_field.db_type(connection) return '%s @> %s::%s' % (lhs, rhs, type_cast), params @@ -176,9 +176,9 @@ class ArrayContainsLookup(Lookup): class ArrayContainedByLookup(Lookup): lookup_name = 'contained_by' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s <@ %s' % (lhs, rhs), params @@ -187,9 +187,9 @@ class ArrayContainedByLookup(Lookup): class ArrayOverlapLookup(Lookup): lookup_name = 'overlap' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s && %s' % (lhs, rhs), params @@ -202,8 +202,8 @@ class ArrayLenTransform(Transform): def output_field(self): return IntegerField() - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return 'array_length(%s, 1)' % lhs, params @@ -214,8 +214,8 @@ class IndexTransform(Transform): self.index = index self.base_field = base_field - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return '%s[%s]' % (lhs, self.index), params @property @@ -240,8 +240,8 @@ class SliceTransform(Transform): self.start = start self.end = end - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return '%s[%s:%s]' % (lhs, self.start, self.end), params diff --git a/django/contrib/postgres/fields/hstore.py b/django/contrib/postgres/fields/hstore.py index ae51ce02ce..be998488fb 100644 --- a/django/contrib/postgres/fields/hstore.py +++ b/django/contrib/postgres/fields/hstore.py @@ -64,9 +64,9 @@ class HStoreField(Field): class HStoreContainsLookup(Lookup): lookup_name = 'contains' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s @> %s' % (lhs, rhs), params @@ -75,9 +75,9 @@ class HStoreContainsLookup(Lookup): class HStoreContainedByLookup(Lookup): lookup_name = 'contained_by' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s <@ %s' % (lhs, rhs), params @@ -86,9 +86,9 @@ class HStoreContainedByLookup(Lookup): class HasKeyLookup(Lookup): lookup_name = 'has_key' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s ? %s' % (lhs, rhs), params @@ -97,9 +97,9 @@ class HasKeyLookup(Lookup): class HasKeysLookup(Lookup): lookup_name = 'has_keys' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s ?& %s' % (lhs, rhs), params @@ -111,8 +111,8 @@ class KeyTransform(Transform): super(KeyTransform, self).__init__(*args, **kwargs) self.key_name = key_name - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return "%s -> '%s'" % (lhs, self.key_name), params @@ -130,8 +130,8 @@ class KeysTransform(Transform): lookup_name = 'keys' output_field = ArrayField(TextField()) - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return 'akeys(%s)' % lhs, params @@ -140,6 +140,6 @@ class ValuesTransform(Transform): lookup_name = 'values' output_field = ArrayField(TextField()) - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return 'avals(%s)' % lhs, params diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 22a5e3ab1e..1d71ecd195 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -490,7 +490,8 @@ class Col(ExpressionNode): super(Col, self).__init__(output_field=source) self.alias, self.target = alias, target - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): + qn = compiler.quote_name_unless_alias return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] def relabeled_clone(self, relabels): @@ -541,8 +542,8 @@ class Date(ExpressionNode): def set_source_expressions(self, exprs): self.col, = self.exprs - def as_sql(self, qn, connection): - sql, params = self.col.as_sql(qn, connection) + def as_sql(self, compiler, connection): + sql, params = self.col.as_sql(compiler, connection) assert not(params) return connection.ops.date_trunc_sql(self.lookup_type, sql), [] @@ -563,7 +564,7 @@ class DateTime(ExpressionNode): def set_source_expressions(self, exprs): self.col, = exprs - def as_sql(self, qn, connection): - sql, params = self.col.as_sql(qn, connection) + def as_sql(self, compiler, connection): + sql, params = self.col.as_sql(compiler, connection) assert not(params) return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index f7699f5152..beb9122e48 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -1553,7 +1553,8 @@ class ForeignObject(RelatedField): def get_extra_restriction(self, where_class, alias, related_alias): """ Returns a pair condition used for joining and subquery pushdown. The - condition is something that responds to as_sql(qn, connection) method. + condition is something that responds to as_sql(compiler, connection) + method. Note that currently referring both the 'alias' and 'related_alias' will not work in some conditions, like subquery pushdown. diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index abb5645147..af966723f6 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -65,7 +65,7 @@ class Transform(RegisterLookupMixin): self.lhs = lhs self.init_lookups = lookups[:] - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): raise NotImplementedError @cached_property @@ -101,7 +101,7 @@ class Lookup(RegisterLookupMixin): value = transform(value, lookups) return value - def batch_process_rhs(self, qn, connection, rhs=None): + def batch_process_rhs(self, compiler, connection, rhs=None): if rhs is None: rhs = self.rhs if self.bilateral_transforms: @@ -110,7 +110,7 @@ class Lookup(RegisterLookupMixin): value = QueryWrapper('%s', [self.lhs.output_field.get_db_prep_value(p, connection)]) value = self.apply_bilateral_transforms(value) - sql, sql_params = qn.compile(value) + sql, sql_params = compiler.compile(value) sqls.append(sql) sqls_params.extend(sql_params) else: @@ -127,11 +127,11 @@ class Lookup(RegisterLookupMixin): '%s', self.lhs.output_field.get_db_prep_lookup( self.lookup_name, value, connection, prepared=True)) - def process_lhs(self, qn, connection, lhs=None): + def process_lhs(self, compiler, connection, lhs=None): lhs = lhs or self.lhs - return qn.compile(lhs) + return compiler.compile(lhs) - def process_rhs(self, qn, connection): + def process_rhs(self, compiler, connection): value = self.rhs if self.bilateral_transforms: if self.rhs_is_direct_value(): @@ -148,7 +148,7 @@ class Lookup(RegisterLookupMixin): if hasattr(value, 'get_compiler'): value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql'): - sql, params = qn.compile(value) + sql, params = compiler.compile(value) return '(' + sql + ')', params if hasattr(value, '_as_sql'): sql, params = value._as_sql(connection=connection) @@ -175,14 +175,14 @@ class Lookup(RegisterLookupMixin): cols.extend(self.rhs.get_group_by_cols()) return cols - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): raise NotImplementedError class BuiltinLookup(Lookup): - def process_lhs(self, qn, connection, lhs=None): + def process_lhs(self, compiler, connection, lhs=None): lhs_sql, params = super(BuiltinLookup, self).process_lhs( - qn, connection, lhs) + compiler, connection, lhs) field_internal_type = self.lhs.output_field.get_internal_type() db_type = self.lhs.output_field.db_type(connection=connection) lhs_sql = connection.ops.field_cast_sql( @@ -190,9 +190,9 @@ class BuiltinLookup(Lookup): lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql return lhs_sql, params - def as_sql(self, qn, connection): - lhs_sql, params = self.process_lhs(qn, connection) - rhs_sql, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs_sql, params = self.process_lhs(compiler, connection) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) rhs_sql = self.get_rhs_op(connection, rhs_sql) return '%s %s' % (lhs_sql, rhs_sql), params @@ -247,7 +247,7 @@ default_lookups['lte'] = LessThanOrEqual class In(BuiltinLookup): lookup_name = 'in' - def process_rhs(self, qn, connection): + def process_rhs(self, compiler, connection): if self.rhs_is_direct_value(): # rhs should be an iterable, we use batch_process_rhs # to prepare/transform those values @@ -255,23 +255,23 @@ class In(BuiltinLookup): if not rhs: from django.db.models.sql.datastructures import EmptyResultSet raise EmptyResultSet - sqls, sqls_params = self.batch_process_rhs(qn, connection, rhs) + sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) placeholder = '(' + ', '.join(sqls) + ')' return (placeholder, sqls_params) else: - return super(In, self).process_rhs(qn, connection) + return super(In, self).process_rhs(compiler, connection) def get_rhs_op(self, connection, rhs): return 'IN %s' % rhs - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): max_in_list_size = connection.ops.max_in_list_size() if self.rhs_is_direct_value() and (max_in_list_size and len(self.rhs) > max_in_list_size): # This is a special case for Oracle which limits the number of elements # which can appear in an 'IN' clause. - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.batch_process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.batch_process_rhs(compiler, connection) in_clause_elements = ['('] params = [] for offset in xrange(0, len(rhs_params), max_in_list_size): @@ -288,7 +288,7 @@ class In(BuiltinLookup): in_clause_elements.append(')') return ''.join(in_clause_elements), params else: - return super(In, self).as_sql(qn, connection) + return super(In, self).as_sql(compiler, connection) default_lookups['in'] = In @@ -348,21 +348,21 @@ class Range(BuiltinLookup): def get_rhs_op(self, connection, rhs): return "BETWEEN %s AND %s" % (rhs[0], rhs[1]) - def process_rhs(self, qn, connection): + def process_rhs(self, compiler, connection): if self.rhs_is_direct_value(): # rhs should be an iterable of 2 values, we use batch_process_rhs # to prepare/transform those values - return self.batch_process_rhs(qn, connection) + return self.batch_process_rhs(compiler, connection) else: - return super(Range, self).process_rhs(qn, connection) + return super(Range, self).process_rhs(compiler, connection) default_lookups['range'] = Range class DateLookup(BuiltinLookup): - def process_lhs(self, qn, connection, lhs=None): + def process_lhs(self, compiler, connection, lhs=None): from django.db.models import DateTimeField - lhs, params = super(DateLookup, self).process_lhs(qn, connection, lhs) + lhs, params = super(DateLookup, self).process_lhs(compiler, connection, lhs) if isinstance(self.lhs.output_field, DateTimeField): tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname) @@ -413,8 +413,8 @@ default_lookups['second'] = Second class IsNull(BuiltinLookup): lookup_name = 'isnull' - def as_sql(self, qn, connection): - sql, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + sql, params = compiler.compile(self.lhs) if self.rhs: return "%s IS NULL" % sql, params else: @@ -425,9 +425,9 @@ default_lookups['isnull'] = IsNull class Search(BuiltinLookup): lookup_name = 'search' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) sql_template = connection.ops.fulltext_search_sql(field_name=lhs) return sql_template, lhs_params + rhs_params @@ -437,12 +437,12 @@ default_lookups['search'] = Search class Regex(BuiltinLookup): lookup_name = 'regex' - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): if self.lookup_name in connection.operators: - return super(Regex, self).as_sql(qn, connection) + return super(Regex, self).as_sql(compiler, connection) else: - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) sql_template = connection.ops.regex_lookup(self.lookup_name) return sql_template % (lhs, rhs), lhs_params + rhs_params default_lookups['regex'] = Regex diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 59cd453722..6dbeb855d1 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -29,7 +29,7 @@ class QueryWrapper(object): def __init__(self, sql, params): self.data = sql, list(params) - def as_sql(self, qn=None, connection=None): + def as_sql(self, compiler=None, connection=None): return self.data diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 6ebf5fb966..713f530a91 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -88,16 +88,16 @@ class Aggregate(RegisterLookupMixin): clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) return clone - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): "Return the aggregate, rendered as SQL with parameters." params = [] if hasattr(self.col, 'as_sql'): - field_name, params = self.col.as_sql(qn, connection) + field_name, params = self.col.as_sql(compiler, connection) elif isinstance(self.col, (list, tuple)): - field_name = '.'.join(qn(c) for c in self.col) + field_name = '.'.join(compiler(c) for c in self.col) else: - field_name = qn(self.col) + field_name = compiler(self.col) substitutions = { 'function': self.sql_function, diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 5f425a7543..4f75702dc7 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1098,13 +1098,13 @@ class SQLUpdateCompiler(SQLCompiler): class SQLAggregateCompiler(SQLCompiler): - def as_sql(self, qn=None): + def as_sql(self, compiler=None): """ Creates the SQL for this query. Returns the SQL string and list of parameters. """ - if qn is None: - qn = self + if compiler is None: + compiler = self sql, params = [], [] for annotation in self.query.annotation_select.values(): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 13815cb68c..ea53f71f86 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -81,7 +81,7 @@ class WhereNode(tree.Node): value = obj.prepare(lookup_type, value) return (obj, lookup_type, value_annotation, value) - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): """ Returns the SQL version of the where clause and the value to be substituted in. Returns '', [] if this node matches everything, @@ -102,10 +102,10 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = qn.compile(child) + sql, params = compiler.compile(child) else: # A leaf node in the tree. - sql, params = self.make_atom(child, qn, connection) + sql, params = self.make_atom(child, compiler, connection) except EmptyResultSet: nothing_childs += 1 else: @@ -165,7 +165,7 @@ class WhereNode(tree.Node): cols.extend(child[3].get_group_by_cols()) return cols - def make_atom(self, child, qn, connection): + def make_atom(self, child, compiler, connection): """ Turn a tuple (Constraint(table_alias, column_name, db_type), lookup_type, value_annotation, params) into valid SQL. @@ -192,16 +192,16 @@ class WhereNode(tree.Node): if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), [] + field_sql, field_params = self.sql_for_columns(lvalue, compiler, connection, field_internal_type), [] else: # A smart object with an as_sql() method. - field_sql, field_params = qn.compile(lvalue) + field_sql, field_params = compiler.compile(lvalue) is_datetime_field = value_annotation is datetime.datetime cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' if hasattr(params, 'as_sql'): - extra, params = qn.compile(params) + extra, params = compiler.compile(params) cast_sql = '' else: extra = '' @@ -314,7 +314,7 @@ class EmptyWhere(WhereNode): def add(self, data, connector): return - def as_sql(self, qn=None, connection=None): + def as_sql(self, compiler=None, connection=None): raise EmptyResultSet @@ -323,7 +323,7 @@ class EverythingNode(object): A node that matches everything. """ - def as_sql(self, qn=None, connection=None): + def as_sql(self, compiler=None, connection=None): return '', [] @@ -331,7 +331,7 @@ class NothingNode(object): """ A node that matches nothing. """ - def as_sql(self, qn=None, connection=None): + def as_sql(self, compiler=None, connection=None): raise EmptyResultSet @@ -340,7 +340,7 @@ class ExtraWhere(object): self.sqls = sqls self.params = params - def as_sql(self, qn=None, connection=None): + def as_sql(self, compiler=None, connection=None): sqls = ["(%s)" % sql for sql in self.sqls] return " AND ".join(sqls), list(self.params or ()) @@ -402,7 +402,7 @@ class SubqueryConstraint(object): self.targets = targets self.query_object = query_object - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): query = self.query_object # QuerySet was sent @@ -420,7 +420,7 @@ class SubqueryConstraint(object): query.clear_ordering(True) query_compiler = query.get_compiler(connection=connection) - return query_compiler.as_subquery_condition(self.alias, self.columns, qn) + return query_compiler.as_subquery_condition(self.alias, self.columns, compiler) def relabel_aliases(self, change_map): self.alias = change_map.get(self.alias, self.alias) diff --git a/docs/howto/custom-lookups.txt b/docs/howto/custom-lookups.txt index da536c4070..0fe678f9fe 100644 --- a/docs/howto/custom-lookups.txt +++ b/docs/howto/custom-lookups.txt @@ -32,9 +32,9 @@ straightforward:: class NotEqual(Lookup): lookup_name = 'ne' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s <> %s' % (lhs, rhs), params @@ -70,12 +70,12 @@ lowercase strings containing only letters, but the only hard requirement is that it must not contain the string ``__``. We then need to define the ``as_sql`` method. This takes a ``SQLCompiler`` -object, called ``qn``, and the active database connection. ``SQLCompiler`` -objects are not documented, but the only thing we need to know about them is -that they have a ``compile()`` method which returns a tuple containing a SQL -string, and the parameters to be interpolated into that string. In most cases, -you don't need to use it directly and can pass it on to ``process_lhs()`` and -``process_rhs()``. +object, called ``compiler``, and the active database connection. +``SQLCompiler`` objects are not documented, but the only thing we need to know +about them is that they have a ``compile()`` method which returns a tuple +containing a SQL string, and the parameters to be interpolated into that +string. In most cases, you don't need to use it directly and can pass it on to +``process_lhs()`` and ``process_rhs()``. A ``Lookup`` works against two values, ``lhs`` and ``rhs``, standing for left-hand side and right-hand side. The left-hand side is usually a field @@ -86,13 +86,13 @@ reference to the ``name`` field of the ``Author`` model, and ``'Jack'`` is the right-hand side. We call ``process_lhs`` and ``process_rhs`` to convert them into the values we -need for SQL using the ``qn`` object described before. These methods return -tuples containing some SQL and the parameters to be interpolated into that SQL, -just as we need to return from our ``as_sql`` method. In the above example, -``process_lhs`` returns ``('"author"."name"', [])`` and ``process_rhs`` returns -``('"%s"', ['Jack'])``. In this example there were no parameters for the left -hand side, but this would depend on the object we have, so we still need to -include them in the parameters we return. +need for SQL using the ``compiler`` object described before. These methods +return tuples containing some SQL and the parameters to be interpolated into +that SQL, just as we need to return from our ``as_sql`` method. In the above +example, ``process_lhs`` returns ``('"author"."name"', [])`` and +``process_rhs`` returns ``('"%s"', ['Jack'])``. In this example there were no +parameters for the left hand side, but this would depend on the object we have, +so we still need to include them in the parameters we return. Finally we combine the parts into a SQL expression with ``<>``, and supply all the parameters for the query. We then return a tuple containing the generated @@ -123,8 +123,8 @@ function ``ABS()`` to transform the value before comparison:: class AbsoluteValue(Transform): lookup_name = 'abs' - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return "ABS(%s)" % lhs, params Next, let's register it for ``IntegerField``:: @@ -160,8 +160,8 @@ be done by adding an ``output_field`` attribute to the transform:: class AbsoluteValue(Transform): lookup_name = 'abs' - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return "ABS(%s)" % lhs, params @property @@ -191,9 +191,9 @@ The implementation is:: class AbsoluteValueLessThan(Lookup): lookup_name = 'lt' - def as_sql(self, qn, connection): - lhs, lhs_params = qn.compile(self.lhs.lhs) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = compiler.compile(self.lhs.lhs) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params + lhs_params + rhs_params return '%s < %s AND %s > -%s' % (lhs, rhs, lhs, rhs), params @@ -247,8 +247,8 @@ this transformation should apply to both ``lhs`` and ``rhs``:: lookup_name = 'upper' bilateral = True - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return "UPPER(%s)" % lhs, params Next, let's register it:: @@ -275,9 +275,9 @@ We can change the behavior on a specific backend by creating a subclass of ``NotEqual`` with a ``as_mysql`` method:: class MySQLNotEqual(NotEqual): - def as_mysql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_mysql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params return '%s != %s' % (lhs, rhs), params Field.register_lookup(MySQLNotExact) diff --git a/docs/ref/models/lookups.txt b/docs/ref/models/lookups.txt index da338b7cb2..23980eddc5 100644 --- a/docs/ref/models/lookups.txt +++ b/docs/ref/models/lookups.txt @@ -80,24 +80,24 @@ field references, aggregates, and ``Transform`` are examples that follow this API. A class is said to follow the query expression API when it implements the following methods: -.. method:: as_sql(self, qn, connection) +.. method:: as_sql(self, compiler, connection) Responsible for producing the query string and parameters for the expression. - The ``qn`` is an ``SQLCompiler`` object, which has a ``compile()`` method - that can be used to compile other expressions. The ``connection`` is the - connection used to execute the query. + The ``compiler`` is an ``SQLCompiler`` object, which has a ``compile()`` + method that can be used to compile other expressions. The ``connection`` is + the connection used to execute the query. Calling ``expression.as_sql()`` is usually incorrect - instead - ``qn.compile(expression)`` should be used. The ``qn.compile()`` method will - take care of calling vendor-specific methods of the expression. + ``compiler.compile(expression)`` should be used. The ``compiler.compile()`` + method will take care of calling vendor-specific methods of the expression. -.. method:: as_vendorname(self, qn, connection) +.. method:: as_vendorname(self, compiler, connection) Works like ``as_sql()`` method. When an expression is compiled by - ``qn.compile()``, Django will first try to call ``as_vendorname()``, where - ``vendorname`` is the vendor name of the backend used for executing the - query. The ``vendorname`` is one of ``postgresql``, ``oracle``, ``sqlite``, - or ``mysql`` for Django's built-in backends. + ``compiler.compile()``, Django will first try to call ``as_vendorname()``, + where ``vendorname`` is the vendor name of the backend used for executing + the query. The ``vendorname`` is one of ``postgresql``, ``oracle``, + ``sqlite``, or ``mysql`` for Django's built-in backends. .. method:: get_lookup(lookup_name) @@ -200,17 +200,17 @@ Lookup reference The name of this lookup, used to identify it on parsing query expressions. It cannot contain the string ``"__"``. - .. method:: process_lhs(qn, connection[, lhs=None]) + .. method:: process_lhs(compiler, connection[, lhs=None]) Returns a tuple ``(lhs_string, lhs_params)``, as returned by - ``qn.compile(lhs)``. This method can be overridden to tune how the - ``lhs`` is processed. + ``compiler.compile(lhs)``. This method can be overridden to tune how + the ``lhs`` is processed. - ``qn`` is an ``SQLCompiler`` object, to be used like ``qn.compile(lhs)`` - for compiling ``lhs``. The ``connection`` can be used for compiling - vendor specific SQL. If ``lhs`` is not ``None``, use it as the - processed ``lhs`` instead of ``self.lhs``. + ``compiler`` is an ``SQLCompiler`` object, to be used like + ``compiler.compile(lhs)`` for compiling ``lhs``. The ``connection`` + can be used for compiling vendor specific SQL. If ``lhs`` is not + ``None``, use it as the processed ``lhs`` instead of ``self.lhs``. - .. method:: process_rhs(qn, connection) + .. method:: process_rhs(compiler, connection) Behaves the same way as :meth:`process_lhs`, for the right-hand side. diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 8c6529c73b..e4b821b43d 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -863,8 +863,8 @@ class ComplexAggregateTestCase(TestCase): def test_add_implementation(self): try: # test completely changing how the output is rendered - def lower_case_function_override(self, qn, connection): - sql, params = qn.compile(self.source_expressions[0]) + def lower_case_function_override(self, compiler, connection): + sql, params = compiler.compile(self.source_expressions[0]) substitutions = dict(function=self.function.lower(), expressions=sql) substitutions.update(self.extra) return self.template % substitutions, params @@ -877,9 +877,9 @@ class ComplexAggregateTestCase(TestCase): self.assertEqual(b1.sums, 383) # test changing the dict and delegating - def lower_case_function_super(self, qn, connection): + def lower_case_function_super(self, compiler, connection): self.extra['function'] = self.function.lower() - return super(Sum, self).as_sql(qn, connection) + return super(Sum, self).as_sql(compiler, connection) setattr(Sum, 'as_' + connection.vendor, lower_case_function_super) qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), @@ -889,7 +889,7 @@ class ComplexAggregateTestCase(TestCase): self.assertEqual(b1.sums, 383) # test overriding all parts of the template - def be_evil(self, qn, connection): + def be_evil(self, compiler, connection): substitutions = dict(function='MAX', expressions='2') substitutions.update(self.extra) return self.template % substitutions, () @@ -921,8 +921,8 @@ class ComplexAggregateTestCase(TestCase): class Greatest(Func): function = 'GREATEST' - def as_sqlite(self, qn, connection): - return super(Greatest, self).as_sql(qn, connection, function='MAX') + def as_sqlite(self, compiler, connection): + return super(Greatest, self).as_sql(compiler, connection, function='MAX') qs = Publisher.objects.annotate( price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py index 872427e883..620db613a1 100644 --- a/tests/custom_lookups/tests.py +++ b/tests/custom_lookups/tests.py @@ -14,15 +14,15 @@ from .models import Author, MySQLUnixTimestamp class Div3Lookup(models.Lookup): lookup_name = 'div3' - def as_sql(self, qn, connection): - lhs, params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) return '(%s) %%%% 3 = %s' % (lhs, rhs), params - def as_oracle(self, qn, connection): - lhs, params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_oracle(self, compiler, connection): + lhs, params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) return 'mod(%s, 3) = %s' % (lhs, rhs), params @@ -30,12 +30,12 @@ class Div3Lookup(models.Lookup): class Div3Transform(models.Transform): lookup_name = 'div3' - def as_sql(self, qn, connection): - lhs, lhs_params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, lhs_params = compiler.compile(self.lhs) return '(%s) %%%% 3' % lhs, lhs_params - def as_oracle(self, qn, connection): - lhs, lhs_params = qn.compile(self.lhs) + def as_oracle(self, compiler, connection): + lhs, lhs_params = compiler.compile(self.lhs) return 'mod(%s, 3)' % lhs, lhs_params @@ -47,8 +47,8 @@ class Mult3BilateralTransform(models.Transform): bilateral = True lookup_name = 'mult3' - def as_sql(self, qn, connection): - lhs, lhs_params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, lhs_params = compiler.compile(self.lhs) return '3 * (%s)' % lhs, lhs_params @@ -56,16 +56,16 @@ class UpperBilateralTransform(models.Transform): bilateral = True lookup_name = 'upper' - def as_sql(self, qn, connection): - lhs, lhs_params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, lhs_params = compiler.compile(self.lhs) return 'UPPER(%s)' % lhs, lhs_params class YearTransform(models.Transform): lookup_name = 'year' - def as_sql(self, qn, connection): - lhs_sql, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs_sql, params = compiler.compile(self.lhs) return connection.ops.date_extract_sql('year', lhs_sql), params @property @@ -77,11 +77,11 @@ class YearTransform(models.Transform): class YearExact(models.lookups.Lookup): lookup_name = 'exact' - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): # We will need to skip the extract part, and instead go # directly with the originating field, that is self.lhs.lhs - lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(qn, connection) + lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) # Note that we must be careful so that we have params in the # same order as we have the parts in the SQL. params = lhs_params + rhs_params + lhs_params + rhs_params @@ -98,12 +98,12 @@ class YearLte(models.lookups.LessThanOrEqual): The purpose of this lookup is to efficiently compare the year of the field. """ - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): # Skip the YearTransform above us (no possibility for efficient # lookup otherwise). real_lhs = self.lhs.lhs - lhs_sql, params = self.process_lhs(qn, connection, real_lhs) - rhs_sql, rhs_params = self.process_rhs(qn, connection) + lhs_sql, params = self.process_lhs(compiler, connection, real_lhs) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) params.extend(rhs_params) # Build SQL where the integer year is concatenated with last month # and day, then convert that to date. (We try to have SQL like: @@ -117,7 +117,7 @@ class SQLFunc(models.Lookup): super(SQLFunc, self).__init__(*args, **kwargs) self.name = name - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): return '%s()', [self.name] @property @@ -162,9 +162,9 @@ class InMonth(models.lookups.Lookup): """ lookup_name = 'inmonth' - def as_sql(self, qn, connection): - lhs, lhs_params = self.process_lhs(qn, connection) - rhs, rhs_params = self.process_rhs(qn, connection) + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) # We need to be careful so that we get the params in right # places. params = lhs_params + rhs_params + lhs_params + rhs_params @@ -180,8 +180,8 @@ class DateTimeTransform(models.Transform): def output_field(self): return models.DateTimeField() - def as_sql(self, qn, connection): - lhs, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) return 'from_unixtime({})'.format(lhs), params @@ -448,9 +448,9 @@ class YearLteTests(TestCase): try: # Two ways to add a customized implementation for different backends: # First is MonkeyPatch of the class. - def as_custom_sql(self, qn, connection): - lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(qn, connection) + def as_custom_sql(self, compiler, connection): + lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params + lhs_params + rhs_params return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % @@ -468,9 +468,9 @@ class YearLteTests(TestCase): # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres # and so on, but as we don't know which DB we are running on, we need to use # setattr. - def as_custom_sql(self, qn, connection): - lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) - rhs_sql, rhs_params = self.process_rhs(qn, connection) + def as_custom_sql(self, compiler, connection): + lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(compiler, connection) params = lhs_params + rhs_params + lhs_params + rhs_params return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % @@ -489,8 +489,8 @@ class TrackCallsYearTransform(YearTransform): lookup_name = 'year' call_order = [] - def as_sql(self, qn, connection): - lhs_sql, params = qn.compile(self.lhs) + def as_sql(self, compiler, connection): + lhs_sql, params = compiler.compile(self.lhs) return connection.ops.date_extract_sql('year', lhs_sql), params @property diff --git a/tests/foreign_object/models.py b/tests/foreign_object/models.py index fc51118149..07d9ff4450 100644 --- a/tests/foreign_object/models.py +++ b/tests/foreign_object/models.py @@ -117,7 +117,8 @@ class ColConstraint(object): def __init__(self, alias, col, value): self.alias, self.col, self.value = alias, col, value - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): + qn = compiler.quote_name_unless_alias return '%s.%s = %%s' % (qn(self.alias), qn(self.col)), [self.value] diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 5ee7c85005..7b4766519a 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -2817,7 +2817,7 @@ class ProxyQueryCleanupTest(TestCase): class WhereNodeTest(TestCase): class DummyNode(object): - def as_sql(self, qn, connection): + def as_sql(self, compiler, connection): return 'dummy', [] class MockCompiler(object): @@ -2828,70 +2828,70 @@ class WhereNodeTest(TestCase): return connection.ops.quote_name(name) def test_empty_full_handling_conjunction(self): - qn = WhereNodeTest.MockCompiler() + compiler = WhereNodeTest.MockCompiler() w = WhereNode(children=[EverythingNode()]) - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) w.negate() - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w = WhereNode(children=[NothingNode()]) - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w.negate() - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) w = WhereNode(children=[EverythingNode(), EverythingNode()]) - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) w.negate() - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w = WhereNode(children=[EverythingNode(), self.DummyNode()]) - self.assertEqual(w.as_sql(qn, connection), ('dummy', [])) + self.assertEqual(w.as_sql(compiler, connection), ('dummy', [])) w = WhereNode(children=[self.DummyNode(), self.DummyNode()]) - self.assertEqual(w.as_sql(qn, connection), ('(dummy AND dummy)', [])) + self.assertEqual(w.as_sql(compiler, connection), ('(dummy AND dummy)', [])) w.negate() - self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy AND dummy)', [])) + self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy AND dummy)', [])) w = WhereNode(children=[NothingNode(), self.DummyNode()]) - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w.negate() - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) def test_empty_full_handling_disjunction(self): - qn = WhereNodeTest.MockCompiler() + compiler = WhereNodeTest.MockCompiler() w = WhereNode(children=[EverythingNode()], connector='OR') - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) w.negate() - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w = WhereNode(children=[NothingNode()], connector='OR') - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w.negate() - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) w = WhereNode(children=[EverythingNode(), EverythingNode()], connector='OR') - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) w.negate() - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w = WhereNode(children=[EverythingNode(), self.DummyNode()], connector='OR') - self.assertEqual(w.as_sql(qn, connection), ('', [])) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) w.negate() - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR') - self.assertEqual(w.as_sql(qn, connection), ('(dummy OR dummy)', [])) + self.assertEqual(w.as_sql(compiler, connection), ('(dummy OR dummy)', [])) w.negate() - self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy OR dummy)', [])) + self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy OR dummy)', [])) w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR') - self.assertEqual(w.as_sql(qn, connection), ('dummy', [])) + self.assertEqual(w.as_sql(compiler, connection), ('dummy', [])) w.negate() - self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', [])) + self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy)', [])) def test_empty_nodes(self): - qn = WhereNodeTest.MockCompiler() + compiler = WhereNodeTest.MockCompiler() empty_w = WhereNode() w = WhereNode(children=[empty_w, empty_w]) - self.assertEqual(w.as_sql(qn, connection), (None, [])) + self.assertEqual(w.as_sql(compiler, connection), (None, [])) w.negate() - self.assertEqual(w.as_sql(qn, connection), (None, [])) + self.assertEqual(w.as_sql(compiler, connection), (None, [])) w.connector = 'OR' - self.assertEqual(w.as_sql(qn, connection), (None, [])) + self.assertEqual(w.as_sql(compiler, connection), (None, [])) w.negate() - self.assertEqual(w.as_sql(qn, connection), (None, [])) + self.assertEqual(w.as_sql(compiler, connection), (None, [])) w = WhereNode(children=[empty_w, NothingNode()], connector='OR') - self.assertRaises(EmptyResultSet, w.as_sql, qn, connection) + self.assertRaises(EmptyResultSet, w.as_sql, compiler, connection) class IteratorExceptionsTest(TestCase):