diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index a92f0f0dc8..7d130bec24 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -12,7 +12,7 @@ from django.db import models, router, transaction, DEFAULT_DB_ALIAS from django.db.models import signals from django.db.models.fields.related import ForeignObject, ForeignObjectRel from django.db.models.related import PathInfo -from django.db.models.sql.where import Constraint +from django.db.models.sql.datastructures import Col from django.forms import ModelForm, ALL_FIELDS from django.forms.models import (BaseModelFormSet, modelformset_factory, modelform_defines_fields) @@ -236,7 +236,8 @@ class GenericRelation(ForeignObject): field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0] contenttype_pk = self.get_content_type().pk cond = where_class() - cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND') + lookup = field.get_lookup('exact')(Col(remote_alias, field, field), contenttype_pk) + cond.add(lookup, 'AND') return cond def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index 81aa1e2fb6..989b5a1292 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -49,9 +49,7 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations): return placeholder def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn): - alias, col, db_type = lvalue - - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue lookup_info = self.geometry_functions.get(lookup_type, False) if lookup_info: diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 8212938f01..b3aa191ccc 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -231,10 +231,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn): "Returns the SQL WHERE clause for use in Oracle spatial SQL construction." - alias, col, db_type = lvalue - - # Getting the quoted table name as `geo_col`. - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue # See if a Oracle Geometry function matches the lookup type next lookup_info = self.geometry_functions.get(lookup_type, False) diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 40ca09f0fa..e8da6a52a6 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -478,10 +478,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): (alias, col, db_type), the lookup type string, lookup value, and the geometry field. """ - alias, col, db_type = lvalue - - # Getting the quoted geometry column. - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue if lookup_type in self.geometry_operators: if field.geography and not lookup_type in self.geography_operators: diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 5149b541a2..ad44b470b3 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -324,10 +324,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): [a tuple of (alias, column, db_type)], lookup type, lookup value, the model field, and the quoting function. """ - alias, col, db_type = lvalue - - # Getting the quoted field as `geo_col`. - geo_col = '%s.%s' % (qn(alias), qn(col)) + geo_col, db_type = lvalue if lookup_type in self.geometry_functions: # See if a SpatiaLite geometry function matches the lookup type. diff --git a/django/contrib/gis/db/models/constants.py b/django/contrib/gis/db/models/constants.py new file mode 100644 index 0000000000..4ece41415c --- /dev/null +++ b/django/contrib/gis/db/models/constants.py @@ -0,0 +1,15 @@ +from django.db.models.sql.constants import QUERY_TERMS + +GIS_LOOKUPS = { + 'bbcontains', 'bboverlaps', 'contained', 'contains', + 'contains_properly', 'coveredby', 'covers', 'crosses', 'disjoint', + 'distance_gt', 'distance_gte', 'distance_lt', 'distance_lte', + 'dwithin', 'equals', 'exact', + 'intersects', 'overlaps', 'relate', 'same_as', 'touches', 'within', + 'left', 'right', 'overlaps_left', 'overlaps_right', + 'overlaps_above', 'overlaps_below', + 'strictly_above', 'strictly_below' +} +ALL_TERMS = GIS_LOOKUPS | QUERY_TERMS + +__all__ = ['ALL_TERMS', 'GIS_LOOKUPS'] diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index 734f805c8d..2daf08147b 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -2,6 +2,8 @@ from django.db.models.fields import Field from django.db.models.sql.expressions import SQLEvaluator from django.utils.translation import ugettext_lazy as _ from django.contrib.gis import forms +from django.contrib.gis.db.models.constants import GIS_LOOKUPS +from django.contrib.gis.db.models.lookups import GISLookup from django.contrib.gis.db.models.proxy import GeometryProxy from django.contrib.gis.geometry.backend import Geometry, GeometryException from django.utils import six @@ -284,6 +286,10 @@ class GeometryField(Field): """ return connection.ops.get_geom_placeholder(self, value) +for lookup_name in GIS_LOOKUPS: + lookup = type(lookup_name, (GISLookup,), {'lookup_name': lookup_name}) + GeometryField.register_lookup(lookup) + # The OpenGIS Geometry Type Fields class PointField(GeometryField): diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py new file mode 100644 index 0000000000..cad6c69000 --- /dev/null +++ b/django/contrib/gis/db/models/lookups.py @@ -0,0 +1,28 @@ +from django.db.models.lookups import Lookup +from django.db.models.sql.expressions import SQLEvaluator + + +class GISLookup(Lookup): + def as_sql(self, qn, connection): + from django.contrib.gis.db.models.sql import GeoWhereNode + # We use the same approach as was used by GeoWhereNode. It would + # be a good idea to upgrade GIS to use similar code that is used + # for other lookups. + if isinstance(self.rhs, SQLEvaluator): + # Make sure the F Expression destination field exists, and + # set an `srid` attribute with the same as that of the + # destination. + geo_fld = GeoWhereNode._check_geo_field(self.rhs.opts, self.rhs.expression.name) + if not geo_fld: + raise ValueError('No geographic field found in expression.') + self.rhs.srid = geo_fld.srid + db_type = self.lhs.output_type.db_type(connection=connection) + params = self.lhs.output_type.get_db_prep_lookup( + self.lookup_name, self.rhs, connection=connection) + lhs_sql, lhs_params = self.process_lhs(qn, connection) + # lhs_params not currently supported. + assert not lhs_params + data = (lhs_sql, db_type) + spatial_sql, spatial_params = connection.ops.spatial_lookup_sql( + data, self.lookup_name, self.rhs, self.lhs.output_type, qn) + return spatial_sql, spatial_params + params diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py index 3923bf460e..f3fa1f6322 100644 --- a/django/contrib/gis/db/models/sql/query.py +++ b/django/contrib/gis/db/models/sql/query.py @@ -1,6 +1,7 @@ from django.db import connections from django.db.models.query import sql +from django.contrib.gis.db.models.constants import ALL_TERMS from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models.sql import aggregates as gis_aggregates from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField @@ -9,19 +10,6 @@ from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import Area, Distance -ALL_TERMS = set([ - 'bbcontains', 'bboverlaps', 'contained', 'contains', - 'contains_properly', 'coveredby', 'covers', 'crosses', 'disjoint', - 'distance_gt', 'distance_gte', 'distance_lt', 'distance_lte', - 'dwithin', 'equals', 'exact', - 'intersects', 'overlaps', 'relate', 'same_as', 'touches', 'within', - 'left', 'right', 'overlaps_left', 'overlaps_right', - 'overlaps_above', 'overlaps_below', - 'strictly_above', 'strictly_below' -]) -ALL_TERMS.update(sql.constants.QUERY_TERMS) - - class GeoQuery(sql.Query): """ A single spatial SQL query. diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 2dbb8b3aae..99d68221e8 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -673,6 +673,9 @@ class BaseDatabaseFeatures(object): # What kind of error does the backend throw when accessing closed cursor? closed_cursor_error_class = ProgrammingError + # Does 'a' LIKE 'A' match? + has_case_insensitive_like = True + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index 7725b0c7a0..33f885d50c 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -62,6 +62,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_combined_alters = True nulls_order_largest = True closed_cursor_error_class = InterfaceError + has_case_insensitive_like = False class DatabaseWrapper(BaseDatabaseWrapper): @@ -83,6 +84,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'iendswith': 'LIKE UPPER(%s)', } + pattern_ops = { + 'startswith': "LIKE %s || '%%%%'", + 'istartswith': "LIKE UPPER(%s) || '%%%%'", + } + Database = Database def __init__(self, *args, **kwargs): diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 6c9728889f..e55973ea39 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -334,6 +334,11 @@ class DatabaseWrapper(BaseDatabaseWrapper): 'iendswith': "LIKE %s ESCAPE '\\'", } + pattern_ops = { + 'startswith': "LIKE %s || '%%%%'", + 'istartswith': "LIKE UPPER(%s) || '%%%%'", + } + Database = Database def __init__(self, *args, **kwargs): diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 454a693be7..bdbdd5fd91 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -17,6 +17,7 @@ from django.db.models.fields.related import ( # NOQA from django.db.models.fields.proxy import OrderWrt # NOQA from django.db.models.deletion import ( # NOQA CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError) +from django.db.models.lookups import Lookup, Transform # NOQA from django.db.models import signals # NOQA diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 1ec11b4acb..e31d228aa5 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -15,10 +15,11 @@ def refs_aggregate(lookup_parts, aggregates): default annotation names we must check each prefix of the lookup_parts for match. """ - for i in range(len(lookup_parts) + 1): - if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates: - return True - return False + for n in range(len(lookup_parts) + 1): + level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) + if level_n_lookup in aggregates: + return aggregates[level_n_lookup], lookup_parts[n:] + return False, () class Aggregate(object): diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 7172cf1a55..7ace8878aa 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -11,6 +11,7 @@ from itertools import tee from django.apps import apps from django.db import connection +from django.db.models.lookups import default_lookups, RegisterLookupMixin from django.db.models.query_utils import QueryWrapper from django.conf import settings from django import forms @@ -80,7 +81,7 @@ def _empty(of_cls): @total_ordering -class Field(object): +class Field(RegisterLookupMixin): """Base class for all field types""" # Designates whether empty strings fundamentally are allowed at the @@ -101,6 +102,7 @@ class Field(object): 'unique': _('%(model_name)s with this %(field_label)s ' 'already exists.'), } + class_lookups = default_lookups.copy() # Generic field type description, usually overridden by subclasses def _description(self): @@ -514,8 +516,7 @@ class Field(object): except ValueError: raise ValueError("The __year lookup type requires an integer " "argument") - - raise TypeError("Field has invalid lookup: %s" % lookup_type) + return self.get_prep_value(value) def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): @@ -564,6 +565,8 @@ class Field(object): return connection.ops.year_lookup_bounds_for_date_field(value) else: return [value] # this isn't supposed to happen + else: + return [value] def has_default(self): """ diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 82e7725dac..69fb3f8492 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -5,9 +5,11 @@ from django.db.backends import utils from django.db.models import signals, Q from django.db.models.fields import (AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) +from django.db.models.lookups import IsNull from django.db.models.related import RelatedObject, PathInfo from django.db.models.query import QuerySet from django.db.models.deletion import CASCADE +from django.db.models.sql.datastructures import Col from django.utils.encoding import smart_text from django.utils import six from django.utils.deprecation import RenameMethodsBase @@ -987,6 +989,11 @@ class ForeignObjectRel(object): # example custom multicolumn joins currently have no remote field). self.field_name = None + def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, + raw_value): + return self.field.get_lookup_constraint(constraint_class, alias, targets, sources, + lookup_type, raw_value) + class ManyToOneRel(ForeignObjectRel): def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, @@ -1193,14 +1200,16 @@ class ForeignObject(RelatedField): pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] return pathinfos - def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, + def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookups, raw_value): - from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR + from django.db.models.sql.where import SubqueryConstraint, AND, OR root_constraint = constraint_class() assert len(targets) == len(sources) + if len(lookups) > 1: + raise exceptions.FieldError('Relation fields do not support nested lookups') + lookup_type = lookups[0] def get_normalized_value(value): - from django.db.models import Model if isinstance(value, Model): value_list = [] @@ -1221,28 +1230,27 @@ class ForeignObject(RelatedField): [source.name for source in sources], raw_value), AND) elif lookup_type == 'isnull': - root_constraint.add( - (Constraint(alias, targets[0].column, targets[0]), lookup_type, raw_value), AND) + root_constraint.add(IsNull(Col(alias, targets[0], sources[0]), raw_value), AND) elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte'] and not is_multicolumn)): value = get_normalized_value(raw_value) - for index, source in enumerate(sources): + for target, source, val in zip(targets, sources, value): + lookup_class = target.get_lookup(lookup_type) root_constraint.add( - (Constraint(alias, targets[index].column, sources[index]), lookup_type, - value[index]), AND) + lookup_class(Col(alias, target, source), val), AND) elif lookup_type in ['range', 'in'] and not is_multicolumn: values = [get_normalized_value(value) for value in raw_value] value = [val[0] for val in values] - root_constraint.add( - (Constraint(alias, targets[0].column, sources[0]), lookup_type, value), AND) + lookup_class = targets[0].get_lookup(lookup_type) + root_constraint.add(lookup_class(Col(alias, targets[0], sources[0]), value), AND) elif lookup_type == 'in': values = [get_normalized_value(value) for value in raw_value] for value in values: value_constraint = constraint_class() - for index, target in enumerate(targets): - value_constraint.add( - (Constraint(alias, target.column, sources[index]), 'exact', value[index]), - AND) + for source, target, val in zip(sources, targets, value): + lookup_class = target.get_lookup('exact') + lookup = lookup_class(Col(alias, target, source), val) + value_constraint.add(lookup, AND) root_constraint.add(value_constraint, OR) else: raise TypeError('Related Field got invalid lookup: %s' % lookup_type) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py new file mode 100644 index 0000000000..5369994bbc --- /dev/null +++ b/django/db/models/lookups.py @@ -0,0 +1,317 @@ +from copy import copy +import inspect + +from django.conf import settings +from django.utils import timezone +from django.utils.functional import cached_property + + +class RegisterLookupMixin(object): + def get_lookup(self, lookup_name): + try: + return self.class_lookups[lookup_name] + except KeyError: + # To allow for inheritance, check parent class class lookups. + for parent in inspect.getmro(self.__class__): + if not 'class_lookups' in parent.__dict__: + continue + if lookup_name in parent.class_lookups: + return parent.class_lookups[lookup_name] + except AttributeError: + # This class didn't have any class_lookups + pass + if hasattr(self, 'output_type'): + return self.output_type.get_lookup(lookup_name) + return None + + @classmethod + def register_lookup(cls, lookup): + if not 'class_lookups' in cls.__dict__: + cls.class_lookups = {} + cls.class_lookups[lookup.lookup_name] = lookup + + @classmethod + def _unregister_lookup(cls, lookup): + """ + Removes given lookup from cls lookups. Meant to be used in + tests only. + """ + del cls.class_lookups[lookup.lookup_name] + + +class Transform(RegisterLookupMixin): + def __init__(self, lhs, lookups): + self.lhs = lhs + self.init_lookups = lookups[:] + + def as_sql(self, qn, connection): + raise NotImplementedError + + @cached_property + def output_type(self): + return self.lhs.output_type + + def relabeled_clone(self, relabels): + return self.__class__(self.lhs.relabeled_clone(relabels)) + + def get_group_by_cols(self): + return self.lhs.get_group_by_cols() + + +class Lookup(RegisterLookupMixin): + lookup_name = None + + def __init__(self, lhs, rhs): + self.lhs, self.rhs = lhs, rhs + self.rhs = self.get_prep_lookup() + + def get_prep_lookup(self): + return self.lhs.output_type.get_prep_lookup(self.lookup_name, self.rhs) + + def get_db_prep_lookup(self, value, connection): + return ( + '%s', self.lhs.output_type.get_db_prep_lookup( + self.lookup_name, value, connection, prepared=True)) + + def process_lhs(self, qn, connection, lhs=None): + lhs = lhs or self.lhs + return qn.compile(lhs) + + def process_rhs(self, qn, connection, rhs=None): + value = rhs or self.rhs + # Due to historical reasons there are a couple of different + # ways to produce sql here. get_compiler is likely a Query + # instance, _as_sql QuerySet and as_sql just something with + # as_sql. Finally the value can of course be just plain + # Python value. + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) + if hasattr(value, 'as_sql'): + sql, params = qn.compile(value) + return '(' + sql + ')', params + if hasattr(value, '_as_sql'): + sql, params = value._as_sql(connection=connection) + return '(' + sql + ')', params + else: + return self.get_db_prep_lookup(value, connection) + + def relabeled_clone(self, relabels): + new = copy(self) + new.lhs = new.lhs.relabeled_clone(relabels) + if hasattr(new.rhs, 'relabeled_clone'): + new.rhs = new.rhs.relabeled_clone(relabels) + return new + + def get_group_by_cols(self): + cols = self.lhs.get_group_by_cols() + if hasattr(self.rhs, 'get_group_by_cols'): + cols.extend(self.rhs.get_group_by_cols()) + return cols + + def as_sql(self, qn, connection): + raise NotImplementedError + + +class BuiltinLookup(Lookup): + def as_sql(self, qn, connection): + lhs_sql, params = self.process_lhs(qn, connection) + field_internal_type = self.lhs.output_type.get_internal_type() + db_type = self.lhs.output_type + lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql + lhs_sql = connection.ops.lookup_cast(self.lookup_name) % lhs_sql + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + operator_plus_rhs = self.get_rhs_op(connection, rhs_sql) + return '%s %s' % (lhs_sql, operator_plus_rhs), params + + def get_rhs_op(self, connection, rhs): + return connection.operators[self.lookup_name] % rhs + + +default_lookups = {} + + +class Exact(BuiltinLookup): + lookup_name = 'exact' +default_lookups['exact'] = Exact + + +class IExact(BuiltinLookup): + lookup_name = 'iexact' +default_lookups['iexact'] = IExact + + +class Contains(BuiltinLookup): + lookup_name = 'contains' +default_lookups['contains'] = Contains + + +class IContains(BuiltinLookup): + lookup_name = 'icontains' +default_lookups['icontains'] = IContains + + +class GreaterThan(BuiltinLookup): + lookup_name = 'gt' +default_lookups['gt'] = GreaterThan + + +class GreaterThanOrEqual(BuiltinLookup): + lookup_name = 'gte' +default_lookups['gte'] = GreaterThanOrEqual + + +class LessThan(BuiltinLookup): + lookup_name = 'lt' +default_lookups['lt'] = LessThan + + +class LessThanOrEqual(BuiltinLookup): + lookup_name = 'lte' +default_lookups['lte'] = LessThanOrEqual + + +class In(BuiltinLookup): + lookup_name = 'in' + + def get_db_prep_lookup(self, value, connection): + params = self.lhs.output_type.get_db_prep_lookup( + self.lookup_name, value, connection, prepared=True) + if not params: + # TODO: check why this leads to circular import + from django.db.models.sql.datastructures import EmptyResultSet + raise EmptyResultSet + placeholder = '(' + ', '.join('%s' for p in params) + ')' + return (placeholder, params) + + def get_rhs_op(self, connection, rhs): + return 'IN %s' % rhs +default_lookups['in'] = In + + +class PatternLookup(BuiltinLookup): + def get_rhs_op(self, connection, rhs): + # Assume we are in startswith. We need to produce SQL like: + # col LIKE %s, ['thevalue%'] + # For python values we can (and should) do that directly in Python, + # but if the value is for example reference to other column, then + # we need to add the % pattern match to the lookup by something like + # col LIKE othercol || '%%' + # So, for Python values we don't need any special pattern, but for + # SQL reference values we need the correct pattern added. + value = self.rhs + if (hasattr(value, 'get_compiler') or hasattr(value, 'as_sql') + or hasattr(value, '_as_sql')): + return connection.pattern_ops[self.lookup_name] % rhs + else: + return super(PatternLookup, self).get_rhs_op(connection, rhs) + + +class StartsWith(PatternLookup): + lookup_name = 'startswith' +default_lookups['startswith'] = StartsWith + + +class IStartsWith(PatternLookup): + lookup_name = 'istartswith' +default_lookups['istartswith'] = IStartsWith + + +class EndsWith(BuiltinLookup): + lookup_name = 'endswith' +default_lookups['endswith'] = EndsWith + + +class IEndsWith(BuiltinLookup): + lookup_name = 'iendswith' +default_lookups['iendswith'] = IEndsWith + + +class Between(BuiltinLookup): + def get_rhs_op(self, connection, rhs): + return "BETWEEN %s AND %s" % (rhs, rhs) + + +class Year(Between): + lookup_name = 'year' +default_lookups['year'] = Year + + +class Range(Between): + lookup_name = 'range' +default_lookups['range'] = Range + + +class DateLookup(BuiltinLookup): + + def process_lhs(self, qn, connection): + lhs, params = super(DateLookup, self).process_lhs(qn, connection) + tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None + sql, tz_params = connection.ops.datetime_extract_sql(self.extract_type, lhs, tzname) + return connection.ops.lookup_cast(self.lookup_name) % sql, tz_params + + def get_rhs_op(self, connection, rhs): + return '= %s' % rhs + + +class Month(DateLookup): + lookup_name = 'month' + extract_type = 'month' +default_lookups['month'] = Month + + +class Day(DateLookup): + lookup_name = 'day' + extract_type = 'day' +default_lookups['day'] = Day + + +class WeekDay(DateLookup): + lookup_name = 'week_day' + extract_type = 'week_day' +default_lookups['week_day'] = WeekDay + + +class Hour(DateLookup): + lookup_name = 'hour' + extract_type = 'hour' +default_lookups['hour'] = Hour + + +class Minute(DateLookup): + lookup_name = 'minute' + extract_type = 'minute' +default_lookups['minute'] = Minute + + +class Second(DateLookup): + lookup_name = 'second' + extract_type = 'second' +default_lookups['second'] = Second + + +class IsNull(BuiltinLookup): + lookup_name = 'isnull' + + def as_sql(self, qn, connection): + sql, params = qn.compile(self.lhs) + if self.rhs: + return "%s IS NULL" % sql, params + else: + return "%s IS NOT NULL" % sql, params +default_lookups['isnull'] = IsNull + + +class Search(BuiltinLookup): + lookup_name = 'search' +default_lookups['search'] = Search + + +class Regex(BuiltinLookup): + lookup_name = 'regex' +default_lookups['regex'] = Regex + + +class IRegex(BuiltinLookup): + lookup_name = 'iregex' +default_lookups['iregex'] = IRegex diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 8542a330c6..aef8b493bb 100644 --- a/django/db/models/sql/aggregates.py +++ b/django/db/models/sql/aggregates.py @@ -4,6 +4,7 @@ Classes to represent the default SQL aggregate functions import copy from django.db.models.fields import IntegerField, FloatField +from django.db.models.lookups import RegisterLookupMixin __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] @@ -14,7 +15,7 @@ ordinal_aggregate_field = IntegerField() computed_aggregate_field = FloatField() -class Aggregate(object): +class Aggregate(RegisterLookupMixin): """ Default SQL Aggregate. """ @@ -93,6 +94,13 @@ class Aggregate(object): return self.sql_template % substitutions, params + def get_group_by_cols(self): + return [] + + @property + def output_type(self): + return self.field + class Avg(Aggregate): is_computed = True diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 41bba93206..123427cf8b 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -45,7 +45,7 @@ class SQLCompiler(object): if self.query.select_related and not self.query.related_select_cols: self.fill_related_selections() - def quote_name_unless_alias(self, name): + def __call__(self, name): """ A wrapper around connection.ops.quote_name that doesn't quote aliases for table names. This avoids problems with some SQL dialects that treat @@ -61,6 +61,22 @@ class SQLCompiler(object): self.quote_cache[name] = r return r + def quote_name_unless_alias(self, name): + """ + A wrapper around connection.ops.quote_name that doesn't quote aliases + for table names. This avoids problems with some SQL dialects that treat + quoted strings specially (e.g. PostgreSQL). + """ + return self(name) + + def compile(self, node): + vendor_impl = getattr( + node, 'as_' + self.connection.vendor, None) + if vendor_impl: + return vendor_impl(self, self.connection) + else: + return node.as_sql(self, self.connection) + def as_sql(self, with_limits=True, with_col_aliases=False): """ Creates the SQL for this query. Returns the SQL string and list of @@ -88,11 +104,9 @@ class SQLCompiler(object): # docstring of get_from_clause() for details. from_, f_params = self.get_from_clause() - qn = self.quote_name_unless_alias - - where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection) - having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection) - having_group_by = self.query.having.get_cols() + where, w_params = self.compile(self.query.where) + having, h_params = self.compile(self.query.having) + having_group_by = self.query.having.get_group_by_cols() params = [] for val in six.itervalues(self.query.extra_select): params.extend(val[1]) @@ -180,7 +194,7 @@ class SQLCompiler(object): (without the table names) are given unique aliases. This is needed in some cases to avoid ambiguity with nested queries. """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] params = [] @@ -213,7 +227,7 @@ class SQLCompiler(object): aliases.add(r) col_aliases.add(col[1]) else: - col_sql, col_params = col.as_sql(qn, self.connection) + col_sql, col_params = self.compile(col) result.append(col_sql) params.extend(col_params) @@ -229,7 +243,7 @@ class SQLCompiler(object): max_name_length = self.connection.ops.max_name_length() for alias, aggregate in self.query.aggregate_select.items(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) if alias is None: result.append(agg_sql) else: @@ -267,7 +281,7 @@ class SQLCompiler(object): result = [] if opts is None: opts = self.query.get_meta() - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name aliases = set() only_load = self.deferred_to_columns() @@ -319,7 +333,7 @@ class SQLCompiler(object): Note that this method can alter the tables in the query, and thus it must be called before get_from_clause(). """ - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name result = [] opts = self.query.get_meta() @@ -352,7 +366,7 @@ class SQLCompiler(object): ordering = (self.query.order_by or self.query.get_meta().ordering or []) - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name distinct = self.query.distinct select_aliases = self._select_aliases @@ -490,7 +504,7 @@ class SQLCompiler(object): ordering and distinct must be done first. """ result = [] - qn = self.quote_name_unless_alias + qn = self qn2 = self.connection.ops.quote_name first = True from_params = [] @@ -508,8 +522,7 @@ class SQLCompiler(object): extra_cond = join_field.get_extra_restriction( self.query.where_class, alias, lhs) if extra_cond: - extra_sql, extra_params = extra_cond.as_sql( - qn, self.connection) + extra_sql, extra_params = self.compile(extra_cond) extra_sql = 'AND (%s)' % extra_sql from_params.extend(extra_params) else: @@ -541,7 +554,7 @@ class SQLCompiler(object): """ Returns a tuple representing the SQL elements in the "group by" clause. """ - qn = self.quote_name_unless_alias + qn = self result, params = [], [] if self.query.group_by is not None: select_cols = self.query.select + self.query.related_select_cols @@ -560,7 +573,7 @@ class SQLCompiler(object): if isinstance(col, (list, tuple)): sql = '%s.%s' % (qn(col[0]), qn(col[1])) elif hasattr(col, 'as_sql'): - sql, col_params = col.as_sql(qn, self.connection) + self.compile(col) else: sql = '(%s)' % str(col) if sql not in seen: @@ -784,7 +797,7 @@ class SQLCompiler(object): return result def as_subquery_condition(self, alias, columns, qn): - inner_qn = self.quote_name_unless_alias + inner_qn = self qn2 = self.connection.ops.quote_name if len(columns) == 1: sql, params = self.as_sql() @@ -895,9 +908,9 @@ class SQLDeleteCompiler(SQLCompiler): """ assert len(self.query.tables) == 1, \ "Can only delete from one table at a time." - qn = self.quote_name_unless_alias + qn = self result = ['DELETE FROM %s' % qn(self.query.tables[0])] - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(params) @@ -913,7 +926,7 @@ class SQLUpdateCompiler(SQLCompiler): if not self.query.values: return '', () table = self.query.tables[0] - qn = self.quote_name_unless_alias + qn = self result = ['UPDATE %s' % qn(table)] result.append('SET') values, update_params = [], [] @@ -933,7 +946,7 @@ class SQLUpdateCompiler(SQLCompiler): val = SQLEvaluator(val, self.query, allow_joins=False) name = field.column if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn, self.connection) + sql, params = self.compile(val) values.append('%s = %s' % (qn(name), sql)) update_params.extend(params) elif val is not None: @@ -944,7 +957,7 @@ class SQLUpdateCompiler(SQLCompiler): if not values: return '', () result.append(', '.join(values)) - where, params = self.query.where.as_sql(qn=qn, connection=self.connection) + where, params = self.compile(self.query.where) if where: result.append('WHERE %s' % where) return ' '.join(result), tuple(update_params + params) @@ -1024,11 +1037,11 @@ class SQLAggregateCompiler(SQLCompiler): parameters. """ if qn is None: - qn = self.quote_name_unless_alias + qn = self sql, params = [], [] for aggregate in self.query.aggregate_select.values(): - agg_sql, agg_params = aggregate.as_sql(qn, self.connection) + agg_sql, agg_params = self.compile(aggregate) sql.append(agg_sql) params.extend(agg_params) sql = ', '.join(sql) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index f45ecaf76d..421c3cd860 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -5,19 +5,28 @@ the SQL domain. class Col(object): - def __init__(self, alias, col): - self.alias = alias - self.col = col + def __init__(self, alias, target, source): + self.alias, self.target, self.source = alias, target, source def as_sql(self, qn, connection): - return '%s.%s' % (qn(self.alias), self.col), [] + return "%s.%s" % (qn(self.alias), qn(self.target.column)), [] + + @property + def output_type(self): + return self.source + + def relabeled_clone(self, relabels): + return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source) + + def get_group_by_cols(self): + return [(self.alias, self.target.column)] + + def get_lookup(self, name): + return self.output_type.get_lookup(name) def prepare(self): return self - def relabeled_clone(self, relabels): - return self.__class__(relabels.get(self.alias, self.alias), self.col) - class EmptyResultSet(Exception): pass diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 9f29e2ace5..e31eaa8a2f 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -24,11 +24,11 @@ class SQLEvaluator(object): (change_map.get(col[0], col[0]), col[1]))) return clone - def get_cols(self): + def get_group_by_cols(self): cols = [] for node, col in self.cols: - if hasattr(node, 'get_cols'): - cols.extend(node.get_cols()) + if hasattr(node, 'get_group_by_cols'): + cols.extend(node.get_group_by_cols()) elif isinstance(col, tuple): cols.append(col) return cols diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index c3c8e55793..db4e6744bf 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -19,6 +19,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.aggregates import refs_aggregate from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist +from django.db.models.lookups import Transform from django.db.models.query_utils import Q from django.db.models.related import PathInfo from django.db.models.sql import aggregates as base_aggregates_module @@ -1028,13 +1029,16 @@ class Query(object): # Add the aggregate to the query aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) - def prepare_lookup_value(self, value, lookup_type, can_reuse): + def prepare_lookup_value(self, value, lookups, can_reuse): + # Default lookup if none given is exact. + if len(lookups) == 0: + lookups = ['exact'] # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # uses of None as a query value. if value is None: - if lookup_type not in ('exact', 'iexact'): + if lookups[-1] not in ('exact', 'iexact'): raise ValueError("Cannot use None as a query value") - lookup_type = 'isnull' + lookups[-1] = 'isnull' value = True elif callable(value): warnings.warn( @@ -1055,40 +1059,54 @@ class Query(object): # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we # can do here. Similar thing is done in is_nullable(), too. if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and - lookup_type == 'exact' and value == ''): + lookups[-1] == 'exact' and value == ''): value = True - lookup_type = 'isnull' - return value, lookup_type + lookups[-1] = ['isnull'] + return value, lookups def solve_lookup_type(self, lookup): """ Solve the lookup type from the lookup (eg: 'foobar__id__icontains') """ - lookup_type = 'exact' # Default lookup type - lookup_parts = lookup.split(LOOKUP_SEP) - num_parts = len(lookup_parts) - if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms - and (not self._aggregates or lookup not in self._aggregates)): - # Traverse the lookup query to distinguish related fields from - # lookup types. - lookup_model = self.model - for counter, field_name in enumerate(lookup_parts): - try: - lookup_field = lookup_model._meta.get_field(field_name) - except FieldDoesNotExist: - # Not a field. Bail out. - lookup_type = lookup_parts.pop() - break - # Unless we're at the end of the list of lookups, let's attempt - # to continue traversing relations. - if (counter + 1) < num_parts: - try: - lookup_model = lookup_field.rel.to - except AttributeError: - # Not a related field. Bail out. - lookup_type = lookup_parts.pop() - break - return lookup_type, lookup_parts + lookup_splitted = lookup.split(LOOKUP_SEP) + if self._aggregates: + aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) + if aggregate: + return aggregate_lookups, (), aggregate + _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) + field_parts = lookup_splitted[0:len(lookup_splitted) - len(lookup_parts)] + if len(lookup_parts) == 0: + lookup_parts = ['exact'] + elif len(lookup_parts) > 1: + if not field_parts: + raise FieldError( + 'Invalid lookup "%s" for model %s".' % + (lookup, self.get_meta().model.__name__)) + return lookup_parts, field_parts, False + + def build_lookup(self, lookups, lhs, rhs): + lookups = lookups[:] + while lookups: + lookup = lookups[0] + next = lhs.get_lookup(lookup) + if next: + if len(lookups) == 1: + # This was the last lookup, so return value lookup. + if issubclass(next, Transform): + lookups.append('exact') + lhs = next(lhs, lookups) + else: + return next(lhs, rhs) + else: + lhs = next(lhs, lookups) + # A field's get_lookup() can return None to opt for backwards + # compatibility path. + elif len(lookups) > 2: + raise FieldError( + "Unsupported lookup for field '%s'" % lhs.output_type.name) + else: + return None + lookups = lookups[1:] def build_filter(self, filter_expr, branch_negated=False, current_negated=False, can_reuse=None, connector=AND): @@ -1118,21 +1136,24 @@ class Query(object): is responsible for unreffing the joins used. """ arg, value = filter_expr - lookup_type, parts = self.solve_lookup_type(arg) - if not parts: + if not arg: raise FieldError("Cannot parse keyword query %r" % arg) + lookups, parts, reffed_aggregate = self.solve_lookup_type(arg) # Work out the lookup type and remove it from the end of 'parts', # if necessary. - value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse) + value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) used_joins = getattr(value, '_used_joins', []) clause = self.where_class() - if self._aggregates: - for alias, aggregate in self.aggregates.items(): - if alias in (parts[0], LOOKUP_SEP.join(parts)): - clause.add((aggregate, lookup_type, value), AND) - return clause, [] + if reffed_aggregate: + condition = self.build_lookup(lookups, reffed_aggregate, value) + if not condition: + # Backwards compat for custom lookups + assert len(lookups) == 1 + condition = (reffed_aggregate, lookups[0], value) + clause.add(condition, AND) + return clause, [] opts = self.get_meta() alias = self.get_initial_alias() @@ -1154,11 +1175,31 @@ class Query(object): targets, alias, join_list = self.trim_joins(sources, join_list, path) if hasattr(field, 'get_lookup_constraint'): - constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources, - lookup_type, value) + # For now foreign keys get special treatment. This should be + # refactored when composite fields lands. + condition = field.get_lookup_constraint(self.where_class, alias, targets, sources, + lookups, value) + lookup_type = lookups[-1] else: - constraint = (Constraint(alias, targets[0].column, field), lookup_type, value) - clause.add(constraint, AND) + assert(len(targets) == 1) + col = Col(alias, targets[0], field) + condition = self.build_lookup(lookups, col, value) + if not condition: + # Backwards compat for custom lookups + if lookups[0] not in self.query_terms: + raise FieldError( + "Join on field '%s' not permitted. Did you " + "misspell '%s' for the lookup type?" % + (col.output_type.name, lookups[0])) + if len(lookups) > 1: + raise FieldError("Nested lookup '%s' not supported." % + LOOKUP_SEP.join(lookups)) + condition = (Constraint(alias, targets[0].column, field), lookups[0], value) + lookup_type = lookups[-1] + else: + lookup_type = condition.lookup_name + + clause.add(condition, AND) require_outer = lookup_type == 'isnull' and value is True and not current_negated if current_negated and (lookup_type != 'isnull' or value is False): @@ -1175,7 +1216,8 @@ class Query(object): # (col IS NULL OR col != someval) # <=> # NOT (col IS NOT NULL AND col = someval). - clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND) + lookup_class = targets[0].get_lookup('isnull') + clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND) return clause, used_joins if not require_outer else () def add_filter(self, filter_clause): @@ -1189,7 +1231,7 @@ class Query(object): if not self._aggregates: return False if not isinstance(obj, Node): - return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates) + return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0] or (hasattr(obj[1], 'contains_aggregate') and obj[1].contains_aggregate(self.aggregates))) return any(self.need_having(c) for c in obj.children) @@ -1277,7 +1319,7 @@ class Query(object): needed_inner = joinpromoter.update_join_types(self) return target_clause, needed_inner - def names_to_path(self, names, opts, allow_many): + def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): """ Walks the names path and turns them PathInfo tuples. Note that a single name in 'names' can generate multiple PathInfos (m2m for @@ -1297,9 +1339,10 @@ class Query(object): try: field, model, direct, m2m = opts.get_field_by_name(name) except FieldDoesNotExist: - available = opts.get_all_field_names() + list(self.aggregate_select) - raise FieldError("Cannot resolve keyword %r into field. " - "Choices are: %s" % (name, ", ".join(available))) + # We didn't found the current field, so move position back + # one step. + pos -= 1 + break # Check if we need any joins for concrete inheritance cases (the # field lives in parent, but we are currently in one of its # children) @@ -1334,15 +1377,14 @@ class Query(object): final_field = field targets = (field,) break + if pos == -1 or (fail_on_missing and pos + 1 != len(names)): + self.raise_field_error(opts, name) + return path, final_field, targets, names[pos + 1:] - if pos != len(names) - 1: - if pos == len(names) - 2: - raise FieldError( - "Join on field %r not permitted. Did you misspell %r for " - "the lookup type?" % (name, names[pos + 1])) - else: - raise FieldError("Join on field %r not permitted." % name) - return path, final_field, targets + def raise_field_error(self, opts, name): + available = opts.get_all_field_names() + list(self.aggregate_select) + raise FieldError("Cannot resolve keyword %r into field. " + "Choices are: %s" % (name, ", ".join(available))) def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): """ @@ -1371,8 +1413,9 @@ class Query(object): """ joins = [alias] # First, generate the path for the names - path, final_field, targets = self.names_to_path( - names, opts, allow_many) + path, final_field, targets, rest = self.names_to_path( + names, opts, allow_many, fail_on_missing=True) + # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type # of the trimmed joins. @@ -1387,8 +1430,6 @@ class Query(object): alias = self.join( connection, reuse=reuse, nullable=nullable, join_field=join.join_field) joins.append(alias) - if hasattr(final_field, 'field'): - final_field = final_field.field return final_field, targets, opts, joins, path def trim_joins(self, targets, joins, path): @@ -1451,17 +1492,19 @@ class Query(object): # nothing alias, col = query.select[0].col if self.is_nullable(query.select[0].field): - query.where.add((Constraint(alias, col, query.select[0].field), 'isnull', False), AND) + lookup_class = query.select[0].field.get_lookup('isnull') + lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False) + query.where.add(lookup, AND) if alias in can_reuse: - pk = query.select[0].field.model._meta.pk + select_field = query.select[0].field + pk = select_field.model._meta.pk # Need to add a restriction so that outer query's filters are in effect for # the subquery, too. query.bump_prefix(self) - query.where.add( - (Constraint(query.select[0].col[0], pk.column, pk), - 'exact', Col(alias, pk.column)), - AND - ) + lookup_class = select_field.get_lookup('exact') + lookup = lookup_class(Col(query.select[0].col[0], pk, pk), + Col(alias, pk, pk)) + query.where.add(lookup, AND) condition, needed_inner = self.build_filter( ('%s__in' % trimmed_prefix, query), diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index e9e292e787..86b1efd3f8 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -5,12 +5,12 @@ Query subclasses which provide extra functionality beyond simple data retrieval. from django.conf import settings from django.core.exceptions import FieldError from django.db import connections +from django.db.models.query_utils import Q from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo from django.db.models.sql.datastructures import Date, DateTime from django.db.models.sql.query import Query -from django.db.models.sql.where import AND, Constraint from django.utils import six from django.utils import timezone @@ -42,10 +42,10 @@ class DeleteQuery(Query): if not field: field = self.get_meta().pk for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class() - where.add((Constraint(None, field.column, field), 'in', - pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(self.get_meta().db_table, where, using=using) + self.where = self.where_class() + self.add_q(Q( + **{field.attname + '__in': pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]})) + self.do_query(self.get_meta().db_table, self.where, using=using) def delete_qs(self, query, using): """ @@ -80,9 +80,8 @@ class DeleteQuery(Query): SelectInfo((self.get_initial_alias(), pk.column), None) ] values = innerq - where = self.where_class() - where.add((Constraint(None, pk.column, pk), 'in', values), AND) - self.where = where + self.where = self.where_class() + self.add_q(Q(pk__in=values)) self.get_compiler(using).execute_sql(None) @@ -113,13 +112,10 @@ class UpdateQuery(Query): related_updates=self.related_updates.copy(), **kwargs) def update_batch(self, pk_list, values, using): - pk_field = self.get_meta().pk self.add_update_values(values) for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): self.where = self.where_class() - self.where.add((Constraint(None, pk_field.column, pk_field), 'in', - pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]), - AND) + self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE])) self.get_compiler(using).execute_sql(None) def add_update_values(self, values): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 44a4ce9d1d..be0c559c1b 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -5,6 +5,7 @@ Code to manage the creation and SQL rendering of 'where' constraints. import collections import datetime from itertools import repeat +import warnings from django.conf import settings from django.db.models.fields import DateTimeField, Field @@ -101,7 +102,7 @@ class WhereNode(tree.Node): for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn, connection=connection) + sql, params = qn.compile(child) else: # A leaf node in the tree. sql, params = self.make_atom(child, qn, connection) @@ -152,16 +153,16 @@ class WhereNode(tree.Node): sql_string = '(%s)' % sql_string return sql_string, result_params - def get_cols(self): + def get_group_by_cols(self): cols = [] for child in self.children: - if hasattr(child, 'get_cols'): - cols.extend(child.get_cols()) + if hasattr(child, 'get_group_by_cols'): + cols.extend(child.get_group_by_cols()) else: if isinstance(child[0], Constraint): cols.append((child[0].alias, child[0].col)) - if hasattr(child[3], 'get_cols'): - cols.extend(child[3].get_cols()) + if hasattr(child[3], 'get_group_by_cols'): + cols.extend(child[3].get_group_by_cols()) return cols def make_atom(self, child, qn, connection): @@ -174,6 +175,9 @@ class WhereNode(tree.Node): Returns the string for the SQL fragment and the parameters to use for it. """ + warnings.warn( + "The make_atom() method will be removed in Django 1.9. Use Lookup class instead.", + PendingDeprecationWarning) lvalue, lookup_type, value_annotation, params_or_value = child field_internal_type = lvalue.field.get_internal_type() if lvalue.field else None @@ -193,13 +197,13 @@ class WhereNode(tree.Node): field_sql, field_params = self.sql_for_columns(lvalue, qn, connection, field_internal_type), [] else: # A smart object with an as_sql() method. - field_sql, field_params = lvalue.as_sql(qn, connection) + field_sql, field_params = qn.compile(lvalue) is_datetime_field = value_annotation is datetime.datetime cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn, connection) + extra, params = qn.compile(params) cast_sql = '' else: extra = '' @@ -282,6 +286,8 @@ class WhereNode(tree.Node): if hasattr(child, 'relabel_aliases'): # For example another WhereNode child.relabel_aliases(change_map) + elif hasattr(child, 'relabeled_clone'): + self.children[pos] = child.relabeled_clone(change_map) elif isinstance(child, (list, tuple)): # tuple starting with Constraint child = (child[0].relabeled_clone(change_map),) + child[1:] @@ -347,10 +353,13 @@ class Constraint(object): pre-process itself prior to including in the WhereNode. """ def __init__(self, alias, col, field): + warnings.warn( + "The Constraint class will be removed in Django 1.9. Use Lookup class instead.", + PendingDeprecationWarning) self.alias, self.col, self.field = alias, col, field def prepare(self, lookup_type, value): - if self.field: + if self.field and not hasattr(value, 'as_sql'): return self.field.get_prep_lookup(lookup_type, value) return value diff --git a/docs/howto/custom-model-fields.txt b/docs/howto/custom-model-fields.txt index b7bcefad53..3d6eeadb61 100644 --- a/docs/howto/custom-model-fields.txt +++ b/docs/howto/custom-model-fields.txt @@ -662,6 +662,12 @@ Django filter lookups: ``exact``, ``iexact``, ``contains``, ``icontains``, ``endswith``, ``iendswith``, ``range``, ``year``, ``month``, ``day``, ``isnull``, ``search``, ``regex``, and ``iregex``. +.. versionadded:: 1.7 + + If you are using :doc:`Custom lookups ` the + ``lookup_type`` can be any ``lookup_name`` used by the project's custom + lookups. + Your method must be prepared to handle all of these ``lookup_type`` values and should raise either a ``ValueError`` if the ``value`` is of the wrong sort (a list when you were expecting an object, for example) or a ``TypeError`` if diff --git a/docs/index.txt b/docs/index.txt index f37a597276..1856beb65e 100644 --- a/docs/index.txt +++ b/docs/index.txt @@ -81,7 +81,8 @@ manipulating the data of your Web application. Learn more about it below: :doc:`Transactions ` | :doc:`Aggregation ` | :doc:`Custom fields ` | - :doc:`Multiple databases ` + :doc:`Multiple databases ` | + :doc:`Custom lookups ` * **Other:** :doc:`Supported databases ` | diff --git a/docs/ref/models/custom-lookups.txt b/docs/ref/models/custom-lookups.txt new file mode 100644 index 0000000000..7798e92d30 --- /dev/null +++ b/docs/ref/models/custom-lookups.txt @@ -0,0 +1,336 @@ +============== +Custom lookups +============== + +.. versionadded:: 1.7 + +.. module:: django.db.models.lookups + :synopsis: Custom lookups + +.. currentmodule:: django.db.models + +By default Django offers a wide variety of :ref:`built-in lookups +` for filtering (for example, ``exact`` and ``icontains``). This +documentation explains how to write custom lookups and how to alter the working +of existing lookups. + +A simple Lookup example +~~~~~~~~~~~~~~~~~~~~~~~ + +Let's start with a simple custom lookup. We will write a custom lookup ``ne`` +which works opposite to ``exact``. ``Author.objects.filter(name__ne='Jack')`` +will translate to the SQL:: + + "author"."name" <> 'Jack' + +This SQL is backend independent, so we don't need to worry about different +databases. + +There are two steps to making this work. Firstly we need to implement the +lookup, then we need to tell Django about it. The implementation is quite +straightforward:: + + from django.db.models import Lookup + + class NotEqual(Lookup): + lookup_name = 'ne' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s <> %s' % (lhs, rhs), params + +To register the ``NotEqual`` lookup we will just need to call +``register_lookup`` on the field class we want the lookup to be available. In +this case, the lookup makes sense on all ``Field`` subclasses, so we register +it with ``Field`` directly:: + + from django.db.models.fields import Field + Field.register_lookup(NotEqual) + +We can now use ``foo__ne`` for any field ``foo``. You will need to ensure that +this registration happens before you try to create any querysets using it. You +could place the implementation in a ``models.py`` file, or register the lookup +in the ``ready()`` method of an ``AppConfig``. + +Taking a closer look at the implementation, the first required attribute is +``lookup_name``. This allows the ORM to understand how to interpret ``name__ne`` +and use ``NotEqual`` to generate the SQL. By convention, these names are always +lowercase strings containing only letters, but the only hard requirement is +that it must not contain the string ``__``. + +A ``Lookup`` works against two values, ``lhs`` and ``rhs``, standing for +left-hand side and right-hand side. The left-hand side is usually a field +reference, but it can be anything implementing the :ref:`query expression API +`. The right-hand is the value given by the user. In the +example ``Author.objects.filter(name__ne='Jack')``, the left-hand side is a +reference to the ``name`` field of the ``Author`` model, and ``'Jack'`` is the +right-hand side. + +We call ``process_lhs`` and ``process_rhs`` to convert them into the values we +need for SQL. In the above example, ``process_lhs`` returns +``('"author"."name"', [])`` and ``process_rhs`` returns ``('"%s"', ['Jack'])``. +In this example there were no parameters for the left hand side, but this would +depend on the object we have, so we still need to include them in the +parameters we return. + +Finally we combine the parts into a SQL expression with ``<>``, and supply all +the parameters for the query. We then return a tuple containing the generated +SQL string and the parameters. + +A simple transformer example +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The custom lookup above is great, but in some cases you may want to be able to +chain lookups together. For example, let's suppose we are building an +application where we want to make use of the ``abs()`` operator. +We have an ``Experiment`` model which records a start value, end value and the +change (start - end). We would like to find all experiments where the change +was equal to a certain amount (``Experiment.objects.filter(change__abs=27)``), +or where it did not exceede a certain amount +(``Experiment.objects.filter(change__abs__lt=27)``). + +.. note:: + This example is somewhat contrived, but it demonstrates nicely the range of + functionality which is possible in a database backend independent manner, + and without duplicating functionality already in Django. + +We will start by writing a ``AbsoluteValue`` transformer. This will use the SQL +function ``ABS()`` to transform the value before comparison:: + + from django.db.models import Transform + + class AbsoluteValue(Transform): + lookup_name = 'abs' + + def as_sql(self, qn, connection): + lhs, params = qn.compile(self.lhs) + return "ABS(%s)" % lhs, params + +Next, lets register it for ``IntegerField``:: + + from django.db.models import IntegerField + IntegerField.register_lookup(AbsoluteValue) + +We can now run the queris we had before. +``Experiment.objects.filter(change__abs=27)`` will generate the following SQL:: + + SELECT ... WHERE ABS("experiments"."change") = 27 + +By using ``Transform`` instead of ``Lookup`` it means we are able to chain +further lookups afterwards. So +``Experiment.objects.filter(change__abs__lt=27)`` will generate the following +SQL:: + + SELECT ... WHERE ABS("experiments"."change") < 27 + +Subclasses of ``Transform`` usually only operate on the left-hand side of the +expression. Further lookups will work on the transformed value. Note that in +this case where there is no other lookup specified, Django interprets +``change__abs=27`` as ``change__abs__exact=27``. + +When looking for which lookups are allowable after the ``Transform`` has been +applied, Django uses the ``output_type`` attribute. We didn't need to specify +this here as it didn't change, but supposing we were applying ``AbsoluteValue`` +to some field which represents a more complex type (for example a point +relative to an origin, or a complex number) then we may have wanted to specify +``output_type = FloatField``, which will ensure that further lookups like +``abs__lte`` behave as they would for a ``FloatField``. + +Writing an efficient abs__lt lookup +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When using the above written ``abs`` lookup, the SQL produced will not use +indexes efficiently in some cases. In particular, when we use +``change__abs__lt=27``, this is equivalent to ``change__gt=-27`` AND +``change__lt=27``. (For the ``lte`` case we could use the SQL ``BETWEEN``). + +So we would like ``Experiment.objects.filter(change__abs__lt=27)`` to generate +the following SQL:: + + SELECT .. WHERE "experiments"."change" < 27 AND "experiments"."change" > -27 + +The implementation is:: + + from django.db.models import Lookup + + class AbsoluteValueLessThan(Lookup): + lookup_name = 'lt' + + def as_sql(self, qn, connection): + lhs, lhs_params = qn.compile(self.lhs.lhs) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + lhs_params + rhs_params + return '%s > %s AND %s < -%s % (lhs, rhs, lhs, rhs), params + + AbsoluteValue.register_lookup(AbsoluteValueLessThan) + +There are a couple of notable things going on. First, ``AbsoluteValueLessThan`` +isn't calling ``process_lhs()``. Instead it skips the transformation of the +``lhs`` done by ``AbsoluteValue`` and uses the original ``lhs``. That is, we +want to get ``27`` not ``ABS(27)``. Referring directly to ``self.lhs.lhs`` is +safe as ``AbsoluteValueLessThan`` can be accessed only from the +``AbsoluteValue`` lookup, that is the ``lhs`` is always an instance of +``AbsoluteValue``. + +Notice also that as both sides are used multiple times in the query the params +need to contain ``lhs_params`` and ``rhs_params`` multiple times. + +The final query does the inversion (``27`` to ``-27``) directly in the +database. The reason for doing this is that if the self.rhs is something else +than a plain integer value (for example an ``F()`` reference) we can't do the +transformations in Python. + +.. note:: + In fact, most lookups with ``__abs`` could be implemented as range queries + like this, and on most database backend it is likely to be more sensible to + do so as you can make use of the indexes. However with PostgreSQL you may + want to add an index on ``abs(change)`` which would allow these queries to + be very efficient. + +Writing alternative implemenatations for existing lookups +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sometimes different database vendors require different SQL for the same +operation. For this example we will rewrite a custom implementation for +MySQL for the NotEqual operator. Instead of ``<>`` we will be using ``!=`` +operator. (Note that in reality almost all databases support both, including +all the official databases supported by Django). + +We can change the behaviour on a specific backend by creating a subclass of +``NotEqual`` with a ``as_mysql`` method:: + + class MySQLNotEqual(NotEqual): + def as_mysql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + return '%s != %s' % (lhs, rhs), params + Field.register_lookup(MySQLNotExact) + +We can then register it with ``Field``. It takes the place of the original +``NotEqual`` class as it has + +When compiling a query, Django first looks for ``as_%s % connection.vendor`` +methods, and then falls back to ``as_sql``. The vendor names for the in-built +backends are ``sqlite``, ``postgresql``, ``oracle`` and ``mysql``. + +.. _query-expression: + +The Query Expression API +~~~~~~~~~~~~~~~~~~~~~~~~ + +A lookup can assume that the lhs responds to the query expression API. +Currently direct field references, aggregates and ``Transform`` instances respond +to this API. + +.. method:: as_sql(qn, connection) + + Responsible for producing the query string and parameters for the + expression. The ``qn`` has a ``compile()`` method that can be used to + compile other expressions. The ``connection`` is the connection used to + execute the query. + + Calling expression.as_sql() directly is usually incorrect - instead + qn.compile(expression) should be used. The qn.compile() method will take + care of calling vendor-specific methods of the expression. + +.. method:: get_lookup(lookup_name) + + The ``get_lookup()`` method is used to fetch lookups. By default the + lookup is fetched from the expression's output type in the same way + described in registering and fetching lookup documentation below. + It is possible to override this method to alter that behaviour. + +.. method:: as_vendorname(qn, connection) + + Works like ``as_sql()`` method. When an expression is compiled by + ``qn.compile()``, Django will first try to call ``as_vendorname()``, where + vendorname is the vendor name of the backend used for executing the query. + The vendorname is one of ``postgresql``, ``oracle``, ``sqlite`` or + ``mysql`` for Django's built-in backends. + +.. attribute:: output_type + + The ``output_type`` attribute is used by the ``get_lookup()`` method to check for + lookups. The output_type should be a field. + +Note that this documentation lists only the public methods of the API. + +Lookup reference +~~~~~~~~~~~~~~~~ + +.. class:: Lookup + + In addition to the attributes and methods below, lookups also support + ``as_sql`` and ``as_vendorname`` from the query expression API. + +.. attribute:: lhs + + The ``lhs`` (left-hand side) of a lookup tells us what we are comparing the + rhs to. It is an object which implements the query expression API. This is + likely to be a field, an aggregate or a subclass of ``Transform``. + +.. attribute:: rhs + + The ``rhs`` (right-hand side) of a lookup is the value we are comparing the + left hand side to. It may be a plain value, or something which compiles + into SQL, for example an ``F()`` object or a ``Queryset``. + +.. attribute:: lookup_name + + This class level attribute is used when registering lookups. It determines + the name used in queries to trigger this lookup. For example, ``contains`` + or ``exact``. This should not contain the string ``__``. + +.. method:: process_lhs(qn, connection) + + This returns a tuple of ``(lhs_string, lhs_params)``. In some cases you may + wish to compile ``lhs`` directly in your ``as_sql`` methods using + ``qn.compile(self.lhs)``. + +.. method:: process_rhs(qn, connection) + + Behaves the same as ``process_lhs`` but acts on the right-hand side. + +Transform reference +~~~~~~~~~~~~~~~~~~~ + +.. class:: Transform + + In addition to implementing the query expression API Transforms have the + following methods and attributes. + +.. attribute:: lhs + + The ``lhs`` (left-hand-side) of a transform contains the value to be + transformed. The ``lhs`` implements the query expression API. + +.. attribute:: lookup_name + + This class level attribute is used when registering lookups. It determines + the name used in queries to trigger this lookup. For example, ``year`` + or ``dayofweek``. This should not contain the string ``__``. + +.. _lookup-registration-api: + +Registering and fetching lookups +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The lookup registration API is explained below. + +.. classmethod:: register_lookup(lookup) + + Registers the Lookup or Transform for the class. For example + ``DateField.register_lookup(YearExact)`` will register ``YearExact`` for + all ``DateFields`` in the project, but also for fields that are instances + of a subclass of ``DateField`` (for example ``DateTimeField``). + +.. method:: get_lookup(lookup_name) + + Django uses ``get_lookup(lookup_name)`` to fetch lookups or transforms. + The implementation of ``get_lookup()`` fetches lookups or transforms + registered for the current class based on their lookup_name attribute. + +The lookup registration API is available for ``Transform`` and ``Field`` classes. diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index cc2a4cdcf4..2f62a6448b 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -343,6 +343,13 @@ underscores to spaces. See :ref:`Verbose field names `. A list of validators to run for this field. See the :doc:`validators documentation ` for more information. +Registering and fetching lookups +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``Field`` implements the :ref:`lookup registration API `. +The API can be used to customize which lookups are available for a field class, and +how lookups are fetched from a field. + .. _model-field-types: Field types diff --git a/docs/ref/models/index.txt b/docs/ref/models/index.txt index c61bb35a5f..4716c03f95 100644 --- a/docs/ref/models/index.txt +++ b/docs/ref/models/index.txt @@ -13,3 +13,4 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`. instances querysets queries + custom-lookups diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index ca9bf1dbd1..c2a3981567 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -1995,6 +1995,9 @@ specified as keyword arguments to the ``QuerySet`` methods :meth:`filter()`, For an introduction, see :ref:`models and database queries documentation `. +Django's inbuilt lookups are listed below. It is also possible to write +:doc:`custom lookups ` for model fields. + .. fieldlookup:: exact exact diff --git a/docs/releases/1.7.txt b/docs/releases/1.7.txt index 671b5878b9..0525bb2e12 100644 --- a/docs/releases/1.7.txt +++ b/docs/releases/1.7.txt @@ -180,6 +180,27 @@ for the following, instead of backend specific behavior. finally: c.close() +Custom lookups +~~~~~~~~~~~~~~ + +It is now possible to write custom lookups and transforms for the ORM. +Custom lookups work just like Django's inbuilt lookups (e.g. ``lte``, +``icontains``) while transforms are a new concept. + +The :class:`django.db.models.Lookup` class provides a way to add lookup +operators for model fields. As an example it is possible to add ``day_lte`` +opertor for ``DateFields``. + +The :class:`django.db.models.Transform` class allows transformations of +database values prior to the final lookup. For example it is possible to +write a ``year`` transform that extracts year from the field's value. +Transforms allow for chaining. After the ``year`` transform has been added +to ``DateField`` it is possible to filter on the transformed value, for +example ``qs.filter(author__birthdate__year__lte=1981)``. + +For more information about both custom lookups and transforms refer to +:doc:`custom lookups ` documentation. + Minor features ~~~~~~~~~~~~~~ diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index eee61654bc..6ea10278f2 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -443,7 +443,7 @@ class BaseAggregateTestCase(TestCase): vals = Author.objects.filter(pk=1).aggregate(Count("friends__id")) self.assertEqual(vals, {"friends__id__count": 2}) - books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__ge=2).order_by("pk") + books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk") self.assertQuerysetEqual( books, [ "The Definitive Guide to Django: Web Development Done Right", diff --git a/tests/custom_lookups/__init__.py b/tests/custom_lookups/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/custom_lookups/models.py b/tests/custom_lookups/models.py new file mode 100644 index 0000000000..9841b36ce5 --- /dev/null +++ b/tests/custom_lookups/models.py @@ -0,0 +1,13 @@ +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Author(models.Model): + name = models.CharField(max_length=20) + age = models.IntegerField(null=True) + birthdate = models.DateField(null=True) + average_rating = models.FloatField(null=True) + + def __str__(self): + return self.name diff --git a/tests/custom_lookups/tests.py b/tests/custom_lookups/tests.py new file mode 100644 index 0000000000..9f1e7fd44a --- /dev/null +++ b/tests/custom_lookups/tests.py @@ -0,0 +1,279 @@ +from datetime import date +import unittest + +from django.test import TestCase +from .models import Author +from django.db import models +from django.db import connection + + +class Div3Lookup(models.Lookup): + lookup_name = 'div3' + + def as_sql(self, qn, connection): + lhs, params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + return '%s %%%% 3 = %s' % (lhs, rhs), params + + +class Div3Transform(models.Transform): + lookup_name = 'div3' + + def as_sql(self, qn, connection): + lhs, lhs_params = qn.compile(self.lhs) + return '%s %%%% 3' % (lhs,), lhs_params + + +class YearTransform(models.Transform): + lookup_name = 'year' + + def as_sql(self, qn, connection): + lhs_sql, params = qn.compile(self.lhs) + return connection.ops.date_extract_sql('year', lhs_sql), params + + @property + def output_type(self): + return models.IntegerField() + + +class YearExact(models.lookups.Lookup): + lookup_name = 'exact' + + def as_sql(self, qn, connection): + # We will need to skip the extract part, and instead go + # directly with the originating field, that is self.lhs.lhs + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + # Note that we must be careful so that we have params in the + # same order as we have the parts in the SQL. + params = lhs_params + rhs_params + lhs_params + rhs_params + # We use PostgreSQL specific SQL here. Note that we must do the + # conversions in SQL instead of in Python to support F() references. + return ("%(lhs)s >= (%(rhs)s || '-01-01')::date " + "AND %(lhs)s <= (%(rhs)s || '-12-31')::date" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) +YearTransform.register_lookup(YearExact) + + +class YearLte(models.lookups.LessThanOrEqual): + """ + The purpose of this lookup is to efficiently compare the year of the field. + """ + + def as_sql(self, qn, connection): + # Skip the YearTransform above us (no possibility for efficient + # lookup otherwise). + real_lhs = self.lhs.lhs + lhs_sql, params = self.process_lhs(qn, connection, real_lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params.extend(rhs_params) + # Build SQL where the integer year is concatenated with last month + # and day, then convert that to date. (We try to have SQL like: + # WHERE somecol <= '2013-12-31') + # but also make it work if the rhs_sql is field reference. + return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params +YearTransform.register_lookup(YearLte) + + +# We will register this class temporarily in the test method. + + +class InMonth(models.lookups.Lookup): + """ + InMonth matches if the column's month is the same as value's month. + """ + lookup_name = 'inmonth' + + def as_sql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + rhs, rhs_params = self.process_rhs(qn, connection) + # We need to be careful so that we get the params in right + # places. + params = lhs_params + rhs_params + lhs_params + rhs_params + return ("%s >= date_trunc('month', %s) and " + "%s < date_trunc('month', %s) + interval '1 months'" % + (lhs, rhs, lhs, rhs), params) + + +class LookupTests(TestCase): + def test_basic_lookup(self): + a1 = Author.objects.create(name='a1', age=1) + a2 = Author.objects.create(name='a2', age=2) + a3 = Author.objects.create(name='a3', age=3) + a4 = Author.objects.create(name='a4', age=4) + models.IntegerField.register_lookup(Div3Lookup) + try: + self.assertQuerysetEqual( + Author.objects.filter(age__div3=0), + [a3], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(age__div3=1).order_by('age'), + [a1, a4], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(age__div3=2), + [a2], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(age__div3=3), + [], lambda x: x + ) + finally: + models.IntegerField._unregister_lookup(Div3Lookup) + + @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") + def test_birthdate_month(self): + a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) + a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) + a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31)) + a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1)) + models.DateField.register_lookup(InMonth) + try: + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), + [a3], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), + [a2], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), + [a1], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), + [a4], lambda x: x + ) + self.assertQuerysetEqual( + Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), + [], lambda x: x + ) + finally: + models.DateField._unregister_lookup(InMonth) + + def test_div3_extract(self): + models.IntegerField.register_lookup(Div3Transform) + try: + a1 = Author.objects.create(name='a1', age=1) + a2 = Author.objects.create(name='a2', age=2) + a3 = Author.objects.create(name='a3', age=3) + a4 = Author.objects.create(name='a4', age=4) + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(age__div3=2), + [a2], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__lte=3), + [a1, a2, a3, a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(age__div3__in=[0, 2]), + [a2, a3], lambda x: x) + finally: + models.IntegerField._unregister_lookup(Div3Transform) + + +class YearLteTests(TestCase): + def setUp(self): + models.DateField.register_lookup(YearTransform) + self.a1 = Author.objects.create(name='a1', birthdate=date(1981, 2, 16)) + self.a2 = Author.objects.create(name='a2', birthdate=date(2012, 2, 29)) + self.a3 = Author.objects.create(name='a3', birthdate=date(2012, 1, 31)) + self.a4 = Author.objects.create(name='a4', birthdate=date(2012, 3, 1)) + + def tearDown(self): + models.DateField._unregister_lookup(YearTransform) + + @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") + def test_year_lte(self): + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lte=2012), + [self.a1, self.a2, self.a3, self.a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(birthdate__year=2012), + [self.a2, self.a3, self.a4], lambda x: x) + + self.assertNotIn('BETWEEN', str(baseqs.filter(birthdate__year=2012).query)) + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lte=2011), + [self.a1], lambda x: x) + # The non-optimized version works, too. + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lt=2012), + [self.a1], lambda x: x) + + @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific SQL used") + def test_year_lte_fexpr(self): + self.a2.age = 2011 + self.a2.save() + self.a3.age = 2012 + self.a3.save() + self.a4.age = 2013 + self.a4.save() + baseqs = Author.objects.order_by('name') + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lte=models.F('age')), + [self.a3, self.a4], lambda x: x) + self.assertQuerysetEqual( + baseqs.filter(birthdate__year__lt=models.F('age')), + [self.a4], lambda x: x) + + def test_year_lte_sql(self): + # This test will just check the generated SQL for __lte. This + # doesn't require running on PostgreSQL and spots the most likely + # error - not running YearLte SQL at all. + baseqs = Author.objects.order_by('name') + self.assertIn( + '<= (2011 || ', str(baseqs.filter(birthdate__year__lte=2011).query)) + self.assertIn( + '-12-31', str(baseqs.filter(birthdate__year__lte=2011).query)) + + def test_postgres_year_exact(self): + baseqs = Author.objects.order_by('name') + self.assertIn( + '= (2011 || ', str(baseqs.filter(birthdate__year=2011).query)) + self.assertIn( + '-12-31', str(baseqs.filter(birthdate__year=2011).query)) + + def test_custom_implementation_year_exact(self): + try: + # Two ways to add a customized implementation for different backends: + # First is MonkeyPatch of the class. + def as_custom_sql(self, qn, connection): + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + lhs_params + rhs_params + return ("%(lhs)s >= str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + setattr(YearExact, 'as_' + connection.vendor, as_custom_sql) + self.assertIn( + 'concat(', + str(Author.objects.filter(birthdate__year=2012).query)) + finally: + delattr(YearExact, 'as_' + connection.vendor) + try: + # The other way is to subclass the original lookup and register the subclassed + # lookup instead of the original. + class CustomYearExact(YearExact): + # This method should be named "as_mysql" for MySQL, "as_postgresql" for postgres + # and so on, but as we don't know which DB we are running on, we need to use + # setattr. + def as_custom_sql(self, qn, connection): + lhs_sql, lhs_params = self.process_lhs(qn, connection, self.lhs.lhs) + rhs_sql, rhs_params = self.process_rhs(qn, connection) + params = lhs_params + rhs_params + lhs_params + rhs_params + return ("%(lhs)s >= str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') " + "AND %(lhs)s <= str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')" % + {'lhs': lhs_sql, 'rhs': rhs_sql}, params) + setattr(CustomYearExact, 'as_' + connection.vendor, CustomYearExact.as_custom_sql) + YearTransform.register_lookup(CustomYearExact) + self.assertIn( + 'CONCAT(', + str(Author.objects.filter(birthdate__year=2012).query)) + finally: + YearTransform._unregister_lookup(CustomYearExact) + YearTransform.register_lookup(YearExact) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 99f41024f8..320271b1dc 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals from django.core.exceptions import FieldError from django.db.models import F from django.db import transaction -from django.test import TestCase +from django.test import TestCase, skipIfDBFeature from django.utils import six from .models import Company, Employee @@ -224,6 +224,25 @@ class ExpressionsTests(TestCase): acme.num_employees = F("num_employees") + 16 self.assertRaises(TypeError, acme.save) + def test_ticket_11722_iexact_lookup(self): + Employee.objects.create(firstname="John", lastname="Doe") + Employee.objects.create(firstname="Test", lastname="test") + + queryset = Employee.objects.filter(firstname__iexact=F('lastname')) + self.assertQuerysetEqual(queryset, [""]) + + @skipIfDBFeature('has_case_insensitive_like') + def test_ticket_16731_startswith_lookup(self): + Employee.objects.create(firstname="John", lastname="Doe") + e2 = Employee.objects.create(firstname="Jack", lastname="Jackson") + e3 = Employee.objects.create(firstname="Jack", lastname="jackson") + self.assertQuerysetEqual( + Employee.objects.filter(lastname__startswith=F('firstname')), + [e2], lambda x: x) + self.assertQuerysetEqual( + Employee.objects.filter(lastname__istartswith=F('firstname')).order_by('pk'), + [e2, e3], lambda x: x) + def test_ticket_18375_join_reuse(self): # Test that reverse multijoin F() references and the lookup target # the same join. Pre #18375 the F() join was generated first, and the diff --git a/tests/null_queries/tests.py b/tests/null_queries/tests.py index f807ad88ce..1b73c977b4 100644 --- a/tests/null_queries/tests.py +++ b/tests/null_queries/tests.py @@ -45,9 +45,6 @@ class NullQueriesTests(TestCase): # Can't use None on anything other than __exact and __iexact self.assertRaises(ValueError, Choice.objects.filter, id__gt=None) - # Can't use None on anything other than __exact and __iexact - self.assertRaises(ValueError, Choice.objects.filter, foo__gt=None) - # Related managers use __exact=None implicitly if the object hasn't been saved. p2 = Poll(question="How?") self.assertEqual(repr(p2.choice_set.all()), '[]') diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 338ec06921..03cfc71afe 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -2632,8 +2632,15 @@ class WhereNodeTest(TestCase): def as_sql(self, qn, connection): return 'dummy', [] + class MockCompiler(object): + def compile(self, node): + return node.as_sql(self, connection) + + def __call__(self, name): + return connection.ops.quote_name(name) + def test_empty_full_handling_conjunction(self): - qn = connection.ops.quote_name + qn = WhereNodeTest.MockCompiler() w = WhereNode(children=[EverythingNode()]) self.assertEqual(w.as_sql(qn, connection), ('', [])) w.negate() @@ -2658,7 +2665,7 @@ class WhereNodeTest(TestCase): self.assertEqual(w.as_sql(qn, connection), ('', [])) def test_empty_full_handling_disjunction(self): - qn = connection.ops.quote_name + qn = WhereNodeTest.MockCompiler() w = WhereNode(children=[EverythingNode()], connector='OR') self.assertEqual(w.as_sql(qn, connection), ('', [])) w.negate() @@ -2685,7 +2692,7 @@ class WhereNodeTest(TestCase): self.assertEqual(w.as_sql(qn, connection), ('NOT (dummy)', [])) def test_empty_nodes(self): - qn = connection.ops.quote_name + qn = WhereNodeTest.MockCompiler() empty_w = WhereNode() w = WhereNode(children=[empty_w, empty_w]) self.assertEqual(w.as_sql(qn, connection), (None, []))