Fixed #24020 -- Refactored SQL compiler to use expressions
Refactored compiler SELECT, GROUP BY and ORDER BY generation. While there, also refactored select_related() implementation (get_cached_row() and get_klass_info() are now gone!). Made get_db_converters() method work on expressions instead of internal_type. This allows the backend converters to target specific expressions if need be. Added query.context, this can be used to set per-query state. Also changed the signature of database converters. They now accept context as an argument.
This commit is contained in:
parent
b8abfe141b
commit
0c7633178f
|
@ -10,7 +10,6 @@ from django.db.models import signals, DO_NOTHING
|
||||||
from django.db.models.base import ModelBase
|
from django.db.models.base import ModelBase
|
||||||
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
|
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
|
||||||
from django.db.models.query_utils import PathInfo
|
from django.db.models.query_utils import PathInfo
|
||||||
from django.db.models.expressions import Col
|
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.utils.encoding import smart_text, python_2_unicode_compatible
|
from django.utils.encoding import smart_text, python_2_unicode_compatible
|
||||||
|
|
||||||
|
@ -367,7 +366,7 @@ class GenericRelation(ForeignObject):
|
||||||
field = self.rel.to._meta.get_field(self.content_type_field_name)
|
field = self.rel.to._meta.get_field(self.content_type_field_name)
|
||||||
contenttype_pk = self.get_content_type().pk
|
contenttype_pk = self.get_content_type().pk
|
||||||
cond = where_class()
|
cond = where_class()
|
||||||
lookup = field.get_lookup('exact')(Col(remote_alias, field, field), contenttype_pk)
|
lookup = field.get_lookup('exact')(field.get_col(remote_alias), contenttype_pk)
|
||||||
cond.add(lookup, 'AND')
|
cond.add(lookup, 'AND')
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
|
|
|
@ -158,10 +158,10 @@ class BaseSpatialOperations(object):
|
||||||
|
|
||||||
# Default conversion functions for aggregates; will be overridden if implemented
|
# Default conversion functions for aggregates; will be overridden if implemented
|
||||||
# for the spatial backend.
|
# for the spatial backend.
|
||||||
def convert_extent(self, box):
|
def convert_extent(self, box, srid):
|
||||||
raise NotImplementedError('Aggregate extent not implemented for this spatial backend.')
|
raise NotImplementedError('Aggregate extent not implemented for this spatial backend.')
|
||||||
|
|
||||||
def convert_extent3d(self, box):
|
def convert_extent3d(self, box, srid):
|
||||||
raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.')
|
raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.')
|
||||||
|
|
||||||
def convert_geom(self, geom_val, geom_field):
|
def convert_geom(self, geom_val, geom_field):
|
||||||
|
|
|
@ -7,7 +7,6 @@ from django.contrib.gis.db.backends.utils import SpatialOperator
|
||||||
|
|
||||||
class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
|
class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
|
|
||||||
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
|
|
||||||
mysql = True
|
mysql = True
|
||||||
name = 'mysql'
|
name = 'mysql'
|
||||||
select = 'AsText(%s)'
|
select = 'AsText(%s)'
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
from django.contrib.gis.db.models.sql.compiler import GeoSQLCompiler as BaseGeoSQLCompiler
|
|
||||||
from django.db.backends.oracle import compiler
|
|
||||||
|
|
||||||
SQLCompiler = compiler.SQLCompiler
|
|
||||||
|
|
||||||
|
|
||||||
class GeoSQLCompiler(BaseGeoSQLCompiler, SQLCompiler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
|
@ -52,7 +52,6 @@ class SDORelate(SpatialOperator):
|
||||||
|
|
||||||
|
|
||||||
class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
compiler_module = "django.contrib.gis.db.backends.oracle.compiler"
|
|
||||||
|
|
||||||
name = 'oracle'
|
name = 'oracle'
|
||||||
oracle = True
|
oracle = True
|
||||||
|
@ -111,8 +110,9 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
def geo_quote_name(self, name):
|
def geo_quote_name(self, name):
|
||||||
return super(OracleOperations, self).geo_quote_name(name).upper()
|
return super(OracleOperations, self).geo_quote_name(name).upper()
|
||||||
|
|
||||||
def get_db_converters(self, internal_type):
|
def get_db_converters(self, expression):
|
||||||
converters = super(OracleOperations, self).get_db_converters(internal_type)
|
converters = super(OracleOperations, self).get_db_converters(expression)
|
||||||
|
internal_type = expression.output_field.get_internal_type()
|
||||||
geometry_fields = (
|
geometry_fields = (
|
||||||
'PointField', 'GeometryField', 'LineStringField',
|
'PointField', 'GeometryField', 'LineStringField',
|
||||||
'PolygonField', 'MultiPointField', 'MultiLineStringField',
|
'PolygonField', 'MultiPointField', 'MultiLineStringField',
|
||||||
|
@ -121,14 +121,23 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
)
|
)
|
||||||
if internal_type in geometry_fields:
|
if internal_type in geometry_fields:
|
||||||
converters.append(self.convert_textfield_value)
|
converters.append(self.convert_textfield_value)
|
||||||
|
if hasattr(expression.output_field, 'geom_type'):
|
||||||
|
converters.append(self.convert_geometry)
|
||||||
return converters
|
return converters
|
||||||
|
|
||||||
def convert_extent(self, clob):
|
def convert_geometry(self, value, expression, context):
|
||||||
|
if value:
|
||||||
|
value = Geometry(value)
|
||||||
|
if 'transformed_srid' in context:
|
||||||
|
value.srid = context['transformed_srid']
|
||||||
|
return value
|
||||||
|
|
||||||
|
def convert_extent(self, clob, srid):
|
||||||
if clob:
|
if clob:
|
||||||
# Generally, Oracle returns a polygon for the extent -- however,
|
# Generally, Oracle returns a polygon for the extent -- however,
|
||||||
# it can return a single point if there's only one Point in the
|
# it can return a single point if there's only one Point in the
|
||||||
# table.
|
# table.
|
||||||
ext_geom = Geometry(clob.read())
|
ext_geom = Geometry(clob.read(), srid)
|
||||||
gtype = str(ext_geom.geom_type)
|
gtype = str(ext_geom.geom_type)
|
||||||
if gtype == 'Polygon':
|
if gtype == 'Polygon':
|
||||||
# Construct the 4-tuple from the coordinates in the polygon.
|
# Construct the 4-tuple from the coordinates in the polygon.
|
||||||
|
@ -226,7 +235,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
else:
|
else:
|
||||||
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
|
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
|
||||||
sql_function = getattr(self, agg_name)
|
sql_function = getattr(self, agg_name)
|
||||||
return self.select % sql_template, sql_function
|
return sql_template, sql_function
|
||||||
|
|
||||||
# Routines for getting the OGC-compliant models.
|
# Routines for getting the OGC-compliant models.
|
||||||
def geometry_columns(self):
|
def geometry_columns(self):
|
||||||
|
|
|
@ -44,7 +44,6 @@ class PostGISDistanceOperator(PostGISOperator):
|
||||||
|
|
||||||
|
|
||||||
class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
|
|
||||||
name = 'postgis'
|
name = 'postgis'
|
||||||
postgis = True
|
postgis = True
|
||||||
geography = True
|
geography = True
|
||||||
|
@ -188,7 +187,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
agg_name = aggregate.__class__.__name__
|
agg_name = aggregate.__class__.__name__
|
||||||
return agg_name in self.valid_aggregates
|
return agg_name in self.valid_aggregates
|
||||||
|
|
||||||
def convert_extent(self, box):
|
def convert_extent(self, box, srid):
|
||||||
"""
|
"""
|
||||||
Returns a 4-tuple extent for the `Extent` aggregate by converting
|
Returns a 4-tuple extent for the `Extent` aggregate by converting
|
||||||
the bounding box text returned by PostGIS (`box` argument), for
|
the bounding box text returned by PostGIS (`box` argument), for
|
||||||
|
@ -199,7 +198,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
xmax, ymax = map(float, ur.split())
|
xmax, ymax = map(float, ur.split())
|
||||||
return (xmin, ymin, xmax, ymax)
|
return (xmin, ymin, xmax, ymax)
|
||||||
|
|
||||||
def convert_extent3d(self, box3d):
|
def convert_extent3d(self, box3d, srid):
|
||||||
"""
|
"""
|
||||||
Returns a 6-tuple extent for the `Extent3D` aggregate by converting
|
Returns a 6-tuple extent for the `Extent3D` aggregate by converting
|
||||||
the 3d bounding-box text returned by PostGIS (`box3d` argument), for
|
the 3d bounding-box text returned by PostGIS (`box3d` argument), for
|
||||||
|
|
|
@ -14,7 +14,6 @@ from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
|
||||||
class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
|
|
||||||
name = 'spatialite'
|
name = 'spatialite'
|
||||||
spatialite = True
|
spatialite = True
|
||||||
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
|
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
|
||||||
|
@ -131,11 +130,11 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
agg_name = aggregate.__class__.__name__
|
agg_name = aggregate.__class__.__name__
|
||||||
return agg_name in self.valid_aggregates
|
return agg_name in self.valid_aggregates
|
||||||
|
|
||||||
def convert_extent(self, box):
|
def convert_extent(self, box, srid):
|
||||||
"""
|
"""
|
||||||
Convert the polygon data received from Spatialite to min/max values.
|
Convert the polygon data received from Spatialite to min/max values.
|
||||||
"""
|
"""
|
||||||
shell = Geometry(box).shell
|
shell = Geometry(box, srid).shell
|
||||||
xmin, ymin = shell[0][:2]
|
xmin, ymin = shell[0][:2]
|
||||||
xmax, ymax = shell[2][:2]
|
xmax, ymax = shell[2][:2]
|
||||||
return (xmin, ymin, xmax, ymax)
|
return (xmin, ymin, xmax, ymax)
|
||||||
|
@ -256,7 +255,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
agg_name = agg_name.lower()
|
agg_name = agg_name.lower()
|
||||||
if agg_name == 'union':
|
if agg_name == 'union':
|
||||||
agg_name += 'agg'
|
agg_name += 'agg'
|
||||||
sql_template = self.select % '%(function)s(%(expressions)s)'
|
sql_template = '%(function)s(%(expressions)s)'
|
||||||
sql_function = getattr(self, agg_name)
|
sql_function = getattr(self, agg_name)
|
||||||
return sql_template, sql_function
|
return sql_template, sql_function
|
||||||
|
|
||||||
|
@ -268,3 +267,16 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||||
def spatial_ref_sys(self):
|
def spatial_ref_sys(self):
|
||||||
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys
|
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys
|
||||||
return SpatialiteSpatialRefSys
|
return SpatialiteSpatialRefSys
|
||||||
|
|
||||||
|
def get_db_converters(self, expression):
|
||||||
|
converters = super(SpatiaLiteOperations, self).get_db_converters(expression)
|
||||||
|
if hasattr(expression.output_field, 'geom_type'):
|
||||||
|
converters.append(self.convert_geometry)
|
||||||
|
return converters
|
||||||
|
|
||||||
|
def convert_geometry(self, value, expression, context):
|
||||||
|
if value:
|
||||||
|
value = Geometry(value)
|
||||||
|
if 'transformed_srid' in context:
|
||||||
|
value.srid = context['transformed_srid']
|
||||||
|
return value
|
||||||
|
|
|
@ -28,7 +28,7 @@ class GeoAggregate(Aggregate):
|
||||||
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
|
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
return connection.ops.convert_geom(value, self.output_field)
|
return connection.ops.convert_geom(value, self.output_field)
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,8 +43,8 @@ class Extent(GeoAggregate):
|
||||||
def __init__(self, expression, **extra):
|
def __init__(self, expression, **extra):
|
||||||
super(Extent, self).__init__(expression, output_field=ExtentField(), **extra)
|
super(Extent, self).__init__(expression, output_field=ExtentField(), **extra)
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
return connection.ops.convert_extent(value)
|
return connection.ops.convert_extent(value, context.get('transformed_srid'))
|
||||||
|
|
||||||
|
|
||||||
class Extent3D(GeoAggregate):
|
class Extent3D(GeoAggregate):
|
||||||
|
@ -54,8 +54,8 @@ class Extent3D(GeoAggregate):
|
||||||
def __init__(self, expression, **extra):
|
def __init__(self, expression, **extra):
|
||||||
super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra)
|
super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra)
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
return connection.ops.convert_extent3d(value)
|
return connection.ops.convert_extent3d(value, context.get('transformed_srid'))
|
||||||
|
|
||||||
|
|
||||||
class MakeLine(GeoAggregate):
|
class MakeLine(GeoAggregate):
|
||||||
|
|
|
@ -42,7 +42,30 @@ def get_srid_info(srid, connection):
|
||||||
return _srid_cache[connection.alias][srid]
|
return _srid_cache[connection.alias][srid]
|
||||||
|
|
||||||
|
|
||||||
class GeometryField(Field):
|
class GeoSelectFormatMixin(object):
|
||||||
|
def select_format(self, compiler, sql, params):
|
||||||
|
"""
|
||||||
|
Returns the selection format string, depending on the requirements
|
||||||
|
of the spatial backend. For example, Oracle and MySQL require custom
|
||||||
|
selection formats in order to retrieve geometries in OGC WKT. For all
|
||||||
|
other fields a simple '%s' format string is returned.
|
||||||
|
"""
|
||||||
|
connection = compiler.connection
|
||||||
|
srid = compiler.query.get_context('transformed_srid')
|
||||||
|
if srid:
|
||||||
|
sel_fmt = '%s(%%s, %s)' % (connection.ops.transform, srid)
|
||||||
|
else:
|
||||||
|
sel_fmt = '%s'
|
||||||
|
if connection.ops.select:
|
||||||
|
# This allows operations to be done on fields in the SELECT,
|
||||||
|
# overriding their values -- used by the Oracle and MySQL
|
||||||
|
# spatial backends to get database values as WKT, and by the
|
||||||
|
# `transform` method.
|
||||||
|
sel_fmt = connection.ops.select % sel_fmt
|
||||||
|
return sel_fmt % sql, params
|
||||||
|
|
||||||
|
|
||||||
|
class GeometryField(GeoSelectFormatMixin, Field):
|
||||||
"The base GIS field -- maps to the OpenGIS Specification Geometry type."
|
"The base GIS field -- maps to the OpenGIS Specification Geometry type."
|
||||||
|
|
||||||
# The OpenGIS Geometry name.
|
# The OpenGIS Geometry name.
|
||||||
|
@ -196,7 +219,7 @@ class GeometryField(Field):
|
||||||
else:
|
else:
|
||||||
return geom
|
return geom
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
if value and not isinstance(value, Geometry):
|
if value and not isinstance(value, Geometry):
|
||||||
value = Geometry(value)
|
value = Geometry(value)
|
||||||
return value
|
return value
|
||||||
|
@ -337,7 +360,7 @@ class GeometryCollectionField(GeometryField):
|
||||||
description = _("Geometry collection")
|
description = _("Geometry collection")
|
||||||
|
|
||||||
|
|
||||||
class ExtentField(Field):
|
class ExtentField(GeoSelectFormatMixin, Field):
|
||||||
"Used as a return value from an extent aggregate"
|
"Used as a return value from an extent aggregate"
|
||||||
|
|
||||||
description = _("Extent Aggregate Field")
|
description = _("Extent Aggregate Field")
|
||||||
|
|
|
@ -1,9 +1,16 @@
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
|
from django.db.models.expressions import RawSQL
|
||||||
|
from django.db.models.fields import Field
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
|
|
||||||
from django.contrib.gis.db.models import aggregates
|
from django.contrib.gis.db.models import aggregates
|
||||||
from django.contrib.gis.db.models.fields import get_srid_info, PointField, LineStringField
|
from django.contrib.gis.db.models.fields import (
|
||||||
from django.contrib.gis.db.models.sql import AreaField, DistanceField, GeomField, GeoQuery, GMLField
|
get_srid_info, LineStringField, GeometryField, PointField,
|
||||||
|
)
|
||||||
|
from django.contrib.gis.db.models.lookups import GISLookup
|
||||||
|
from django.contrib.gis.db.models.sql import (
|
||||||
|
AreaField, DistanceField, GeomField, GMLField,
|
||||||
|
)
|
||||||
from django.contrib.gis.geometry.backend import Geometry
|
from django.contrib.gis.geometry.backend import Geometry
|
||||||
from django.contrib.gis.measure import Area, Distance
|
from django.contrib.gis.measure import Area, Distance
|
||||||
|
|
||||||
|
@ -13,11 +20,6 @@ from django.utils import six
|
||||||
class GeoQuerySet(QuerySet):
|
class GeoQuerySet(QuerySet):
|
||||||
"The Geographic QuerySet."
|
"The Geographic QuerySet."
|
||||||
|
|
||||||
### Methods overloaded from QuerySet ###
|
|
||||||
def __init__(self, model=None, query=None, using=None, hints=None):
|
|
||||||
super(GeoQuerySet, self).__init__(model=model, query=query, using=using, hints=hints)
|
|
||||||
self.query = query or GeoQuery(self.model)
|
|
||||||
|
|
||||||
### GeoQuerySet Methods ###
|
### GeoQuerySet Methods ###
|
||||||
def area(self, tolerance=0.05, **kwargs):
|
def area(self, tolerance=0.05, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -26,7 +28,8 @@ class GeoQuerySet(QuerySet):
|
||||||
"""
|
"""
|
||||||
# Performing setup here rather than in `_spatial_attribute` so that
|
# Performing setup here rather than in `_spatial_attribute` so that
|
||||||
# we can get the units for `AreaField`.
|
# we can get the units for `AreaField`.
|
||||||
procedure_args, geo_field = self._spatial_setup('area', field_name=kwargs.get('field_name', None))
|
procedure_args, geo_field = self._spatial_setup(
|
||||||
|
'area', field_name=kwargs.get('field_name', None))
|
||||||
s = {'procedure_args': procedure_args,
|
s = {'procedure_args': procedure_args,
|
||||||
'geo_field': geo_field,
|
'geo_field': geo_field,
|
||||||
'setup': False,
|
'setup': False,
|
||||||
|
@ -378,24 +381,8 @@ class GeoQuerySet(QuerySet):
|
||||||
if not isinstance(srid, six.integer_types):
|
if not isinstance(srid, six.integer_types):
|
||||||
raise TypeError('An integer SRID must be provided.')
|
raise TypeError('An integer SRID must be provided.')
|
||||||
field_name = kwargs.get('field_name', None)
|
field_name = kwargs.get('field_name', None)
|
||||||
tmp, geo_field = self._spatial_setup('transform', field_name=field_name)
|
self._spatial_setup('transform', field_name=field_name)
|
||||||
|
self.query.add_context('transformed_srid', srid)
|
||||||
# Getting the selection SQL for the given geographic field.
|
|
||||||
field_col = self._geocol_select(geo_field, field_name)
|
|
||||||
|
|
||||||
# Why cascading substitutions? Because spatial backends like
|
|
||||||
# Oracle and MySQL already require a function call to convert to text, thus
|
|
||||||
# when there's also a transformation we need to cascade the substitutions.
|
|
||||||
# For example, 'SDO_UTIL.TO_WKTGEOMETRY(SDO_CS.TRANSFORM( ... )'
|
|
||||||
geo_col = self.query.custom_select.get(geo_field, field_col)
|
|
||||||
|
|
||||||
# Setting the key for the field's column with the custom SELECT SQL to
|
|
||||||
# override the geometry column returned from the database.
|
|
||||||
custom_sel = '%s(%s, %s)' % (connections[self.db].ops.transform, geo_col, srid)
|
|
||||||
# TODO: Should we have this as an alias?
|
|
||||||
# custom_sel = '(%s(%s, %s)) AS %s' % (SpatialBackend.transform, geo_col, srid, qn(geo_field.name))
|
|
||||||
self.query.transformed_srid = srid # So other GeoQuerySet methods
|
|
||||||
self.query.custom_select[geo_field] = custom_sel
|
|
||||||
return self._clone()
|
return self._clone()
|
||||||
|
|
||||||
def union(self, geom, **kwargs):
|
def union(self, geom, **kwargs):
|
||||||
|
@ -433,7 +420,7 @@ class GeoQuerySet(QuerySet):
|
||||||
|
|
||||||
# Is there a geographic field in the model to perform this
|
# Is there a geographic field in the model to perform this
|
||||||
# operation on?
|
# operation on?
|
||||||
geo_field = self.query._geo_field(field_name)
|
geo_field = self._geo_field(field_name)
|
||||||
if not geo_field:
|
if not geo_field:
|
||||||
raise TypeError('%s output only available on GeometryFields.' % func)
|
raise TypeError('%s output only available on GeometryFields.' % func)
|
||||||
|
|
||||||
|
@ -454,7 +441,7 @@ class GeoQuerySet(QuerySet):
|
||||||
returning their result to the caller of the function.
|
returning their result to the caller of the function.
|
||||||
"""
|
"""
|
||||||
# Getting the field the geographic aggregate will be called on.
|
# Getting the field the geographic aggregate will be called on.
|
||||||
geo_field = self.query._geo_field(field_name)
|
geo_field = self._geo_field(field_name)
|
||||||
if not geo_field:
|
if not geo_field:
|
||||||
raise TypeError('%s aggregate only available on GeometryFields.' % aggregate.name)
|
raise TypeError('%s aggregate only available on GeometryFields.' % aggregate.name)
|
||||||
|
|
||||||
|
@ -509,12 +496,12 @@ class GeoQuerySet(QuerySet):
|
||||||
settings.setdefault('select_params', [])
|
settings.setdefault('select_params', [])
|
||||||
|
|
||||||
connection = connections[self.db]
|
connection = connections[self.db]
|
||||||
backend = connection.ops
|
|
||||||
|
|
||||||
# Performing setup for the spatial column, unless told not to.
|
# Performing setup for the spatial column, unless told not to.
|
||||||
if settings.get('setup', True):
|
if settings.get('setup', True):
|
||||||
default_args, geo_field = self._spatial_setup(att, desc=settings['desc'], field_name=field_name,
|
default_args, geo_field = self._spatial_setup(
|
||||||
geo_field_type=settings.get('geo_field_type', None))
|
att, desc=settings['desc'], field_name=field_name,
|
||||||
|
geo_field_type=settings.get('geo_field_type', None))
|
||||||
for k, v in six.iteritems(default_args):
|
for k, v in six.iteritems(default_args):
|
||||||
settings['procedure_args'].setdefault(k, v)
|
settings['procedure_args'].setdefault(k, v)
|
||||||
else:
|
else:
|
||||||
|
@ -544,18 +531,19 @@ class GeoQuerySet(QuerySet):
|
||||||
|
|
||||||
# If the result of this function needs to be converted.
|
# If the result of this function needs to be converted.
|
||||||
if settings.get('select_field', False):
|
if settings.get('select_field', False):
|
||||||
sel_fld = settings['select_field']
|
select_field = settings['select_field']
|
||||||
if isinstance(sel_fld, GeomField) and backend.select:
|
|
||||||
self.query.custom_select[model_att] = backend.select
|
|
||||||
if connection.ops.oracle:
|
if connection.ops.oracle:
|
||||||
sel_fld.empty_strings_allowed = False
|
select_field.empty_strings_allowed = False
|
||||||
self.query.extra_select_fields[model_att] = sel_fld
|
else:
|
||||||
|
select_field = Field()
|
||||||
|
|
||||||
# Finally, setting the extra selection attribute with
|
# Finally, setting the extra selection attribute with
|
||||||
# the format string expanded with the stored procedure
|
# the format string expanded with the stored procedure
|
||||||
# arguments.
|
# arguments.
|
||||||
return self.extra(select={model_att: fmt % settings['procedure_args']},
|
self.query.add_annotation(
|
||||||
select_params=settings['select_params'])
|
RawSQL(fmt % settings['procedure_args'], settings['select_params'], select_field),
|
||||||
|
model_att)
|
||||||
|
return self
|
||||||
|
|
||||||
def _distance_attribute(self, func, geom=None, tolerance=0.05, spheroid=False, **kwargs):
|
def _distance_attribute(self, func, geom=None, tolerance=0.05, spheroid=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -616,8 +604,9 @@ class GeoQuerySet(QuerySet):
|
||||||
else:
|
else:
|
||||||
# Getting whether this field is in units of degrees since the field may have
|
# Getting whether this field is in units of degrees since the field may have
|
||||||
# been transformed via the `transform` GeoQuerySet method.
|
# been transformed via the `transform` GeoQuerySet method.
|
||||||
if self.query.transformed_srid:
|
srid = self.query.get_context('transformed_srid')
|
||||||
u, unit_name, s = get_srid_info(self.query.transformed_srid, connection)
|
if srid:
|
||||||
|
u, unit_name, s = get_srid_info(srid, connection)
|
||||||
geodetic = unit_name.lower() in geo_field.geodetic_units
|
geodetic = unit_name.lower() in geo_field.geodetic_units
|
||||||
|
|
||||||
if geodetic and not connection.features.supports_distance_geodetic:
|
if geodetic and not connection.features.supports_distance_geodetic:
|
||||||
|
@ -627,20 +616,20 @@ class GeoQuerySet(QuerySet):
|
||||||
)
|
)
|
||||||
|
|
||||||
if distance:
|
if distance:
|
||||||
if self.query.transformed_srid:
|
if srid:
|
||||||
# Setting the `geom_args` flag to false because we want to handle
|
# Setting the `geom_args` flag to false because we want to handle
|
||||||
# transformation SQL here, rather than the way done by default
|
# transformation SQL here, rather than the way done by default
|
||||||
# (which will transform to the original SRID of the field rather
|
# (which will transform to the original SRID of the field rather
|
||||||
# than to what was transformed to).
|
# than to what was transformed to).
|
||||||
geom_args = False
|
geom_args = False
|
||||||
procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, self.query.transformed_srid)
|
procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, srid)
|
||||||
if geom.srid is None or geom.srid == self.query.transformed_srid:
|
if geom.srid is None or geom.srid == srid:
|
||||||
# If the geom parameter srid is None, it is assumed the coordinates
|
# If the geom parameter srid is None, it is assumed the coordinates
|
||||||
# are in the transformed units. A placeholder is used for the
|
# are in the transformed units. A placeholder is used for the
|
||||||
# geometry parameter. `GeomFromText` constructor is also needed
|
# geometry parameter. `GeomFromText` constructor is also needed
|
||||||
# to wrap geom placeholder for SpatiaLite.
|
# to wrap geom placeholder for SpatiaLite.
|
||||||
if backend.spatialite:
|
if backend.spatialite:
|
||||||
procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, self.query.transformed_srid)
|
procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, srid)
|
||||||
else:
|
else:
|
||||||
procedure_fmt += ', %%s'
|
procedure_fmt += ', %%s'
|
||||||
else:
|
else:
|
||||||
|
@ -649,10 +638,11 @@ class GeoQuerySet(QuerySet):
|
||||||
# SpatiaLite also needs geometry placeholder wrapped in `GeomFromText`
|
# SpatiaLite also needs geometry placeholder wrapped in `GeomFromText`
|
||||||
# constructor.
|
# constructor.
|
||||||
if backend.spatialite:
|
if backend.spatialite:
|
||||||
procedure_fmt += ', %s(%s(%%%%s, %s), %s)' % (backend.transform, backend.from_text,
|
procedure_fmt += (', %s(%s(%%%%s, %s), %s)' % (
|
||||||
geom.srid, self.query.transformed_srid)
|
backend.transform, backend.from_text,
|
||||||
|
geom.srid, srid))
|
||||||
else:
|
else:
|
||||||
procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, self.query.transformed_srid)
|
procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, srid)
|
||||||
else:
|
else:
|
||||||
# `transform()` was not used on this GeoQuerySet.
|
# `transform()` was not used on this GeoQuerySet.
|
||||||
procedure_fmt = '%(geo_col)s,%(geom)s'
|
procedure_fmt = '%(geo_col)s,%(geom)s'
|
||||||
|
@ -743,22 +733,57 @@ class GeoQuerySet(QuerySet):
|
||||||
column. Takes into account if the geographic field is in a
|
column. Takes into account if the geographic field is in a
|
||||||
ForeignKey relation to the current model.
|
ForeignKey relation to the current model.
|
||||||
"""
|
"""
|
||||||
|
compiler = self.query.get_compiler(self.db)
|
||||||
opts = self.model._meta
|
opts = self.model._meta
|
||||||
if geo_field not in opts.fields:
|
if geo_field not in opts.fields:
|
||||||
# Is this operation going to be on a related geographic field?
|
# Is this operation going to be on a related geographic field?
|
||||||
# If so, it'll have to be added to the select related information
|
# If so, it'll have to be added to the select related information
|
||||||
# (e.g., if 'location__point' was given as the field name).
|
# (e.g., if 'location__point' was given as the field name).
|
||||||
|
# Note: the operation really is defined as "must add select related!"
|
||||||
self.query.add_select_related([field_name])
|
self.query.add_select_related([field_name])
|
||||||
compiler = self.query.get_compiler(self.db)
|
# Call pre_sql_setup() so that compiler.select gets populated.
|
||||||
compiler.pre_sql_setup()
|
compiler.pre_sql_setup()
|
||||||
for (rel_table, rel_col), field in self.query.related_select_cols:
|
for col, _, _ in compiler.select:
|
||||||
if field == geo_field:
|
if col.output_field == geo_field:
|
||||||
return compiler._field_column(geo_field, rel_table)
|
return col.as_sql(compiler, compiler.connection)[0]
|
||||||
raise ValueError("%r not in self.query.related_select_cols" % geo_field)
|
raise ValueError("%r not in compiler's related_select_cols" % geo_field)
|
||||||
elif geo_field not in opts.local_fields:
|
elif geo_field not in opts.local_fields:
|
||||||
# This geographic field is inherited from another model, so we have to
|
# This geographic field is inherited from another model, so we have to
|
||||||
# use the db table for the _parent_ model instead.
|
# use the db table for the _parent_ model instead.
|
||||||
parent_model = geo_field.model._meta.concrete_model
|
parent_model = geo_field.model._meta.concrete_model
|
||||||
return self.query.get_compiler(self.db)._field_column(geo_field, parent_model._meta.db_table)
|
return self._field_column(compiler, geo_field, parent_model._meta.db_table)
|
||||||
else:
|
else:
|
||||||
return self.query.get_compiler(self.db)._field_column(geo_field)
|
return self._field_column(compiler, geo_field)
|
||||||
|
|
||||||
|
# Private API utilities, subject to change.
|
||||||
|
def _geo_field(self, field_name=None):
|
||||||
|
"""
|
||||||
|
Returns the first Geometry field encountered or the one specified via
|
||||||
|
the `field_name` keyword. The `field_name` may be a string specifying
|
||||||
|
the geometry field on this GeoQuerySet's model, or a lookup string
|
||||||
|
to a geometry field via a ForeignKey relation.
|
||||||
|
"""
|
||||||
|
if field_name is None:
|
||||||
|
# Incrementing until the first geographic field is found.
|
||||||
|
for field in self.model._meta.fields:
|
||||||
|
if isinstance(field, GeometryField):
|
||||||
|
return field
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Otherwise, check by the given field name -- which may be
|
||||||
|
# a lookup to a _related_ geographic field.
|
||||||
|
return GISLookup._check_geo_field(self.model._meta, field_name)
|
||||||
|
|
||||||
|
def _field_column(self, compiler, field, table_alias=None, column=None):
|
||||||
|
"""
|
||||||
|
Helper function that returns the database column for the given field.
|
||||||
|
The table and column are returned (quoted) in the proper format, e.g.,
|
||||||
|
`"geoapp_city"."point"`. If `table_alias` is not specified, the
|
||||||
|
database table associated with the model of this `GeoQuerySet` will be
|
||||||
|
used. If `column` is specified, it will be used instead of the value
|
||||||
|
in `field.column`.
|
||||||
|
"""
|
||||||
|
if table_alias is None:
|
||||||
|
table_alias = compiler.query.get_meta().db_table
|
||||||
|
return "%s.%s" % (compiler.quote_name_unless_alias(table_alias),
|
||||||
|
compiler.connection.ops.quote_name(column or field.column))
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField, GMLField
|
from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField, GMLField
|
||||||
from django.contrib.gis.db.models.sql.query import GeoQuery
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AreaField', 'DistanceField', 'GeomField', 'GMLField', 'GeoQuery',
|
'AreaField', 'DistanceField', 'GeomField', 'GMLField'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,240 +0,0 @@
|
||||||
from django.db.backends.utils import truncate_name
|
|
||||||
from django.db.models.sql import compiler
|
|
||||||
from django.utils import six
|
|
||||||
|
|
||||||
SQLCompiler = compiler.SQLCompiler
|
|
||||||
|
|
||||||
|
|
||||||
class GeoSQLCompiler(compiler.SQLCompiler):
|
|
||||||
|
|
||||||
def get_columns(self, with_aliases=False):
|
|
||||||
"""
|
|
||||||
Return the list of columns to use in the select statement. If no
|
|
||||||
columns have been specified, returns all columns relating to fields in
|
|
||||||
the model.
|
|
||||||
|
|
||||||
If 'with_aliases' is true, any column names that are duplicated
|
|
||||||
(without the table names) are given unique aliases. This is needed in
|
|
||||||
some cases to avoid ambiguity with nested queries.
|
|
||||||
|
|
||||||
This routine is overridden from Query to handle customized selection of
|
|
||||||
geometry columns.
|
|
||||||
"""
|
|
||||||
qn = self.quote_name_unless_alias
|
|
||||||
qn2 = self.connection.ops.quote_name
|
|
||||||
result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
|
|
||||||
for alias, col in six.iteritems(self.query.extra_select)]
|
|
||||||
params = []
|
|
||||||
aliases = set(self.query.extra_select.keys())
|
|
||||||
if with_aliases:
|
|
||||||
col_aliases = aliases.copy()
|
|
||||||
else:
|
|
||||||
col_aliases = set()
|
|
||||||
if self.query.select:
|
|
||||||
only_load = self.deferred_to_columns()
|
|
||||||
# This loop customized for GeoQuery.
|
|
||||||
for col, field in self.query.select:
|
|
||||||
if isinstance(col, (list, tuple)):
|
|
||||||
alias, column = col
|
|
||||||
table = self.query.alias_map[alias].table_name
|
|
||||||
if table in only_load and column not in only_load[table]:
|
|
||||||
continue
|
|
||||||
r = self.get_field_select(field, alias, column)
|
|
||||||
if with_aliases:
|
|
||||||
if col[1] in col_aliases:
|
|
||||||
c_alias = 'Col%d' % len(col_aliases)
|
|
||||||
result.append('%s AS %s' % (r, c_alias))
|
|
||||||
aliases.add(c_alias)
|
|
||||||
col_aliases.add(c_alias)
|
|
||||||
else:
|
|
||||||
result.append('%s AS %s' % (r, qn2(col[1])))
|
|
||||||
aliases.add(r)
|
|
||||||
col_aliases.add(col[1])
|
|
||||||
else:
|
|
||||||
result.append(r)
|
|
||||||
aliases.add(r)
|
|
||||||
col_aliases.add(col[1])
|
|
||||||
else:
|
|
||||||
col_sql, col_params = col.as_sql(self, self.connection)
|
|
||||||
result.append(col_sql)
|
|
||||||
params.extend(col_params)
|
|
||||||
|
|
||||||
if hasattr(col, 'alias'):
|
|
||||||
aliases.add(col.alias)
|
|
||||||
col_aliases.add(col.alias)
|
|
||||||
|
|
||||||
elif self.query.default_cols:
|
|
||||||
cols, new_aliases = self.get_default_columns(with_aliases,
|
|
||||||
col_aliases)
|
|
||||||
result.extend(cols)
|
|
||||||
aliases.update(new_aliases)
|
|
||||||
|
|
||||||
max_name_length = self.connection.ops.max_name_length()
|
|
||||||
for alias, annotation in self.query.annotation_select.items():
|
|
||||||
agg_sql, agg_params = self.compile(annotation)
|
|
||||||
if alias is None:
|
|
||||||
result.append(agg_sql)
|
|
||||||
else:
|
|
||||||
result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
|
|
||||||
params.extend(agg_params)
|
|
||||||
|
|
||||||
# This loop customized for GeoQuery.
|
|
||||||
for (table, col), field in self.query.related_select_cols:
|
|
||||||
r = self.get_field_select(field, table, col)
|
|
||||||
if with_aliases and col in col_aliases:
|
|
||||||
c_alias = 'Col%d' % len(col_aliases)
|
|
||||||
result.append('%s AS %s' % (r, c_alias))
|
|
||||||
aliases.add(c_alias)
|
|
||||||
col_aliases.add(c_alias)
|
|
||||||
else:
|
|
||||||
result.append(r)
|
|
||||||
aliases.add(r)
|
|
||||||
col_aliases.add(col)
|
|
||||||
|
|
||||||
self._select_aliases = aliases
|
|
||||||
return result, params
|
|
||||||
|
|
||||||
def get_default_columns(self, with_aliases=False, col_aliases=None,
|
|
||||||
start_alias=None, opts=None, as_pairs=False, from_parent=None):
|
|
||||||
"""
|
|
||||||
Computes the default columns for selecting every field in the base
|
|
||||||
model. Will sometimes be called to pull in related models (e.g. via
|
|
||||||
select_related), in which case "opts" and "start_alias" will be given
|
|
||||||
to provide a starting point for the traversal.
|
|
||||||
|
|
||||||
Returns a list of strings, quoted appropriately for use in SQL
|
|
||||||
directly, as well as a set of aliases used in the select statement (if
|
|
||||||
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
|
|
||||||
of strings as the first component and None as the second component).
|
|
||||||
|
|
||||||
This routine is overridden from Query to handle customized selection of
|
|
||||||
geometry columns.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
if opts is None:
|
|
||||||
opts = self.query.get_meta()
|
|
||||||
aliases = set()
|
|
||||||
only_load = self.deferred_to_columns()
|
|
||||||
seen = self.query.included_inherited_models.copy()
|
|
||||||
if start_alias:
|
|
||||||
seen[None] = start_alias
|
|
||||||
for field in opts.concrete_fields:
|
|
||||||
model = field.model._meta.concrete_model
|
|
||||||
if model is opts.model:
|
|
||||||
model = None
|
|
||||||
if from_parent and model is not None and issubclass(from_parent, model):
|
|
||||||
# Avoid loading data for already loaded parents.
|
|
||||||
continue
|
|
||||||
alias = self.query.join_parent_model(opts, model, start_alias, seen)
|
|
||||||
table = self.query.alias_map[alias].table_name
|
|
||||||
if table in only_load and field.column not in only_load[table]:
|
|
||||||
continue
|
|
||||||
if as_pairs:
|
|
||||||
result.append((alias, field))
|
|
||||||
aliases.add(alias)
|
|
||||||
continue
|
|
||||||
# This part of the function is customized for GeoQuery. We
|
|
||||||
# see if there was any custom selection specified in the
|
|
||||||
# dictionary, and set up the selection format appropriately.
|
|
||||||
field_sel = self.get_field_select(field, alias)
|
|
||||||
if with_aliases and field.column in col_aliases:
|
|
||||||
c_alias = 'Col%d' % len(col_aliases)
|
|
||||||
result.append('%s AS %s' % (field_sel, c_alias))
|
|
||||||
col_aliases.add(c_alias)
|
|
||||||
aliases.add(c_alias)
|
|
||||||
else:
|
|
||||||
r = field_sel
|
|
||||||
result.append(r)
|
|
||||||
aliases.add(r)
|
|
||||||
if with_aliases:
|
|
||||||
col_aliases.add(field.column)
|
|
||||||
return result, aliases
|
|
||||||
|
|
||||||
def get_converters(self, fields):
|
|
||||||
converters = super(GeoSQLCompiler, self).get_converters(fields)
|
|
||||||
for i, alias in enumerate(self.query.extra_select):
|
|
||||||
field = self.query.extra_select_fields.get(alias)
|
|
||||||
if field:
|
|
||||||
backend_converters = self.connection.ops.get_db_converters(field.get_internal_type())
|
|
||||||
converters[i] = (backend_converters, [field.from_db_value], field)
|
|
||||||
return converters
|
|
||||||
|
|
||||||
#### Routines unique to GeoQuery ####
|
|
||||||
def get_extra_select_format(self, alias):
|
|
||||||
sel_fmt = '%s'
|
|
||||||
if hasattr(self.query, 'custom_select') and alias in self.query.custom_select:
|
|
||||||
sel_fmt = sel_fmt % self.query.custom_select[alias]
|
|
||||||
return sel_fmt
|
|
||||||
|
|
||||||
def get_field_select(self, field, alias=None, column=None):
|
|
||||||
"""
|
|
||||||
Returns the SELECT SQL string for the given field. Figures out
|
|
||||||
if any custom selection SQL is needed for the column The `alias`
|
|
||||||
keyword may be used to manually specify the database table where
|
|
||||||
the column exists, if not in the model associated with this
|
|
||||||
`GeoQuery`. Similarly, `column` may be used to specify the exact
|
|
||||||
column name, rather than using the `column` attribute on `field`.
|
|
||||||
"""
|
|
||||||
sel_fmt = self.get_select_format(field)
|
|
||||||
if field in self.query.custom_select:
|
|
||||||
field_sel = sel_fmt % self.query.custom_select[field]
|
|
||||||
else:
|
|
||||||
field_sel = sel_fmt % self._field_column(field, alias, column)
|
|
||||||
return field_sel
|
|
||||||
|
|
||||||
def get_select_format(self, fld):
|
|
||||||
"""
|
|
||||||
Returns the selection format string, depending on the requirements
|
|
||||||
of the spatial backend. For example, Oracle and MySQL require custom
|
|
||||||
selection formats in order to retrieve geometries in OGC WKT. For all
|
|
||||||
other fields a simple '%s' format string is returned.
|
|
||||||
"""
|
|
||||||
if self.connection.ops.select and hasattr(fld, 'geom_type'):
|
|
||||||
# This allows operations to be done on fields in the SELECT,
|
|
||||||
# overriding their values -- used by the Oracle and MySQL
|
|
||||||
# spatial backends to get database values as WKT, and by the
|
|
||||||
# `transform` method.
|
|
||||||
sel_fmt = self.connection.ops.select
|
|
||||||
|
|
||||||
# Because WKT doesn't contain spatial reference information,
|
|
||||||
# the SRID is prefixed to the returned WKT to ensure that the
|
|
||||||
# transformed geometries have an SRID different than that of the
|
|
||||||
# field -- this is only used by `transform` for Oracle and
|
|
||||||
# SpatiaLite backends.
|
|
||||||
if self.query.transformed_srid and (self.connection.ops.oracle or
|
|
||||||
self.connection.ops.spatialite):
|
|
||||||
sel_fmt = "'SRID=%d;'||%s" % (self.query.transformed_srid, sel_fmt)
|
|
||||||
else:
|
|
||||||
sel_fmt = '%s'
|
|
||||||
return sel_fmt
|
|
||||||
|
|
||||||
# Private API utilities, subject to change.
|
|
||||||
def _field_column(self, field, table_alias=None, column=None):
|
|
||||||
"""
|
|
||||||
Helper function that returns the database column for the given field.
|
|
||||||
The table and column are returned (quoted) in the proper format, e.g.,
|
|
||||||
`"geoapp_city"."point"`. If `table_alias` is not specified, the
|
|
||||||
database table associated with the model of this `GeoQuery` will be
|
|
||||||
used. If `column` is specified, it will be used instead of the value
|
|
||||||
in `field.column`.
|
|
||||||
"""
|
|
||||||
if table_alias is None:
|
|
||||||
table_alias = self.query.get_meta().db_table
|
|
||||||
return "%s.%s" % (self.quote_name_unless_alias(table_alias),
|
|
||||||
self.connection.ops.quote_name(column or field.column))
|
|
||||||
|
|
||||||
|
|
||||||
class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
|
|
||||||
pass
|
|
|
@ -3,6 +3,7 @@ This module holds simple classes to convert geospatial values from the
|
||||||
database.
|
database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from django.contrib.gis.db.models.fields import GeoSelectFormatMixin
|
||||||
from django.contrib.gis.geometry.backend import Geometry
|
from django.contrib.gis.geometry.backend import Geometry
|
||||||
from django.contrib.gis.measure import Area, Distance
|
from django.contrib.gis.measure import Area, Distance
|
||||||
|
|
||||||
|
@ -10,13 +11,19 @@ from django.contrib.gis.measure import Area, Distance
|
||||||
class BaseField(object):
|
class BaseField(object):
|
||||||
empty_strings_allowed = True
|
empty_strings_allowed = True
|
||||||
|
|
||||||
|
def get_db_converters(self, connection):
|
||||||
|
return [self.from_db_value]
|
||||||
|
|
||||||
|
def select_format(self, compiler, sql, params):
|
||||||
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
class AreaField(BaseField):
|
class AreaField(BaseField):
|
||||||
"Wrapper for Area values."
|
"Wrapper for Area values."
|
||||||
def __init__(self, area_att):
|
def __init__(self, area_att):
|
||||||
self.area_att = area_att
|
self.area_att = area_att
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = Area(**{self.area_att: value})
|
value = Area(**{self.area_att: value})
|
||||||
return value
|
return value
|
||||||
|
@ -30,7 +37,7 @@ class DistanceField(BaseField):
|
||||||
def __init__(self, distance_att):
|
def __init__(self, distance_att):
|
||||||
self.distance_att = distance_att
|
self.distance_att = distance_att
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = Distance(**{self.distance_att: value})
|
value = Distance(**{self.distance_att: value})
|
||||||
return value
|
return value
|
||||||
|
@ -39,12 +46,15 @@ class DistanceField(BaseField):
|
||||||
return 'DistanceField'
|
return 'DistanceField'
|
||||||
|
|
||||||
|
|
||||||
class GeomField(BaseField):
|
class GeomField(GeoSelectFormatMixin, BaseField):
|
||||||
"""
|
"""
|
||||||
Wrapper for Geometry values. It is a lightweight alternative to
|
Wrapper for Geometry values. It is a lightweight alternative to
|
||||||
using GeometryField (which requires an SQL query upon instantiation).
|
using GeometryField (which requires an SQL query upon instantiation).
|
||||||
"""
|
"""
|
||||||
def from_db_value(self, value, connection):
|
# Hacky marker for get_db_converters()
|
||||||
|
geom_type = None
|
||||||
|
|
||||||
|
def from_db_value(self, value, connection, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = Geometry(value)
|
value = Geometry(value)
|
||||||
return value
|
return value
|
||||||
|
@ -61,5 +71,5 @@ class GMLField(BaseField):
|
||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return 'GMLField'
|
return 'GMLField'
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
from django.db import connections
|
|
||||||
from django.db.models.query import sql
|
|
||||||
from django.db.models.sql.constants import QUERY_TERMS
|
|
||||||
|
|
||||||
from django.contrib.gis.db.models.fields import GeometryField
|
|
||||||
from django.contrib.gis.db.models.lookups import GISLookup
|
|
||||||
from django.contrib.gis.db.models import aggregates as gis_aggregates
|
|
||||||
from django.contrib.gis.db.models.sql.conversion import GeomField
|
|
||||||
|
|
||||||
|
|
||||||
class GeoQuery(sql.Query):
|
|
||||||
"""
|
|
||||||
A single spatial SQL query.
|
|
||||||
"""
|
|
||||||
# Overriding the valid query terms.
|
|
||||||
query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys())
|
|
||||||
|
|
||||||
compiler = 'GeoSQLCompiler'
|
|
||||||
|
|
||||||
#### Methods overridden from the base Query class ####
|
|
||||||
def __init__(self, model):
|
|
||||||
super(GeoQuery, self).__init__(model)
|
|
||||||
# The following attributes are customized for the GeoQuerySet.
|
|
||||||
# The SpatialBackend classes contain backend-specific routines and functions.
|
|
||||||
self.custom_select = {}
|
|
||||||
self.transformed_srid = None
|
|
||||||
self.extra_select_fields = {}
|
|
||||||
|
|
||||||
def clone(self, *args, **kwargs):
|
|
||||||
obj = super(GeoQuery, self).clone(*args, **kwargs)
|
|
||||||
# Customized selection dictionary and transformed srid flag have
|
|
||||||
# to also be added to obj.
|
|
||||||
obj.custom_select = self.custom_select.copy()
|
|
||||||
obj.transformed_srid = self.transformed_srid
|
|
||||||
obj.extra_select_fields = self.extra_select_fields.copy()
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def get_aggregation(self, using, force_subq=False):
|
|
||||||
# Remove any aggregates marked for reduction from the subquery
|
|
||||||
# and move them to the outer AggregateQuery.
|
|
||||||
connection = connections[using]
|
|
||||||
for alias, annotation in self.annotation_select.items():
|
|
||||||
if isinstance(annotation, gis_aggregates.GeoAggregate):
|
|
||||||
if not getattr(annotation, 'is_extent', False) or connection.ops.oracle:
|
|
||||||
self.extra_select_fields[alias] = GeomField()
|
|
||||||
return super(GeoQuery, self).get_aggregation(using, force_subq)
|
|
||||||
|
|
||||||
# Private API utilities, subject to change.
|
|
||||||
def _geo_field(self, field_name=None):
|
|
||||||
"""
|
|
||||||
Returns the first Geometry field encountered; or specified via the
|
|
||||||
`field_name` keyword. The `field_name` may be a string specifying
|
|
||||||
the geometry field on this GeoQuery's model, or a lookup string
|
|
||||||
to a geometry field via a ForeignKey relation.
|
|
||||||
"""
|
|
||||||
if field_name is None:
|
|
||||||
# Incrementing until the first geographic field is found.
|
|
||||||
for fld in self.model._meta.fields:
|
|
||||||
if isinstance(fld, GeometryField):
|
|
||||||
return fld
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
# Otherwise, check by the given field name -- which may be
|
|
||||||
# a lookup to a _related_ geographic field.
|
|
||||||
return GISLookup._check_geo_field(self.model._meta, field_name)
|
|
|
@ -231,15 +231,6 @@ class RelatedGeoModelTest(TestCase):
|
||||||
self.assertIn('Aurora', names)
|
self.assertIn('Aurora', names)
|
||||||
self.assertIn('Kecksburg', names)
|
self.assertIn('Kecksburg', names)
|
||||||
|
|
||||||
def test11_geoquery_pickle(self):
|
|
||||||
"Ensuring GeoQuery objects are unpickled correctly. See #10839."
|
|
||||||
import pickle
|
|
||||||
from django.contrib.gis.db.models.sql import GeoQuery
|
|
||||||
qs = City.objects.all()
|
|
||||||
q_str = pickle.dumps(qs.query)
|
|
||||||
q = pickle.loads(q_str)
|
|
||||||
self.assertEqual(GeoQuery, q.__class__)
|
|
||||||
|
|
||||||
# TODO: fix on Oracle -- get the following error because the SQL is ordered
|
# TODO: fix on Oracle -- get the following error because the SQL is ordered
|
||||||
# by a geometry object, which Oracle apparently doesn't like:
|
# by a geometry object, which Oracle apparently doesn't like:
|
||||||
# ORA-22901: cannot compare nested table or VARRAY or LOB attributes of an object type
|
# ORA-22901: cannot compare nested table or VARRAY or LOB attributes of an object type
|
||||||
|
|
|
@ -1262,7 +1262,7 @@ class BaseDatabaseOperations(object):
|
||||||
second = timezone.make_aware(second, tz)
|
second = timezone.make_aware(second, tz)
|
||||||
return [first, second]
|
return [first, second]
|
||||||
|
|
||||||
def get_db_converters(self, internal_type):
|
def get_db_converters(self, expression):
|
||||||
"""Get a list of functions needed to convert field data.
|
"""Get a list of functions needed to convert field data.
|
||||||
|
|
||||||
Some field types on some backends do not provide data in the correct
|
Some field types on some backends do not provide data in the correct
|
||||||
|
@ -1270,7 +1270,7 @@ class BaseDatabaseOperations(object):
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def convert_durationfield_value(self, value, field):
|
def convert_durationfield_value(self, value, expression, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = str(decimal.Decimal(value) / decimal.Decimal(1000000))
|
value = str(decimal.Decimal(value) / decimal.Decimal(1000000))
|
||||||
value = parse_duration(value)
|
value = parse_duration(value)
|
||||||
|
|
|
@ -302,7 +302,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
columns. If no ordering would otherwise be applied, we don't want any
|
columns. If no ordering would otherwise be applied, we don't want any
|
||||||
implicit sorting going on.
|
implicit sorting going on.
|
||||||
"""
|
"""
|
||||||
return ["NULL"]
|
return [(None, ("NULL", [], 'asc', False))]
|
||||||
|
|
||||||
def fulltext_search_sql(self, field_name):
|
def fulltext_search_sql(self, field_name):
|
||||||
return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name
|
return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name
|
||||||
|
@ -387,8 +387,9 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
return 'POW(%s)' % ','.join(sub_expressions)
|
return 'POW(%s)' % ','.join(sub_expressions)
|
||||||
return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
|
return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
|
||||||
|
|
||||||
def get_db_converters(self, internal_type):
|
def get_db_converters(self, expression):
|
||||||
converters = super(DatabaseOperations, self).get_db_converters(internal_type)
|
converters = super(DatabaseOperations, self).get_db_converters(expression)
|
||||||
|
internal_type = expression.output_field.get_internal_type()
|
||||||
if internal_type in ['BooleanField', 'NullBooleanField']:
|
if internal_type in ['BooleanField', 'NullBooleanField']:
|
||||||
converters.append(self.convert_booleanfield_value)
|
converters.append(self.convert_booleanfield_value)
|
||||||
if internal_type == 'UUIDField':
|
if internal_type == 'UUIDField':
|
||||||
|
@ -397,17 +398,17 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
converters.append(self.convert_textfield_value)
|
converters.append(self.convert_textfield_value)
|
||||||
return converters
|
return converters
|
||||||
|
|
||||||
def convert_booleanfield_value(self, value, field):
|
def convert_booleanfield_value(self, value, expression, context):
|
||||||
if value in (0, 1):
|
if value in (0, 1):
|
||||||
value = bool(value)
|
value = bool(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_uuidfield_value(self, value, field):
|
def convert_uuidfield_value(self, value, expression, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = uuid.UUID(value)
|
value = uuid.UUID(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_textfield_value(self, value, field):
|
def convert_textfield_value(self, value, expression, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = force_text(value)
|
value = force_text(value)
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -268,8 +268,9 @@ WHEN (new.%(col_name)s IS NULL)
|
||||||
sql = field_name # Cast to DATE removes sub-second precision.
|
sql = field_name # Cast to DATE removes sub-second precision.
|
||||||
return sql, []
|
return sql, []
|
||||||
|
|
||||||
def get_db_converters(self, internal_type):
|
def get_db_converters(self, expression):
|
||||||
converters = super(DatabaseOperations, self).get_db_converters(internal_type)
|
converters = super(DatabaseOperations, self).get_db_converters(expression)
|
||||||
|
internal_type = expression.output_field.get_internal_type()
|
||||||
if internal_type == 'TextField':
|
if internal_type == 'TextField':
|
||||||
converters.append(self.convert_textfield_value)
|
converters.append(self.convert_textfield_value)
|
||||||
elif internal_type == 'BinaryField':
|
elif internal_type == 'BinaryField':
|
||||||
|
@ -285,28 +286,29 @@ WHEN (new.%(col_name)s IS NULL)
|
||||||
converters.append(self.convert_empty_values)
|
converters.append(self.convert_empty_values)
|
||||||
return converters
|
return converters
|
||||||
|
|
||||||
def convert_empty_values(self, value, field):
|
def convert_empty_values(self, value, expression, context):
|
||||||
# Oracle stores empty strings as null. We need to undo this in
|
# Oracle stores empty strings as null. We need to undo this in
|
||||||
# order to adhere to the Django convention of using the empty
|
# order to adhere to the Django convention of using the empty
|
||||||
# string instead of null, but only if the field accepts the
|
# string instead of null, but only if the field accepts the
|
||||||
# empty string.
|
# empty string.
|
||||||
|
field = expression.output_field
|
||||||
if value is None and field.empty_strings_allowed:
|
if value is None and field.empty_strings_allowed:
|
||||||
value = ''
|
value = ''
|
||||||
if field.get_internal_type() == 'BinaryField':
|
if field.get_internal_type() == 'BinaryField':
|
||||||
value = b''
|
value = b''
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_textfield_value(self, value, field):
|
def convert_textfield_value(self, value, expression, context):
|
||||||
if isinstance(value, Database.LOB):
|
if isinstance(value, Database.LOB):
|
||||||
value = force_text(value.read())
|
value = force_text(value.read())
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_binaryfield_value(self, value, field):
|
def convert_binaryfield_value(self, value, expression, context):
|
||||||
if isinstance(value, Database.LOB):
|
if isinstance(value, Database.LOB):
|
||||||
value = force_bytes(value.read())
|
value = force_bytes(value.read())
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_booleanfield_value(self, value, field):
|
def convert_booleanfield_value(self, value, expression, context):
|
||||||
if value in (1, 0):
|
if value in (1, 0):
|
||||||
value = bool(value)
|
value = bool(value)
|
||||||
return value
|
return value
|
||||||
|
@ -314,16 +316,16 @@ WHEN (new.%(col_name)s IS NULL)
|
||||||
# cx_Oracle always returns datetime.datetime objects for
|
# cx_Oracle always returns datetime.datetime objects for
|
||||||
# DATE and TIMESTAMP columns, but Django wants to see a
|
# DATE and TIMESTAMP columns, but Django wants to see a
|
||||||
# python datetime.date, .time, or .datetime.
|
# python datetime.date, .time, or .datetime.
|
||||||
def convert_datefield_value(self, value, field):
|
def convert_datefield_value(self, value, expression, context):
|
||||||
if isinstance(value, Database.Timestamp):
|
if isinstance(value, Database.Timestamp):
|
||||||
return value.date()
|
return value.date()
|
||||||
|
|
||||||
def convert_timefield_value(self, value, field):
|
def convert_timefield_value(self, value, expression, context):
|
||||||
if isinstance(value, Database.Timestamp):
|
if isinstance(value, Database.Timestamp):
|
||||||
value = value.time()
|
value = value.time()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_uuidfield_value(self, value, field):
|
def convert_uuidfield_value(self, value, expression, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = uuid.UUID(value)
|
value = uuid.UUID(value)
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -269,8 +269,9 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
|
|
||||||
return six.text_type(value)
|
return six.text_type(value)
|
||||||
|
|
||||||
def get_db_converters(self, internal_type):
|
def get_db_converters(self, expression):
|
||||||
converters = super(DatabaseOperations, self).get_db_converters(internal_type)
|
converters = super(DatabaseOperations, self).get_db_converters(expression)
|
||||||
|
internal_type = expression.output_field.get_internal_type()
|
||||||
if internal_type == 'DateTimeField':
|
if internal_type == 'DateTimeField':
|
||||||
converters.append(self.convert_datetimefield_value)
|
converters.append(self.convert_datetimefield_value)
|
||||||
elif internal_type == 'DateField':
|
elif internal_type == 'DateField':
|
||||||
|
@ -283,25 +284,25 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||||
converters.append(self.convert_uuidfield_value)
|
converters.append(self.convert_uuidfield_value)
|
||||||
return converters
|
return converters
|
||||||
|
|
||||||
def convert_decimalfield_value(self, value, field):
|
def convert_decimalfield_value(self, value, expression, context):
|
||||||
return backend_utils.typecast_decimal(field.format_number(value))
|
return backend_utils.typecast_decimal(expression.output_field.format_number(value))
|
||||||
|
|
||||||
def convert_datefield_value(self, value, field):
|
def convert_datefield_value(self, value, expression, context):
|
||||||
if value is not None and not isinstance(value, datetime.date):
|
if value is not None and not isinstance(value, datetime.date):
|
||||||
value = parse_date(value)
|
value = parse_date(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_datetimefield_value(self, value, field):
|
def convert_datetimefield_value(self, value, expression, context):
|
||||||
if value is not None and not isinstance(value, datetime.datetime):
|
if value is not None and not isinstance(value, datetime.datetime):
|
||||||
value = parse_datetime_with_timezone_support(value)
|
value = parse_datetime_with_timezone_support(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_timefield_value(self, value, field):
|
def convert_timefield_value(self, value, expression, context):
|
||||||
if value is not None and not isinstance(value, datetime.time):
|
if value is not None and not isinstance(value, datetime.time):
|
||||||
value = parse_time(value)
|
value = parse_time(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def convert_uuidfield_value(self, value, field):
|
def convert_uuidfield_value(self, value, expression, context):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = uuid.UUID(value)
|
value = uuid.UUID(value)
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -87,7 +87,7 @@ class Avg(Aggregate):
|
||||||
def __init__(self, expression, **extra):
|
def __init__(self, expression, **extra):
|
||||||
super(Avg, self).__init__(expression, output_field=FloatField(), **extra)
|
super(Avg, self).__init__(expression, output_field=FloatField(), **extra)
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
return float(value)
|
return float(value)
|
||||||
|
@ -105,7 +105,7 @@ class Count(Aggregate):
|
||||||
super(Count, self).__init__(
|
super(Count, self).__init__(
|
||||||
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
|
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return 0
|
return 0
|
||||||
return int(value)
|
return int(value)
|
||||||
|
@ -128,7 +128,7 @@ class StdDev(Aggregate):
|
||||||
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
||||||
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
|
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
return float(value)
|
return float(value)
|
||||||
|
@ -146,7 +146,7 @@ class Variance(Aggregate):
|
||||||
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
||||||
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
|
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
return float(value)
|
return float(value)
|
||||||
|
|
|
@ -127,7 +127,7 @@ class ExpressionNode(CombinableMixin):
|
||||||
is_summary = False
|
is_summary = False
|
||||||
|
|
||||||
def get_db_converters(self, connection):
|
def get_db_converters(self, connection):
|
||||||
return [self.convert_value]
|
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
||||||
|
|
||||||
def __init__(self, output_field=None):
|
def __init__(self, output_field=None):
|
||||||
self._output_field = output_field
|
self._output_field = output_field
|
||||||
|
@ -240,7 +240,7 @@ class ExpressionNode(CombinableMixin):
|
||||||
raise FieldError(
|
raise FieldError(
|
||||||
"Expression contains mixed types. You must set output_field")
|
"Expression contains mixed types. You must set output_field")
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
"""
|
"""
|
||||||
Expressions provide their own converters because users have the option
|
Expressions provide their own converters because users have the option
|
||||||
of manually specifying the output_field which may be a different type
|
of manually specifying the output_field which may be a different type
|
||||||
|
@ -305,6 +305,8 @@ class ExpressionNode(CombinableMixin):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self):
|
||||||
|
if not self.contains_aggregate:
|
||||||
|
return [self]
|
||||||
cols = []
|
cols = []
|
||||||
for source in self.get_source_expressions():
|
for source in self.get_source_expressions():
|
||||||
cols.extend(source.get_group_by_cols())
|
cols.extend(source.get_group_by_cols())
|
||||||
|
@ -490,6 +492,9 @@ class Value(ExpressionNode):
|
||||||
return 'NULL', []
|
return 'NULL', []
|
||||||
return '%s', [self.value]
|
return '%s', [self.value]
|
||||||
|
|
||||||
|
def get_group_by_cols(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class DurationValue(Value):
|
class DurationValue(Value):
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
|
@ -499,6 +504,37 @@ class DurationValue(Value):
|
||||||
return connection.ops.date_interval_sql(self.value)
|
return connection.ops.date_interval_sql(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class RawSQL(ExpressionNode):
|
||||||
|
def __init__(self, sql, params, output_field=None):
|
||||||
|
if output_field is None:
|
||||||
|
output_field = fields.Field()
|
||||||
|
self.sql, self.params = sql, params
|
||||||
|
super(RawSQL, self).__init__(output_field=output_field)
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
return '(%s)' % self.sql, self.params
|
||||||
|
|
||||||
|
def get_group_by_cols(self):
|
||||||
|
return [self]
|
||||||
|
|
||||||
|
|
||||||
|
class Random(ExpressionNode):
|
||||||
|
def __init__(self):
|
||||||
|
super(Random, self).__init__(output_field=fields.FloatField())
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
return connection.ops.random_function_sql(), []
|
||||||
|
|
||||||
|
|
||||||
|
class ColIndexRef(ExpressionNode):
|
||||||
|
def __init__(self, idx):
|
||||||
|
self.idx = idx
|
||||||
|
super(ColIndexRef, self).__init__()
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
return str(self.idx), []
|
||||||
|
|
||||||
|
|
||||||
class Col(ExpressionNode):
|
class Col(ExpressionNode):
|
||||||
def __init__(self, alias, target, source=None):
|
def __init__(self, alias, target, source=None):
|
||||||
if source is None:
|
if source is None:
|
||||||
|
@ -516,6 +552,9 @@ class Col(ExpressionNode):
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
|
def get_db_converters(self, connection):
|
||||||
|
return self.output_field.get_db_converters(connection)
|
||||||
|
|
||||||
|
|
||||||
class Ref(ExpressionNode):
|
class Ref(ExpressionNode):
|
||||||
"""
|
"""
|
||||||
|
@ -537,7 +576,7 @@ class Ref(ExpressionNode):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
return "%s" % compiler.quote_name_unless_alias(self.refs), []
|
return "%s" % connection.ops.quote_name(self.refs), []
|
||||||
|
|
||||||
def get_group_by_cols(self):
|
def get_group_by_cols(self):
|
||||||
return [self]
|
return [self]
|
||||||
|
@ -581,7 +620,7 @@ class Date(ExpressionNode):
|
||||||
copy.lookup_type = self.lookup_type
|
copy.lookup_type = self.lookup_type
|
||||||
return copy
|
return copy
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
value = value.date()
|
value = value.date()
|
||||||
return value
|
return value
|
||||||
|
@ -629,7 +668,7 @@ class DateTime(ExpressionNode):
|
||||||
copy.tzname = self.tzname
|
copy.tzname = self.tzname
|
||||||
return copy
|
return copy
|
||||||
|
|
||||||
def convert_value(self, value, connection):
|
def convert_value(self, value, connection, context):
|
||||||
if settings.USE_TZ:
|
if settings.USE_TZ:
|
||||||
if value is None:
|
if value is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -333,6 +333,28 @@ class Field(RegisterLookupMixin):
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def get_col(self, alias, source=None):
|
||||||
|
if source is None:
|
||||||
|
source = self
|
||||||
|
if alias != self.model._meta.db_table or source != self:
|
||||||
|
from django.db.models.expressions import Col
|
||||||
|
return Col(alias, self, source)
|
||||||
|
else:
|
||||||
|
return self.cached_col
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def cached_col(self):
|
||||||
|
from django.db.models.expressions import Col
|
||||||
|
return Col(self.model._meta.db_table, self)
|
||||||
|
|
||||||
|
def select_format(self, compiler, sql, params):
|
||||||
|
"""
|
||||||
|
Custom format for select clauses. For example, GIS columns need to be
|
||||||
|
selected as AsText(table.col) on MySQL as the table.col data can't be used
|
||||||
|
by Django.
|
||||||
|
"""
|
||||||
|
return sql, params
|
||||||
|
|
||||||
def deconstruct(self):
|
def deconstruct(self):
|
||||||
"""
|
"""
|
||||||
Returns enough information to recreate the field as a 4-tuple:
|
Returns enough information to recreate the field as a 4-tuple:
|
||||||
|
|
|
@ -15,7 +15,6 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
|
||||||
from django.db.models.lookups import IsNull
|
from django.db.models.lookups import IsNull
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from django.db.models.query_utils import PathInfo
|
from django.db.models.query_utils import PathInfo
|
||||||
from django.db.models.expressions import Col
|
|
||||||
from django.utils.encoding import force_text, smart_text
|
from django.utils.encoding import force_text, smart_text
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
from django.utils.deprecation import RemovedInDjango20Warning
|
from django.utils.deprecation import RemovedInDjango20Warning
|
||||||
|
@ -1738,26 +1737,26 @@ class ForeignObject(RelatedField):
|
||||||
[source.name for source in sources], raw_value),
|
[source.name for source in sources], raw_value),
|
||||||
AND)
|
AND)
|
||||||
elif lookup_type == 'isnull':
|
elif lookup_type == 'isnull':
|
||||||
root_constraint.add(IsNull(Col(alias, targets[0], sources[0]), raw_value), AND)
|
root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND)
|
||||||
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
|
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
|
||||||
and not is_multicolumn)):
|
and not is_multicolumn)):
|
||||||
value = get_normalized_value(raw_value)
|
value = get_normalized_value(raw_value)
|
||||||
for target, source, val in zip(targets, sources, value):
|
for target, source, val in zip(targets, sources, value):
|
||||||
lookup_class = target.get_lookup(lookup_type)
|
lookup_class = target.get_lookup(lookup_type)
|
||||||
root_constraint.add(
|
root_constraint.add(
|
||||||
lookup_class(Col(alias, target, source), val), AND)
|
lookup_class(target.get_col(alias, source), val), AND)
|
||||||
elif lookup_type in ['range', 'in'] and not is_multicolumn:
|
elif lookup_type in ['range', 'in'] and not is_multicolumn:
|
||||||
values = [get_normalized_value(value) for value in raw_value]
|
values = [get_normalized_value(value) for value in raw_value]
|
||||||
value = [val[0] for val in values]
|
value = [val[0] for val in values]
|
||||||
lookup_class = targets[0].get_lookup(lookup_type)
|
lookup_class = targets[0].get_lookup(lookup_type)
|
||||||
root_constraint.add(lookup_class(Col(alias, targets[0], sources[0]), value), AND)
|
root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND)
|
||||||
elif lookup_type == 'in':
|
elif lookup_type == 'in':
|
||||||
values = [get_normalized_value(value) for value in raw_value]
|
values = [get_normalized_value(value) for value in raw_value]
|
||||||
for value in values:
|
for value in values:
|
||||||
value_constraint = constraint_class()
|
value_constraint = constraint_class()
|
||||||
for source, target, val in zip(sources, targets, value):
|
for source, target, val in zip(sources, targets, value):
|
||||||
lookup_class = target.get_lookup('exact')
|
lookup_class = target.get_lookup('exact')
|
||||||
lookup = lookup_class(Col(alias, target, source), val)
|
lookup = lookup_class(target.get_col(alias, source), val)
|
||||||
value_constraint.add(lookup, AND)
|
value_constraint.add(lookup, AND)
|
||||||
root_constraint.add(value_constraint, OR)
|
root_constraint.add(value_constraint, OR)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,8 +13,7 @@ from django.db import (connections, router, transaction, IntegrityError,
|
||||||
DJANGO_VERSION_PICKLE_KEY)
|
DJANGO_VERSION_PICKLE_KEY)
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.fields import AutoField, Empty
|
from django.db.models.fields import AutoField, Empty
|
||||||
from django.db.models.query_utils import (Q, select_related_descend,
|
from django.db.models.query_utils import Q, deferred_class_factory, InvalidQuery
|
||||||
deferred_class_factory, InvalidQuery)
|
|
||||||
from django.db.models.deletion import Collector
|
from django.db.models.deletion import Collector
|
||||||
from django.db.models.sql.constants import CURSOR
|
from django.db.models.sql.constants import CURSOR
|
||||||
from django.db.models import sql
|
from django.db.models import sql
|
||||||
|
@ -233,76 +232,34 @@ class QuerySet(object):
|
||||||
An iterator over the results from applying this QuerySet to the
|
An iterator over the results from applying this QuerySet to the
|
||||||
database.
|
database.
|
||||||
"""
|
"""
|
||||||
fill_cache = False
|
|
||||||
if connections[self.db].features.supports_select_related:
|
|
||||||
fill_cache = self.query.select_related
|
|
||||||
if isinstance(fill_cache, dict):
|
|
||||||
requested = fill_cache
|
|
||||||
else:
|
|
||||||
requested = None
|
|
||||||
max_depth = self.query.max_depth
|
|
||||||
|
|
||||||
extra_select = list(self.query.extra_select)
|
|
||||||
annotation_select = list(self.query.annotation_select)
|
|
||||||
|
|
||||||
only_load = self.query.get_loaded_field_names()
|
|
||||||
fields = self.model._meta.concrete_fields
|
|
||||||
|
|
||||||
load_fields = []
|
|
||||||
# If only/defer clauses have been specified,
|
|
||||||
# build the list of fields that are to be loaded.
|
|
||||||
if only_load:
|
|
||||||
for field in self.model._meta.concrete_fields:
|
|
||||||
model = field.model._meta.model
|
|
||||||
try:
|
|
||||||
if field.name in only_load[model]:
|
|
||||||
# Add a field that has been explicitly included
|
|
||||||
load_fields.append(field.name)
|
|
||||||
except KeyError:
|
|
||||||
# Model wasn't explicitly listed in the only_load table
|
|
||||||
# Therefore, we need to load all fields from this model
|
|
||||||
load_fields.append(field.name)
|
|
||||||
|
|
||||||
skip = None
|
|
||||||
if load_fields:
|
|
||||||
# Some fields have been deferred, so we have to initialize
|
|
||||||
# via keyword arguments.
|
|
||||||
skip = set()
|
|
||||||
init_list = []
|
|
||||||
for field in fields:
|
|
||||||
if field.name not in load_fields:
|
|
||||||
skip.add(field.attname)
|
|
||||||
else:
|
|
||||||
init_list.append(field.attname)
|
|
||||||
model_cls = deferred_class_factory(self.model, skip)
|
|
||||||
else:
|
|
||||||
model_cls = self.model
|
|
||||||
init_list = [f.attname for f in fields]
|
|
||||||
|
|
||||||
# Cache db and model outside the loop
|
|
||||||
db = self.db
|
db = self.db
|
||||||
compiler = self.query.get_compiler(using=db)
|
compiler = self.query.get_compiler(using=db)
|
||||||
index_start = len(extra_select)
|
# Execute the query. This will also fill compiler.select, klass_info,
|
||||||
annotation_start = index_start + len(init_list)
|
# and annotations.
|
||||||
|
results = compiler.execute_sql()
|
||||||
if fill_cache:
|
select, klass_info, annotation_col_map = (compiler.select, compiler.klass_info,
|
||||||
klass_info = get_klass_info(model_cls, max_depth=max_depth,
|
compiler.annotation_col_map)
|
||||||
requested=requested, only_load=only_load)
|
if klass_info is None:
|
||||||
for row in compiler.results_iter():
|
return
|
||||||
if fill_cache:
|
model_cls = klass_info['model']
|
||||||
obj, _ = get_cached_row(row, index_start, db, klass_info,
|
select_fields = klass_info['select_fields']
|
||||||
offset=len(annotation_select))
|
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
|
||||||
else:
|
init_list = [f[0].output_field.attname
|
||||||
obj = model_cls.from_db(db, init_list, row[index_start:annotation_start])
|
for f in select[model_fields_start:model_fields_end]]
|
||||||
|
if len(init_list) != len(model_cls._meta.concrete_fields):
|
||||||
if extra_select:
|
init_set = set(init_list)
|
||||||
for i, k in enumerate(extra_select):
|
skip = [f.attname for f in model_cls._meta.concrete_fields
|
||||||
setattr(obj, k, row[i])
|
if f.attname not in init_set]
|
||||||
|
model_cls = deferred_class_factory(model_cls, skip)
|
||||||
# Add the annotations to the model
|
related_populators = get_related_populators(klass_info, select, db)
|
||||||
if annotation_select:
|
for row in compiler.results_iter(results):
|
||||||
for i, annotation in enumerate(annotation_select):
|
obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end])
|
||||||
setattr(obj, annotation, row[i + annotation_start])
|
if related_populators:
|
||||||
|
for rel_populator in related_populators:
|
||||||
|
rel_populator.populate(row, obj)
|
||||||
|
if annotation_col_map:
|
||||||
|
for attr_name, col_pos in annotation_col_map.items():
|
||||||
|
setattr(obj, attr_name, row[col_pos])
|
||||||
|
|
||||||
# Add the known related objects to the model, if there are any
|
# Add the known related objects to the model, if there are any
|
||||||
if self._known_related_objects:
|
if self._known_related_objects:
|
||||||
|
@ -1032,11 +989,8 @@ class QuerySet(object):
|
||||||
"""
|
"""
|
||||||
Prepare the query for computing a result that contains aggregate annotations.
|
Prepare the query for computing a result that contains aggregate annotations.
|
||||||
"""
|
"""
|
||||||
opts = self.model._meta
|
|
||||||
if self.query.group_by is None:
|
if self.query.group_by is None:
|
||||||
field_names = [f.attname for f in opts.concrete_fields]
|
self.query.group_by = True
|
||||||
self.query.add_fields(field_names, False)
|
|
||||||
self.query.set_group_by()
|
|
||||||
|
|
||||||
def _prepare(self):
|
def _prepare(self):
|
||||||
return self
|
return self
|
||||||
|
@ -1135,9 +1089,11 @@ class ValuesQuerySet(QuerySet):
|
||||||
Called by the _clone() method after initializing the rest of the
|
Called by the _clone() method after initializing the rest of the
|
||||||
instance.
|
instance.
|
||||||
"""
|
"""
|
||||||
|
if self.query.group_by is True:
|
||||||
|
self.query.add_fields([f.attname for f in self.model._meta.concrete_fields], False)
|
||||||
|
self.query.set_group_by()
|
||||||
self.query.clear_deferred_loading()
|
self.query.clear_deferred_loading()
|
||||||
self.query.clear_select_fields()
|
self.query.clear_select_fields()
|
||||||
|
|
||||||
if self._fields:
|
if self._fields:
|
||||||
self.extra_names = []
|
self.extra_names = []
|
||||||
self.annotation_names = []
|
self.annotation_names = []
|
||||||
|
@ -1246,11 +1202,12 @@ class ValuesQuerySet(QuerySet):
|
||||||
|
|
||||||
class ValuesListQuerySet(ValuesQuerySet):
|
class ValuesListQuerySet(ValuesQuerySet):
|
||||||
def iterator(self):
|
def iterator(self):
|
||||||
|
compiler = self.query.get_compiler(self.db)
|
||||||
if self.flat and len(self._fields) == 1:
|
if self.flat and len(self._fields) == 1:
|
||||||
for row in self.query.get_compiler(self.db).results_iter():
|
for row in compiler.results_iter():
|
||||||
yield row[0]
|
yield row[0]
|
||||||
elif not self.query.extra_select and not self.query.annotation_select:
|
elif not self.query.extra_select and not self.query.annotation_select:
|
||||||
for row in self.query.get_compiler(self.db).results_iter():
|
for row in compiler.results_iter():
|
||||||
yield tuple(row)
|
yield tuple(row)
|
||||||
else:
|
else:
|
||||||
# When extra(select=...) or an annotation is involved, the extra
|
# When extra(select=...) or an annotation is involved, the extra
|
||||||
|
@ -1269,7 +1226,7 @@ class ValuesListQuerySet(ValuesQuerySet):
|
||||||
else:
|
else:
|
||||||
fields = names
|
fields = names
|
||||||
|
|
||||||
for row in self.query.get_compiler(self.db).results_iter():
|
for row in compiler.results_iter():
|
||||||
data = dict(zip(names, row))
|
data = dict(zip(names, row))
|
||||||
yield tuple(data[f] for f in fields)
|
yield tuple(data[f] for f in fields)
|
||||||
|
|
||||||
|
@ -1281,244 +1238,6 @@ class ValuesListQuerySet(ValuesQuerySet):
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
|
||||||
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
|
||||||
only_load=None, from_parent=None):
|
|
||||||
"""
|
|
||||||
Helper function that recursively returns an information for a klass, to be
|
|
||||||
used in get_cached_row. It exists just to compute this information only
|
|
||||||
once for entire queryset. Otherwise it would be computed for each row, which
|
|
||||||
leads to poor performance on large querysets.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
* klass - the class to retrieve (and instantiate)
|
|
||||||
* max_depth - the maximum depth to which a select_related()
|
|
||||||
relationship should be explored.
|
|
||||||
* cur_depth - the current depth in the select_related() tree.
|
|
||||||
Used in recursive calls to determine if we should dig deeper.
|
|
||||||
* requested - A dictionary describing the select_related() tree
|
|
||||||
that is to be retrieved. keys are field names; values are
|
|
||||||
dictionaries describing the keys on that related object that
|
|
||||||
are themselves to be select_related().
|
|
||||||
* only_load - if the query has had only() or defer() applied,
|
|
||||||
this is the list of field names that will be returned. If None,
|
|
||||||
the full field list for `klass` can be assumed.
|
|
||||||
* from_parent - the parent model used to get to this model
|
|
||||||
|
|
||||||
Note that when travelling from parent to child, we will only load child
|
|
||||||
fields which aren't in the parent.
|
|
||||||
"""
|
|
||||||
if max_depth and requested is None and cur_depth > max_depth:
|
|
||||||
# We've recursed deeply enough; stop now.
|
|
||||||
return None
|
|
||||||
|
|
||||||
if only_load:
|
|
||||||
load_fields = only_load.get(klass) or set()
|
|
||||||
# When we create the object, we will also be creating populating
|
|
||||||
# all the parent classes, so traverse the parent classes looking
|
|
||||||
# for fields that must be included on load.
|
|
||||||
for parent in klass._meta.get_parent_list():
|
|
||||||
fields = only_load.get(parent)
|
|
||||||
if fields:
|
|
||||||
load_fields.update(fields)
|
|
||||||
else:
|
|
||||||
load_fields = None
|
|
||||||
|
|
||||||
if load_fields:
|
|
||||||
# Handle deferred fields.
|
|
||||||
skip = set()
|
|
||||||
init_list = []
|
|
||||||
# Build the list of fields that *haven't* been requested
|
|
||||||
for field in klass._meta.concrete_fields:
|
|
||||||
model = field.model._meta.concrete_model
|
|
||||||
if from_parent and model and issubclass(from_parent, model):
|
|
||||||
# Avoid loading fields already loaded for parent model for
|
|
||||||
# child models.
|
|
||||||
continue
|
|
||||||
elif field.name not in load_fields:
|
|
||||||
skip.add(field.attname)
|
|
||||||
else:
|
|
||||||
init_list.append(field.attname)
|
|
||||||
# Retrieve all the requested fields
|
|
||||||
field_count = len(init_list)
|
|
||||||
if skip:
|
|
||||||
klass = deferred_class_factory(klass, skip)
|
|
||||||
field_names = init_list
|
|
||||||
else:
|
|
||||||
field_names = ()
|
|
||||||
else:
|
|
||||||
# Load all fields on klass
|
|
||||||
|
|
||||||
field_count = len(klass._meta.concrete_fields)
|
|
||||||
# Check if we need to skip some parent fields.
|
|
||||||
if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields):
|
|
||||||
# Only load those fields which haven't been already loaded into
|
|
||||||
# 'from_parent'.
|
|
||||||
non_seen_models = [p for p in klass._meta.get_parent_list()
|
|
||||||
if not issubclass(from_parent, p)]
|
|
||||||
# Load local fields, too...
|
|
||||||
non_seen_models.append(klass)
|
|
||||||
field_names = [f.attname for f in klass._meta.concrete_fields
|
|
||||||
if f.model in non_seen_models]
|
|
||||||
field_count = len(field_names)
|
|
||||||
# Try to avoid populating field_names variable for performance reasons.
|
|
||||||
# If field_names variable is set, we use **kwargs based model init
|
|
||||||
# which is slower than normal init.
|
|
||||||
if field_count == len(klass._meta.concrete_fields):
|
|
||||||
field_names = ()
|
|
||||||
|
|
||||||
restricted = requested is not None
|
|
||||||
|
|
||||||
related_fields = []
|
|
||||||
for f in klass._meta.fields:
|
|
||||||
if select_related_descend(f, restricted, requested, load_fields):
|
|
||||||
if restricted:
|
|
||||||
next = requested[f.name]
|
|
||||||
else:
|
|
||||||
next = None
|
|
||||||
klass_info = get_klass_info(f.rel.to, max_depth=max_depth, cur_depth=cur_depth + 1,
|
|
||||||
requested=next, only_load=only_load)
|
|
||||||
related_fields.append((f, klass_info))
|
|
||||||
|
|
||||||
reverse_related_fields = []
|
|
||||||
if restricted:
|
|
||||||
for o in klass._meta.related_objects:
|
|
||||||
if o.field.unique and select_related_descend(o.field, restricted, requested,
|
|
||||||
only_load.get(o.related_model), reverse=True):
|
|
||||||
next = requested[o.field.related_query_name()]
|
|
||||||
parent = klass if issubclass(o.related_model, klass) else None
|
|
||||||
klass_info = get_klass_info(o.related_model, max_depth=max_depth, cur_depth=cur_depth + 1,
|
|
||||||
requested=next, only_load=only_load, from_parent=parent)
|
|
||||||
reverse_related_fields.append((o.field, klass_info))
|
|
||||||
if field_names:
|
|
||||||
pk_idx = field_names.index(klass._meta.pk.attname)
|
|
||||||
else:
|
|
||||||
meta = klass._meta
|
|
||||||
pk_idx = meta.concrete_fields.index(meta.pk)
|
|
||||||
|
|
||||||
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
|
|
||||||
|
|
||||||
|
|
||||||
def reorder_for_init(model, field_names, values):
|
|
||||||
"""
|
|
||||||
Reorders given field names and values for those fields
|
|
||||||
to be in the same order as model.__init__() expects to find them.
|
|
||||||
"""
|
|
||||||
new_names, new_values = [], []
|
|
||||||
for f in model._meta.concrete_fields:
|
|
||||||
if f.attname not in field_names:
|
|
||||||
continue
|
|
||||||
new_names.append(f.attname)
|
|
||||||
new_values.append(values[field_names.index(f.attname)])
|
|
||||||
assert len(new_names) == len(field_names)
|
|
||||||
return new_names, new_values
|
|
||||||
|
|
||||||
|
|
||||||
def get_cached_row(row, index_start, using, klass_info, offset=0,
|
|
||||||
parent_data=()):
|
|
||||||
"""
|
|
||||||
Helper function that recursively returns an object with the specified
|
|
||||||
related attributes already populated.
|
|
||||||
|
|
||||||
This method may be called recursively to populate deep select_related()
|
|
||||||
clauses.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
* row - the row of data returned by the database cursor
|
|
||||||
* index_start - the index of the row at which data for this
|
|
||||||
object is known to start
|
|
||||||
* offset - the number of additional fields that are known to
|
|
||||||
exist in row for `klass`. This usually means the number of
|
|
||||||
annotated results on `klass`.
|
|
||||||
* using - the database alias on which the query is being executed.
|
|
||||||
* klass_info - result of the get_klass_info function
|
|
||||||
* parent_data - parent model data in format (field, value). Used
|
|
||||||
to populate the non-local fields of child models.
|
|
||||||
"""
|
|
||||||
if klass_info is None:
|
|
||||||
return None
|
|
||||||
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
|
|
||||||
|
|
||||||
fields = row[index_start:index_start + field_count]
|
|
||||||
# If the pk column is None (or the equivalent '' in the case the
|
|
||||||
# connection interprets empty strings as nulls), then the related
|
|
||||||
# object must be non-existent - set the relation to None.
|
|
||||||
if (fields[pk_idx] is None or
|
|
||||||
(connections[using].features.interprets_empty_strings_as_nulls and
|
|
||||||
fields[pk_idx] == '')):
|
|
||||||
obj = None
|
|
||||||
elif field_names:
|
|
||||||
values = list(fields)
|
|
||||||
parent_values = []
|
|
||||||
parent_field_names = []
|
|
||||||
for rel_field, value in parent_data:
|
|
||||||
parent_field_names.append(rel_field.attname)
|
|
||||||
parent_values.append(value)
|
|
||||||
field_names, values = reorder_for_init(
|
|
||||||
klass, parent_field_names + field_names,
|
|
||||||
parent_values + values)
|
|
||||||
obj = klass.from_db(using, field_names, values)
|
|
||||||
else:
|
|
||||||
field_names = [f.attname for f in klass._meta.concrete_fields]
|
|
||||||
obj = klass.from_db(using, field_names, fields)
|
|
||||||
# Instantiate related fields
|
|
||||||
index_end = index_start + field_count + offset
|
|
||||||
# Iterate over each related object, populating any
|
|
||||||
# select_related() fields
|
|
||||||
for f, klass_info in related_fields:
|
|
||||||
# Recursively retrieve the data for the related object
|
|
||||||
cached_row = get_cached_row(row, index_end, using, klass_info)
|
|
||||||
# If the recursive descent found an object, populate the
|
|
||||||
# descriptor caches relevant to the object
|
|
||||||
if cached_row:
|
|
||||||
rel_obj, index_end = cached_row
|
|
||||||
if obj is not None:
|
|
||||||
# If the base object exists, populate the
|
|
||||||
# descriptor cache
|
|
||||||
setattr(obj, f.get_cache_name(), rel_obj)
|
|
||||||
if f.unique and rel_obj is not None:
|
|
||||||
# If the field is unique, populate the
|
|
||||||
# reverse descriptor cache on the related object
|
|
||||||
setattr(rel_obj, f.rel.get_cache_name(), obj)
|
|
||||||
|
|
||||||
# Now do the same, but for reverse related objects.
|
|
||||||
# Only handle the restricted case - i.e., don't do a depth
|
|
||||||
# descent into reverse relations unless explicitly requested
|
|
||||||
for f, klass_info in reverse_related_fields:
|
|
||||||
# Transfer data from this object to childs.
|
|
||||||
parent_data = []
|
|
||||||
for rel_field in klass_info[0]._meta.fields:
|
|
||||||
rel_model = rel_field.model._meta.concrete_model
|
|
||||||
if rel_model == klass_info[0]._meta.model:
|
|
||||||
rel_model = None
|
|
||||||
if rel_model is not None and isinstance(obj, rel_model):
|
|
||||||
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
|
|
||||||
# Recursively retrieve the data for the related object
|
|
||||||
cached_row = get_cached_row(row, index_end, using, klass_info,
|
|
||||||
parent_data=parent_data)
|
|
||||||
# If the recursive descent found an object, populate the
|
|
||||||
# descriptor caches relevant to the object
|
|
||||||
if cached_row:
|
|
||||||
rel_obj, index_end = cached_row
|
|
||||||
if obj is not None:
|
|
||||||
# populate the reverse descriptor cache
|
|
||||||
setattr(obj, f.rel.get_cache_name(), rel_obj)
|
|
||||||
if rel_obj is not None:
|
|
||||||
# If the related object exists, populate
|
|
||||||
# the descriptor cache.
|
|
||||||
setattr(rel_obj, f.get_cache_name(), obj)
|
|
||||||
# Populate related object caches using parent data.
|
|
||||||
for rel_field, _ in parent_data:
|
|
||||||
if rel_field.rel:
|
|
||||||
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
|
|
||||||
try:
|
|
||||||
cached_obj = getattr(obj, rel_field.get_cache_name())
|
|
||||||
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
|
|
||||||
except AttributeError:
|
|
||||||
# Related object hasn't been cached yet
|
|
||||||
pass
|
|
||||||
return obj, index_end
|
|
||||||
|
|
||||||
|
|
||||||
class RawQuerySet(object):
|
class RawQuerySet(object):
|
||||||
"""
|
"""
|
||||||
Provides an iterator which converts the results of raw SQL queries into
|
Provides an iterator which converts the results of raw SQL queries into
|
||||||
|
@ -1569,7 +1288,9 @@ class RawQuerySet(object):
|
||||||
else:
|
else:
|
||||||
model_cls = self.model
|
model_cls = self.model
|
||||||
fields = [self.model_fields.get(c, None) for c in self.columns]
|
fields = [self.model_fields.get(c, None) for c in self.columns]
|
||||||
converters = compiler.get_converters(fields)
|
converters = compiler.get_converters([
|
||||||
|
f.get_col(f.model._meta.db_table) if f else None for f in fields
|
||||||
|
])
|
||||||
for values in query:
|
for values in query:
|
||||||
if converters:
|
if converters:
|
||||||
values = compiler.apply_converters(values, converters)
|
values = compiler.apply_converters(values, converters)
|
||||||
|
@ -1920,3 +1641,120 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
|
||||||
qs._prefetch_done = True
|
qs._prefetch_done = True
|
||||||
obj._prefetched_objects_cache[cache_name] = qs
|
obj._prefetched_objects_cache[cache_name] = qs
|
||||||
return all_related_objects, additional_lookups
|
return all_related_objects, additional_lookups
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedPopulator(object):
|
||||||
|
"""
|
||||||
|
RelatedPopulator is used for select_related() object instantiation.
|
||||||
|
|
||||||
|
The idea is that each select_related() model will be populated by a
|
||||||
|
different RelatedPopulator instance. The RelatedPopulator instances get
|
||||||
|
klass_info and select (computed in SQLCompiler) plus the used db as
|
||||||
|
input for initialization. That data is used to compute which columns
|
||||||
|
to use, how to instantiate the model, and how to populate the links
|
||||||
|
between the objects.
|
||||||
|
|
||||||
|
The actual creation of the objects is done in populate() method. This
|
||||||
|
method gets row and from_obj as input and populates the select_related()
|
||||||
|
model instance.
|
||||||
|
"""
|
||||||
|
def __init__(self, klass_info, select, db):
|
||||||
|
self.db = db
|
||||||
|
# Pre-compute needed attributes. The attributes are:
|
||||||
|
# - model_cls: the possibly deferred model class to instantiate
|
||||||
|
# - either:
|
||||||
|
# - cols_start, cols_end: usually the columns in the row are
|
||||||
|
# in the same order model_cls.__init__ expects them, so we
|
||||||
|
# can instantiate by model_cls(*row[cols_start:cols_end])
|
||||||
|
# - reorder_for_init: When select_related descends to a child
|
||||||
|
# class, then we want to reuse the already selected parent
|
||||||
|
# data. However, in this case the parent data isn't necessarily
|
||||||
|
# in the same order that Model.__init__ expects it to be, so
|
||||||
|
# we have to reorder the parent data. The reorder_for_init
|
||||||
|
# attribute contains a function used to reorder the field data
|
||||||
|
# in the order __init__ expects it.
|
||||||
|
# - pk_idx: the index of the primary key field in the reordered
|
||||||
|
# model data. Used to check if a related object exists at all.
|
||||||
|
# - init_list: the field attnames fetched from the database. For
|
||||||
|
# deferred models this isn't the same as all attnames of the
|
||||||
|
# model's fields.
|
||||||
|
# - related_populators: a list of RelatedPopulator instances if
|
||||||
|
# select_related() descends to related models from this model.
|
||||||
|
# - cache_name, reverse_cache_name: the names to use for setattr
|
||||||
|
# when assigning the fetched object to the from_obj. If the
|
||||||
|
# reverse_cache_name is set, then we also set the reverse link.
|
||||||
|
select_fields = klass_info['select_fields']
|
||||||
|
from_parent = klass_info['from_parent']
|
||||||
|
if not from_parent:
|
||||||
|
self.cols_start = select_fields[0]
|
||||||
|
self.cols_end = select_fields[-1] + 1
|
||||||
|
self.init_list = [
|
||||||
|
f[0].output_field.attname for f in select[self.cols_start:self.cols_end]
|
||||||
|
]
|
||||||
|
self.reorder_for_init = None
|
||||||
|
else:
|
||||||
|
model_init_attnames = [
|
||||||
|
f.attname for f in klass_info['model']._meta.concrete_fields
|
||||||
|
]
|
||||||
|
reorder_map = []
|
||||||
|
for idx in select_fields:
|
||||||
|
field = select[idx][0].output_field
|
||||||
|
init_pos = model_init_attnames.index(field.attname)
|
||||||
|
reorder_map.append((init_pos, field.attname, idx))
|
||||||
|
reorder_map.sort()
|
||||||
|
self.init_list = [v[1] for v in reorder_map]
|
||||||
|
pos_list = [row_pos for _, _, row_pos in reorder_map]
|
||||||
|
|
||||||
|
def reorder_for_init(row):
|
||||||
|
return [row[row_pos] for row_pos in pos_list]
|
||||||
|
self.reorder_for_init = reorder_for_init
|
||||||
|
|
||||||
|
self.model_cls = self.get_deferred_cls(klass_info, self.init_list)
|
||||||
|
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
|
||||||
|
self.related_populators = get_related_populators(klass_info, select, self.db)
|
||||||
|
field = klass_info['field']
|
||||||
|
reverse = klass_info['reverse']
|
||||||
|
self.reverse_cache_name = None
|
||||||
|
if reverse:
|
||||||
|
self.cache_name = field.rel.get_cache_name()
|
||||||
|
self.reverse_cache_name = field.get_cache_name()
|
||||||
|
else:
|
||||||
|
self.cache_name = field.get_cache_name()
|
||||||
|
if field.unique:
|
||||||
|
self.reverse_cache_name = field.rel.get_cache_name()
|
||||||
|
|
||||||
|
def get_deferred_cls(self, klass_info, init_list):
|
||||||
|
model_cls = klass_info['model']
|
||||||
|
if len(init_list) != len(model_cls._meta.concrete_fields):
|
||||||
|
init_set = set(init_list)
|
||||||
|
skip = [
|
||||||
|
f.attname for f in model_cls._meta.concrete_fields
|
||||||
|
if f.attname not in init_set
|
||||||
|
]
|
||||||
|
model_cls = deferred_class_factory(model_cls, skip)
|
||||||
|
return model_cls
|
||||||
|
|
||||||
|
def populate(self, row, from_obj):
|
||||||
|
if self.reorder_for_init:
|
||||||
|
obj_data = self.reorder_for_init(row)
|
||||||
|
else:
|
||||||
|
obj_data = row[self.cols_start:self.cols_end]
|
||||||
|
if obj_data[self.pk_idx] is None:
|
||||||
|
obj = None
|
||||||
|
else:
|
||||||
|
obj = self.model_cls.from_db(self.db, self.init_list, obj_data)
|
||||||
|
if obj and self.related_populators:
|
||||||
|
for rel_iter in self.related_populators:
|
||||||
|
rel_iter.populate(row, obj)
|
||||||
|
setattr(from_obj, self.cache_name, obj)
|
||||||
|
if obj and self.reverse_cache_name:
|
||||||
|
setattr(obj, self.reverse_cache_name, from_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def get_related_populators(klass_info, select, db):
|
||||||
|
iterators = []
|
||||||
|
related_klass_infos = klass_info.get('related_klass_infos', [])
|
||||||
|
for rel_klass_info in related_klass_infos:
|
||||||
|
rel_cls = RelatedPopulator(rel_klass_info, select, db)
|
||||||
|
iterators.append(rel_cls)
|
||||||
|
return iterators
|
||||||
|
|
|
@ -170,7 +170,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
|
||||||
if not restricted and field.null:
|
if not restricted and field.null:
|
||||||
return False
|
return False
|
||||||
if load_fields:
|
if load_fields:
|
||||||
if field.name not in load_fields:
|
if field.attname not in load_fields:
|
||||||
if restricted and field.name in requested:
|
if restricted and field.name in requested:
|
||||||
raise InvalidQuery("Field %s.%s cannot be both deferred"
|
raise InvalidQuery("Field %s.%s cannot be both deferred"
|
||||||
" and traversed using select_related"
|
" and traversed using select_related"
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -2,7 +2,6 @@
|
||||||
Constants specific to the SQL storage portion of the ORM.
|
Constants specific to the SQL storage portion of the ORM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Valid query types (a set is used for speedy lookups). These are (currently)
|
# Valid query types (a set is used for speedy lookups). These are (currently)
|
||||||
|
@ -21,9 +20,6 @@ GET_ITERATOR_CHUNK_SIZE = 100
|
||||||
|
|
||||||
# Namedtuples for sql.* internal use.
|
# Namedtuples for sql.* internal use.
|
||||||
|
|
||||||
# Pairs of column clauses to select, and (possibly None) field for the clause.
|
|
||||||
SelectInfo = namedtuple('SelectInfo', 'col field')
|
|
||||||
|
|
||||||
# How many results to expect from a cursor.execute call
|
# How many results to expect from a cursor.execute call
|
||||||
MULTI = 'multi'
|
MULTI = 'multi'
|
||||||
SINGLE = 'single'
|
SINGLE = 'single'
|
||||||
|
|
|
@ -21,7 +21,7 @@ from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.db.models.expressions import Col, Ref
|
from django.db.models.expressions import Col, Ref
|
||||||
from django.db.models.query_utils import PathInfo, Q, refs_aggregate
|
from django.db.models.query_utils import PathInfo, Q, refs_aggregate
|
||||||
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
|
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
|
||||||
ORDER_PATTERN, SelectInfo, INNER, LOUTER)
|
ORDER_PATTERN, INNER, LOUTER)
|
||||||
from django.db.models.sql.datastructures import (
|
from django.db.models.sql.datastructures import (
|
||||||
EmptyResultSet, Empty, MultiJoin, Join, BaseTable)
|
EmptyResultSet, Empty, MultiJoin, Join, BaseTable)
|
||||||
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
|
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
|
||||||
|
@ -46,7 +46,7 @@ class RawQuery(object):
|
||||||
A single raw SQL query
|
A single raw SQL query
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sql, using, params=None):
|
def __init__(self, sql, using, params=None, context=None):
|
||||||
self.params = params or ()
|
self.params = params or ()
|
||||||
self.sql = sql
|
self.sql = sql
|
||||||
self.using = using
|
self.using = using
|
||||||
|
@ -57,9 +57,10 @@ class RawQuery(object):
|
||||||
self.low_mark, self.high_mark = 0, None # Used for offset/limit
|
self.low_mark, self.high_mark = 0, None # Used for offset/limit
|
||||||
self.extra_select = {}
|
self.extra_select = {}
|
||||||
self.annotation_select = {}
|
self.annotation_select = {}
|
||||||
|
self.context = context or {}
|
||||||
|
|
||||||
def clone(self, using):
|
def clone(self, using):
|
||||||
return RawQuery(self.sql, using, params=self.params)
|
return RawQuery(self.sql, using, params=self.params, context=self.context.copy())
|
||||||
|
|
||||||
def get_columns(self):
|
def get_columns(self):
|
||||||
if self.cursor is None:
|
if self.cursor is None:
|
||||||
|
@ -122,20 +123,23 @@ class Query(object):
|
||||||
self.standard_ordering = True
|
self.standard_ordering = True
|
||||||
self.used_aliases = set()
|
self.used_aliases = set()
|
||||||
self.filter_is_sticky = False
|
self.filter_is_sticky = False
|
||||||
self.included_inherited_models = {}
|
|
||||||
|
|
||||||
# SQL-related attributes
|
# SQL-related attributes
|
||||||
# Select and related select clauses as SelectInfo instances.
|
# Select and related select clauses are expressions to use in the
|
||||||
|
# SELECT clause of the query.
|
||||||
# The select is used for cases where we want to set up the select
|
# The select is used for cases where we want to set up the select
|
||||||
# clause to contain other than default fields (values(), annotate(),
|
# clause to contain other than default fields (values(), subqueries...)
|
||||||
# subqueries...)
|
# Note that annotations go to annotations dictionary.
|
||||||
self.select = []
|
self.select = []
|
||||||
# The related_select_cols is used for columns needed for
|
|
||||||
# select_related - this is populated in the compile stage.
|
|
||||||
self.related_select_cols = []
|
|
||||||
self.tables = [] # Aliases in the order they are created.
|
self.tables = [] # Aliases in the order they are created.
|
||||||
self.where = where()
|
self.where = where()
|
||||||
self.where_class = where
|
self.where_class = where
|
||||||
|
# The group_by attribute can have one of the following forms:
|
||||||
|
# - None: no group by at all in the query
|
||||||
|
# - A list of expressions: group by (at least) those expressions.
|
||||||
|
# String refs are also allowed for now.
|
||||||
|
# - True: group by all select fields of the model
|
||||||
|
# See compiler.get_group_by() for details.
|
||||||
self.group_by = None
|
self.group_by = None
|
||||||
self.having = where()
|
self.having = where()
|
||||||
self.order_by = []
|
self.order_by = []
|
||||||
|
@ -174,6 +178,8 @@ class Query(object):
|
||||||
# load.
|
# load.
|
||||||
self.deferred_loading = (set(), True)
|
self.deferred_loading = (set(), True)
|
||||||
|
|
||||||
|
self.context = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra(self):
|
def extra(self):
|
||||||
if self._extra is None:
|
if self._extra is None:
|
||||||
|
@ -254,14 +260,14 @@ class Query(object):
|
||||||
obj.default_cols = self.default_cols
|
obj.default_cols = self.default_cols
|
||||||
obj.default_ordering = self.default_ordering
|
obj.default_ordering = self.default_ordering
|
||||||
obj.standard_ordering = self.standard_ordering
|
obj.standard_ordering = self.standard_ordering
|
||||||
obj.included_inherited_models = self.included_inherited_models.copy()
|
|
||||||
obj.select = self.select[:]
|
obj.select = self.select[:]
|
||||||
obj.related_select_cols = []
|
|
||||||
obj.tables = self.tables[:]
|
obj.tables = self.tables[:]
|
||||||
obj.where = self.where.clone()
|
obj.where = self.where.clone()
|
||||||
obj.where_class = self.where_class
|
obj.where_class = self.where_class
|
||||||
if self.group_by is None:
|
if self.group_by is None:
|
||||||
obj.group_by = None
|
obj.group_by = None
|
||||||
|
elif self.group_by is True:
|
||||||
|
obj.group_by = True
|
||||||
else:
|
else:
|
||||||
obj.group_by = self.group_by[:]
|
obj.group_by = self.group_by[:]
|
||||||
obj.having = self.having.clone()
|
obj.having = self.having.clone()
|
||||||
|
@ -272,7 +278,6 @@ class Query(object):
|
||||||
obj.select_for_update = self.select_for_update
|
obj.select_for_update = self.select_for_update
|
||||||
obj.select_for_update_nowait = self.select_for_update_nowait
|
obj.select_for_update_nowait = self.select_for_update_nowait
|
||||||
obj.select_related = self.select_related
|
obj.select_related = self.select_related
|
||||||
obj.related_select_cols = []
|
|
||||||
obj._annotations = self._annotations.copy() if self._annotations is not None else None
|
obj._annotations = self._annotations.copy() if self._annotations is not None else None
|
||||||
if self.annotation_select_mask is None:
|
if self.annotation_select_mask is None:
|
||||||
obj.annotation_select_mask = None
|
obj.annotation_select_mask = None
|
||||||
|
@ -310,8 +315,15 @@ class Query(object):
|
||||||
obj.__dict__.update(kwargs)
|
obj.__dict__.update(kwargs)
|
||||||
if hasattr(obj, '_setup_query'):
|
if hasattr(obj, '_setup_query'):
|
||||||
obj._setup_query()
|
obj._setup_query()
|
||||||
|
obj.context = self.context.copy()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def add_context(self, key, value):
|
||||||
|
self.context[key] = value
|
||||||
|
|
||||||
|
def get_context(self, key, default=None):
|
||||||
|
return self.context.get(key, default)
|
||||||
|
|
||||||
def relabeled_clone(self, change_map):
|
def relabeled_clone(self, change_map):
|
||||||
clone = self.clone()
|
clone = self.clone()
|
||||||
clone.change_aliases(change_map)
|
clone.change_aliases(change_map)
|
||||||
|
@ -375,7 +387,8 @@ class Query(object):
|
||||||
# done in a subquery so that we are aggregating on the limit and/or
|
# done in a subquery so that we are aggregating on the limit and/or
|
||||||
# distinct results instead of applying the distinct and limit after the
|
# distinct results instead of applying the distinct and limit after the
|
||||||
# aggregation.
|
# aggregation.
|
||||||
if (self.group_by or has_limit or has_existing_annotations or self.distinct):
|
if (isinstance(self.group_by, list) or has_limit or has_existing_annotations or
|
||||||
|
self.distinct):
|
||||||
from django.db.models.sql.subqueries import AggregateQuery
|
from django.db.models.sql.subqueries import AggregateQuery
|
||||||
outer_query = AggregateQuery(self.model)
|
outer_query = AggregateQuery(self.model)
|
||||||
inner_query = self.clone()
|
inner_query = self.clone()
|
||||||
|
@ -383,7 +396,6 @@ class Query(object):
|
||||||
inner_query.clear_ordering(True)
|
inner_query.clear_ordering(True)
|
||||||
inner_query.select_for_update = False
|
inner_query.select_for_update = False
|
||||||
inner_query.select_related = False
|
inner_query.select_related = False
|
||||||
inner_query.related_select_cols = []
|
|
||||||
|
|
||||||
relabels = {t: 'subquery' for t in inner_query.tables}
|
relabels = {t: 'subquery' for t in inner_query.tables}
|
||||||
relabels[None] = 'subquery'
|
relabels[None] = 'subquery'
|
||||||
|
@ -407,26 +419,17 @@ class Query(object):
|
||||||
self.select = []
|
self.select = []
|
||||||
self.default_cols = False
|
self.default_cols = False
|
||||||
self._extra = {}
|
self._extra = {}
|
||||||
self.remove_inherited_models()
|
|
||||||
|
|
||||||
outer_query.clear_ordering(True)
|
outer_query.clear_ordering(True)
|
||||||
outer_query.clear_limits()
|
outer_query.clear_limits()
|
||||||
outer_query.select_for_update = False
|
outer_query.select_for_update = False
|
||||||
outer_query.select_related = False
|
outer_query.select_related = False
|
||||||
outer_query.related_select_cols = []
|
|
||||||
compiler = outer_query.get_compiler(using)
|
compiler = outer_query.get_compiler(using)
|
||||||
result = compiler.execute_sql(SINGLE)
|
result = compiler.execute_sql(SINGLE)
|
||||||
if result is None:
|
if result is None:
|
||||||
result = [None for q in outer_query.annotation_select.items()]
|
result = [None for q in outer_query.annotation_select.items()]
|
||||||
|
|
||||||
fields = [annotation.output_field
|
converters = compiler.get_converters(outer_query.annotation_select.values())
|
||||||
for alias, annotation in outer_query.annotation_select.items()]
|
|
||||||
converters = compiler.get_converters(fields)
|
|
||||||
for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()):
|
|
||||||
if position in converters:
|
|
||||||
converters[position][1].insert(0, annotation.convert_value)
|
|
||||||
else:
|
|
||||||
converters[position] = ([], [annotation.convert_value], annotation.output_field)
|
|
||||||
result = compiler.apply_converters(result, converters)
|
result = compiler.apply_converters(result, converters)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -476,7 +479,6 @@ class Query(object):
|
||||||
assert self.distinct_fields == rhs.distinct_fields, \
|
assert self.distinct_fields == rhs.distinct_fields, \
|
||||||
"Cannot combine queries with different distinct fields."
|
"Cannot combine queries with different distinct fields."
|
||||||
|
|
||||||
self.remove_inherited_models()
|
|
||||||
# Work out how to relabel the rhs aliases, if necessary.
|
# Work out how to relabel the rhs aliases, if necessary.
|
||||||
change_map = {}
|
change_map = {}
|
||||||
conjunction = (connector == AND)
|
conjunction = (connector == AND)
|
||||||
|
@ -545,13 +547,8 @@ class Query(object):
|
||||||
|
|
||||||
# Selection columns and extra extensions are those provided by 'rhs'.
|
# Selection columns and extra extensions are those provided by 'rhs'.
|
||||||
self.select = []
|
self.select = []
|
||||||
for col, field in rhs.select:
|
for col in rhs.select:
|
||||||
if isinstance(col, (list, tuple)):
|
self.add_select(col.relabeled_clone(change_map))
|
||||||
new_col = change_map.get(col[0], col[0]), col[1]
|
|
||||||
self.select.append(SelectInfo(new_col, field))
|
|
||||||
else:
|
|
||||||
new_col = col.relabeled_clone(change_map)
|
|
||||||
self.select.append(SelectInfo(new_col, field))
|
|
||||||
|
|
||||||
if connector == OR:
|
if connector == OR:
|
||||||
# It would be nice to be able to handle this, but the queries don't
|
# It would be nice to be able to handle this, but the queries don't
|
||||||
|
@ -661,17 +658,6 @@ class Query(object):
|
||||||
for model, values in six.iteritems(seen):
|
for model, values in six.iteritems(seen):
|
||||||
callback(target, model, values)
|
callback(target, model, values)
|
||||||
|
|
||||||
def deferred_to_columns_cb(self, target, model, fields):
|
|
||||||
"""
|
|
||||||
Callback used by deferred_to_columns(). The "target" parameter should
|
|
||||||
be a set instance.
|
|
||||||
"""
|
|
||||||
table = model._meta.db_table
|
|
||||||
if table not in target:
|
|
||||||
target[table] = set()
|
|
||||||
for field in fields:
|
|
||||||
target[table].add(field.column)
|
|
||||||
|
|
||||||
def table_alias(self, table_name, create=False):
|
def table_alias(self, table_name, create=False):
|
||||||
"""
|
"""
|
||||||
Returns a table alias for the given table_name and whether this is a
|
Returns a table alias for the given table_name and whether this is a
|
||||||
|
@ -788,10 +774,9 @@ class Query(object):
|
||||||
# "group by", "where" and "having".
|
# "group by", "where" and "having".
|
||||||
self.where.relabel_aliases(change_map)
|
self.where.relabel_aliases(change_map)
|
||||||
self.having.relabel_aliases(change_map)
|
self.having.relabel_aliases(change_map)
|
||||||
if self.group_by:
|
if isinstance(self.group_by, list):
|
||||||
self.group_by = [relabel_column(col) for col in self.group_by]
|
self.group_by = [relabel_column(col) for col in self.group_by]
|
||||||
self.select = [SelectInfo(relabel_column(s.col), s.field)
|
self.select = [col.relabeled_clone(change_map) for col in self.select]
|
||||||
for s in self.select]
|
|
||||||
if self._annotations:
|
if self._annotations:
|
||||||
self._annotations = OrderedDict(
|
self._annotations = OrderedDict(
|
||||||
(key, relabel_column(col)) for key, col in self._annotations.items())
|
(key, relabel_column(col)) for key, col in self._annotations.items())
|
||||||
|
@ -815,9 +800,6 @@ class Query(object):
|
||||||
if alias == old_alias:
|
if alias == old_alias:
|
||||||
self.tables[pos] = new_alias
|
self.tables[pos] = new_alias
|
||||||
break
|
break
|
||||||
for key, alias in self.included_inherited_models.items():
|
|
||||||
if alias in change_map:
|
|
||||||
self.included_inherited_models[key] = change_map[alias]
|
|
||||||
self.external_aliases = {change_map.get(alias, alias)
|
self.external_aliases = {change_map.get(alias, alias)
|
||||||
for alias in self.external_aliases}
|
for alias in self.external_aliases}
|
||||||
|
|
||||||
|
@ -930,28 +912,6 @@ class Query(object):
|
||||||
self.alias_map[alias] = join
|
self.alias_map[alias] = join
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
def setup_inherited_models(self):
|
|
||||||
"""
|
|
||||||
If the model that is the basis for this QuerySet inherits other models,
|
|
||||||
we need to ensure that those other models have their tables included in
|
|
||||||
the query.
|
|
||||||
|
|
||||||
We do this as a separate step so that subclasses know which
|
|
||||||
tables are going to be active in the query, without needing to compute
|
|
||||||
all the select columns (this method is called from pre_sql_setup(),
|
|
||||||
whereas column determination is a later part, and side-effect, of
|
|
||||||
as_sql()).
|
|
||||||
"""
|
|
||||||
opts = self.get_meta()
|
|
||||||
root_alias = self.tables[0]
|
|
||||||
seen = {None: root_alias}
|
|
||||||
|
|
||||||
for field in opts.fields:
|
|
||||||
model = field.model._meta.concrete_model
|
|
||||||
if model is not opts.model and model not in seen:
|
|
||||||
self.join_parent_model(opts, model, root_alias, seen)
|
|
||||||
self.included_inherited_models = seen
|
|
||||||
|
|
||||||
def join_parent_model(self, opts, model, alias, seen):
|
def join_parent_model(self, opts, model, alias, seen):
|
||||||
"""
|
"""
|
||||||
Makes sure the given 'model' is joined in the query. If 'model' isn't
|
Makes sure the given 'model' is joined in the query. If 'model' isn't
|
||||||
|
@ -969,7 +929,9 @@ class Query(object):
|
||||||
curr_opts = opts
|
curr_opts = opts
|
||||||
for int_model in chain:
|
for int_model in chain:
|
||||||
if int_model in seen:
|
if int_model in seen:
|
||||||
return seen[int_model]
|
curr_opts = int_model._meta
|
||||||
|
alias = seen[int_model]
|
||||||
|
continue
|
||||||
# Proxy model have elements in base chain
|
# Proxy model have elements in base chain
|
||||||
# with no parents, assign the new options
|
# with no parents, assign the new options
|
||||||
# object and skip to the next base in that
|
# object and skip to the next base in that
|
||||||
|
@ -984,23 +946,13 @@ class Query(object):
|
||||||
alias = seen[int_model] = joins[-1]
|
alias = seen[int_model] = joins[-1]
|
||||||
return alias or seen[None]
|
return alias or seen[None]
|
||||||
|
|
||||||
def remove_inherited_models(self):
|
|
||||||
"""
|
|
||||||
Undoes the effects of setup_inherited_models(). Should be called
|
|
||||||
whenever select columns (self.select) are set explicitly.
|
|
||||||
"""
|
|
||||||
for key, alias in self.included_inherited_models.items():
|
|
||||||
if key:
|
|
||||||
self.unref_alias(alias)
|
|
||||||
self.included_inherited_models = {}
|
|
||||||
|
|
||||||
def add_aggregate(self, aggregate, model, alias, is_summary):
|
def add_aggregate(self, aggregate, model, alias, is_summary):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"add_aggregate() is deprecated. Use add_annotation() instead.",
|
"add_aggregate() is deprecated. Use add_annotation() instead.",
|
||||||
RemovedInDjango20Warning, stacklevel=2)
|
RemovedInDjango20Warning, stacklevel=2)
|
||||||
self.add_annotation(aggregate, alias, is_summary)
|
self.add_annotation(aggregate, alias, is_summary)
|
||||||
|
|
||||||
def add_annotation(self, annotation, alias, is_summary):
|
def add_annotation(self, annotation, alias, is_summary=False):
|
||||||
"""
|
"""
|
||||||
Adds a single annotation expression to the Query
|
Adds a single annotation expression to the Query
|
||||||
"""
|
"""
|
||||||
|
@ -1011,6 +963,7 @@ class Query(object):
|
||||||
|
|
||||||
def prepare_lookup_value(self, value, lookups, can_reuse):
|
def prepare_lookup_value(self, value, lookups, can_reuse):
|
||||||
# Default lookup if none given is exact.
|
# Default lookup if none given is exact.
|
||||||
|
used_joins = []
|
||||||
if len(lookups) == 0:
|
if len(lookups) == 0:
|
||||||
lookups = ['exact']
|
lookups = ['exact']
|
||||||
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
||||||
|
@ -1026,7 +979,9 @@ class Query(object):
|
||||||
RemovedInDjango19Warning, stacklevel=2)
|
RemovedInDjango19Warning, stacklevel=2)
|
||||||
value = value()
|
value = value()
|
||||||
elif hasattr(value, 'resolve_expression'):
|
elif hasattr(value, 'resolve_expression'):
|
||||||
|
pre_joins = self.alias_refcount.copy()
|
||||||
value = value.resolve_expression(self, reuse=can_reuse)
|
value = value.resolve_expression(self, reuse=can_reuse)
|
||||||
|
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
|
||||||
# Subqueries need to use a different set of aliases than the
|
# Subqueries need to use a different set of aliases than the
|
||||||
# outer query. Call bump_prefix to change aliases of the inner
|
# outer query. Call bump_prefix to change aliases of the inner
|
||||||
# query (the value).
|
# query (the value).
|
||||||
|
@ -1044,7 +999,7 @@ class Query(object):
|
||||||
lookups[-1] == 'exact' and value == ''):
|
lookups[-1] == 'exact' and value == ''):
|
||||||
value = True
|
value = True
|
||||||
lookups[-1] = 'isnull'
|
lookups[-1] = 'isnull'
|
||||||
return value, lookups
|
return value, lookups, used_joins
|
||||||
|
|
||||||
def solve_lookup_type(self, lookup):
|
def solve_lookup_type(self, lookup):
|
||||||
"""
|
"""
|
||||||
|
@ -1173,8 +1128,7 @@ class Query(object):
|
||||||
|
|
||||||
# Work out the lookup type and remove it from the end of 'parts',
|
# Work out the lookup type and remove it from the end of 'parts',
|
||||||
# if necessary.
|
# if necessary.
|
||||||
value, lookups = self.prepare_lookup_value(value, lookups, can_reuse)
|
value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse)
|
||||||
used_joins = getattr(value, '_used_joins', [])
|
|
||||||
|
|
||||||
clause = self.where_class()
|
clause = self.where_class()
|
||||||
if reffed_aggregate:
|
if reffed_aggregate:
|
||||||
|
@ -1223,7 +1177,7 @@ class Query(object):
|
||||||
# handle Expressions as annotations
|
# handle Expressions as annotations
|
||||||
col = targets[0]
|
col = targets[0]
|
||||||
else:
|
else:
|
||||||
col = Col(alias, targets[0], field)
|
col = targets[0].get_col(alias, field)
|
||||||
condition = self.build_lookup(lookups, col, value)
|
condition = self.build_lookup(lookups, col, value)
|
||||||
if not condition:
|
if not condition:
|
||||||
# Backwards compat for custom lookups
|
# Backwards compat for custom lookups
|
||||||
|
@ -1258,7 +1212,7 @@ class Query(object):
|
||||||
# <=>
|
# <=>
|
||||||
# NOT (col IS NOT NULL AND col = someval).
|
# NOT (col IS NOT NULL AND col = someval).
|
||||||
lookup_class = targets[0].get_lookup('isnull')
|
lookup_class = targets[0].get_lookup('isnull')
|
||||||
clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND)
|
clause.add(lookup_class(targets[0].get_col(alias, sources[0]), False), AND)
|
||||||
return clause, used_joins if not require_outer else ()
|
return clause, used_joins if not require_outer else ()
|
||||||
|
|
||||||
def add_filter(self, filter_clause):
|
def add_filter(self, filter_clause):
|
||||||
|
@ -1535,7 +1489,7 @@ class Query(object):
|
||||||
self.unref_alias(joins.pop())
|
self.unref_alias(joins.pop())
|
||||||
return targets, joins[-1], joins
|
return targets, joins[-1], joins
|
||||||
|
|
||||||
def resolve_ref(self, name, allow_joins, reuse, summarize):
|
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False):
|
||||||
if not allow_joins and LOOKUP_SEP in name:
|
if not allow_joins and LOOKUP_SEP in name:
|
||||||
raise FieldError("Joined field references are not permitted in this query")
|
raise FieldError("Joined field references are not permitted in this query")
|
||||||
if name in self.annotations:
|
if name in self.annotations:
|
||||||
|
@ -1558,8 +1512,7 @@ class Query(object):
|
||||||
"isn't supported")
|
"isn't supported")
|
||||||
if reuse is not None:
|
if reuse is not None:
|
||||||
reuse.update(join_list)
|
reuse.update(join_list)
|
||||||
col = Col(join_list[-1], targets[0], sources[0])
|
col = targets[0].get_col(join_list[-1], sources[0])
|
||||||
col._used_joins = join_list
|
|
||||||
return col
|
return col
|
||||||
|
|
||||||
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
|
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
|
||||||
|
@ -1588,26 +1541,28 @@ class Query(object):
|
||||||
# Try to have as simple as possible subquery -> trim leading joins from
|
# Try to have as simple as possible subquery -> trim leading joins from
|
||||||
# the subquery.
|
# the subquery.
|
||||||
trimmed_prefix, contains_louter = query.trim_start(names_with_path)
|
trimmed_prefix, contains_louter = query.trim_start(names_with_path)
|
||||||
query.remove_inherited_models()
|
|
||||||
|
|
||||||
# Add extra check to make sure the selected field will not be null
|
# Add extra check to make sure the selected field will not be null
|
||||||
# since we are adding an IN <subquery> clause. This prevents the
|
# since we are adding an IN <subquery> clause. This prevents the
|
||||||
# database from tripping over IN (...,NULL,...) selects and returning
|
# database from tripping over IN (...,NULL,...) selects and returning
|
||||||
# nothing
|
# nothing
|
||||||
alias, col = query.select[0].col
|
col = query.select[0]
|
||||||
if self.is_nullable(query.select[0].field):
|
select_field = col.field
|
||||||
lookup_class = query.select[0].field.get_lookup('isnull')
|
alias = col.alias
|
||||||
lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False)
|
if self.is_nullable(select_field):
|
||||||
|
lookup_class = select_field.get_lookup('isnull')
|
||||||
|
lookup = lookup_class(select_field.get_col(alias), False)
|
||||||
query.where.add(lookup, AND)
|
query.where.add(lookup, AND)
|
||||||
if alias in can_reuse:
|
if alias in can_reuse:
|
||||||
select_field = query.select[0].field
|
|
||||||
pk = select_field.model._meta.pk
|
pk = select_field.model._meta.pk
|
||||||
# Need to add a restriction so that outer query's filters are in effect for
|
# Need to add a restriction so that outer query's filters are in effect for
|
||||||
# the subquery, too.
|
# the subquery, too.
|
||||||
query.bump_prefix(self)
|
query.bump_prefix(self)
|
||||||
lookup_class = select_field.get_lookup('exact')
|
lookup_class = select_field.get_lookup('exact')
|
||||||
lookup = lookup_class(Col(query.select[0].col[0], pk, pk),
|
# Note that the query.select[0].alias is different from alias
|
||||||
Col(alias, pk, pk))
|
# due to bump_prefix above.
|
||||||
|
lookup = lookup_class(pk.get_col(query.select[0].alias),
|
||||||
|
pk.get_col(alias))
|
||||||
query.where.add(lookup, AND)
|
query.where.add(lookup, AND)
|
||||||
query.external_aliases.add(alias)
|
query.external_aliases.add(alias)
|
||||||
|
|
||||||
|
@ -1687,6 +1642,14 @@ class Query(object):
|
||||||
"""
|
"""
|
||||||
self.select = []
|
self.select = []
|
||||||
|
|
||||||
|
def add_select(self, col):
|
||||||
|
self.default_cols = False
|
||||||
|
self.select.append(col)
|
||||||
|
|
||||||
|
def set_select(self, cols):
|
||||||
|
self.default_cols = False
|
||||||
|
self.select = cols
|
||||||
|
|
||||||
def add_distinct_fields(self, *field_names):
|
def add_distinct_fields(self, *field_names):
|
||||||
"""
|
"""
|
||||||
Adds and resolves the given fields to the query's "distinct on" clause.
|
Adds and resolves the given fields to the query's "distinct on" clause.
|
||||||
|
@ -1710,7 +1673,7 @@ class Query(object):
|
||||||
name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
|
name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
|
||||||
targets, final_alias, joins = self.trim_joins(targets, joins, path)
|
targets, final_alias, joins = self.trim_joins(targets, joins, path)
|
||||||
for target in targets:
|
for target in targets:
|
||||||
self.select.append(SelectInfo((final_alias, target.column), target))
|
self.add_select(target.get_col(final_alias))
|
||||||
except MultiJoin:
|
except MultiJoin:
|
||||||
raise FieldError("Invalid field name: '%s'" % name)
|
raise FieldError("Invalid field name: '%s'" % name)
|
||||||
except FieldError:
|
except FieldError:
|
||||||
|
@ -1723,7 +1686,6 @@ class Query(object):
|
||||||
+ list(self.annotation_select))
|
+ list(self.annotation_select))
|
||||||
raise FieldError("Cannot resolve keyword %r into field. "
|
raise FieldError("Cannot resolve keyword %r into field. "
|
||||||
"Choices are: %s" % (name, ", ".join(names)))
|
"Choices are: %s" % (name, ", ".join(names)))
|
||||||
self.remove_inherited_models()
|
|
||||||
|
|
||||||
def add_ordering(self, *ordering):
|
def add_ordering(self, *ordering):
|
||||||
"""
|
"""
|
||||||
|
@ -1766,7 +1728,7 @@ class Query(object):
|
||||||
"""
|
"""
|
||||||
self.group_by = []
|
self.group_by = []
|
||||||
|
|
||||||
for col, _ in self.select:
|
for col in self.select:
|
||||||
self.group_by.append(col)
|
self.group_by.append(col)
|
||||||
|
|
||||||
if self._annotations:
|
if self._annotations:
|
||||||
|
@ -1789,7 +1751,6 @@ class Query(object):
|
||||||
for part in field.split(LOOKUP_SEP):
|
for part in field.split(LOOKUP_SEP):
|
||||||
d = d.setdefault(part, {})
|
d = d.setdefault(part, {})
|
||||||
self.select_related = field_dict
|
self.select_related = field_dict
|
||||||
self.related_select_cols = []
|
|
||||||
|
|
||||||
def add_extra(self, select, select_params, where, params, tables, order_by):
|
def add_extra(self, select, select_params, where, params, tables, order_by):
|
||||||
"""
|
"""
|
||||||
|
@ -1897,7 +1858,7 @@ class Query(object):
|
||||||
"""
|
"""
|
||||||
Callback used by get_deferred_field_names().
|
Callback used by get_deferred_field_names().
|
||||||
"""
|
"""
|
||||||
target[model] = set(f.name for f in fields)
|
target[model] = {f.attname for f in fields}
|
||||||
|
|
||||||
def set_aggregate_mask(self, names):
|
def set_aggregate_mask(self, names):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -2041,7 +2002,7 @@ class Query(object):
|
||||||
if self.alias_refcount[table] > 0:
|
if self.alias_refcount[table] > 0:
|
||||||
self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table)
|
self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table)
|
||||||
break
|
break
|
||||||
self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields]
|
self.set_select([f.get_col(select_alias) for f in select_fields])
|
||||||
return trimmed_prefix, contains_louter
|
return trimmed_prefix, contains_louter
|
||||||
|
|
||||||
def is_nullable(self, field):
|
def is_nullable(self, field):
|
||||||
|
|
|
@ -5,7 +5,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
from django.db.models.query_utils import Q
|
from django.db.models.query_utils import Q
|
||||||
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
|
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||||
from django.db.models.sql.query import Query
|
from django.db.models.sql.query import Query
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ class DeleteQuery(Query):
|
||||||
else:
|
else:
|
||||||
innerq.clear_select_clause()
|
innerq.clear_select_clause()
|
||||||
innerq.select = [
|
innerq.select = [
|
||||||
SelectInfo((self.get_initial_alias(), pk.column), None)
|
pk.get_col(self.get_initial_alias())
|
||||||
]
|
]
|
||||||
values = innerq
|
values = innerq
|
||||||
self.where = self.where_class()
|
self.where = self.where_class()
|
||||||
|
|
|
@ -483,7 +483,7 @@ instances::
|
||||||
class HandField(models.Field):
|
class HandField(models.Field):
|
||||||
# ...
|
# ...
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
return parse_hand(value)
|
return parse_hand(value)
|
||||||
|
|
|
@ -399,7 +399,7 @@ calling the appropriate methods on the wrapped expression.
|
||||||
clone.expression = self.expression.relabeled_clone(change_map)
|
clone.expression = self.expression.relabeled_clone(change_map)
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
.. method:: convert_value(self, value, connection)
|
.. method:: convert_value(self, value, connection, context)
|
||||||
|
|
||||||
A hook allowing the expression to coerce ``value`` into a more
|
A hook allowing the expression to coerce ``value`` into a more
|
||||||
appropriate type.
|
appropriate type.
|
||||||
|
|
|
@ -1670,7 +1670,7 @@ Field API reference
|
||||||
|
|
||||||
When loading data, :meth:`from_db_value` is used:
|
When loading data, :meth:`from_db_value` is used:
|
||||||
|
|
||||||
.. method:: from_db_value(value, connection)
|
.. method:: from_db_value(value, connection, context)
|
||||||
|
|
||||||
.. versionadded:: 1.8
|
.. versionadded:: 1.8
|
||||||
|
|
||||||
|
|
|
@ -679,7 +679,7 @@ class BaseAggregateTestCase(TestCase):
|
||||||
# the only "ORDER BY" clause present in the query.
|
# the only "ORDER BY" clause present in the query.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
re.findall(r'order by (\w+)', qstr),
|
re.findall(r'order by (\w+)', qstr),
|
||||||
[', '.join(forced_ordering).lower()]
|
[', '.join(f[1][0] for f in forced_ordering).lower()]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.assertNotIn('order by', qstr)
|
self.assertNotIn('order by', qstr)
|
||||||
|
|
|
@ -490,9 +490,10 @@ class AggregationTests(TestCase):
|
||||||
|
|
||||||
# Regression for #15709 - Ensure each group_by field only exists once
|
# Regression for #15709 - Ensure each group_by field only exists once
|
||||||
# per query
|
# per query
|
||||||
qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by()
|
qstr = str(Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by().query)
|
||||||
grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([], [])
|
# Check that there is just one GROUP BY clause (zero commas means at
|
||||||
self.assertEqual(len(grouping), 1)
|
# most one clause)
|
||||||
|
self.assertEqual(qstr[qstr.index('GROUP BY'):].count(', '), 0)
|
||||||
|
|
||||||
def test_duplicate_alias(self):
|
def test_duplicate_alias(self):
|
||||||
# Regression for #11256 - duplicating a default alias raises ValueError.
|
# Regression for #11256 - duplicating a default alias raises ValueError.
|
||||||
|
@ -924,14 +925,11 @@ class AggregationTests(TestCase):
|
||||||
|
|
||||||
# There should only be one GROUP BY clause, for the `id` column.
|
# There should only be one GROUP BY clause, for the `id` column.
|
||||||
# `name` and `age` should not be grouped on.
|
# `name` and `age` should not be grouped on.
|
||||||
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], [])
|
_, _, group_by = results.query.get_compiler(using='default').pre_sql_setup()
|
||||||
self.assertEqual(len(grouping), 1)
|
self.assertEqual(len(group_by), 1)
|
||||||
assert 'id' in grouping[0]
|
self.assertIn('id', group_by[0][0])
|
||||||
assert 'name' not in grouping[0]
|
self.assertNotIn('name', group_by[0][0])
|
||||||
assert 'age' not in grouping[0]
|
self.assertNotIn('age', group_by[0][0])
|
||||||
|
|
||||||
# The query group_by property should also only show the `id`.
|
|
||||||
self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')])
|
|
||||||
|
|
||||||
# Ensure that we get correct results.
|
# Ensure that we get correct results.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -953,14 +951,11 @@ class AggregationTests(TestCase):
|
||||||
def test_aggregate_duplicate_columns_only(self):
|
def test_aggregate_duplicate_columns_only(self):
|
||||||
# Works with only() too.
|
# Works with only() too.
|
||||||
results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set'))
|
results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set'))
|
||||||
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], [])
|
_, _, grouping = results.query.get_compiler(using='default').pre_sql_setup()
|
||||||
self.assertEqual(len(grouping), 1)
|
self.assertEqual(len(grouping), 1)
|
||||||
assert 'id' in grouping[0]
|
self.assertIn('id', grouping[0][0])
|
||||||
assert 'name' not in grouping[0]
|
self.assertNotIn('name', grouping[0][0])
|
||||||
assert 'age' not in grouping[0]
|
self.assertNotIn('age', grouping[0][0])
|
||||||
|
|
||||||
# The query group_by property should also only show the `id`.
|
|
||||||
self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')])
|
|
||||||
|
|
||||||
# Ensure that we get correct results.
|
# Ensure that we get correct results.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -983,14 +978,11 @@ class AggregationTests(TestCase):
|
||||||
# And select_related()
|
# And select_related()
|
||||||
results = Book.objects.select_related('contact').annotate(
|
results = Book.objects.select_related('contact').annotate(
|
||||||
num_authors=Count('authors'))
|
num_authors=Count('authors'))
|
||||||
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], [])
|
_, _, grouping = results.query.get_compiler(using='default').pre_sql_setup()
|
||||||
self.assertEqual(len(grouping), 1)
|
self.assertEqual(len(grouping), 1)
|
||||||
assert 'id' in grouping[0]
|
self.assertIn('id', grouping[0][0])
|
||||||
assert 'name' not in grouping[0]
|
self.assertNotIn('name', grouping[0][0])
|
||||||
assert 'contact' not in grouping[0]
|
self.assertNotIn('contact', grouping[0][0])
|
||||||
|
|
||||||
# The query group_by property should also only show the `id`.
|
|
||||||
self.assertEqual(results.query.group_by, [('aggregation_regress_book', 'id')])
|
|
||||||
|
|
||||||
# Ensure that we get correct results.
|
# Ensure that we get correct results.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
|
@ -43,7 +43,7 @@ class MyAutoField(models.CharField):
|
||||||
value = MyWrapper(value)
|
value = MyWrapper(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
if not value:
|
if not value:
|
||||||
return
|
return
|
||||||
return MyWrapper(value)
|
return MyWrapper(value)
|
||||||
|
|
|
@ -96,3 +96,11 @@ class Request(models.Model):
|
||||||
request2 = models.CharField(default='request2', max_length=1000)
|
request2 = models.CharField(default='request2', max_length=1000)
|
||||||
request3 = models.CharField(default='request3', max_length=1000)
|
request3 = models.CharField(default='request3', max_length=1000)
|
||||||
request4 = models.CharField(default='request4', max_length=1000)
|
request4 = models.CharField(default='request4', max_length=1000)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(models.Model):
|
||||||
|
text = models.TextField()
|
||||||
|
|
||||||
|
|
||||||
|
class Derived(Base):
|
||||||
|
other_text = models.TextField()
|
||||||
|
|
|
@ -12,7 +12,7 @@ from django.test import TestCase, override_settings
|
||||||
from .models import (
|
from .models import (
|
||||||
ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, SimpleItem, Feature,
|
ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, SimpleItem, Feature,
|
||||||
ItemAndSimpleItem, OneToOneItem, SpecialFeature, Location, Request,
|
ItemAndSimpleItem, OneToOneItem, SpecialFeature, Location, Request,
|
||||||
ProxyRelated,
|
ProxyRelated, Derived, Base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,6 +145,15 @@ class DeferRegressionTest(TestCase):
|
||||||
list(SimpleItem.objects.annotate(Count('feature')).only('name')),
|
list(SimpleItem.objects.annotate(Count('feature')).only('name')),
|
||||||
list)
|
list)
|
||||||
|
|
||||||
|
def test_ticket_23270(self):
|
||||||
|
Derived.objects.create(text="foo", other_text="bar")
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
obj = Base.objects.select_related("derived").defer("text")[0]
|
||||||
|
self.assertIsInstance(obj.derived, Derived)
|
||||||
|
self.assertEqual("bar", obj.derived.other_text)
|
||||||
|
self.assertNotIn("text", obj.__dict__)
|
||||||
|
self.assertEqual(1, obj.derived.base_ptr_id)
|
||||||
|
|
||||||
def test_only_and_defer_usage_on_proxy_models(self):
|
def test_only_and_defer_usage_on_proxy_models(self):
|
||||||
# Regression for #15790 - only() broken for proxy models
|
# Regression for #15790 - only() broken for proxy models
|
||||||
proxy = Proxy.objects.create(name="proxy", value=42)
|
proxy = Proxy.objects.create(name="proxy", value=42)
|
||||||
|
|
|
@ -18,7 +18,7 @@ class CashField(models.DecimalField):
|
||||||
kwargs['decimal_places'] = 2
|
kwargs['decimal_places'] = 2
|
||||||
super(CashField, self).__init__(**kwargs)
|
super(CashField, self).__init__(**kwargs)
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
cash = Cash(value)
|
cash = Cash(value)
|
||||||
cash.vendor = connection.vendor
|
cash.vendor = connection.vendor
|
||||||
return cash
|
return cash
|
||||||
|
|
|
@ -3563,8 +3563,10 @@ class Ticket20955Tests(TestCase):
|
||||||
# version's queries.
|
# version's queries.
|
||||||
task_get.creator.staffuser.staff
|
task_get.creator.staffuser.staff
|
||||||
task_get.owner.staffuser.staff
|
task_get.owner.staffuser.staff
|
||||||
task_select_related = Task.objects.select_related(
|
qs = Task.objects.select_related(
|
||||||
'creator__staffuser__staff', 'owner__staffuser__staff').get(pk=task.pk)
|
'creator__staffuser__staff', 'owner__staffuser__staff')
|
||||||
|
self.assertEqual(str(qs.query).count(' JOIN '), 6)
|
||||||
|
task_select_related = qs.get(pk=task.pk)
|
||||||
with self.assertNumQueries(0):
|
with self.assertNumQueries(0):
|
||||||
self.assertEqual(task_select_related.creator.staffuser.staff,
|
self.assertEqual(task_select_related.creator.staffuser.staff,
|
||||||
task_get.creator.staffuser.staff)
|
task_get.creator.staffuser.staff)
|
||||||
|
|
|
@ -83,6 +83,7 @@ class ReverseSelectRelatedTestCase(TestCase):
|
||||||
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
|
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
|
||||||
self.assertEqual(stat.advanceduserstat.posts, 200)
|
self.assertEqual(stat.advanceduserstat.posts, 200)
|
||||||
self.assertEqual(stat.user.username, 'bob')
|
self.assertEqual(stat.user.username, 'bob')
|
||||||
|
with self.assertNumQueries(1):
|
||||||
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
|
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
|
||||||
|
|
||||||
def test_nullable_relation(self):
|
def test_nullable_relation(self):
|
||||||
|
|
|
@ -112,7 +112,7 @@ class TeamField(models.CharField):
|
||||||
return value
|
return value
|
||||||
return Team(value)
|
return Team(value)
|
||||||
|
|
||||||
def from_db_value(self, value, connection):
|
def from_db_value(self, value, connection, context):
|
||||||
return Team(value)
|
return Team(value)
|
||||||
|
|
||||||
def value_to_string(self, obj):
|
def value_to_string(self, obj):
|
||||||
|
|
Loading…
Reference in New Issue