Fixed #14030 -- Allowed annotations to accept all expressions

This commit is contained in:
Josh Smeaton 2013-12-26 00:13:18 +11:00 committed by Marc Tamlyn
parent 39e3ef88c2
commit f59fd15c49
43 changed files with 2572 additions and 801 deletions

View File

@ -347,6 +347,7 @@ answer newbie questions, and generally made Django that much better:
Jorge Bastida <me@jorgebastida.com> Jorge Bastida <me@jorgebastida.com>
Jorge Gajon <gajon@gajon.org> Jorge Gajon <gajon@gajon.org>
Joseph Kocherhans <joseph@jkocherhans.com> Joseph Kocherhans <joseph@jkocherhans.com>
Josh Smeaton <josh.smeaton@gmail.com>
Joshua Ginsberg <jag@flowtheory.net> Joshua Ginsberg <jag@flowtheory.net>
Jozko Skrablin <jozko.skrablin@gmail.com> Jozko Skrablin <jozko.skrablin@gmail.com>
J. Pablo Fernandez <pupeno@pupeno.com> J. Pablo Fernandez <pupeno@pupeno.com>

View File

@ -10,7 +10,7 @@ from django.db.models import signals, FieldDoesNotExist, DO_NOTHING
from django.db.models.base import ModelBase from django.db.models.base import ModelBase
from django.db.models.fields.related import ForeignObject, ForeignObjectRel from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.related import PathInfo from django.db.models.related import PathInfo
from django.db.models.sql.datastructures import Col from django.db.models.expressions import Col
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.utils.encoding import smart_text, python_2_unicode_compatible from django.utils.encoding import smart_text, python_2_unicode_compatible

View File

@ -186,7 +186,7 @@ class BaseSpatialOperations(object):
""" """
raise NotImplementedError('Distance operations not available on this spatial backend.') raise NotImplementedError('Distance operations not available on this spatial backend.')
def get_geom_placeholder(self, f, value): def get_geom_placeholder(self, f, value, qn):
""" """
Returns the placeholder for the given geometry field with the given Returns the placeholder for the given geometry field with the given
value. Depending on the spatial backend, the placeholder may contain a value. Depending on the spatial backend, the placeholder may contain a
@ -195,16 +195,6 @@ class BaseSpatialOperations(object):
""" """
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
def get_expression_column(self, evaluator):
"""
Helper method to return the quoted column string from the evaluator
for its expression.
"""
for expr, col_tup in evaluator.cols:
if expr is evaluator.expression:
return '%s.%s' % tuple(map(self.quote_name, col_tup))
raise Exception("Could not find the column for the expression.")
# Spatial SQL Construction # Spatial SQL Construction
def spatial_aggregate_sql(self, agg): def spatial_aggregate_sql(self, agg):
raise NotImplementedError('Aggregate support not implemented for this spatial backend.') raise NotImplementedError('Aggregate support not implemented for this spatial backend.')

View File

@ -35,14 +35,14 @@ class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
def geo_db_type(self, f): def geo_db_type(self, f):
return f.geom_type return f.geom_type
def get_geom_placeholder(self, f, value): def get_geom_placeholder(self, f, value, qn):
""" """
The placeholder here has to include MySQL's WKT constructor. Because The placeholder here has to include MySQL's WKT constructor. Because
MySQL does not support spatial transformations, there is no need to MySQL does not support spatial transformations, there is no need to
modify the placeholder based on the contents of the given value. modify the placeholder based on the contents of the given value.
""" """
if hasattr(value, 'expression'): if hasattr(value, 'as_sql'):
placeholder = self.get_expression_column(value) placeholder, _ = qn.compile(value)
else: else:
placeholder = '%s(%%s)' % self.from_text placeholder = '%s(%%s)' % self.from_text
return placeholder return placeholder

View File

@ -9,7 +9,7 @@
""" """
import re import re
from django.db.backends.oracle.base import DatabaseOperations from django.db.backends.oracle.base import DatabaseOperations, Database
from django.contrib.gis.db.backends.base import BaseSpatialOperations from django.contrib.gis.db.backends.base import BaseSpatialOperations
from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
@ -145,9 +145,11 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
else: else:
return None return None
def convert_geom(self, clob, geo_field): def convert_geom(self, value, geo_field):
if clob: if value:
return Geometry(clob.read(), geo_field.srid) if isinstance(value, Database.LOB):
value = value.read()
return Geometry(value, geo_field.srid)
else: else:
return None return None
@ -184,7 +186,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value): def get_geom_placeholder(self, f, value, qn):
""" """
Provides a proper substitution value for Geometries that are not in the Provides a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the SRID of the field. Specifically, this routine will substitute in the
@ -196,14 +198,15 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
def transform_value(val, srid): def transform_value(val, srid):
return val.srid != srid return val.srid != srid
if hasattr(value, 'expression'): if hasattr(value, 'as_sql'):
if transform_value(value, f.srid): if transform_value(value, f.srid):
placeholder = '%s(%%s, %s)' % (self.transform, f.srid) placeholder = '%s(%%s, %s)' % (self.transform, f.srid)
else: else:
placeholder = '%s' placeholder = '%s'
# No geometry value used for F expression, substitute in # No geometry value used for F expression, substitute in
# the column name instead. # the column name instead.
return placeholder % self.get_expression_column(value) sql, _ = qn.compile(value)
return placeholder % sql
else: else:
if transform_value(value, f.srid): if transform_value(value, f.srid):
return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (self.transform, value.srid, f.srid) return '%s(SDO_GEOMETRY(%%s, %s), %s)' % (self.transform, value.srid, f.srid)
@ -219,9 +222,9 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
if agg_name == 'union': if agg_name == 'union':
agg_name += 'agg' agg_name += 'agg'
if agg.is_extent: if agg.is_extent:
sql_template = '%(function)s(%(field)s)' sql_template = '%(function)s(%(expressions)s)'
else: else:
sql_template = '%(function)s(SDOAGGRTYPE(%(field)s,%(tolerance)s))' sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
sql_function = getattr(self, agg_name) sql_function = getattr(self, agg_name)
return self.select % sql_template, sql_function return self.select % sql_template, sql_function

View File

@ -22,7 +22,7 @@ class PostGISOperator(SpatialOperator):
super(PostGISOperator, self).__init__(**kwargs) super(PostGISOperator, self).__init__(**kwargs)
def as_sql(self, connection, lookup, *args): def as_sql(self, connection, lookup, *args):
if lookup.lhs.source.geography and not self.geography: if lookup.lhs.output_field.geography and not self.geography:
raise ValueError('PostGIS geography does not support the "%s" ' raise ValueError('PostGIS geography does not support the "%s" '
'function/operator.' % (self.func or self.op,)) 'function/operator.' % (self.func or self.op,))
return super(PostGISOperator, self).as_sql(connection, lookup, *args) return super(PostGISOperator, self).as_sql(connection, lookup, *args)
@ -32,7 +32,7 @@ class PostGISDistanceOperator(PostGISOperator):
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s' sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s'
def as_sql(self, connection, lookup, template_params, sql_params): def as_sql(self, connection, lookup, template_params, sql_params):
if not lookup.lhs.source.geography and lookup.lhs.source.geodetic(connection): if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
sql_template = self.sql_template sql_template = self.sql_template
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid': if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'}) template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
@ -215,7 +215,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
Converts the geometry returned from PostGIS aggretates. Converts the geometry returned from PostGIS aggretates.
""" """
if hex: if hex:
return Geometry(hex) return Geometry(hex, srid=geo_field.srid)
else: else:
return None return None
@ -284,7 +284,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
else: else:
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value): def get_geom_placeholder(self, f, value, qn):
""" """
Provides a proper substitution value for Geometries that are not in the Provides a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the SRID of the field. Specifically, this routine will substitute in the
@ -296,11 +296,12 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
# Adding Transform() to the SQL placeholder. # Adding Transform() to the SQL placeholder.
placeholder = '%s(%%s, %s)' % (self.transform, f.srid) placeholder = '%s(%%s, %s)' % (self.transform, f.srid)
if hasattr(value, 'expression'): if hasattr(value, 'as_sql'):
# If this is an F expression, then we don't really want # If this is an F expression, then we don't really want
# a placeholder and instead substitute in the column # a placeholder and instead substitute in the column
# of the expression. # of the expression.
placeholder = placeholder % self.get_expression_column(value) sql, _ = qn.compile(value)
placeholder = placeholder % sql
return placeholder return placeholder
@ -375,7 +376,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
agg_name = agg_name.lower() agg_name = agg_name.lower()
if agg_name == 'union': if agg_name == 'union':
agg_name += 'agg' agg_name += 'agg'
sql_template = '%(function)s(%(field)s)' sql_template = '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name) sql_function = getattr(self, agg_name)
return sql_template, sql_function return sql_template, sql_function

View File

@ -178,7 +178,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
dist_param = value dist_param = value
return [dist_param] return [dist_param]
def get_geom_placeholder(self, f, value): def get_geom_placeholder(self, f, value, qn):
""" """
Provides a proper substitution value for Geometries that are not in the Provides a proper substitution value for Geometries that are not in the
SRID of the field. Specifically, this routine will substitute in the SRID of the field. Specifically, this routine will substitute in the
@ -186,14 +186,15 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
""" """
def transform_value(value, srid): def transform_value(value, srid):
return not (value is None or value.srid == srid) return not (value is None or value.srid == srid)
if hasattr(value, 'expression'): if hasattr(value, 'as_sql'):
if transform_value(value, f.srid): if transform_value(value, f.srid):
placeholder = '%s(%%s, %s)' % (self.transform, f.srid) placeholder = '%s(%%s, %s)' % (self.transform, f.srid)
else: else:
placeholder = '%s' placeholder = '%s'
# No geometry value used for F expression, substitute in # No geometry value used for F expression, substitute in
# the column name instead. # the column name instead.
return placeholder % self.get_expression_column(value) sql, _ = qn.compile(value)
return placeholder % sql
else: else:
if transform_value(value, f.srid): if transform_value(value, f.srid):
# Adding Transform() to the SQL placeholder. # Adding Transform() to the SQL placeholder.
@ -255,7 +256,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
agg_name = agg_name.lower() agg_name = agg_name.lower()
if agg_name == 'union': if agg_name == 'union':
agg_name += 'agg' agg_name += 'agg'
sql_template = self.select % '%(function)s(%(field)s)' sql_template = self.select % '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name) sql_function = getattr(self, agg_name)
return sql_template, sql_function return sql_template, sql_function

View File

@ -1,23 +1,66 @@
from django.db.models import Aggregate from django.db.models.aggregates import Aggregate
from django.contrib.gis.db.models.fields import GeometryField, ExtentField
__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union']
class Collect(Aggregate): class GeoAggregate(Aggregate):
template = None
function = None
is_extent = False
def as_sql(self, compiler, connection):
if connection.ops.oracle:
if not hasattr(self, 'tolerance'):
self.tolerance = 0.05
self.extra['tolerance'] = self.tolerance
template, function = connection.ops.spatial_aggregate_sql(self)
if template is None:
template = '%(function)s(%(expressions)s)'
self.extra['template'] = self.extra.get('template', template)
self.extra['function'] = self.extra.get('function', function)
return super(GeoAggregate, self).as_sql(compiler, connection)
def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize)
if not isinstance(self.expressions[0].output_field, GeometryField):
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
return c
def convert_value(self, value, connection):
return connection.ops.convert_geom(value, self.output_field)
class Collect(GeoAggregate):
name = 'Collect' name = 'Collect'
class Extent(Aggregate): class Extent(GeoAggregate):
name = 'Extent' name = 'Extent'
is_extent = '2D'
def __init__(self, expression, **extra):
super(Extent, self).__init__(expression, output_field=ExtentField(), **extra)
def convert_value(self, value, connection):
return connection.ops.convert_extent(value)
class Extent3D(Aggregate): class Extent3D(GeoAggregate):
name = 'Extent3D' name = 'Extent3D'
is_extent = '3D'
def __init__(self, expression, **extra):
super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra)
def convert_value(self, value, connection):
return connection.ops.convert_extent3d(value)
class MakeLine(Aggregate): class MakeLine(GeoAggregate):
name = 'MakeLine' name = 'MakeLine'
class Union(Aggregate): class Union(GeoAggregate):
name = 'Union' name = 'Union'

View File

@ -1,5 +1,5 @@
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.sql.expressions import SQLEvaluator from django.db.models.expressions import ExpressionNode
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.contrib.gis import forms from django.contrib.gis import forms
from django.contrib.gis.db.models.lookups import gis_lookups from django.contrib.gis.db.models.lookups import gis_lookups
@ -165,7 +165,7 @@ class GeometryField(Field):
returning to the caller. returning to the caller.
""" """
value = super(GeometryField, self).get_prep_value(value) value = super(GeometryField, self).get_prep_value(value)
if isinstance(value, SQLEvaluator): if isinstance(value, ExpressionNode):
return value return value
elif isinstance(value, (tuple, list)): elif isinstance(value, (tuple, list)):
geom = value[0] geom = value[0]
@ -197,7 +197,7 @@ class GeometryField(Field):
return geom return geom
def from_db_value(self, value, connection): def from_db_value(self, value, connection):
if value: if value and not isinstance(value, Geometry):
value = Geometry(value) value = Geometry(value)
return value return value
@ -259,7 +259,7 @@ class GeometryField(Field):
pass pass
else: else:
params += value[1:] params += value[1:]
elif isinstance(value, SQLEvaluator): elif isinstance(value, ExpressionNode):
params = [] params = []
else: else:
params = [connection.ops.Adapter(value)] params = [connection.ops.Adapter(value)]
@ -282,12 +282,12 @@ class GeometryField(Field):
else: else:
return connection.ops.Adapter(self.get_prep_value(value)) return connection.ops.Adapter(self.get_prep_value(value))
def get_placeholder(self, value, connection): def get_placeholder(self, value, qn, connection):
""" """
Returns the placeholder for the geometry column for the Returns the placeholder for the geometry column for the
given value. given value.
""" """
return connection.ops.get_geom_placeholder(self, value) return connection.ops.get_geom_placeholder(self, value, qn)
for klass in gis_lookups.values(): for klass in gis_lookups.values():
@ -335,3 +335,12 @@ class GeometryCollectionField(GeometryField):
geom_type = 'GEOMETRYCOLLECTION' geom_type = 'GEOMETRYCOLLECTION'
form_class = forms.GeometryCollectionField form_class = forms.GeometryCollectionField
description = _("Geometry collection") description = _("Geometry collection")
class ExtentField(Field):
"Used as a return value from an extent aggregate"
description = _("Extent Aggregate Field")
def get_internal_type(self):
return "ExtentField"

View File

@ -4,7 +4,7 @@ import re
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
from django.db.models.lookups import Lookup from django.db.models.lookups import Lookup
from django.db.models.sql.expressions import SQLEvaluator from django.db.models.expressions import ExpressionNode, Col
from django.utils import six from django.utils import six
gis_lookups = {} gis_lookups = {}
@ -68,18 +68,19 @@ class GISLookup(Lookup):
rhs, rhs_params = super(GISLookup, self).process_rhs(qn, connection) rhs, rhs_params = super(GISLookup, self).process_rhs(qn, connection)
geom = self.rhs geom = self.rhs
if isinstance(self.rhs, SQLEvaluator): if isinstance(self.rhs, Col):
# Make sure the F Expression destination field exists, and # Make sure the F Expression destination field exists, and
# set an `srid` attribute with the same as that of the # set an `srid` attribute with the same as that of the
# destination. # destination.
geo_fld = self._check_geo_field(self.rhs.opts, self.rhs.expression.name) geo_fld = self.rhs.output_field
if not geo_fld: if not hasattr(geo_fld, 'srid'):
raise ValueError('No geographic field found in expression.') raise ValueError('No geographic field found in expression.')
self.rhs.srid = geo_fld.srid self.rhs.srid = geo_fld.srid
elif isinstance(self.rhs, ExpressionNode):
raise ValueError('Complex expressions not supported for GeometryField')
elif isinstance(self.rhs, (list, tuple)): elif isinstance(self.rhs, (list, tuple)):
geom = self.rhs[0] geom = self.rhs[0]
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, qn)
rhs = connection.ops.get_geom_placeholder(self.lhs.source, geom)
return rhs, rhs_params return rhs, rhs_params
def as_sql(self, qn, connection): def as_sql(self, qn, connection):

View File

@ -530,7 +530,7 @@ class GeoQuerySet(QuerySet):
# transformation SQL. # transformation SQL.
geom = geo_field.get_prep_value(settings['procedure_args'][name]) geom = geo_field.get_prep_value(settings['procedure_args'][name])
params = geo_field.get_db_prep_lookup('contains', geom, connection=connection) params = geo_field.get_db_prep_lookup('contains', geom, connection=connection)
geom_placeholder = geo_field.get_placeholder(geom, connection) geom_placeholder = geo_field.get_placeholder(geom, None, connection)
# Replacing the procedure format with that of any needed # Replacing the procedure format with that of any needed
# transformation SQL. # transformation SQL.

View File

@ -6,12 +6,15 @@ from django.contrib.gis.db.models.fields import GeometryField
__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__ __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__
warnings.warn(
"django.contrib.gis.db.models.sql.aggregates is deprecated. Use "
"django.contrib.gis.db.models.aggregates instead.",
RemovedInDjango20Warning, stacklevel=2)
class GeoAggregate(Aggregate): class GeoAggregate(Aggregate):
# Default SQL template for spatial aggregates. # Default SQL template for spatial aggregates.
sql_template = '%(function)s(%(field)s)' sql_template = '%(function)s(%(expressions)s)'
# Conversion class, if necessary.
conversion_class = None
# Flags for indicating the type of the aggregate. # Flags for indicating the type of the aggregate.
is_extent = False is_extent = False
@ -45,7 +48,7 @@ class GeoAggregate(Aggregate):
substitutions = { substitutions = {
'function': sql_function, 'function': sql_function,
'field': field_name 'expressions': field_name
} }
substitutions.update(self.extra) substitutions.update(self.extra)

View File

@ -70,8 +70,8 @@ class GeoSQLCompiler(compiler.SQLCompiler):
aliases.update(new_aliases) aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length() max_name_length = self.connection.ops.max_name_length()
for alias, aggregate in self.query.aggregate_select.items(): for alias, annotation in self.query.annotation_select.items():
agg_sql, agg_params = aggregate.as_sql(qn, self.connection) agg_sql, agg_params = self.compile(annotation)
if alias is None: if alias is None:
result.append(agg_sql) result.append(agg_sql)
else: else:

View File

@ -4,7 +4,7 @@ from django.db.models.sql.constants import QUERY_TERMS
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.lookups import GISLookup from django.contrib.gis.db.models.lookups import GISLookup
from django.contrib.gis.db.models.sql import aggregates as gis_aggregates from django.contrib.gis.db.models import aggregates as gis_aggregates
from django.contrib.gis.db.models.sql.conversion import GeomField from django.contrib.gis.db.models.sql.conversion import GeomField
@ -14,7 +14,6 @@ class GeoQuery(sql.Query):
""" """
# Overriding the valid query terms. # Overriding the valid query terms.
query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys()) query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys())
aggregates_module = gis_aggregates
compiler = 'GeoSQLCompiler' compiler = 'GeoSQLCompiler'
@ -40,28 +39,12 @@ class GeoQuery(sql.Query):
# Remove any aggregates marked for reduction from the subquery # Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery. # and move them to the outer AggregateQuery.
connection = connections[using] connection = connections[using]
for alias, aggregate in self.aggregate_select.items(): for alias, annotation in self.annotation_select.items():
if isinstance(aggregate, gis_aggregates.GeoAggregate): if isinstance(annotation, gis_aggregates.GeoAggregate):
if not getattr(aggregate, 'is_extent', False) or connection.ops.oracle: if not getattr(annotation, 'is_extent', False) or connection.ops.oracle:
self.extra_select_fields[alias] = GeomField() self.extra_select_fields[alias] = GeomField()
return super(GeoQuery, self).get_aggregation(using, force_subq) return super(GeoQuery, self).get_aggregation(using, force_subq)
def resolve_aggregate(self, value, aggregate, connection):
"""
Overridden from GeoQuery's normalize to handle the conversion of
GeoAggregate objects.
"""
if isinstance(aggregate, self.aggregates_module.GeoAggregate):
if aggregate.is_extent:
if aggregate.is_extent == '3D':
return connection.ops.convert_extent3d(value)
else:
return connection.ops.convert_extent(value)
else:
return connection.ops.convert_geom(value, aggregate.source)
else:
return super(GeoQuery, self).resolve_aggregate(value, aggregate, connection)
# Private API utilities, subject to change. # Private API utilities, subject to change.
def _geo_field(self, field_name=None): def _geo_field(self, field_name=None):
""" """

View File

@ -20,8 +20,7 @@ from django.db.backends.sqlite3.client import DatabaseClient
from django.db.backends.sqlite3.creation import DatabaseCreation from django.db.backends.sqlite3.creation import DatabaseCreation
from django.db.backends.sqlite3.introspection import DatabaseIntrospection from django.db.backends.sqlite3.introspection import DatabaseIntrospection
from django.db.backends.sqlite3.schema import DatabaseSchemaEditor from django.db.backends.sqlite3.schema import DatabaseSchemaEditor
from django.db.models import fields from django.db.models import fields, aggregates
from django.db.models.sql import aggregates
from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -163,8 +162,7 @@ class DatabaseOperations(BaseDatabaseOperations):
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField) bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
bad_aggregates = (aggregates.Sum, aggregates.Avg, bad_aggregates = (aggregates.Sum, aggregates.Avg,
aggregates.Variance, aggregates.StdDev) aggregates.Variance, aggregates.StdDev)
if (isinstance(aggregate.source, bad_fields) and if aggregate.refs_field(bad_aggregates, bad_fields):
isinstance(aggregate, bad_aggregates)):
raise NotImplementedError( raise NotImplementedError(
'You cannot use Sum, Avg, StdDev and Variance aggregations ' 'You cannot use Sum, Avg, StdDev and Variance aggregations '
'on date/time fields in sqlite3 ' 'on date/time fields in sqlite3 '

View File

@ -4,7 +4,7 @@ import warnings
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA
from django.db.models.query import Q, QuerySet, Prefetch # NOQA from django.db.models.query import Q, QuerySet, Prefetch # NOQA
from django.db.models.expressions import F # NOQA from django.db.models.expressions import ExpressionNode, F, Value, Func # NOQA
from django.db.models.manager import Manager # NOQA from django.db.models.manager import Manager # NOQA
from django.db.models.base import Model # NOQA from django.db.models.base import Model # NOQA
from django.db.models.aggregates import * # NOQA from django.db.models.aggregates import * # NOQA

View File

@ -1,94 +1,152 @@
""" """
Classes to represent the definitions of aggregate functions. Classes to represent the definitions of aggregate functions.
""" """
from django.db.models.constants import LOOKUP_SEP from django.core.exceptions import FieldError
from django.db.models.expressions import Func, Value
from django.db.models.fields import IntegerField, FloatField
__all__ = [ __all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
] ]
def refs_aggregate(lookup_parts, aggregates): class Aggregate(Func):
""" contains_aggregate = True
A little helper method to check if the lookup_parts contains references name = None
to the given aggregates set. Because the LOOKUP_SEP is contained in the
default annotation names we must check each prefix of the lookup_parts
for match.
"""
for n in range(len(lookup_parts) + 1):
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
if level_n_lookup in aggregates:
return aggregates[level_n_lookup], lookup_parts[n:]
return False, ()
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
assert len(self.source_expressions) == 1
c = super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
if c.source_expressions[0].contains_aggregate and not summarize:
name = self.source_expressions[0].name
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
c.name, name, name))
c._patch_aggregate(query) # backward-compatibility support
return c
class Aggregate(object): def refs_field(self, aggregate_types, field_types):
try:
return (isinstance(self, aggregate_types) and
isinstance(self.input_field._output_field_or_none, field_types))
except FieldError:
# Sometimes we don't know the input_field's output type (for example,
# doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
# is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
# have any output field.
return False
@property
def input_field(self):
return self.source_expressions[0]
@property
def default_alias(self):
if hasattr(self.source_expressions[0], 'name'):
return '%s__%s' % (self.source_expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self):
return []
def _patch_aggregate(self, query):
""" """
Default Aggregate definition. Helper method for patching 3rd party aggregates that do not yet support
""" the new way of subclassing. This method should be removed in 2.0
def __init__(self, lookup, **extra):
"""Instantiate a new aggregate.
* lookup is the field on which the aggregate operates. add_to_query(query, alias, col, source, is_summary) will be defined on
* extra is a dictionary of additional data to provide for the legacy aggregates which, in turn, instantiates the SQL implementation of
aggregate definition the aggregate. In all the cases found, the general implementation of
add_to_query looks like:
Also utilizes the class variables:
* name, the identifier for this aggregate function.
"""
self.lookup = lookup
self.extra = extra
def _default_alias(self):
return '%s__%s' % (self.lookup, self.name.lower())
default_alias = property(_default_alias)
def add_to_query(self, query, alias, col, source, is_summary): def add_to_query(self, query, alias, col, source, is_summary):
"""Add the aggregate to the nominated query. klass = SQLImplementationAggregate
This method is used to convert the generic Aggregate definition into a
backend-specific definition.
* query is the backend-specific query instance to which the aggregate
is to be added.
* col is a column reference describing the subject field
of the aggregate. It can be an alias, or a tuple describing
a table and column name.
* source is the underlying field or aggregate definition for
the column reference. If the aggregate is not an ordinal or
computed type, this reference is used to determine the coerced
output type of the aggregate.
* is_summary is a boolean that is set True if the aggregate is a
summary value rather than an annotation.
"""
klass = getattr(query.aggregates_module, self.name)
aggregate = klass(col, source=source, is_summary=is_summary, **self.extra) aggregate = klass(col, source=source, is_summary=is_summary, **self.extra)
query.aggregates[alias] = aggregate query.aggregates[alias] = aggregate
By supplying a known alias, we can get the SQLAggregate out of the
aggregates dict, and use the sql_function and sql_template attributes
to patch *this* aggregate.
"""
if not hasattr(self, 'add_to_query') or self.function is not None:
return
placeholder_alias = "_XXXXXXXX_"
self.add_to_query(query, placeholder_alias, None, None, None)
sql_aggregate = query.aggregates.pop(placeholder_alias)
if 'sql_function' not in self.extra and hasattr(sql_aggregate, 'sql_function'):
self.extra['function'] = sql_aggregate.sql_function
if hasattr(sql_aggregate, 'sql_template'):
self.extra['template'] = sql_aggregate.sql_template
class Avg(Aggregate): class Avg(Aggregate):
function = 'AVG'
name = 'Avg' name = 'Avg'
def __init__(self, expression, **extra):
super(Avg, self).__init__(expression, output_field=FloatField(), **extra)
def convert_value(self, value, connection):
if value is None:
return value
return float(value)
class Count(Aggregate): class Count(Aggregate):
function = 'COUNT'
name = 'Count' name = 'Count'
template = '%(function)s(%(distinct)s%(expressions)s)'
def __init__(self, expression, distinct=False, **extra):
if expression == '*':
expression = Value(expression)
expression._output_field = IntegerField()
super(Count, self).__init__(
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
def convert_value(self, value, connection):
if value is None:
return 0
return int(value)
class Max(Aggregate): class Max(Aggregate):
function = 'MAX'
name = 'Max' name = 'Max'
class Min(Aggregate): class Min(Aggregate):
function = 'MIN'
name = 'Min' name = 'Min'
class StdDev(Aggregate): class StdDev(Aggregate):
name = 'StdDev' name = 'StdDev'
def __init__(self, expression, sample=False, **extra):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
def convert_value(self, value, connection):
if value is None:
return value
return float(value)
class Sum(Aggregate): class Sum(Aggregate):
function = 'SUM'
name = 'Sum' name = 'Sum'
class Variance(Aggregate): class Variance(Aggregate):
name = 'Variance' name = 'Variance'
def __init__(self, expression, sample=False, **extra):
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
def convert_value(self, value, connection):
if value is None:
return value
return float(value)

View File

@ -1,14 +1,20 @@
import copy
import datetime import datetime
from django.db.models.aggregates import refs_aggregate from django.core.exceptions import FieldError
from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.utils import tree from django.db.models.query_utils import refs_aggregate
from django.utils.functional import cached_property
class ExpressionNode(tree.Node): class CombinableMixin(object):
""" """
Base class for all query expressions. Provides the ability to combine one or two objects with
some connector. For example F('foo') + F('bar').
""" """
# Arithmetic connectors # Arithmetic connectors
ADD = '+' ADD = '+'
SUB = '-' SUB = '-'
@ -25,44 +31,17 @@ class ExpressionNode(tree.Node):
BITAND = '&' BITAND = '&'
BITOR = '|' BITOR = '|'
def __init__(self, children=None, connector=None, negated=False):
if children is not None and len(children) > 1 and connector is None:
raise TypeError('You have to specify a connector.')
super(ExpressionNode, self).__init__(children, connector, negated)
def _combine(self, other, connector, reversed, node=None): def _combine(self, other, connector, reversed, node=None):
if isinstance(other, datetime.timedelta): if isinstance(other, datetime.timedelta):
return DateModifierNode([self, other], connector) return DateModifierNode(self, connector, other)
if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression
other = Value(other)
if reversed: if reversed:
obj = ExpressionNode([other], connector) return Expression(other, connector, self)
obj.add(node or self, connector) return Expression(self, connector, other)
else:
obj = node or ExpressionNode([self], connector)
obj.add(other, connector)
return obj
def contains_aggregate(self, existing_aggregates):
if self.children:
return any(child.contains_aggregate(existing_aggregates)
for child in self.children
if hasattr(child, 'contains_aggregate'))
else:
return refs_aggregate(self.name.split(LOOKUP_SEP),
existing_aggregates)
def prepare_database_save(self, unused):
return self
###################
# VISITOR METHODS #
###################
def prepare(self, evaluator, query, allow_joins):
return evaluator.prepare_node(self, query, allow_joins)
def evaluate(self, evaluator, qn, connection):
return evaluator.evaluate_node(self, qn, connection)
############# #############
# OPERATORS # # OPERATORS #
@ -137,27 +116,240 @@ class ExpressionNode(tree.Node):
) )
class F(ExpressionNode): class ExpressionNode(CombinableMixin):
""" """
An expression representing the value of the given field. Base class for all query expressions.
""" """
def __init__(self, name):
super(F, self).__init__(None, None, False)
self.name = name
def __deepcopy__(self, memodict): # aggregate specific fields
obj = super(F, self).__deepcopy__(memodict) is_summary = False
obj.name = self.name
return obj
def prepare(self, evaluator, query, allow_joins): def __init__(self, output_field=None):
return evaluator.prepare_leaf(self, query, allow_joins) self._output_field = output_field
def evaluate(self, evaluator, qn, connection): def get_source_expressions(self):
return evaluator.evaluate_leaf(self, qn, connection) return []
def set_source_expressions(self, exprs):
assert len(exprs) == 0
def as_sql(self, compiler, connection):
"""
Responsible for returning a (sql, [params]) tuple to be included
in the current query.
Different backends can provide their own implementation, by
providing an `as_{vendor}` method and patching the Expression:
```
def override_as_sql(self, compiler, connection):
# custom logic
return super(ExpressionNode, self).as_sql(compiler, connection)
setattr(ExpressionNode, 'as_' + connection.vendor, override_as_sql)
```
Arguments:
* compiler: the query compiler responsible for generating the query.
Must have a compile method, returning a (sql, [params]) tuple.
Calling compiler(value) will return a quoted `value`.
* connection: the database connection used for the current query.
Returns: (sql, params)
Where `sql` is a string containing ordered sql parameters to be
replaced with the elements of the list `params`.
"""
raise NotImplementedError("Subclasses must implement as_sql()")
@cached_property
def contains_aggregate(self):
for expr in self.get_source_expressions():
if expr and expr.contains_aggregate:
return True
return False
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
"""
Provides the chance to do any preprocessing or validation before being
added to the query.
Arguments:
* query: the backend query implementation
* allow_joins: boolean allowing or denying use of joins
in this query
* reuse: a set of reusable joins for multijoins
* summarize: a terminal aggregate clause
Returns: an ExpressionNode to be added to the query.
"""
c = self.copy()
c.is_summary = summarize
return c
def _prepare(self):
"""
Hook used by Field.get_prep_lookup() to do custom preparation.
"""
return self
@property
def field(self):
return self.output_field
@cached_property
def output_field(self):
"""
Returns the output type of this expressions.
"""
if self._output_field_or_none is None:
raise FieldError("Cannot resolve expression type, unknown output_field")
return self._output_field_or_none
@cached_property
def _output_field_or_none(self):
"""
Returns the output field of this expression, or None if no output type
can be resolved. Note that the 'output_field' property will raise
FieldError if no type can be resolved, but this attribute allows for
None values.
"""
if self._output_field is None:
self._resolve_output_field()
return self._output_field
def _resolve_output_field(self):
"""
Attempts to infer the output type of the expression. If the output
fields of all source fields match then we can simply infer the same
type here.
"""
if self._output_field is None:
sources = self.get_source_fields()
num_sources = len(sources)
if num_sources == 0:
self._output_field = None
else:
self._output_field = sources[0]
for source in sources:
if source is not None and not isinstance(self._output_field, source.__class__):
raise FieldError(
"Expression contains mixed types. You must set output_field")
def convert_value(self, value, connection):
"""
Expressions provide their own converters because users have the option
of manually specifying the output_field which may be a different type
from the one the database returns.
"""
field = self.output_field
internal_type = field.get_internal_type()
if value is None:
return value
elif internal_type == 'FloatField':
return float(value)
elif internal_type.endswith('IntegerField'):
return int(value)
elif internal_type == 'DecimalField':
return backend_utils.typecast_decimal(field.format_number(value))
return value
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def get_transform(self, name):
return self.output_field.get_transform(name)
def relabeled_clone(self, change_map):
clone = self.copy()
clone.set_source_expressions(
[e.relabeled_clone(change_map) for e in self.get_source_expressions()])
return clone
def copy(self):
c = copy.copy(self)
c.copied = True
return c
def refs_aggregate(self, existing_aggregates):
"""
Does this expression contain a reference to some of the
existing aggregates? If so, returns the aggregate and also
the lookup parts that *weren't* found. So, if
exsiting_aggregates = {'max_id': Max('id')}
self.name = 'max_id'
queryset.filter(max_id__range=[10,100])
then this method will return Max('id') and those parts of the
name that weren't found. In this case `max_id` is found and the range
portion is returned as ('range',).
"""
for node in self.get_source_expressions():
agg, lookup = node.refs_aggregate(existing_aggregates)
if agg:
return agg, lookup
return False, ()
def refs_field(self, aggregate_types, field_types):
"""
Helper method for check_aggregate_support on backends
"""
return any(
node.refs_field(aggregate_types, field_types)
for node in self.get_source_expressions())
def prepare_database_save(self, field):
return self
def get_group_by_cols(self):
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
return cols
def get_source_fields(self):
"""
Returns the underlying field types used by this
aggregate.
"""
return [e._output_field_or_none for e in self.get_source_expressions()]
class DateModifierNode(ExpressionNode): class Expression(ExpressionNode):
def __init__(self, lhs, connector, rhs, output_field=None):
super(Expression, self).__init__(output_field=output_field)
self.connector = connector
self.lhs = lhs
self.rhs = rhs
def get_source_expressions(self):
return [self.lhs, self.rhs]
def set_source_expressions(self, exprs):
self.lhs, self.rhs = exprs
def as_sql(self, compiler, connection):
expressions = []
expression_params = []
sql, params = compiler.compile(self.lhs)
expressions.append(sql)
expression_params.extend(params)
sql, params = compiler.compile(self.rhs)
expressions.append(sql)
expression_params.extend(params)
# order of precedence
expression_wrapper = '(%s)'
sql = connection.ops.combine_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = self.copy()
c.is_summary = summarize
c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize)
c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize)
return c
class DateModifierNode(Expression):
""" """
Node that implements the following syntax: Node that implements the following syntax:
filter(end_date__gt=F('start_date') + datetime.timedelta(days=3, seconds=200)) filter(end_date__gt=F('start_date') + datetime.timedelta(days=3, seconds=200))
@ -183,14 +375,195 @@ class DateModifierNode(ExpressionNode):
Only adding and subtracting timedeltas is supported, attempts to use other Only adding and subtracting timedeltas is supported, attempts to use other
operations raise a TypeError. operations raise a TypeError.
""" """
def __init__(self, children, connector, negated=False): def __init__(self, lhs, connector, rhs):
if len(children) != 2: if not isinstance(rhs, datetime.timedelta):
raise TypeError('Must specify a node and a timedelta.') raise TypeError('rhs must be a timedelta.')
if not isinstance(children[1], datetime.timedelta):
raise TypeError('Second child must be a timedelta.')
if connector not in (self.ADD, self.SUB): if connector not in (self.ADD, self.SUB):
raise TypeError('Connector must be + or -, not %s' % connector) raise TypeError('Connector must be + or -, not %s' % connector)
super(DateModifierNode, self).__init__(children, connector, negated) super(DateModifierNode, self).__init__(lhs, connector, Value(rhs))
def evaluate(self, evaluator, qn, connection): def as_sql(self, compiler, connection):
return evaluator.evaluate_date_modifier_node(self, qn, connection) timedelta = self.rhs.value
sql, params = compiler.compile(self.lhs)
if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
return sql, params
return connection.ops.date_interval_sql(sql, self.connector, timedelta), params
class F(CombinableMixin):
"""
An object capable of resolving references to existing query objects.
"""
def __init__(self, name):
"""
Arguments:
* name: the name of the field this expression references
"""
self.name = name
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
return query.resolve_ref(self.name, allow_joins, reuse, summarize)
def refs_aggregate(self, existing_aggregates):
return refs_aggregate(self.name.split(LOOKUP_SEP), existing_aggregates)
class Func(ExpressionNode):
"""
A SQL function call.
"""
function = None
template = '%(function)s(%(expressions)s)'
arg_joiner = ', '
def __init__(self, *expressions, **extra):
output_field = extra.pop('output_field', None)
super(Func, self).__init__(output_field=output_field)
self.source_expressions = self._parse_expressions(*expressions)
self.extra = extra
def get_source_expressions(self):
return self.source_expressions
def set_source_expressions(self, exprs):
self.source_expressions = exprs
def _parse_expressions(self, *expressions):
return [
arg if hasattr(arg, 'resolve_expression') else F(arg)
for arg in expressions
]
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = self.copy()
c.is_summary = summarize
for pos, arg in enumerate(c.source_expressions):
c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize)
return c
def as_sql(self, compiler, connection, function=None, template=None):
sql_parts = []
params = []
for arg in self.source_expressions:
arg_sql, arg_params = compiler.compile(arg)
sql_parts.append(arg_sql)
params.extend(arg_params)
if function is None:
self.extra['function'] = self.extra.get('function', self.function)
else:
self.extra['function'] = function
self.extra['expressions'] = self.extra['field'] = self.arg_joiner.join(sql_parts)
template = template or self.extra.get('template', self.template)
return template % self.extra, params
def copy(self):
copy = super(Func, self).copy()
copy.source_expressions = self.source_expressions[:]
copy.extra = self.extra.copy()
return copy
class Value(ExpressionNode):
"""
Represents a wrapped value as a node within an expression
"""
def __init__(self, value, output_field=None):
"""
Arguments:
* value: the value this expression represents. The value will be
added into the sql parameter list and properly quoted.
* output_field: an instance of the model field type that this
expression will return, such as IntegerField() or CharField().
"""
super(Value, self).__init__(output_field=output_field)
self.value = value
def as_sql(self, compiler, connection):
return '%s', [self.value]
class Col(ExpressionNode):
def __init__(self, alias, target, source=None):
if source is None:
source = target
super(Col, self).__init__(output_field=source)
self.alias, self.target = alias, target
def as_sql(self, qn, connection):
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.output_field)
def get_group_by_cols(self):
return [(self.alias, self.target.column)]
class Ref(ExpressionNode):
"""
Reference to column alias of the query. For example, Ref('sum_cost') in
qs.annotate(sum_cost=Sum('cost')) query.
"""
def __init__(self, refs, source):
super(Ref, self).__init__()
self.source = source
self.refs = refs
def get_source_expressions(self):
return [self.source]
def set_source_expressions(self, exprs):
self.source, = exprs
def relabeled_clone(self, relabels):
return self
def as_sql(self, compiler, connection):
return "%s" % compiler(self.refs), []
def get_group_by_cols(self):
return [(None, self.refs)]
class Date(ExpressionNode):
"""
Add a date selection column.
"""
def __init__(self, col, lookup_type):
super(Date, self).__init__(output_field=fields.DateField())
self.col = col
self.lookup_type = lookup_type
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = self.exprs
def as_sql(self, qn, connection):
sql, params = self.col.as_sql(qn, connection)
assert not(params)
return connection.ops.date_trunc_sql(self.lookup_type, sql), []
class DateTime(ExpressionNode):
"""
Add a datetime selection column.
"""
def __init__(self, col, lookup_type, tzname):
super(DateTime, self).__init__(output_field=fields.DateTimeField())
self.col = col
self.lookup_type = lookup_type
self.tzname = tzname
def get_source_expressions(self):
return [self.col]
def set_source_expressions(self, exprs):
self.col, = exprs
def as_sql(self, qn, connection):
sql, params = self.col.as_sql(qn, connection)
assert not(params)
return connection.ops.datetime_trunc_sql(self.lookup_type, sql, self.tzname)

View File

@ -637,8 +637,6 @@ class Field(RegisterLookupMixin):
""" """
Perform preliminary non-db specific lookup checks and conversions Perform preliminary non-db specific lookup checks and conversions
""" """
if hasattr(value, 'prepare'):
return value.prepare()
if hasattr(value, '_prepare'): if hasattr(value, '_prepare'):
return value._prepare() return value._prepare()

View File

@ -13,7 +13,7 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
from django.db.models.lookups import IsNull from django.db.models.lookups import IsNull
from django.db.models.related import RelatedObject, PathInfo from django.db.models.related import RelatedObject, PathInfo
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.db.models.sql.datastructures import Col from django.db.models.expressions import Col
from django.utils.encoding import force_text, smart_text from django.utils.encoding import force_text, smart_text
from django.utils import six from django.utils import six
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _

View File

@ -154,8 +154,7 @@ class QuerySet(object):
2. sql/compiler.results_iter() 2. sql/compiler.results_iter()
- Returns one row at time. At this point the rows are still just - Returns one row at time. At this point the rows are still just
tuples. In some cases the return values are converted to tuples. In some cases the return values are converted to
Python values at this location (see resolve_columns(), Python values at this location.
resolve_aggregate()).
3. self.iterator() 3. self.iterator()
- Responsible for turning the rows into model objects. - Responsible for turning the rows into model objects.
""" """
@ -241,7 +240,7 @@ class QuerySet(object):
max_depth = self.query.max_depth max_depth = self.query.max_depth
extra_select = list(self.query.extra_select) extra_select = list(self.query.extra_select)
aggregate_select = list(self.query.aggregate_select) annotation_select = list(self.query.annotation_select)
only_load = self.query.get_loaded_field_names() only_load = self.query.get_loaded_field_names()
fields = self.model._meta.concrete_fields fields = self.model._meta.concrete_fields
@ -282,7 +281,7 @@ class QuerySet(object):
db = self.db db = self.db
compiler = self.query.get_compiler(using=db) compiler = self.query.get_compiler(using=db)
index_start = len(extra_select) index_start = len(extra_select)
aggregate_start = index_start + len(init_list) annotation_start = index_start + len(init_list)
if fill_cache: if fill_cache:
klass_info = get_klass_info(model_cls, max_depth=max_depth, klass_info = get_klass_info(model_cls, max_depth=max_depth,
@ -290,18 +289,18 @@ class QuerySet(object):
for row in compiler.results_iter(): for row in compiler.results_iter():
if fill_cache: if fill_cache:
obj, _ = get_cached_row(row, index_start, db, klass_info, obj, _ = get_cached_row(row, index_start, db, klass_info,
offset=len(aggregate_select)) offset=len(annotation_select))
else: else:
obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start]) obj = model_cls.from_db(db, init_list, row[index_start:annotation_start])
if extra_select: if extra_select:
for i, k in enumerate(extra_select): for i, k in enumerate(extra_select):
setattr(obj, k, row[i]) setattr(obj, k, row[i])
# Add the aggregates to the model # Add the annotations to the model
if aggregate_select: if annotation_select:
for i, aggregate in enumerate(aggregate_select): for i, annotation in enumerate(annotation_select):
setattr(obj, aggregate, row[i + aggregate_start]) setattr(obj, annotation, row[i + annotation_start])
# Add the known related objects to the model, if there are any # Add the known related objects to the model, if there are any
if self._known_related_objects: if self._known_related_objects:
@ -330,13 +329,16 @@ class QuerySet(object):
if self.query.distinct_fields: if self.query.distinct_fields:
raise NotImplementedError("aggregate() + distinct(fields) not implemented.") raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
for arg in args: for arg in args:
if not hasattr(arg, 'default_alias'):
raise TypeError("Complex aggregates require an alias")
kwargs[arg.default_alias] = arg kwargs[arg.default_alias] = arg
query = self.query.clone() query = self.query.clone()
force_subq = query.low_mark != 0 or query.high_mark is not None force_subq = query.low_mark != 0 or query.high_mark is not None
for (alias, aggregate_expr) in kwargs.items(): for (alias, aggregate_expr) in kwargs.items():
query.add_aggregate(aggregate_expr, self.model, alias, query.add_annotation(aggregate_expr, self.model, alias, is_summary=True)
is_summary=True) if not query.annotations[alias].contains_aggregate:
raise TypeError("%s is not an aggregate expression" % alias)
return query.get_aggregation(using=self.db, force_subq=force_subq) return query.get_aggregation(using=self.db, force_subq=force_subq)
def count(self): def count(self):
@ -787,33 +789,40 @@ class QuerySet(object):
def annotate(self, *args, **kwargs): def annotate(self, *args, **kwargs):
""" """
Return a query set in which the returned objects have been annotated Return a query set in which the returned objects have been annotated
with data aggregated from related fields. with extra data or aggregations.
""" """
aggrs = OrderedDict() # To preserve ordering of args annotations = OrderedDict() # To preserve ordering of args
for arg in args: for arg in args:
try:
# we can't do an hasattr here because py2 returns False
# if default_alias exists but throws a TypeError
if arg.default_alias in kwargs: if arg.default_alias in kwargs:
raise ValueError("The named annotation '%s' conflicts with the " raise ValueError("The named annotation '%s' conflicts with the "
"default name for another annotation." "default name for another annotation."
% arg.default_alias) % arg.default_alias)
aggrs[arg.default_alias] = arg except AttributeError: # default_alias
aggrs.update(kwargs) raise TypeError("Complex annotations require an alias")
annotations[arg.default_alias] = arg
annotations.update(kwargs)
obj = self._clone()
names = getattr(self, '_fields', None) names = getattr(self, '_fields', None)
if names is None: if names is None:
names = set(self.model._meta.get_all_field_names()) names = set(self.model._meta.get_all_field_names())
for aggregate in aggrs:
if aggregate in names: # Add the annotations to the query
for alias, annotation in annotations.items():
if alias in names:
raise ValueError("The annotation '%s' conflicts with a field on " raise ValueError("The annotation '%s' conflicts with a field on "
"the model." % aggregate) "the model." % alias)
obj.query.add_annotation(annotation, self.model, alias, is_summary=False)
obj = self._clone() # expressions need to be added to the query before we know if they contain aggregates
added_aggregates = []
obj._setup_aggregate_query(list(aggrs)) for alias, annotation in obj.query.annotations.items():
if alias in annotations and annotation.contains_aggregate:
# Add the aggregates to the query added_aggregates.append(alias)
for (alias, aggregate_expr) in aggrs.items(): if added_aggregates:
obj.query.add_aggregate(aggregate_expr, self.model, alias, obj._setup_aggregate_query(list(added_aggregates))
is_summary=False)
return obj return obj
@ -1096,9 +1105,9 @@ class ValuesQuerySet(QuerySet):
# Purge any extra columns that haven't been explicitly asked for # Purge any extra columns that haven't been explicitly asked for
extra_names = list(self.query.extra_select) extra_names = list(self.query.extra_select)
field_names = self.field_names field_names = self.field_names
aggregate_names = list(self.query.aggregate_select) annotation_names = list(self.query.annotation_select)
names = extra_names + field_names + aggregate_names names = extra_names + field_names + annotation_names
for row in self.query.get_compiler(self.db).results_iter(): for row in self.query.get_compiler(self.db).results_iter():
yield dict(zip(names, row)) yield dict(zip(names, row))
@ -1122,9 +1131,9 @@ class ValuesQuerySet(QuerySet):
if self._fields: if self._fields:
self.extra_names = [] self.extra_names = []
self.aggregate_names = [] self.annotation_names = []
if not self.query._extra and not self.query._aggregates: if not self.query._extra and not self.query._annotations:
# Short cut - if there are no extra or aggregates, then # Short cut - if there are no extra or annotations, then
# the values() clause must be just field names. # the values() clause must be just field names.
self.field_names = list(self._fields) self.field_names = list(self._fields)
else: else:
@ -1136,22 +1145,22 @@ class ValuesQuerySet(QuerySet):
# had selected previously. # had selected previously.
if self.query._extra and f in self.query._extra: if self.query._extra and f in self.query._extra:
self.extra_names.append(f) self.extra_names.append(f)
elif f in self.query.aggregate_select: elif f in self.query.annotation_select:
self.aggregate_names.append(f) self.annotation_names.append(f)
else: else:
self.field_names.append(f) self.field_names.append(f)
else: else:
# Default to all fields. # Default to all fields.
self.extra_names = None self.extra_names = None
self.field_names = [f.attname for f in self.model._meta.concrete_fields] self.field_names = [f.attname for f in self.model._meta.concrete_fields]
self.aggregate_names = None self.annotation_names = None
self.query.select = [] self.query.select = []
if self.extra_names is not None: if self.extra_names is not None:
self.query.set_extra_mask(self.extra_names) self.query.set_extra_mask(self.extra_names)
self.query.add_fields(self.field_names, True) self.query.add_fields(self.field_names, True)
if self.aggregate_names is not None: if self.annotation_names is not None:
self.query.set_aggregate_mask(self.aggregate_names) self.query.set_annotation_mask(self.annotation_names)
def _clone(self, klass=None, setup=False, **kwargs): def _clone(self, klass=None, setup=False, **kwargs):
""" """
@ -1164,7 +1173,7 @@ class ValuesQuerySet(QuerySet):
c._fields = self._fields[:] c._fields = self._fields[:]
c.field_names = self.field_names c.field_names = self.field_names
c.extra_names = self.extra_names c.extra_names = self.extra_names
c.aggregate_names = self.aggregate_names c.annotation_names = self.annotation_names
if setup and hasattr(c, '_setup_query'): if setup and hasattr(c, '_setup_query'):
c._setup_query() c._setup_query()
return c return c
@ -1173,7 +1182,7 @@ class ValuesQuerySet(QuerySet):
super(ValuesQuerySet, self)._merge_sanity_check(other) super(ValuesQuerySet, self)._merge_sanity_check(other)
if (set(self.extra_names) != set(other.extra_names) or if (set(self.extra_names) != set(other.extra_names) or
set(self.field_names) != set(other.field_names) or set(self.field_names) != set(other.field_names) or
self.aggregate_names != other.aggregate_names): self.annotation_names != other.annotation_names):
raise TypeError("Merging '%s' classes must involve the same values in each case." raise TypeError("Merging '%s' classes must involve the same values in each case."
% self.__class__.__name__) % self.__class__.__name__)
@ -1183,9 +1192,9 @@ class ValuesQuerySet(QuerySet):
""" """
self.query.set_group_by() self.query.set_group_by()
if self.aggregate_names is not None: if self.annotation_names is not None:
self.aggregate_names.extend(aggregates) self.annotation_names.extend(aggregates)
self.query.set_aggregate_mask(self.aggregate_names) self.query.set_annotation_mask(self.annotation_names)
super(ValuesQuerySet, self)._setup_aggregate_query(aggregates) super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
@ -1231,7 +1240,7 @@ class ValuesListQuerySet(ValuesQuerySet):
if self.flat and len(self._fields) == 1: if self.flat and len(self._fields) == 1:
for row in self.query.get_compiler(self.db).results_iter(): for row in self.query.get_compiler(self.db).results_iter():
yield row[0] yield row[0]
elif not self.query.extra_select and not self.query.aggregate_select: elif not self.query.extra_select and not self.query.annotation_select:
for row in self.query.get_compiler(self.db).results_iter(): for row in self.query.get_compiler(self.db).results_iter():
yield tuple(row) yield tuple(row)
else: else:
@ -1240,14 +1249,14 @@ class ValuesListQuerySet(ValuesQuerySet):
# the fields to match the order in self._fields. # the fields to match the order in self._fields.
extra_names = list(self.query.extra_select) extra_names = list(self.query.extra_select)
field_names = self.field_names field_names = self.field_names
aggregate_names = list(self.query.aggregate_select) annotation_names = list(self.query.annotation_select)
names = extra_names + field_names + aggregate_names names = extra_names + field_names + annotation_names
# If a field list has been specified, use it. Otherwise, use the # If a field list has been specified, use it. Otherwise, use the
# full list of fields, including extras and aggregates. # full list of fields, including extras and annotations.
if self._fields: if self._fields:
fields = list(self._fields) + [f for f in aggregate_names if f not in self._fields] fields = list(self._fields) + [f for f in annotation_names if f not in self._fields]
else: else:
fields = names fields = names

View File

@ -9,6 +9,7 @@ from __future__ import unicode_literals
from django.apps import apps from django.apps import apps
from django.db.backends import utils from django.db.backends import utils
from django.db.models.constants import LOOKUP_SEP
from django.utils import six from django.utils import six
from django.utils import tree from django.utils import tree
@ -220,3 +221,17 @@ def deferred_class_factory(model, attrs):
# The above function is also used to unpickle model instances with deferred # The above function is also used to unpickle model instances with deferred
# fields. # fields.
deferred_class_factory.__safe_for_unpickling__ = True deferred_class_factory.__safe_for_unpickling__ = True
def refs_aggregate(lookup_parts, aggregates):
"""
A little helper method to check if the lookup_parts contains references
to the given aggregates set. Because the LOOKUP_SEP is contained in the
default annotation names we must check each prefix of the lookup_parts
for a match.
"""
for n in range(len(lookup_parts) + 1):
level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
if level_n_lookup in aggregates and aggregates[level_n_lookup].contains_aggregate:
return aggregates[level_n_lookup], lookup_parts[n:]
return False, ()

View File

@ -2,15 +2,23 @@
Classes to represent the default SQL aggregate functions Classes to represent the default SQL aggregate functions
""" """
import copy import copy
import warnings
from django.db.models.fields import IntegerField, FloatField from django.db.models.fields import IntegerField, FloatField
from django.db.models.lookups import RegisterLookupMixin from django.db.models.lookups import RegisterLookupMixin
from django.utils.deprecation import RemovedInDjango20Warning
from django.utils.functional import cached_property from django.utils.functional import cached_property
__all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance'] __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance']
warnings.warn(
"django.db.models.sql.aggregates is deprecated. Use "
"django.db.models.aggregates instead.",
RemovedInDjango20Warning, stacklevel=2)
class Aggregate(RegisterLookupMixin): class Aggregate(RegisterLookupMixin):
""" """
Default SQL Aggregate. Default SQL Aggregate.

View File

@ -4,12 +4,10 @@ from django.conf import settings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.backends.utils import truncate_name from django.db.backends.utils import truncate_name
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import ExpressionNode
from django.db.models.query_utils import select_related_descend, QueryWrapper from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS, from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo) ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query from django.db.models.sql.query import get_order_dir, Query
from django.db.transaction import TransactionManagementError from django.db.transaction import TransactionManagementError
from django.db.utils import DatabaseError from django.db.utils import DatabaseError
@ -248,8 +246,8 @@ class SQLCompiler(object):
aliases.update(new_aliases) aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length() max_name_length = self.connection.ops.max_name_length()
for alias, aggregate in self.query.aggregate_select.items(): for alias, annotation in self.query.annotation_select.items():
agg_sql, agg_params = self.compile(aggregate) agg_sql, agg_params = self.compile(annotation)
if alias is None: if alias is None:
result.append(agg_sql) result.append(agg_sql)
else: else:
@ -409,7 +407,7 @@ class SQLCompiler(object):
group_by.append((str(field), [])) group_by.append((str(field), []))
continue continue
col, order = get_order_dir(field, asc) col, order = get_order_dir(field, asc)
if col in self.query.aggregate_select: if col in self.query.annotation_select:
result.append('%s %s' % (qn(col), order)) result.append('%s %s' % (qn(col), order))
continue continue
if '.' in field: if '.' in field:
@ -718,25 +716,17 @@ class SQLCompiler(object):
""" """
fields = None fields = None
converters = None converters = None
has_aggregate_select = bool(self.query.aggregate_select) has_annotation_select = bool(self.query.annotation_select)
for rows in self.execute_sql(MULTI): for rows in self.execute_sql(MULTI):
for row in rows: for row in rows:
if has_aggregate_select:
loaded_fields = (
self.query.get_loaded_field_names().get(self.query.model, set()) or
self.query.select
)
aggregate_start = len(self.query.extra_select) + len(loaded_fields)
aggregate_end = aggregate_start + len(self.query.aggregate_select)
if fields is None: if fields is None:
# We only set this up here because # We only set this up here because
# related_select_cols isn't populated until # related_select_cols isn't populated until
# execute_sql() has been called. # execute_sql() has been called.
# We also include types of fields of related models that # If the field was deferred, exclude it from being passed
# will be included via select_related() for the benefit # into `get_converters` because it wasn't selected.
# of MySQL/MySQLdb when boolean fields are involved only_load = self.deferred_to_columns()
# (#15040).
# This code duplicates the logic for the order of fields # This code duplicates the logic for the order of fields
# found in get_columns(). It would be nice to clean this up. # found in get_columns(). It would be nice to clean this up.
@ -746,30 +736,45 @@ class SQLCompiler(object):
fields = self.query.get_meta().concrete_fields fields = self.query.get_meta().concrete_fields
else: else:
fields = [] fields = []
fields = fields + [f.field for f in self.query.related_select_cols]
# If the field was deferred, exclude it from being passed
# into `get_converters` because it wasn't selected.
only_load = self.deferred_to_columns()
if only_load: if only_load:
fields = [f for f in fields if f.model._meta.db_table not in only_load or # strip deferred fields
f.column in only_load[f.model._meta.db_table]] fields = [
if has_aggregate_select: f for f in fields if
# pad None in to fields for aggregates f.model._meta.db_table not in only_load or
fields = fields[:aggregate_start] + [ f.column in only_load[f.model._meta.db_table]
None for x in range(0, aggregate_end - aggregate_start) ]
] + fields[aggregate_start:]
# annotations come before the related cols
if has_annotation_select:
# extra is always at the start of the field list
prepended_cols = len(self.query.extra_select)
annotation_start = len(fields) + prepended_cols
fields = fields + [
anno.output_field for alias, anno in self.query.annotation_select.items()]
annotation_end = len(fields) + prepended_cols
# add related fields
fields = fields + [
# strip deferred
f.field for f in self.query.related_select_cols if
f.field.model._meta.db_table not in only_load or
f.field.column in only_load[f.field.model._meta.db_table]
]
converters = self.get_converters(fields) converters = self.get_converters(fields)
if has_annotation_select:
for (alias, annotation), position in zip(
self.query.annotation_select.items(),
range(annotation_start, annotation_end + 1)):
if position in converters:
# annotation conversions always run first
converters[position][1].insert(0, annotation.convert_value)
else:
converters[position] = ([], [annotation.convert_value], annotation.output_field)
if converters: if converters:
row = self.apply_converters(row, converters) row = self.apply_converters(row, converters)
if has_aggregate_select:
row = tuple(row[:aggregate_start]) + tuple(
self.query.resolve_aggregate(value, aggregate, self.connection)
for (alias, aggregate), value
in zip(self.query.aggregate_select.items(), row[aggregate_start:aggregate_end])
) + tuple(row[aggregate_end:])
yield row yield row
def has_results(self): def has_results(self):
@ -878,7 +883,7 @@ class SQLInsertCompiler(SQLCompiler):
elif hasattr(field, 'get_placeholder'): elif hasattr(field, 'get_placeholder'):
# Some fields (e.g. geo fields) need special munging before # Some fields (e.g. geo fields) need special munging before
# they can be inserted. # they can be inserted.
return field.get_placeholder(val, self.connection) return field.get_placeholder(val, self, self.connection)
else: else:
# Return the common case for the placeholder # Return the common case for the placeholder
return '%s' return '%s'
@ -985,8 +990,10 @@ class SQLUpdateCompiler(SQLCompiler):
result.append('SET') result.append('SET')
values, update_params = [], [] values, update_params = [], []
for field, model, val in self.query.values: for field, model, val in self.query.values:
if hasattr(val, 'prepare_database_save'): if hasattr(val, 'resolve_expression'):
if field.rel or isinstance(val, ExpressionNode): val = val.resolve_expression(self.query, allow_joins=False)
elif hasattr(val, 'prepare_database_save'):
if field.rel:
val = val.prepare_database_save(field) val = val.prepare_database_save(field)
else: else:
raise TypeError("Database is trying to update a relational field " raise TypeError("Database is trying to update a relational field "
@ -998,12 +1005,9 @@ class SQLUpdateCompiler(SQLCompiler):
# Getting the placeholder for the field. # Getting the placeholder for the field.
if hasattr(field, 'get_placeholder'): if hasattr(field, 'get_placeholder'):
placeholder = field.get_placeholder(val, self.connection) placeholder = field.get_placeholder(val, self, self.connection)
else: else:
placeholder = '%s' placeholder = '%s'
if hasattr(val, 'evaluate'):
val = SQLEvaluator(val, self.query, allow_joins=False)
name = field.column name = field.column
if hasattr(val, 'as_sql'): if hasattr(val, 'as_sql'):
sql, params = self.compile(val) sql, params = self.compile(val)
@ -1103,8 +1107,8 @@ class SQLAggregateCompiler(SQLCompiler):
qn = self qn = self
sql, params = [], [] sql, params = [], []
for aggregate in self.query.aggregate_select.values(): for annotation in self.query.annotation_select.values():
agg_sql, agg_params = self.compile(aggregate) agg_sql, agg_params = self.compile(annotation)
sql.append(agg_sql) sql.append(agg_sql)
params.extend(agg_params) params.extend(agg_params)
sql = ', '.join(sql) sql = ', '.join(sql)

View File

@ -4,33 +4,6 @@ the SQL domain.
""" """
class Col(object):
def __init__(self, alias, target, source):
self.alias, self.target, self.source = alias, target, source
def as_sql(self, qn, connection):
return "%s.%s" % (qn(self.alias), qn(self.target.column)), []
@property
def output_field(self):
return self.source
def relabeled_clone(self, relabels):
return self.__class__(relabels.get(self.alias, self.alias), self.target, self.source)
def get_group_by_cols(self):
return [(self.alias, self.target.column)]
def get_lookup(self, name):
return self.output_field.get_lookup(name)
def get_transform(self, name):
return self.output_field.get_transform(name)
def prepare(self):
return self
class EmptyResultSet(Exception): class EmptyResultSet(Exception):
pass pass
@ -49,42 +22,3 @@ class MultiJoin(Exception):
class Empty(object): class Empty(object):
pass pass
class Date(object):
"""
Add a date selection column.
"""
def __init__(self, col, lookup_type):
self.col = col
self.lookup_type = lookup_type
def relabeled_clone(self, change_map):
return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
def as_sql(self, qn, connection):
if isinstance(self.col, (list, tuple)):
col = '%s.%s' % tuple(qn(c) for c in self.col)
else:
col = self.col
return connection.ops.date_trunc_sql(self.lookup_type, col), []
class DateTime(object):
"""
Add a datetime selection column.
"""
def __init__(self, col, lookup_type, tzname):
self.col = col
self.lookup_type = lookup_type
self.tzname = tzname
def relabeled_clone(self, change_map):
return self.__class__((change_map.get(self.col[0], self.col[0]), self.col[1]))
def as_sql(self, qn, connection):
if isinstance(self.col, (list, tuple)):
col = '%s.%s' % tuple(qn(c) for c in self.col)
else:
col = self.col
return connection.ops.datetime_trunc_sql(self.lookup_type, col, self.tzname)

View File

@ -1,119 +0,0 @@
import copy
from django.core.exceptions import FieldError
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import FieldDoesNotExist
class SQLEvaluator(object):
def __init__(self, expression, query, allow_joins=True, reuse=None):
self.expression = expression
self.opts = query.get_meta()
self.reuse = reuse
self.cols = []
self.expression.prepare(self, query, allow_joins)
def relabeled_clone(self, change_map):
clone = copy.copy(self)
clone.cols = []
for node, col in self.cols:
if hasattr(col, 'relabeled_clone'):
clone.cols.append((node, col.relabeled_clone(change_map)))
else:
clone.cols.append((node,
(change_map.get(col[0], col[0]), col[1])))
return clone
def get_group_by_cols(self):
cols = []
for node, col in self.cols:
if hasattr(node, 'get_group_by_cols'):
cols.extend(node.get_group_by_cols())
elif isinstance(col, tuple):
cols.append(col)
return cols
def prepare(self):
return self
def as_sql(self, qn, connection):
return self.expression.evaluate(self, qn, connection)
#####################################################
# Visitor methods for initial expression preparation #
#####################################################
def prepare_node(self, node, query, allow_joins):
for child in node.children:
if hasattr(child, 'prepare'):
child.prepare(self, query, allow_joins)
def prepare_leaf(self, node, query, allow_joins):
if not allow_joins and LOOKUP_SEP in node.name:
raise FieldError("Joined field references are not permitted in this query")
field_list = node.name.split(LOOKUP_SEP)
if node.name in query.aggregates:
self.cols.append((node, query.aggregate_select[node.name]))
else:
try:
_, sources, _, join_list, path = query.setup_joins(
field_list, query.get_meta(), query.get_initial_alias(),
can_reuse=self.reuse)
self._used_joins = join_list
targets, _, join_list = query.trim_joins(sources, join_list, path)
if self.reuse is not None:
self.reuse.update(join_list)
for t in targets:
self.cols.append((node, (join_list[-1], t.column)))
except FieldDoesNotExist:
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (self.name,
[f.name for f in self.opts.fields]))
##################################################
# Visitor methods for final expression evaluation #
##################################################
def evaluate_node(self, node, qn, connection):
expressions = []
expression_params = []
for child in node.children:
if hasattr(child, 'evaluate'):
sql, params = child.evaluate(self, qn, connection)
else:
sql, params = '%s', (child,)
if len(getattr(child, 'children', [])) > 1:
format = '(%s)'
else:
format = '%s'
if sql:
expressions.append(format % sql)
expression_params.extend(params)
return connection.ops.combine_expression(node.connector, expressions), expression_params
def evaluate_leaf(self, node, qn, connection):
col = None
for n, c in self.cols:
if n is node:
col = c
break
if col is None:
raise ValueError("Given node not found")
if hasattr(col, 'as_sql'):
return col.as_sql(qn, connection)
else:
return '%s.%s' % (qn(col[0]), qn(col[1])), []
def evaluate_date_modifier_node(self, node, qn, connection):
timedelta = node.children.pop()
sql, params = self.evaluate_node(node, qn, connection)
node.children.append(timedelta)
if (timedelta.days == timedelta.seconds == timedelta.microseconds == 0):
return sql, params
return connection.ops.date_interval_sql(sql, node.connector, timedelta), params

View File

@ -14,20 +14,18 @@ import warnings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connections, DEFAULT_DB_ALIAS from django.db import connections, DEFAULT_DB_ALIAS
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.aggregates import refs_aggregate from django.db.models.expressions import Col, Ref
from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
from django.db.models.query_utils import Q from django.db.models.query_utils import Q, refs_aggregate
from django.db.models.related import PathInfo from django.db.models.related import PathInfo
from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.aggregates import Count
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
ORDER_PATTERN, JoinInfo, SelectInfo) ORDER_PATTERN, JoinInfo, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin, Col from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
ExtraWhere, AND, OR, EmptyWhere) ExtraWhere, AND, OR, EmptyWhere)
from django.utils import six from django.utils import six
from django.utils.deprecation import RemovedInDjango19Warning from django.utils.deprecation import RemovedInDjango19Warning, RemovedInDjango20Warning
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.tree import Node from django.utils.tree import Node
@ -49,7 +47,7 @@ class RawQuery(object):
# the compiler can be used to process results. # the compiler can be used to process results.
self.low_mark, self.high_mark = 0, None # Used for offset/limit self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.extra_select = {} self.extra_select = {}
self.aggregate_select = {} self.annotation_select = {}
def clone(self, using): def clone(self, using):
return RawQuery(self.sql, using, params=self.params) return RawQuery(self.sql, using, params=self.params)
@ -97,7 +95,6 @@ class Query(object):
alias_prefix = 'T' alias_prefix = 'T'
subq_aliases = frozenset([alias_prefix]) subq_aliases = frozenset([alias_prefix])
query_terms = QUERY_TERMS query_terms = QUERY_TERMS
aggregates_module = base_aggregates_module
compiler = 'SQLCompiler' compiler = 'SQLCompiler'
@ -140,13 +137,13 @@ class Query(object):
self.select_for_update_nowait = False self.select_for_update_nowait = False
self.select_related = False self.select_related = False
# SQL aggregate-related attributes # SQL annotation-related attributes
# The _aggregates will be an OrderedDict when used. Due to the cost # The _annotations will be an OrderedDict when used. Due to the cost
# of creating OrderedDict this attribute is created lazily (in # of creating OrderedDict this attribute is created lazily (in
# self.aggregates property). # self.annotations property).
self._aggregates = None # Maps alias -> SQL aggregate function self._annotations = None # Maps alias -> Annotation Expression
self.aggregate_select_mask = None self.annotation_select_mask = None
self._aggregate_select_cache = None self._annotation_select_cache = None
# Arbitrary maximum limit for select_related. Prevents infinite # Arbitrary maximum limit for select_related. Prevents infinite
# recursion. Can be changed by the depth parameter to select_related(). # recursion. Can be changed by the depth parameter to select_related().
@ -155,7 +152,7 @@ class Query(object):
# These are for extensions. The contents are more or less appended # These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause. # verbatim to the appropriate clause.
# The _extra attribute is an OrderedDict, lazily created similarly to # The _extra attribute is an OrderedDict, lazily created similarly to
# .aggregates # .annotations
self._extra = None # Maps col_alias -> (col_sql, params). self._extra = None # Maps col_alias -> (col_sql, params).
self.extra_select_mask = None self.extra_select_mask = None
self._extra_select_cache = None self._extra_select_cache = None
@ -174,11 +171,18 @@ class Query(object):
self._extra = OrderedDict() self._extra = OrderedDict()
return self._extra return self._extra
@property
def annotations(self):
if self._annotations is None:
self._annotations = OrderedDict()
return self._annotations
@property @property
def aggregates(self): def aggregates(self):
if self._aggregates is None: warnings.warn(
self._aggregates = OrderedDict() "The aggregates property is deprecated. Use annotations instead.",
return self._aggregates RemovedInDjango20Warning, stacklevel=2)
return self.annotations
def __str__(self): def __str__(self):
""" """
@ -203,7 +207,7 @@ class Query(object):
memo[id(self)] = result memo[id(self)] = result
return result return result
def prepare(self): def _prepare(self):
return self return self
def get_compiler(self, using=None, connection=None): def get_compiler(self, using=None, connection=None):
@ -213,8 +217,8 @@ class Query(object):
connection = connections[using] connection = connections[using]
# Check that the compiler will be able to execute the query # Check that the compiler will be able to execute the query
for alias, aggregate in self.aggregate_select.items(): for alias, annotation in self.annotation_select.items():
connection.ops.check_aggregate_support(aggregate) connection.ops.check_aggregate_support(annotation)
return connection.ops.compiler(self.compiler)(self, connection, using) return connection.ops.compiler(self.compiler)(self, connection, using)
@ -260,17 +264,17 @@ class Query(object):
obj.select_for_update_nowait = self.select_for_update_nowait obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_related = self.select_related obj.select_related = self.select_related
obj.related_select_cols = [] obj.related_select_cols = []
obj._aggregates = self._aggregates.copy() if self._aggregates is not None else None obj._annotations = self._annotations.copy() if self._annotations is not None else None
if self.aggregate_select_mask is None: if self.annotation_select_mask is None:
obj.aggregate_select_mask = None obj.annotation_select_mask = None
else: else:
obj.aggregate_select_mask = self.aggregate_select_mask.copy() obj.annotation_select_mask = self.annotation_select_mask.copy()
# _aggregate_select_cache cannot be copied, as doing so breaks the # _annotation_select_cache cannot be copied, as doing so breaks the
# (necessary) state in which both aggregates and # (necessary) state in which both annotations and
# _aggregate_select_cache point to the same underlying objects. # _annotation_select_cache point to the same underlying objects.
# It will get re-populated in the cloned queryset the next time it's # It will get re-populated in the cloned queryset the next time it's
# used. # used.
obj._aggregate_select_cache = None obj._annotation_select_cache = None
obj.max_depth = self.max_depth obj.max_depth = self.max_depth
obj._extra = self._extra.copy() if self._extra is not None else None obj._extra = self._extra.copy() if self._extra is not None else None
if self.extra_select_mask is None: if self.extra_select_mask is None:
@ -299,94 +303,84 @@ class Query(object):
obj._setup_query() obj._setup_query()
return obj return obj
def resolve_aggregate(self, value, aggregate, connection):
"""Resolve the value of aggregates returned by the database to
consistent (and reasonable) types.
This is required because of the predisposition of certain backends
to return Decimal and long types when they are not needed.
"""
if value is None:
if aggregate.is_ordinal:
return 0
# Return None as-is
return value
elif aggregate.is_ordinal:
# Any ordinal aggregate (e.g., count) returns an int
return int(value)
elif aggregate.is_computed:
# Any computed aggregate (e.g., avg) returns a float
return float(value)
else:
# Return value depends on the type of the field being processed.
backend_converters = connection.ops.get_db_converters(aggregate.field.get_internal_type())
field_converters = aggregate.field.get_db_converters(connection)
for converter in backend_converters:
value = converter(value, aggregate.field)
for converter in field_converters:
value = converter(value, connection)
return value
def get_aggregation(self, using, force_subq=False): def get_aggregation(self, using, force_subq=False):
""" """
Returns the dictionary with the values of the existing aggregations. Returns the dictionary with the values of the existing aggregations.
""" """
if not self.aggregate_select: if not self.annotation_select:
return {} return {}
# annotations must be forced into subquery
has_annotation = any(
annotation for alias, annotation
in self.annotation_select.items()
if not annotation.contains_aggregate)
# If there is a group by clause, aggregating does not add useful # If there is a group by clause, aggregating does not add useful
# information but retrieves only the first row. Aggregate # information but retrieves only the first row. Aggregate
# over the subquery instead. # over the subquery instead.
if self.group_by is not None or force_subq: if self.group_by is not None or force_subq or has_annotation:
from django.db.models.sql.subqueries import AggregateQuery from django.db.models.sql.subqueries import AggregateQuery
query = AggregateQuery(self.model) outer_query = AggregateQuery(self.model)
obj = self.clone() inner_query = self.clone()
if not force_subq: if not force_subq:
# In forced subq case the ordering and limits will likely # In forced subq case the ordering and limits will likely
# affect the results. # affect the results.
obj.clear_ordering(True) inner_query.clear_ordering(True)
obj.clear_limits() inner_query.clear_limits()
obj.select_for_update = False inner_query.select_for_update = False
obj.select_related = False inner_query.select_related = False
obj.related_select_cols = [] inner_query.related_select_cols = []
relabels = dict((t, 'subquery') for t in self.tables) relabels = dict((t, 'subquery') for t in inner_query.tables)
relabels[None] = 'subquery'
# Remove any aggregates marked for reduction from the subquery # Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery. # and move them to the outer AggregateQuery.
for alias, aggregate in self.aggregate_select.items(): for alias, annotation in inner_query.annotation_select.items():
if aggregate.is_summary: if annotation.is_summary:
query.aggregates[alias] = aggregate.relabeled_clone(relabels) # The annotation is already referring the subquery alias, so we
del obj.aggregate_select[alias] # just need to move the annotation to the outer query.
outer_query.annotations[alias] = annotation.relabeled_clone(relabels)
del inner_query.annotation_select[alias]
try: try:
query.add_subquery(obj, using) outer_query.add_subquery(inner_query, using)
except EmptyResultSet: except EmptyResultSet:
return dict( return dict(
(alias, None) (alias, None)
for alias in query.aggregate_select for alias in outer_query.annotation_select
) )
else: else:
query = self outer_query = self
self.select = [] self.select = []
self.default_cols = False self.default_cols = False
self._extra = {} self._extra = {}
self.remove_inherited_models() self.remove_inherited_models()
query.clear_ordering(True) outer_query.clear_ordering(True)
query.clear_limits() outer_query.clear_limits()
query.select_for_update = False outer_query.select_for_update = False
query.select_related = False outer_query.select_related = False
query.related_select_cols = [] outer_query.related_select_cols = []
compiler = outer_query.get_compiler(using)
result = query.get_compiler(using).execute_sql(SINGLE) result = compiler.execute_sql(SINGLE)
if result is None: if result is None:
result = [None for q in query.aggregate_select.items()] result = [None for q in outer_query.annotation_select.items()]
fields = [annotation.output_field
for alias, annotation in outer_query.annotation_select.items()]
converters = compiler.get_converters(fields)
for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()):
if position in converters:
converters[position][1].insert(0, annotation.convert_value)
else:
converters[position] = ([], [annotation.convert_value], annotation.output_field)
result = compiler.apply_converters(result, converters)
return dict( return dict(
(alias, self.resolve_aggregate(val, aggregate, connection=connections[using])) (alias, val)
for (alias, aggregate), val for (alias, annotation), val
in zip(query.aggregate_select.items(), result) in zip(outer_query.annotation_select.items(), result)
) )
def get_count(self, using): def get_count(self, using):
@ -394,7 +388,7 @@ class Query(object):
Performs a COUNT() query using the current filter constraints. Performs a COUNT() query using the current filter constraints.
""" """
obj = self.clone() obj = self.clone()
if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields): if len(self.select) > 1 or self.annotation_select or (self.distinct and self.distinct_fields):
# If a select clause exists, then the query has already started to # If a select clause exists, then the query has already started to
# specify the columns that are to be returned. # specify the columns that are to be returned.
# In this case, we need to use a subquery to evaluate the count. # In this case, we need to use a subquery to evaluate the count.
@ -769,9 +763,9 @@ class Query(object):
self.group_by = [relabel_column(col) for col in self.group_by] self.group_by = [relabel_column(col) for col in self.group_by]
self.select = [SelectInfo(relabel_column(s.col), s.field) self.select = [SelectInfo(relabel_column(s.col), s.field)
for s in self.select] for s in self.select]
if self._aggregates: if self._annotations:
self._aggregates = OrderedDict( self._annotations = OrderedDict(
(key, relabel_column(col)) for key, col in self._aggregates.items()) (key, relabel_column(col)) for key, col in self._annotations.items())
# 2. Rename the alias in the internal table/alias datastructures. # 2. Rename the alias in the internal table/alias datastructures.
for ident, aliases in self.join_map.items(): for ident, aliases in self.join_map.items():
@ -974,52 +968,18 @@ class Query(object):
self.included_inherited_models = {} self.included_inherited_models = {}
def add_aggregate(self, aggregate, model, alias, is_summary): def add_aggregate(self, aggregate, model, alias, is_summary):
warnings.warn(
"add_aggregate() is deprecated. Use add_annotation() instead.",
RemovedInDjango20Warning, stacklevel=2)
self.add_annotation(aggregate, model, alias, is_summary)
def add_annotation(self, annotation, model, alias, is_summary):
""" """
Adds a single aggregate expression to the Query Adds a single annotation expression to the Query
""" """
opts = model._meta annotation = annotation.resolve_expression(self, summarize=is_summary)
field_list = aggregate.lookup.split(LOOKUP_SEP) self.append_annotation_mask([alias])
if len(field_list) == 1 and self._aggregates and aggregate.lookup in self.aggregates: self.annotations[alias] = annotation
# Aggregate is over an annotation
field_name = field_list[0]
col = field_name
source = self.aggregates[field_name]
if not is_summary:
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
aggregate.name, field_name, field_name))
elif ((len(field_list) > 1) or
(field_list[0] not in [i.name for i in opts.fields]) or
self.group_by is None or
not is_summary):
# If:
# - the field descriptor has more than one part (foo__bar), or
# - the field descriptor is referencing an m2m/m2o field, or
# - this is a reference to a model field (possibly inherited), or
# - this is an annotation over a model field
# then we need to explore the joins that are required.
# Join promotion note - we must not remove any rows here, so use
# outer join if there isn't any existing join.
_, sources, opts, join_list, path = self.setup_joins(
field_list, opts, self.get_initial_alias())
# Process the join chain to see if it can be trimmed
targets, _, join_list = self.trim_joins(sources, join_list, path)
col = targets[0].column
source = sources[0]
col = (join_list[-1], col)
else:
# The simplest cases. No joins required -
# just reference the provided column alias.
field_name = field_list[0]
source = opts.get_field(field_name)
col = field_name
# We want to have the alias in SELECT clause even if mask is set.
self.append_aggregate_mask([alias])
# Add the aggregate to the query
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
def prepare_lookup_value(self, value, lookups, can_reuse): def prepare_lookup_value(self, value, lookups, can_reuse):
# Default lookup if none given is exact. # Default lookup if none given is exact.
@ -1037,9 +997,8 @@ class Query(object):
"Passing callable arguments to queryset is deprecated.", "Passing callable arguments to queryset is deprecated.",
RemovedInDjango19Warning, stacklevel=2) RemovedInDjango19Warning, stacklevel=2)
value = value() value = value()
elif isinstance(value, ExpressionNode): elif hasattr(value, 'resolve_expression'):
# If value is a query expression, evaluate it value = value.resolve_expression(self, reuse=can_reuse)
value = SQLEvaluator(value, self, reuse=can_reuse)
if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'): if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'):
value = value._clone() value = value._clone()
value.query.bump_prefix(self) value.query.bump_prefix(self)
@ -1061,8 +1020,8 @@ class Query(object):
Solve the lookup type from the lookup (eg: 'foobar__id__icontains') Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
""" """
lookup_splitted = lookup.split(LOOKUP_SEP) lookup_splitted = lookup.split(LOOKUP_SEP)
if self._aggregates: if self._annotations:
aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.aggregates) aggregate, aggregate_lookups = refs_aggregate(lookup_splitted, self.annotations)
if aggregate: if aggregate:
return aggregate_lookups, (), aggregate return aggregate_lookups, (), aggregate
_, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta()) _, field, _, lookup_parts = self.names_to_path(lookup_splitted, self.get_meta())
@ -1232,6 +1191,10 @@ class Query(object):
lookup_type = lookups[-1] lookup_type = lookups[-1]
else: else:
assert(len(targets) == 1) assert(len(targets) == 1)
if hasattr(targets[0], 'as_sql'):
# handle Expressions as annotations
col = targets[0]
else:
col = Col(alias, targets[0], field) col = Col(alias, targets[0], field)
condition = self.build_lookup(lookups, col, value) condition = self.build_lookup(lookups, col, value)
if not condition: if not condition:
@ -1278,12 +1241,12 @@ class Query(object):
Returns whether or not all elements of this q_object need to be put Returns whether or not all elements of this q_object need to be put
together in the HAVING clause. together in the HAVING clause.
""" """
if not self._aggregates: if not self._annotations:
return False return False
if not isinstance(obj, Node): if not isinstance(obj, Node):
return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.aggregates)[0] return (refs_aggregate(obj[0].split(LOOKUP_SEP), self.annotations)[0]
or (hasattr(obj[1], 'contains_aggregate') or (hasattr(obj[1], 'refs_aggregate')
and obj[1].contains_aggregate(self.aggregates))) and obj[1].refs_aggregate(self.annotations)[0]))
return any(self.need_having(c) for c in obj.children) return any(self.need_having(c) for c in obj.children)
def split_having_parts(self, q_object, negated=False): def split_having_parts(self, q_object, negated=False):
@ -1390,13 +1353,21 @@ class Query(object):
if name == 'pk': if name == 'pk':
name = opts.pk.name name = opts.pk.name
try: try:
field, model, direct, m2m = opts.get_field_by_name(name) field, model, _, _ = opts.get_field_by_name(name)
except FieldDoesNotExist: except FieldDoesNotExist:
# is it an annotation?
if self._annotations and name in self._annotations:
field, model = self._annotations[name], None
if not field.contains_aggregate:
# Local non-relational field.
final_field = field
targets = (field,)
break
# We didn't find the current field, so move position back # We didn't find the current field, so move position back
# one step. # one step.
pos -= 1 pos -= 1
if pos == -1 or fail_on_missing: if pos == -1 or fail_on_missing:
available = opts.get_all_field_names() + list(self.aggregate_select) available = opts.get_all_field_names() + list(self.annotation_select)
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(available))) "Choices are: %s" % (name, ", ".join(available)))
break break
@ -1445,6 +1416,11 @@ class Query(object):
break break
return path, final_field, targets, names[pos + 1:] return path, final_field, targets, names[pos + 1:]
def raise_field_error(self, opts, name):
available = opts.get_all_field_names() + list(self.annotation_select)
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(available)))
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True): def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True):
""" """
Compute the necessary table joins for the passage through the fields Compute the necessary table joins for the passage through the fields
@ -1519,6 +1495,29 @@ class Query(object):
self.unref_alias(joins.pop()) self.unref_alias(joins.pop())
return targets, joins[-1], joins return targets, joins[-1], joins
def resolve_ref(self, name, allow_joins, reuse, summarize):
if not allow_joins and LOOKUP_SEP in name:
raise FieldError("Joined field references are not permitted in this query")
if name in self.annotations:
if summarize:
return Ref(name, self.annotation_select[name])
else:
return self.annotation_select[name]
else:
field_list = name.split(LOOKUP_SEP)
field, sources, opts, join_list, path = self.setup_joins(
field_list, self.get_meta(),
self.get_initial_alias(), reuse)
targets, _, join_list = self.trim_joins(sources, join_list, path)
if len(targets) > 1:
raise FieldError("Referencing multicolumn fields with F() objects "
"isn't supported")
if reuse is not None:
reuse.update(join_list)
col = Col(join_list[-1], targets[0], sources[0])
col._used_joins = join_list
return col
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
""" """
When doing an exclude against any kind of N-to-many relation, we need When doing an exclude against any kind of N-to-many relation, we need
@ -1633,7 +1632,7 @@ class Query(object):
self.default_cols = False self.default_cols = False
self.select_related = False self.select_related = False
self.set_extra_mask(()) self.set_extra_mask(())
self.set_aggregate_mask(()) self.set_annotation_mask(())
def clear_select_fields(self): def clear_select_fields(self):
""" """
@ -1676,7 +1675,7 @@ class Query(object):
raise raise
else: else:
names = sorted(opts.get_all_field_names() + list(self.extra) names = sorted(opts.get_all_field_names() + list(self.extra)
+ list(self.aggregate_select)) + list(self.annotation_select))
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names))) "Choices are: %s" % (name, ", ".join(names)))
self.remove_inherited_models() self.remove_inherited_models()
@ -1725,39 +1724,55 @@ class Query(object):
for col, _ in self.select: for col, _ in self.select:
self.group_by.append(col) self.group_by.append(col)
if self._annotations:
for alias, annotation in six.iteritems(self.annotations):
for col in annotation.get_group_by_cols():
self.group_by.append(col)
def add_count_column(self): def add_count_column(self):
""" """
Converts the query to do count(...) or count(distinct(pk)) in order to Converts the query to do count(...) or count(distinct(pk)) in order to
get its size. get its size.
""" """
summarize = False
if not self.distinct: if not self.distinct:
if not self.select: if not self.select:
count = self.aggregates_module.Count('*', is_summary=True) count = Count('*')
summarize = True
else: else:
assert len(self.select) == 1, \ assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select': %r" % self.select "Cannot add count col with multiple cols in 'select': %r" % self.select
count = self.aggregates_module.Count(self.select[0].col) col = self.select[0].col
if isinstance(col, (tuple, list)):
count = Count(col[1])
else:
count = Count(col)
else: else:
opts = self.get_meta() opts = self.get_meta()
if not self.select: if not self.select:
count = self.aggregates_module.Count( lookup = self.join((None, opts.db_table, None)), opts.pk.column
(self.join((None, opts.db_table, None)), opts.pk.column), count = Count(lookup[1], distinct=True)
is_summary=True, distinct=True) summarize = True
else: else:
# Because of SQL portability issues, multi-column, distinct # Because of SQL portability issues, multi-column, distinct
# counts need a sub-query -- see get_count() for details. # counts need a sub-query -- see get_count() for details.
assert len(self.select) == 1, \ assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select'." "Cannot add count col with multiple cols in 'select'."
col = self.select[0].col
count = self.aggregates_module.Count(self.select[0].col, distinct=True) if isinstance(col, (tuple, list)):
count = Count(col[1], distinct=True)
else:
count = Count(col, distinct=True)
# Distinct handling is done in Count(), so don't do it at this # Distinct handling is done in Count(), so don't do it at this
# level. # level.
self.distinct = False self.distinct = False
# Set only aggregate to be the count column. # Set only aggregate to be the count column.
# Clear out the select cache to reflect the new unmasked aggregates. # Clear out the select cache to reflect the new unmasked annotations.
self._aggregates = {None: count} count = count.resolve_expression(self, summarize=summarize)
self.set_aggregate_mask(None) self._annotations = {None: count}
self.set_annotation_mask(None)
self.group_by = None self.group_by = None
def add_select_related(self, fields): def add_select_related(self, fields):
@ -1886,16 +1901,28 @@ class Query(object):
target[model] = set(f.name for f in fields) target[model] = set(f.name for f in fields)
def set_aggregate_mask(self, names): def set_aggregate_mask(self, names):
"Set the mask of aggregates that will actually be returned by the SELECT" warnings.warn(
"set_aggregate_mask() is deprecated. Use set_annotation_mask() instead.",
RemovedInDjango20Warning, stacklevel=2)
self.set_annotation_mask(names)
def set_annotation_mask(self, names):
"Set the mask of annotations that will actually be returned by the SELECT"
if names is None: if names is None:
self.aggregate_select_mask = None self.annotation_select_mask = None
else: else:
self.aggregate_select_mask = set(names) self.annotation_select_mask = set(names)
self._aggregate_select_cache = None self._annotation_select_cache = None
def append_aggregate_mask(self, names): def append_aggregate_mask(self, names):
if self.aggregate_select_mask is not None: warnings.warn(
self.set_aggregate_mask(set(names).union(self.aggregate_select_mask)) "append_aggregate_mask() is deprecated. Use append_annotation_mask() instead.",
RemovedInDjango20Warning, stacklevel=2)
self.append_annotation_mask(names)
def append_annotation_mask(self, names):
if self.annotation_select_mask is not None:
self.set_annotation_mask(set(names).union(self.annotation_select_mask))
def set_extra_mask(self, names): def set_extra_mask(self, names):
""" """
@ -1910,24 +1937,31 @@ class Query(object):
self._extra_select_cache = None self._extra_select_cache = None
@property @property
def aggregate_select(self): def annotation_select(self):
"""The OrderedDict of aggregate columns that are not masked, and should """The OrderedDict of aggregate columns that are not masked, and should
be used in the SELECT clause. be used in the SELECT clause.
This result is cached for optimization purposes. This result is cached for optimization purposes.
""" """
if self._aggregate_select_cache is not None: if self._annotation_select_cache is not None:
return self._aggregate_select_cache return self._annotation_select_cache
elif not self._aggregates: elif not self._annotations:
return {} return {}
elif self.aggregate_select_mask is not None: elif self.annotation_select_mask is not None:
self._aggregate_select_cache = OrderedDict( self._annotation_select_cache = OrderedDict(
(k, v) for k, v in self.aggregates.items() (k, v) for k, v in self.annotations.items()
if k in self.aggregate_select_mask if k in self.annotation_select_mask
) )
return self._aggregate_select_cache return self._annotation_select_cache
else: else:
return self.aggregates return self.annotations
@property
def aggregate_select(self):
warnings.warn(
"aggregate_select() is deprecated. Use annotation_select() instead.",
RemovedInDjango20Warning, stacklevel=2)
return self.annotation_select
@property @property
def extra_select(self): def extra_select(self):

View File

@ -7,9 +7,9 @@ from django.core.exceptions import FieldError
from django.db import connections from django.db import connections
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Date, DateTime, Col
from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
from django.db.models.sql.datastructures import Date, DateTime
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.utils import six from django.utils import six
from django.utils import timezone from django.utils import timezone
@ -229,7 +229,7 @@ class DateQuery(Query):
)) ))
self._check_field(field) # overridden in DateTimeQuery self._check_field(field) # overridden in DateTimeQuery
alias = joins[-1] alias = joins[-1]
select = self._get_select((alias, field.column), lookup_type) select = self._get_select(Col(alias, field), lookup_type)
self.clear_select_clause() self.clear_select_clause()
self.select = [SelectInfo(select, None)] self.select = [SelectInfo(select, None)]
self.distinct = True self.distinct = True

View File

@ -10,7 +10,6 @@ import warnings
from django.conf import settings from django.conf import settings
from django.db.models.fields import DateTimeField, Field from django.db.models.fields import DateTimeField, Field
from django.db.models.sql.datastructures import EmptyResultSet, Empty from django.db.models.sql.datastructures import EmptyResultSet, Empty
from django.db.models.sql.aggregates import Aggregate
from django.utils.deprecation import RemovedInDjango19Warning from django.utils.deprecation import RemovedInDjango19Warning
from django.utils.six.moves import xrange from django.utils.six.moves import xrange
from django.utils import timezone from django.utils import timezone
@ -78,7 +77,7 @@ class WhereNode(tree.Node):
else: else:
value_annotation = bool(value) value_annotation = bool(value)
if hasattr(obj, "prepare"): if hasattr(obj, 'prepare'):
value = obj.prepare(lookup_type, value) value = obj.prepare(lookup_type, value)
return (obj, lookup_type, value_annotation, value) return (obj, lookup_type, value_annotation, value)
@ -187,11 +186,9 @@ class WhereNode(tree.Node):
lvalue, params = lvalue.process(lookup_type, params_or_value, connection) lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
except EmptyShortCircuit: except EmptyShortCircuit:
raise EmptyResultSet raise EmptyResultSet
elif isinstance(lvalue, Aggregate):
params = lvalue.field.get_db_prep_lookup(lookup_type, params_or_value, connection)
else: else:
raise TypeError("'make_atom' expects a Constraint or an Aggregate " raise TypeError("'make_atom' expects a Constraint as the first "
"as the first item of its 'child' argument.") "item of its 'child' argument.")
if isinstance(lvalue, tuple): if isinstance(lvalue, tuple):
# A direct database column lookup. # A direct database column lookup.

View File

@ -86,7 +86,8 @@ manipulating the data of your Web application. Learn more about it below:
:doc:`Aggregation <topics/db/aggregation>` | :doc:`Aggregation <topics/db/aggregation>` |
:doc:`Custom fields <howto/custom-model-fields>` | :doc:`Custom fields <howto/custom-model-fields>` |
:doc:`Multiple databases <topics/db/multi-db>` | :doc:`Multiple databases <topics/db/multi-db>` |
:doc:`Custom lookups <howto/custom-lookups>` :doc:`Custom lookups <howto/custom-lookups>` |
:doc:`Query Expressions <ref/models/expressions>`
* **Other:** * **Other:**
:doc:`Supported databases <ref/databases>` | :doc:`Supported databases <ref/databases>` |

View File

@ -41,6 +41,17 @@ details on these changes.
:class:`~django.core.management.BaseCommand` instead, which takes no arguments :class:`~django.core.management.BaseCommand` instead, which takes no arguments
by default. by default.
* ``django.db.models.sql.aggregates`` module will be removed.
* ``django.contrib.gis.db.models.sql.aggregates`` module will be removed.
* The following methods and properties of ``django.db.sql.query.Query`` will
be removed:
* Properties: ``aggregates`` and ``aggregate_select``
* Methods: ``add_aggregate``, ``set_aggregate_mask``, and
``append_aggregate_mask``.
* ``django.template.resolve_variable`` will be removed. * ``django.template.resolve_variable`` will be removed.
* The ``error_message`` argument of ``django.forms.RegexField`` will be removed. * The ``error_message`` argument of ``django.forms.RegexField`` will be removed.

View File

@ -0,0 +1,522 @@
=================
Query Expressions
=================
.. currentmodule:: django.db.models
Query expressions describe a value or a computation that can be used as part of
a filter, an annotation, or an aggregation. There are a number of built-in
expressions (documented below) that can be used to help you write queries.
Expressions can be combined, or in some cases nested, to form more complex
computations.
Supported arithmetic
====================
Django supports addition, subtraction, multiplication, division, modulo
arithmetic, and the power operator on query expressions, using Python constants,
variables, and even other expressions.
.. versionadded:: 1.7
Support for the power operator ``**`` was added.
Some examples
=============
.. versionchanged:: 1.8
Some of the examples rely on functionality that is new in Django 1.8.
.. code-block:: python
# Find companies that have more employees than chairs.
Company.objects.filter(num_employees__gt=F('num_chairs'))
# Find companies that have at least twice as many employees
# as chairs. Both the querysets below are equivalent.
Company.objects.filter(num_employees__gt=F('num_chairs') * 2)
Company.objects.filter(
num_employees__gt=F('num_chairs') + F('num_chairs'))
# How many chairs are needed for each company to seat all employees?
>>> company = Company.objects.filter(
... num_employees__gt=F('num_chairs')).annotate(
... chairs_needed=F('num_employees') - F('num_chairs')).first()
>>> company.num_employees
120
>>> company.num_chairs
50
>>> company.chairs_needed
70
# Annotate models with an aggregated value. Both forms
# below are equivalent.
Company.objects.annotate(num_products=Count('products'))
Company.objects.annotate(num_products=Count(F('products')))
# Aggregates can contain complex computations also
Company.objects.annotate(num_offerings=Count(F('products') + F('services')))
Built-in Expressions
====================
``F()`` expressions
-------------------
.. class:: F
An ``F()`` object represents the value of a model field or annotated column. It
makes it possible to refer to model field values and perform database
operations using them without actually having to pull them out of the database
into Python memory.
Instead, Django uses the ``F()`` object to generate a SQL expression that
describes the required operation at the database level.
This is easiest to understand through an example. Normally, one might do
something like this::
# Tintin filed a news story!
reporter = Reporters.objects.get(name='Tintin')
reporter.stories_filed += 1
reporter.save()
Here, we have pulled the value of ``reporter.stories_filed`` from the database
into memory and manipulated it using familiar Python operators, and then saved
the object back to the database. But instead we could also have done::
from django.db.models import F
reporter = Reporters.objects.get(name='Tintin')
reporter.stories_filed = F('stories_filed') + 1
reporter.save()
Although ``reporter.stories_filed = F('stories_filed') + 1`` looks like a
normal Python assignment of value to an instance attribute, in fact it's an SQL
construct describing an operation on the database.
When Django encounters an instance of ``F()``, it overrides the standard Python
operators to create an encapsulated SQL expression; in this case, one which
instructs the database to increment the database field represented by
``reporter.stories_filed``.
Whatever value is or was on ``reporter.stories_filed``, Python never gets to
know about it - it is dealt with entirely by the database. All Python does,
through Django's ``F()`` class, is create the SQL syntax to refer to the field
and describe the operation.
.. note::
In order to access the new value that has been saved in this way, the object
will need to be reloaded::
reporter = Reporters.objects.get(pk=reporter.pk)
As well as being used in operations on single instances as above, ``F()`` can
be used on ``QuerySets`` of object instances, with ``update()``. This reduces
the two queries we were using above - the ``get()`` and the
:meth:`~Model.save()` - to just one::
reporter = Reporters.objects.filter(name='Tintin')
reporter.update(stories_filed=F('stories_filed') + 1)
We can also use :meth:`~django.db.models.query.QuerySet.update()` to increment
the field value on multiple objects - which could be very much faster than
pulling them all into Python from the database, looping over them, incrementing
the field value of each one, and saving each one back to the database::
Reporter.objects.all().update(stories_filed=F('stories_filed) + 1)
``F()`` therefore can offer performance advantages by:
* getting the database, rather than Python, to do work
* reducing the number of queries some operations require
.. _avoiding-race-conditions-using-f:
Avoiding race conditions using ``F()``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Another useful benefit of ``F()`` is that having the database - rather than
Python - update a field's value avoids a *race condition*.
If two Python threads execute the code in the first example above, one thread
could retrieve, increment, and save a field's value after the other has
retrieved it from the database. The value that the second thread saves will be
based on the original value; the work of the first thread will simply be lost.
If the database is responsible for updating the field, the process is more
robust: it will only ever update the field based on the value of the field in
the database when the :meth:`~Model.save()` or ``update()`` is executed, rather
than based on its value when the instance was retrieved.
Using ``F()`` in filters
~~~~~~~~~~~~~~~~~~~~~~~~
``F()`` is also very useful in ``QuerySet`` filters, where they make it
possible to filter a set of objects against criteria based on their field
values, rather than on Python values.
This is documented in :ref:`using F() expressions in queries
<using-f-expressions-in-filters>`.
.. _func-expressions:
``Func()`` expressions
----------------------
.. versionadded:: 1.8
``Func()`` expressions are the base type of all expressions that involve
database functions like ``COALESCE`` and ``LOWER``, or aggregates like ``SUM``.
They can be used directly::
queryset.annotate(field_lower=Func(F('field'), function='LOWER'))
or they can be used to build a library of database functions::
class Lower(Func):
function = 'LOWER'
queryset.annotate(field_lower=Lower(F('field')))
But both cases will result in a queryset where each model is annotated with an
extra attribute ``field_lower`` produced, roughly, from the following SQL::
SELECT
...
LOWER("app_label"."field") as "field_lower"
The ``Func`` API is as follows:
.. class:: Func(*expressions, **extra)
.. attribute:: function
A class attribute describing the function that will be generated.
Specifically, the ``function`` will be interpolated as the ``function``
placeholder within :attr:`template`. Defaults to ``None``.
.. attribute:: template
A class attribute, as a format string, that describes the SQL that is
generated for this function. Defaults to
``'%(function)s(%(expressions)s)'``.
.. attribute:: arg_joiner
A class attribute that denotes the character used to join the list of
``expressions`` together. Defaults to ``', '``.
The ``*expressions`` argument is a list of positional expressions that the
function will be applied to. The expressions will be converted to strings,
joined together with ``arg_joiner``, and then interpolated into the ``template``
as the ``expressions`` placeholder.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. Note that the keywords ``function`` and
``template`` can be used to replace the ``function`` and ``template``
attributes respectively, without having to define your own class.
``output_field`` can be used to define the expected return type.
``Aggregate()`` expressions
---------------------------
An aggregate expression is a special case of a :ref:`Func() expression
<func-expressions>` that informs the query that a ``GROUP BY`` clause
is required. All of the :ref:`aggregate functions <aggregation-functions>`,
like ``Sum()`` and ``Count()``, inherit from ``Aggregate()``.
Since ``Aggregate``\s are expressions and wrap expressions, you can represent
some complex computations::
Company.objects.annotate(
managers_required=(Count('num_employees') / 4) + Count('num_managers'))
The ``Aggregate`` API is as follows:
.. class:: Aggregate(expression, output_field=None, **extra)
.. attribute:: template
A class attribute, as a format string, that describes the SQL that is
generated for this aggregate. Defaults to
``'%(function)s( %(expressions)s )'``.
.. attribute:: function
A class attribute describing the aggregate function that will be
generated. Specifically, the ``function`` will be interpolated as the
``function`` placeholder within :attr:`template`. Defaults to ``None``.
The ``expression`` argument can be the name of a field on the model, or another
expression. It will be converted to a string and used as the ``expressions``
placeholder within the ``template``.
The ``output_field`` argument requires a model field instance, like
``IntegerField()`` or ``BooleanField()``, into which Django will load the value
after it's retrieved from the database.
Note that ``output_field`` is only required when Django is unable to determine
what field type the result should be. Complex expressions that mix field types
should define the desired ``output_field``. For example, adding an
``IntegerField()`` and a ``FloatField()`` together should probably have
``output_field=FloatField()`` defined.
.. versionchanged:: 1.8
``output_field`` is a new parameter.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute.
.. versionadded:: 1.8
Aggregate functions can now use arithmetic and reference multiple
model fields in a single function.
Creating your own Aggregate Functions
-------------------------------------
Creating your own aggregate is extremely easy. At a minimum, you need
to define ``function``, but you can also completely customize the
SQL that is generated. Here's a brief example::
class Count(Aggregate):
# supports COUNT(distinct field)
function = 'COUNT'
template = '%(function)s(%(distinct)s%(expressions)s)'
def __init__(self, expression, distinct=False, **extra):
super(Count, self).__init__(
expression,
distinct='DISTINCT ' if distinct else '',
output_field=IntegerField(),
**extra)
``Value()`` expressions
-----------------------
.. class:: Value(value, output_field=None)
A ``Value()`` object represents the smallest possible component of an
expression: a simple value. When you need to represent the value of an integer,
boolean, or string within an expression, you can wrap that value within a
``Value()``.
You will rarely need to use ``Value()`` directly. When you write the expression
``F('field') + 1``, Django implicitly wraps the ``1`` in a ``Value()``,
allowing simple values to be used in more complex expressions.
The ``value`` argument describes the value to be included in the expression,
such as ``1``, ``True``, or ``None``. Django knows how to convert these Python
values into their corresponding database type.
The ``output_field`` argument should be a model field instance, like
``IntegerField()`` or ``BooleanField()``, into which Django will load the value
after it's retrieved from the database.
Technical Information
=====================
Below you'll find technical implementation details that may be useful to
library authors. The technical API and examples below will help with
creating generic query expressions that can extend the built-in functionality
that Django provides.
Expression API
--------------
Query expressions implement the :ref:`query expression API <query-expression>`,
but also expose a number of extra methods and attributes listed below. All
query expressions must inherit from ``ExpressionNode()`` or a relevant
subclass.
When a query expression wraps another expression, it is responsible for
calling the appropriate methods on the wrapped expression.
.. class:: ExpressionNode
.. attribute:: contains_aggregate
Tells Django that this expression contains an aggregate and that a
``GROUP BY`` clause needs to be added to the query.
.. method:: resolve_expression(query=None, allow_joins=True, reuse=None, summarize=False)
Provides the chance to do any pre-processing or validation of
the expression before it's added to the query. ``resolve_expression()``
must also be called on any nested expressions. A ``copy()`` of ``self``
should be returned with any necessary transformations.
``query`` is the backend query implementation.
``allow_joins`` is a boolean that allows or denies the use of
joins in the query.
``reuse`` is a set of reusable joins for multi-join scenarios.
``summarize`` is a boolean that, when ``True``, signals that the
query being computed is a terminal aggregate query.
.. method:: get_source_expressions()
Returns an ordered list of inner expressions. For example::
>>> Sum(F('foo')).get_source_expressions()
[F('foo')]
.. method:: set_source_expressions(expressions)
Takes a list of expressions and stores them such that
``get_source_expressions()`` can return them.
.. method:: relabeled_clone(change_map)
Returns a clone (copy) of ``self``, with any column aliases relabeled.
Column aliases are renamed when subqueries are created.
``relabeled_clone()`` should also be called on any nested expressions
and assigned to the clone.
``change_map`` is a dictionary mapping old aliases to new aliases.
Example::
def relabeled_clone(self, change_map):
clone = copy.copy(self)
clone.expression = self.expression.relabeled_clone(change_map)
return clone
.. method:: convert_value(self, value, connection)
A hook allowing the expression to coerce ``value`` into a more
appropriate type.
.. method:: refs_aggregate(existing_aggregates)
Returns a tuple containing the ``(aggregate, lookup_path)`` of the
first aggregate that this expression (or any nested expression)
references, or ``(False, ())`` if no aggregate is referenced.
For example::
queryset.filter(num_chairs__gt=F('sum__employees'))
The ``F()`` expression here references a previous ``Sum()``
computation which means that this filter expression should be
added to the ``HAVING`` clause rather than the ``WHERE`` clause.
In the majority of cases, returning the result of ``refs_aggregate``
on any nested expression should be appropriate, as the necessary
built-in expressions will return the correct values.
.. method:: get_group_by_cols()
Responsible for returning the list of columns references by
this expression. ``get_group_by_cols()`` should be called on any
nested expressions. ``F()`` objects, in particular, hold a reference
to a column.
Writing your own Query Expressions
----------------------------------
You can write your own query expression classes that use, and can integrate
with, other query expressions. Let's step through an example by writing an
implementation of the ``COALESCE`` SQL function, without using the built-in
:ref:`Func() expressions <func-expressions>`.
The ``COALESCE`` SQL function is defined as taking a list of columns or
values. It will return the first column or value that isn't ``NULL``.
We'll start by defining the template to be used for SQL generation and
an ``__init__()`` method to set some attributes::
import copy
from django.db.models import ExpressionNode
class Coalesce(ExpressionNode):
template = 'COALESCE( %(expressions)s )'
def __init__(self, expressions, output_field, **extra):
super(Coalesce, self).__init__(output_field=output_field)
if len(expressions) < 2:
raise ValueError('expressions must have at least 2 elements')
for expression in expressions:
if not hasattr(expression, 'resolve_expression'):
raise TypeError('%r is not an Expression' % expression)
self.expressions = expressions
self.extra = extra
We do some basic validation on the parameters, including requiring at least
2 columns or values, and ensuring they are expressions. We are requiring
``output_field`` here so that Django knows what kind of model field to assign
the eventual result to.
Now we implement the pre-processing and validation. Since we do not have
any of our own validation at this point, we just delegate to the nested
expressions::
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = self.copy()
c.is_summary = summarize
for pos, expression in enumerate(self.expressions):
c.expressions[pos] = expression.resolve_expression(query, allow_joins, reuse, summarize)
return c
Next, we write the method responsible for generating the SQL::
def as_sql(self, compiler, connection):
sql_expressions, sql_params = [], []
for expression in self.expressions:
sql, params = compiler.compile(expression)
sql_expressions.append(sql)
sql_params.extend(params)
self.extra['expressions'] = ','.join(sql_expressions)
return self.template % self.extra, sql_params
def as_oracle(self, compiler, connection):
"""
Example of vendor specific handling (Oracle in this case).
Let's make the function name lowercase.
"""
self.template = 'coalesce( %(expressions)s )'
return self.as_sql(compiler, connection)
We generate the SQL for each of the ``expressions`` by using the
``compiler.compile()`` method, and join the result together with commas.
Then the template is filled out with our data and the SQL and parameters
are returned.
We've also defined a custom implementation that is specific to the Oracle
backend. The ``as_oracle()`` function will be called instead of ``as_sql()``
if the Oracle backend is in use.
Finally, we implement the rest of the methods that allow our query expression
to play nice with other query expressions::
def get_source_expressions(self):
return self.expressions
def set_source_expressions(expressions):
self.expressions = expressions
Let's see how it works::
>>> qs = Company.objects.annotate(
... tagline=Coalesce([
... F('motto'),
... F('ticker_name'),
... F('description'),
... Value('No Tagline')
... ], output_field=CharField()))
>>> for c in qs:
... print("%s: %s" % (c.name, c.tagline))
...
Google: Do No Evil
Apple: AAPL
Yahoo: Internet Company
Django Software Foundation: No Tagline

View File

@ -15,3 +15,4 @@ Model API reference. For introductory material, see :doc:`/topics/db/models`.
querysets querysets
queries queries
lookups lookups
expressions

View File

@ -7,115 +7,6 @@ Query-related classes
This document provides reference material for query-related tools not This document provides reference material for query-related tools not
documented elsewhere. documented elsewhere.
``F()`` expressions
===================
.. class:: F
An ``F()`` object represents the value of a model field. It makes it possible
to refer to model field values and perform database operations using them
without actually having to pull them out of the database into Python memory.
Instead, Django uses the ``F()`` object to generate a SQL expression that
describes the required operation at the database level.
This is easiest to understand through an example. Normally, one might do
something like this::
# Tintin filed a news story!
reporter = Reporters.objects.get(name='Tintin')
reporter.stories_filed += 1
reporter.save()
Here, we have pulled the value of ``reporter.stories_filed`` from the database
into memory and manipulated it using familiar Python operators, and then saved
the object back to the database. But instead we could also have done::
from django.db.models import F
reporter = Reporters.objects.get(name='Tintin')
reporter.stories_filed = F('stories_filed') + 1
reporter.save()
Although ``reporter.stories_filed = F('stories_filed') + 1`` looks like a
normal Python assignment of value to an instance attribute, in fact it's an SQL
construct describing an operation on the database.
When Django encounters an instance of ``F()``, it overrides the standard Python
operators to create an encapsulated SQL expression; in this case, one which
instructs the database to increment the database field represented by
``reporter.stories_filed``.
Whatever value is or was on ``reporter.stories_filed``, Python never gets to
know about it - it is dealt with entirely by the database. All Python does,
through Django's ``F()`` class, is create the SQL syntax to refer to the field
and describe the operation.
.. note::
In order to access the new value that has been saved in this way, the object
will need to be reloaded::
reporter = Reporters.objects.get(pk=reporter.pk)
As well as being used in operations on single instances as above, ``F()`` can
be used on ``QuerySets`` of object instances, with ``update()``. This reduces
the two queries we were using above - the ``get()`` and the
:meth:`~Model.save()` - to just one::
reporter = Reporters.objects.filter(name='Tintin')
reporter.update(stories_filed=F('stories_filed') + 1)
We can also use :meth:`~django.db.models.query.QuerySet.update()` to increment
the field value on multiple objects - which could be very much faster than
pulling them all into Python from the database, looping over them, incrementing
the field value of each one, and saving each one back to the database::
Reporter.objects.all().update(stories_filed=F('stories_filed') + 1)
``F()`` therefore can offer performance advantages by:
* getting the database, rather than Python, to do work
* reducing the number of queries some operations require
.. _avoiding-race-conditions-using-f:
Avoiding race conditions using ``F()``
--------------------------------------
Another useful benefit of ``F()`` is that having the database - rather than
Python - update a field's value avoids a *race condition*.
If two Python threads execute the code in the first example above, one thread
could retrieve, increment, and save a field's value after the other has
retrieved it from the database. The value that the second thread saves will be
based on the original value; the work of the first thread will simply be lost.
If the database is responsible for updating the field, the process is more
robust: it will only ever update the field based on the value of the field in
the database when the :meth:`~Model.save()` or ``update()`` is executed, rather
than based on its value when the instance was retrieved.
Using ``F()`` in filters
------------------------
``F()`` is also very useful in ``QuerySet`` filters, where they make it
possible to filter a set of objects against criteria based on their field
values, rather than on Python values.
This is documented in :ref:`using F() expressions in queries
<using-f-expressions-in-filters>`
Supported operations with ``F()``
---------------------------------
As well as addition, Django supports subtraction, multiplication, division,
and modulo arithmetic with ``F()`` objects, using Python constants,
variables, and even other ``F()`` objects.
.. versionadded:: 1.7
The power operator ``**`` is also supported.
``Q()`` objects ``Q()`` objects
=============== ===============

View File

@ -220,9 +220,18 @@ annotate
.. method:: annotate(*args, **kwargs) .. method:: annotate(*args, **kwargs)
Annotates each object in the ``QuerySet`` with the provided list of Annotates each object in the ``QuerySet`` with the provided list of :doc:`query
aggregate values (averages, sums, etc) that have been computed over expressions </ref/models/expressions>`. An expression may be a simple value, a
the objects that are related to the objects in the ``QuerySet``. reference to a field on the model (or any related models), or an aggregate
expression (averages, sums, etc) that has been computed over the objects that
are related to the objects in the ``QuerySet``.
.. versionadded:: 1.8
Previous versions of Django only allowed aggregate functions to be used as
annotations. It is now possible to annotate a model with all kinds of
expressions.
Each argument to ``annotate()`` is an annotation that will be added Each argument to ``annotate()`` is an annotation that will be added
to each object in the ``QuerySet`` that is returned. to each object in the ``QuerySet`` that is returned.
@ -232,7 +241,9 @@ in `Aggregation Functions`_ below.
Annotations specified using keyword arguments will use the keyword as Annotations specified using keyword arguments will use the keyword as
the alias for the annotation. Anonymous arguments will have an alias the alias for the annotation. Anonymous arguments will have an alias
generated for them based upon the name of the aggregate function and generated for them based upon the name of the aggregate function and
the model field that is being aggregated. the model field that is being aggregated. Only aggregate expressions
that reference a single field can be anonymous arguments. Everything
else must be a keyword argument.
For example, if you were manipulating a list of blogs, you may want For example, if you were manipulating a list of blogs, you may want
to determine how many entries have been made in each blog:: to determine how many entries have been made in each blog::
@ -1886,12 +1897,15 @@ the ``QuerySet``. Each argument to ``aggregate()`` specifies a value that will
be included in the dictionary that is returned. be included in the dictionary that is returned.
The aggregation functions that are provided by Django are described in The aggregation functions that are provided by Django are described in
`Aggregation Functions`_ below. `Aggregation Functions`_ below. Since aggregates are also :doc:`query
expressions </ref/models/expressions>`, you may combine aggregates with other
aggregates or values to create complex aggregates.
Aggregates specified using keyword arguments will use the keyword as the name Aggregates specified using keyword arguments will use the keyword as the name
for the annotation. Anonymous arguments will have a name generated for them for the annotation. Anonymous arguments will have a name generated for them
based upon the name of the aggregate function and the model field that is being based upon the name of the aggregate function and the model field that is being
aggregated. aggregated. Complex aggregates cannot use anonymous arguments and must specify
a keyword argument as an alias.
For example, when you are working with blog entries, you may want to know the For example, when you are working with blog entries, you may want to know the
number of authors that have contributed blog entries:: number of authors that have contributed blog entries::
@ -2667,8 +2681,9 @@ Aggregation functions
Django provides the following aggregation functions in the Django provides the following aggregation functions in the
``django.db.models`` module. For details on how to use these ``django.db.models`` module. For details on how to use these
aggregate functions, see aggregate functions, see :doc:`the topic guide on aggregation
:doc:`the topic guide on aggregation </topics/db/aggregation>`. </topics/db/aggregation>`. See the :class:`~django.db.models.Aggregate`
documentation to learn how to create your aggregates.
.. warning:: .. warning::
@ -2685,12 +2700,47 @@ aggregate functions, see
instead of ``0`` if the ``QuerySet`` contains no entries. An exception is instead of ``0`` if the ``QuerySet`` contains no entries. An exception is
``Count``, which does return ``0`` if the ``QuerySet`` is empty. ``Count``, which does return ``0`` if the ``QuerySet`` is empty.
All aggregates have the following parameters in common:
``expression``
~~~~~~~~~~~~~~
A string that references a field on the model, or a :doc:`query expression
</ref/models/expressions>`.
.. versionadded:: 1.8
Aggregate functions are now able to reference multiple fields in complex
computations.
``output_field``
~~~~~~~~~~~~~~~~
An optional argument that represents the :doc:`model field </ref/models/fields>`
of the return value
.. versionadded:: 1.8
The ``output_field`` argument was added.
.. note::
When combining multiple field types, Django can only determine the
``output_field`` if all fields are of the same type. Otherwise, you
must provide the ``output_field`` yourself.
``**extra``
~~~~~~~~~~~
Keyword arguments that can provide extra context for the SQL generated
by the aggregate.
Avg Avg
~~~ ~~~
.. class:: Avg(field) .. class:: Avg(expression, output_field=None, **extra)
Returns the mean value of the given field, which must be numeric. Returns the mean value of the given expression, which must be numeric.
* Default alias: ``<field>__avg`` * Default alias: ``<field>__avg``
* Return type: ``float`` * Return type: ``float``
@ -2698,9 +2748,10 @@ Avg
Count Count
~~~~~ ~~~~~
.. class:: Count(field, distinct=False) .. class:: Count(expression, distinct=False, **extra)
Returns the number of objects that are related through the provided field. Returns the number of objects that are related through the provided
expression.
* Default alias: ``<field>__count`` * Default alias: ``<field>__count``
* Return type: ``int`` * Return type: ``int``
@ -2716,29 +2767,29 @@ Count
Max Max
~~~ ~~~
.. class:: Max(field) .. class:: Max(expression, output_field=None, **extra)
Returns the maximum value of the given field. Returns the maximum value of the given expression.
* Default alias: ``<field>__max`` * Default alias: ``<field>__max``
* Return type: same as input field * Return type: same as input field, or ``output_field`` if supplied
Min Min
~~~ ~~~
.. class:: Min(field) .. class:: Min(expression, output_field=None, **extra)
Returns the minimum value of the given field. Returns the minimum value of the given expression.
* Default alias: ``<field>__min`` * Default alias: ``<field>__min``
* Return type: same as input field * Return type: same as input field, or ``output_field`` if supplied
StdDev StdDev
~~~~~~ ~~~~~~
.. class:: StdDev(field, sample=False) .. class:: StdDev(expression, sample=False, **extra)
Returns the standard deviation of the data in the provided field. Returns the standard deviation of the data in the provided expression.
* Default alias: ``<field>__stddev`` * Default alias: ``<field>__stddev``
* Return type: ``float`` * Return type: ``float``
@ -2760,19 +2811,19 @@ StdDev
Sum Sum
~~~ ~~~
.. class:: Sum(field) .. class:: Sum(expression, output_field=None, **extra)
Computes the sum of all values of the given field. Computes the sum of all values of the given expression.
* Default alias: ``<field>__sum`` * Default alias: ``<field>__sum``
* Return type: same as input field * Return type: same as input field, or ``output_field`` if supplied
Variance Variance
~~~~~~~~ ~~~~~~~~
.. class:: Variance(field, sample=False) .. class:: Variance(expression, sample=False, **extra)
Returns the variance of the data in the provided field. Returns the variance of the data in the provided expression.
* Default alias: ``<field>__variance`` * Default alias: ``<field>__variance``
* Return type: ``float`` * Return type: ``float``

View File

@ -52,6 +52,15 @@ New data types
<django.forms.UUIDField>`. It is stored as the native ``uuid`` data type on <django.forms.UUIDField>`. It is stored as the native ``uuid`` data type on
PostgreSQL and as a fixed length character field on other backends. PostgreSQL and as a fixed length character field on other backends.
Query Expressions
~~~~~~~~~~~~~~~~~
:doc:`Query Expressions </ref/models/expressions>` allow users to create,
customize, and compose complex SQL expressions. This has enabled annotate
to accept expressions other than aggregates. Aggregates are now able to
reference multiple fields, as well as perform arithmetic, similar to ``F()``
objects.
Minor features Minor features
~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~
@ -857,6 +866,29 @@ or ``name='django.contrib.gis.sitemaps.views.kmz'``.
.. _security issue: https://www.djangoproject.com/weblog/2014/apr/21/security/#s-issue-unexpected-code-execution-using-reverse .. _security issue: https://www.djangoproject.com/weblog/2014/apr/21/security/#s-issue-unexpected-code-execution-using-reverse
Aggregate methods and modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The ``django.db.models.sql.aggregates`` and
``django.contrib.gis.db.models.sql.aggregates`` modules (both private API), have
been deprecated as ``django.db.models.aggregates`` and
``django.contrib.gis.db.models.aggregates`` are now also responsible
for SQL generation. The old modules will be removed in Django 2.0.
If you were using the old modules, see :doc:`Query Expressions
</ref/models/expressions>` for instructions on rewriting custom aggregates
using the new stable API.
The following methods and properties of ``django.db.models.sql.query.Query``
have also been deprecated and the backwards compatibility shims will be removed
in Django 2.0:
* ``Query.aggregates``, replaced by ``annotations``.
* ``Query.aggregate_select``, replaced by ``annotation_select``.
* ``Query.add_aggregate()``, replaced by ``add_annotation()``.
* ``Query.set_aggregate_mask()``, replaced by ``set_annotation_mask()``.
* ``Query.append_aggregate_mask()``, replaced by ``append_annotation_mask()``.
Extending management command arguments through ``Command.option_list`` Extending management command arguments through ``Command.option_list``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -67,6 +67,11 @@ In a hurry? Here's how to do common aggregate queries, assuming the models above
>>> Book.objects.all().aggregate(Max('price')) >>> Book.objects.all().aggregate(Max('price'))
{'price__max': Decimal('81.20')} {'price__max': Decimal('81.20')}
# Cost per page
>>> Book.objects.all().aggregate(
... price_per_page=Sum(F('price')/F('pages'), output_field=FloatField()))
{'price_per_page': 0.4470664529184653}
# All the following queries involve traversing the Book<->Publisher # All the following queries involve traversing the Book<->Publisher
# many-to-many relationship backward # many-to-many relationship backward

View File

@ -3,12 +3,21 @@ from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
import re import re
import warnings
from django.core.exceptions import FieldError
from django.db import connection from django.db import connection
from django.db.models import Avg, Sum, Count, Max, Min from django.db.models import (
Avg, Sum, Count, Max, Min,
Aggregate, F, Value, Func,
IntegerField, FloatField, DecimalField)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
from django.db.models.sql import aggregates as sql_aggregates
from django.test import TestCase from django.test import TestCase
from django.test.utils import Approximate from django.test.utils import Approximate
from django.test.utils import CaptureQueriesContext from django.test.utils import CaptureQueriesContext
from django.utils.deprecation import RemovedInDjango20Warning
from .models import Author, Publisher, Book, Store from .models import Author, Publisher, Book, Store
@ -678,3 +687,271 @@ class BaseAggregateTestCase(TestCase):
else: else:
self.assertNotIn('order by', qstr) self.assertNotIn('order by', qstr)
self.assertEqual(qstr.count(' join '), 0) self.assertEqual(qstr.count(' join '), 0)
class ComplexAggregateTestCase(TestCase):
fixtures = ["aggregation.json"]
def test_nonaggregate_aggregation_throws(self):
with self.assertRaisesRegexp(TypeError, 'fail is not an aggregate expression'):
Book.objects.aggregate(fail=F('price'))
def test_nonfield_annotation(self):
book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField())))[0]
self.assertEqual(book.val, 2)
book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField()))[0]
self.assertEqual(book.val, 2)
def test_missing_output_field_raises_error(self):
with self.assertRaisesRegexp(FieldError, 'Cannot resolve expression type, unknown output_field'):
Book.objects.annotate(val=Max(Value(2)))[0]
def test_annotation_expressions(self):
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
for qs in (authors, authors2):
self.assertEqual(len(qs), 9)
self.assertQuerysetEqual(
qs, [
('Adrian Holovaty', 132),
('Brad Dayley', None),
('Jacob Kaplan-Moss', 129),
('James Bennett', 63),
('Jeffrey Forcier', 128),
('Paul Bissex', 120),
('Peter Norvig', 103),
('Stuart Russell', 103),
('Wesley J. Chun', 176)
],
lambda a: (a.name, a.combined_ages)
)
def test_aggregation_expressions(self):
a1 = Author.objects.aggregate(av_age=Sum('age') / Count('*'))
a2 = Author.objects.aggregate(av_age=Sum('age') / Count('age'))
a3 = Author.objects.aggregate(av_age=Avg('age'))
self.assertEqual(a1, {'av_age': 37})
self.assertEqual(a2, {'av_age': 37})
self.assertEqual(a3, {'av_age': Approximate(37.4, places=1)})
def test_order_of_precedence(self):
p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3)
self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)})
p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3)
self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)})
def test_combine_different_types(self):
with self.assertRaisesRegexp(FieldError, 'Expression contains mixed types. You must set output_field'):
Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')).get(pk=4)
b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
output_field=IntegerField())).get(pk=4)
self.assertEqual(b1.sums, 383)
b2 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
output_field=FloatField())).get(pk=4)
self.assertEqual(b2.sums, 383.69)
b3 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
output_field=DecimalField(max_digits=6, decimal_places=2))).get(pk=4)
self.assertEqual(b3.sums, Decimal("383.69"))
def test_complex_aggregations_require_kwarg(self):
with self.assertRaisesRegexp(TypeError, 'Complex expressions require an alias'):
Author.objects.annotate(Sum(F('age') + F('friends__age')))
with self.assertRaisesRegexp(TypeError, 'Complex aggregates require an alias'):
Author.objects.aggregate(Sum('age') / Count('age'))
def test_aggregate_over_complex_annotation(self):
qs = Author.objects.annotate(
combined_ages=Sum(F('age') + F('friends__age')))
age = qs.aggregate(max_combined_age=Max('combined_ages'))
self.assertEqual(age['max_combined_age'], 176)
age = qs.aggregate(max_combined_age_doubled=Max('combined_ages') * 2)
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
age = qs.aggregate(
max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'))
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
age = qs.aggregate(
max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'),
sum_combined_age=Sum('combined_ages'))
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
self.assertEqual(age['sum_combined_age'], 954)
age = qs.aggregate(
max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'),
sum_combined_age_doubled=Sum('combined_ages') + Sum('combined_ages'))
self.assertEqual(age['max_combined_age_doubled'], 176 * 2)
self.assertEqual(age['sum_combined_age_doubled'], 954 * 2)
def test_values_annotation_with_expression(self):
# ensure the F() is promoted to the group by clause
qs = Author.objects.values('name').annotate(another_age=Sum('age') + F('age'))
a = qs.get(pk=1)
self.assertEqual(a['another_age'], 68)
qs = qs.annotate(friend_count=Count('friends'))
a = qs.get(pk=1)
self.assertEqual(a['friend_count'], 2)
qs = qs.annotate(combined_age=Sum('age') + F('friends__age')).filter(pk=1).order_by('-combined_age')
self.assertEqual(
list(qs), [
{
"name": 'Adrian Holovaty',
"another_age": 68,
"friend_count": 1,
"combined_age": 69
},
{
"name": 'Adrian Holovaty',
"another_age": 68,
"friend_count": 1,
"combined_age": 63
}
]
)
vals = qs.values('name', 'combined_age')
self.assertEqual(
list(vals), [
{
"name": 'Adrian Holovaty',
"combined_age": 69
},
{
"name": 'Adrian Holovaty',
"combined_age": 63
}
]
)
def test_annotate_values_aggregate(self):
alias_age = Author.objects.annotate(
age_alias=F('age')
).values(
'age_alias',
).aggregate(sum_age=Sum('age_alias'))
age = Author.objects.values('age').aggregate(sum_age=Sum('age'))
self.assertEqual(alias_age['sum_age'], age['sum_age'])
def test_annotate_over_annotate(self):
author = Author.objects.annotate(
age_alias=F('age')
).annotate(
sum_age=Sum('age_alias')
).get(pk=1)
other_author = Author.objects.annotate(
sum_age=Sum('age')
).get(pk=1)
self.assertEqual(author.sum_age, other_author.sum_age)
def test_annotated_aggregate_over_annotated_aggregate(self):
with self.assertRaisesRegexp(FieldError, "Cannot compute Sum\('id__max'\): 'id__max' is an aggregate"):
Book.objects.annotate(Max('id')).annotate(Sum('id__max'))
def test_add_implementation(self):
try:
# test completely changing how the output is rendered
def lower_case_function_override(self, qn, connection):
sql, params = qn.compile(self.source_expressions[0])
substitutions = dict(function=self.function.lower(), expressions=sql)
substitutions.update(self.extra)
return self.template % substitutions, params
setattr(Sum, 'as_' + connection.vendor, lower_case_function_override)
qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
output_field=IntegerField()))
self.assertEqual(str(qs.query).count('sum('), 1)
b1 = qs.get(pk=4)
self.assertEqual(b1.sums, 383)
# test changing the dict and delegating
def lower_case_function_super(self, qn, connection):
self.extra['function'] = self.function.lower()
return super(Sum, self).as_sql(qn, connection)
setattr(Sum, 'as_' + connection.vendor, lower_case_function_super)
qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
output_field=IntegerField()))
self.assertEqual(str(qs.query).count('sum('), 1)
b1 = qs.get(pk=4)
self.assertEqual(b1.sums, 383)
# test overriding all parts of the template
def be_evil(self, qn, connection):
substitutions = dict(function='MAX', expressions='2')
substitutions.update(self.extra)
return self.template % substitutions, ()
setattr(Sum, 'as_' + connection.vendor, be_evil)
qs = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'),
output_field=IntegerField()))
self.assertEqual(str(qs.query).count('MAX('), 1)
b1 = qs.get(pk=4)
self.assertEqual(b1.sums, 2)
finally:
delattr(Sum, 'as_' + connection.vendor)
def test_complex_values_aggregation(self):
max_rating = Book.objects.values('rating').aggregate(
double_max_rating=Max('rating') + Max('rating'))
self.assertEqual(max_rating['double_max_rating'], 5 * 2)
max_books_per_rating = Book.objects.values('rating').annotate(
books_per_rating=Count('id') + 5
).aggregate(Max('books_per_rating'))
self.assertEqual(
max_books_per_rating,
{'books_per_rating__max': 3 + 5})
def test_expression_on_aggregation(self):
# Create a plain expression
class Greatest(Func):
function = 'GREATEST'
def as_sqlite(self, qn, connection):
return super(Greatest, self).as_sql(qn, connection, function='MAX')
qs = Publisher.objects.annotate(
price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))
).filter(price_or_median__gte=F('num_awards')).order_by('pk')
self.assertQuerysetEqual(
qs, [1, 2, 3, 4], lambda v: v.pk)
qs2 = Publisher.objects.annotate(
rating_or_num_awards=Greatest(Avg('book__rating'), F('num_awards'),
output_field=FloatField())
).filter(rating_or_num_awards__gt=F('num_awards')).order_by('pk')
self.assertQuerysetEqual(
qs2, [1, 2], lambda v: v.pk)
def test_backwards_compatibility(self):
class SqlNewSum(sql_aggregates.Aggregate):
sql_function = 'SUM'
class NewSum(Aggregate):
name = 'Sum'
def add_to_query(self, query, alias, col, source, is_summary):
klass = SqlNewSum
aggregate = klass(
col, source=source, is_summary=is_summary, **self.extra)
query.annotations[alias] = aggregate
with warnings.catch_warnings():
warnings.simplefilter("ignore", RemovedInDjango20Warning)
qs = Author.objects.values('name').annotate(another_age=NewSum('age') + F('age'))
a = qs.get(pk=1)
self.assertEqual(a['another_age'], 68)

View File

View File

@ -0,0 +1,243 @@
[
{
"pk": 1,
"model": "annotations.publisher",
"fields": {
"name": "Apress",
"num_awards": 3
}
},
{
"pk": 2,
"model": "annotations.publisher",
"fields": {
"name": "Sams",
"num_awards": 1
}
},
{
"pk": 3,
"model": "annotations.publisher",
"fields": {
"name": "Prentice Hall",
"num_awards": 7
}
},
{
"pk": 4,
"model": "annotations.publisher",
"fields": {
"name": "Morgan Kaufmann",
"num_awards": 9
}
},
{
"pk": 5,
"model": "annotations.publisher",
"fields": {
"name": "Jonno's House of Books",
"num_awards": 0
}
},
{
"pk": 1,
"model": "annotations.book",
"fields": {
"publisher": 1,
"isbn": "159059725",
"name": "The Definitive Guide to Django: Web Development Done Right",
"price": "30.00",
"rating": 4.5,
"authors": [1, 2],
"contact": 1,
"pages": 447,
"pubdate": "2007-12-6"
}
},
{
"pk": 2,
"model": "annotations.book",
"fields": {
"publisher": 2,
"isbn": "067232959",
"name": "Sams Teach Yourself Django in 24 Hours",
"price": "23.09",
"rating": 3.0,
"authors": [3],
"contact": 3,
"pages": 528,
"pubdate": "2008-3-3"
}
},
{
"pk": 3,
"model": "annotations.book",
"fields": {
"publisher": 1,
"isbn": "159059996",
"name": "Practical Django Projects",
"price": "29.69",
"rating": 4.0,
"authors": [4],
"contact": 4,
"pages": 300,
"pubdate": "2008-6-23"
}
},
{
"pk": 4,
"model": "annotations.book",
"fields": {
"publisher": 3,
"isbn": "013235613",
"name": "Python Web Development with Django",
"price": "29.69",
"rating": 4.0,
"authors": [5, 6, 7],
"contact": 5,
"pages": 350,
"pubdate": "2008-11-3"
}
},
{
"pk": 5,
"model": "annotations.book",
"fields": {
"publisher": 3,
"isbn": "013790395",
"name": "Artificial Intelligence: A Modern Approach",
"price": "82.80",
"rating": 4.0,
"authors": [8, 9],
"contact": 8,
"pages": 1132,
"pubdate": "1995-1-15"
}
},
{
"pk": 6,
"model": "annotations.book",
"fields": {
"publisher": 4,
"isbn": "155860191",
"name": "Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp",
"price": "75.00",
"rating": 5.0,
"authors": [8],
"contact": 8,
"pages": 946,
"pubdate": "1991-10-15"
}
},
{
"pk": 1,
"model": "annotations.store",
"fields": {
"books": [1, 2, 3, 4, 5, 6],
"name": "Amazon.com",
"original_opening": "1994-4-23 9:17:42",
"friday_night_closing": "23:59:59"
}
},
{
"pk": 2,
"model": "annotations.store",
"fields": {
"books": [1, 3, 5, 6],
"name": "Books.com",
"original_opening": "2001-3-15 11:23:37",
"friday_night_closing": "23:59:59"
}
},
{
"pk": 3,
"model": "annotations.store",
"fields": {
"books": [3, 4, 6],
"name": "Mamma and Pappa's Books",
"original_opening": "1945-4-25 16:24:14",
"friday_night_closing": "21:30:00"
}
},
{
"pk": 1,
"model": "annotations.author",
"fields": {
"age": 34,
"friends": [2, 4],
"name": "Adrian Holovaty"
}
},
{
"pk": 2,
"model": "annotations.author",
"fields": {
"age": 35,
"friends": [1, 7],
"name": "Jacob Kaplan-Moss"
}
},
{
"pk": 3,
"model": "annotations.author",
"fields": {
"age": 45,
"friends": [],
"name": "Brad Dayley"
}
},
{
"pk": 4,
"model": "annotations.author",
"fields": {
"age": 29,
"friends": [1],
"name": "James Bennett"
}
},
{
"pk": 5,
"model": "annotations.author",
"fields": {
"age": 37,
"friends": [6, 7],
"name": "Jeffrey Forcier"
}
},
{
"pk": 6,
"model": "annotations.author",
"fields": {
"age": 29,
"friends": [5, 7],
"name": "Paul Bissex"
}
},
{
"pk": 7,
"model": "annotations.author",
"fields": {
"age": 25,
"friends": [2, 5, 6],
"name": "Wesley J. Chun"
}
},
{
"pk": 8,
"model": "annotations.author",
"fields": {
"age": 57,
"friends": [9],
"name": "Peter Norvig"
}
},
{
"pk": 9,
"model": "annotations.author",
"fields": {
"age": 46,
"friends": [8],
"name": "Stuart Russell"
}
}
]

View File

@ -0,0 +1,86 @@
# coding: utf-8
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
@python_2_unicode_compatible
class Author(models.Model):
name = models.CharField(max_length=100)
age = models.IntegerField()
friends = models.ManyToManyField('self', blank=True)
def __str__(self):
return self.name
@python_2_unicode_compatible
class Publisher(models.Model):
name = models.CharField(max_length=255)
num_awards = models.IntegerField()
def __str__(self):
return self.name
@python_2_unicode_compatible
class Book(models.Model):
isbn = models.CharField(max_length=9)
name = models.CharField(max_length=255)
pages = models.IntegerField()
rating = models.FloatField()
price = models.DecimalField(decimal_places=2, max_digits=6)
authors = models.ManyToManyField(Author)
contact = models.ForeignKey(Author, related_name='book_contact_set')
publisher = models.ForeignKey(Publisher)
pubdate = models.DateField()
def __str__(self):
return self.name
@python_2_unicode_compatible
class Store(models.Model):
name = models.CharField(max_length=255)
books = models.ManyToManyField(Book)
original_opening = models.DateTimeField()
friday_night_closing = models.TimeField()
def __str__(self):
return self.name
@python_2_unicode_compatible
class DepartmentStore(Store):
chain = models.CharField(max_length=255)
def __str__(self):
return '%s - %s ' % (self.chain, self.name)
@python_2_unicode_compatible
class Employee(models.Model):
# The order of these fields matter, do not change. Certain backends
# rely on field ordering to perform database conversions, and this
# model helps to test that.
first_name = models.CharField(max_length=20)
manager = models.BooleanField(default=False)
last_name = models.CharField(max_length=20)
store = models.ForeignKey(Store)
age = models.IntegerField()
salary = models.DecimalField(max_digits=8, decimal_places=2)
def __str__(self):
return '%s %s' % (self.first_name, self.last_name)
@python_2_unicode_compatible
class Company(models.Model):
name = models.CharField(max_length=200)
motto = models.CharField(max_length=200, null=True, blank=True)
ticker_name = models.CharField(max_length=10, null=True, blank=True)
description = models.CharField(max_length=200, null=True, blank=True)
def __str__(self):
return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)'
% (self.name, self.motto, self.ticker_name, self.description)
)

288
tests/annotations/tests.py Normal file
View File

@ -0,0 +1,288 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
from django.core.exceptions import FieldError
from django.db.models import (
Sum, Count,
F, Value, Func,
IntegerField, BooleanField, CharField)
from django.db.models.fields import FieldDoesNotExist
from django.test import TestCase
from .models import Author, Book, Store, DepartmentStore, Company, Employee
class NonAggregateAnnotationTestCase(TestCase):
fixtures = ["annotations.json"]
def test_basic_annotation(self):
books = Book.objects.annotate(
is_book=Value(1, output_field=IntegerField()))
for book in books:
self.assertEqual(book.is_book, 1)
def test_basic_f_annotation(self):
books = Book.objects.annotate(another_rating=F('rating'))
for book in books:
self.assertEqual(book.another_rating, book.rating)
def test_joined_annotation(self):
books = Book.objects.select_related('publisher').annotate(
num_awards=F('publisher__num_awards'))
for book in books:
self.assertEqual(book.num_awards, book.publisher.num_awards)
def test_annotate_with_aggregation(self):
books = Book.objects.annotate(
is_book=Value(1, output_field=IntegerField()),
rating_count=Count('rating'))
for book in books:
self.assertEqual(book.is_book, 1)
self.assertEqual(book.rating_count, 1)
def test_aggregate_over_annotation(self):
agg = Author.objects.annotate(other_age=F('age')).aggregate(otherage_sum=Sum('other_age'))
other_agg = Author.objects.aggregate(age_sum=Sum('age'))
self.assertEqual(agg['otherage_sum'], other_agg['age_sum'])
def test_filter_annotation(self):
books = Book.objects.annotate(
is_book=Value(1, output_field=IntegerField())
).filter(is_book=1)
for book in books:
self.assertEqual(book.is_book, 1)
def test_filter_annotation_with_f(self):
books = Book.objects.annotate(
other_rating=F('rating')
).filter(other_rating=3.5)
for book in books:
self.assertEqual(book.other_rating, 3.5)
def test_filter_annotation_with_double_f(self):
books = Book.objects.annotate(
other_rating=F('rating')
).filter(other_rating=F('rating'))
for book in books:
self.assertEqual(book.other_rating, book.rating)
def test_filter_agg_with_double_f(self):
books = Book.objects.annotate(
sum_rating=Sum('rating')
).filter(sum_rating=F('sum_rating'))
for book in books:
self.assertEqual(book.sum_rating, book.rating)
def test_filter_wrong_annotation(self):
with self.assertRaisesRegexp(FieldError, "Cannot resolve keyword .*"):
list(Book.objects.annotate(
sum_rating=Sum('rating')
).filter(sum_rating=F('nope')))
def test_update_with_annotation(self):
book_preupdate = Book.objects.get(pk=2)
Book.objects.annotate(other_rating=F('rating') - 1).update(rating=F('other_rating'))
book_postupdate = Book.objects.get(pk=2)
self.assertEqual(book_preupdate.rating - 1, book_postupdate.rating)
def test_annotation_with_m2m(self):
books = Book.objects.annotate(author_age=F('authors__age')).filter(pk=1).order_by('author_age')
self.assertEqual(books[0].author_age, 34)
self.assertEqual(books[1].author_age, 35)
def test_annotation_reverse_m2m(self):
books = Book.objects.annotate(
store_name=F('store__name')).filter(
name='Practical Django Projects').order_by(
'store_name')
self.assertQuerysetEqual(
books, [
'Amazon.com',
'Books.com',
'Mamma and Pappa\'s Books'
],
lambda b: b.store_name
)
def test_values_annotation(self):
"""
Annotations can reference fields in a values clause,
and contribute to an existing values clause.
"""
# annotate references a field in values()
qs = Book.objects.values('rating').annotate(other_rating=F('rating') - 1)
book = qs.get(pk=1)
self.assertEqual(book['rating'] - 1, book['other_rating'])
# filter refs the annotated value
book = qs.get(other_rating=4)
self.assertEqual(book['other_rating'], 4)
# can annotate an existing values with a new field
book = qs.annotate(other_isbn=F('isbn')).get(other_rating=4)
self.assertEqual(book['other_rating'], 4)
self.assertEqual(book['other_isbn'], '155860191')
def test_defer_annotation(self):
"""
Deferred attributes can be referenced by an annotation,
but they are not themselves deferred, and cannot be deferred.
"""
qs = Book.objects.defer('rating').annotate(other_rating=F('rating') - 1)
with self.assertNumQueries(2):
book = qs.get(other_rating=4)
self.assertEqual(book.rating, 5)
self.assertEqual(book.other_rating, 4)
with self.assertRaisesRegexp(FieldDoesNotExist, "\w has no field named u?'other_rating'"):
book = qs.defer('other_rating').get(other_rating=4)
def test_mti_annotations(self):
"""
Fields on an inherited model can be referenced by an
annotated field.
"""
d = DepartmentStore.objects.create(
name='Angus & Robinson',
original_opening=datetime.date(2014, 3, 8),
friday_night_closing=datetime.time(21, 00, 00),
chain='Westfield'
)
books = Book.objects.filter(rating__gt=4)
for b in books:
d.books.add(b)
qs = DepartmentStore.objects.annotate(
other_name=F('name'),
other_chain=F('chain'),
is_open=Value(True, BooleanField()),
book_isbn=F('books__isbn')
).select_related('store').order_by('book_isbn').filter(chain='Westfield')
self.assertQuerysetEqual(
qs, [
('Angus & Robinson', 'Westfield', True, '155860191'),
('Angus & Robinson', 'Westfield', True, '159059725')
],
lambda d: (d.other_name, d.other_chain, d.is_open, d.book_isbn)
)
def test_column_field_ordering(self):
"""
Test that columns are aligned in the correct order for
resolve_columns. This test will fail on mysql if column
ordering is out. Column fields should be aligned as:
1. extra_select
2. model_fields
3. annotation_fields
4. model_related_fields
"""
store = Store.objects.first()
Employee.objects.create(id=1, first_name='Max', manager=True, last_name='Paine',
store=store, age=23, salary=Decimal(50000.00))
Employee.objects.create(id=2, first_name='Buffy', manager=False, last_name='Summers',
store=store, age=18, salary=Decimal(40000.00))
qs = Employee.objects.extra(
select={'random_value': '42'}
).select_related('store').annotate(
annotated_value=Value(17, output_field=IntegerField())
)
rows = [
(1, 'Max', True, 42, 'Paine', 23, Decimal(50000.00), store.name, 17),
(2, 'Buffy', False, 42, 'Summers', 18, Decimal(40000.00), store.name, 17)
]
self.assertQuerysetEqual(
qs.order_by('id'), rows,
lambda e: (
e.id, e.first_name, e.manager, e.random_value, e.last_name, e.age,
e.salary, e.store.name, e.annotated_value))
def test_column_field_ordering_with_deferred(self):
store = Store.objects.first()
Employee.objects.create(id=1, first_name='Max', manager=True, last_name='Paine',
store=store, age=23, salary=Decimal(50000.00))
Employee.objects.create(id=2, first_name='Buffy', manager=False, last_name='Summers',
store=store, age=18, salary=Decimal(40000.00))
qs = Employee.objects.extra(
select={'random_value': '42'}
).select_related('store').annotate(
annotated_value=Value(17, output_field=IntegerField())
)
rows = [
(1, 'Max', True, 42, 'Paine', 23, Decimal(50000.00), store.name, 17),
(2, 'Buffy', False, 42, 'Summers', 18, Decimal(40000.00), store.name, 17)
]
# and we respect deferred columns!
self.assertQuerysetEqual(
qs.defer('age').order_by('id'), rows,
lambda e: (
e.id, e.first_name, e.manager, e.random_value, e.last_name, e.age,
e.salary, e.store.name, e.annotated_value))
def test_custom_functions(self):
Company(name='Apple', motto=None, ticker_name='APPL', description='Beautiful Devices').save()
Company(name='Django Software Foundation', motto=None, ticker_name=None, description=None).save()
Company(name='Google', motto='Do No Evil', ticker_name='GOOG', description='Internet Company').save()
Company(name='Yahoo', motto=None, ticker_name=None, description='Internet Company').save()
qs = Company.objects.annotate(
tagline=Func(
F('motto'),
F('ticker_name'),
F('description'),
Value('No Tag'),
function='COALESCE')
).order_by('name')
self.assertQuerysetEqual(
qs, [
('Apple', 'APPL'),
('Django Software Foundation', 'No Tag'),
('Google', 'Do No Evil'),
('Yahoo', 'Internet Company')
],
lambda c: (c.name, c.tagline)
)
def test_custom_functions_can_ref_other_functions(self):
Company(name='Apple', motto=None, ticker_name='APPL', description='Beautiful Devices').save()
Company(name='Django Software Foundation', motto=None, ticker_name=None, description=None).save()
Company(name='Google', motto='Do No Evil', ticker_name='GOOG', description='Internet Company').save()
Company(name='Yahoo', motto=None, ticker_name=None, description='Internet Company').save()
class Lower(Func):
function = 'LOWER'
qs = Company.objects.annotate(
tagline=Func(
F('motto'),
F('ticker_name'),
F('description'),
Value('No Tag'),
function='COALESCE')
).annotate(
tagline_lower=Lower(F('tagline'), output_field=CharField())
).order_by('name')
# LOWER function supported by:
# oracle, postgres, mysql, sqlite, sqlserver
self.assertQuerysetEqual(
qs, [
('Apple', 'APPL'.lower()),
('Django Software Foundation', 'No Tag'.lower()),
('Google', 'Do No Evil'.lower()),
('Yahoo', 'Internet Company'.lower())
],
lambda c: (c.name, c.tagline_lower)
)

View File

@ -296,6 +296,21 @@ class ExpressionsTests(TestCase):
g = deepcopy(f) g = deepcopy(f)
self.assertEqual(f.name, g.name) self.assertEqual(f.name, g.name)
def test_f_reuse(self):
f = F('id')
n = Number.objects.create(integer=-1)
c = Company.objects.create(
name="Example Inc.", num_employees=2300, num_chairs=5,
ceo=Employee.objects.create(firstname="Joe", lastname="Smith")
)
c_qs = Company.objects.filter(id=f)
self.assertEqual(c_qs.get(), c)
# Reuse the same F-object for another queryset
n_qs = Number.objects.filter(id=f)
self.assertEqual(n_qs.get(), n)
# The original query still works correctly
self.assertEqual(c_qs.get(), c)
class ExpressionsNumericTests(TestCase): class ExpressionsNumericTests(TestCase):
@ -362,12 +377,16 @@ class ExpressionsNumericTests(TestCase):
Complex expressions of different connection types are possible. Complex expressions of different connection types are possible.
""" """
n = Number.objects.create(integer=10, float=123.45) n = Number.objects.create(integer=10, float=123.45)
self.assertEqual(Number.objects.filter(pk=n.pk) self.assertEqual(Number.objects.filter(pk=n.pk).update(
.update(float=F('integer') + F('float') * 2), 1) float=F('integer') + F('float') * 2), 1)
self.assertEqual(Number.objects.get(pk=n.pk).integer, 10) self.assertEqual(Number.objects.get(pk=n.pk).integer, 10)
self.assertEqual(Number.objects.get(pk=n.pk).float, Approximate(256.900, places=3)) self.assertEqual(Number.objects.get(pk=n.pk).float, Approximate(256.900, places=3))
def test_incorrect_field_expression(self):
with self.assertRaisesRegexp(FieldError, "Cannot resolve keyword u?'nope' into field.*"):
list(Employee.objects.filter(firstname=F('nope')))
class ExpressionOperatorTests(TestCase): class ExpressionOperatorTests(TestCase):
def setUp(self): def setUp(self):