Fixed #24020 -- Refactored SQL compiler to use expressions
Refactored compiler SELECT, GROUP BY and ORDER BY generation. While there, also refactored select_related() implementation (get_cached_row() and get_klass_info() are now gone!). Made get_db_converters() method work on expressions instead of internal_type. This allows the backend converters to target specific expressions if need be. Added query.context, this can be used to set per-query state. Also changed the signature of database converters. They now accept context as an argument.
This commit is contained in:
parent
b8abfe141b
commit
0c7633178f
|
@ -10,7 +10,6 @@ from django.db.models import signals, DO_NOTHING
|
|||
from django.db.models.base import ModelBase
|
||||
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
|
||||
from django.db.models.query_utils import PathInfo
|
||||
from django.db.models.expressions import Col
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.utils.encoding import smart_text, python_2_unicode_compatible
|
||||
|
||||
|
@ -367,7 +366,7 @@ class GenericRelation(ForeignObject):
|
|||
field = self.rel.to._meta.get_field(self.content_type_field_name)
|
||||
contenttype_pk = self.get_content_type().pk
|
||||
cond = where_class()
|
||||
lookup = field.get_lookup('exact')(Col(remote_alias, field, field), contenttype_pk)
|
||||
lookup = field.get_lookup('exact')(field.get_col(remote_alias), contenttype_pk)
|
||||
cond.add(lookup, 'AND')
|
||||
return cond
|
||||
|
||||
|
|
|
@ -158,10 +158,10 @@ class BaseSpatialOperations(object):
|
|||
|
||||
# Default conversion functions for aggregates; will be overridden if implemented
|
||||
# for the spatial backend.
|
||||
def convert_extent(self, box):
|
||||
def convert_extent(self, box, srid):
|
||||
raise NotImplementedError('Aggregate extent not implemented for this spatial backend.')
|
||||
|
||||
def convert_extent3d(self, box):
|
||||
def convert_extent3d(self, box, srid):
|
||||
raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.')
|
||||
|
||||
def convert_geom(self, geom_val, geom_field):
|
||||
|
|
|
@ -7,7 +7,6 @@ from django.contrib.gis.db.backends.utils import SpatialOperator
|
|||
|
||||
class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
|
||||
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
|
||||
mysql = True
|
||||
name = 'mysql'
|
||||
select = 'AsText(%s)'
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
from django.contrib.gis.db.models.sql.compiler import GeoSQLCompiler as BaseGeoSQLCompiler
|
||||
from django.db.backends.oracle import compiler
|
||||
|
||||
SQLCompiler = compiler.SQLCompiler
|
||||
|
||||
|
||||
class GeoSQLCompiler(BaseGeoSQLCompiler, SQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
|
||||
pass
|
|
@ -52,7 +52,6 @@ class SDORelate(SpatialOperator):
|
|||
|
||||
|
||||
class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
compiler_module = "django.contrib.gis.db.backends.oracle.compiler"
|
||||
|
||||
name = 'oracle'
|
||||
oracle = True
|
||||
|
@ -111,8 +110,9 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
def geo_quote_name(self, name):
|
||||
return super(OracleOperations, self).geo_quote_name(name).upper()
|
||||
|
||||
def get_db_converters(self, internal_type):
|
||||
converters = super(OracleOperations, self).get_db_converters(internal_type)
|
||||
def get_db_converters(self, expression):
|
||||
converters = super(OracleOperations, self).get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
geometry_fields = (
|
||||
'PointField', 'GeometryField', 'LineStringField',
|
||||
'PolygonField', 'MultiPointField', 'MultiLineStringField',
|
||||
|
@ -121,14 +121,23 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
)
|
||||
if internal_type in geometry_fields:
|
||||
converters.append(self.convert_textfield_value)
|
||||
if hasattr(expression.output_field, 'geom_type'):
|
||||
converters.append(self.convert_geometry)
|
||||
return converters
|
||||
|
||||
def convert_extent(self, clob):
|
||||
def convert_geometry(self, value, expression, context):
|
||||
if value:
|
||||
value = Geometry(value)
|
||||
if 'transformed_srid' in context:
|
||||
value.srid = context['transformed_srid']
|
||||
return value
|
||||
|
||||
def convert_extent(self, clob, srid):
|
||||
if clob:
|
||||
# Generally, Oracle returns a polygon for the extent -- however,
|
||||
# it can return a single point if there's only one Point in the
|
||||
# table.
|
||||
ext_geom = Geometry(clob.read())
|
||||
ext_geom = Geometry(clob.read(), srid)
|
||||
gtype = str(ext_geom.geom_type)
|
||||
if gtype == 'Polygon':
|
||||
# Construct the 4-tuple from the coordinates in the polygon.
|
||||
|
@ -226,7 +235,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
else:
|
||||
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
|
||||
sql_function = getattr(self, agg_name)
|
||||
return self.select % sql_template, sql_function
|
||||
return sql_template, sql_function
|
||||
|
||||
# Routines for getting the OGC-compliant models.
|
||||
def geometry_columns(self):
|
||||
|
|
|
@ -44,7 +44,6 @@ class PostGISDistanceOperator(PostGISOperator):
|
|||
|
||||
|
||||
class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
|
||||
name = 'postgis'
|
||||
postgis = True
|
||||
geography = True
|
||||
|
@ -188,7 +187,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
agg_name = aggregate.__class__.__name__
|
||||
return agg_name in self.valid_aggregates
|
||||
|
||||
def convert_extent(self, box):
|
||||
def convert_extent(self, box, srid):
|
||||
"""
|
||||
Returns a 4-tuple extent for the `Extent` aggregate by converting
|
||||
the bounding box text returned by PostGIS (`box` argument), for
|
||||
|
@ -199,7 +198,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
xmax, ymax = map(float, ur.split())
|
||||
return (xmin, ymin, xmax, ymax)
|
||||
|
||||
def convert_extent3d(self, box3d):
|
||||
def convert_extent3d(self, box3d, srid):
|
||||
"""
|
||||
Returns a 6-tuple extent for the `Extent3D` aggregate by converting
|
||||
the 3d bounding-box text returned by PostGIS (`box3d` argument), for
|
||||
|
|
|
@ -14,7 +14,6 @@ from django.utils.functional import cached_property
|
|||
|
||||
|
||||
class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
||||
compiler_module = 'django.contrib.gis.db.models.sql.compiler'
|
||||
name = 'spatialite'
|
||||
spatialite = True
|
||||
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
|
||||
|
@ -131,11 +130,11 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
agg_name = aggregate.__class__.__name__
|
||||
return agg_name in self.valid_aggregates
|
||||
|
||||
def convert_extent(self, box):
|
||||
def convert_extent(self, box, srid):
|
||||
"""
|
||||
Convert the polygon data received from Spatialite to min/max values.
|
||||
"""
|
||||
shell = Geometry(box).shell
|
||||
shell = Geometry(box, srid).shell
|
||||
xmin, ymin = shell[0][:2]
|
||||
xmax, ymax = shell[2][:2]
|
||||
return (xmin, ymin, xmax, ymax)
|
||||
|
@ -256,7 +255,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
agg_name = agg_name.lower()
|
||||
if agg_name == 'union':
|
||||
agg_name += 'agg'
|
||||
sql_template = self.select % '%(function)s(%(expressions)s)'
|
||||
sql_template = '%(function)s(%(expressions)s)'
|
||||
sql_function = getattr(self, agg_name)
|
||||
return sql_template, sql_function
|
||||
|
||||
|
@ -268,3 +267,16 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations):
|
|||
def spatial_ref_sys(self):
|
||||
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys
|
||||
return SpatialiteSpatialRefSys
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super(SpatiaLiteOperations, self).get_db_converters(expression)
|
||||
if hasattr(expression.output_field, 'geom_type'):
|
||||
converters.append(self.convert_geometry)
|
||||
return converters
|
||||
|
||||
def convert_geometry(self, value, expression, context):
|
||||
if value:
|
||||
value = Geometry(value)
|
||||
if 'transformed_srid' in context:
|
||||
value.srid = context['transformed_srid']
|
||||
return value
|
||||
|
|
|
@ -28,7 +28,7 @@ class GeoAggregate(Aggregate):
|
|||
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
|
||||
return c
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
return connection.ops.convert_geom(value, self.output_field)
|
||||
|
||||
|
||||
|
@ -43,8 +43,8 @@ class Extent(GeoAggregate):
|
|||
def __init__(self, expression, **extra):
|
||||
super(Extent, self).__init__(expression, output_field=ExtentField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
return connection.ops.convert_extent(value)
|
||||
def convert_value(self, value, connection, context):
|
||||
return connection.ops.convert_extent(value, context.get('transformed_srid'))
|
||||
|
||||
|
||||
class Extent3D(GeoAggregate):
|
||||
|
@ -54,8 +54,8 @@ class Extent3D(GeoAggregate):
|
|||
def __init__(self, expression, **extra):
|
||||
super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
return connection.ops.convert_extent3d(value)
|
||||
def convert_value(self, value, connection, context):
|
||||
return connection.ops.convert_extent3d(value, context.get('transformed_srid'))
|
||||
|
||||
|
||||
class MakeLine(GeoAggregate):
|
||||
|
|
|
@ -42,7 +42,30 @@ def get_srid_info(srid, connection):
|
|||
return _srid_cache[connection.alias][srid]
|
||||
|
||||
|
||||
class GeometryField(Field):
|
||||
class GeoSelectFormatMixin(object):
|
||||
def select_format(self, compiler, sql, params):
|
||||
"""
|
||||
Returns the selection format string, depending on the requirements
|
||||
of the spatial backend. For example, Oracle and MySQL require custom
|
||||
selection formats in order to retrieve geometries in OGC WKT. For all
|
||||
other fields a simple '%s' format string is returned.
|
||||
"""
|
||||
connection = compiler.connection
|
||||
srid = compiler.query.get_context('transformed_srid')
|
||||
if srid:
|
||||
sel_fmt = '%s(%%s, %s)' % (connection.ops.transform, srid)
|
||||
else:
|
||||
sel_fmt = '%s'
|
||||
if connection.ops.select:
|
||||
# This allows operations to be done on fields in the SELECT,
|
||||
# overriding their values -- used by the Oracle and MySQL
|
||||
# spatial backends to get database values as WKT, and by the
|
||||
# `transform` method.
|
||||
sel_fmt = connection.ops.select % sel_fmt
|
||||
return sel_fmt % sql, params
|
||||
|
||||
|
||||
class GeometryField(GeoSelectFormatMixin, Field):
|
||||
"The base GIS field -- maps to the OpenGIS Specification Geometry type."
|
||||
|
||||
# The OpenGIS Geometry name.
|
||||
|
@ -196,7 +219,7 @@ class GeometryField(Field):
|
|||
else:
|
||||
return geom
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
if value and not isinstance(value, Geometry):
|
||||
value = Geometry(value)
|
||||
return value
|
||||
|
@ -337,7 +360,7 @@ class GeometryCollectionField(GeometryField):
|
|||
description = _("Geometry collection")
|
||||
|
||||
|
||||
class ExtentField(Field):
|
||||
class ExtentField(GeoSelectFormatMixin, Field):
|
||||
"Used as a return value from an extent aggregate"
|
||||
|
||||
description = _("Extent Aggregate Field")
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
from django.db import connections
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
from django.contrib.gis.db.models import aggregates
|
||||
from django.contrib.gis.db.models.fields import get_srid_info, PointField, LineStringField
|
||||
from django.contrib.gis.db.models.sql import AreaField, DistanceField, GeomField, GeoQuery, GMLField
|
||||
from django.contrib.gis.db.models.fields import (
|
||||
get_srid_info, LineStringField, GeometryField, PointField,
|
||||
)
|
||||
from django.contrib.gis.db.models.lookups import GISLookup
|
||||
from django.contrib.gis.db.models.sql import (
|
||||
AreaField, DistanceField, GeomField, GMLField,
|
||||
)
|
||||
from django.contrib.gis.geometry.backend import Geometry
|
||||
from django.contrib.gis.measure import Area, Distance
|
||||
|
||||
|
@ -13,11 +20,6 @@ from django.utils import six
|
|||
class GeoQuerySet(QuerySet):
|
||||
"The Geographic QuerySet."
|
||||
|
||||
### Methods overloaded from QuerySet ###
|
||||
def __init__(self, model=None, query=None, using=None, hints=None):
|
||||
super(GeoQuerySet, self).__init__(model=model, query=query, using=using, hints=hints)
|
||||
self.query = query or GeoQuery(self.model)
|
||||
|
||||
### GeoQuerySet Methods ###
|
||||
def area(self, tolerance=0.05, **kwargs):
|
||||
"""
|
||||
|
@ -26,7 +28,8 @@ class GeoQuerySet(QuerySet):
|
|||
"""
|
||||
# Performing setup here rather than in `_spatial_attribute` so that
|
||||
# we can get the units for `AreaField`.
|
||||
procedure_args, geo_field = self._spatial_setup('area', field_name=kwargs.get('field_name', None))
|
||||
procedure_args, geo_field = self._spatial_setup(
|
||||
'area', field_name=kwargs.get('field_name', None))
|
||||
s = {'procedure_args': procedure_args,
|
||||
'geo_field': geo_field,
|
||||
'setup': False,
|
||||
|
@ -378,24 +381,8 @@ class GeoQuerySet(QuerySet):
|
|||
if not isinstance(srid, six.integer_types):
|
||||
raise TypeError('An integer SRID must be provided.')
|
||||
field_name = kwargs.get('field_name', None)
|
||||
tmp, geo_field = self._spatial_setup('transform', field_name=field_name)
|
||||
|
||||
# Getting the selection SQL for the given geographic field.
|
||||
field_col = self._geocol_select(geo_field, field_name)
|
||||
|
||||
# Why cascading substitutions? Because spatial backends like
|
||||
# Oracle and MySQL already require a function call to convert to text, thus
|
||||
# when there's also a transformation we need to cascade the substitutions.
|
||||
# For example, 'SDO_UTIL.TO_WKTGEOMETRY(SDO_CS.TRANSFORM( ... )'
|
||||
geo_col = self.query.custom_select.get(geo_field, field_col)
|
||||
|
||||
# Setting the key for the field's column with the custom SELECT SQL to
|
||||
# override the geometry column returned from the database.
|
||||
custom_sel = '%s(%s, %s)' % (connections[self.db].ops.transform, geo_col, srid)
|
||||
# TODO: Should we have this as an alias?
|
||||
# custom_sel = '(%s(%s, %s)) AS %s' % (SpatialBackend.transform, geo_col, srid, qn(geo_field.name))
|
||||
self.query.transformed_srid = srid # So other GeoQuerySet methods
|
||||
self.query.custom_select[geo_field] = custom_sel
|
||||
self._spatial_setup('transform', field_name=field_name)
|
||||
self.query.add_context('transformed_srid', srid)
|
||||
return self._clone()
|
||||
|
||||
def union(self, geom, **kwargs):
|
||||
|
@ -433,7 +420,7 @@ class GeoQuerySet(QuerySet):
|
|||
|
||||
# Is there a geographic field in the model to perform this
|
||||
# operation on?
|
||||
geo_field = self.query._geo_field(field_name)
|
||||
geo_field = self._geo_field(field_name)
|
||||
if not geo_field:
|
||||
raise TypeError('%s output only available on GeometryFields.' % func)
|
||||
|
||||
|
@ -454,7 +441,7 @@ class GeoQuerySet(QuerySet):
|
|||
returning their result to the caller of the function.
|
||||
"""
|
||||
# Getting the field the geographic aggregate will be called on.
|
||||
geo_field = self.query._geo_field(field_name)
|
||||
geo_field = self._geo_field(field_name)
|
||||
if not geo_field:
|
||||
raise TypeError('%s aggregate only available on GeometryFields.' % aggregate.name)
|
||||
|
||||
|
@ -509,12 +496,12 @@ class GeoQuerySet(QuerySet):
|
|||
settings.setdefault('select_params', [])
|
||||
|
||||
connection = connections[self.db]
|
||||
backend = connection.ops
|
||||
|
||||
# Performing setup for the spatial column, unless told not to.
|
||||
if settings.get('setup', True):
|
||||
default_args, geo_field = self._spatial_setup(att, desc=settings['desc'], field_name=field_name,
|
||||
geo_field_type=settings.get('geo_field_type', None))
|
||||
default_args, geo_field = self._spatial_setup(
|
||||
att, desc=settings['desc'], field_name=field_name,
|
||||
geo_field_type=settings.get('geo_field_type', None))
|
||||
for k, v in six.iteritems(default_args):
|
||||
settings['procedure_args'].setdefault(k, v)
|
||||
else:
|
||||
|
@ -544,18 +531,19 @@ class GeoQuerySet(QuerySet):
|
|||
|
||||
# If the result of this function needs to be converted.
|
||||
if settings.get('select_field', False):
|
||||
sel_fld = settings['select_field']
|
||||
if isinstance(sel_fld, GeomField) and backend.select:
|
||||
self.query.custom_select[model_att] = backend.select
|
||||
select_field = settings['select_field']
|
||||
if connection.ops.oracle:
|
||||
sel_fld.empty_strings_allowed = False
|
||||
self.query.extra_select_fields[model_att] = sel_fld
|
||||
select_field.empty_strings_allowed = False
|
||||
else:
|
||||
select_field = Field()
|
||||
|
||||
# Finally, setting the extra selection attribute with
|
||||
# the format string expanded with the stored procedure
|
||||
# arguments.
|
||||
return self.extra(select={model_att: fmt % settings['procedure_args']},
|
||||
select_params=settings['select_params'])
|
||||
self.query.add_annotation(
|
||||
RawSQL(fmt % settings['procedure_args'], settings['select_params'], select_field),
|
||||
model_att)
|
||||
return self
|
||||
|
||||
def _distance_attribute(self, func, geom=None, tolerance=0.05, spheroid=False, **kwargs):
|
||||
"""
|
||||
|
@ -616,8 +604,9 @@ class GeoQuerySet(QuerySet):
|
|||
else:
|
||||
# Getting whether this field is in units of degrees since the field may have
|
||||
# been transformed via the `transform` GeoQuerySet method.
|
||||
if self.query.transformed_srid:
|
||||
u, unit_name, s = get_srid_info(self.query.transformed_srid, connection)
|
||||
srid = self.query.get_context('transformed_srid')
|
||||
if srid:
|
||||
u, unit_name, s = get_srid_info(srid, connection)
|
||||
geodetic = unit_name.lower() in geo_field.geodetic_units
|
||||
|
||||
if geodetic and not connection.features.supports_distance_geodetic:
|
||||
|
@ -627,20 +616,20 @@ class GeoQuerySet(QuerySet):
|
|||
)
|
||||
|
||||
if distance:
|
||||
if self.query.transformed_srid:
|
||||
if srid:
|
||||
# Setting the `geom_args` flag to false because we want to handle
|
||||
# transformation SQL here, rather than the way done by default
|
||||
# (which will transform to the original SRID of the field rather
|
||||
# than to what was transformed to).
|
||||
geom_args = False
|
||||
procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, self.query.transformed_srid)
|
||||
if geom.srid is None or geom.srid == self.query.transformed_srid:
|
||||
procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, srid)
|
||||
if geom.srid is None or geom.srid == srid:
|
||||
# If the geom parameter srid is None, it is assumed the coordinates
|
||||
# are in the transformed units. A placeholder is used for the
|
||||
# geometry parameter. `GeomFromText` constructor is also needed
|
||||
# to wrap geom placeholder for SpatiaLite.
|
||||
if backend.spatialite:
|
||||
procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, self.query.transformed_srid)
|
||||
procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, srid)
|
||||
else:
|
||||
procedure_fmt += ', %%s'
|
||||
else:
|
||||
|
@ -649,10 +638,11 @@ class GeoQuerySet(QuerySet):
|
|||
# SpatiaLite also needs geometry placeholder wrapped in `GeomFromText`
|
||||
# constructor.
|
||||
if backend.spatialite:
|
||||
procedure_fmt += ', %s(%s(%%%%s, %s), %s)' % (backend.transform, backend.from_text,
|
||||
geom.srid, self.query.transformed_srid)
|
||||
procedure_fmt += (', %s(%s(%%%%s, %s), %s)' % (
|
||||
backend.transform, backend.from_text,
|
||||
geom.srid, srid))
|
||||
else:
|
||||
procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, self.query.transformed_srid)
|
||||
procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, srid)
|
||||
else:
|
||||
# `transform()` was not used on this GeoQuerySet.
|
||||
procedure_fmt = '%(geo_col)s,%(geom)s'
|
||||
|
@ -743,22 +733,57 @@ class GeoQuerySet(QuerySet):
|
|||
column. Takes into account if the geographic field is in a
|
||||
ForeignKey relation to the current model.
|
||||
"""
|
||||
compiler = self.query.get_compiler(self.db)
|
||||
opts = self.model._meta
|
||||
if geo_field not in opts.fields:
|
||||
# Is this operation going to be on a related geographic field?
|
||||
# If so, it'll have to be added to the select related information
|
||||
# (e.g., if 'location__point' was given as the field name).
|
||||
# Note: the operation really is defined as "must add select related!"
|
||||
self.query.add_select_related([field_name])
|
||||
compiler = self.query.get_compiler(self.db)
|
||||
# Call pre_sql_setup() so that compiler.select gets populated.
|
||||
compiler.pre_sql_setup()
|
||||
for (rel_table, rel_col), field in self.query.related_select_cols:
|
||||
if field == geo_field:
|
||||
return compiler._field_column(geo_field, rel_table)
|
||||
raise ValueError("%r not in self.query.related_select_cols" % geo_field)
|
||||
for col, _, _ in compiler.select:
|
||||
if col.output_field == geo_field:
|
||||
return col.as_sql(compiler, compiler.connection)[0]
|
||||
raise ValueError("%r not in compiler's related_select_cols" % geo_field)
|
||||
elif geo_field not in opts.local_fields:
|
||||
# This geographic field is inherited from another model, so we have to
|
||||
# use the db table for the _parent_ model instead.
|
||||
parent_model = geo_field.model._meta.concrete_model
|
||||
return self.query.get_compiler(self.db)._field_column(geo_field, parent_model._meta.db_table)
|
||||
return self._field_column(compiler, geo_field, parent_model._meta.db_table)
|
||||
else:
|
||||
return self.query.get_compiler(self.db)._field_column(geo_field)
|
||||
return self._field_column(compiler, geo_field)
|
||||
|
||||
# Private API utilities, subject to change.
|
||||
def _geo_field(self, field_name=None):
|
||||
"""
|
||||
Returns the first Geometry field encountered or the one specified via
|
||||
the `field_name` keyword. The `field_name` may be a string specifying
|
||||
the geometry field on this GeoQuerySet's model, or a lookup string
|
||||
to a geometry field via a ForeignKey relation.
|
||||
"""
|
||||
if field_name is None:
|
||||
# Incrementing until the first geographic field is found.
|
||||
for field in self.model._meta.fields:
|
||||
if isinstance(field, GeometryField):
|
||||
return field
|
||||
return False
|
||||
else:
|
||||
# Otherwise, check by the given field name -- which may be
|
||||
# a lookup to a _related_ geographic field.
|
||||
return GISLookup._check_geo_field(self.model._meta, field_name)
|
||||
|
||||
def _field_column(self, compiler, field, table_alias=None, column=None):
|
||||
"""
|
||||
Helper function that returns the database column for the given field.
|
||||
The table and column are returned (quoted) in the proper format, e.g.,
|
||||
`"geoapp_city"."point"`. If `table_alias` is not specified, the
|
||||
database table associated with the model of this `GeoQuerySet` will be
|
||||
used. If `column` is specified, it will be used instead of the value
|
||||
in `field.column`.
|
||||
"""
|
||||
if table_alias is None:
|
||||
table_alias = compiler.query.get_meta().db_table
|
||||
return "%s.%s" % (compiler.quote_name_unless_alias(table_alias),
|
||||
compiler.connection.ops.quote_name(column or field.column))
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField, GMLField
|
||||
from django.contrib.gis.db.models.sql.query import GeoQuery
|
||||
|
||||
__all__ = [
|
||||
'AreaField', 'DistanceField', 'GeomField', 'GMLField', 'GeoQuery',
|
||||
'AreaField', 'DistanceField', 'GeomField', 'GMLField'
|
||||
]
|
||||
|
|
|
@ -1,240 +0,0 @@
|
|||
from django.db.backends.utils import truncate_name
|
||||
from django.db.models.sql import compiler
|
||||
from django.utils import six
|
||||
|
||||
SQLCompiler = compiler.SQLCompiler
|
||||
|
||||
|
||||
class GeoSQLCompiler(compiler.SQLCompiler):
|
||||
|
||||
def get_columns(self, with_aliases=False):
|
||||
"""
|
||||
Return the list of columns to use in the select statement. If no
|
||||
columns have been specified, returns all columns relating to fields in
|
||||
the model.
|
||||
|
||||
If 'with_aliases' is true, any column names that are duplicated
|
||||
(without the table names) are given unique aliases. This is needed in
|
||||
some cases to avoid ambiguity with nested queries.
|
||||
|
||||
This routine is overridden from Query to handle customized selection of
|
||||
geometry columns.
|
||||
"""
|
||||
qn = self.quote_name_unless_alias
|
||||
qn2 = self.connection.ops.quote_name
|
||||
result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
|
||||
for alias, col in six.iteritems(self.query.extra_select)]
|
||||
params = []
|
||||
aliases = set(self.query.extra_select.keys())
|
||||
if with_aliases:
|
||||
col_aliases = aliases.copy()
|
||||
else:
|
||||
col_aliases = set()
|
||||
if self.query.select:
|
||||
only_load = self.deferred_to_columns()
|
||||
# This loop customized for GeoQuery.
|
||||
for col, field in self.query.select:
|
||||
if isinstance(col, (list, tuple)):
|
||||
alias, column = col
|
||||
table = self.query.alias_map[alias].table_name
|
||||
if table in only_load and column not in only_load[table]:
|
||||
continue
|
||||
r = self.get_field_select(field, alias, column)
|
||||
if with_aliases:
|
||||
if col[1] in col_aliases:
|
||||
c_alias = 'Col%d' % len(col_aliases)
|
||||
result.append('%s AS %s' % (r, c_alias))
|
||||
aliases.add(c_alias)
|
||||
col_aliases.add(c_alias)
|
||||
else:
|
||||
result.append('%s AS %s' % (r, qn2(col[1])))
|
||||
aliases.add(r)
|
||||
col_aliases.add(col[1])
|
||||
else:
|
||||
result.append(r)
|
||||
aliases.add(r)
|
||||
col_aliases.add(col[1])
|
||||
else:
|
||||
col_sql, col_params = col.as_sql(self, self.connection)
|
||||
result.append(col_sql)
|
||||
params.extend(col_params)
|
||||
|
||||
if hasattr(col, 'alias'):
|
||||
aliases.add(col.alias)
|
||||
col_aliases.add(col.alias)
|
||||
|
||||
elif self.query.default_cols:
|
||||
cols, new_aliases = self.get_default_columns(with_aliases,
|
||||
col_aliases)
|
||||
result.extend(cols)
|
||||
aliases.update(new_aliases)
|
||||
|
||||
max_name_length = self.connection.ops.max_name_length()
|
||||
for alias, annotation in self.query.annotation_select.items():
|
||||
agg_sql, agg_params = self.compile(annotation)
|
||||
if alias is None:
|
||||
result.append(agg_sql)
|
||||
else:
|
||||
result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
|
||||
params.extend(agg_params)
|
||||
|
||||
# This loop customized for GeoQuery.
|
||||
for (table, col), field in self.query.related_select_cols:
|
||||
r = self.get_field_select(field, table, col)
|
||||
if with_aliases and col in col_aliases:
|
||||
c_alias = 'Col%d' % len(col_aliases)
|
||||
result.append('%s AS %s' % (r, c_alias))
|
||||
aliases.add(c_alias)
|
||||
col_aliases.add(c_alias)
|
||||
else:
|
||||
result.append(r)
|
||||
aliases.add(r)
|
||||
col_aliases.add(col)
|
||||
|
||||
self._select_aliases = aliases
|
||||
return result, params
|
||||
|
||||
def get_default_columns(self, with_aliases=False, col_aliases=None,
|
||||
start_alias=None, opts=None, as_pairs=False, from_parent=None):
|
||||
"""
|
||||
Computes the default columns for selecting every field in the base
|
||||
model. Will sometimes be called to pull in related models (e.g. via
|
||||
select_related), in which case "opts" and "start_alias" will be given
|
||||
to provide a starting point for the traversal.
|
||||
|
||||
Returns a list of strings, quoted appropriately for use in SQL
|
||||
directly, as well as a set of aliases used in the select statement (if
|
||||
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
|
||||
of strings as the first component and None as the second component).
|
||||
|
||||
This routine is overridden from Query to handle customized selection of
|
||||
geometry columns.
|
||||
"""
|
||||
result = []
|
||||
if opts is None:
|
||||
opts = self.query.get_meta()
|
||||
aliases = set()
|
||||
only_load = self.deferred_to_columns()
|
||||
seen = self.query.included_inherited_models.copy()
|
||||
if start_alias:
|
||||
seen[None] = start_alias
|
||||
for field in opts.concrete_fields:
|
||||
model = field.model._meta.concrete_model
|
||||
if model is opts.model:
|
||||
model = None
|
||||
if from_parent and model is not None and issubclass(from_parent, model):
|
||||
# Avoid loading data for already loaded parents.
|
||||
continue
|
||||
alias = self.query.join_parent_model(opts, model, start_alias, seen)
|
||||
table = self.query.alias_map[alias].table_name
|
||||
if table in only_load and field.column not in only_load[table]:
|
||||
continue
|
||||
if as_pairs:
|
||||
result.append((alias, field))
|
||||
aliases.add(alias)
|
||||
continue
|
||||
# This part of the function is customized for GeoQuery. We
|
||||
# see if there was any custom selection specified in the
|
||||
# dictionary, and set up the selection format appropriately.
|
||||
field_sel = self.get_field_select(field, alias)
|
||||
if with_aliases and field.column in col_aliases:
|
||||
c_alias = 'Col%d' % len(col_aliases)
|
||||
result.append('%s AS %s' % (field_sel, c_alias))
|
||||
col_aliases.add(c_alias)
|
||||
aliases.add(c_alias)
|
||||
else:
|
||||
r = field_sel
|
||||
result.append(r)
|
||||
aliases.add(r)
|
||||
if with_aliases:
|
||||
col_aliases.add(field.column)
|
||||
return result, aliases
|
||||
|
||||
def get_converters(self, fields):
|
||||
converters = super(GeoSQLCompiler, self).get_converters(fields)
|
||||
for i, alias in enumerate(self.query.extra_select):
|
||||
field = self.query.extra_select_fields.get(alias)
|
||||
if field:
|
||||
backend_converters = self.connection.ops.get_db_converters(field.get_internal_type())
|
||||
converters[i] = (backend_converters, [field.from_db_value], field)
|
||||
return converters
|
||||
|
||||
#### Routines unique to GeoQuery ####
|
||||
def get_extra_select_format(self, alias):
|
||||
sel_fmt = '%s'
|
||||
if hasattr(self.query, 'custom_select') and alias in self.query.custom_select:
|
||||
sel_fmt = sel_fmt % self.query.custom_select[alias]
|
||||
return sel_fmt
|
||||
|
||||
def get_field_select(self, field, alias=None, column=None):
|
||||
"""
|
||||
Returns the SELECT SQL string for the given field. Figures out
|
||||
if any custom selection SQL is needed for the column The `alias`
|
||||
keyword may be used to manually specify the database table where
|
||||
the column exists, if not in the model associated with this
|
||||
`GeoQuery`. Similarly, `column` may be used to specify the exact
|
||||
column name, rather than using the `column` attribute on `field`.
|
||||
"""
|
||||
sel_fmt = self.get_select_format(field)
|
||||
if field in self.query.custom_select:
|
||||
field_sel = sel_fmt % self.query.custom_select[field]
|
||||
else:
|
||||
field_sel = sel_fmt % self._field_column(field, alias, column)
|
||||
return field_sel
|
||||
|
||||
def get_select_format(self, fld):
|
||||
"""
|
||||
Returns the selection format string, depending on the requirements
|
||||
of the spatial backend. For example, Oracle and MySQL require custom
|
||||
selection formats in order to retrieve geometries in OGC WKT. For all
|
||||
other fields a simple '%s' format string is returned.
|
||||
"""
|
||||
if self.connection.ops.select and hasattr(fld, 'geom_type'):
|
||||
# This allows operations to be done on fields in the SELECT,
|
||||
# overriding their values -- used by the Oracle and MySQL
|
||||
# spatial backends to get database values as WKT, and by the
|
||||
# `transform` method.
|
||||
sel_fmt = self.connection.ops.select
|
||||
|
||||
# Because WKT doesn't contain spatial reference information,
|
||||
# the SRID is prefixed to the returned WKT to ensure that the
|
||||
# transformed geometries have an SRID different than that of the
|
||||
# field -- this is only used by `transform` for Oracle and
|
||||
# SpatiaLite backends.
|
||||
if self.query.transformed_srid and (self.connection.ops.oracle or
|
||||
self.connection.ops.spatialite):
|
||||
sel_fmt = "'SRID=%d;'||%s" % (self.query.transformed_srid, sel_fmt)
|
||||
else:
|
||||
sel_fmt = '%s'
|
||||
return sel_fmt
|
||||
|
||||
# Private API utilities, subject to change.
|
||||
def _field_column(self, field, table_alias=None, column=None):
|
||||
"""
|
||||
Helper function that returns the database column for the given field.
|
||||
The table and column are returned (quoted) in the proper format, e.g.,
|
||||
`"geoapp_city"."point"`. If `table_alias` is not specified, the
|
||||
database table associated with the model of this `GeoQuery` will be
|
||||
used. If `column` is specified, it will be used instead of the value
|
||||
in `field.column`.
|
||||
"""
|
||||
if table_alias is None:
|
||||
table_alias = self.query.get_meta().db_table
|
||||
return "%s.%s" % (self.quote_name_unless_alias(table_alias),
|
||||
self.connection.ops.quote_name(column or field.column))
|
||||
|
||||
|
||||
class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
|
||||
pass
|
|
@ -3,6 +3,7 @@ This module holds simple classes to convert geospatial values from the
|
|||
database.
|
||||
"""
|
||||
|
||||
from django.contrib.gis.db.models.fields import GeoSelectFormatMixin
|
||||
from django.contrib.gis.geometry.backend import Geometry
|
||||
from django.contrib.gis.measure import Area, Distance
|
||||
|
||||
|
@ -10,13 +11,19 @@ from django.contrib.gis.measure import Area, Distance
|
|||
class BaseField(object):
|
||||
empty_strings_allowed = True
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return [self.from_db_value]
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
return sql, params
|
||||
|
||||
|
||||
class AreaField(BaseField):
|
||||
"Wrapper for Area values."
|
||||
def __init__(self, area_att):
|
||||
self.area_att = area_att
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
if value is not None:
|
||||
value = Area(**{self.area_att: value})
|
||||
return value
|
||||
|
@ -30,7 +37,7 @@ class DistanceField(BaseField):
|
|||
def __init__(self, distance_att):
|
||||
self.distance_att = distance_att
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
if value is not None:
|
||||
value = Distance(**{self.distance_att: value})
|
||||
return value
|
||||
|
@ -39,12 +46,15 @@ class DistanceField(BaseField):
|
|||
return 'DistanceField'
|
||||
|
||||
|
||||
class GeomField(BaseField):
|
||||
class GeomField(GeoSelectFormatMixin, BaseField):
|
||||
"""
|
||||
Wrapper for Geometry values. It is a lightweight alternative to
|
||||
using GeometryField (which requires an SQL query upon instantiation).
|
||||
"""
|
||||
def from_db_value(self, value, connection):
|
||||
# Hacky marker for get_db_converters()
|
||||
geom_type = None
|
||||
|
||||
def from_db_value(self, value, connection, context):
|
||||
if value is not None:
|
||||
value = Geometry(value)
|
||||
return value
|
||||
|
@ -61,5 +71,5 @@ class GMLField(BaseField):
|
|||
def get_internal_type(self):
|
||||
return 'GMLField'
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
return value
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
from django.db import connections
|
||||
from django.db.models.query import sql
|
||||
from django.db.models.sql.constants import QUERY_TERMS
|
||||
|
||||
from django.contrib.gis.db.models.fields import GeometryField
|
||||
from django.contrib.gis.db.models.lookups import GISLookup
|
||||
from django.contrib.gis.db.models import aggregates as gis_aggregates
|
||||
from django.contrib.gis.db.models.sql.conversion import GeomField
|
||||
|
||||
|
||||
class GeoQuery(sql.Query):
|
||||
"""
|
||||
A single spatial SQL query.
|
||||
"""
|
||||
# Overriding the valid query terms.
|
||||
query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys())
|
||||
|
||||
compiler = 'GeoSQLCompiler'
|
||||
|
||||
#### Methods overridden from the base Query class ####
|
||||
def __init__(self, model):
|
||||
super(GeoQuery, self).__init__(model)
|
||||
# The following attributes are customized for the GeoQuerySet.
|
||||
# The SpatialBackend classes contain backend-specific routines and functions.
|
||||
self.custom_select = {}
|
||||
self.transformed_srid = None
|
||||
self.extra_select_fields = {}
|
||||
|
||||
def clone(self, *args, **kwargs):
|
||||
obj = super(GeoQuery, self).clone(*args, **kwargs)
|
||||
# Customized selection dictionary and transformed srid flag have
|
||||
# to also be added to obj.
|
||||
obj.custom_select = self.custom_select.copy()
|
||||
obj.transformed_srid = self.transformed_srid
|
||||
obj.extra_select_fields = self.extra_select_fields.copy()
|
||||
return obj
|
||||
|
||||
def get_aggregation(self, using, force_subq=False):
|
||||
# Remove any aggregates marked for reduction from the subquery
|
||||
# and move them to the outer AggregateQuery.
|
||||
connection = connections[using]
|
||||
for alias, annotation in self.annotation_select.items():
|
||||
if isinstance(annotation, gis_aggregates.GeoAggregate):
|
||||
if not getattr(annotation, 'is_extent', False) or connection.ops.oracle:
|
||||
self.extra_select_fields[alias] = GeomField()
|
||||
return super(GeoQuery, self).get_aggregation(using, force_subq)
|
||||
|
||||
# Private API utilities, subject to change.
|
||||
def _geo_field(self, field_name=None):
|
||||
"""
|
||||
Returns the first Geometry field encountered; or specified via the
|
||||
`field_name` keyword. The `field_name` may be a string specifying
|
||||
the geometry field on this GeoQuery's model, or a lookup string
|
||||
to a geometry field via a ForeignKey relation.
|
||||
"""
|
||||
if field_name is None:
|
||||
# Incrementing until the first geographic field is found.
|
||||
for fld in self.model._meta.fields:
|
||||
if isinstance(fld, GeometryField):
|
||||
return fld
|
||||
return False
|
||||
else:
|
||||
# Otherwise, check by the given field name -- which may be
|
||||
# a lookup to a _related_ geographic field.
|
||||
return GISLookup._check_geo_field(self.model._meta, field_name)
|
|
@ -231,15 +231,6 @@ class RelatedGeoModelTest(TestCase):
|
|||
self.assertIn('Aurora', names)
|
||||
self.assertIn('Kecksburg', names)
|
||||
|
||||
def test11_geoquery_pickle(self):
|
||||
"Ensuring GeoQuery objects are unpickled correctly. See #10839."
|
||||
import pickle
|
||||
from django.contrib.gis.db.models.sql import GeoQuery
|
||||
qs = City.objects.all()
|
||||
q_str = pickle.dumps(qs.query)
|
||||
q = pickle.loads(q_str)
|
||||
self.assertEqual(GeoQuery, q.__class__)
|
||||
|
||||
# TODO: fix on Oracle -- get the following error because the SQL is ordered
|
||||
# by a geometry object, which Oracle apparently doesn't like:
|
||||
# ORA-22901: cannot compare nested table or VARRAY or LOB attributes of an object type
|
||||
|
|
|
@ -1262,7 +1262,7 @@ class BaseDatabaseOperations(object):
|
|||
second = timezone.make_aware(second, tz)
|
||||
return [first, second]
|
||||
|
||||
def get_db_converters(self, internal_type):
|
||||
def get_db_converters(self, expression):
|
||||
"""Get a list of functions needed to convert field data.
|
||||
|
||||
Some field types on some backends do not provide data in the correct
|
||||
|
@ -1270,7 +1270,7 @@ class BaseDatabaseOperations(object):
|
|||
"""
|
||||
return []
|
||||
|
||||
def convert_durationfield_value(self, value, field):
|
||||
def convert_durationfield_value(self, value, expression, context):
|
||||
if value is not None:
|
||||
value = str(decimal.Decimal(value) / decimal.Decimal(1000000))
|
||||
value = parse_duration(value)
|
||||
|
|
|
@ -302,7 +302,7 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
columns. If no ordering would otherwise be applied, we don't want any
|
||||
implicit sorting going on.
|
||||
"""
|
||||
return ["NULL"]
|
||||
return [(None, ("NULL", [], 'asc', False))]
|
||||
|
||||
def fulltext_search_sql(self, field_name):
|
||||
return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name
|
||||
|
@ -387,8 +387,9 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
return 'POW(%s)' % ','.join(sub_expressions)
|
||||
return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
|
||||
|
||||
def get_db_converters(self, internal_type):
|
||||
converters = super(DatabaseOperations, self).get_db_converters(internal_type)
|
||||
def get_db_converters(self, expression):
|
||||
converters = super(DatabaseOperations, self).get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type in ['BooleanField', 'NullBooleanField']:
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
if internal_type == 'UUIDField':
|
||||
|
@ -397,17 +398,17 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
converters.append(self.convert_textfield_value)
|
||||
return converters
|
||||
|
||||
def convert_booleanfield_value(self, value, field):
|
||||
def convert_booleanfield_value(self, value, expression, context):
|
||||
if value in (0, 1):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, field):
|
||||
def convert_uuidfield_value(self, value, expression, context):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def convert_textfield_value(self, value, field):
|
||||
def convert_textfield_value(self, value, expression, context):
|
||||
if value is not None:
|
||||
value = force_text(value)
|
||||
return value
|
||||
|
|
|
@ -268,8 +268,9 @@ WHEN (new.%(col_name)s IS NULL)
|
|||
sql = field_name # Cast to DATE removes sub-second precision.
|
||||
return sql, []
|
||||
|
||||
def get_db_converters(self, internal_type):
|
||||
converters = super(DatabaseOperations, self).get_db_converters(internal_type)
|
||||
def get_db_converters(self, expression):
|
||||
converters = super(DatabaseOperations, self).get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == 'TextField':
|
||||
converters.append(self.convert_textfield_value)
|
||||
elif internal_type == 'BinaryField':
|
||||
|
@ -285,28 +286,29 @@ WHEN (new.%(col_name)s IS NULL)
|
|||
converters.append(self.convert_empty_values)
|
||||
return converters
|
||||
|
||||
def convert_empty_values(self, value, field):
|
||||
def convert_empty_values(self, value, expression, context):
|
||||
# Oracle stores empty strings as null. We need to undo this in
|
||||
# order to adhere to the Django convention of using the empty
|
||||
# string instead of null, but only if the field accepts the
|
||||
# empty string.
|
||||
field = expression.output_field
|
||||
if value is None and field.empty_strings_allowed:
|
||||
value = ''
|
||||
if field.get_internal_type() == 'BinaryField':
|
||||
value = b''
|
||||
return value
|
||||
|
||||
def convert_textfield_value(self, value, field):
|
||||
def convert_textfield_value(self, value, expression, context):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = force_text(value.read())
|
||||
return value
|
||||
|
||||
def convert_binaryfield_value(self, value, field):
|
||||
def convert_binaryfield_value(self, value, expression, context):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = force_bytes(value.read())
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, field):
|
||||
def convert_booleanfield_value(self, value, expression, context):
|
||||
if value in (1, 0):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
@ -314,16 +316,16 @@ WHEN (new.%(col_name)s IS NULL)
|
|||
# cx_Oracle always returns datetime.datetime objects for
|
||||
# DATE and TIMESTAMP columns, but Django wants to see a
|
||||
# python datetime.date, .time, or .datetime.
|
||||
def convert_datefield_value(self, value, field):
|
||||
def convert_datefield_value(self, value, expression, context):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
return value.date()
|
||||
|
||||
def convert_timefield_value(self, value, field):
|
||||
def convert_timefield_value(self, value, expression, context):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, field):
|
||||
def convert_uuidfield_value(self, value, expression, context):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
|
|
@ -269,8 +269,9 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
|
||||
return six.text_type(value)
|
||||
|
||||
def get_db_converters(self, internal_type):
|
||||
converters = super(DatabaseOperations, self).get_db_converters(internal_type)
|
||||
def get_db_converters(self, expression):
|
||||
converters = super(DatabaseOperations, self).get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == 'DateTimeField':
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == 'DateField':
|
||||
|
@ -283,25 +284,25 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
converters.append(self.convert_uuidfield_value)
|
||||
return converters
|
||||
|
||||
def convert_decimalfield_value(self, value, field):
|
||||
return backend_utils.typecast_decimal(field.format_number(value))
|
||||
def convert_decimalfield_value(self, value, expression, context):
|
||||
return backend_utils.typecast_decimal(expression.output_field.format_number(value))
|
||||
|
||||
def convert_datefield_value(self, value, field):
|
||||
def convert_datefield_value(self, value, expression, context):
|
||||
if value is not None and not isinstance(value, datetime.date):
|
||||
value = parse_date(value)
|
||||
return value
|
||||
|
||||
def convert_datetimefield_value(self, value, field):
|
||||
def convert_datetimefield_value(self, value, expression, context):
|
||||
if value is not None and not isinstance(value, datetime.datetime):
|
||||
value = parse_datetime_with_timezone_support(value)
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, field):
|
||||
def convert_timefield_value(self, value, expression, context):
|
||||
if value is not None and not isinstance(value, datetime.time):
|
||||
value = parse_time(value)
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, field):
|
||||
def convert_uuidfield_value(self, value, expression, context):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
|
|
@ -87,7 +87,7 @@ class Avg(Aggregate):
|
|||
def __init__(self, expression, **extra):
|
||||
super(Avg, self).__init__(expression, output_field=FloatField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
if value is None:
|
||||
return value
|
||||
return float(value)
|
||||
|
@ -105,7 +105,7 @@ class Count(Aggregate):
|
|||
super(Count, self).__init__(
|
||||
expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
if value is None:
|
||||
return 0
|
||||
return int(value)
|
||||
|
@ -128,7 +128,7 @@ class StdDev(Aggregate):
|
|||
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
||||
super(StdDev, self).__init__(expression, output_field=FloatField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
if value is None:
|
||||
return value
|
||||
return float(value)
|
||||
|
@ -146,7 +146,7 @@ class Variance(Aggregate):
|
|||
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
||||
super(Variance, self).__init__(expression, output_field=FloatField(), **extra)
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
if value is None:
|
||||
return value
|
||||
return float(value)
|
||||
|
|
|
@ -127,7 +127,7 @@ class ExpressionNode(CombinableMixin):
|
|||
is_summary = False
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return [self.convert_value]
|
||||
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
||||
|
||||
def __init__(self, output_field=None):
|
||||
self._output_field = output_field
|
||||
|
@ -240,7 +240,7 @@ class ExpressionNode(CombinableMixin):
|
|||
raise FieldError(
|
||||
"Expression contains mixed types. You must set output_field")
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
"""
|
||||
Expressions provide their own converters because users have the option
|
||||
of manually specifying the output_field which may be a different type
|
||||
|
@ -305,6 +305,8 @@ class ExpressionNode(CombinableMixin):
|
|||
return self
|
||||
|
||||
def get_group_by_cols(self):
|
||||
if not self.contains_aggregate:
|
||||
return [self]
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
|
@ -490,6 +492,9 @@ class Value(ExpressionNode):
|
|||
return 'NULL', []
|
||||
return '%s', [self.value]
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
|
||||
class DurationValue(Value):
|
||||
def as_sql(self, compiler, connection):
|
||||
|
@ -499,6 +504,37 @@ class DurationValue(Value):
|
|||
return connection.ops.date_interval_sql(self.value)
|
||||
|
||||
|
||||
class RawSQL(ExpressionNode):
|
||||
def __init__(self, sql, params, output_field=None):
|
||||
if output_field is None:
|
||||
output_field = fields.Field()
|
||||
self.sql, self.params = sql, params
|
||||
super(RawSQL, self).__init__(output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return '(%s)' % self.sql, self.params
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [self]
|
||||
|
||||
|
||||
class Random(ExpressionNode):
|
||||
def __init__(self):
|
||||
super(Random, self).__init__(output_field=fields.FloatField())
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return connection.ops.random_function_sql(), []
|
||||
|
||||
|
||||
class ColIndexRef(ExpressionNode):
|
||||
def __init__(self, idx):
|
||||
self.idx = idx
|
||||
super(ColIndexRef, self).__init__()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return str(self.idx), []
|
||||
|
||||
|
||||
class Col(ExpressionNode):
|
||||
def __init__(self, alias, target, source=None):
|
||||
if source is None:
|
||||
|
@ -516,6 +552,9 @@ class Col(ExpressionNode):
|
|||
def get_group_by_cols(self):
|
||||
return [self]
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return self.output_field.get_db_converters(connection)
|
||||
|
||||
|
||||
class Ref(ExpressionNode):
|
||||
"""
|
||||
|
@ -537,7 +576,7 @@ class Ref(ExpressionNode):
|
|||
return self
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return "%s" % compiler.quote_name_unless_alias(self.refs), []
|
||||
return "%s" % connection.ops.quote_name(self.refs), []
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return [self]
|
||||
|
@ -581,7 +620,7 @@ class Date(ExpressionNode):
|
|||
copy.lookup_type = self.lookup_type
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
if isinstance(value, datetime.datetime):
|
||||
value = value.date()
|
||||
return value
|
||||
|
@ -629,7 +668,7 @@ class DateTime(ExpressionNode):
|
|||
copy.tzname = self.tzname
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, connection):
|
||||
def convert_value(self, value, connection, context):
|
||||
if settings.USE_TZ:
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -333,6 +333,28 @@ class Field(RegisterLookupMixin):
|
|||
]
|
||||
return []
|
||||
|
||||
def get_col(self, alias, source=None):
|
||||
if source is None:
|
||||
source = self
|
||||
if alias != self.model._meta.db_table or source != self:
|
||||
from django.db.models.expressions import Col
|
||||
return Col(alias, self, source)
|
||||
else:
|
||||
return self.cached_col
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
from django.db.models.expressions import Col
|
||||
return Col(self.model._meta.db_table, self)
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
"""
|
||||
Custom format for select clauses. For example, GIS columns need to be
|
||||
selected as AsText(table.col) on MySQL as the table.col data can't be used
|
||||
by Django.
|
||||
"""
|
||||
return sql, params
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Returns enough information to recreate the field as a 4-tuple:
|
||||
|
|
|
@ -15,7 +15,6 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
|
|||
from django.db.models.lookups import IsNull
|
||||
from django.db.models.query import QuerySet
|
||||
from django.db.models.query_utils import PathInfo
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils.encoding import force_text, smart_text
|
||||
from django.utils import six
|
||||
from django.utils.deprecation import RemovedInDjango20Warning
|
||||
|
@ -1738,26 +1737,26 @@ class ForeignObject(RelatedField):
|
|||
[source.name for source in sources], raw_value),
|
||||
AND)
|
||||
elif lookup_type == 'isnull':
|
||||
root_constraint.add(IsNull(Col(alias, targets[0], sources[0]), raw_value), AND)
|
||||
root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND)
|
||||
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
|
||||
and not is_multicolumn)):
|
||||
value = get_normalized_value(raw_value)
|
||||
for target, source, val in zip(targets, sources, value):
|
||||
lookup_class = target.get_lookup(lookup_type)
|
||||
root_constraint.add(
|
||||
lookup_class(Col(alias, target, source), val), AND)
|
||||
lookup_class(target.get_col(alias, source), val), AND)
|
||||
elif lookup_type in ['range', 'in'] and not is_multicolumn:
|
||||
values = [get_normalized_value(value) for value in raw_value]
|
||||
value = [val[0] for val in values]
|
||||
lookup_class = targets[0].get_lookup(lookup_type)
|
||||
root_constraint.add(lookup_class(Col(alias, targets[0], sources[0]), value), AND)
|
||||
root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND)
|
||||
elif lookup_type == 'in':
|
||||
values = [get_normalized_value(value) for value in raw_value]
|
||||
for value in values:
|
||||
value_constraint = constraint_class()
|
||||
for source, target, val in zip(sources, targets, value):
|
||||
lookup_class = target.get_lookup('exact')
|
||||
lookup = lookup_class(Col(alias, target, source), val)
|
||||
lookup = lookup_class(target.get_col(alias, source), val)
|
||||
value_constraint.add(lookup, AND)
|
||||
root_constraint.add(value_constraint, OR)
|
||||
else:
|
||||
|
|
|
@ -13,8 +13,7 @@ from django.db import (connections, router, transaction, IntegrityError,
|
|||
DJANGO_VERSION_PICKLE_KEY)
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import AutoField, Empty
|
||||
from django.db.models.query_utils import (Q, select_related_descend,
|
||||
deferred_class_factory, InvalidQuery)
|
||||
from django.db.models.query_utils import Q, deferred_class_factory, InvalidQuery
|
||||
from django.db.models.deletion import Collector
|
||||
from django.db.models.sql.constants import CURSOR
|
||||
from django.db.models import sql
|
||||
|
@ -233,76 +232,34 @@ class QuerySet(object):
|
|||
An iterator over the results from applying this QuerySet to the
|
||||
database.
|
||||
"""
|
||||
fill_cache = False
|
||||
if connections[self.db].features.supports_select_related:
|
||||
fill_cache = self.query.select_related
|
||||
if isinstance(fill_cache, dict):
|
||||
requested = fill_cache
|
||||
else:
|
||||
requested = None
|
||||
max_depth = self.query.max_depth
|
||||
|
||||
extra_select = list(self.query.extra_select)
|
||||
annotation_select = list(self.query.annotation_select)
|
||||
|
||||
only_load = self.query.get_loaded_field_names()
|
||||
fields = self.model._meta.concrete_fields
|
||||
|
||||
load_fields = []
|
||||
# If only/defer clauses have been specified,
|
||||
# build the list of fields that are to be loaded.
|
||||
if only_load:
|
||||
for field in self.model._meta.concrete_fields:
|
||||
model = field.model._meta.model
|
||||
try:
|
||||
if field.name in only_load[model]:
|
||||
# Add a field that has been explicitly included
|
||||
load_fields.append(field.name)
|
||||
except KeyError:
|
||||
# Model wasn't explicitly listed in the only_load table
|
||||
# Therefore, we need to load all fields from this model
|
||||
load_fields.append(field.name)
|
||||
|
||||
skip = None
|
||||
if load_fields:
|
||||
# Some fields have been deferred, so we have to initialize
|
||||
# via keyword arguments.
|
||||
skip = set()
|
||||
init_list = []
|
||||
for field in fields:
|
||||
if field.name not in load_fields:
|
||||
skip.add(field.attname)
|
||||
else:
|
||||
init_list.append(field.attname)
|
||||
model_cls = deferred_class_factory(self.model, skip)
|
||||
else:
|
||||
model_cls = self.model
|
||||
init_list = [f.attname for f in fields]
|
||||
|
||||
# Cache db and model outside the loop
|
||||
db = self.db
|
||||
compiler = self.query.get_compiler(using=db)
|
||||
index_start = len(extra_select)
|
||||
annotation_start = index_start + len(init_list)
|
||||
|
||||
if fill_cache:
|
||||
klass_info = get_klass_info(model_cls, max_depth=max_depth,
|
||||
requested=requested, only_load=only_load)
|
||||
for row in compiler.results_iter():
|
||||
if fill_cache:
|
||||
obj, _ = get_cached_row(row, index_start, db, klass_info,
|
||||
offset=len(annotation_select))
|
||||
else:
|
||||
obj = model_cls.from_db(db, init_list, row[index_start:annotation_start])
|
||||
|
||||
if extra_select:
|
||||
for i, k in enumerate(extra_select):
|
||||
setattr(obj, k, row[i])
|
||||
|
||||
# Add the annotations to the model
|
||||
if annotation_select:
|
||||
for i, annotation in enumerate(annotation_select):
|
||||
setattr(obj, annotation, row[i + annotation_start])
|
||||
# Execute the query. This will also fill compiler.select, klass_info,
|
||||
# and annotations.
|
||||
results = compiler.execute_sql()
|
||||
select, klass_info, annotation_col_map = (compiler.select, compiler.klass_info,
|
||||
compiler.annotation_col_map)
|
||||
if klass_info is None:
|
||||
return
|
||||
model_cls = klass_info['model']
|
||||
select_fields = klass_info['select_fields']
|
||||
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
|
||||
init_list = [f[0].output_field.attname
|
||||
for f in select[model_fields_start:model_fields_end]]
|
||||
if len(init_list) != len(model_cls._meta.concrete_fields):
|
||||
init_set = set(init_list)
|
||||
skip = [f.attname for f in model_cls._meta.concrete_fields
|
||||
if f.attname not in init_set]
|
||||
model_cls = deferred_class_factory(model_cls, skip)
|
||||
related_populators = get_related_populators(klass_info, select, db)
|
||||
for row in compiler.results_iter(results):
|
||||
obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end])
|
||||
if related_populators:
|
||||
for rel_populator in related_populators:
|
||||
rel_populator.populate(row, obj)
|
||||
if annotation_col_map:
|
||||
for attr_name, col_pos in annotation_col_map.items():
|
||||
setattr(obj, attr_name, row[col_pos])
|
||||
|
||||
# Add the known related objects to the model, if there are any
|
||||
if self._known_related_objects:
|
||||
|
@ -1032,11 +989,8 @@ class QuerySet(object):
|
|||
"""
|
||||
Prepare the query for computing a result that contains aggregate annotations.
|
||||
"""
|
||||
opts = self.model._meta
|
||||
if self.query.group_by is None:
|
||||
field_names = [f.attname for f in opts.concrete_fields]
|
||||
self.query.add_fields(field_names, False)
|
||||
self.query.set_group_by()
|
||||
self.query.group_by = True
|
||||
|
||||
def _prepare(self):
|
||||
return self
|
||||
|
@ -1135,9 +1089,11 @@ class ValuesQuerySet(QuerySet):
|
|||
Called by the _clone() method after initializing the rest of the
|
||||
instance.
|
||||
"""
|
||||
if self.query.group_by is True:
|
||||
self.query.add_fields([f.attname for f in self.model._meta.concrete_fields], False)
|
||||
self.query.set_group_by()
|
||||
self.query.clear_deferred_loading()
|
||||
self.query.clear_select_fields()
|
||||
|
||||
if self._fields:
|
||||
self.extra_names = []
|
||||
self.annotation_names = []
|
||||
|
@ -1246,11 +1202,12 @@ class ValuesQuerySet(QuerySet):
|
|||
|
||||
class ValuesListQuerySet(ValuesQuerySet):
|
||||
def iterator(self):
|
||||
compiler = self.query.get_compiler(self.db)
|
||||
if self.flat and len(self._fields) == 1:
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
for row in compiler.results_iter():
|
||||
yield row[0]
|
||||
elif not self.query.extra_select and not self.query.annotation_select:
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
for row in compiler.results_iter():
|
||||
yield tuple(row)
|
||||
else:
|
||||
# When extra(select=...) or an annotation is involved, the extra
|
||||
|
@ -1269,7 +1226,7 @@ class ValuesListQuerySet(ValuesQuerySet):
|
|||
else:
|
||||
fields = names
|
||||
|
||||
for row in self.query.get_compiler(self.db).results_iter():
|
||||
for row in compiler.results_iter():
|
||||
data = dict(zip(names, row))
|
||||
yield tuple(data[f] for f in fields)
|
||||
|
||||
|
@ -1281,244 +1238,6 @@ class ValuesListQuerySet(ValuesQuerySet):
|
|||
return clone
|
||||
|
||||
|
||||
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
||||
only_load=None, from_parent=None):
|
||||
"""
|
||||
Helper function that recursively returns an information for a klass, to be
|
||||
used in get_cached_row. It exists just to compute this information only
|
||||
once for entire queryset. Otherwise it would be computed for each row, which
|
||||
leads to poor performance on large querysets.
|
||||
|
||||
Arguments:
|
||||
* klass - the class to retrieve (and instantiate)
|
||||
* max_depth - the maximum depth to which a select_related()
|
||||
relationship should be explored.
|
||||
* cur_depth - the current depth in the select_related() tree.
|
||||
Used in recursive calls to determine if we should dig deeper.
|
||||
* requested - A dictionary describing the select_related() tree
|
||||
that is to be retrieved. keys are field names; values are
|
||||
dictionaries describing the keys on that related object that
|
||||
are themselves to be select_related().
|
||||
* only_load - if the query has had only() or defer() applied,
|
||||
this is the list of field names that will be returned. If None,
|
||||
the full field list for `klass` can be assumed.
|
||||
* from_parent - the parent model used to get to this model
|
||||
|
||||
Note that when travelling from parent to child, we will only load child
|
||||
fields which aren't in the parent.
|
||||
"""
|
||||
if max_depth and requested is None and cur_depth > max_depth:
|
||||
# We've recursed deeply enough; stop now.
|
||||
return None
|
||||
|
||||
if only_load:
|
||||
load_fields = only_load.get(klass) or set()
|
||||
# When we create the object, we will also be creating populating
|
||||
# all the parent classes, so traverse the parent classes looking
|
||||
# for fields that must be included on load.
|
||||
for parent in klass._meta.get_parent_list():
|
||||
fields = only_load.get(parent)
|
||||
if fields:
|
||||
load_fields.update(fields)
|
||||
else:
|
||||
load_fields = None
|
||||
|
||||
if load_fields:
|
||||
# Handle deferred fields.
|
||||
skip = set()
|
||||
init_list = []
|
||||
# Build the list of fields that *haven't* been requested
|
||||
for field in klass._meta.concrete_fields:
|
||||
model = field.model._meta.concrete_model
|
||||
if from_parent and model and issubclass(from_parent, model):
|
||||
# Avoid loading fields already loaded for parent model for
|
||||
# child models.
|
||||
continue
|
||||
elif field.name not in load_fields:
|
||||
skip.add(field.attname)
|
||||
else:
|
||||
init_list.append(field.attname)
|
||||
# Retrieve all the requested fields
|
||||
field_count = len(init_list)
|
||||
if skip:
|
||||
klass = deferred_class_factory(klass, skip)
|
||||
field_names = init_list
|
||||
else:
|
||||
field_names = ()
|
||||
else:
|
||||
# Load all fields on klass
|
||||
|
||||
field_count = len(klass._meta.concrete_fields)
|
||||
# Check if we need to skip some parent fields.
|
||||
if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields):
|
||||
# Only load those fields which haven't been already loaded into
|
||||
# 'from_parent'.
|
||||
non_seen_models = [p for p in klass._meta.get_parent_list()
|
||||
if not issubclass(from_parent, p)]
|
||||
# Load local fields, too...
|
||||
non_seen_models.append(klass)
|
||||
field_names = [f.attname for f in klass._meta.concrete_fields
|
||||
if f.model in non_seen_models]
|
||||
field_count = len(field_names)
|
||||
# Try to avoid populating field_names variable for performance reasons.
|
||||
# If field_names variable is set, we use **kwargs based model init
|
||||
# which is slower than normal init.
|
||||
if field_count == len(klass._meta.concrete_fields):
|
||||
field_names = ()
|
||||
|
||||
restricted = requested is not None
|
||||
|
||||
related_fields = []
|
||||
for f in klass._meta.fields:
|
||||
if select_related_descend(f, restricted, requested, load_fields):
|
||||
if restricted:
|
||||
next = requested[f.name]
|
||||
else:
|
||||
next = None
|
||||
klass_info = get_klass_info(f.rel.to, max_depth=max_depth, cur_depth=cur_depth + 1,
|
||||
requested=next, only_load=only_load)
|
||||
related_fields.append((f, klass_info))
|
||||
|
||||
reverse_related_fields = []
|
||||
if restricted:
|
||||
for o in klass._meta.related_objects:
|
||||
if o.field.unique and select_related_descend(o.field, restricted, requested,
|
||||
only_load.get(o.related_model), reverse=True):
|
||||
next = requested[o.field.related_query_name()]
|
||||
parent = klass if issubclass(o.related_model, klass) else None
|
||||
klass_info = get_klass_info(o.related_model, max_depth=max_depth, cur_depth=cur_depth + 1,
|
||||
requested=next, only_load=only_load, from_parent=parent)
|
||||
reverse_related_fields.append((o.field, klass_info))
|
||||
if field_names:
|
||||
pk_idx = field_names.index(klass._meta.pk.attname)
|
||||
else:
|
||||
meta = klass._meta
|
||||
pk_idx = meta.concrete_fields.index(meta.pk)
|
||||
|
||||
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
|
||||
|
||||
|
||||
def reorder_for_init(model, field_names, values):
|
||||
"""
|
||||
Reorders given field names and values for those fields
|
||||
to be in the same order as model.__init__() expects to find them.
|
||||
"""
|
||||
new_names, new_values = [], []
|
||||
for f in model._meta.concrete_fields:
|
||||
if f.attname not in field_names:
|
||||
continue
|
||||
new_names.append(f.attname)
|
||||
new_values.append(values[field_names.index(f.attname)])
|
||||
assert len(new_names) == len(field_names)
|
||||
return new_names, new_values
|
||||
|
||||
|
||||
def get_cached_row(row, index_start, using, klass_info, offset=0,
|
||||
parent_data=()):
|
||||
"""
|
||||
Helper function that recursively returns an object with the specified
|
||||
related attributes already populated.
|
||||
|
||||
This method may be called recursively to populate deep select_related()
|
||||
clauses.
|
||||
|
||||
Arguments:
|
||||
* row - the row of data returned by the database cursor
|
||||
* index_start - the index of the row at which data for this
|
||||
object is known to start
|
||||
* offset - the number of additional fields that are known to
|
||||
exist in row for `klass`. This usually means the number of
|
||||
annotated results on `klass`.
|
||||
* using - the database alias on which the query is being executed.
|
||||
* klass_info - result of the get_klass_info function
|
||||
* parent_data - parent model data in format (field, value). Used
|
||||
to populate the non-local fields of child models.
|
||||
"""
|
||||
if klass_info is None:
|
||||
return None
|
||||
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
|
||||
|
||||
fields = row[index_start:index_start + field_count]
|
||||
# If the pk column is None (or the equivalent '' in the case the
|
||||
# connection interprets empty strings as nulls), then the related
|
||||
# object must be non-existent - set the relation to None.
|
||||
if (fields[pk_idx] is None or
|
||||
(connections[using].features.interprets_empty_strings_as_nulls and
|
||||
fields[pk_idx] == '')):
|
||||
obj = None
|
||||
elif field_names:
|
||||
values = list(fields)
|
||||
parent_values = []
|
||||
parent_field_names = []
|
||||
for rel_field, value in parent_data:
|
||||
parent_field_names.append(rel_field.attname)
|
||||
parent_values.append(value)
|
||||
field_names, values = reorder_for_init(
|
||||
klass, parent_field_names + field_names,
|
||||
parent_values + values)
|
||||
obj = klass.from_db(using, field_names, values)
|
||||
else:
|
||||
field_names = [f.attname for f in klass._meta.concrete_fields]
|
||||
obj = klass.from_db(using, field_names, fields)
|
||||
# Instantiate related fields
|
||||
index_end = index_start + field_count + offset
|
||||
# Iterate over each related object, populating any
|
||||
# select_related() fields
|
||||
for f, klass_info in related_fields:
|
||||
# Recursively retrieve the data for the related object
|
||||
cached_row = get_cached_row(row, index_end, using, klass_info)
|
||||
# If the recursive descent found an object, populate the
|
||||
# descriptor caches relevant to the object
|
||||
if cached_row:
|
||||
rel_obj, index_end = cached_row
|
||||
if obj is not None:
|
||||
# If the base object exists, populate the
|
||||
# descriptor cache
|
||||
setattr(obj, f.get_cache_name(), rel_obj)
|
||||
if f.unique and rel_obj is not None:
|
||||
# If the field is unique, populate the
|
||||
# reverse descriptor cache on the related object
|
||||
setattr(rel_obj, f.rel.get_cache_name(), obj)
|
||||
|
||||
# Now do the same, but for reverse related objects.
|
||||
# Only handle the restricted case - i.e., don't do a depth
|
||||
# descent into reverse relations unless explicitly requested
|
||||
for f, klass_info in reverse_related_fields:
|
||||
# Transfer data from this object to childs.
|
||||
parent_data = []
|
||||
for rel_field in klass_info[0]._meta.fields:
|
||||
rel_model = rel_field.model._meta.concrete_model
|
||||
if rel_model == klass_info[0]._meta.model:
|
||||
rel_model = None
|
||||
if rel_model is not None and isinstance(obj, rel_model):
|
||||
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
|
||||
# Recursively retrieve the data for the related object
|
||||
cached_row = get_cached_row(row, index_end, using, klass_info,
|
||||
parent_data=parent_data)
|
||||
# If the recursive descent found an object, populate the
|
||||
# descriptor caches relevant to the object
|
||||
if cached_row:
|
||||
rel_obj, index_end = cached_row
|
||||
if obj is not None:
|
||||
# populate the reverse descriptor cache
|
||||
setattr(obj, f.rel.get_cache_name(), rel_obj)
|
||||
if rel_obj is not None:
|
||||
# If the related object exists, populate
|
||||
# the descriptor cache.
|
||||
setattr(rel_obj, f.get_cache_name(), obj)
|
||||
# Populate related object caches using parent data.
|
||||
for rel_field, _ in parent_data:
|
||||
if rel_field.rel:
|
||||
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
|
||||
try:
|
||||
cached_obj = getattr(obj, rel_field.get_cache_name())
|
||||
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
|
||||
except AttributeError:
|
||||
# Related object hasn't been cached yet
|
||||
pass
|
||||
return obj, index_end
|
||||
|
||||
|
||||
class RawQuerySet(object):
|
||||
"""
|
||||
Provides an iterator which converts the results of raw SQL queries into
|
||||
|
@ -1569,7 +1288,9 @@ class RawQuerySet(object):
|
|||
else:
|
||||
model_cls = self.model
|
||||
fields = [self.model_fields.get(c, None) for c in self.columns]
|
||||
converters = compiler.get_converters(fields)
|
||||
converters = compiler.get_converters([
|
||||
f.get_col(f.model._meta.db_table) if f else None for f in fields
|
||||
])
|
||||
for values in query:
|
||||
if converters:
|
||||
values = compiler.apply_converters(values, converters)
|
||||
|
@ -1920,3 +1641,120 @@ def prefetch_one_level(instances, prefetcher, lookup, level):
|
|||
qs._prefetch_done = True
|
||||
obj._prefetched_objects_cache[cache_name] = qs
|
||||
return all_related_objects, additional_lookups
|
||||
|
||||
|
||||
class RelatedPopulator(object):
|
||||
"""
|
||||
RelatedPopulator is used for select_related() object instantiation.
|
||||
|
||||
The idea is that each select_related() model will be populated by a
|
||||
different RelatedPopulator instance. The RelatedPopulator instances get
|
||||
klass_info and select (computed in SQLCompiler) plus the used db as
|
||||
input for initialization. That data is used to compute which columns
|
||||
to use, how to instantiate the model, and how to populate the links
|
||||
between the objects.
|
||||
|
||||
The actual creation of the objects is done in populate() method. This
|
||||
method gets row and from_obj as input and populates the select_related()
|
||||
model instance.
|
||||
"""
|
||||
def __init__(self, klass_info, select, db):
|
||||
self.db = db
|
||||
# Pre-compute needed attributes. The attributes are:
|
||||
# - model_cls: the possibly deferred model class to instantiate
|
||||
# - either:
|
||||
# - cols_start, cols_end: usually the columns in the row are
|
||||
# in the same order model_cls.__init__ expects them, so we
|
||||
# can instantiate by model_cls(*row[cols_start:cols_end])
|
||||
# - reorder_for_init: When select_related descends to a child
|
||||
# class, then we want to reuse the already selected parent
|
||||
# data. However, in this case the parent data isn't necessarily
|
||||
# in the same order that Model.__init__ expects it to be, so
|
||||
# we have to reorder the parent data. The reorder_for_init
|
||||
# attribute contains a function used to reorder the field data
|
||||
# in the order __init__ expects it.
|
||||
# - pk_idx: the index of the primary key field in the reordered
|
||||
# model data. Used to check if a related object exists at all.
|
||||
# - init_list: the field attnames fetched from the database. For
|
||||
# deferred models this isn't the same as all attnames of the
|
||||
# model's fields.
|
||||
# - related_populators: a list of RelatedPopulator instances if
|
||||
# select_related() descends to related models from this model.
|
||||
# - cache_name, reverse_cache_name: the names to use for setattr
|
||||
# when assigning the fetched object to the from_obj. If the
|
||||
# reverse_cache_name is set, then we also set the reverse link.
|
||||
select_fields = klass_info['select_fields']
|
||||
from_parent = klass_info['from_parent']
|
||||
if not from_parent:
|
||||
self.cols_start = select_fields[0]
|
||||
self.cols_end = select_fields[-1] + 1
|
||||
self.init_list = [
|
||||
f[0].output_field.attname for f in select[self.cols_start:self.cols_end]
|
||||
]
|
||||
self.reorder_for_init = None
|
||||
else:
|
||||
model_init_attnames = [
|
||||
f.attname for f in klass_info['model']._meta.concrete_fields
|
||||
]
|
||||
reorder_map = []
|
||||
for idx in select_fields:
|
||||
field = select[idx][0].output_field
|
||||
init_pos = model_init_attnames.index(field.attname)
|
||||
reorder_map.append((init_pos, field.attname, idx))
|
||||
reorder_map.sort()
|
||||
self.init_list = [v[1] for v in reorder_map]
|
||||
pos_list = [row_pos for _, _, row_pos in reorder_map]
|
||||
|
||||
def reorder_for_init(row):
|
||||
return [row[row_pos] for row_pos in pos_list]
|
||||
self.reorder_for_init = reorder_for_init
|
||||
|
||||
self.model_cls = self.get_deferred_cls(klass_info, self.init_list)
|
||||
self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
|
||||
self.related_populators = get_related_populators(klass_info, select, self.db)
|
||||
field = klass_info['field']
|
||||
reverse = klass_info['reverse']
|
||||
self.reverse_cache_name = None
|
||||
if reverse:
|
||||
self.cache_name = field.rel.get_cache_name()
|
||||
self.reverse_cache_name = field.get_cache_name()
|
||||
else:
|
||||
self.cache_name = field.get_cache_name()
|
||||
if field.unique:
|
||||
self.reverse_cache_name = field.rel.get_cache_name()
|
||||
|
||||
def get_deferred_cls(self, klass_info, init_list):
|
||||
model_cls = klass_info['model']
|
||||
if len(init_list) != len(model_cls._meta.concrete_fields):
|
||||
init_set = set(init_list)
|
||||
skip = [
|
||||
f.attname for f in model_cls._meta.concrete_fields
|
||||
if f.attname not in init_set
|
||||
]
|
||||
model_cls = deferred_class_factory(model_cls, skip)
|
||||
return model_cls
|
||||
|
||||
def populate(self, row, from_obj):
|
||||
if self.reorder_for_init:
|
||||
obj_data = self.reorder_for_init(row)
|
||||
else:
|
||||
obj_data = row[self.cols_start:self.cols_end]
|
||||
if obj_data[self.pk_idx] is None:
|
||||
obj = None
|
||||
else:
|
||||
obj = self.model_cls.from_db(self.db, self.init_list, obj_data)
|
||||
if obj and self.related_populators:
|
||||
for rel_iter in self.related_populators:
|
||||
rel_iter.populate(row, obj)
|
||||
setattr(from_obj, self.cache_name, obj)
|
||||
if obj and self.reverse_cache_name:
|
||||
setattr(obj, self.reverse_cache_name, from_obj)
|
||||
|
||||
|
||||
def get_related_populators(klass_info, select, db):
|
||||
iterators = []
|
||||
related_klass_infos = klass_info.get('related_klass_infos', [])
|
||||
for rel_klass_info in related_klass_infos:
|
||||
rel_cls = RelatedPopulator(rel_klass_info, select, db)
|
||||
iterators.append(rel_cls)
|
||||
return iterators
|
||||
|
|
|
@ -170,7 +170,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
|
|||
if not restricted and field.null:
|
||||
return False
|
||||
if load_fields:
|
||||
if field.name not in load_fields:
|
||||
if field.attname not in load_fields:
|
||||
if restricted and field.name in requested:
|
||||
raise InvalidQuery("Field %s.%s cannot be both deferred"
|
||||
" and traversed using select_related"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -2,7 +2,6 @@
|
|||
Constants specific to the SQL storage portion of the ORM.
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
import re
|
||||
|
||||
# Valid query types (a set is used for speedy lookups). These are (currently)
|
||||
|
@ -21,9 +20,6 @@ GET_ITERATOR_CHUNK_SIZE = 100
|
|||
|
||||
# Namedtuples for sql.* internal use.
|
||||
|
||||
# Pairs of column clauses to select, and (possibly None) field for the clause.
|
||||
SelectInfo = namedtuple('SelectInfo', 'col field')
|
||||
|
||||
# How many results to expect from a cursor.execute call
|
||||
MULTI = 'multi'
|
||||
SINGLE = 'single'
|
||||
|
|
|
@ -21,7 +21,7 @@ from django.db.models.constants import LOOKUP_SEP
|
|||
from django.db.models.expressions import Col, Ref
|
||||
from django.db.models.query_utils import PathInfo, Q, refs_aggregate
|
||||
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
|
||||
ORDER_PATTERN, SelectInfo, INNER, LOUTER)
|
||||
ORDER_PATTERN, INNER, LOUTER)
|
||||
from django.db.models.sql.datastructures import (
|
||||
EmptyResultSet, Empty, MultiJoin, Join, BaseTable)
|
||||
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
|
||||
|
@ -46,7 +46,7 @@ class RawQuery(object):
|
|||
A single raw SQL query
|
||||
"""
|
||||
|
||||
def __init__(self, sql, using, params=None):
|
||||
def __init__(self, sql, using, params=None, context=None):
|
||||
self.params = params or ()
|
||||
self.sql = sql
|
||||
self.using = using
|
||||
|
@ -57,9 +57,10 @@ class RawQuery(object):
|
|||
self.low_mark, self.high_mark = 0, None # Used for offset/limit
|
||||
self.extra_select = {}
|
||||
self.annotation_select = {}
|
||||
self.context = context or {}
|
||||
|
||||
def clone(self, using):
|
||||
return RawQuery(self.sql, using, params=self.params)
|
||||
return RawQuery(self.sql, using, params=self.params, context=self.context.copy())
|
||||
|
||||
def get_columns(self):
|
||||
if self.cursor is None:
|
||||
|
@ -122,20 +123,23 @@ class Query(object):
|
|||
self.standard_ordering = True
|
||||
self.used_aliases = set()
|
||||
self.filter_is_sticky = False
|
||||
self.included_inherited_models = {}
|
||||
|
||||
# SQL-related attributes
|
||||
# Select and related select clauses as SelectInfo instances.
|
||||
# Select and related select clauses are expressions to use in the
|
||||
# SELECT clause of the query.
|
||||
# The select is used for cases where we want to set up the select
|
||||
# clause to contain other than default fields (values(), annotate(),
|
||||
# subqueries...)
|
||||
# clause to contain other than default fields (values(), subqueries...)
|
||||
# Note that annotations go to annotations dictionary.
|
||||
self.select = []
|
||||
# The related_select_cols is used for columns needed for
|
||||
# select_related - this is populated in the compile stage.
|
||||
self.related_select_cols = []
|
||||
self.tables = [] # Aliases in the order they are created.
|
||||
self.where = where()
|
||||
self.where_class = where
|
||||
# The group_by attribute can have one of the following forms:
|
||||
# - None: no group by at all in the query
|
||||
# - A list of expressions: group by (at least) those expressions.
|
||||
# String refs are also allowed for now.
|
||||
# - True: group by all select fields of the model
|
||||
# See compiler.get_group_by() for details.
|
||||
self.group_by = None
|
||||
self.having = where()
|
||||
self.order_by = []
|
||||
|
@ -174,6 +178,8 @@ class Query(object):
|
|||
# load.
|
||||
self.deferred_loading = (set(), True)
|
||||
|
||||
self.context = {}
|
||||
|
||||
@property
|
||||
def extra(self):
|
||||
if self._extra is None:
|
||||
|
@ -254,14 +260,14 @@ class Query(object):
|
|||
obj.default_cols = self.default_cols
|
||||
obj.default_ordering = self.default_ordering
|
||||
obj.standard_ordering = self.standard_ordering
|
||||
obj.included_inherited_models = self.included_inherited_models.copy()
|
||||
obj.select = self.select[:]
|
||||
obj.related_select_cols = []
|
||||
obj.tables = self.tables[:]
|
||||
obj.where = self.where.clone()
|
||||
obj.where_class = self.where_class
|
||||
if self.group_by is None:
|
||||
obj.group_by = None
|
||||
elif self.group_by is True:
|
||||
obj.group_by = True
|
||||
else:
|
||||
obj.group_by = self.group_by[:]
|
||||
obj.having = self.having.clone()
|
||||
|
@ -272,7 +278,6 @@ class Query(object):
|
|||
obj.select_for_update = self.select_for_update
|
||||
obj.select_for_update_nowait = self.select_for_update_nowait
|
||||
obj.select_related = self.select_related
|
||||
obj.related_select_cols = []
|
||||
obj._annotations = self._annotations.copy() if self._annotations is not None else None
|
||||
if self.annotation_select_mask is None:
|
||||
obj.annotation_select_mask = None
|
||||
|
@ -310,8 +315,15 @@ class Query(object):
|
|||
obj.__dict__.update(kwargs)
|
||||
if hasattr(obj, '_setup_query'):
|
||||
obj._setup_query()
|
||||
obj.context = self.context.copy()
|
||||
return obj
|
||||
|
||||
def add_context(self, key, value):
|
||||
self.context[key] = value
|
||||
|
||||
def get_context(self, key, default=None):
|
||||
return self.context.get(key, default)
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
clone = self.clone()
|
||||
clone.change_aliases(change_map)
|
||||
|
@ -375,7 +387,8 @@ class Query(object):
|
|||
# done in a subquery so that we are aggregating on the limit and/or
|
||||
# distinct results instead of applying the distinct and limit after the
|
||||
# aggregation.
|
||||
if (self.group_by or has_limit or has_existing_annotations or self.distinct):
|
||||
if (isinstance(self.group_by, list) or has_limit or has_existing_annotations or
|
||||
self.distinct):
|
||||
from django.db.models.sql.subqueries import AggregateQuery
|
||||
outer_query = AggregateQuery(self.model)
|
||||
inner_query = self.clone()
|
||||
|
@ -383,7 +396,6 @@ class Query(object):
|
|||
inner_query.clear_ordering(True)
|
||||
inner_query.select_for_update = False
|
||||
inner_query.select_related = False
|
||||
inner_query.related_select_cols = []
|
||||
|
||||
relabels = {t: 'subquery' for t in inner_query.tables}
|
||||
relabels[None] = 'subquery'
|
||||
|
@ -407,26 +419,17 @@ class Query(object):
|
|||
self.select = []
|
||||
self.default_cols = False
|
||||
self._extra = {}
|
||||
self.remove_inherited_models()
|
||||
|
||||
outer_query.clear_ordering(True)
|
||||
outer_query.clear_limits()
|
||||
outer_query.select_for_update = False
|
||||
outer_query.select_related = False
|
||||
outer_query.related_select_cols = []
|
||||
compiler = outer_query.get_compiler(using)
|
||||
result = compiler.execute_sql(SINGLE)
|
||||
if result is None:
|
||||
result = [None for q in outer_query.annotation_select.items()]
|
||||
|
||||
fields = [annotation.output_field
|
||||
for alias, annotation in outer_query.annotation_select.items()]
|
||||
converters = compiler.get_converters(fields)
|
||||
for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()):
|
||||
if position in converters:
|
||||
converters[position][1].insert(0, annotation.convert_value)
|
||||
else:
|
||||
converters[position] = ([], [annotation.convert_value], annotation.output_field)
|
||||
converters = compiler.get_converters(outer_query.annotation_select.values())
|
||||
result = compiler.apply_converters(result, converters)
|
||||
|
||||
return {
|
||||
|
@ -476,7 +479,6 @@ class Query(object):
|
|||
assert self.distinct_fields == rhs.distinct_fields, \
|
||||
"Cannot combine queries with different distinct fields."
|
||||
|
||||
self.remove_inherited_models()
|
||||
# Work out how to relabel the rhs aliases, if necessary.
|
||||
change_map = {}
|
||||
conjunction = (connector == AND)
|
||||
|
@ -545,13 +547,8 @@ class Query(object):
|
|||
|
||||
# Selection columns and extra extensions are those provided by 'rhs'.
|
||||
self.select = []
|
||||
for col, field in rhs.select:
|
||||
if isinstance(col, (list, tuple)):
|
||||
new_col = change_map.get(col[0], col[0]), col[1]
|
||||
self.select.append(SelectInfo(new_col, field))
|
||||
else:
|
||||
new_col = col.relabeled_clone(change_map)
|
||||
self.select.append(SelectInfo(new_col, field))
|
||||
for col in rhs.select:
|
||||
self.add_select(col.relabeled_clone(change_map))
|
||||
|
||||
if connector == OR:
|
||||
# It would be nice to be able to handle this, but the queries don't
|
||||
|
@ -661,17 +658,6 @@ class Query(object):
|
|||
for model, values in six.iteritems(seen):
|
||||
callback(target, model, values)
|
||||
|
||||
def deferred_to_columns_cb(self, target, model, fields):
|
||||
"""
|
||||
Callback used by deferred_to_columns(). The "target" parameter should
|
||||
be a set instance.
|
||||
"""
|
||||
table = model._meta.db_table
|
||||
if table not in target:
|
||||
target[table] = set()
|
||||
for field in fields:
|
||||
target[table].add(field.column)
|
||||
|
||||
def table_alias(self, table_name, create=False):
|
||||
"""
|
||||
Returns a table alias for the given table_name and whether this is a
|
||||
|
@ -788,10 +774,9 @@ class Query(object):
|
|||
# "group by", "where" and "having".
|
||||
self.where.relabel_aliases(change_map)
|
||||
self.having.relabel_aliases(change_map)
|
||||
if self.group_by:
|
||||
if isinstance(self.group_by, list):
|
||||
self.group_by = [relabel_column(col) for col in self.group_by]
|
||||
self.select = [SelectInfo(relabel_column(s.col), s.field)
|
||||
for s in self.select]
|
||||
self.select = [col.relabeled_clone(change_map) for col in self.select]
|
||||
if self._annotations:
|
||||
self._annotations = OrderedDict(
|
||||
(key, relabel_column(col)) for key, col in self._annotations.items())
|
||||
|
@ -815,9 +800,6 @@ class Query(object):
|
|||
if alias == old_alias:
|
||||
self.tables[pos] = new_alias
|
||||
break
|
||||
for key, alias in self.included_inherited_models.items():
|
||||
if alias in change_map:
|
||||
self.included_inherited_models[key] = change_map[alias]
|
||||
self.external_aliases = {change_map.get(alias, alias)
|
||||
for alias in self.external_aliases}
|
||||
|
||||
|
@ -930,28 +912,6 @@ class Query(object):
|
|||
self.alias_map[alias] = join
|
||||
return alias
|
||||
|
||||
def setup_inherited_models(self):
|
||||
"""
|
||||
If the model that is the basis for this QuerySet inherits other models,
|
||||
we need to ensure that those other models have their tables included in
|
||||
the query.
|
||||
|
||||
We do this as a separate step so that subclasses know which
|
||||
tables are going to be active in the query, without needing to compute
|
||||
all the select columns (this method is called from pre_sql_setup(),
|
||||
whereas column determination is a later part, and side-effect, of
|
||||
as_sql()).
|
||||
"""
|
||||
opts = self.get_meta()
|
||||
root_alias = self.tables[0]
|
||||
seen = {None: root_alias}
|
||||
|
||||
for field in opts.fields:
|
||||
model = field.model._meta.concrete_model
|
||||
if model is not opts.model and model not in seen:
|
||||
self.join_parent_model(opts, model, root_alias, seen)
|
||||
self.included_inherited_models = seen
|
||||
|
||||
def join_parent_model(self, opts, model, alias, seen):
|
||||
"""
|
||||
Makes sure the given 'model' is joined in the query. If 'model' isn't
|
||||
|
@ -969,7 +929,9 @@ class Query(object):
|
|||
curr_opts = opts
|
||||
for int_model in chain:
|
||||
if int_model in seen:
|
||||
return seen[int_model]
|
||||
curr_opts = int_model._meta
|
||||
alias = seen[int_model]
|
||||
continue
|
||||
# Proxy model have elements in base chain
|
||||
# with no parents, assign the new options
|
||||
# object and skip to the next base in that
|
||||
|
@ -984,23 +946,13 @@ class Query(object):
|
|||
alias = seen[int_model] = joins[-1]
|
||||
return alias or seen[None]
|
||||
|
||||
def remove_inherited_models(self):
|
||||
"""
|
||||
Undoes the effects of setup_inherited_models(). Should be called
|
||||
whenever select columns (self.select) are set explicitly.
|
||||
"""
|
||||
for key, alias in self.included_inherited_models.items():
|
||||
if key:
|
||||
self.unref_alias(alias)
|
||||
self.included_inherited_models = {}
|
||||
|
||||
def add_aggregate(self, aggregate, model, alias, is_summary):
|
||||
warnings.warn(
|
||||
"add_aggregate() is deprecated. Use add_annotation() instead.",
|
||||
RemovedInDjango20Warning, stacklevel=2)
|
||||
self.add_annotation(aggregate, alias, is_summary)
|
||||
|
||||
def add_annotation(self, annotation, alias, is_summary):
|
||||
def add_annotation(self, annotation, alias, is_summary=False):
|
||||
"""
|
||||
Adds a single annotation expression to the Query
|
||||
"""
|
||||
|
@ -1011,6 +963,7 @@ class Query(object):
|
|||
|
||||
def prepare_lookup_value(self, value, lookups, can_reuse):
|
||||
# Default lookup if none given is exact.
|
||||
used_joins = []
|
||||
if len(lookups) == 0:
|
||||
lookups = ['exact']
|
||||
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
|
||||
|
@ -1026,7 +979,9 @@ class Query(object):
|
|||
RemovedInDjango19Warning, stacklevel=2)
|
||||
value = value()
|
||||
elif hasattr(value, 'resolve_expression'):
|
||||
pre_joins = self.alias_refcount.copy()
|
||||
value = value.resolve_expression(self, reuse=can_reuse)
|
||||
used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)]
|
||||
# Subqueries need to use a different set of aliases than the
|
||||
# outer query. Call bump_prefix to change aliases of the inner
|
||||
# query (the value).
|
||||
|
@ -1044,7 +999,7 @@ class Query(object):
|
|||
lookups[-1] == 'exact' and value == ''):
|
||||
value = True
|
||||
lookups[-1] = 'isnull'
|
||||
return value, lookups
|
||||
return value, lookups, used_joins
|
||||
|
||||
def solve_lookup_type(self, lookup):
|
||||
"""
|
||||
|
@ -1173,8 +1128,7 @@ class Query(object):
|
|||
|
||||
# Work out the lookup type and remove it from the end of 'parts',
|
||||
# if necessary.
|
||||
value, lookups = self.prepare_lookup_value(value, lookups, can_reuse)
|
||||
used_joins = getattr(value, '_used_joins', [])
|
||||
value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse)
|
||||
|
||||
clause = self.where_class()
|
||||
if reffed_aggregate:
|
||||
|
@ -1223,7 +1177,7 @@ class Query(object):
|
|||
# handle Expressions as annotations
|
||||
col = targets[0]
|
||||
else:
|
||||
col = Col(alias, targets[0], field)
|
||||
col = targets[0].get_col(alias, field)
|
||||
condition = self.build_lookup(lookups, col, value)
|
||||
if not condition:
|
||||
# Backwards compat for custom lookups
|
||||
|
@ -1258,7 +1212,7 @@ class Query(object):
|
|||
# <=>
|
||||
# NOT (col IS NOT NULL AND col = someval).
|
||||
lookup_class = targets[0].get_lookup('isnull')
|
||||
clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND)
|
||||
clause.add(lookup_class(targets[0].get_col(alias, sources[0]), False), AND)
|
||||
return clause, used_joins if not require_outer else ()
|
||||
|
||||
def add_filter(self, filter_clause):
|
||||
|
@ -1535,7 +1489,7 @@ class Query(object):
|
|||
self.unref_alias(joins.pop())
|
||||
return targets, joins[-1], joins
|
||||
|
||||
def resolve_ref(self, name, allow_joins, reuse, summarize):
|
||||
def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False):
|
||||
if not allow_joins and LOOKUP_SEP in name:
|
||||
raise FieldError("Joined field references are not permitted in this query")
|
||||
if name in self.annotations:
|
||||
|
@ -1558,8 +1512,7 @@ class Query(object):
|
|||
"isn't supported")
|
||||
if reuse is not None:
|
||||
reuse.update(join_list)
|
||||
col = Col(join_list[-1], targets[0], sources[0])
|
||||
col._used_joins = join_list
|
||||
col = targets[0].get_col(join_list[-1], sources[0])
|
||||
return col
|
||||
|
||||
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
|
||||
|
@ -1588,26 +1541,28 @@ class Query(object):
|
|||
# Try to have as simple as possible subquery -> trim leading joins from
|
||||
# the subquery.
|
||||
trimmed_prefix, contains_louter = query.trim_start(names_with_path)
|
||||
query.remove_inherited_models()
|
||||
|
||||
# Add extra check to make sure the selected field will not be null
|
||||
# since we are adding an IN <subquery> clause. This prevents the
|
||||
# database from tripping over IN (...,NULL,...) selects and returning
|
||||
# nothing
|
||||
alias, col = query.select[0].col
|
||||
if self.is_nullable(query.select[0].field):
|
||||
lookup_class = query.select[0].field.get_lookup('isnull')
|
||||
lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False)
|
||||
col = query.select[0]
|
||||
select_field = col.field
|
||||
alias = col.alias
|
||||
if self.is_nullable(select_field):
|
||||
lookup_class = select_field.get_lookup('isnull')
|
||||
lookup = lookup_class(select_field.get_col(alias), False)
|
||||
query.where.add(lookup, AND)
|
||||
if alias in can_reuse:
|
||||
select_field = query.select[0].field
|
||||
pk = select_field.model._meta.pk
|
||||
# Need to add a restriction so that outer query's filters are in effect for
|
||||
# the subquery, too.
|
||||
query.bump_prefix(self)
|
||||
lookup_class = select_field.get_lookup('exact')
|
||||
lookup = lookup_class(Col(query.select[0].col[0], pk, pk),
|
||||
Col(alias, pk, pk))
|
||||
# Note that the query.select[0].alias is different from alias
|
||||
# due to bump_prefix above.
|
||||
lookup = lookup_class(pk.get_col(query.select[0].alias),
|
||||
pk.get_col(alias))
|
||||
query.where.add(lookup, AND)
|
||||
query.external_aliases.add(alias)
|
||||
|
||||
|
@ -1687,6 +1642,14 @@ class Query(object):
|
|||
"""
|
||||
self.select = []
|
||||
|
||||
def add_select(self, col):
|
||||
self.default_cols = False
|
||||
self.select.append(col)
|
||||
|
||||
def set_select(self, cols):
|
||||
self.default_cols = False
|
||||
self.select = cols
|
||||
|
||||
def add_distinct_fields(self, *field_names):
|
||||
"""
|
||||
Adds and resolves the given fields to the query's "distinct on" clause.
|
||||
|
@ -1710,7 +1673,7 @@ class Query(object):
|
|||
name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m)
|
||||
targets, final_alias, joins = self.trim_joins(targets, joins, path)
|
||||
for target in targets:
|
||||
self.select.append(SelectInfo((final_alias, target.column), target))
|
||||
self.add_select(target.get_col(final_alias))
|
||||
except MultiJoin:
|
||||
raise FieldError("Invalid field name: '%s'" % name)
|
||||
except FieldError:
|
||||
|
@ -1723,7 +1686,6 @@ class Query(object):
|
|||
+ list(self.annotation_select))
|
||||
raise FieldError("Cannot resolve keyword %r into field. "
|
||||
"Choices are: %s" % (name, ", ".join(names)))
|
||||
self.remove_inherited_models()
|
||||
|
||||
def add_ordering(self, *ordering):
|
||||
"""
|
||||
|
@ -1766,7 +1728,7 @@ class Query(object):
|
|||
"""
|
||||
self.group_by = []
|
||||
|
||||
for col, _ in self.select:
|
||||
for col in self.select:
|
||||
self.group_by.append(col)
|
||||
|
||||
if self._annotations:
|
||||
|
@ -1789,7 +1751,6 @@ class Query(object):
|
|||
for part in field.split(LOOKUP_SEP):
|
||||
d = d.setdefault(part, {})
|
||||
self.select_related = field_dict
|
||||
self.related_select_cols = []
|
||||
|
||||
def add_extra(self, select, select_params, where, params, tables, order_by):
|
||||
"""
|
||||
|
@ -1897,7 +1858,7 @@ class Query(object):
|
|||
"""
|
||||
Callback used by get_deferred_field_names().
|
||||
"""
|
||||
target[model] = set(f.name for f in fields)
|
||||
target[model] = {f.attname for f in fields}
|
||||
|
||||
def set_aggregate_mask(self, names):
|
||||
warnings.warn(
|
||||
|
@ -2041,7 +2002,7 @@ class Query(object):
|
|||
if self.alias_refcount[table] > 0:
|
||||
self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table)
|
||||
break
|
||||
self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields]
|
||||
self.set_select([f.get_col(select_alias) for f in select_fields])
|
||||
return trimmed_prefix, contains_louter
|
||||
|
||||
def is_nullable(self, field):
|
||||
|
|
|
@ -5,7 +5,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
|
|||
from django.core.exceptions import FieldError
|
||||
from django.db import connections
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
|
||||
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||
from django.db.models.sql.query import Query
|
||||
from django.utils import six
|
||||
|
||||
|
@ -71,7 +71,7 @@ class DeleteQuery(Query):
|
|||
else:
|
||||
innerq.clear_select_clause()
|
||||
innerq.select = [
|
||||
SelectInfo((self.get_initial_alias(), pk.column), None)
|
||||
pk.get_col(self.get_initial_alias())
|
||||
]
|
||||
values = innerq
|
||||
self.where = self.where_class()
|
||||
|
|
|
@ -483,7 +483,7 @@ instances::
|
|||
class HandField(models.Field):
|
||||
# ...
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
if value is None:
|
||||
return value
|
||||
return parse_hand(value)
|
||||
|
|
|
@ -399,7 +399,7 @@ calling the appropriate methods on the wrapped expression.
|
|||
clone.expression = self.expression.relabeled_clone(change_map)
|
||||
return clone
|
||||
|
||||
.. method:: convert_value(self, value, connection)
|
||||
.. method:: convert_value(self, value, connection, context)
|
||||
|
||||
A hook allowing the expression to coerce ``value`` into a more
|
||||
appropriate type.
|
||||
|
|
|
@ -1670,7 +1670,7 @@ Field API reference
|
|||
|
||||
When loading data, :meth:`from_db_value` is used:
|
||||
|
||||
.. method:: from_db_value(value, connection)
|
||||
.. method:: from_db_value(value, connection, context)
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
|
|
|
@ -679,7 +679,7 @@ class BaseAggregateTestCase(TestCase):
|
|||
# the only "ORDER BY" clause present in the query.
|
||||
self.assertEqual(
|
||||
re.findall(r'order by (\w+)', qstr),
|
||||
[', '.join(forced_ordering).lower()]
|
||||
[', '.join(f[1][0] for f in forced_ordering).lower()]
|
||||
)
|
||||
else:
|
||||
self.assertNotIn('order by', qstr)
|
||||
|
|
|
@ -490,9 +490,10 @@ class AggregationTests(TestCase):
|
|||
|
||||
# Regression for #15709 - Ensure each group_by field only exists once
|
||||
# per query
|
||||
qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by()
|
||||
grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([], [])
|
||||
self.assertEqual(len(grouping), 1)
|
||||
qstr = str(Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by().query)
|
||||
# Check that there is just one GROUP BY clause (zero commas means at
|
||||
# most one clause)
|
||||
self.assertEqual(qstr[qstr.index('GROUP BY'):].count(', '), 0)
|
||||
|
||||
def test_duplicate_alias(self):
|
||||
# Regression for #11256 - duplicating a default alias raises ValueError.
|
||||
|
@ -924,14 +925,11 @@ class AggregationTests(TestCase):
|
|||
|
||||
# There should only be one GROUP BY clause, for the `id` column.
|
||||
# `name` and `age` should not be grouped on.
|
||||
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], [])
|
||||
self.assertEqual(len(grouping), 1)
|
||||
assert 'id' in grouping[0]
|
||||
assert 'name' not in grouping[0]
|
||||
assert 'age' not in grouping[0]
|
||||
|
||||
# The query group_by property should also only show the `id`.
|
||||
self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')])
|
||||
_, _, group_by = results.query.get_compiler(using='default').pre_sql_setup()
|
||||
self.assertEqual(len(group_by), 1)
|
||||
self.assertIn('id', group_by[0][0])
|
||||
self.assertNotIn('name', group_by[0][0])
|
||||
self.assertNotIn('age', group_by[0][0])
|
||||
|
||||
# Ensure that we get correct results.
|
||||
self.assertEqual(
|
||||
|
@ -953,14 +951,11 @@ class AggregationTests(TestCase):
|
|||
def test_aggregate_duplicate_columns_only(self):
|
||||
# Works with only() too.
|
||||
results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set'))
|
||||
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], [])
|
||||
_, _, grouping = results.query.get_compiler(using='default').pre_sql_setup()
|
||||
self.assertEqual(len(grouping), 1)
|
||||
assert 'id' in grouping[0]
|
||||
assert 'name' not in grouping[0]
|
||||
assert 'age' not in grouping[0]
|
||||
|
||||
# The query group_by property should also only show the `id`.
|
||||
self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')])
|
||||
self.assertIn('id', grouping[0][0])
|
||||
self.assertNotIn('name', grouping[0][0])
|
||||
self.assertNotIn('age', grouping[0][0])
|
||||
|
||||
# Ensure that we get correct results.
|
||||
self.assertEqual(
|
||||
|
@ -983,14 +978,11 @@ class AggregationTests(TestCase):
|
|||
# And select_related()
|
||||
results = Book.objects.select_related('contact').annotate(
|
||||
num_authors=Count('authors'))
|
||||
grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], [])
|
||||
_, _, grouping = results.query.get_compiler(using='default').pre_sql_setup()
|
||||
self.assertEqual(len(grouping), 1)
|
||||
assert 'id' in grouping[0]
|
||||
assert 'name' not in grouping[0]
|
||||
assert 'contact' not in grouping[0]
|
||||
|
||||
# The query group_by property should also only show the `id`.
|
||||
self.assertEqual(results.query.group_by, [('aggregation_regress_book', 'id')])
|
||||
self.assertIn('id', grouping[0][0])
|
||||
self.assertNotIn('name', grouping[0][0])
|
||||
self.assertNotIn('contact', grouping[0][0])
|
||||
|
||||
# Ensure that we get correct results.
|
||||
self.assertEqual(
|
||||
|
|
|
@ -43,7 +43,7 @@ class MyAutoField(models.CharField):
|
|||
value = MyWrapper(value)
|
||||
return value
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
if not value:
|
||||
return
|
||||
return MyWrapper(value)
|
||||
|
|
|
@ -96,3 +96,11 @@ class Request(models.Model):
|
|||
request2 = models.CharField(default='request2', max_length=1000)
|
||||
request3 = models.CharField(default='request3', max_length=1000)
|
||||
request4 = models.CharField(default='request4', max_length=1000)
|
||||
|
||||
|
||||
class Base(models.Model):
|
||||
text = models.TextField()
|
||||
|
||||
|
||||
class Derived(Base):
|
||||
other_text = models.TextField()
|
||||
|
|
|
@ -12,7 +12,7 @@ from django.test import TestCase, override_settings
|
|||
from .models import (
|
||||
ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, SimpleItem, Feature,
|
||||
ItemAndSimpleItem, OneToOneItem, SpecialFeature, Location, Request,
|
||||
ProxyRelated,
|
||||
ProxyRelated, Derived, Base,
|
||||
)
|
||||
|
||||
|
||||
|
@ -145,6 +145,15 @@ class DeferRegressionTest(TestCase):
|
|||
list(SimpleItem.objects.annotate(Count('feature')).only('name')),
|
||||
list)
|
||||
|
||||
def test_ticket_23270(self):
|
||||
Derived.objects.create(text="foo", other_text="bar")
|
||||
with self.assertNumQueries(1):
|
||||
obj = Base.objects.select_related("derived").defer("text")[0]
|
||||
self.assertIsInstance(obj.derived, Derived)
|
||||
self.assertEqual("bar", obj.derived.other_text)
|
||||
self.assertNotIn("text", obj.__dict__)
|
||||
self.assertEqual(1, obj.derived.base_ptr_id)
|
||||
|
||||
def test_only_and_defer_usage_on_proxy_models(self):
|
||||
# Regression for #15790 - only() broken for proxy models
|
||||
proxy = Proxy.objects.create(name="proxy", value=42)
|
||||
|
|
|
@ -18,7 +18,7 @@ class CashField(models.DecimalField):
|
|||
kwargs['decimal_places'] = 2
|
||||
super(CashField, self).__init__(**kwargs)
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
cash = Cash(value)
|
||||
cash.vendor = connection.vendor
|
||||
return cash
|
||||
|
|
|
@ -3563,8 +3563,10 @@ class Ticket20955Tests(TestCase):
|
|||
# version's queries.
|
||||
task_get.creator.staffuser.staff
|
||||
task_get.owner.staffuser.staff
|
||||
task_select_related = Task.objects.select_related(
|
||||
'creator__staffuser__staff', 'owner__staffuser__staff').get(pk=task.pk)
|
||||
qs = Task.objects.select_related(
|
||||
'creator__staffuser__staff', 'owner__staffuser__staff')
|
||||
self.assertEqual(str(qs.query).count(' JOIN '), 6)
|
||||
task_select_related = qs.get(pk=task.pk)
|
||||
with self.assertNumQueries(0):
|
||||
self.assertEqual(task_select_related.creator.staffuser.staff,
|
||||
task_get.creator.staffuser.staff)
|
||||
|
|
|
@ -83,6 +83,7 @@ class ReverseSelectRelatedTestCase(TestCase):
|
|||
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
|
||||
self.assertEqual(stat.advanceduserstat.posts, 200)
|
||||
self.assertEqual(stat.user.username, 'bob')
|
||||
with self.assertNumQueries(1):
|
||||
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
|
||||
|
||||
def test_nullable_relation(self):
|
||||
|
|
|
@ -112,7 +112,7 @@ class TeamField(models.CharField):
|
|||
return value
|
||||
return Team(value)
|
||||
|
||||
def from_db_value(self, value, connection):
|
||||
def from_db_value(self, value, connection, context):
|
||||
return Team(value)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
|
|
Loading…
Reference in New Issue