From 4d219d4cdef21d9c14e5d6b9299d583d1975fcba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Wed, 27 Nov 2013 22:07:30 +0200 Subject: [PATCH] Initial implementation of custom lookups --- django/db/backends/__init__.py | 3 + django/db/backends/utils.py | 24 +++ django/db/models/aggregates.py | 4 +- django/db/models/fields/__init__.py | 32 +++- django/db/models/fields/related.py | 5 + django/db/models/lookups.py | 242 +++++++++++++++++++++++++ django/db/models/sql/aggregates.py | 7 + django/db/models/sql/compiler.py | 59 +++--- django/db/models/sql/datastructures.py | 21 ++- django/db/models/sql/query.py | 126 +++++++------ django/db/models/sql/where.py | 10 +- tests/aggregation/tests.py | 2 +- tests/custom_lookups/__init__.py | 0 tests/custom_lookups/models.py | 7 + tests/custom_lookups/tests.py | 136 ++++++++++++++ tests/queries/tests.py | 13 +- 16 files changed, 594 insertions(+), 97 deletions(-) create mode 100644 django/db/models/lookups.py create mode 100644 tests/custom_lookups/__init__.py create mode 100644 tests/custom_lookups/models.py create mode 100644 tests/custom_lookups/tests.py diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index a23751caf51..39a774132ec 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -67,6 +67,9 @@ class BaseDatabaseWrapper(object): self.allow_thread_sharing = allow_thread_sharing self._thread_ident = thread.get_ident() + # Compile implementations, used by compiler.compile(someelem) + self.compile_implementations = utils.get_implementations(self.vendor) + def __eq__(self, other): if isinstance(other, BaseDatabaseWrapper): return self.alias == other.alias diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py index 8bed6415c3f..a1c9f6ec084 100644 --- a/django/db/backends/utils.py +++ b/django/db/backends/utils.py @@ -194,3 +194,27 @@ def format_number(value, max_digits, decimal_places): return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)) else: return "%.*f" % (decimal_places, value) + +# Map of vendor name -> map of query element class -> implementation function +compile_implementations = {} + + +def get_implementations(vendor): + try: + implementation = compile_implementations[vendor] + except KeyError: + # TODO: do we need thread safety here? We could easily use an lock... + implementation = {} + compile_implementations[vendor] = implementation + return implementation + + +class add_implementation(object): + def __init__(self, klass, vendor): + self.klass = klass + self.vendor = vendor + + def __call__(self, func): + implementations = get_implementations(self.vendor) + implementations[self.klass] = func + return func diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 1ec11b4acbf..d663e40f263 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -17,8 +17,8 @@ def refs_aggregate(lookup_parts, aggregates): """ for i in range(len(lookup_parts) + 1): if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates: - return True - return False + return aggregates[LOOKUP_SEP.join(lookup_parts[0:i])], lookup_parts[i:] + return False, () class Aggregate(object): diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index dec2e2dd9e6..85985ab2e5b 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -4,6 +4,7 @@ import collections import copy import datetime import decimal +import inspect import math import warnings from base64 import b64decode, b64encode @@ -11,6 +12,7 @@ from itertools import tee from django.db import connection from django.db.models.loading import get_model +from django.db.models.lookups import default_lookups from django.db.models.query_utils import QueryWrapper from django.conf import settings from django import forms @@ -101,6 +103,7 @@ class Field(object): 'unique': _('%(model_name)s with this %(field_label)s ' 'already exists.'), } + class_lookups = default_lookups.copy() # Generic field type description, usually overridden by subclasses def _description(self): @@ -446,6 +449,30 @@ class Field(object): def get_internal_type(self): return self.__class__.__name__ + def get_lookup(self, lookup_name): + try: + return self.class_lookups[lookup_name] + except KeyError: + for parent in inspect.getmro(self.__class__): + if not 'class_lookups' in parent.__dict__: + continue + if lookup_name in parent.class_lookups: + return parent.class_lookups[lookup_name] + + @classmethod + def register_lookup(cls, lookup): + if not 'class_lookups' in cls.__dict__: + cls.class_lookups = {} + cls.class_lookups[lookup.lookup_name] = lookup + + @classmethod + def _unregister_lookup(cls, lookup): + """ + Removes given lookup from cls lookups. Meant to be used in + tests only. + """ + del cls.class_lookups[lookup.lookup_name] + def pre_save(self, model_instance, add): """ Returns field's value just before saving. @@ -504,8 +531,7 @@ class Field(object): except ValueError: raise ValueError("The __year lookup type requires an integer " "argument") - - raise TypeError("Field has invalid lookup: %s" % lookup_type) + return self.get_prep_value(value) def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): @@ -554,6 +580,8 @@ class Field(object): return connection.ops.year_lookup_bounds_for_date_field(value) else: return [value] # this isn't supposed to happen + else: + return [value] def has_default(self): """ diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 4a945e0f452..cb0402007e3 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -934,6 +934,11 @@ class ForeignObjectRel(object): # example custom multicolumn joins currently have no remote field). self.field_name = None + def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, + raw_value): + return self.field.get_lookup_constraint(constraint_class, alias, targets, sources, + lookup_type, raw_value) + class ManyToOneRel(ForeignObjectRel): def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py new file mode 100644 index 00000000000..85acb24a730 --- /dev/null +++ b/django/db/models/lookups.py @@ -0,0 +1,242 @@ +from copy import copy + +from django.conf import settings +from django.utils import timezone + + +class Lookup(object): + def __init__(self, constraint_class, lhs, rhs): + self.constraint_class, self.lhs, self.rhs = constraint_class, lhs, rhs + self.rhs = self.get_prep_lookup() + + def get_db_prep_lookup(self, value, connection): + return ( + '%s', self.lhs.output_type.get_db_prep_lookup( + self.lookup_name, value, connection, prepared=True)) + + def get_prep_lookup(self): + return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) + + def process_lhs(self, qn, connection): + return qn.compile(self.lhs) + + def process_rhs(self, qn, connection): + value = self.rhs + # Due to historical reasons there are a couple of different + # ways to produce sql here. get_compiler is likely a Query + # instance, _as_sql QuerySet and as_sql just something with + # as_sql. Finally the value can of course be just plain + # Python value. + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) + if hasattr(value, 'as_sql'): + sql, params = qn.compile(value) + return '(' + sql + ')', params + if hasattr(value, '_as_sql'): + sql, params = value._as_sql(connection=connection) + return '(' + sql + ')', params + else: + return self.get_db_prep_lookup(value, connection) + + def relabeled_clone(self, relabels): + new = copy(self) + new.lhs = new.lhs.relabeled_clone(relabels) + if hasattr(new.rhs, 'relabeled_clone'): + new.rhs = new.rhs.relabeled_clone(relabels) + return new + + def get_cols(self): + cols = self.lhs.get_cols() + if hasattr(self.rhs, 'get_cols'): + cols.extend(self.rhs.get_cols()) + return cols + + def as_sql(self, qn, connection): + raise NotImplementedError + + +class DjangoLookup(Lookup): + def as_sql(self, qn, connection): + lhs_sql, params = self.process_lhs(qn, connection) + field_internal_type = self.lhs.output_type.get_internal_type() + db_type = self.lhs.output_type + lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql + lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + operator_plus_rhs = self.get_rhs_op(connection, rhs_sql) + return '%s %s' % (lhs_sql, operator_plus_rhs), params + + def get_rhs_op(self, connection, rhs): + return connection.operators[self.lookup_name] % rhs + + +default_lookups = {} + + +class Exact(DjangoLookup): + lookup_name = 'exact' +default_lookups['exact'] = Exact + + +class IExact(DjangoLookup): + lookup_name = 'iexact' +default_lookups['iexact'] = IExact + + +class Contains(DjangoLookup): + lookup_name = 'contains' +default_lookups['contains'] = Contains + + +class IContains(DjangoLookup): + lookup_name = 'icontains' +default_lookups['icontains'] = IContains + + +class GreaterThan(DjangoLookup): + lookup_name = 'gt' +default_lookups['gt'] = GreaterThan + + +class GreaterThanOrEqual(DjangoLookup): + lookup_name = 'gte' +default_lookups['gte'] = GreaterThanOrEqual + + +class LessThan(DjangoLookup): + lookup_name = 'lt' +default_lookups['lt'] = LessThan + + +class LessThanOrEqual(DjangoLookup): + lookup_name = 'lte' +default_lookups['lte'] = LessThanOrEqual + + +class In(DjangoLookup): + lookup_name = 'in' + + def get_db_prep_lookup(self, value, connection): + params = self.lhs.field.get_db_prep_lookup( + self.lookup_name, value, connection, prepared=True) + if not params: + # TODO: check why this leads to circular import + from django.db.models.sql.datastructures import EmptyResultSet + raise EmptyResultSet + placeholder = '(' + ', '.join('%s' for p in params) + ')' + return (placeholder, params) + + def get_rhs_op(self, connection, rhs): + return 'IN %s' % rhs +default_lookups['in'] = In + + +class StartsWith(DjangoLookup): + lookup_name = 'startswith' +default_lookups['startswith'] = StartsWith + + +class IStartsWith(DjangoLookup): + lookup_name = 'istartswith' +default_lookups['istartswith'] = IStartsWith + + +class EndsWith(DjangoLookup): + lookup_name = 'endswith' +default_lookups['endswith'] = EndsWith + + +class IEndsWith(DjangoLookup): + lookup_name = 'iendswith' +default_lookups['iendswith'] = IEndsWith + + +class Between(DjangoLookup): + def get_rhs_op(self, connection, rhs): + return "BETWEEN %s AND %s" % (rhs, rhs) + + +class Year(Between): + lookup_name = 'year' +default_lookups['year'] = Year + + +class Range(Between): + lookup_name = 'range' +default_lookups['range'] = Range + + +class DateLookup(DjangoLookup): + + def process_lhs(self, qn, connection): + lhs, params = super(DateLookup, self).process_lhs(qn, connection) + tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None + sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname) + return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params + + def get_rhs_op(self, connection, rhs): + return '= %s' % rhs + + +class Month(DateLookup): + lookup_name = 'month' + extract_type = 'month' +default_lookups['month'] = Month + + +class Day(DateLookup): + lookup_name = 'day' + extract_type = 'day' +default_lookups['day'] = Day + + +class WeekDay(DateLookup): + lookup_name = 'week_day' + extract_type = 'week_day' +default_lookups['week_day'] = WeekDay + + +class Hour(DateLookup): + lookup_name = 'hour' + extract_type = 'hour' +default_lookups['hour'] = Hour + + +class Minute(DateLookup): + lookup_name = 'minute' + extract_type = 'minute' +default_lookups['minute'] = Minute + + +class Second(DateLookup): + lookup_name = 'second' + extract_type = 'second' +default_lookups['second'] = Second + + +class IsNull(DjangoLookup): + lookup_name = 'isnull' + + def as_sql(self, qn, connection): + sql, params = qn.compile(self.lhs) + if self.rhs: + return "%s IS NULL" % sql, params + else: + return "%s IS NOT NULL" % sql, params +default_lookups['isnull'] = IsNull + + +class Search(DjangoLookup): + lookup_name = 'search' +default_lookups['search'] = Search + + +class Regex(DjangoLookup): + lookup_name = 'regex' +default_lookups['regex'] = Regex + + +class IRegex(DjangoLookup): + lookup_name = 'iregex' +default_lookups['iregex'] = IRegex diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 8542a330c64..dcf04d4b78b 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -93,6 +93,13 @@ class Aggregate(object): return self.sql_template % substitutions, params + def get_cols(self): + return [] + + @property + def output_type(self): + return self.field + class Avg(Aggregate): is_computed = True diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 98b953937f2..d8b9c0ccb93 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -45,7 +45,7 @@ class SQLCompiler(object): if self.query.select_related and not self.query.related_select_cols: self.fill_related_selections() - def quote_name_unless_alias(self, name): + def __call__(self, name): """ A wrapper around connection.ops.quote_name that doesn't quote aliases for table names. This avoids problems with some SQL dialects that treat @@ -61,6 +61,20 @@ class SQLCompiler(object): self.quote_cache[name] = r return r + def quote_name_unless_alias(self, name): + """ + A wrapper around connection.ops.quote_name that doesn't quote aliases + for table names. This avoids problems with some SQL dialects that treat + quoted strings specially (e.g. PostgreSQL). + """ + return self(name) + + def compile(self, node): + if node.__class__ in self.connection.compile_implementations: + return self.connection.compile_implementations[node.__class__](node, self) + else: + return node.as_sql(self, self.connection) + def as_sql(self, with_limits=True, with_col_aliases=False): """ Creates the SQL for this query. Returns the SQL string and list of @@ -88,10 +102,8 @@ class SQLCompiler(object): # docstring of get_from_clause() for details. from_, f_params = self.get_from_clause() - qn = self.quote_name_unless_alias - - where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) - having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) + where, w_params = self.compile(self.query.where) + having, h_params = self.compile(self.query.having) having_group_by = self.query.having.get_cols() params = [] for val in six.itervalues(self.query.extra_select): @@ -180,7 +192,7 @@ class SQLCompiler(object): (without the table names) are given unique aliases. This is needed in some cases to avoid ambiguity with nested queries. """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] params = [] @@ -213,7 +225,7 @@ class SQLCompiler(object): aliases.add(r) col_aliases.add(col[1]) else: - col_sql, col_params = col.as_sql(qn, self.connection) + col_sql, col_params = self.compile(col) result.append(col_sql) params.extend(col_params) @@ -229,7 +241,7 @@ class SQLCompiler(object): max_name_length = self.connection.ops.max_name_length() for alias, aggregate in self.query.aggregate_select.items(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) if alias is None: result.append(agg_sql) else: @@ -267,7 +279,7 @@ class SQLCompiler(object): result = [] if opts is None: opts = self.query.get_meta() - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name aliases = set() only_load = self.deferred_to_columns() @@ -312,7 +324,7 @@ class SQLCompiler(object): Note that this method can alter the tables in the query, and thus it must be called before get_from_clause(). """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = [] opts = self.query.get_meta() @@ -345,7 +357,7 @@ class SQLCompiler(object): ordering = (self.query.order_by or self.query.get_meta().ordering or []) - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name distinct = self.query.distinct select_aliases = self._select_aliases @@ -483,7 +495,7 @@ class SQLCompiler(object): ordering and distinct must be done first. """ result = [] - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name first = True from_params = [] @@ -501,8 +513,7 @@ class SQLCompiler(object): extra_cond = join_field.get_extra_restriction( self.query.where_class, alias, lhs) if extra_cond: - extra_sql, extra_params = extra_cond.as_sql( - qn, self.connection) + extra_sql, extra_params = self.compile(extra_cond) extra_sql = 'AND (%s)' % extra_sql from_params.extend(extra_params) else: @@ -534,7 +545,7 @@ class SQLCompiler(object): """ Returns a tuple representing the SQL elements in the "group by" clause. """ - qn = self.quote_name_unless_alias + qn = self result, params = [], [] if self.query.group_by is not None: select_cols = self.query.select + self.query.related_select_cols @@ -553,7 +564,7 @@ class SQLCompiler(object): if isinstance(col, (list, tuple)): sql = '%s.%s' % (qn(col[0]), qn(col[1])) elif hasattr(col, 'as_sql'): - sql, col_params = col.as_sql(qn, self.connection) + self.compile(col) else: sql = '(%s)' % str(col) if sql not in seen: @@ -776,7 +787,7 @@ class SQLCompiler(object): return result def as_subquery_condition(self, alias, columns, qn): - inner_qn = self.quote_name_unless_alias + inner_qn = self qn2 = self.connection.ops.quote_name if len(columns) == 1: sql, params = self.as_sql() @@ -887,9 +898,9 @@ class SQLDeleteCompiler(SQLCompiler): """ assert len(self.query.tables) == 1, \ "Can only delete from one table at a time." - qn = self.quote_name_unless_alias + qn = self result = ['DELETE FROM %s' % qn(self.query.tables[0])] - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(params) @@ -905,7 +916,7 @@ class SQLUpdateCompiler(SQLCompiler): if not self.query.values: return '', () table = self.query.tables[0] - qn = self.quote_name_unless_alias + qn = self result = ['UPDATE %s' % qn(table)] result.append('SET') values, update_params = [], [] @@ -925,7 +936,7 @@ class SQLUpdateCompiler(SQLCompiler): val = SQLEvaluator(val, self.query, allow_joins=False) name = field.column if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn, self.connection) + sql, params = self.compile(val) values.append('%s = %s' % (qn(name), sql)) update_params.extend(params) elif val is not None: @@ -936,7 +947,7 @@ class SQLUpdateCompiler(SQLCompiler): if not values: return '', () result.append(', '.join(values)) - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(update_params + params) @@ -1016,11 +1027,11 @@ class SQLAggregateCompiler(SQLCompiler): parameters. """ if qn is None: - qn = self.quote_name_unless_alias + qn = self sql, params = [], [] for aggregate in self.query.aggregate_select.values(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) sql.append(agg_sql) params.extend(agg_params) sql = ', '.join(sql) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index f45ecaf76d0..dd273f51c7c 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -5,18 +5,25 @@ the SQL domain. class Col(object): - def __init__(self, alias, col): - self.alias = alias - self.col = col + def __init__(self, alias, target, source): + self.alias, self.target, self.source = alias, target, source def as_sql(self, qn, connection): - return '%s.%s' % (qn(self.alias), self.col), [] + return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] - def prepare(self): - return self + @property + def output_type(self): + return self.source + + @property + def field(self): + return self.source def relabeled_clone(self, relabels): - return self.__class__(relabels.get(self.alias, self.alias), self.col) + return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source) + + def get_cols(self): + return [(self.alias, self.target.column)] class EmptyResultSet(Exception): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index e0593ef0a73..a7fafd88b54 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1030,6 +1030,12 @@ class Query(object): def prepare_lookup_value(self, value, lookup_type, can_reuse): # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value. + if len(lookup_type) > 1: + raise FieldError('Nested lookups not allowed') + elif len(lookup_type) == 0: + lookup_type = 'exact' + else: + lookup_type = lookup_type[0] if value is None: if lookup_type != 'exact': raise ValueError("Cannot use None as a query value") @@ -1060,31 +1066,39 @@ class Query(object): """ Solve the lookup type from the lookup (eg: 'foobar__id__icontains') """ - lookup_type = 'exact' # Default lookup type - lookup_parts = lookup.split(LOOKUP_SEP) - num_parts = len(lookup_parts) - if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms - and (not self._aggregates or lookup not in self._aggregates)): - # Traverse the lookup query to distinguish related fields from - # lookup types. - lookup_model = self.model - for counter, field_name in enumerate(lookup_parts): - try: - lookup_field = lookup_model._meta.get_field(field_name) - except FieldDoesNotExist: - # Not a field. Bail out. - lookup_type = lookup_parts.pop() - break - # Unless we're at the end of the list of lookups, let's attempt - # to continue traversing relations. - if (counter + 1) < num_parts: - try: - lookup_model = lookup_field.rel.to - except AttributeError: - # Not a related field. Bail out. - lookup_type = lookup_parts.pop() - break - return lookup_type, lookup_parts + lookup_splitted = lookup.split(LOOKUP_SEP) + aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) + if aggregate: + if len(aggregate_lookups) > 1: + raise FieldError("Nested lookups not allowed.") + return aggregate_lookups, (), aggregate + _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) + field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] + if len(lookup_parts) == 0: + lookup_parts = ['exact'] + elif len(lookup_parts) > 1: + if field_parts: + raise FieldError( + 'Only one lookup part allowed (found path "%s" from "%s").' % + (LOOKUP_SEP.join(field_parts), lookup)) + else: + raise FieldError( + 'Invalid lookup "%s" for model %s".' % + (lookup, self.get_meta().model.__name__)) + else: + if not hasattr(field, 'get_lookup_constraint'): + lookup_class = field.get_lookup(lookup_parts[0]) + if lookup_class is None and lookup_parts[0] not in self.query_terms: + raise FieldError( + 'Invalid lookup name %s' % lookup_parts[0]) + return lookup_parts, field_parts, False + + def build_lookup(self, lookup_type, lhs, rhs): + if hasattr(lhs.output_type, 'get_lookup'): + lookup = lhs.output_type.get_lookup(lookup_type) + if lookup: + return lookup(self.where_class, lhs, rhs) + return None def build_filter(self, filter_expr, branch_negated=False, current_negated=False, can_reuse=None, connector=AND): @@ -1114,9 +1128,9 @@ class Query(object): is responsible for unreffing the joins used. """ arg, value = filter_expr - lookup_type, parts = self.solve_lookup_type(arg) - if not parts: + if not arg: raise FieldError("Cannot parse keyword query %r" % arg) + lookup_type, parts, reffed_aggregate = self.solve_lookup_type(arg) # Work out the lookup type and remove it from the end of 'parts', # if necessary. @@ -1124,11 +1138,13 @@ class Query(object): used_joins = getattr(value, '_used_joins', []) clause = self.where_class() - if self._aggregates: - for alias, aggregate in self.aggregates.items(): - if alias in (parts[0], LOOKUP_SEP.join(parts)): - clause.add((aggregate, lookup_type, value), AND) - return clause, [] + if reffed_aggregate: + condition = self.build_lookup(lookup_type, reffed_aggregate, value) + if not condition: + # Backwards compat for custom lookups + condition = (reffed_aggregate, lookup_type, value) + clause.add(condition, AND) + return clause, [] opts = self.get_meta() alias = self.get_initial_alias() @@ -1150,11 +1166,18 @@ class Query(object): targets, alias, join_list = self.trim_joins(sources, join_list, path) if hasattr(field, 'get_lookup_constraint'): - constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources, - lookup_type, value) + # For now foreign keys get special treatment. This should be + # refactored when composite fields lands. + condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, + lookup_type, value) else: - constraint = (Constraint(alias, targets[0].column, field), lookup_type, value) - clause.add(constraint, AND) + assert(len(targets) == 1) + col = Col(alias, targets[0], field) + condition = self.build_lookup(lookup_type, col, value) + if not condition: + # Backwards compat for custom lookups + condition = (Constraint(alias, targets[0].column, field), lookup_type, value) + clause.add(condition, AND) require_outer = lookup_type == 'isnull' and value is True and not current_negated if current_negated and (lookup_type != 'isnull' or value is False): @@ -1185,7 +1208,7 @@ class Query(object): if not self._aggregates: return False if not isinstance(obj, Node): - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates) + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0] or (hasattr(obj[1], 'contains_aggregate') and obj[1].contains_aggregate(self.aggregates))) return any(self.need_having(c) for c in obj.children) @@ -1273,7 +1296,7 @@ class Query(object): needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner - def names_to_path(self, names, opts, allow_many): + def names_to_path(self, names, opts, allow_many=True): """ Walks the names path and turns them PathInfo tuples. Note that a single name in 'names' can generate multiple PathInfos (m2m for @@ -1293,9 +1316,10 @@ class Query(object): try: field, model, direct, m2m = opts.get_field_by_name(name) except FieldDoesNotExist: - available = opts.get_all_field_names() + list(self.aggregate_select) - raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (name, ", ".join(available))) + # We didn't found the current field, so move position back + # one step. + pos -= 1 + break # Check if we need any joins for concrete inheritance cases (the # field lives in parent, but we are currently in one of its # children) @@ -1330,15 +1354,9 @@ class Query(object): final_field = field targets = (field,) break - - if pos != len(names) - 1: - if pos == len(names) - 2: - raise FieldError( - "Join on field %r not permitted. Did you misspell %r for " - "the lookup type?" % (name, names[pos + 1])) - else: - raise FieldError("Join on field %r not permitted." % name) - return path, final_field, targets + if pos == -1: + raise FieldError('Whazaa') + return path, final_field, targets, names[pos + 1:] def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ @@ -1367,8 +1385,10 @@ class Query(object): """ joins = [alias] # First, generate the path for the names - path, final_field, targets = self.names_to_path( + path, final_field, targets, rest = self.names_to_path( names, opts, allow_many) + if rest: + raise FieldError('Invalid lookup') # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1383,8 +1403,6 @@ class Query(object): alias = self.join( connection, reuse=reuse, nullable=nullable, join_field=join.join_field) joins.append(alias) - if hasattr(final_field, 'field'): - final_field = final_field.field return final_field, targets, opts, joins, path def trim_joins(self, targets, joins, path): @@ -1455,7 +1473,7 @@ class Query(object): query.bump_prefix(self) query.where.add( (Constraint(query.select[0].col[0], pk.column, pk), - 'exact', Col(alias, pk.column)), + 'exact', Col(alias, pk, pk)), AND ) diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 7b71580370a..637b8518304 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -101,7 +101,7 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn, connection=connection) + sql, params = qn.compile(child) else: # A leaf node in the tree. sql, params = self.make_atom(child, qn, connection) @@ -193,13 +193,13 @@ class WhereNode(tree.Node): field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), [] else: # A smart object with an as_sql() method. - field_sql, field_params = lvalue.as_sql(qn, connection) + field_sql, field_params = qn.compile(lvalue) is_datetime_field = value_annotation is datetime.datetime cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn, connection) + extra, params = qn.compile(params) cast_sql = '' else: extra = '' @@ -282,6 +282,8 @@ class WhereNode(tree.Node): if hasattr(child, 'relabel_aliases'): # For example another WhereNode child.relabel_aliases(change_map) + elif hasattr(child, 'relabeled_clone'): + self.children[pos] = child.relabeled_clone(change_map) elif isinstance(child, (list, tuple)): # tuple starting with Constraint child = (child[0].relabeled_clone(change_map),) + child[1:] @@ -350,7 +352,7 @@ class Constraint(object): self.alias, self.col, self.field = alias, col, field def prepare(self, lookup_type, value): - if self.field: + if self.field and not hasattr(value, 'as_sql'): return self.field.get_prep_lookup(lookup_type, value) return value diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index eee61654bc1..6ea10278f23 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -443,7 +443,7 @@ class BaseAggregateTestCase(TestCase): vals = Author.objects.filter(pk=1).aggregate(Count("friends__id")) self.assertEqual(vals, {"friends__id__count": 2}) - books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk") + books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk") self.assertQuerysetEqual( books, [ "The Definitive Guide to Django: Web Development Done Right", diff --git a/tests/custom_lookups/__init__.py b/tests/custom_lookups/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/custom_lookups/models.py b/tests/custom_lookups/models.py new file mode 100644 index 00000000000..5152bc65028 --- /dev/null +++ b/tests/custom_lookups/models.py @@ -0,0 +1,7 @@ +from django.db import models + + +class Author(models.Model): + name = models.CharField(max_length=20) + age = models.IntegerField(null=True) + birthdate = models.DateField(null=True) diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py new file mode 100644 index 00000000000..a6086931378 --- /dev/null +++ b/tests/custom_lookups/tests.py @@ -0,0 +1,136 @@ +from copy import copy +from datetime import date +import unittest + +from django.test import TestCase +from .models import Author +from django.db import models +from django.db import connection +from django.db.backends.utils import add_implementation + + +class Div3Lookup(models.lookups.Lookup): + lookup_name = 'div3' + + def as_sql(self, qn, connection): + lhs, params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + return '%s %%%% 3 = %s' % (lhs, rhs), params + + +class InMonth(models.lookups.Lookup): + """ + InMonth matches if the column's month is contained in the value's month. + """ + lookup_name = 'inmonth' + + def as_sql(self, qn, connection): + lhs, params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + # We need to be careful so that we get the params in right + # places. + full_params = params[:] + full_params.extend(rhs_params) + full_params.extend(params) + full_params.extend(rhs_params) + return ("%s >= date_trunc('month', %s) and " + "%s < date_trunc('month', %s) + interval '1 months'" % + (lhs, rhs, lhs, rhs), full_params) + + +class LookupTests(TestCase): + def test_basic_lookup(self): + a1 = Author.objects.create(name='a1', age=1) + a2 = Author.objects.create(name='a2', age=2) + a3 = Author.objects.create(name='a3', age=3) + a4 = Author.objects.create(name='a4', age=4) + models.IntegerField.register_lookup(Div3Lookup) + try: + self.assertQuerysetEqual( + Author.objects.filter(age__div3=0), + [a3], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(age__div3=1).order_by('age'), + [a1, a4], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(age__div3=2), + [a2], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(age__div3=3), + [], lambda x: x + ) + finally: + models.IntegerField._unregister_lookup(Div3Lookup) + + @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") + def test_birthdate_month(self): + a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) + a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) + a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31)) + a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1)) + models.DateField.register_lookup(InMonth) + try: + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), + [a3], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), + [a2], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), + [a1], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), + [a4], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), + [], lambda x: x + ) + finally: + models.DateField._unregister_lookup(InMonth) + + def test_custom_compiles(self): + a1 = Author.objects.create(name='a1', age=1) + a2 = Author.objects.create(name='a2', age=2) + a3 = Author.objects.create(name='a3', age=3) + a4 = Author.objects.create(name='a4', age=4) + + class AnotherEqual(models.lookups.Exact): + lookup_name = 'anotherequal' + models.Field.register_lookup(AnotherEqual) + try: + @add_implementation(AnotherEqual, connection.vendor) + def custom_eq_sql(node, compiler): + return '1 = 1', [] + + self.assertIn('1 = 1', str(Author.objects.filter(name__anotherequal='asdf').query)) + self.assertQuerysetEqual( + Author.objects.filter(name__anotherequal='asdf').order_by('name'), + [a1, a2, a3, a4], lambda x: x) + + @add_implementation(AnotherEqual, connection.vendor) + def another_custom_eq_sql(node, compiler): + # If you need to override one method, it seems this is the best + # option. + node = copy(node) + + class OverriddenAnotherEqual(AnotherEqual): + def get_rhs_op(self, connection, rhs): + return ' <> %s' + node.__class__ = OverriddenAnotherEqual + return node.as_sql(compiler, compiler.connection) + self.assertIn(' <> ', str(Author.objects.filter(name__anotherequal='a1').query)) + self.assertQuerysetEqual( + Author.objects.filter(name__anotherequal='a1').order_by('name'), + [a2, a3, a4], lambda x: x + ) + finally: + models.Field._unregister_lookup(AnotherEqual) diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 02c7b334803..7e381830bbe 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -2620,8 +2620,15 @@ class WhereNodeTest(TestCase): def as_sql(self, qn, connection): return 'dummy', [] + class MockCompiler(object): + def compile(self, node): + return node.as_sql(self, connection) + + def __call__(self, name): + return connection.ops.quote_name(name) + def test_empty_full_handling_conjunction(self): - qn = connection.ops.quote_name + qn = WhereNodeTest.MockCompiler() w = WhereNode(children=[EverythingNode()]) self.assertEqual(w.as_sql(qn, connection), ('', [])) w.negate() @@ -2646,7 +2653,7 @@ class WhereNodeTest(TestCase): self.assertEqual(w.as_sql(qn, connection), ('', [])) def test_empty_full_handling_disjunction(self): - qn = connection.ops.quote_name + qn = WhereNodeTest.MockCompiler() w = WhereNode(children=[EverythingNode()], connector='OR') self.assertEqual(w.as_sql(qn, connection), ('', [])) w.negate() @@ -2673,7 +2680,7 @@ class WhereNodeTest(TestCase): self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', [])) def test_empty_nodes(self): - qn = connection.ops.quote_name + qn = WhereNodeTest.MockCompiler() empty_w = WhereNode() w = WhereNode(children=[empty_w, empty_w]) self.assertEqual(w.as_sql(qn, connection), (None, []))