Fixed #24020 -- Refactored SQL compiler to use expressions

Refactored compiler SELECT, GROUP BY and ORDER BY generation.
While there, also refactored select_related() implementation
(get_cached_row() and get_klass_info() are now gone!).

Made get_db_converters() method work on expressions instead of
internal_type. This allows the backend converters to target
specific expressions if need be.

Added query.context, this can be used to set per-query state.

Also changed the signature of database converters. They now accept
context as an argument.
This commit is contained in:
Anssi Kääriäinen 2014-12-01 09:28:01 +02:00 committed by Tim Graham
parent b8abfe141b
commit 0c7633178f
41 changed files with 970 additions and 1416 deletions

View File

@ -10,7 +10,6 @@ from django.db.models import signals, DO_NOTHING
from django.db.models.base import ModelBase from django.db.models.base import ModelBase
from django.db.models.fields.related import ForeignObject, ForeignObjectRel from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.db.models.expressions import Col
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.utils.encoding import smart_text, python_2_unicode_compatible from django.utils.encoding import smart_text, python_2_unicode_compatible
@ -367,7 +366,7 @@ class GenericRelation(ForeignObject):
field = self.rel.to._meta.get_field(self.content_type_field_name) field = self.rel.to._meta.get_field(self.content_type_field_name)
contenttype_pk = self.get_content_type().pk contenttype_pk = self.get_content_type().pk
cond = where_class() cond = where_class()
lookup = field.get_lookup('exact')(Col(remote_alias, field, field), contenttype_pk) lookup = field.get_lookup('exact')(field.get_col(remote_alias), contenttype_pk)
cond.add(lookup, 'AND') cond.add(lookup, 'AND')
return cond return cond

View File

@ -158,10 +158,10 @@ class BaseSpatialOperations(object):
# Default conversion functions for aggregates; will be overridden if implemented # Default conversion functions for aggregates; will be overridden if implemented
# for the spatial backend. # for the spatial backend.
def convert_extent(self, box): def convert_extent(self, box, srid):
raise NotImplementedError('Aggregate extent not implemented for this spatial backend.') raise NotImplementedError('Aggregate extent not implemented for this spatial backend.')
def convert_extent3d(self, box): def convert_extent3d(self, box, srid):
raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.') raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.')
def convert_geom(self, geom_val, geom_field): def convert_geom(self, geom_val, geom_field):

View File

@ -7,7 +7,6 @@ from django.contrib.gis.db.backends.utils import SpatialOperator
class MySQLOperations(DatabaseOperations, BaseSpatialOperations): class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
mysql = True mysql = True
name = 'mysql' name = 'mysql'
select = 'AsText(%s)' select = 'AsText(%s)'

View File

@ -1,24 +0,0 @@
from django.contrib.gis.db.models.sql.compiler import GeoSQLCompiler as BaseGeoSQLCompiler
from django.db.backends.oracle import compiler
SQLCompiler = compiler.SQLCompiler
class GeoSQLCompiler(BaseGeoSQLCompiler, SQLCompiler):
pass
class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
pass

View File

@ -52,7 +52,6 @@ class SDORelate(SpatialOperator):
class OracleOperations(DatabaseOperations, BaseSpatialOperations): class OracleOperations(DatabaseOperations, BaseSpatialOperations):
compiler_module = "django.contrib.gis.db.backends.oracle.compiler"
name = 'oracle' name = 'oracle'
oracle = True oracle = True
@ -111,8 +110,9 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
def geo_quote_name(self, name): def geo_quote_name(self, name):
return super(OracleOperations, self).geo_quote_name(name).upper() return super(OracleOperations, self).geo_quote_name(name).upper()
def get_db_converters(self, internal_type): def get_db_converters(self, expression):
converters = super(OracleOperations, self).get_db_converters(internal_type) converters = super(OracleOperations, self).get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
geometry_fields = ( geometry_fields = (
'PointField', 'GeometryField', 'LineStringField', 'PointField', 'GeometryField', 'LineStringField',
'PolygonField', 'MultiPointField', 'MultiLineStringField', 'PolygonField', 'MultiPointField', 'MultiLineStringField',
@ -121,14 +121,23 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
) )
if internal_type in geometry_fields: if internal_type in geometry_fields:
converters.append(self.convert_textfield_value) converters.append(self.convert_textfield_value)
if hasattr(expression.output_field, 'geom_type'):
converters.append(self.convert_geometry)
return converters return converters
def convert_extent(self, clob): def convert_geometry(self, value, expression, context):
if value:
value = Geometry(value)
if 'transformed_srid' in context:
value.srid = context['transformed_srid']
return value
def convert_extent(self, clob, srid):
if clob: if clob:
# Generally, Oracle returns a polygon for the extent -- however, # Generally, Oracle returns a polygon for the extent -- however,
# it can return a single point if there's only one Point in the # it can return a single point if there's only one Point in the
# table. # table.
ext_geom = Geometry(clob.read()) ext_geom = Geometry(clob.read(), srid)
gtype = str(ext_geom.geom_type) gtype = str(ext_geom.geom_type)
if gtype == 'Polygon': if gtype == 'Polygon':
# Construct the 4-tuple from the coordinates in the polygon. # Construct the 4-tuple from the coordinates in the polygon.
@ -226,7 +235,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
else: else:
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
sql_function = getattr(self, agg_name) sql_function = getattr(self, agg_name)
return self.select % sql_template, sql_function return sql_template, sql_function
# Routines for getting the OGC-compliant models. # Routines for getting the OGC-compliant models.
def geometry_columns(self): def geometry_columns(self):

View File

@ -44,7 +44,6 @@ class PostGISDistanceOperator(PostGISOperator):
class PostGISOperations(DatabaseOperations, BaseSpatialOperations): class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
name = 'postgis' name = 'postgis'
postgis = True postgis = True
geography = True geography = True
@ -188,7 +187,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
agg_name = aggregate.__class__.__name__ agg_name = aggregate.__class__.__name__
return agg_name in self.valid_aggregates return agg_name in self.valid_aggregates
def convert_extent(self, box): def convert_extent(self, box, srid):
""" """
Returns a 4-tuple extent for the `Extent` aggregate by converting Returns a 4-tuple extent for the `Extent` aggregate by converting
the bounding box text returned by PostGIS (`box` argument), for the bounding box text returned by PostGIS (`box` argument), for
@ -199,7 +198,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
xmax, ymax = map(float, ur.split()) xmax, ymax = map(float, ur.split())
return (xmin, ymin, xmax, ymax) return (xmin, ymin, xmax, ymax)
def convert_extent3d(self, box3d): def convert_extent3d(self, box3d, srid):
""" """
Returns a 6-tuple extent for the `Extent3D` aggregate by converting Returns a 6-tuple extent for the `Extent3D` aggregate by converting
the 3d bounding-box text returned by PostGIS (`box3d` argument), for the 3d bounding-box text returned by PostGIS (`box3d` argument), for

View File

@ -14,7 +14,6 @@ from django.utils.functional import cached_property
class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
name = 'spatialite' name = 'spatialite'
spatialite = True spatialite = True
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)') version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
@ -131,11 +130,11 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
agg_name = aggregate.__class__.__name__ agg_name = aggregate.__class__.__name__
return agg_name in self.valid_aggregates return agg_name in self.valid_aggregates
def convert_extent(self, box): def convert_extent(self, box, srid):
""" """
Convert the polygon data received from Spatialite to min/max values. Convert the polygon data received from Spatialite to min/max values.
""" """
shell = Geometry(box).shell shell = Geometry(box, srid).shell
xmin, ymin = shell[0][:2] xmin, ymin = shell[0][:2]
xmax, ymax = shell[2][:2] xmax, ymax = shell[2][:2]
return (xmin, ymin, xmax, ymax) return (xmin, ymin, xmax, ymax)
@ -256,7 +255,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
agg_name = agg_name.lower() agg_name = agg_name.lower()
if agg_name == 'union': if agg_name == 'union':
agg_name += 'agg' agg_name += 'agg'
sql_template = self.select % '%(function)s(%(expressions)s)' sql_template = '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name) sql_function = getattr(self, agg_name)
return sql_template, sql_function return sql_template, sql_function
@ -268,3 +267,16 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
def spatial_ref_sys(self): def spatial_ref_sys(self):
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys
return SpatialiteSpatialRefSys return SpatialiteSpatialRefSys
def get_db_converters(self, expression):
converters = super(SpatiaLiteOperations, self).get_db_converters(expression)
if hasattr(expression.output_field, 'geom_type'):
converters.append(self.convert_geometry)
return converters
def convert_geometry(self, value, expression, context):
if value:
value = Geometry(value)
if 'transformed_srid' in context:
value.srid = context['transformed_srid']
return value

View File

@ -28,7 +28,7 @@ class GeoAggregate(Aggregate):
raise ValueError('Geospatial aggregates only allowed on geometry fields.') raise ValueError('Geospatial aggregates only allowed on geometry fields.')
return c return c
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
return connection.ops.convert_geom(value, self.output_field) return connection.ops.convert_geom(value, self.output_field)
@ -43,8 +43,8 @@ class Extent(GeoAggregate):
def __init__(self, expression, **extra): def __init__(self, expression, **extra):
super(Extent, self).__init__(expression, output_field=ExtentField(), **extra) super(Extent, self).__init__(expression, output_field=ExtentField(), **extra)
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
return connection.ops.convert_extent(value) return connection.ops.convert_extent(value, context.get('transformed_srid'))
class Extent3D(GeoAggregate): class Extent3D(GeoAggregate):
@ -54,8 +54,8 @@ class Extent3D(GeoAggregate):
def __init__(self, expression, **extra): def __init__(self, expression, **extra):
super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra) super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra)
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
return connection.ops.convert_extent3d(value) return connection.ops.convert_extent3d(value, context.get('transformed_srid'))
class MakeLine(GeoAggregate): class MakeLine(GeoAggregate):

View File

@ -42,7 +42,30 @@ def get_srid_info(srid, connection):
return _srid_cache[connection.alias][srid] return _srid_cache[connection.alias][srid]
class GeometryField(Field): class GeoSelectFormatMixin(object):
def select_format(self, compiler, sql, params):
"""
Returns the selection format string, depending on the requirements
of the spatial backend. For example, Oracle and MySQL require custom
selection formats in order to retrieve geometries in OGC WKT. For all
other fields a simple '%s' format string is returned.
"""
connection = compiler.connection
srid = compiler.query.get_context('transformed_srid')
if srid:
sel_fmt = '%s(%%s, %s)' % (connection.ops.transform, srid)
else:
sel_fmt = '%s'
if connection.ops.select:
# This allows operations to be done on fields in the SELECT,
# overriding their values -- used by the Oracle and MySQL
# spatial backends to get database values as WKT, and by the
# `transform` method.
sel_fmt = connection.ops.select % sel_fmt
return sel_fmt % sql, params
class GeometryField(GeoSelectFormatMixin, Field):
"The base GIS field -- maps to the OpenGIS Specification Geometry type." "The base GIS field -- maps to the OpenGIS Specification Geometry type."
# The OpenGIS Geometry name. # The OpenGIS Geometry name.
@ -196,7 +219,7 @@ class GeometryField(Field):
else: else:
return geom return geom
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
if value and not isinstance(value, Geometry): if value and not isinstance(value, Geometry):
value = Geometry(value) value = Geometry(value)
return value return value
@ -337,7 +360,7 @@ class GeometryCollectionField(GeometryField):
description = _("Geometry collection") description = _("Geometry collection")
class ExtentField(Field): class ExtentField(GeoSelectFormatMixin, Field):
"Used as a return value from an extent aggregate" "Used as a return value from an extent aggregate"
description = _("Extent Aggregate Field") description = _("Extent Aggregate Field")

View File

@ -1,9 +1,16 @@
from django.db import connections from django.db import connections
from django.db.models.expressions import RawSQL
from django.db.models.fields import Field
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.contrib.gis.db.models import aggregates from django.contrib.gis.db.models import aggregates
from django.contrib.gis.db.models.fields import get_srid_info, PointField, LineStringField from django.contrib.gis.db.models.fields import (
from django.contrib.gis.db.models.sql import AreaField, DistanceField, GeomField, GeoQuery, GMLField get_srid_info, LineStringField, GeometryField, PointField,
)
from django.contrib.gis.db.models.lookups import GISLookup
from django.contrib.gis.db.models.sql import (
AreaField, DistanceField, GeomField, GMLField,
)
from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Area, Distance from django.contrib.gis.measure import Area, Distance
@ -13,11 +20,6 @@ from django.utils import six
class GeoQuerySet(QuerySet): class GeoQuerySet(QuerySet):
"The Geographic QuerySet." "The Geographic QuerySet."
### Methods overloaded from QuerySet ###
def __init__(self, model=None, query=None, using=None, hints=None):
super(GeoQuerySet, self).__init__(model=model, query=query, using=using, hints=hints)
self.query = query or GeoQuery(self.model)
### GeoQuerySet Methods ### ### GeoQuerySet Methods ###
def area(self, tolerance=0.05, **kwargs): def area(self, tolerance=0.05, **kwargs):
""" """
@ -26,7 +28,8 @@ class GeoQuerySet(QuerySet):
""" """
# Performing setup here rather than in `_spatial_attribute` so that # Performing setup here rather than in `_spatial_attribute` so that
# we can get the units for `AreaField`. # we can get the units for `AreaField`.
procedure_args, geo_field = self._spatial_setup('area', field_name=kwargs.get('field_name', None)) procedure_args, geo_field = self._spatial_setup(
'area', field_name=kwargs.get('field_name', None))
s = {'procedure_args': procedure_args, s = {'procedure_args': procedure_args,
'geo_field': geo_field, 'geo_field': geo_field,
'setup': False, 'setup': False,
@ -378,24 +381,8 @@ class GeoQuerySet(QuerySet):
if not isinstance(srid, six.integer_types): if not isinstance(srid, six.integer_types):
raise TypeError('An integer SRID must be provided.') raise TypeError('An integer SRID must be provided.')
field_name = kwargs.get('field_name', None) field_name = kwargs.get('field_name', None)
tmp, geo_field = self._spatial_setup('transform', field_name=field_name) self._spatial_setup('transform', field_name=field_name)
self.query.add_context('transformed_srid', srid)
# Getting the selection SQL for the given geographic field.
field_col = self._geocol_select(geo_field, field_name)
# Why cascading substitutions? Because spatial backends like
# Oracle and MySQL already require a function call to convert to text, thus
# when there's also a transformation we need to cascade the substitutions.
# For example, 'SDO_UTIL.TO_WKTGEOMETRY(SDO_CS.TRANSFORM( ... )'
geo_col = self.query.custom_select.get(geo_field, field_col)
# Setting the key for the field's column with the custom SELECT SQL to
# override the geometry column returned from the database.
custom_sel = '%s(%s, %s)' % (connections[self.db].ops.transform, geo_col, srid)
# TODO: Should we have this as an alias?
# custom_sel = '(%s(%s, %s)) AS %s' % (SpatialBackend.transform, geo_col, srid, qn(geo_field.name))
self.query.transformed_srid = srid # So other GeoQuerySet methods
self.query.custom_select[geo_field] = custom_sel
return self._clone() return self._clone()
def union(self, geom, **kwargs): def union(self, geom, **kwargs):
@ -433,7 +420,7 @@ class GeoQuerySet(QuerySet):
# Is there a geographic field in the model to perform this # Is there a geographic field in the model to perform this
# operation on? # operation on?
geo_field = self.query._geo_field(field_name) geo_field = self._geo_field(field_name)
if not geo_field: if not geo_field:
raise TypeError('%s output only available on GeometryFields.' % func) raise TypeError('%s output only available on GeometryFields.' % func)
@ -454,7 +441,7 @@ class GeoQuerySet(QuerySet):
returning their result to the caller of the function. returning their result to the caller of the function.
""" """
# Getting the field the geographic aggregate will be called on. # Getting the field the geographic aggregate will be called on.
geo_field = self.query._geo_field(field_name) geo_field = self._geo_field(field_name)
if not geo_field: if not geo_field:
raise TypeError('%s aggregate only available on GeometryFields.' % aggregate.name) raise TypeError('%s aggregate only available on GeometryFields.' % aggregate.name)
@ -509,11 +496,11 @@ class GeoQuerySet(QuerySet):
settings.setdefault('select_params', []) settings.setdefault('select_params', [])
connection = connections[self.db] connection = connections[self.db]
backend = connection.ops
# Performing setup for the spatial column, unless told not to. # Performing setup for the spatial column, unless told not to.
if settings.get('setup', True): if settings.get('setup', True):
default_args, geo_field = self._spatial_setup(att, desc=settings['desc'], field_name=field_name, default_args, geo_field = self._spatial_setup(
att, desc=settings['desc'], field_name=field_name,
geo_field_type=settings.get('geo_field_type', None)) geo_field_type=settings.get('geo_field_type', None))
for k, v in six.iteritems(default_args): for k, v in six.iteritems(default_args):
settings['procedure_args'].setdefault(k, v) settings['procedure_args'].setdefault(k, v)
@ -544,18 +531,19 @@ class GeoQuerySet(QuerySet):
# If the result of this function needs to be converted. # If the result of this function needs to be converted.
if settings.get('select_field', False): if settings.get('select_field', False):
sel_fld = settings['select_field'] select_field = settings['select_field']
if isinstance(sel_fld, GeomField) and backend.select:
self.query.custom_select[model_att] = backend.select
if connection.ops.oracle: if connection.ops.oracle:
sel_fld.empty_strings_allowed = False select_field.empty_strings_allowed = False
self.query.extra_select_fields[model_att] = sel_fld else:
select_field = Field()
# Finally, setting the extra selection attribute with # Finally, setting the extra selection attribute with
# the format string expanded with the stored procedure # the format string expanded with the stored procedure
# arguments. # arguments.
return self.extra(select={model_att: fmt % settings['procedure_args']}, self.query.add_annotation(
select_params=settings['select_params']) RawSQL(fmt % settings['procedure_args'], settings['select_params'], select_field),
model_att)
return self
def _distance_attribute(self, func, geom=None, tolerance=0.05, spheroid=False, **kwargs): def _distance_attribute(self, func, geom=None, tolerance=0.05, spheroid=False, **kwargs):
""" """
@ -616,8 +604,9 @@ class GeoQuerySet(QuerySet):
else: else:
# Getting whether this field is in units of degrees since the field may have # Getting whether this field is in units of degrees since the field may have
# been transformed via the `transform` GeoQuerySet method. # been transformed via the `transform` GeoQuerySet method.
if self.query.transformed_srid: srid = self.query.get_context('transformed_srid')
u, unit_name, s = get_srid_info(self.query.transformed_srid, connection) if srid:
u, unit_name, s = get_srid_info(srid, connection)
geodetic = unit_name.lower() in geo_field.geodetic_units geodetic = unit_name.lower() in geo_field.geodetic_units
if geodetic and not connection.features.supports_distance_geodetic: if geodetic and not connection.features.supports_distance_geodetic:
@ -627,20 +616,20 @@ class GeoQuerySet(QuerySet):
) )
if distance: if distance:
if self.query.transformed_srid: if srid:
# Setting the `geom_args` flag to false because we want to handle # Setting the `geom_args` flag to false because we want to handle
# transformation SQL here, rather than the way done by default # transformation SQL here, rather than the way done by default
# (which will transform to the original SRID of the field rather # (which will transform to the original SRID of the field rather
# than to what was transformed to). # than to what was transformed to).
geom_args = False geom_args = False
procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, self.query.transformed_srid) procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, srid)
if geom.srid is None or geom.srid == self.query.transformed_srid: if geom.srid is None or geom.srid == srid:
# If the geom parameter srid is None, it is assumed the coordinates # If the geom parameter srid is None, it is assumed the coordinates
# are in the transformed units. A placeholder is used for the # are in the transformed units. A placeholder is used for the
# geometry parameter. `GeomFromText` constructor is also needed # geometry parameter. `GeomFromText` constructor is also needed
# to wrap geom placeholder for SpatiaLite. # to wrap geom placeholder for SpatiaLite.
if backend.spatialite: if backend.spatialite:
procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, self.query.transformed_srid) procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, srid)
else: else:
procedure_fmt += ', %%s' procedure_fmt += ', %%s'
else: else:
@ -649,10 +638,11 @@ class GeoQuerySet(QuerySet):
# SpatiaLite also needs geometry placeholder wrapped in `GeomFromText` # SpatiaLite also needs geometry placeholder wrapped in `GeomFromText`
# constructor. # constructor.
if backend.spatialite: if backend.spatialite:
procedure_fmt += ', %s(%s(%%%%s, %s), %s)' % (backend.transform, backend.from_text, procedure_fmt += (', %s(%s(%%%%s, %s), %s)' % (
geom.srid, self.query.transformed_srid) backend.transform, backend.from_text,
geom.srid, srid))
else: else:
procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, self.query.transformed_srid) procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, srid)
else: else:
# `transform()` was not used on this GeoQuerySet. # `transform()` was not used on this GeoQuerySet.
procedure_fmt = '%(geo_col)s,%(geom)s' procedure_fmt = '%(geo_col)s,%(geom)s'
@ -743,22 +733,57 @@ class GeoQuerySet(QuerySet):
column. Takes into account if the geographic field is in a column. Takes into account if the geographic field is in a
ForeignKey relation to the current model. ForeignKey relation to the current model.
""" """
compiler = self.query.get_compiler(self.db)
opts = self.model._meta opts = self.model._meta
if geo_field not in opts.fields: if geo_field not in opts.fields:
# Is this operation going to be on a related geographic field? # Is this operation going to be on a related geographic field?
# If so, it'll have to be added to the select related information # If so, it'll have to be added to the select related information
# (e.g., if 'location__point' was given as the field name). # (e.g., if 'location__point' was given as the field name).
# Note: the operation really is defined as "must add select related!"
self.query.add_select_related([field_name]) self.query.add_select_related([field_name])
compiler = self.query.get_compiler(self.db) # Call pre_sql_setup() so that compiler.select gets populated.
compiler.pre_sql_setup() compiler.pre_sql_setup()
for (rel_table, rel_col), field in self.query.related_select_cols: for col, _, _ in compiler.select:
if field == geo_field: if col.output_field == geo_field:
return compiler._field_column(geo_field, rel_table) return col.as_sql(compiler, compiler.connection)[0]
raise ValueError("%r not in self.query.related_select_cols" % geo_field) raise ValueError("%r not in compiler's related_select_cols" % geo_field)
elif geo_field not in opts.local_fields: elif geo_field not in opts.local_fields:
# This geographic field is inherited from another model, so we have to # This geographic field is inherited from another model, so we have to
# use the db table for the _parent_ model instead. # use the db table for the _parent_ model instead.
parent_model = geo_field.model._meta.concrete_model parent_model = geo_field.model._meta.concrete_model
return self.query.get_compiler(self.db)._field_column(geo_field, parent_model._meta.db_table) return self._field_column(compiler, geo_field, parent_model._meta.db_table)
else: else:
return self.query.get_compiler(self.db)._field_column(geo_field) return self._field_column(compiler, geo_field)
# Private API utilities, subject to change.
def _geo_field(self, field_name=None):
"""
Returns the first Geometry field encountered or the one specified via
the `field_name` keyword. The `field_name` may be a string specifying
the geometry field on this GeoQuerySet's model, or a lookup string
to a geometry field via a ForeignKey relation.
"""
if field_name is None:
# Incrementing until the first geographic field is found.
for field in self.model._meta.fields:
if isinstance(field, GeometryField):
return field
return False
else:
# Otherwise, check by the given field name -- which may be
# a lookup to a _related_ geographic field.
return GISLookup._check_geo_field(self.model._meta, field_name)
def _field_column(self, compiler, field, table_alias=None, column=None):
"""
Helper function that returns the database column for the given field.
The table and column are returned (quoted) in the proper format, e.g.,
`"geoapp_city"."point"`. If `table_alias` is not specified, the
database table associated with the model of this `GeoQuerySet` will be
used. If `column` is specified, it will be used instead of the value
in `field.column`.
"""
if table_alias is None:
table_alias = compiler.query.get_meta().db_table
return "%s.%s" % (compiler.quote_name_unless_alias(table_alias),
compiler.connection.ops.quote_name(column or field.column))

View File

@ -1,6 +1,5 @@
from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField, GMLField from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField, GMLField
from django.contrib.gis.db.models.sql.query import GeoQuery
__all__ = [ __all__ = [
'AreaField', 'DistanceField', 'GeomField', 'GMLField', 'GeoQuery', 'AreaField', 'DistanceField', 'GeomField', 'GMLField'
] ]

View File

@ -1,240 +0,0 @@
from django.db.backends.utils import truncate_name
from django.db.models.sql import compiler
from django.utils import six
SQLCompiler = compiler.SQLCompiler
class GeoSQLCompiler(compiler.SQLCompiler):
def get_columns(self, with_aliases=False):
"""
Return the list of columns to use in the select statement. If no
columns have been specified, returns all columns relating to fields in
the model.
If 'with_aliases' is true, any column names that are duplicated
(without the table names) are given unique aliases. This is needed in
some cases to avoid ambiguity with nested queries.
This routine is overridden from Query to handle customized selection of
geometry columns.
"""
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
for alias, col in six.iteritems(self.query.extra_select)]
params = []
aliases = set(self.query.extra_select.keys())
if with_aliases:
col_aliases = aliases.copy()
else:
col_aliases = set()
if self.query.select:
only_load = self.deferred_to_columns()
# This loop customized for GeoQuery.
for col, field in self.query.select:
if isinstance(col, (list, tuple)):
alias, column = col
table = self.query.alias_map[alias].table_name
if table in only_load and column not in only_load[table]:
continue
r = self.get_field_select(field, alias, column)
if with_aliases:
if col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append('%s AS %s' % (r, qn2(col[1])))
aliases.add(r)
col_aliases.add(col[1])
else:
result.append(r)
aliases.add(r)
col_aliases.add(col[1])
else:
col_sql, col_params = col.as_sql(self, self.connection)
result.append(col_sql)
params.extend(col_params)
if hasattr(col, 'alias'):
aliases.add(col.alias)
col_aliases.add(col.alias)
elif self.query.default_cols:
cols, new_aliases = self.get_default_columns(with_aliases,
col_aliases)
result.extend(cols)
aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length()
for alias, annotation in self.query.annotation_select.items():
agg_sql, agg_params = self.compile(annotation)
if alias is None:
result.append(agg_sql)
else:
result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
params.extend(agg_params)
# This loop customized for GeoQuery.
for (table, col), field in self.query.related_select_cols:
r = self.get_field_select(field, table, col)
if with_aliases and col in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append(r)
aliases.add(r)
col_aliases.add(col)
self._select_aliases = aliases
return result, params
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, from_parent=None):
"""
Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
select_related), in which case "opts" and "start_alias" will be given
to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement (if
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
of strings as the first component and None as the second component).
This routine is overridden from Query to handle customized selection of
geometry columns.
"""
result = []
if opts is None:
opts = self.query.get_meta()
aliases = set()
only_load = self.deferred_to_columns()
seen = self.query.included_inherited_models.copy()
if start_alias:
seen[None] = start_alias
for field in opts.concrete_fields:
model = field.model._meta.concrete_model
if model is opts.model:
model = None
if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents.
continue
alias = self.query.join_parent_model(opts, model, start_alias, seen)
table = self.query.alias_map[alias].table_name
if table in only_load and field.column not in only_load[table]:
continue
if as_pairs:
result.append((alias, field))
aliases.add(alias)
continue
# This part of the function is customized for GeoQuery. We
# see if there was any custom selection specified in the
# dictionary, and set up the selection format appropriately.
field_sel = self.get_field_select(field, alias)
if with_aliases and field.column in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (field_sel, c_alias))
col_aliases.add(c_alias)
aliases.add(c_alias)
else:
r = field_sel
result.append(r)
aliases.add(r)
if with_aliases:
col_aliases.add(field.column)
return result, aliases
def get_converters(self, fields):
converters = super(GeoSQLCompiler, self).get_converters(fields)
for i, alias in enumerate(self.query.extra_select):
field = self.query.extra_select_fields.get(alias)
if field:
backend_converters = self.connection.ops.get_db_converters(field.get_internal_type())
converters[i] = (backend_converters, [field.from_db_value], field)
return converters
#### Routines unique to GeoQuery ####
def get_extra_select_format(self, alias):
sel_fmt = '%s'
if hasattr(self.query, 'custom_select') and alias in self.query.custom_select:
sel_fmt = sel_fmt % self.query.custom_select[alias]
return sel_fmt
def get_field_select(self, field, alias=None, column=None):
"""
Returns the SELECT SQL string for the given field. Figures out
if any custom selection SQL is needed for the column The `alias`
keyword may be used to manually specify the database table where
the column exists, if not in the model associated with this
`GeoQuery`. Similarly, `column` may be used to specify the exact
column name, rather than using the `column` attribute on `field`.
"""
sel_fmt = self.get_select_format(field)
if field in self.query.custom_select:
field_sel = sel_fmt % self.query.custom_select[field]
else:
field_sel = sel_fmt % self._field_column(field, alias, column)
return field_sel
def get_select_format(self, fld):
"""
Returns the selection format string, depending on the requirements
of the spatial backend. For example, Oracle and MySQL require custom
selection formats in order to retrieve geometries in OGC WKT. For all
other fields a simple '%s' format string is returned.
"""
if self.connection.ops.select and hasattr(fld, 'geom_type'):
# This allows operations to be done on fields in the SELECT,
# overriding their values -- used by the Oracle and MySQL
# spatial backends to get database values as WKT, and by the
# `transform` method.
sel_fmt = self.connection.ops.select
# Because WKT doesn't contain spatial reference information,
# the SRID is prefixed to the returned WKT to ensure that the
# transformed geometries have an SRID different than that of the
# field -- this is only used by `transform` for Oracle and
# SpatiaLite backends.
if self.query.transformed_srid and (self.connection.ops.oracle or
self.connection.ops.spatialite):
sel_fmt = "'SRID=%d;'||%s" % (self.query.transformed_srid, sel_fmt)
else:
sel_fmt = '%s'
return sel_fmt
# Private API utilities, subject to change.
def _field_column(self, field, table_alias=None, column=None):
"""
Helper function that returns the database column for the given field.
The table and column are returned (quoted) in the proper format, e.g.,
`"geoapp_city"."point"`. If `table_alias` is not specified, the
database table associated with the model of this `GeoQuery` will be
used. If `column` is specified, it will be used instead of the value
in `field.column`.
"""
if table_alias is None:
table_alias = self.query.get_meta().db_table
return "%s.%s" % (self.quote_name_unless_alias(table_alias),
self.connection.ops.quote_name(column or field.column))
class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
pass

View File

@ -3,6 +3,7 @@ This module holds simple classes to convert geospatial values from the
database. database.
""" """
from django.contrib.gis.db.models.fields import GeoSelectFormatMixin
from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Area, Distance from django.contrib.gis.measure import Area, Distance
@ -10,13 +11,19 @@ from django.contrib.gis.measure import Area, Distance
class BaseField(object): class BaseField(object):
empty_strings_allowed = True empty_strings_allowed = True
def get_db_converters(self, connection):
return [self.from_db_value]
def select_format(self, compiler, sql, params):
return sql, params
class AreaField(BaseField): class AreaField(BaseField):
"Wrapper for Area values." "Wrapper for Area values."
def __init__(self, area_att): def __init__(self, area_att):
self.area_att = area_att self.area_att = area_att
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
if value is not None: if value is not None:
value = Area(**{self.area_att: value}) value = Area(**{self.area_att: value})
return value return value
@ -30,7 +37,7 @@ class DistanceField(BaseField):
def __init__(self, distance_att): def __init__(self, distance_att):
self.distance_att = distance_att self.distance_att = distance_att
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
if value is not None: if value is not None:
value = Distance(**{self.distance_att: value}) value = Distance(**{self.distance_att: value})
return value return value
@ -39,12 +46,15 @@ class DistanceField(BaseField):
return 'DistanceField' return 'DistanceField'
class GeomField(BaseField): class GeomField(GeoSelectFormatMixin, BaseField):
""" """
Wrapper for Geometry values. It is a lightweight alternative to Wrapper for Geometry values. It is a lightweight alternative to
using GeometryField (which requires an SQL query upon instantiation). using GeometryField (which requires an SQL query upon instantiation).
""" """
def from_db_value(self, value, connection): # Hacky marker for get_db_converters()
geom_type = None
def from_db_value(self, value, connection, context):
if value is not None: if value is not None:
value = Geometry(value) value = Geometry(value)
return value return value
@ -61,5 +71,5 @@ class GMLField(BaseField):
def get_internal_type(self): def get_internal_type(self):
return 'GMLField' return 'GMLField'
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
return value return value

View File

@ -1,65 +0,0 @@
from django.db import connections
from django.db.models.query import sql
from django.db.models.sql.constants import QUERY_TERMS
from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.lookups import GISLookup
from django.contrib.gis.db.models import aggregates as gis_aggregates
from django.contrib.gis.db.models.sql.conversion import GeomField
class GeoQuery(sql.Query):
"""
A single spatial SQL query.
"""
# Overriding the valid query terms.
query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys())
compiler = 'GeoSQLCompiler'
#### Methods overridden from the base Query class ####
def __init__(self, model):
super(GeoQuery, self).__init__(model)
# The following attributes are customized for the GeoQuerySet.
# The SpatialBackend classes contain backend-specific routines and functions.
self.custom_select = {}
self.transformed_srid = None
self.extra_select_fields = {}
def clone(self, *args, **kwargs):
obj = super(GeoQuery, self).clone(*args, **kwargs)
# Customized selection dictionary and transformed srid flag have
# to also be added to obj.
obj.custom_select = self.custom_select.copy()
obj.transformed_srid = self.transformed_srid
obj.extra_select_fields = self.extra_select_fields.copy()
return obj
def get_aggregation(self, using, force_subq=False):
# Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery.
connection = connections[using]
for alias, annotation in self.annotation_select.items():
if isinstance(annotation, gis_aggregates.GeoAggregate):
if not getattr(annotation, 'is_extent', False) or connection.ops.oracle:
self.extra_select_fields[alias] = GeomField()
return super(GeoQuery, self).get_aggregation(using, force_subq)
# Private API utilities, subject to change.
def _geo_field(self, field_name=None):
"""
Returns the first Geometry field encountered; or specified via the
`field_name` keyword. The `field_name` may be a string specifying
the geometry field on this GeoQuery's model, or a lookup string
to a geometry field via a ForeignKey relation.
"""
if field_name is None:
# Incrementing until the first geographic field is found.
for fld in self.model._meta.fields:
if isinstance(fld, GeometryField):
return fld
return False
else:
# Otherwise, check by the given field name -- which may be
# a lookup to a _related_ geographic field.
return GISLookup._check_geo_field(self.model._meta, field_name)

View File

@ -231,15 +231,6 @@ class RelatedGeoModelTest(TestCase):
self.assertIn('Aurora', names) self.assertIn('Aurora', names)
self.assertIn('Kecksburg', names) self.assertIn('Kecksburg', names)
def test11_geoquery_pickle(self):
"Ensuring GeoQuery objects are unpickled correctly. See #10839."
import pickle
from django.contrib.gis.db.models.sql import GeoQuery
qs = City.objects.all()
q_str = pickle.dumps(qs.query)
q = pickle.loads(q_str)
self.assertEqual(GeoQuery, q.__class__)
# TODO: fix on Oracle -- get the following error because the SQL is ordered # TODO: fix on Oracle -- get the following error because the SQL is ordered
# by a geometry object, which Oracle apparently doesn't like: # by a geometry object, which Oracle apparently doesn't like:
# ORA-22901: cannot compare nested table or VARRAY or LOB attributes of an object type # ORA-22901: cannot compare nested table or VARRAY or LOB attributes of an object type

View File

@ -1262,7 +1262,7 @@ class BaseDatabaseOperations(object):
second = timezone.make_aware(second, tz) second = timezone.make_aware(second, tz)
return [first, second] return [first, second]
def get_db_converters(self, internal_type): def get_db_converters(self, expression):
"""Get a list of functions needed to convert field data. """Get a list of functions needed to convert field data.
Some field types on some backends do not provide data in the correct Some field types on some backends do not provide data in the correct
@ -1270,7 +1270,7 @@ class BaseDatabaseOperations(object):
""" """
return [] return []
def convert_durationfield_value(self, value, field): def convert_durationfield_value(self, value, expression, context):
if value is not None: if value is not None:
value = str(decimal.Decimal(value) / decimal.Decimal(1000000)) value = str(decimal.Decimal(value) / decimal.Decimal(1000000))
value = parse_duration(value) value = parse_duration(value)

View File

@ -302,7 +302,7 @@ class DatabaseOperations(BaseDatabaseOperations):
columns. If no ordering would otherwise be applied, we don't want any columns. If no ordering would otherwise be applied, we don't want any
implicit sorting going on. implicit sorting going on.
""" """
return ["NULL"] return [(None, ("NULL", [], 'asc', False))]
def fulltext_search_sql(self, field_name): def fulltext_search_sql(self, field_name):
return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name
@ -387,8 +387,9 @@ class DatabaseOperations(BaseDatabaseOperations):
return 'POW(%s)' % ','.join(sub_expressions) return 'POW(%s)' % ','.join(sub_expressions)
return super(DatabaseOperations, self).combine_expression(connector, sub_expressions) return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
def get_db_converters(self, internal_type): def get_db_converters(self, expression):
converters = super(DatabaseOperations, self).get_db_converters(internal_type) converters = super(DatabaseOperations, self).get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type in ['BooleanField', 'NullBooleanField']: if internal_type in ['BooleanField', 'NullBooleanField']:
converters.append(self.convert_booleanfield_value) converters.append(self.convert_booleanfield_value)
if internal_type == 'UUIDField': if internal_type == 'UUIDField':
@ -397,17 +398,17 @@ class DatabaseOperations(BaseDatabaseOperations):
converters.append(self.convert_textfield_value) converters.append(self.convert_textfield_value)
return converters return converters
def convert_booleanfield_value(self, value, field): def convert_booleanfield_value(self, value, expression, context):
if value in (0, 1): if value in (0, 1):
value = bool(value) value = bool(value)
return value return value
def convert_uuidfield_value(self, value, field): def convert_uuidfield_value(self, value, expression, context):
if value is not None: if value is not None:
value = uuid.UUID(value) value = uuid.UUID(value)
return value return value
def convert_textfield_value(self, value, field): def convert_textfield_value(self, value, expression, context):
if value is not None: if value is not None:
value = force_text(value) value = force_text(value)
return value return value

View File

@ -268,8 +268,9 @@ WHEN (new.%(col_name)s IS NULL)
sql = field_name # Cast to DATE removes sub-second precision. sql = field_name # Cast to DATE removes sub-second precision.
return sql, [] return sql, []
def get_db_converters(self, internal_type): def get_db_converters(self, expression):
converters = super(DatabaseOperations, self).get_db_converters(internal_type) converters = super(DatabaseOperations, self).get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == 'TextField': if internal_type == 'TextField':
converters.append(self.convert_textfield_value) converters.append(self.convert_textfield_value)
elif internal_type == 'BinaryField': elif internal_type == 'BinaryField':
@ -285,28 +286,29 @@ WHEN (new.%(col_name)s IS NULL)
converters.append(self.convert_empty_values) converters.append(self.convert_empty_values)
return converters return converters
def convert_empty_values(self, value, field): def convert_empty_values(self, value, expression, context):
# Oracle stores empty strings as null. We need to undo this in # Oracle stores empty strings as null. We need to undo this in
# order to adhere to the Django convention of using the empty # order to adhere to the Django convention of using the empty
# string instead of null, but only if the field accepts the # string instead of null, but only if the field accepts the
# empty string. # empty string.
field = expression.output_field
if value is None and field.empty_strings_allowed: if value is None and field.empty_strings_allowed:
value = '' value = ''
if field.get_internal_type() == 'BinaryField': if field.get_internal_type() == 'BinaryField':
value = b'' value = b''
return value return value
def convert_textfield_value(self, value, field): def convert_textfield_value(self, value, expression, context):
if isinstance(value, Database.LOB): if isinstance(value, Database.LOB):
value = force_text(value.read()) value = force_text(value.read())
return value return value
def convert_binaryfield_value(self, value, field): def convert_binaryfield_value(self, value, expression, context):
if isinstance(value, Database.LOB): if isinstance(value, Database.LOB):
value = force_bytes(value.read()) value = force_bytes(value.read())
return value return value
def convert_booleanfield_value(self, value, field): def convert_booleanfield_value(self, value, expression, context):
if value in (1, 0): if value in (1, 0):
value = bool(value) value = bool(value)
return value return value
@ -314,16 +316,16 @@ WHEN (new.%(col_name)s IS NULL)
# cx_Oracle always returns datetime.datetime objects for # cx_Oracle always returns datetime.datetime objects for
# DATE and TIMESTAMP columns, but Django wants to see a # DATE and TIMESTAMP columns, but Django wants to see a
# python datetime.date, .time, or .datetime. # python datetime.date, .time, or .datetime.
def convert_datefield_value(self, value, field): def convert_datefield_value(self, value, expression, context):
if isinstance(value, Database.Timestamp): if isinstance(value, Database.Timestamp):
return value.date() return value.date()
def convert_timefield_value(self, value, field): def convert_timefield_value(self, value, expression, context):
if isinstance(value, Database.Timestamp): if isinstance(value, Database.Timestamp):
value = value.time() value = value.time()
return value return value
def convert_uuidfield_value(self, value, field): def convert_uuidfield_value(self, value, expression, context):
if value is not None: if value is not None:
value = uuid.UUID(value) value = uuid.UUID(value)
return value return value

View File

@ -269,8 +269,9 @@ class DatabaseOperations(BaseDatabaseOperations):
return six.text_type(value) return six.text_type(value)
def get_db_converters(self, internal_type): def get_db_converters(self, expression):
converters = super(DatabaseOperations, self).get_db_converters(internal_type) converters = super(DatabaseOperations, self).get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == 'DateTimeField': if internal_type == 'DateTimeField':
converters.append(self.convert_datetimefield_value) converters.append(self.convert_datetimefield_value)
elif internal_type == 'DateField': elif internal_type == 'DateField':
@ -283,25 +284,25 @@ class DatabaseOperations(BaseDatabaseOperations):
converters.append(self.convert_uuidfield_value) converters.append(self.convert_uuidfield_value)
return converters return converters
def convert_decimalfield_value(self, value, field): def convert_decimalfield_value(self, value, expression, context):
return backend_utils.typecast_decimal(field.format_number(value)) return backend_utils.typecast_decimal(expression.output_field.format_number(value))
def convert_datefield_value(self, value, field): def convert_datefield_value(self, value, expression, context):
if value is not None and not isinstance(value, datetime.date): if value is not None and not isinstance(value, datetime.date):
value = parse_date(value) value = parse_date(value)
return value return value
def convert_datetimefield_value(self, value, field): def convert_datetimefield_value(self, value, expression, context):
if value is not None and not isinstance(value, datetime.datetime): if value is not None and not isinstance(value, datetime.datetime):
value = parse_datetime_with_timezone_support(value) value = parse_datetime_with_timezone_support(value)
return value return value
def convert_timefield_value(self, value, field): def convert_timefield_value(self, value, expression, context):
if value is not None and not isinstance(value, datetime.time): if value is not None and not isinstance(value, datetime.time):
value = parse_time(value) value = parse_time(value)
return value return value
def convert_uuidfield_value(self, value, field): def convert_uuidfield_value(self, value, expression, context):
if value is not None: if value is not None:
value = uuid.UUID(value) value = uuid.UUID(value)
return value return value

View File

@ -87,7 +87,7 @@ class Avg(Aggregate):
def __init__(self, expression, **extra): def __init__(self, expression, **extra):
super(Avg, self).__init__(expression, output_field=FloatField(), **extra) super(Avg, self).__init__(expression, output_field=FloatField(), **extra)
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
if value is None: if value is None:
return value return value
return float(value) return float(value)
@ -105,7 +105,7 @@ class Count(Aggregate):
super(Count, self).__init__( super(Count, self).__init__(
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
if value is None: if value is None:
return 0 return 0
return int(value) return int(value)
@ -128,7 +128,7 @@ class StdDev(Aggregate):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra) super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
if value is None: if value is None:
return value return value
return float(value) return float(value)
@ -146,7 +146,7 @@ class Variance(Aggregate):
self.function = 'VAR_SAMP' if sample else 'VAR_POP' self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super(Variance, self).__init__(expression, output_field=FloatField(), **extra) super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
if value is None: if value is None:
return value return value
return float(value) return float(value)

View File

@ -127,7 +127,7 @@ class ExpressionNode(CombinableMixin):
is_summary = False is_summary = False
def get_db_converters(self, connection): def get_db_converters(self, connection):
return [self.convert_value] return [self.convert_value] + self.output_field.get_db_converters(connection)
def __init__(self, output_field=None): def __init__(self, output_field=None):
self._output_field = output_field self._output_field = output_field
@ -240,7 +240,7 @@ class ExpressionNode(CombinableMixin):
raise FieldError( raise FieldError(
"Expression contains mixed types. You must set output_field") "Expression contains mixed types. You must set output_field")
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
""" """
Expressions provide their own converters because users have the option Expressions provide their own converters because users have the option
of manually specifying the output_field which may be a different type of manually specifying the output_field which may be a different type
@ -305,6 +305,8 @@ class ExpressionNode(CombinableMixin):
return self return self
def get_group_by_cols(self): def get_group_by_cols(self):
if not self.contains_aggregate:
return [self]
cols = [] cols = []
for source in self.get_source_expressions(): for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols()) cols.extend(source.get_group_by_cols())
@ -490,6 +492,9 @@ class Value(ExpressionNode):
return 'NULL', [] return 'NULL', []
return '%s', [self.value] return '%s', [self.value]
def get_group_by_cols(self):
return []
class DurationValue(Value): class DurationValue(Value):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
@ -499,6 +504,37 @@ class DurationValue(Value):
return connection.ops.date_interval_sql(self.value) return connection.ops.date_interval_sql(self.value)
class RawSQL(ExpressionNode):
def __init__(self, sql, params, output_field=None):
if output_field is None:
output_field = fields.Field()
self.sql, self.params = sql, params
super(RawSQL, self).__init__(output_field=output_field)
def as_sql(self, compiler, connection):
return '(%s)' % self.sql, self.params
def get_group_by_cols(self):
return [self]
class Random(ExpressionNode):
def __init__(self):
super(Random, self).__init__(output_field=fields.FloatField())
def as_sql(self, compiler, connection):
return connection.ops.random_function_sql(), []
class ColIndexRef(ExpressionNode):
def __init__(self, idx):
self.idx = idx
super(ColIndexRef, self).__init__()
def as_sql(self, compiler, connection):
return str(self.idx), []
class Col(ExpressionNode): class Col(ExpressionNode):
def __init__(self, alias, target, source=None): def __init__(self, alias, target, source=None):
if source is None: if source is None:
@ -516,6 +552,9 @@ class Col(ExpressionNode):
def get_group_by_cols(self): def get_group_by_cols(self):
return [self] return [self]
def get_db_converters(self, connection):
return self.output_field.get_db_converters(connection)
class Ref(ExpressionNode): class Ref(ExpressionNode):
""" """
@ -537,7 +576,7 @@ class Ref(ExpressionNode):
return self return self
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
return "%s" % compiler.quote_name_unless_alias(self.refs), [] return "%s" % connection.ops.quote_name(self.refs), []
def get_group_by_cols(self): def get_group_by_cols(self):
return [self] return [self]
@ -581,7 +620,7 @@ class Date(ExpressionNode):
copy.lookup_type = self.lookup_type copy.lookup_type = self.lookup_type
return copy return copy
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
value = value.date() value = value.date()
return value return value
@ -629,7 +668,7 @@ class DateTime(ExpressionNode):
copy.tzname = self.tzname copy.tzname = self.tzname
return copy return copy
def convert_value(self, value, connection): def convert_value(self, value, connection, context):
if settings.USE_TZ: if settings.USE_TZ:
if value is None: if value is None:
raise ValueError( raise ValueError(

View File

@ -333,6 +333,28 @@ class Field(RegisterLookupMixin):
] ]
return [] return []
def get_col(self, alias, source=None):
if source is None:
source = self
if alias != self.model._meta.db_table or source != self:
from django.db.models.expressions import Col
return Col(alias, self, source)
else:
return self.cached_col
@cached_property
def cached_col(self):
from django.db.models.expressions import Col
return Col(self.model._meta.db_table, self)
def select_format(self, compiler, sql, params):
"""
Custom format for select clauses. For example, GIS columns need to be
selected as AsText(table.col) on MySQL as the table.col data can't be used
by Django.
"""
return sql, params
def deconstruct(self): def deconstruct(self):
""" """
Returns enough information to recreate the field as a 4-tuple: Returns enough information to recreate the field as a 4-tuple:

View File

@ -15,7 +15,6 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
from django.db.models.lookups import IsNull from django.db.models.lookups import IsNull
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.db.models.query_utils import PathInfo from django.db.models.query_utils import PathInfo
from django.db.models.expressions import Col
from django.utils.encoding import force_text, smart_text from django.utils.encoding import force_text, smart_text
from django.utils import six from django.utils import six
from django.utils.deprecation import RemovedInDjango20Warning from django.utils.deprecation import RemovedInDjango20Warning
@ -1738,26 +1737,26 @@ class ForeignObject(RelatedField):
[source.name for source in sources], raw_value), [source.name for source in sources], raw_value),
AND) AND)
elif lookup_type == 'isnull': elif lookup_type == 'isnull':
root_constraint.add(IsNull(Col(alias, targets[0], sources[0]), raw_value), AND) root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND)
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte'] elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
and not is_multicolumn)): and not is_multicolumn)):
value = get_normalized_value(raw_value) value = get_normalized_value(raw_value)
for target, source, val in zip(targets, sources, value): for target, source, val in zip(targets, sources, value):
lookup_class = target.get_lookup(lookup_type) lookup_class = target.get_lookup(lookup_type)
root_constraint.add( root_constraint.add(
lookup_class(Col(alias, target, source), val), AND) lookup_class(target.get_col(alias, source), val), AND)
elif lookup_type in ['range', 'in'] and not is_multicolumn: elif lookup_type in ['range', 'in'] and not is_multicolumn:
values = [get_normalized_value(value) for value in raw_value] values = [get_normalized_value(value) for value in raw_value]
value = [val[0] for val in values] value = [val[0] for val in values]
lookup_class = targets[0].get_lookup(lookup_type) lookup_class = targets[0].get_lookup(lookup_type)
root_constraint.add(lookup_class(Col(alias, targets[0], sources[0]), value), AND) root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND)
elif lookup_type == 'in': elif lookup_type == 'in':
values = [get_normalized_value(value) for value in raw_value] values = [get_normalized_value(value) for value in raw_value]
for value in values: for value in values:
value_constraint = constraint_class() value_constraint = constraint_class()
for source, target, val in zip(sources, targets, value): for source, target, val in zip(sources, targets, value):
lookup_class = target.get_lookup('exact') lookup_class = target.get_lookup('exact')
lookup = lookup_class(Col(alias, target, source), val) lookup = lookup_class(target.get_col(alias, source), val)
value_constraint.add(lookup, AND) value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR) root_constraint.add(value_constraint, OR)
else: else:

View File

@ -13,8 +13,7 @@ from django.db import (connections, router, transaction, IntegrityError,
DJANGO_VERSION_PICKLE_KEY) DJANGO_VERSION_PICKLE_KEY)
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields import AutoField, Empty from django.db.models.fields import AutoField, Empty
from django.db.models.query_utils import (Q, select_related_descend, from django.db.models.query_utils import Q, deferred_class_factory, InvalidQuery
deferred_class_factory, InvalidQuery)
from django.db.models.deletion import Collector from django.db.models.deletion import Collector
from django.db.models.sql.constants import CURSOR from django.db.models.sql.constants import CURSOR
from django.db.models import sql from django.db.models import sql
@ -233,76 +232,34 @@ class QuerySet(object):
An iterator over the results from applying this QuerySet to the An iterator over the results from applying this QuerySet to the
database. database.
""" """
fill_cache = False
if connections[self.db].features.supports_select_related:
fill_cache = self.query.select_related
if isinstance(fill_cache, dict):
requested = fill_cache
else:
requested = None
max_depth = self.query.max_depth
extra_select = list(self.query.extra_select)
annotation_select = list(self.query.annotation_select)
only_load = self.query.get_loaded_field_names()
fields = self.model._meta.concrete_fields
load_fields = []
# If only/defer clauses have been specified,
# build the list of fields that are to be loaded.
if only_load:
for field in self.model._meta.concrete_fields:
model = field.model._meta.model
try:
if field.name in only_load[model]:
# Add a field that has been explicitly included
load_fields.append(field.name)
except KeyError:
# Model wasn't explicitly listed in the only_load table
# Therefore, we need to load all fields from this model
load_fields.append(field.name)
skip = None
if load_fields:
# Some fields have been deferred, so we have to initialize
# via keyword arguments.
skip = set()
init_list = []
for field in fields:
if field.name not in load_fields:
skip.add(field.attname)
else:
init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip)
else:
model_cls = self.model
init_list = [f.attname for f in fields]
# Cache db and model outside the loop
db = self.db db = self.db
compiler = self.query.get_compiler(using=db) compiler = self.query.get_compiler(using=db)
index_start = len(extra_select) # Execute the query. This will also fill compiler.select, klass_info,
annotation_start = index_start + len(init_list) # and annotations.
results = compiler.execute_sql()
if fill_cache: select, klass_info, annotation_col_map = (compiler.select, compiler.klass_info,
klass_info = get_klass_info(model_cls, max_depth=max_depth, compiler.annotation_col_map)
requested=requested, only_load=only_load) if klass_info is None:
for row in compiler.results_iter(): return
if fill_cache: model_cls = klass_info['model']
obj, _ = get_cached_row(row, index_start, db, klass_info, select_fields = klass_info['select_fields']
offset=len(annotation_select)) model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
else: init_list = [f[0].output_field.attname
obj = model_cls.from_db(db, init_list, row[index_start:annotation_start]) for f in select[model_fields_start:model_fields_end]]
if len(init_list) != len(model_cls._meta.concrete_fields):
if extra_select: init_set = set(init_list)
for i, k in enumerate(extra_select): skip = [f.attname for f in model_cls._meta.concrete_fields
setattr(obj, k, row[i]) if f.attname not in init_set]
model_cls = deferred_class_factory(model_cls, skip)
# Add the annotations to the model related_populators = get_related_populators(klass_info, select, db)
if annotation_select: for row in compiler.results_iter(results):
for i, annotation in enumerate(annotation_select): obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end])
setattr(obj, annotation, row[i + annotation_start]) if related_populators:
for rel_populator in related_populators:
rel_populator.populate(row, obj)
if annotation_col_map:
for attr_name, col_pos in annotation_col_map.items():
setattr(obj, attr_name, row[col_pos])
# Add the known related objects to the model, if there are any # Add the known related objects to the model, if there are any
if self._known_related_objects: if self._known_related_objects:
@ -1032,11 +989,8 @@ class QuerySet(object):
""" """
Prepare the query for computing a result that contains aggregate annotations. Prepare the query for computing a result that contains aggregate annotations.
""" """
opts = self.model._meta
if self.query.group_by is None: if self.query.group_by is None:
field_names = [f.attname for f in opts.concrete_fields] self.query.group_by = True
self.query.add_fields(field_names, False)
self.query.set_group_by()
def _prepare(self): def _prepare(self):
return self return self
@ -1135,9 +1089,11 @@ class ValuesQuerySet(QuerySet):
Called by the _clone() method after initializing the rest of the Called by the _clone() method after initializing the rest of the
instance. instance.
""" """
if self.query.group_by is True:
self.query.add_fields([f.attname for f in self.model._meta.concrete_fields], False)
self.query.set_group_by()
self.query.clear_deferred_loading() self.query.clear_deferred_loading()
self.query.clear_select_fields() self.query.clear_select_fields()
if self._fields: if self._fields:
self.extra_names = [] self.extra_names = []
self.annotation_names = [] self.annotation_names = []
@ -1246,11 +1202,12 @@ class ValuesQuerySet(QuerySet):
class ValuesListQuerySet(ValuesQuerySet): class ValuesListQuerySet(ValuesQuerySet):
def iterator(self): def iterator(self):
compiler = self.query.get_compiler(self.db)
if self.flat and len(self._fields) == 1: if self.flat and len(self._fields) == 1:
for row in self.query.get_compiler(self.db).results_iter(): for row in compiler.results_iter():
yield row[0] yield row[0]
elif not self.query.extra_select and not self.query.annotation_select: elif not self.query.extra_select and not self.query.annotation_select:
for row in self.query.get_compiler(self.db).results_iter(): for row in compiler.results_iter():
yield tuple(row) yield tuple(row)
else: else:
# When extra(select=...) or an annotation is involved, the extra # When extra(select=...) or an annotation is involved, the extra
@ -1269,7 +1226,7 @@ class ValuesListQuerySet(ValuesQuerySet):
else: else:
fields = names fields = names
for row in self.query.get_compiler(self.db).results_iter(): for row in compiler.results_iter():
data = dict(zip(names, row)) data = dict(zip(names, row))
yield tuple(data[f] for f in fields) yield tuple(data[f] for f in fields)
@ -1281,244 +1238,6 @@ class ValuesListQuerySet(ValuesQuerySet):
return clone return clone
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
only_load=None, from_parent=None):
"""
Helper function that recursively returns an information for a klass, to be
used in get_cached_row. It exists just to compute this information only
once for entire queryset. Otherwise it would be computed for each row, which
leads to poor performance on large querysets.
Arguments:
* klass - the class to retrieve (and instantiate)
* max_depth - the maximum depth to which a select_related()
relationship should be explored.
* cur_depth - the current depth in the select_related() tree.
Used in recursive calls to determine if we should dig deeper.
* requested - A dictionary describing the select_related() tree
that is to be retrieved. keys are field names; values are
dictionaries describing the keys on that related object that
are themselves to be select_related().
* only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed.
* from_parent - the parent model used to get to this model
Note that when travelling from parent to child, we will only load child
fields which aren't in the parent.
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
return None
if only_load:
load_fields = only_load.get(klass) or set()
# When we create the object, we will also be creating populating
# all the parent classes, so traverse the parent classes looking
# for fields that must be included on load.
for parent in klass._meta.get_parent_list():
fields = only_load.get(parent)
if fields:
load_fields.update(fields)
else:
load_fields = None
if load_fields:
# Handle deferred fields.
skip = set()
init_list = []
# Build the list of fields that *haven't* been requested
for field in klass._meta.concrete_fields:
model = field.model._meta.concrete_model
if from_parent and model and issubclass(from_parent, model):
# Avoid loading fields already loaded for parent model for
# child models.
continue
elif field.name not in load_fields:
skip.add(field.attname)
else:
init_list.append(field.attname)
# Retrieve all the requested fields
field_count = len(init_list)
if skip:
klass = deferred_class_factory(klass, skip)
field_names = init_list
else:
field_names = ()
else:
# Load all fields on klass
field_count = len(klass._meta.concrete_fields)
# Check if we need to skip some parent fields.
if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields):
# Only load those fields which haven't been already loaded into
# 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)]
# Load local fields, too...
non_seen_models.append(klass)
field_names = [f.attname for f in klass._meta.concrete_fields
if f.model in non_seen_models]
field_count = len(field_names)
# Try to avoid populating field_names variable for performance reasons.
# If field_names variable is set, we use **kwargs based model init
# which is slower than normal init.
if field_count == len(klass._meta.concrete_fields):
field_names = ()
restricted = requested is not None
related_fields = []
for f in klass._meta.fields:
if select_related_descend(f, restricted, requested, load_fields):
if restricted:
next = requested[f.name]
else:
next = None
klass_info = get_klass_info(f.rel.to, max_depth=max_depth, cur_depth=cur_depth + 1,
requested=next, only_load=only_load)
related_fields.append((f, klass_info))
reverse_related_fields = []
if restricted:
for o in klass._meta.related_objects:
if o.field.unique and select_related_descend(o.field, restricted, requested,
only_load.get(o.related_model), reverse=True):
next = requested[o.field.related_query_name()]
parent = klass if issubclass(o.related_model, klass) else None
klass_info = get_klass_info(o.related_model, max_depth=max_depth, cur_depth=cur_depth + 1,
requested=next, only_load=only_load, from_parent=parent)
reverse_related_fields.append((o.field, klass_info))
if field_names:
pk_idx = field_names.index(klass._meta.pk.attname)
else:
meta = klass._meta
pk_idx = meta.concrete_fields.index(meta.pk)
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
def reorder_for_init(model, field_names, values):
"""
Reorders given field names and values for those fields
to be in the same order as model.__init__() expects to find them.
"""
new_names, new_values = [], []
for f in model._meta.concrete_fields:
if f.attname not in field_names:
continue
new_names.append(f.attname)
new_values.append(values[field_names.index(f.attname)])
assert len(new_names) == len(field_names)
return new_names, new_values
def get_cached_row(row, index_start, using, klass_info, offset=0,
parent_data=()):
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
This method may be called recursively to populate deep select_related()
clauses.
Arguments:
* row - the row of data returned by the database cursor
* index_start - the index of the row at which data for this
object is known to start
* offset - the number of additional fields that are known to
exist in row for `klass`. This usually means the number of
annotated results on `klass`.
* using - the database alias on which the query is being executed.
* klass_info - result of the get_klass_info function
* parent_data - parent model data in format (field, value). Used
to populate the non-local fields of child models.
"""
if klass_info is None:
return None
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
fields = row[index_start:index_start + field_count]
# If the pk column is None (or the equivalent '' in the case the
# connection interprets empty strings as nulls), then the related
# object must be non-existent - set the relation to None.
if (fields[pk_idx] is None or
(connections[using].features.interprets_empty_strings_as_nulls and
fields[pk_idx] == '')):
obj = None
elif field_names:
values = list(fields)
parent_values = []
parent_field_names = []
for rel_field, value in parent_data:
parent_field_names.append(rel_field.attname)
parent_values.append(value)
field_names, values = reorder_for_init(
klass, parent_field_names + field_names,
parent_values + values)
obj = klass.from_db(using, field_names, values)
else:
field_names = [f.attname for f in klass._meta.concrete_fields]
obj = klass.from_db(using, field_names, fields)
# Instantiate related fields
index_end = index_start + field_count + offset
# Iterate over each related object, populating any
# select_related() fields
for f, klass_info in related_fields:
# Recursively retrieve the data for the related object
cached_row = get_cached_row(row, index_end, using, klass_info)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
# If the base object exists, populate the
# descriptor cache
setattr(obj, f.get_cache_name(), rel_obj)
if f.unique and rel_obj is not None:
# If the field is unique, populate the
# reverse descriptor cache on the related object
setattr(rel_obj, f.rel.get_cache_name(), obj)
# Now do the same, but for reverse related objects.
# Only handle the restricted case - i.e., don't do a depth
# descent into reverse relations unless explicitly requested
for f, klass_info in reverse_related_fields:
# Transfer data from this object to childs.
parent_data = []
for rel_field in klass_info[0]._meta.fields:
rel_model = rel_field.model._meta.concrete_model
if rel_model == klass_info[0]._meta.model:
rel_model = None
if rel_model is not None and isinstance(obj, rel_model):
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
# Recursively retrieve the data for the related object
cached_row = get_cached_row(row, index_end, using, klass_info,
parent_data=parent_data)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
# populate the reverse descriptor cache
setattr(obj, f.rel.get_cache_name(), rel_obj)
if rel_obj is not None:
# If the related object exists, populate
# the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj)
# Populate related object caches using parent data.
for rel_field, _ in parent_data:
if rel_field.rel:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
try:
cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
except AttributeError:
# Related object hasn't been cached yet
pass
return obj, index_end
class RawQuerySet(object): class RawQuerySet(object):
""" """
Provides an iterator which converts the results of raw SQL queries into Provides an iterator which converts the results of raw SQL queries into
@ -1569,7 +1288,9 @@ class RawQuerySet(object):
else: else:
model_cls = self.model model_cls = self.model
fields = [self.model_fields.get(c, None) for c in self.columns] fields = [self.model_fields.get(c, None) for c in self.columns]
converters = compiler.get_converters(fields) converters = compiler.get_converters([
f.get_col(f.model._meta.db_table) if f else None for f in fields
])
for values in query: for values in query:
if converters: if converters:
values = compiler.apply_converters(values, converters) values = compiler.apply_converters(values, converters)
@ -1920,3 +1641,120 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
qs._prefetch_done = True qs._prefetch_done = True
obj._prefetched_objects_cache[cache_name] = qs obj._prefetched_objects_cache[cache_name] = qs
return all_related_objects, additional_lookups return all_related_objects, additional_lookups
class RelatedPopulator(object):
"""
RelatedPopulator is used for select_related() object instantiation.
The idea is that each select_related() model will be populated by a
different RelatedPopulator instance. The RelatedPopulator instances get
klass_info and select (computed in SQLCompiler) plus the used db as
input for initialization. That data is used to compute which columns
to use, how to instantiate the model, and how to populate the links
between the objects.
The actual creation of the objects is done in populate() method. This
method gets row and from_obj as input and populates the select_related()
model instance.
"""
def __init__(self, klass_info, select, db):
self.db = db
# Pre-compute needed attributes. The attributes are:
# - model_cls: the possibly deferred model class to instantiate
# - either:
# - cols_start, cols_end: usually the columns in the row are
# in the same order model_cls.__init__ expects them, so we
# can instantiate by model_cls(*row[cols_start:cols_end])
# - reorder_for_init: When select_related descends to a child
# class, then we want to reuse the already selected parent
# data. However, in this case the parent data isn't necessarily
# in the same order that Model.__init__ expects it to be, so
# we have to reorder the parent data. The reorder_for_init
# attribute contains a function used to reorder the field data
# in the order __init__ expects it.
# - pk_idx: the index of the primary key field in the reordered
# model data. Used to check if a related object exists at all.
# - init_list: the field attnames fetched from the database. For
# deferred models this isn't the same as all attnames of the
# model's fields.
# - related_populators: a list of RelatedPopulator instances if
# select_related() descends to related models from this model.
# - cache_name, reverse_cache_name: the names to use for setattr
# when assigning the fetched object to the from_obj. If the
# reverse_cache_name is set, then we also set the reverse link.
select_fields = klass_info['select_fields']
from_parent = klass_info['from_parent']
if not from_parent:
self.cols_start = select_fields[0]
self.cols_end = select_fields[-1] + 1
self.init_list = [
f[0].output_field.attname for f in select[self.cols_start:self.cols_end]
]
self.reorder_for_init = None
else:
model_init_attnames = [
f.attname for f in klass_info['model']._meta.concrete_fields
]
reorder_map = []
for idx in select_fields:
field = select[idx][0].output_field
init_pos = model_init_attnames.index(field.attname)
reorder_map.append((init_pos, field.attname, idx))
reorder_map.sort()
self.init_list = [v[1] for v in reorder_map]
pos_list = [row_pos for _, _, row_pos in reorder_map]
def reorder_for_init(row):
return [row[row_pos] for row_pos in pos_list]
self.reorder_for_init = reorder_for_init
self.model_cls = self.get_deferred_cls(klass_info, self.init_list)
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
self.related_populators = get_related_populators(klass_info, select, self.db)
field = klass_info['field']
reverse = klass_info['reverse']
self.reverse_cache_name = None
if reverse:
self.cache_name = field.rel.get_cache_name()
self.reverse_cache_name = field.get_cache_name()
else:
self.cache_name = field.get_cache_name()
if field.unique:
self.reverse_cache_name = field.rel.get_cache_name()
def get_deferred_cls(self, klass_info, init_list):
model_cls = klass_info['model']
if len(init_list) != len(model_cls._meta.concrete_fields):
init_set = set(init_list)
skip = [
f.attname for f in model_cls._meta.concrete_fields
if f.attname not in init_set
]
model_cls = deferred_class_factory(model_cls, skip)
return model_cls
def populate(self, row, from_obj):
if self.reorder_for_init:
obj_data = self.reorder_for_init(row)
else:
obj_data = row[self.cols_start:self.cols_end]
if obj_data[self.pk_idx] is None:
obj = None
else:
obj = self.model_cls.from_db(self.db, self.init_list, obj_data)
if obj and self.related_populators:
for rel_iter in self.related_populators:
rel_iter.populate(row, obj)
setattr(from_obj, self.cache_name, obj)
if obj and self.reverse_cache_name:
setattr(obj, self.reverse_cache_name, from_obj)
def get_related_populators(klass_info, select, db):
iterators = []
related_klass_infos = klass_info.get('related_klass_infos', [])
for rel_klass_info in related_klass_infos:
rel_cls = RelatedPopulator(rel_klass_info, select, db)
iterators.append(rel_cls)
return iterators

View File

@ -170,7 +170,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
if not restricted and field.null: if not restricted and field.null:
return False return False
if load_fields: if load_fields:
if field.name not in load_fields: if field.attname not in load_fields:
if restricted and field.name in requested: if restricted and field.name in requested:
raise InvalidQuery("Field %s.%s cannot be both deferred" raise InvalidQuery("Field %s.%s cannot be both deferred"
" and traversed using select_related" " and traversed using select_related"

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,6 @@
Constants specific to the SQL storage portion of the ORM. Constants specific to the SQL storage portion of the ORM.
""" """
from collections import namedtuple
import re import re
# Valid query types (a set is used for speedy lookups). These are (currently) # Valid query types (a set is used for speedy lookups). These are (currently)
@ -21,9 +20,6 @@ GET_ITERATOR_CHUNK_SIZE = 100
# Namedtuples for sql.* internal use. # Namedtuples for sql.* internal use.
# Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field')
# How many results to expect from a cursor.execute call # How many results to expect from a cursor.execute call
MULTI = 'multi' MULTI = 'multi'
SINGLE = 'single' SINGLE = 'single'

View File

@ -21,7 +21,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref from django.db.models.expressions import Col, Ref
from django.db.models.query_utils import PathInfo, Q, refs_aggregate from django.db.models.query_utils import PathInfo, Q, refs_aggregate
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
ORDER_PATTERN, SelectInfo, INNER, LOUTER) ORDER_PATTERN, INNER, LOUTER)
from django.db.models.sql.datastructures import ( from django.db.models.sql.datastructures import (
EmptyResultSet, Empty, MultiJoin, Join, BaseTable) EmptyResultSet, Empty, MultiJoin, Join, BaseTable)
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@ -46,7 +46,7 @@ class RawQuery(object):
A single raw SQL query A single raw SQL query
""" """
def __init__(self, sql, using, params=None): def __init__(self, sql, using, params=None, context=None):
self.params = params or () self.params = params or ()
self.sql = sql self.sql = sql
self.using = using self.using = using
@ -57,9 +57,10 @@ class RawQuery(object):
self.low_mark, self.high_mark = 0, None # Used for offset/limit self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.extra_select = {} self.extra_select = {}
self.annotation_select = {} self.annotation_select = {}
self.context = context or {}
def clone(self, using): def clone(self, using):
return RawQuery(self.sql, using, params=self.params) return RawQuery(self.sql, using, params=self.params, context=self.context.copy())
def get_columns(self): def get_columns(self):
if self.cursor is None: if self.cursor is None:
@ -122,20 +123,23 @@ class Query(object):
self.standard_ordering = True self.standard_ordering = True
self.used_aliases = set() self.used_aliases = set()
self.filter_is_sticky = False self.filter_is_sticky = False
self.included_inherited_models = {}
# SQL-related attributes # SQL-related attributes
# Select and related select clauses as SelectInfo instances. # Select and related select clauses are expressions to use in the
# SELECT clause of the query.
# The select is used for cases where we want to set up the select # The select is used for cases where we want to set up the select
# clause to contain other than default fields (values(), annotate(), # clause to contain other than default fields (values(), subqueries...)
# subqueries...) # Note that annotations go to annotations dictionary.
self.select = [] self.select = []
# The related_select_cols is used for columns needed for
# select_related - this is populated in the compile stage.
self.related_select_cols = []
self.tables = [] # Aliases in the order they are created. self.tables = [] # Aliases in the order they are created.
self.where = where() self.where = where()
self.where_class = where self.where_class = where
# The group_by attribute can have one of the following forms:
# - None: no group by at all in the query
# - A list of expressions: group by (at least) those expressions.
# String refs are also allowed for now.
# - True: group by all select fields of the model
# See compiler.get_group_by() for details.
self.group_by = None self.group_by = None
self.having = where() self.having = where()
self.order_by = [] self.order_by = []
@ -174,6 +178,8 @@ class Query(object):
# load. # load.
self.deferred_loading = (set(), True) self.deferred_loading = (set(), True)
self.context = {}
@property @property
def extra(self): def extra(self):
if self._extra is None: if self._extra is None:
@ -254,14 +260,14 @@ class Query(object):
obj.default_cols = self.default_cols obj.default_cols = self.default_cols
obj.default_ordering = self.default_ordering obj.default_ordering = self.default_ordering
obj.standard_ordering = self.standard_ordering obj.standard_ordering = self.standard_ordering
obj.included_inherited_models = self.included_inherited_models.copy()
obj.select = self.select[:] obj.select = self.select[:]
obj.related_select_cols = []
obj.tables = self.tables[:] obj.tables = self.tables[:]
obj.where = self.where.clone() obj.where = self.where.clone()
obj.where_class = self.where_class obj.where_class = self.where_class
if self.group_by is None: if self.group_by is None:
obj.group_by = None obj.group_by = None
elif self.group_by is True:
obj.group_by = True
else: else:
obj.group_by = self.group_by[:] obj.group_by = self.group_by[:]
obj.having = self.having.clone() obj.having = self.having.clone()
@ -272,7 +278,6 @@ class Query(object):
obj.select_for_update = self.select_for_update obj.select_for_update = self.select_for_update
obj.select_for_update_nowait = self.select_for_update_nowait obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_related = self.select_related obj.select_related = self.select_related
obj.related_select_cols = []
obj._annotations = self._annotations.copy() if self._annotations is not None else None obj._annotations = self._annotations.copy() if self._annotations is not None else None
if self.annotation_select_mask is None: if self.annotation_select_mask is None:
obj.annotation_select_mask = None obj.annotation_select_mask = None
@ -310,8 +315,15 @@ class Query(object):
obj.__dict__.update(kwargs) obj.__dict__.update(kwargs)
if hasattr(obj, '_setup_query'): if hasattr(obj, '_setup_query'):
obj._setup_query() obj._setup_query()
obj.context = self.context.copy()
return obj return obj
def add_context(self, key, value):
self.context[key] = value
def get_context(self, key, default=None):
return self.context.get(key, default)
def relabeled_clone(self, change_map): def relabeled_clone(self, change_map):
clone = self.clone() clone = self.clone()
clone.change_aliases(change_map) clone.change_aliases(change_map)
@ -375,7 +387,8 @@ class Query(object):
# done in a subquery so that we are aggregating on the limit and/or # done in a subquery so that we are aggregating on the limit and/or
# distinct results instead of applying the distinct and limit after the # distinct results instead of applying the distinct and limit after the
# aggregation. # aggregation.
if (self.group_by or has_limit or has_existing_annotations or self.distinct): if (isinstance(self.group_by, list) or has_limit or has_existing_annotations or
self.distinct):
from django.db.models.sql.subqueries import AggregateQuery from django.db.models.sql.subqueries import AggregateQuery
outer_query = AggregateQuery(self.model) outer_query = AggregateQuery(self.model)
inner_query = self.clone() inner_query = self.clone()
@ -383,7 +396,6 @@ class Query(object):
inner_query.clear_ordering(True) inner_query.clear_ordering(True)
inner_query.select_for_update = False inner_query.select_for_update = False
inner_query.select_related = False inner_query.select_related = False
inner_query.related_select_cols = []
relabels = {t: 'subquery' for t in inner_query.tables} relabels = {t: 'subquery' for t in inner_query.tables}
relabels[None] = 'subquery' relabels[None] = 'subquery'
@ -407,26 +419,17 @@ class Query(object):
self.select = [] self.select = []
self.default_cols = False self.default_cols = False
self._extra = {} self._extra = {}
self.remove_inherited_models()
outer_query.clear_ordering(True) outer_query.clear_ordering(True)
outer_query.clear_limits() outer_query.clear_limits()
outer_query.select_for_update = False outer_query.select_for_update = False
outer_query.select_related = False outer_query.select_related = False
outer_query.related_select_cols = []
compiler = outer_query.get_compiler(using) compiler = outer_query.get_compiler(using)
result = compiler.execute_sql(SINGLE) result = compiler.execute_sql(SINGLE)
if result is None: if result is None:
result = [None for q in outer_query.annotation_select.items()] result = [None for q in outer_query.annotation_select.items()]
fields = [annotation.output_field converters = compiler.get_converters(outer_query.annotation_select.values())
for alias, annotation in outer_query.annotation_select.items()]
converters = compiler.get_converters(fields)
for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()):
if position in converters:
converters[position][1].insert(0, annotation.convert_value)
else:
converters[position] = ([], [annotation.convert_value], annotation.output_field)
result = compiler.apply_converters(result, converters) result = compiler.apply_converters(result, converters)
return { return {
@ -476,7 +479,6 @@ class Query(object):
assert self.distinct_fields == rhs.distinct_fields, \ assert self.distinct_fields == rhs.distinct_fields, \
"Cannot combine queries with different distinct fields." "Cannot combine queries with different distinct fields."
self.remove_inherited_models()
# Work out how to relabel the rhs aliases, if necessary. # Work out how to relabel the rhs aliases, if necessary.
change_map = {} change_map = {}
conjunction = (connector == AND) conjunction = (connector == AND)
@ -545,13 +547,8 @@ class Query(object):
# Selection columns and extra extensions are those provided by 'rhs'. # Selection columns and extra extensions are those provided by 'rhs'.
self.select = [] self.select = []
for col, field in rhs.select: for col in rhs.select:
if isinstance(col, (list, tuple)): self.add_select(col.relabeled_clone(change_map))
new_col = change_map.get(col[0], col[0]), col[1]
self.select.append(SelectInfo(new_col, field))
else:
new_col = col.relabeled_clone(change_map)
self.select.append(SelectInfo(new_col, field))
if connector == OR: if connector == OR:
# It would be nice to be able to handle this, but the queries don't # It would be nice to be able to handle this, but the queries don't
@ -661,17 +658,6 @@ class Query(object):
for model, values in six.iteritems(seen): for model, values in six.iteritems(seen):
callback(target, model, values) callback(target, model, values)
def deferred_to_columns_cb(self, target, model, fields):
"""
Callback used by deferred_to_columns(). The "target" parameter should
be a set instance.
"""
table = model._meta.db_table
if table not in target:
target[table] = set()
for field in fields:
target[table].add(field.column)
def table_alias(self, table_name, create=False): def table_alias(self, table_name, create=False):
""" """
Returns a table alias for the given table_name and whether this is a Returns a table alias for the given table_name and whether this is a
@ -788,10 +774,9 @@ class Query(object):
# "group by", "where" and "having". # "group by", "where" and "having".
self.where.relabel_aliases(change_map) self.where.relabel_aliases(change_map)
self.having.relabel_aliases(change_map) self.having.relabel_aliases(change_map)
if self.group_by: if isinstance(self.group_by, list):
self.group_by = [relabel_column(col) for col in self.group_by] self.group_by = [relabel_column(col) for col in self.group_by]
self.select = [SelectInfo(relabel_column(s.col), s.field) self.select = [col.relabeled_clone(change_map) for col in self.select]
for s in self.select]
if self._annotations: if self._annotations:
self._annotations = OrderedDict( self._annotations = OrderedDict(
(key, relabel_column(col)) for key, col in self._annotations.items()) (key, relabel_column(col)) for key, col in self._annotations.items())
@ -815,9 +800,6 @@ class Query(object):
if alias == old_alias: if alias == old_alias:
self.tables[pos] = new_alias self.tables[pos] = new_alias
break break
for key, alias in self.included_inherited_models.items():
if alias in change_map:
self.included_inherited_models[key] = change_map[alias]
self.external_aliases = {change_map.get(alias, alias) self.external_aliases = {change_map.get(alias, alias)
for alias in self.external_aliases} for alias in self.external_aliases}
@ -930,28 +912,6 @@ class Query(object):
self.alias_map[alias] = join self.alias_map[alias] = join
return alias return alias
def setup_inherited_models(self):
"""
If the model that is the basis for this QuerySet inherits other models,
we need to ensure that those other models have their tables included in
the query.
We do this as a separate step so that subclasses know which
tables are going to be active in the query, without needing to compute
all the select columns (this method is called from pre_sql_setup(),
whereas column determination is a later part, and side-effect, of
as_sql()).
"""
opts = self.get_meta()
root_alias = self.tables[0]
seen = {None: root_alias}
for field in opts.fields:
model = field.model._meta.concrete_model
if model is not opts.model and model not in seen:
self.join_parent_model(opts, model, root_alias, seen)
self.included_inherited_models = seen
def join_parent_model(self, opts, model, alias, seen): def join_parent_model(self, opts, model, alias, seen):
""" """
Makes sure the given 'model' is joined in the query. If 'model' isn't Makes sure the given 'model' is joined in the query. If 'model' isn't
@ -969,7 +929,9 @@ class Query(object):
curr_opts = opts curr_opts = opts
for int_model in chain: for int_model in chain:
if int_model in seen: if int_model in seen:
return seen[int_model] curr_opts = int_model._meta
alias = seen[int_model]
continue
# Proxy model have elements in base chain # Proxy model have elements in base chain
# with no parents, assign the new options # with no parents, assign the new options
# object and skip to the next base in that # object and skip to the next base in that
@ -984,23 +946,13 @@ class Query(object):
alias = seen[int_model] = joins[-1] alias = seen[int_model] = joins[-1]
return alias or seen[None] return alias or seen[None]
def remove_inherited_models(self):
"""
Undoes the effects of setup_inherited_models(). Should be called
whenever select columns (self.select) are set explicitly.
"""
for key, alias in self.included_inherited_models.items():
if key:
self.unref_alias(alias)
self.included_inherited_models = {}
def add_aggregate(self, aggregate, model, alias, is_summary): def add_aggregate(self, aggregate, model, alias, is_summary):
warnings.warn( warnings.warn(
"add_aggregate() is deprecated. Use add_annotation() instead.", "add_aggregate() is deprecated. Use add_annotation() instead.",
RemovedInDjango20Warning, stacklevel=2) RemovedInDjango20Warning, stacklevel=2)
self.add_annotation(aggregate, alias, is_summary) self.add_annotation(aggregate, alias, is_summary)
def add_annotation(self, annotation, alias, is_summary): def add_annotation(self, annotation, alias, is_summary=False):
""" """
Adds a single annotation expression to the Query Adds a single annotation expression to the Query
""" """
@ -1011,6 +963,7 @@ class Query(object):
def prepare_lookup_value(self, value, lookups, can_reuse): def prepare_lookup_value(self, value, lookups, can_reuse):
# Default lookup if none given is exact. # Default lookup if none given is exact.
used_joins = []
if len(lookups) == 0: if len(lookups) == 0:
lookups = ['exact'] lookups = ['exact']
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
@ -1026,7 +979,9 @@ class Query(object):
RemovedInDjango19Warning, stacklevel=2) RemovedInDjango19Warning, stacklevel=2)
value = value() value = value()
elif hasattr(value, 'resolve_expression'): elif hasattr(value, 'resolve_expression'):
pre_joins = self.alias_refcount.copy()
value = value.resolve_expression(self, reuse=can_reuse) value = value.resolve_expression(self, reuse=can_reuse)
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
# Subqueries need to use a different set of aliases than the # Subqueries need to use a different set of aliases than the
# outer query. Call bump_prefix to change aliases of the inner # outer query. Call bump_prefix to change aliases of the inner
# query (the value). # query (the value).
@ -1044,7 +999,7 @@ class Query(object):
lookups[-1] == 'exact' and value == ''): lookups[-1] == 'exact' and value == ''):
value = True value = True
lookups[-1] = 'isnull' lookups[-1] = 'isnull'
return value, lookups return value, lookups, used_joins
def solve_lookup_type(self, lookup): def solve_lookup_type(self, lookup):
""" """
@ -1173,8 +1128,7 @@ class Query(object):
# Work out the lookup type and remove it from the end of 'parts', # Work out the lookup type and remove it from the end of 'parts',
# if necessary. # if necessary.
value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse)
used_joins = getattr(value, '_used_joins', [])
clause = self.where_class() clause = self.where_class()
if reffed_aggregate: if reffed_aggregate:
@ -1223,7 +1177,7 @@ class Query(object):
# handle Expressions as annotations # handle Expressions as annotations
col = targets[0] col = targets[0]
else: else:
col = Col(alias, targets[0], field) col = targets[0].get_col(alias, field)
condition = self.build_lookup(lookups, col, value) condition = self.build_lookup(lookups, col, value)
if not condition: if not condition:
# Backwards compat for custom lookups # Backwards compat for custom lookups
@ -1258,7 +1212,7 @@ class Query(object):
# <=> # <=>
# NOT (col IS NOT NULL AND col = someval). # NOT (col IS NOT NULL AND col = someval).
lookup_class = targets[0].get_lookup('isnull') lookup_class = targets[0].get_lookup('isnull')
clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND) clause.add(lookup_class(targets[0].get_col(alias, sources[0]), False), AND)
return clause, used_joins if not require_outer else () return clause, used_joins if not require_outer else ()
def add_filter(self, filter_clause): def add_filter(self, filter_clause):
@ -1535,7 +1489,7 @@ class Query(object):
self.unref_alias(joins.pop()) self.unref_alias(joins.pop())
return targets, joins[-1], joins return targets, joins[-1], joins
def resolve_ref(self, name, allow_joins, reuse, summarize): def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False):
if not allow_joins and LOOKUP_SEP in name: if not allow_joins and LOOKUP_SEP in name:
raise FieldError("Joined field references are not permitted in this query") raise FieldError("Joined field references are not permitted in this query")
if name in self.annotations: if name in self.annotations:
@ -1558,8 +1512,7 @@ class Query(object):
"isn't supported") "isn't supported")
if reuse is not None: if reuse is not None:
reuse.update(join_list) reuse.update(join_list)
col = Col(join_list[-1], targets[0], sources[0]) col = targets[0].get_col(join_list[-1], sources[0])
col._used_joins = join_list
return col return col
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
@ -1588,26 +1541,28 @@ class Query(object):
# Try to have as simple as possible subquery -> trim leading joins from # Try to have as simple as possible subquery -> trim leading joins from
# the subquery. # the subquery.
trimmed_prefix, contains_louter = query.trim_start(names_with_path) trimmed_prefix, contains_louter = query.trim_start(names_with_path)
query.remove_inherited_models()
# Add extra check to make sure the selected field will not be null # Add extra check to make sure the selected field will not be null
# since we are adding an IN <subquery> clause. This prevents the # since we are adding an IN <subquery> clause. This prevents the
# database from tripping over IN (...,NULL,...) selects and returning # database from tripping over IN (...,NULL,...) selects and returning
# nothing # nothing
alias, col = query.select[0].col col = query.select[0]
if self.is_nullable(query.select[0].field): select_field = col.field
lookup_class = query.select[0].field.get_lookup('isnull') alias = col.alias
lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False) if self.is_nullable(select_field):
lookup_class = select_field.get_lookup('isnull')
lookup = lookup_class(select_field.get_col(alias), False)
query.where.add(lookup, AND) query.where.add(lookup, AND)
if alias in can_reuse: if alias in can_reuse:
select_field = query.select[0].field
pk = select_field.model._meta.pk pk = select_field.model._meta.pk
# Need to add a restriction so that outer query's filters are in effect for # Need to add a restriction so that outer query's filters are in effect for
# the subquery, too. # the subquery, too.
query.bump_prefix(self) query.bump_prefix(self)
lookup_class = select_field.get_lookup('exact') lookup_class = select_field.get_lookup('exact')
lookup = lookup_class(Col(query.select[0].col[0], pk, pk), # Note that the query.select[0].alias is different from alias
Col(alias, pk, pk)) # due to bump_prefix above.
lookup = lookup_class(pk.get_col(query.select[0].alias),
pk.get_col(alias))
query.where.add(lookup, AND) query.where.add(lookup, AND)
query.external_aliases.add(alias) query.external_aliases.add(alias)
@ -1687,6 +1642,14 @@ class Query(object):
""" """
self.select = [] self.select = []
def add_select(self, col):
self.default_cols = False
self.select.append(col)
def set_select(self, cols):
self.default_cols = False
self.select = cols
def add_distinct_fields(self, *field_names): def add_distinct_fields(self, *field_names):
""" """
Adds and resolves the given fields to the query's "distinct on" clause. Adds and resolves the given fields to the query's "distinct on" clause.
@ -1710,7 +1673,7 @@ class Query(object):
name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m) name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
targets, final_alias, joins = self.trim_joins(targets, joins, path) targets, final_alias, joins = self.trim_joins(targets, joins, path)
for target in targets: for target in targets:
self.select.append(SelectInfo((final_alias, target.column), target)) self.add_select(target.get_col(final_alias))
except MultiJoin: except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name) raise FieldError("Invalid field name: '%s'" % name)
except FieldError: except FieldError:
@ -1723,7 +1686,6 @@ class Query(object):
+ list(self.annotation_select)) + list(self.annotation_select))
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names))) "Choices are: %s" % (name, ", ".join(names)))
self.remove_inherited_models()
def add_ordering(self, *ordering): def add_ordering(self, *ordering):
""" """
@ -1766,7 +1728,7 @@ class Query(object):
""" """
self.group_by = [] self.group_by = []
for col, _ in self.select: for col in self.select:
self.group_by.append(col) self.group_by.append(col)
if self._annotations: if self._annotations:
@ -1789,7 +1751,6 @@ class Query(object):
for part in field.split(LOOKUP_SEP): for part in field.split(LOOKUP_SEP):
d = d.setdefault(part, {}) d = d.setdefault(part, {})
self.select_related = field_dict self.select_related = field_dict
self.related_select_cols = []
def add_extra(self, select, select_params, where, params, tables, order_by): def add_extra(self, select, select_params, where, params, tables, order_by):
""" """
@ -1897,7 +1858,7 @@ class Query(object):
""" """
Callback used by get_deferred_field_names(). Callback used by get_deferred_field_names().
""" """
target[model] = set(f.name for f in fields) target[model] = {f.attname for f in fields}
def set_aggregate_mask(self, names): def set_aggregate_mask(self, names):
warnings.warn( warnings.warn(
@ -2041,7 +2002,7 @@ class Query(object):
if self.alias_refcount[table] > 0: if self.alias_refcount[table] > 0:
self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table) self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table)
break break
self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields] self.set_select([f.get_col(select_alias) for f in select_fields])
return trimmed_prefix, contains_louter return trimmed_prefix, contains_louter
def is_nullable(self, field): def is_nullable(self, field):

View File

@ -5,7 +5,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connections from django.db import connections
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.utils import six from django.utils import six
@ -71,7 +71,7 @@ class DeleteQuery(Query):
else: else:
innerq.clear_select_clause() innerq.clear_select_clause()
innerq.select = [ innerq.select = [
SelectInfo((self.get_initial_alias(), pk.column), None) pk.get_col(self.get_initial_alias())
] ]
values = innerq values = innerq
self.where = self.where_class() self.where = self.where_class()

View File

@ -483,7 +483,7 @@ instances::
class HandField(models.Field): class HandField(models.Field):
# ... # ...
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
if value is None: if value is None:
return value return value
return parse_hand(value) return parse_hand(value)

View File

@ -399,7 +399,7 @@ calling the appropriate methods on the wrapped expression.
clone.expression = self.expression.relabeled_clone(change_map) clone.expression = self.expression.relabeled_clone(change_map)
return clone return clone
.. method:: convert_value(self, value, connection) .. method:: convert_value(self, value, connection, context)
A hook allowing the expression to coerce ``value`` into a more A hook allowing the expression to coerce ``value`` into a more
appropriate type. appropriate type.

View File

@ -1670,7 +1670,7 @@ Field API reference
When loading data, :meth:`from_db_value` is used: When loading data, :meth:`from_db_value` is used:
.. method:: from_db_value(value, connection) .. method:: from_db_value(value, connection, context)
.. versionadded:: 1.8 .. versionadded:: 1.8

View File

@ -679,7 +679,7 @@ class BaseAggregateTestCase(TestCase):
# the only "ORDER BY" clause present in the query. # the only "ORDER BY" clause present in the query.
self.assertEqual( self.assertEqual(
re.findall(r'order by (\w+)', qstr), re.findall(r'order by (\w+)', qstr),
[', '.join(forced_ordering).lower()] [', '.join(f[1][0] for f in forced_ordering).lower()]
) )
else: else:
self.assertNotIn('order by', qstr) self.assertNotIn('order by', qstr)

View File

@ -490,9 +490,10 @@ class AggregationTests(TestCase):
# Regression for #15709 - Ensure each group_by field only exists once # Regression for #15709 - Ensure each group_by field only exists once
# per query # per query
qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by() qstr = str(Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by().query)
grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([], []) # Check that there is just one GROUP BY clause (zero commas means at
self.assertEqual(len(grouping), 1) # most one clause)
self.assertEqual(qstr[qstr.index('GROUP BY'):].count(', '), 0)
def test_duplicate_alias(self): def test_duplicate_alias(self):
# Regression for #11256 - duplicating a default alias raises ValueError. # Regression for #11256 - duplicating a default alias raises ValueError.
@ -924,14 +925,11 @@ class AggregationTests(TestCase):
# There should only be one GROUP BY clause, for the `id` column. # There should only be one GROUP BY clause, for the `id` column.
# `name` and `age` should not be grouped on. # `name` and `age` should not be grouped on.
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], []) _, _, group_by = results.query.get_compiler(using='default').pre_sql_setup()
self.assertEqual(len(grouping), 1) self.assertEqual(len(group_by), 1)
assert 'id' in grouping[0] self.assertIn('id', group_by[0][0])
assert 'name' not in grouping[0] self.assertNotIn('name', group_by[0][0])
assert 'age' not in grouping[0] self.assertNotIn('age', group_by[0][0])
# The query group_by property should also only show the `id`.
self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')])
# Ensure that we get correct results. # Ensure that we get correct results.
self.assertEqual( self.assertEqual(
@ -953,14 +951,11 @@ class AggregationTests(TestCase):
def test_aggregate_duplicate_columns_only(self): def test_aggregate_duplicate_columns_only(self):
# Works with only() too. # Works with only() too.
results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set')) results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set'))
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], []) _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup()
self.assertEqual(len(grouping), 1) self.assertEqual(len(grouping), 1)
assert 'id' in grouping[0] self.assertIn('id', grouping[0][0])
assert 'name' not in grouping[0] self.assertNotIn('name', grouping[0][0])
assert 'age' not in grouping[0] self.assertNotIn('age', grouping[0][0])
# The query group_by property should also only show the `id`.
self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')])
# Ensure that we get correct results. # Ensure that we get correct results.
self.assertEqual( self.assertEqual(
@ -983,14 +978,11 @@ class AggregationTests(TestCase):
# And select_related() # And select_related()
results = Book.objects.select_related('contact').annotate( results = Book.objects.select_related('contact').annotate(
num_authors=Count('authors')) num_authors=Count('authors'))
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], []) _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup()
self.assertEqual(len(grouping), 1) self.assertEqual(len(grouping), 1)
assert 'id' in grouping[0] self.assertIn('id', grouping[0][0])
assert 'name' not in grouping[0] self.assertNotIn('name', grouping[0][0])
assert 'contact' not in grouping[0] self.assertNotIn('contact', grouping[0][0])
# The query group_by property should also only show the `id`.
self.assertEqual(results.query.group_by, [('aggregation_regress_book', 'id')])
# Ensure that we get correct results. # Ensure that we get correct results.
self.assertEqual( self.assertEqual(

View File

@ -43,7 +43,7 @@ class MyAutoField(models.CharField):
value = MyWrapper(value) value = MyWrapper(value)
return value return value
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
if not value: if not value:
return return
return MyWrapper(value) return MyWrapper(value)

View File

@ -96,3 +96,11 @@ class Request(models.Model):
request2 = models.CharField(default='request2', max_length=1000) request2 = models.CharField(default='request2', max_length=1000)
request3 = models.CharField(default='request3', max_length=1000) request3 = models.CharField(default='request3', max_length=1000)
request4 = models.CharField(default='request4', max_length=1000) request4 = models.CharField(default='request4', max_length=1000)
class Base(models.Model):
text = models.TextField()
class Derived(Base):
other_text = models.TextField()

View File

@ -12,7 +12,7 @@ from django.test import TestCase, override_settings
from .models import ( from .models import (
ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, SimpleItem, Feature, ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, SimpleItem, Feature,
ItemAndSimpleItem, OneToOneItem, SpecialFeature, Location, Request, ItemAndSimpleItem, OneToOneItem, SpecialFeature, Location, Request,
ProxyRelated, ProxyRelated, Derived, Base,
) )
@ -145,6 +145,15 @@ class DeferRegressionTest(TestCase):
list(SimpleItem.objects.annotate(Count('feature')).only('name')), list(SimpleItem.objects.annotate(Count('feature')).only('name')),
list) list)
def test_ticket_23270(self):
Derived.objects.create(text="foo", other_text="bar")
with self.assertNumQueries(1):
obj = Base.objects.select_related("derived").defer("text")[0]
self.assertIsInstance(obj.derived, Derived)
self.assertEqual("bar", obj.derived.other_text)
self.assertNotIn("text", obj.__dict__)
self.assertEqual(1, obj.derived.base_ptr_id)
def test_only_and_defer_usage_on_proxy_models(self): def test_only_and_defer_usage_on_proxy_models(self):
# Regression for #15790 - only() broken for proxy models # Regression for #15790 - only() broken for proxy models
proxy = Proxy.objects.create(name="proxy", value=42) proxy = Proxy.objects.create(name="proxy", value=42)

View File

@ -18,7 +18,7 @@ class CashField(models.DecimalField):
kwargs['decimal_places'] = 2 kwargs['decimal_places'] = 2
super(CashField, self).__init__(**kwargs) super(CashField, self).__init__(**kwargs)
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
cash = Cash(value) cash = Cash(value)
cash.vendor = connection.vendor cash.vendor = connection.vendor
return cash return cash

View File

@ -3563,8 +3563,10 @@ class Ticket20955Tests(TestCase):
# version's queries. # version's queries.
task_get.creator.staffuser.staff task_get.creator.staffuser.staff
task_get.owner.staffuser.staff task_get.owner.staffuser.staff
task_select_related = Task.objects.select_related( qs = Task.objects.select_related(
'creator__staffuser__staff', 'owner__staffuser__staff').get(pk=task.pk) 'creator__staffuser__staff', 'owner__staffuser__staff')
self.assertEqual(str(qs.query).count(' JOIN '), 6)
task_select_related = qs.get(pk=task.pk)
with self.assertNumQueries(0): with self.assertNumQueries(0):
self.assertEqual(task_select_related.creator.staffuser.staff, self.assertEqual(task_select_related.creator.staffuser.staff,
task_get.creator.staffuser.staff) task_get.creator.staffuser.staff)

View File

@ -83,6 +83,7 @@ class ReverseSelectRelatedTestCase(TestCase):
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
self.assertEqual(stat.advanceduserstat.posts, 200) self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertEqual(stat.user.username, 'bob') self.assertEqual(stat.user.username, 'bob')
with self.assertNumQueries(1):
self.assertEqual(stat.advanceduserstat.user.username, 'bob') self.assertEqual(stat.advanceduserstat.user.username, 'bob')
def test_nullable_relation(self): def test_nullable_relation(self):

View File

@ -112,7 +112,7 @@ class TeamField(models.CharField):
return value return value
return Team(value) return Team(value)
def from_db_value(self, value, connection): def from_db_value(self, value, connection, context):
return Team(value) return Team(value)
def value_to_string(self, obj): def value_to_string(self, obj):