diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index 07d47becf8..fba8e0c222 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -10,7 +10,6 @@ from django.db.models import signals, DO_NOTHING from django.db.models.base import ModelBase from django.db.models.fields.related import ForeignObject, ForeignObjectRel from django.db.models.query_utils import PathInfo -from django.db.models.expressions import Col from django.contrib.contenttypes.models import ContentType from django.utils.encoding import smart_text, python_2_unicode_compatible @@ -367,7 +366,7 @@ class GenericRelation(ForeignObject): field = self.rel.to._meta.get_field(self.content_type_field_name) contenttype_pk = self.get_content_type().pk cond = where_class() - lookup = field.get_lookup('exact')(Col(remote_alias, field, field), contenttype_pk) + lookup = field.get_lookup('exact')(field.get_col(remote_alias), contenttype_pk) cond.add(lookup, 'AND') return cond diff --git a/django/contrib/gis/db/backends/base.py b/django/contrib/gis/db/backends/base.py index 9e9612cc0f..8d8c2e7b7c 100644 --- a/django/contrib/gis/db/backends/base.py +++ b/django/contrib/gis/db/backends/base.py @@ -158,10 +158,10 @@ class BaseSpatialOperations(object): # Default conversion functions for aggregates; will be overridden if implemented # for the spatial backend. - def convert_extent(self, box): + def convert_extent(self, box, srid): raise NotImplementedError('Aggregate extent not implemented for this spatial backend.') - def convert_extent3d(self, box): + def convert_extent3d(self, box, srid): raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.') def convert_geom(self, geom_val, geom_field): diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index ccbe0542f4..a64f0cc60d 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -7,7 +7,6 @@ from django.contrib.gis.db.backends.utils import SpatialOperator class MySQLOperations(DatabaseOperations, BaseSpatialOperations): - compiler_module = 'django.contrib.gis.db.models.sql.compiler' mysql = True name = 'mysql' select = 'AsText(%s)' diff --git a/django/contrib/gis/db/backends/oracle/compiler.py b/django/contrib/gis/db/backends/oracle/compiler.py deleted file mode 100644 index a78765b901..0000000000 --- a/django/contrib/gis/db/backends/oracle/compiler.py +++ /dev/null @@ -1,24 +0,0 @@ -from django.contrib.gis.db.models.sql.compiler import GeoSQLCompiler as BaseGeoSQLCompiler -from django.db.backends.oracle import compiler - -SQLCompiler = compiler.SQLCompiler - - -class GeoSQLCompiler(BaseGeoSQLCompiler, SQLCompiler): - pass - - -class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler): - pass - - -class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler): - pass - - -class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler): - pass - - -class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler): - pass diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index c9e3e8ee31..4c37a501e8 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -52,7 +52,6 @@ class SDORelate(SpatialOperator): class OracleOperations(DatabaseOperations, BaseSpatialOperations): - compiler_module = "django.contrib.gis.db.backends.oracle.compiler" name = 'oracle' oracle = True @@ -111,8 +110,9 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): def geo_quote_name(self, name): return super(OracleOperations, self).geo_quote_name(name).upper() - def get_db_converters(self, internal_type): - converters = super(OracleOperations, self).get_db_converters(internal_type) + def get_db_converters(self, expression): + converters = super(OracleOperations, self).get_db_converters(expression) + internal_type = expression.output_field.get_internal_type() geometry_fields = ( 'PointField', 'GeometryField', 'LineStringField', 'PolygonField', 'MultiPointField', 'MultiLineStringField', @@ -121,14 +121,23 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): ) if internal_type in geometry_fields: converters.append(self.convert_textfield_value) + if hasattr(expression.output_field, 'geom_type'): + converters.append(self.convert_geometry) return converters - def convert_extent(self, clob): + def convert_geometry(self, value, expression, context): + if value: + value = Geometry(value) + if 'transformed_srid' in context: + value.srid = context['transformed_srid'] + return value + + def convert_extent(self, clob, srid): if clob: # Generally, Oracle returns a polygon for the extent -- however, # it can return a single point if there's only one Point in the # table. - ext_geom = Geometry(clob.read()) + ext_geom = Geometry(clob.read(), srid) gtype = str(ext_geom.geom_type) if gtype == 'Polygon': # Construct the 4-tuple from the coordinates in the polygon. @@ -226,7 +235,7 @@ class OracleOperations(DatabaseOperations, BaseSpatialOperations): else: sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' sql_function = getattr(self, agg_name) - return self.select % sql_template, sql_function + return sql_template, sql_function # Routines for getting the OGC-compliant models. def geometry_columns(self): diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 2cd088c59d..1f6d2ef4f3 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -44,7 +44,6 @@ class PostGISDistanceOperator(PostGISOperator): class PostGISOperations(DatabaseOperations, BaseSpatialOperations): - compiler_module = 'django.contrib.gis.db.models.sql.compiler' name = 'postgis' postgis = True geography = True @@ -188,7 +187,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): agg_name = aggregate.__class__.__name__ return agg_name in self.valid_aggregates - def convert_extent(self, box): + def convert_extent(self, box, srid): """ Returns a 4-tuple extent for the `Extent` aggregate by converting the bounding box text returned by PostGIS (`box` argument), for @@ -199,7 +198,7 @@ class PostGISOperations(DatabaseOperations, BaseSpatialOperations): xmax, ymax = map(float, ur.split()) return (xmin, ymin, xmax, ymax) - def convert_extent3d(self, box3d): + def convert_extent3d(self, box3d, srid): """ Returns a 6-tuple extent for the `Extent3D` aggregate by converting the 3d bounding-box text returned by PostGIS (`box3d` argument), for diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 8e9bdd57d2..e94d53a15b 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -14,7 +14,6 @@ from django.utils.functional import cached_property class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): - compiler_module = 'django.contrib.gis.db.models.sql.compiler' name = 'spatialite' spatialite = True version_regex = re.compile(r'^(?P\d)\.(?P\d)\.(?P\d+)') @@ -131,11 +130,11 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): agg_name = aggregate.__class__.__name__ return agg_name in self.valid_aggregates - def convert_extent(self, box): + def convert_extent(self, box, srid): """ Convert the polygon data received from Spatialite to min/max values. """ - shell = Geometry(box).shell + shell = Geometry(box, srid).shell xmin, ymin = shell[0][:2] xmax, ymax = shell[2][:2] return (xmin, ymin, xmax, ymax) @@ -256,7 +255,7 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): agg_name = agg_name.lower() if agg_name == 'union': agg_name += 'agg' - sql_template = self.select % '%(function)s(%(expressions)s)' + sql_template = '%(function)s(%(expressions)s)' sql_function = getattr(self, agg_name) return sql_template, sql_function @@ -268,3 +267,16 @@ class SpatiaLiteOperations(DatabaseOperations, BaseSpatialOperations): def spatial_ref_sys(self): from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys return SpatialiteSpatialRefSys + + def get_db_converters(self, expression): + converters = super(SpatiaLiteOperations, self).get_db_converters(expression) + if hasattr(expression.output_field, 'geom_type'): + converters.append(self.convert_geometry) + return converters + + def convert_geometry(self, value, expression, context): + if value: + value = Geometry(value) + if 'transformed_srid' in context: + value.srid = context['transformed_srid'] + return value diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 0cf0a8b266..c9476d0f02 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -28,7 +28,7 @@ class GeoAggregate(Aggregate): raise ValueError('Geospatial aggregates only allowed on geometry fields.') return c - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): return connection.ops.convert_geom(value, self.output_field) @@ -43,8 +43,8 @@ class Extent(GeoAggregate): def __init__(self, expression, **extra): super(Extent, self).__init__(expression, output_field=ExtentField(), **extra) - def convert_value(self, value, connection): - return connection.ops.convert_extent(value) + def convert_value(self, value, connection, context): + return connection.ops.convert_extent(value, context.get('transformed_srid')) class Extent3D(GeoAggregate): @@ -54,8 +54,8 @@ class Extent3D(GeoAggregate): def __init__(self, expression, **extra): super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra) - def convert_value(self, value, connection): - return connection.ops.convert_extent3d(value) + def convert_value(self, value, connection, context): + return connection.ops.convert_extent3d(value, context.get('transformed_srid')) class MakeLine(GeoAggregate): diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index f2c91aa8ab..927118c0ac 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -42,7 +42,30 @@ def get_srid_info(srid, connection): return _srid_cache[connection.alias][srid] -class GeometryField(Field): +class GeoSelectFormatMixin(object): + def select_format(self, compiler, sql, params): + """ + Returns the selection format string, depending on the requirements + of the spatial backend. For example, Oracle and MySQL require custom + selection formats in order to retrieve geometries in OGC WKT. For all + other fields a simple '%s' format string is returned. + """ + connection = compiler.connection + srid = compiler.query.get_context('transformed_srid') + if srid: + sel_fmt = '%s(%%s, %s)' % (connection.ops.transform, srid) + else: + sel_fmt = '%s' + if connection.ops.select: + # This allows operations to be done on fields in the SELECT, + # overriding their values -- used by the Oracle and MySQL + # spatial backends to get database values as WKT, and by the + # `transform` method. + sel_fmt = connection.ops.select % sel_fmt + return sel_fmt % sql, params + + +class GeometryField(GeoSelectFormatMixin, Field): "The base GIS field -- maps to the OpenGIS Specification Geometry type." # The OpenGIS Geometry name. @@ -196,7 +219,7 @@ class GeometryField(Field): else: return geom - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): if value and not isinstance(value, Geometry): value = Geometry(value) return value @@ -337,7 +360,7 @@ class GeometryCollectionField(GeometryField): description = _("Geometry collection") -class ExtentField(Field): +class ExtentField(GeoSelectFormatMixin, Field): "Used as a return value from an extent aggregate" description = _("Extent Aggregate Field") diff --git a/django/contrib/gis/db/models/query.py b/django/contrib/gis/db/models/query.py index c05688dec9..55f286ff85 100644 --- a/django/contrib/gis/db/models/query.py +++ b/django/contrib/gis/db/models/query.py @@ -1,9 +1,16 @@ from django.db import connections +from django.db.models.expressions import RawSQL +from django.db.models.fields import Field from django.db.models.query import QuerySet from django.contrib.gis.db.models import aggregates -from django.contrib.gis.db.models.fields import get_srid_info, PointField, LineStringField -from django.contrib.gis.db.models.sql import AreaField, DistanceField, GeomField, GeoQuery, GMLField +from django.contrib.gis.db.models.fields import ( + get_srid_info, LineStringField, GeometryField, PointField, +) +from django.contrib.gis.db.models.lookups import GISLookup +from django.contrib.gis.db.models.sql import ( + AreaField, DistanceField, GeomField, GMLField, +) from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import Area, Distance @@ -13,11 +20,6 @@ from django.utils import six class GeoQuerySet(QuerySet): "The Geographic QuerySet." - ### Methods overloaded from QuerySet ### - def __init__(self, model=None, query=None, using=None, hints=None): - super(GeoQuerySet, self).__init__(model=model, query=query, using=using, hints=hints) - self.query = query or GeoQuery(self.model) - ### GeoQuerySet Methods ### def area(self, tolerance=0.05, **kwargs): """ @@ -26,7 +28,8 @@ class GeoQuerySet(QuerySet): """ # Performing setup here rather than in `_spatial_attribute` so that # we can get the units for `AreaField`. - procedure_args, geo_field = self._spatial_setup('area', field_name=kwargs.get('field_name', None)) + procedure_args, geo_field = self._spatial_setup( + 'area', field_name=kwargs.get('field_name', None)) s = {'procedure_args': procedure_args, 'geo_field': geo_field, 'setup': False, @@ -378,24 +381,8 @@ class GeoQuerySet(QuerySet): if not isinstance(srid, six.integer_types): raise TypeError('An integer SRID must be provided.') field_name = kwargs.get('field_name', None) - tmp, geo_field = self._spatial_setup('transform', field_name=field_name) - - # Getting the selection SQL for the given geographic field. - field_col = self._geocol_select(geo_field, field_name) - - # Why cascading substitutions? Because spatial backends like - # Oracle and MySQL already require a function call to convert to text, thus - # when there's also a transformation we need to cascade the substitutions. - # For example, 'SDO_UTIL.TO_WKTGEOMETRY(SDO_CS.TRANSFORM( ... )' - geo_col = self.query.custom_select.get(geo_field, field_col) - - # Setting the key for the field's column with the custom SELECT SQL to - # override the geometry column returned from the database. - custom_sel = '%s(%s, %s)' % (connections[self.db].ops.transform, geo_col, srid) - # TODO: Should we have this as an alias? - # custom_sel = '(%s(%s, %s)) AS %s' % (SpatialBackend.transform, geo_col, srid, qn(geo_field.name)) - self.query.transformed_srid = srid # So other GeoQuerySet methods - self.query.custom_select[geo_field] = custom_sel + self._spatial_setup('transform', field_name=field_name) + self.query.add_context('transformed_srid', srid) return self._clone() def union(self, geom, **kwargs): @@ -433,7 +420,7 @@ class GeoQuerySet(QuerySet): # Is there a geographic field in the model to perform this # operation on? - geo_field = self.query._geo_field(field_name) + geo_field = self._geo_field(field_name) if not geo_field: raise TypeError('%s output only available on GeometryFields.' % func) @@ -454,7 +441,7 @@ class GeoQuerySet(QuerySet): returning their result to the caller of the function. """ # Getting the field the geographic aggregate will be called on. - geo_field = self.query._geo_field(field_name) + geo_field = self._geo_field(field_name) if not geo_field: raise TypeError('%s aggregate only available on GeometryFields.' % aggregate.name) @@ -509,12 +496,12 @@ class GeoQuerySet(QuerySet): settings.setdefault('select_params', []) connection = connections[self.db] - backend = connection.ops # Performing setup for the spatial column, unless told not to. if settings.get('setup', True): - default_args, geo_field = self._spatial_setup(att, desc=settings['desc'], field_name=field_name, - geo_field_type=settings.get('geo_field_type', None)) + default_args, geo_field = self._spatial_setup( + att, desc=settings['desc'], field_name=field_name, + geo_field_type=settings.get('geo_field_type', None)) for k, v in six.iteritems(default_args): settings['procedure_args'].setdefault(k, v) else: @@ -544,18 +531,19 @@ class GeoQuerySet(QuerySet): # If the result of this function needs to be converted. if settings.get('select_field', False): - sel_fld = settings['select_field'] - if isinstance(sel_fld, GeomField) and backend.select: - self.query.custom_select[model_att] = backend.select + select_field = settings['select_field'] if connection.ops.oracle: - sel_fld.empty_strings_allowed = False - self.query.extra_select_fields[model_att] = sel_fld + select_field.empty_strings_allowed = False + else: + select_field = Field() # Finally, setting the extra selection attribute with # the format string expanded with the stored procedure # arguments. - return self.extra(select={model_att: fmt % settings['procedure_args']}, - select_params=settings['select_params']) + self.query.add_annotation( + RawSQL(fmt % settings['procedure_args'], settings['select_params'], select_field), + model_att) + return self def _distance_attribute(self, func, geom=None, tolerance=0.05, spheroid=False, **kwargs): """ @@ -616,8 +604,9 @@ class GeoQuerySet(QuerySet): else: # Getting whether this field is in units of degrees since the field may have # been transformed via the `transform` GeoQuerySet method. - if self.query.transformed_srid: - u, unit_name, s = get_srid_info(self.query.transformed_srid, connection) + srid = self.query.get_context('transformed_srid') + if srid: + u, unit_name, s = get_srid_info(srid, connection) geodetic = unit_name.lower() in geo_field.geodetic_units if geodetic and not connection.features.supports_distance_geodetic: @@ -627,20 +616,20 @@ class GeoQuerySet(QuerySet): ) if distance: - if self.query.transformed_srid: + if srid: # Setting the `geom_args` flag to false because we want to handle # transformation SQL here, rather than the way done by default # (which will transform to the original SRID of the field rather # than to what was transformed to). geom_args = False - procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, self.query.transformed_srid) - if geom.srid is None or geom.srid == self.query.transformed_srid: + procedure_fmt = '%s(%%(geo_col)s, %s)' % (backend.transform, srid) + if geom.srid is None or geom.srid == srid: # If the geom parameter srid is None, it is assumed the coordinates # are in the transformed units. A placeholder is used for the # geometry parameter. `GeomFromText` constructor is also needed # to wrap geom placeholder for SpatiaLite. if backend.spatialite: - procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, self.query.transformed_srid) + procedure_fmt += ', %s(%%%%s, %s)' % (backend.from_text, srid) else: procedure_fmt += ', %%s' else: @@ -649,10 +638,11 @@ class GeoQuerySet(QuerySet): # SpatiaLite also needs geometry placeholder wrapped in `GeomFromText` # constructor. if backend.spatialite: - procedure_fmt += ', %s(%s(%%%%s, %s), %s)' % (backend.transform, backend.from_text, - geom.srid, self.query.transformed_srid) + procedure_fmt += (', %s(%s(%%%%s, %s), %s)' % ( + backend.transform, backend.from_text, + geom.srid, srid)) else: - procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, self.query.transformed_srid) + procedure_fmt += ', %s(%%%%s, %s)' % (backend.transform, srid) else: # `transform()` was not used on this GeoQuerySet. procedure_fmt = '%(geo_col)s,%(geom)s' @@ -743,22 +733,57 @@ class GeoQuerySet(QuerySet): column. Takes into account if the geographic field is in a ForeignKey relation to the current model. """ + compiler = self.query.get_compiler(self.db) opts = self.model._meta if geo_field not in opts.fields: # Is this operation going to be on a related geographic field? # If so, it'll have to be added to the select related information # (e.g., if 'location__point' was given as the field name). + # Note: the operation really is defined as "must add select related!" self.query.add_select_related([field_name]) - compiler = self.query.get_compiler(self.db) + # Call pre_sql_setup() so that compiler.select gets populated. compiler.pre_sql_setup() - for (rel_table, rel_col), field in self.query.related_select_cols: - if field == geo_field: - return compiler._field_column(geo_field, rel_table) - raise ValueError("%r not in self.query.related_select_cols" % geo_field) + for col, _, _ in compiler.select: + if col.output_field == geo_field: + return col.as_sql(compiler, compiler.connection)[0] + raise ValueError("%r not in compiler's related_select_cols" % geo_field) elif geo_field not in opts.local_fields: # This geographic field is inherited from another model, so we have to # use the db table for the _parent_ model instead. parent_model = geo_field.model._meta.concrete_model - return self.query.get_compiler(self.db)._field_column(geo_field, parent_model._meta.db_table) + return self._field_column(compiler, geo_field, parent_model._meta.db_table) else: - return self.query.get_compiler(self.db)._field_column(geo_field) + return self._field_column(compiler, geo_field) + + # Private API utilities, subject to change. + def _geo_field(self, field_name=None): + """ + Returns the first Geometry field encountered or the one specified via + the `field_name` keyword. The `field_name` may be a string specifying + the geometry field on this GeoQuerySet's model, or a lookup string + to a geometry field via a ForeignKey relation. + """ + if field_name is None: + # Incrementing until the first geographic field is found. + for field in self.model._meta.fields: + if isinstance(field, GeometryField): + return field + return False + else: + # Otherwise, check by the given field name -- which may be + # a lookup to a _related_ geographic field. + return GISLookup._check_geo_field(self.model._meta, field_name) + + def _field_column(self, compiler, field, table_alias=None, column=None): + """ + Helper function that returns the database column for the given field. + The table and column are returned (quoted) in the proper format, e.g., + `"geoapp_city"."point"`. If `table_alias` is not specified, the + database table associated with the model of this `GeoQuerySet` will be + used. If `column` is specified, it will be used instead of the value + in `field.column`. + """ + if table_alias is None: + table_alias = compiler.query.get_meta().db_table + return "%s.%s" % (compiler.quote_name_unless_alias(table_alias), + compiler.connection.ops.quote_name(column or field.column)) diff --git a/django/contrib/gis/db/models/sql/__init__.py b/django/contrib/gis/db/models/sql/__init__.py index fc965c5c67..a00e6ba38e 100644 --- a/django/contrib/gis/db/models/sql/__init__.py +++ b/django/contrib/gis/db/models/sql/__init__.py @@ -1,6 +1,5 @@ from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField, GMLField -from django.contrib.gis.db.models.sql.query import GeoQuery __all__ = [ - 'AreaField', 'DistanceField', 'GeomField', 'GMLField', 'GeoQuery', + 'AreaField', 'DistanceField', 'GeomField', 'GMLField' ] diff --git a/django/contrib/gis/db/models/sql/compiler.py b/django/contrib/gis/db/models/sql/compiler.py deleted file mode 100644 index 1501c98136..0000000000 --- a/django/contrib/gis/db/models/sql/compiler.py +++ /dev/null @@ -1,240 +0,0 @@ -from django.db.backends.utils import truncate_name -from django.db.models.sql import compiler -from django.utils import six - -SQLCompiler = compiler.SQLCompiler - - -class GeoSQLCompiler(compiler.SQLCompiler): - - def get_columns(self, with_aliases=False): - """ - Return the list of columns to use in the select statement. If no - columns have been specified, returns all columns relating to fields in - the model. - - If 'with_aliases' is true, any column names that are duplicated - (without the table names) are given unique aliases. This is needed in - some cases to avoid ambiguity with nested queries. - - This routine is overridden from Query to handle customized selection of - geometry columns. - """ - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias)) - for alias, col in six.iteritems(self.query.extra_select)] - params = [] - aliases = set(self.query.extra_select.keys()) - if with_aliases: - col_aliases = aliases.copy() - else: - col_aliases = set() - if self.query.select: - only_load = self.deferred_to_columns() - # This loop customized for GeoQuery. - for col, field in self.query.select: - if isinstance(col, (list, tuple)): - alias, column = col - table = self.query.alias_map[alias].table_name - if table in only_load and column not in only_load[table]: - continue - r = self.get_field_select(field, alias, column) - if with_aliases: - if col[1] in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append('%s AS %s' % (r, qn2(col[1]))) - aliases.add(r) - col_aliases.add(col[1]) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col[1]) - else: - col_sql, col_params = col.as_sql(self, self.connection) - result.append(col_sql) - params.extend(col_params) - - if hasattr(col, 'alias'): - aliases.add(col.alias) - col_aliases.add(col.alias) - - elif self.query.default_cols: - cols, new_aliases = self.get_default_columns(with_aliases, - col_aliases) - result.extend(cols) - aliases.update(new_aliases) - - max_name_length = self.connection.ops.max_name_length() - for alias, annotation in self.query.annotation_select.items(): - agg_sql, agg_params = self.compile(annotation) - if alias is None: - result.append(agg_sql) - else: - result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length)))) - params.extend(agg_params) - - # This loop customized for GeoQuery. - for (table, col), field in self.query.related_select_cols: - r = self.get_field_select(field, table, col) - if with_aliases and col in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col) - - self._select_aliases = aliases - return result, params - - def get_default_columns(self, with_aliases=False, col_aliases=None, - start_alias=None, opts=None, as_pairs=False, from_parent=None): - """ - Computes the default columns for selecting every field in the base - model. Will sometimes be called to pull in related models (e.g. via - select_related), in which case "opts" and "start_alias" will be given - to provide a starting point for the traversal. - - Returns a list of strings, quoted appropriately for use in SQL - directly, as well as a set of aliases used in the select statement (if - 'as_pairs' is True, returns a list of (alias, col_name) pairs instead - of strings as the first component and None as the second component). - - This routine is overridden from Query to handle customized selection of - geometry columns. - """ - result = [] - if opts is None: - opts = self.query.get_meta() - aliases = set() - only_load = self.deferred_to_columns() - seen = self.query.included_inherited_models.copy() - if start_alias: - seen[None] = start_alias - for field in opts.concrete_fields: - model = field.model._meta.concrete_model - if model is opts.model: - model = None - if from_parent and model is not None and issubclass(from_parent, model): - # Avoid loading data for already loaded parents. - continue - alias = self.query.join_parent_model(opts, model, start_alias, seen) - table = self.query.alias_map[alias].table_name - if table in only_load and field.column not in only_load[table]: - continue - if as_pairs: - result.append((alias, field)) - aliases.add(alias) - continue - # This part of the function is customized for GeoQuery. We - # see if there was any custom selection specified in the - # dictionary, and set up the selection format appropriately. - field_sel = self.get_field_select(field, alias) - if with_aliases and field.column in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (field_sel, c_alias)) - col_aliases.add(c_alias) - aliases.add(c_alias) - else: - r = field_sel - result.append(r) - aliases.add(r) - if with_aliases: - col_aliases.add(field.column) - return result, aliases - - def get_converters(self, fields): - converters = super(GeoSQLCompiler, self).get_converters(fields) - for i, alias in enumerate(self.query.extra_select): - field = self.query.extra_select_fields.get(alias) - if field: - backend_converters = self.connection.ops.get_db_converters(field.get_internal_type()) - converters[i] = (backend_converters, [field.from_db_value], field) - return converters - - #### Routines unique to GeoQuery #### - def get_extra_select_format(self, alias): - sel_fmt = '%s' - if hasattr(self.query, 'custom_select') and alias in self.query.custom_select: - sel_fmt = sel_fmt % self.query.custom_select[alias] - return sel_fmt - - def get_field_select(self, field, alias=None, column=None): - """ - Returns the SELECT SQL string for the given field. Figures out - if any custom selection SQL is needed for the column The `alias` - keyword may be used to manually specify the database table where - the column exists, if not in the model associated with this - `GeoQuery`. Similarly, `column` may be used to specify the exact - column name, rather than using the `column` attribute on `field`. - """ - sel_fmt = self.get_select_format(field) - if field in self.query.custom_select: - field_sel = sel_fmt % self.query.custom_select[field] - else: - field_sel = sel_fmt % self._field_column(field, alias, column) - return field_sel - - def get_select_format(self, fld): - """ - Returns the selection format string, depending on the requirements - of the spatial backend. For example, Oracle and MySQL require custom - selection formats in order to retrieve geometries in OGC WKT. For all - other fields a simple '%s' format string is returned. - """ - if self.connection.ops.select and hasattr(fld, 'geom_type'): - # This allows operations to be done on fields in the SELECT, - # overriding their values -- used by the Oracle and MySQL - # spatial backends to get database values as WKT, and by the - # `transform` method. - sel_fmt = self.connection.ops.select - - # Because WKT doesn't contain spatial reference information, - # the SRID is prefixed to the returned WKT to ensure that the - # transformed geometries have an SRID different than that of the - # field -- this is only used by `transform` for Oracle and - # SpatiaLite backends. - if self.query.transformed_srid and (self.connection.ops.oracle or - self.connection.ops.spatialite): - sel_fmt = "'SRID=%d;'||%s" % (self.query.transformed_srid, sel_fmt) - else: - sel_fmt = '%s' - return sel_fmt - - # Private API utilities, subject to change. - def _field_column(self, field, table_alias=None, column=None): - """ - Helper function that returns the database column for the given field. - The table and column are returned (quoted) in the proper format, e.g., - `"geoapp_city"."point"`. If `table_alias` is not specified, the - database table associated with the model of this `GeoQuery` will be - used. If `column` is specified, it will be used instead of the value - in `field.column`. - """ - if table_alias is None: - table_alias = self.query.get_meta().db_table - return "%s.%s" % (self.quote_name_unless_alias(table_alias), - self.connection.ops.quote_name(column or field.column)) - - -class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler): - pass - - -class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler): - pass - - -class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler): - pass - - -class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler): - pass diff --git a/django/contrib/gis/db/models/sql/conversion.py b/django/contrib/gis/db/models/sql/conversion.py index 98112b3285..19e7b5bfbd 100644 --- a/django/contrib/gis/db/models/sql/conversion.py +++ b/django/contrib/gis/db/models/sql/conversion.py @@ -3,6 +3,7 @@ This module holds simple classes to convert geospatial values from the database. """ +from django.contrib.gis.db.models.fields import GeoSelectFormatMixin from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import Area, Distance @@ -10,13 +11,19 @@ from django.contrib.gis.measure import Area, Distance class BaseField(object): empty_strings_allowed = True + def get_db_converters(self, connection): + return [self.from_db_value] + + def select_format(self, compiler, sql, params): + return sql, params + class AreaField(BaseField): "Wrapper for Area values." def __init__(self, area_att): self.area_att = area_att - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): if value is not None: value = Area(**{self.area_att: value}) return value @@ -30,7 +37,7 @@ class DistanceField(BaseField): def __init__(self, distance_att): self.distance_att = distance_att - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): if value is not None: value = Distance(**{self.distance_att: value}) return value @@ -39,12 +46,15 @@ class DistanceField(BaseField): return 'DistanceField' -class GeomField(BaseField): +class GeomField(GeoSelectFormatMixin, BaseField): """ Wrapper for Geometry values. It is a lightweight alternative to using GeometryField (which requires an SQL query upon instantiation). """ - def from_db_value(self, value, connection): + # Hacky marker for get_db_converters() + geom_type = None + + def from_db_value(self, value, connection, context): if value is not None: value = Geometry(value) return value @@ -61,5 +71,5 @@ class GMLField(BaseField): def get_internal_type(self): return 'GMLField' - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): return value diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py deleted file mode 100644 index 9c071b85fc..0000000000 --- a/django/contrib/gis/db/models/sql/query.py +++ /dev/null @@ -1,65 +0,0 @@ -from django.db import connections -from django.db.models.query import sql -from django.db.models.sql.constants import QUERY_TERMS - -from django.contrib.gis.db.models.fields import GeometryField -from django.contrib.gis.db.models.lookups import GISLookup -from django.contrib.gis.db.models import aggregates as gis_aggregates -from django.contrib.gis.db.models.sql.conversion import GeomField - - -class GeoQuery(sql.Query): - """ - A single spatial SQL query. - """ - # Overriding the valid query terms. - query_terms = QUERY_TERMS | set(GeometryField.class_lookups.keys()) - - compiler = 'GeoSQLCompiler' - - #### Methods overridden from the base Query class #### - def __init__(self, model): - super(GeoQuery, self).__init__(model) - # The following attributes are customized for the GeoQuerySet. - # The SpatialBackend classes contain backend-specific routines and functions. - self.custom_select = {} - self.transformed_srid = None - self.extra_select_fields = {} - - def clone(self, *args, **kwargs): - obj = super(GeoQuery, self).clone(*args, **kwargs) - # Customized selection dictionary and transformed srid flag have - # to also be added to obj. - obj.custom_select = self.custom_select.copy() - obj.transformed_srid = self.transformed_srid - obj.extra_select_fields = self.extra_select_fields.copy() - return obj - - def get_aggregation(self, using, force_subq=False): - # Remove any aggregates marked for reduction from the subquery - # and move them to the outer AggregateQuery. - connection = connections[using] - for alias, annotation in self.annotation_select.items(): - if isinstance(annotation, gis_aggregates.GeoAggregate): - if not getattr(annotation, 'is_extent', False) or connection.ops.oracle: - self.extra_select_fields[alias] = GeomField() - return super(GeoQuery, self).get_aggregation(using, force_subq) - - # Private API utilities, subject to change. - def _geo_field(self, field_name=None): - """ - Returns the first Geometry field encountered; or specified via the - `field_name` keyword. The `field_name` may be a string specifying - the geometry field on this GeoQuery's model, or a lookup string - to a geometry field via a ForeignKey relation. - """ - if field_name is None: - # Incrementing until the first geographic field is found. - for fld in self.model._meta.fields: - if isinstance(fld, GeometryField): - return fld - return False - else: - # Otherwise, check by the given field name -- which may be - # a lookup to a _related_ geographic field. - return GISLookup._check_geo_field(self.model._meta, field_name) diff --git a/django/contrib/gis/tests/relatedapp/tests.py b/django/contrib/gis/tests/relatedapp/tests.py index a19e317a14..5beab038fd 100644 --- a/django/contrib/gis/tests/relatedapp/tests.py +++ b/django/contrib/gis/tests/relatedapp/tests.py @@ -231,15 +231,6 @@ class RelatedGeoModelTest(TestCase): self.assertIn('Aurora', names) self.assertIn('Kecksburg', names) - def test11_geoquery_pickle(self): - "Ensuring GeoQuery objects are unpickled correctly. See #10839." - import pickle - from django.contrib.gis.db.models.sql import GeoQuery - qs = City.objects.all() - q_str = pickle.dumps(qs.query) - q = pickle.loads(q_str) - self.assertEqual(GeoQuery, q.__class__) - # TODO: fix on Oracle -- get the following error because the SQL is ordered # by a geometry object, which Oracle apparently doesn't like: # ORA-22901: cannot compare nested table or VARRAY or LOB attributes of an object type diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 3c99819218..6c675cb482 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -1262,7 +1262,7 @@ class BaseDatabaseOperations(object): second = timezone.make_aware(second, tz) return [first, second] - def get_db_converters(self, internal_type): + def get_db_converters(self, expression): """Get a list of functions needed to convert field data. Some field types on some backends do not provide data in the correct @@ -1270,7 +1270,7 @@ class BaseDatabaseOperations(object): """ return [] - def convert_durationfield_value(self, value, field): + def convert_durationfield_value(self, value, expression, context): if value is not None: value = str(decimal.Decimal(value) / decimal.Decimal(1000000)) value = parse_duration(value) diff --git a/django/db/backends/mysql/base.py b/django/db/backends/mysql/base.py index 3b258d35f3..d4117ec724 100644 --- a/django/db/backends/mysql/base.py +++ b/django/db/backends/mysql/base.py @@ -302,7 +302,7 @@ class DatabaseOperations(BaseDatabaseOperations): columns. If no ordering would otherwise be applied, we don't want any implicit sorting going on. """ - return ["NULL"] + return [(None, ("NULL", [], 'asc', False))] def fulltext_search_sql(self, field_name): return 'MATCH (%s) AGAINST (%%s IN BOOLEAN MODE)' % field_name @@ -387,8 +387,9 @@ class DatabaseOperations(BaseDatabaseOperations): return 'POW(%s)' % ','.join(sub_expressions) return super(DatabaseOperations, self).combine_expression(connector, sub_expressions) - def get_db_converters(self, internal_type): - converters = super(DatabaseOperations, self).get_db_converters(internal_type) + def get_db_converters(self, expression): + converters = super(DatabaseOperations, self).get_db_converters(expression) + internal_type = expression.output_field.get_internal_type() if internal_type in ['BooleanField', 'NullBooleanField']: converters.append(self.convert_booleanfield_value) if internal_type == 'UUIDField': @@ -397,17 +398,17 @@ class DatabaseOperations(BaseDatabaseOperations): converters.append(self.convert_textfield_value) return converters - def convert_booleanfield_value(self, value, field): + def convert_booleanfield_value(self, value, expression, context): if value in (0, 1): value = bool(value) return value - def convert_uuidfield_value(self, value, field): + def convert_uuidfield_value(self, value, expression, context): if value is not None: value = uuid.UUID(value) return value - def convert_textfield_value(self, value, field): + def convert_textfield_value(self, value, expression, context): if value is not None: value = force_text(value) return value diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index ea08ff466c..c684ddd842 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -268,8 +268,9 @@ WHEN (new.%(col_name)s IS NULL) sql = field_name # Cast to DATE removes sub-second precision. return sql, [] - def get_db_converters(self, internal_type): - converters = super(DatabaseOperations, self).get_db_converters(internal_type) + def get_db_converters(self, expression): + converters = super(DatabaseOperations, self).get_db_converters(expression) + internal_type = expression.output_field.get_internal_type() if internal_type == 'TextField': converters.append(self.convert_textfield_value) elif internal_type == 'BinaryField': @@ -285,28 +286,29 @@ WHEN (new.%(col_name)s IS NULL) converters.append(self.convert_empty_values) return converters - def convert_empty_values(self, value, field): + def convert_empty_values(self, value, expression, context): # Oracle stores empty strings as null. We need to undo this in # order to adhere to the Django convention of using the empty # string instead of null, but only if the field accepts the # empty string. + field = expression.output_field if value is None and field.empty_strings_allowed: value = '' if field.get_internal_type() == 'BinaryField': value = b'' return value - def convert_textfield_value(self, value, field): + def convert_textfield_value(self, value, expression, context): if isinstance(value, Database.LOB): value = force_text(value.read()) return value - def convert_binaryfield_value(self, value, field): + def convert_binaryfield_value(self, value, expression, context): if isinstance(value, Database.LOB): value = force_bytes(value.read()) return value - def convert_booleanfield_value(self, value, field): + def convert_booleanfield_value(self, value, expression, context): if value in (1, 0): value = bool(value) return value @@ -314,16 +316,16 @@ WHEN (new.%(col_name)s IS NULL) # cx_Oracle always returns datetime.datetime objects for # DATE and TIMESTAMP columns, but Django wants to see a # python datetime.date, .time, or .datetime. - def convert_datefield_value(self, value, field): + def convert_datefield_value(self, value, expression, context): if isinstance(value, Database.Timestamp): return value.date() - def convert_timefield_value(self, value, field): + def convert_timefield_value(self, value, expression, context): if isinstance(value, Database.Timestamp): value = value.time() return value - def convert_uuidfield_value(self, value, field): + def convert_uuidfield_value(self, value, expression, context): if value is not None: value = uuid.UUID(value) return value diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index 4e0cf0c9aa..801351336d 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -269,8 +269,9 @@ class DatabaseOperations(BaseDatabaseOperations): return six.text_type(value) - def get_db_converters(self, internal_type): - converters = super(DatabaseOperations, self).get_db_converters(internal_type) + def get_db_converters(self, expression): + converters = super(DatabaseOperations, self).get_db_converters(expression) + internal_type = expression.output_field.get_internal_type() if internal_type == 'DateTimeField': converters.append(self.convert_datetimefield_value) elif internal_type == 'DateField': @@ -283,25 +284,25 @@ class DatabaseOperations(BaseDatabaseOperations): converters.append(self.convert_uuidfield_value) return converters - def convert_decimalfield_value(self, value, field): - return backend_utils.typecast_decimal(field.format_number(value)) + def convert_decimalfield_value(self, value, expression, context): + return backend_utils.typecast_decimal(expression.output_field.format_number(value)) - def convert_datefield_value(self, value, field): + def convert_datefield_value(self, value, expression, context): if value is not None and not isinstance(value, datetime.date): value = parse_date(value) return value - def convert_datetimefield_value(self, value, field): + def convert_datetimefield_value(self, value, expression, context): if value is not None and not isinstance(value, datetime.datetime): value = parse_datetime_with_timezone_support(value) return value - def convert_timefield_value(self, value, field): + def convert_timefield_value(self, value, expression, context): if value is not None and not isinstance(value, datetime.time): value = parse_time(value) return value - def convert_uuidfield_value(self, value, field): + def convert_uuidfield_value(self, value, expression, context): if value is not None: value = uuid.UUID(value) return value diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index c68378c7da..4b57890cf8 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -87,7 +87,7 @@ class Avg(Aggregate): def __init__(self, expression, **extra): super(Avg, self).__init__(expression, output_field=FloatField(), **extra) - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): if value is None: return value return float(value) @@ -105,7 +105,7 @@ class Count(Aggregate): super(Count, self).__init__( expression, distinct='DISTINCT ' if distinct else '', output_field=IntegerField(), **extra) - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): if value is None: return 0 return int(value) @@ -128,7 +128,7 @@ class StdDev(Aggregate): self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP' super(StdDev, self).__init__(expression, output_field=FloatField(), **extra) - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): if value is None: return value return float(value) @@ -146,7 +146,7 @@ class Variance(Aggregate): self.function = 'VAR_SAMP' if sample else 'VAR_POP' super(Variance, self).__init__(expression, output_field=FloatField(), **extra) - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): if value is None: return value return float(value) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 9fe3a35fee..969b8d3acd 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -127,7 +127,7 @@ class ExpressionNode(CombinableMixin): is_summary = False def get_db_converters(self, connection): - return [self.convert_value] + return [self.convert_value] + self.output_field.get_db_converters(connection) def __init__(self, output_field=None): self._output_field = output_field @@ -240,7 +240,7 @@ class ExpressionNode(CombinableMixin): raise FieldError( "Expression contains mixed types. You must set output_field") - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): """ Expressions provide their own converters because users have the option of manually specifying the output_field which may be a different type @@ -305,6 +305,8 @@ class ExpressionNode(CombinableMixin): return self def get_group_by_cols(self): + if not self.contains_aggregate: + return [self] cols = [] for source in self.get_source_expressions(): cols.extend(source.get_group_by_cols()) @@ -490,6 +492,9 @@ class Value(ExpressionNode): return 'NULL', [] return '%s', [self.value] + def get_group_by_cols(self): + return [] + class DurationValue(Value): def as_sql(self, compiler, connection): @@ -499,6 +504,37 @@ class DurationValue(Value): return connection.ops.date_interval_sql(self.value) +class RawSQL(ExpressionNode): + def __init__(self, sql, params, output_field=None): + if output_field is None: + output_field = fields.Field() + self.sql, self.params = sql, params + super(RawSQL, self).__init__(output_field=output_field) + + def as_sql(self, compiler, connection): + return '(%s)' % self.sql, self.params + + def get_group_by_cols(self): + return [self] + + +class Random(ExpressionNode): + def __init__(self): + super(Random, self).__init__(output_field=fields.FloatField()) + + def as_sql(self, compiler, connection): + return connection.ops.random_function_sql(), [] + + +class ColIndexRef(ExpressionNode): + def __init__(self, idx): + self.idx = idx + super(ColIndexRef, self).__init__() + + def as_sql(self, compiler, connection): + return str(self.idx), [] + + class Col(ExpressionNode): def __init__(self, alias, target, source=None): if source is None: @@ -516,6 +552,9 @@ class Col(ExpressionNode): def get_group_by_cols(self): return [self] + def get_db_converters(self, connection): + return self.output_field.get_db_converters(connection) + class Ref(ExpressionNode): """ @@ -537,7 +576,7 @@ class Ref(ExpressionNode): return self def as_sql(self, compiler, connection): - return "%s" % compiler.quote_name_unless_alias(self.refs), [] + return "%s" % connection.ops.quote_name(self.refs), [] def get_group_by_cols(self): return [self] @@ -581,7 +620,7 @@ class Date(ExpressionNode): copy.lookup_type = self.lookup_type return copy - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): if isinstance(value, datetime.datetime): value = value.date() return value @@ -629,7 +668,7 @@ class DateTime(ExpressionNode): copy.tzname = self.tzname return copy - def convert_value(self, value, connection): + def convert_value(self, value, connection, context): if settings.USE_TZ: if value is None: raise ValueError( diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 5d34137bf2..adf975cfa0 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -333,6 +333,28 @@ class Field(RegisterLookupMixin): ] return [] + def get_col(self, alias, source=None): + if source is None: + source = self + if alias != self.model._meta.db_table or source != self: + from django.db.models.expressions import Col + return Col(alias, self, source) + else: + return self.cached_col + + @cached_property + def cached_col(self): + from django.db.models.expressions import Col + return Col(self.model._meta.db_table, self) + + def select_format(self, compiler, sql, params): + """ + Custom format for select clauses. For example, GIS columns need to be + selected as AsText(table.col) on MySQL as the table.col data can't be used + by Django. + """ + return sql, params + def deconstruct(self): """ Returns enough information to recreate the field as a 4-tuple: diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 9ef6c1350a..e6ecd98537 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -15,7 +15,6 @@ from django.db.models.fields import (AutoField, Field, IntegerField, from django.db.models.lookups import IsNull from django.db.models.query import QuerySet from django.db.models.query_utils import PathInfo -from django.db.models.expressions import Col from django.utils.encoding import force_text, smart_text from django.utils import six from django.utils.deprecation import RemovedInDjango20Warning @@ -1738,26 +1737,26 @@ class ForeignObject(RelatedField): [source.name for source in sources], raw_value), AND) elif lookup_type == 'isnull': - root_constraint.add(IsNull(Col(alias, targets[0], sources[0]), raw_value), AND) + root_constraint.add(IsNull(targets[0].get_col(alias, sources[0]), raw_value), AND) elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte'] and not is_multicolumn)): value = get_normalized_value(raw_value) for target, source, val in zip(targets, sources, value): lookup_class = target.get_lookup(lookup_type) root_constraint.add( - lookup_class(Col(alias, target, source), val), AND) + lookup_class(target.get_col(alias, source), val), AND) elif lookup_type in ['range', 'in'] and not is_multicolumn: values = [get_normalized_value(value) for value in raw_value] value = [val[0] for val in values] lookup_class = targets[0].get_lookup(lookup_type) - root_constraint.add(lookup_class(Col(alias, targets[0], sources[0]), value), AND) + root_constraint.add(lookup_class(targets[0].get_col(alias, sources[0]), value), AND) elif lookup_type == 'in': values = [get_normalized_value(value) for value in raw_value] for value in values: value_constraint = constraint_class() for source, target, val in zip(sources, targets, value): lookup_class = target.get_lookup('exact') - lookup = lookup_class(Col(alias, target, source), val) + lookup = lookup_class(target.get_col(alias, source), val) value_constraint.add(lookup, AND) root_constraint.add(value_constraint, OR) else: diff --git a/django/db/models/query.py b/django/db/models/query.py index 7a447f0313..6774feef0c 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -13,8 +13,7 @@ from django.db import (connections, router, transaction, IntegrityError, DJANGO_VERSION_PICKLE_KEY) from django.db.models.constants import LOOKUP_SEP from django.db.models.fields import AutoField, Empty -from django.db.models.query_utils import (Q, select_related_descend, - deferred_class_factory, InvalidQuery) +from django.db.models.query_utils import Q, deferred_class_factory, InvalidQuery from django.db.models.deletion import Collector from django.db.models.sql.constants import CURSOR from django.db.models import sql @@ -233,76 +232,34 @@ class QuerySet(object): An iterator over the results from applying this QuerySet to the database. """ - fill_cache = False - if connections[self.db].features.supports_select_related: - fill_cache = self.query.select_related - if isinstance(fill_cache, dict): - requested = fill_cache - else: - requested = None - max_depth = self.query.max_depth - - extra_select = list(self.query.extra_select) - annotation_select = list(self.query.annotation_select) - - only_load = self.query.get_loaded_field_names() - fields = self.model._meta.concrete_fields - - load_fields = [] - # If only/defer clauses have been specified, - # build the list of fields that are to be loaded. - if only_load: - for field in self.model._meta.concrete_fields: - model = field.model._meta.model - try: - if field.name in only_load[model]: - # Add a field that has been explicitly included - load_fields.append(field.name) - except KeyError: - # Model wasn't explicitly listed in the only_load table - # Therefore, we need to load all fields from this model - load_fields.append(field.name) - - skip = None - if load_fields: - # Some fields have been deferred, so we have to initialize - # via keyword arguments. - skip = set() - init_list = [] - for field in fields: - if field.name not in load_fields: - skip.add(field.attname) - else: - init_list.append(field.attname) - model_cls = deferred_class_factory(self.model, skip) - else: - model_cls = self.model - init_list = [f.attname for f in fields] - - # Cache db and model outside the loop db = self.db compiler = self.query.get_compiler(using=db) - index_start = len(extra_select) - annotation_start = index_start + len(init_list) - - if fill_cache: - klass_info = get_klass_info(model_cls, max_depth=max_depth, - requested=requested, only_load=only_load) - for row in compiler.results_iter(): - if fill_cache: - obj, _ = get_cached_row(row, index_start, db, klass_info, - offset=len(annotation_select)) - else: - obj = model_cls.from_db(db, init_list, row[index_start:annotation_start]) - - if extra_select: - for i, k in enumerate(extra_select): - setattr(obj, k, row[i]) - - # Add the annotations to the model - if annotation_select: - for i, annotation in enumerate(annotation_select): - setattr(obj, annotation, row[i + annotation_start]) + # Execute the query. This will also fill compiler.select, klass_info, + # and annotations. + results = compiler.execute_sql() + select, klass_info, annotation_col_map = (compiler.select, compiler.klass_info, + compiler.annotation_col_map) + if klass_info is None: + return + model_cls = klass_info['model'] + select_fields = klass_info['select_fields'] + model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 + init_list = [f[0].output_field.attname + for f in select[model_fields_start:model_fields_end]] + if len(init_list) != len(model_cls._meta.concrete_fields): + init_set = set(init_list) + skip = [f.attname for f in model_cls._meta.concrete_fields + if f.attname not in init_set] + model_cls = deferred_class_factory(model_cls, skip) + related_populators = get_related_populators(klass_info, select, db) + for row in compiler.results_iter(results): + obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end]) + if related_populators: + for rel_populator in related_populators: + rel_populator.populate(row, obj) + if annotation_col_map: + for attr_name, col_pos in annotation_col_map.items(): + setattr(obj, attr_name, row[col_pos]) # Add the known related objects to the model, if there are any if self._known_related_objects: @@ -1032,11 +989,8 @@ class QuerySet(object): """ Prepare the query for computing a result that contains aggregate annotations. """ - opts = self.model._meta if self.query.group_by is None: - field_names = [f.attname for f in opts.concrete_fields] - self.query.add_fields(field_names, False) - self.query.set_group_by() + self.query.group_by = True def _prepare(self): return self @@ -1135,9 +1089,11 @@ class ValuesQuerySet(QuerySet): Called by the _clone() method after initializing the rest of the instance. """ + if self.query.group_by is True: + self.query.add_fields([f.attname for f in self.model._meta.concrete_fields], False) + self.query.set_group_by() self.query.clear_deferred_loading() self.query.clear_select_fields() - if self._fields: self.extra_names = [] self.annotation_names = [] @@ -1246,11 +1202,12 @@ class ValuesQuerySet(QuerySet): class ValuesListQuerySet(ValuesQuerySet): def iterator(self): + compiler = self.query.get_compiler(self.db) if self.flat and len(self._fields) == 1: - for row in self.query.get_compiler(self.db).results_iter(): + for row in compiler.results_iter(): yield row[0] elif not self.query.extra_select and not self.query.annotation_select: - for row in self.query.get_compiler(self.db).results_iter(): + for row in compiler.results_iter(): yield tuple(row) else: # When extra(select=...) or an annotation is involved, the extra @@ -1269,7 +1226,7 @@ class ValuesListQuerySet(ValuesQuerySet): else: fields = names - for row in self.query.get_compiler(self.db).results_iter(): + for row in compiler.results_iter(): data = dict(zip(names, row)) yield tuple(data[f] for f in fields) @@ -1281,244 +1238,6 @@ class ValuesListQuerySet(ValuesQuerySet): return clone -def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, - only_load=None, from_parent=None): - """ - Helper function that recursively returns an information for a klass, to be - used in get_cached_row. It exists just to compute this information only - once for entire queryset. Otherwise it would be computed for each row, which - leads to poor performance on large querysets. - - Arguments: - * klass - the class to retrieve (and instantiate) - * max_depth - the maximum depth to which a select_related() - relationship should be explored. - * cur_depth - the current depth in the select_related() tree. - Used in recursive calls to determine if we should dig deeper. - * requested - A dictionary describing the select_related() tree - that is to be retrieved. keys are field names; values are - dictionaries describing the keys on that related object that - are themselves to be select_related(). - * only_load - if the query has had only() or defer() applied, - this is the list of field names that will be returned. If None, - the full field list for `klass` can be assumed. - * from_parent - the parent model used to get to this model - - Note that when travelling from parent to child, we will only load child - fields which aren't in the parent. - """ - if max_depth and requested is None and cur_depth > max_depth: - # We've recursed deeply enough; stop now. - return None - - if only_load: - load_fields = only_load.get(klass) or set() - # When we create the object, we will also be creating populating - # all the parent classes, so traverse the parent classes looking - # for fields that must be included on load. - for parent in klass._meta.get_parent_list(): - fields = only_load.get(parent) - if fields: - load_fields.update(fields) - else: - load_fields = None - - if load_fields: - # Handle deferred fields. - skip = set() - init_list = [] - # Build the list of fields that *haven't* been requested - for field in klass._meta.concrete_fields: - model = field.model._meta.concrete_model - if from_parent and model and issubclass(from_parent, model): - # Avoid loading fields already loaded for parent model for - # child models. - continue - elif field.name not in load_fields: - skip.add(field.attname) - else: - init_list.append(field.attname) - # Retrieve all the requested fields - field_count = len(init_list) - if skip: - klass = deferred_class_factory(klass, skip) - field_names = init_list - else: - field_names = () - else: - # Load all fields on klass - - field_count = len(klass._meta.concrete_fields) - # Check if we need to skip some parent fields. - if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields): - # Only load those fields which haven't been already loaded into - # 'from_parent'. - non_seen_models = [p for p in klass._meta.get_parent_list() - if not issubclass(from_parent, p)] - # Load local fields, too... - non_seen_models.append(klass) - field_names = [f.attname for f in klass._meta.concrete_fields - if f.model in non_seen_models] - field_count = len(field_names) - # Try to avoid populating field_names variable for performance reasons. - # If field_names variable is set, we use **kwargs based model init - # which is slower than normal init. - if field_count == len(klass._meta.concrete_fields): - field_names = () - - restricted = requested is not None - - related_fields = [] - for f in klass._meta.fields: - if select_related_descend(f, restricted, requested, load_fields): - if restricted: - next = requested[f.name] - else: - next = None - klass_info = get_klass_info(f.rel.to, max_depth=max_depth, cur_depth=cur_depth + 1, - requested=next, only_load=only_load) - related_fields.append((f, klass_info)) - - reverse_related_fields = [] - if restricted: - for o in klass._meta.related_objects: - if o.field.unique and select_related_descend(o.field, restricted, requested, - only_load.get(o.related_model), reverse=True): - next = requested[o.field.related_query_name()] - parent = klass if issubclass(o.related_model, klass) else None - klass_info = get_klass_info(o.related_model, max_depth=max_depth, cur_depth=cur_depth + 1, - requested=next, only_load=only_load, from_parent=parent) - reverse_related_fields.append((o.field, klass_info)) - if field_names: - pk_idx = field_names.index(klass._meta.pk.attname) - else: - meta = klass._meta - pk_idx = meta.concrete_fields.index(meta.pk) - - return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx - - -def reorder_for_init(model, field_names, values): - """ - Reorders given field names and values for those fields - to be in the same order as model.__init__() expects to find them. - """ - new_names, new_values = [], [] - for f in model._meta.concrete_fields: - if f.attname not in field_names: - continue - new_names.append(f.attname) - new_values.append(values[field_names.index(f.attname)]) - assert len(new_names) == len(field_names) - return new_names, new_values - - -def get_cached_row(row, index_start, using, klass_info, offset=0, - parent_data=()): - """ - Helper function that recursively returns an object with the specified - related attributes already populated. - - This method may be called recursively to populate deep select_related() - clauses. - - Arguments: - * row - the row of data returned by the database cursor - * index_start - the index of the row at which data for this - object is known to start - * offset - the number of additional fields that are known to - exist in row for `klass`. This usually means the number of - annotated results on `klass`. - * using - the database alias on which the query is being executed. - * klass_info - result of the get_klass_info function - * parent_data - parent model data in format (field, value). Used - to populate the non-local fields of child models. - """ - if klass_info is None: - return None - klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info - - fields = row[index_start:index_start + field_count] - # If the pk column is None (or the equivalent '' in the case the - # connection interprets empty strings as nulls), then the related - # object must be non-existent - set the relation to None. - if (fields[pk_idx] is None or - (connections[using].features.interprets_empty_strings_as_nulls and - fields[pk_idx] == '')): - obj = None - elif field_names: - values = list(fields) - parent_values = [] - parent_field_names = [] - for rel_field, value in parent_data: - parent_field_names.append(rel_field.attname) - parent_values.append(value) - field_names, values = reorder_for_init( - klass, parent_field_names + field_names, - parent_values + values) - obj = klass.from_db(using, field_names, values) - else: - field_names = [f.attname for f in klass._meta.concrete_fields] - obj = klass.from_db(using, field_names, fields) - # Instantiate related fields - index_end = index_start + field_count + offset - # Iterate over each related object, populating any - # select_related() fields - for f, klass_info in related_fields: - # Recursively retrieve the data for the related object - cached_row = get_cached_row(row, index_end, using, klass_info) - # If the recursive descent found an object, populate the - # descriptor caches relevant to the object - if cached_row: - rel_obj, index_end = cached_row - if obj is not None: - # If the base object exists, populate the - # descriptor cache - setattr(obj, f.get_cache_name(), rel_obj) - if f.unique and rel_obj is not None: - # If the field is unique, populate the - # reverse descriptor cache on the related object - setattr(rel_obj, f.rel.get_cache_name(), obj) - - # Now do the same, but for reverse related objects. - # Only handle the restricted case - i.e., don't do a depth - # descent into reverse relations unless explicitly requested - for f, klass_info in reverse_related_fields: - # Transfer data from this object to childs. - parent_data = [] - for rel_field in klass_info[0]._meta.fields: - rel_model = rel_field.model._meta.concrete_model - if rel_model == klass_info[0]._meta.model: - rel_model = None - if rel_model is not None and isinstance(obj, rel_model): - parent_data.append((rel_field, getattr(obj, rel_field.attname))) - # Recursively retrieve the data for the related object - cached_row = get_cached_row(row, index_end, using, klass_info, - parent_data=parent_data) - # If the recursive descent found an object, populate the - # descriptor caches relevant to the object - if cached_row: - rel_obj, index_end = cached_row - if obj is not None: - # populate the reverse descriptor cache - setattr(obj, f.rel.get_cache_name(), rel_obj) - if rel_obj is not None: - # If the related object exists, populate - # the descriptor cache. - setattr(rel_obj, f.get_cache_name(), obj) - # Populate related object caches using parent data. - for rel_field, _ in parent_data: - if rel_field.rel: - setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) - try: - cached_obj = getattr(obj, rel_field.get_cache_name()) - setattr(rel_obj, rel_field.get_cache_name(), cached_obj) - except AttributeError: - # Related object hasn't been cached yet - pass - return obj, index_end - - class RawQuerySet(object): """ Provides an iterator which converts the results of raw SQL queries into @@ -1569,7 +1288,9 @@ class RawQuerySet(object): else: model_cls = self.model fields = [self.model_fields.get(c, None) for c in self.columns] - converters = compiler.get_converters(fields) + converters = compiler.get_converters([ + f.get_col(f.model._meta.db_table) if f else None for f in fields + ]) for values in query: if converters: values = compiler.apply_converters(values, converters) @@ -1920,3 +1641,120 @@ def prefetch_one_level(instances, prefetcher, lookup, level): qs._prefetch_done = True obj._prefetched_objects_cache[cache_name] = qs return all_related_objects, additional_lookups + + +class RelatedPopulator(object): + """ + RelatedPopulator is used for select_related() object instantiation. + + The idea is that each select_related() model will be populated by a + different RelatedPopulator instance. The RelatedPopulator instances get + klass_info and select (computed in SQLCompiler) plus the used db as + input for initialization. That data is used to compute which columns + to use, how to instantiate the model, and how to populate the links + between the objects. + + The actual creation of the objects is done in populate() method. This + method gets row and from_obj as input and populates the select_related() + model instance. + """ + def __init__(self, klass_info, select, db): + self.db = db + # Pre-compute needed attributes. The attributes are: + # - model_cls: the possibly deferred model class to instantiate + # - either: + # - cols_start, cols_end: usually the columns in the row are + # in the same order model_cls.__init__ expects them, so we + # can instantiate by model_cls(*row[cols_start:cols_end]) + # - reorder_for_init: When select_related descends to a child + # class, then we want to reuse the already selected parent + # data. However, in this case the parent data isn't necessarily + # in the same order that Model.__init__ expects it to be, so + # we have to reorder the parent data. The reorder_for_init + # attribute contains a function used to reorder the field data + # in the order __init__ expects it. + # - pk_idx: the index of the primary key field in the reordered + # model data. Used to check if a related object exists at all. + # - init_list: the field attnames fetched from the database. For + # deferred models this isn't the same as all attnames of the + # model's fields. + # - related_populators: a list of RelatedPopulator instances if + # select_related() descends to related models from this model. + # - cache_name, reverse_cache_name: the names to use for setattr + # when assigning the fetched object to the from_obj. If the + # reverse_cache_name is set, then we also set the reverse link. + select_fields = klass_info['select_fields'] + from_parent = klass_info['from_parent'] + if not from_parent: + self.cols_start = select_fields[0] + self.cols_end = select_fields[-1] + 1 + self.init_list = [ + f[0].output_field.attname for f in select[self.cols_start:self.cols_end] + ] + self.reorder_for_init = None + else: + model_init_attnames = [ + f.attname for f in klass_info['model']._meta.concrete_fields + ] + reorder_map = [] + for idx in select_fields: + field = select[idx][0].output_field + init_pos = model_init_attnames.index(field.attname) + reorder_map.append((init_pos, field.attname, idx)) + reorder_map.sort() + self.init_list = [v[1] for v in reorder_map] + pos_list = [row_pos for _, _, row_pos in reorder_map] + + def reorder_for_init(row): + return [row[row_pos] for row_pos in pos_list] + self.reorder_for_init = reorder_for_init + + self.model_cls = self.get_deferred_cls(klass_info, self.init_list) + self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) + self.related_populators = get_related_populators(klass_info, select, self.db) + field = klass_info['field'] + reverse = klass_info['reverse'] + self.reverse_cache_name = None + if reverse: + self.cache_name = field.rel.get_cache_name() + self.reverse_cache_name = field.get_cache_name() + else: + self.cache_name = field.get_cache_name() + if field.unique: + self.reverse_cache_name = field.rel.get_cache_name() + + def get_deferred_cls(self, klass_info, init_list): + model_cls = klass_info['model'] + if len(init_list) != len(model_cls._meta.concrete_fields): + init_set = set(init_list) + skip = [ + f.attname for f in model_cls._meta.concrete_fields + if f.attname not in init_set + ] + model_cls = deferred_class_factory(model_cls, skip) + return model_cls + + def populate(self, row, from_obj): + if self.reorder_for_init: + obj_data = self.reorder_for_init(row) + else: + obj_data = row[self.cols_start:self.cols_end] + if obj_data[self.pk_idx] is None: + obj = None + else: + obj = self.model_cls.from_db(self.db, self.init_list, obj_data) + if obj and self.related_populators: + for rel_iter in self.related_populators: + rel_iter.populate(row, obj) + setattr(from_obj, self.cache_name, obj) + if obj and self.reverse_cache_name: + setattr(obj, self.reverse_cache_name, from_obj) + + +def get_related_populators(klass_info, select, db): + iterators = [] + related_klass_infos = klass_info.get('related_klass_infos', []) + for rel_klass_info in related_klass_infos: + rel_cls = RelatedPopulator(rel_klass_info, select, db) + iterators.append(rel_cls) + return iterators diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index e8b6cfb8c1..3eae98ee65 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -170,7 +170,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa if not restricted and field.null: return False if load_fields: - if field.name not in load_fields: + if field.attname not in load_fields: if restricted and field.name in requested: raise InvalidQuery("Field %s.%s cannot be both deferred" " and traversed using select_related" diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 1c0b99e897..7d9b3ee609 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -2,16 +2,15 @@ from itertools import chain import warnings from django.core.exceptions import FieldError -from django.db.backends.utils import truncate_name from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import RawSQL, Ref, Random, ColIndexRef from django.db.models.query_utils import select_related_descend, QueryWrapper from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS, - ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo) + ORDER_DIR, GET_ITERATOR_CHUNK_SIZE) from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.query import get_order_dir, Query from django.db.transaction import TransactionManagementError from django.db.utils import DatabaseError -from django.utils import six from django.utils.deprecation import RemovedInDjango20Warning from django.utils.six.moves import zip @@ -22,28 +21,268 @@ class SQLCompiler(object): self.connection = connection self.using = using self.quote_cache = {'*': '*'} - # When ordering a queryset with distinct on a column not part of the - # select set, the ordering column needs to be added to the select - # clause. This information is needed both in SQL construction and - # masking away the ordering selects from the returned row. - self.ordering_aliases = [] - self.ordering_params = [] + # The select, klass_info, and annotations are needed by QuerySet.iterator() + # these are set as a side-effect of executing the query. Note that we calculate + # separately a list of extra select columns needed for grammatical correctness + # of the query, but these columns are not included in self.select. + self.select = None + self.annotation_col_map = None + self.klass_info = None + + def setup_query(self): + if all(self.query.alias_refcount[a] == 0 for a in self.query.tables): + self.query.get_initial_alias() + self.select, self.klass_info, self.annotation_col_map = self.get_select() + self.col_count = len(self.select) def pre_sql_setup(self): """ Does any necessary class setup immediately prior to producing SQL. This is for things that can't necessarily be done in __init__ because we might not have all the pieces in place at that time. - # TODO: after the query has been executed, the altered state should be - # cleaned. We are not using a clone() of the query here. """ - if not self.query.tables: - self.query.get_initial_alias() - if (not self.query.select and self.query.default_cols and not - self.query.included_inherited_models): - self.query.setup_inherited_models() - if self.query.select_related and not self.query.related_select_cols: - self.fill_related_selections() + self.setup_query() + order_by = self.get_order_by() + extra_select = self.get_extra_select(order_by, self.select) + group_by = self.get_group_by(self.select + extra_select, order_by) + return extra_select, order_by, group_by + + def get_group_by(self, select, order_by): + """ + Returns a list of 2-tuples of form (sql, params). + + The logic of what exactly the GROUP BY clause contains is hard + to describe in other words than "if it passes the test suite, + then it is correct". + """ + # Some examples: + # SomeModel.objects.annotate(Count('somecol')) + # GROUP BY: all fields of the model + # + # SomeModel.objects.values('name').annotate(Count('somecol')) + # GROUP BY: name + # + # SomeModel.objects.annotate(Count('somecol')).values('name') + # GROUP BY: all cols of the model + # + # SomeModel.objects.values('name', 'pk').annotate(Count('somecol')).values('pk') + # GROUP BY: name, pk + # + # SomeModel.objects.values('name').annotate(Count('somecol')).values('pk') + # GROUP BY: name, pk + # + # In fact, the self.query.group_by is the minimal set to GROUP BY. It + # can't be ever restricted to a smaller set, but additional columns in + # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately + # the end result is that it is impossible to force the query to have + # a chosen GROUP BY clause - you can almost do this by using the form: + # .values(*wanted_cols).annotate(AnAggregate()) + # but any later annotations, extra selects, values calls that + # refer some column outside of the wanted_cols, order_by, or even + # filter calls can alter the GROUP BY clause. + + # The query.group_by is either None (no GROUP BY at all), True + # (group by select fields), or a list of expressions to be added + # to the group by. + if self.query.group_by is None: + return [] + expressions = [] + if self.query.group_by is not True: + # If the group by is set to a list (by .values() call most likely), + # then we need to add everything in it to the GROUP BY clause. + # Backwards compatibility hack for setting query.group_by. Remove + # when we have public API way of forcing the GROUP BY clause. + # Converts string references to expressions. + for expr in self.query.group_by: + if not hasattr(expr, 'as_sql'): + expressions.append(self.query.resolve_ref(expr)) + else: + expressions.append(expr) + # Note that even if the group_by is set, it is only the minimal + # set to group by. So, we need to add cols in select, order_by, and + # having into the select in any case. + for expr, _, _ in select: + cols = expr.get_group_by_cols() + for col in cols: + expressions.append(col) + for expr, _ in order_by: + if expr.contains_aggregate: + continue + # We can skip References to select clause, as all expressions in + # the select clause are already part of the group by. + if isinstance(expr, Ref): + continue + expressions.append(expr) + having = self.query.having.get_group_by_cols() + for expr in having: + expressions.append(expr) + result = [] + seen = set() + expressions = self.collapse_group_by(expressions, having) + + for expr in expressions: + sql, params = self.compile(expr) + if (sql, tuple(params)) not in seen: + result.append((sql, params)) + seen.add((sql, tuple(params))) + return result + + def collapse_group_by(self, expressions, having): + # If the DB can group by primary key, then group by the primary key of + # query's main model. Note that for PostgreSQL the GROUP BY clause must + # include the primary key of every table, but for MySQL it is enough to + # have the main table's primary key. Currently only the MySQL form is + # implemented. + # MySQLism: however, columns in HAVING clause must be added to the + # GROUP BY. + if self.connection.features.allows_group_by_pk: + # The logic here is: if the main model's primary key is in the + # query, then set new_expressions to that field. If that happens, + # then also add having expressions to group by. + pk = None + for expr in expressions: + if (expr.output_field.primary_key and + getattr(expr.output_field, 'model') == self.query.model): + pk = expr + if pk: + expressions = [pk] + [expr for expr in expressions if expr in having] + return expressions + + def get_select(self): + """ + Returns three values: + - a list of 3-tuples of (expression, (sql, params), alias) + - a klass_info structure, + - a dictionary of annotations + + The (sql, params) is what the expression will produce, and alias is the + "AS alias" for the column (possibly None). + + The klass_info structure contains the following information: + - Which model to instantiate + - Which columns for that model are present in the query (by + position of the select clause). + - related_klass_infos: [f, klass_info] to descent into + + The annotations is a dictionary of {'attname': column position} values. + """ + select = [] + klass_info = None + annotations = {} + select_idx = 0 + for alias, (sql, params) in self.query.extra_select.items(): + annotations[alias] = select_idx + select.append((RawSQL(sql, params), alias)) + select_idx += 1 + assert not (self.query.select and self.query.default_cols) + if self.query.default_cols: + select_list = [] + for c in self.get_default_columns(): + select_list.append(select_idx) + select.append((c, None)) + select_idx += 1 + klass_info = { + 'model': self.query.model, + 'select_fields': select_list, + } + # self.query.select is a special case. These columns never go to + # any model. + for col in self.query.select: + select.append((col, None)) + select_idx += 1 + for alias, annotation in self.query.annotation_select.items(): + annotations[alias] = select_idx + select.append((annotation, alias)) + select_idx += 1 + + if self.query.select_related: + related_klass_infos = self.get_related_selections(select) + klass_info['related_klass_infos'] = related_klass_infos + + def get_select_from_parent(klass_info): + for ki in klass_info['related_klass_infos']: + if ki['from_parent']: + ki['select_fields'] = (klass_info['select_fields'] + + ki['select_fields']) + get_select_from_parent(ki) + get_select_from_parent(klass_info) + + ret = [] + for col, alias in select: + ret.append((col, self.compile(col, select_format=True), alias)) + return ret, klass_info, annotations + + def get_order_by(self): + """ + Returns a list of 2-tuples of form (expr, (sql, params)) for the + ORDER BY clause. + + The order_by clause can alter the select clause (for example it + can add aliases to clauses that do not yet have one, or it can + add totally new select clauses). + """ + if self.query.extra_order_by: + ordering = self.query.extra_order_by + elif not self.query.default_ordering: + ordering = self.query.order_by + else: + ordering = (self.query.order_by or self.query.get_meta().ordering or []) + if self.query.standard_ordering: + asc, desc = ORDER_DIR['ASC'] + else: + asc, desc = ORDER_DIR['DESC'] + + order_by = [] + for pos, field in enumerate(ordering): + if field == '?': + order_by.append((Random(), asc, False)) + continue + if isinstance(field, int): + if field < 0: + field = -field + int_ord = desc + order_by.append((ColIndexRef(field), int_ord, True)) + continue + col, order = get_order_dir(field, asc) + if col in self.query.annotation_select: + order_by.append((Ref(col, self.query.annotation_select[col]), order, True)) + continue + if '.' in field: + # This came in through an extra(order_by=...) addition. Pass it + # on verbatim. + table, col = col.split('.', 1) + expr = RawSQL('%s.%s' % (self.quote_name_unless_alias(table), col), []) + order_by.append((expr, order, False)) + continue + if not self.query._extra or get_order_dir(field)[0] not in self.query._extra: + # 'col' is of the form 'field' or 'field1__field2' or + # '-field1__field2__field', etc. + order_by.extend(self.find_ordering_name(field, self.query.get_meta(), + default_order=asc)) + else: + if col not in self.query.extra_select: + order_by.append((RawSQL(*self.query.extra[col]), order, False)) + else: + order_by.append((Ref(col, RawSQL(*self.query.extra[col])), + order, True)) + result = [] + seen = set() + for expr, order, is_ref in order_by: + sql, params = self.compile(expr) + if (sql, tuple(params)) in seen: + continue + seen.add((sql, tuple(params))) + result.append((expr, (sql, params, order, is_ref))) + return result + + def get_extra_select(self, order_by, select): + extra_select = [] + select_sql = [t[1] for t in select] + if self.query.distinct and not self.query.distinct_fields: + for expr, (sql, params, _, is_ref) in order_by: + if not is_ref and (sql, params) not in select_sql: + extra_select.append((expr, (sql, params), None)) + return extra_select def __call__(self, name): """ @@ -72,13 +311,15 @@ class SQLCompiler(object): self.quote_cache[name] = r return r - def compile(self, node): - vendor_impl = getattr( - node, 'as_' + self.connection.vendor, None) + def compile(self, node, select_format=False): + vendor_impl = getattr(node, 'as_' + self.connection.vendor, None) if vendor_impl: - return vendor_impl(self, self.connection) + sql, params = vendor_impl(self, self.connection) else: - return node.as_sql(self, self.connection) + sql, params = node.as_sql(self, self.connection) + if select_format: + return node.output_field.select_format(self, sql, params) + return sql, params def as_sql(self, with_limits=True, with_col_aliases=False): """ @@ -88,91 +329,102 @@ class SQLCompiler(object): If 'with_limits' is False, any limit/offset information is not included in the query. """ - if with_limits and self.query.low_mark == self.query.high_mark: - return '', () - - self.pre_sql_setup() # After executing the query, we must get rid of any joins the query # setup created. So, take note of alias counts before the query ran. # However we do not want to get rid of stuff done in pre_sql_setup(), # as the pre_sql_setup will modify query state in a way that forbids # another run of it. refcounts_before = self.query.alias_refcount.copy() - out_cols, s_params = self.get_columns(with_col_aliases) - ordering, o_params, ordering_group_by = self.get_ordering() + try: + extra_select, order_by, group_by = self.pre_sql_setup() + if with_limits and self.query.low_mark == self.query.high_mark: + return '', () + distinct_fields = self.get_distinct() - distinct_fields = self.get_distinct() + # This must come after 'select', 'ordering', and 'distinct' -- see + # docstring of get_from_clause() for details. + from_, f_params = self.get_from_clause() - # This must come after 'select', 'ordering' and 'distinct' -- see - # docstring of get_from_clause() for details. - from_, f_params = self.get_from_clause() + where, w_params = self.compile(self.query.where) + having, h_params = self.compile(self.query.having) + params = [] + result = ['SELECT'] - where, w_params = self.compile(self.query.where) - having, h_params = self.compile(self.query.having) - having_group_by = self.query.having.get_group_by_cols() - params = [] - for val in six.itervalues(self.query.extra_select): - params.extend(val[1]) + if self.query.distinct: + result.append(self.connection.ops.distinct_sql(distinct_fields)) - result = ['SELECT'] + out_cols = [] + col_idx = 1 + for _, (s_sql, s_params), alias in self.select + extra_select: + if alias: + s_sql = '%s AS %s' % (s_sql, self.connection.ops.quote_name(alias)) + elif with_col_aliases: + s_sql = '%s AS %s' % (s_sql, 'Col%d' % col_idx) + col_idx += 1 + params.extend(s_params) + out_cols.append(s_sql) - if self.query.distinct: - result.append(self.connection.ops.distinct_sql(distinct_fields)) + result.append(', '.join(out_cols)) - result.append(', '.join(out_cols + self.ordering_aliases)) - params.extend(s_params) - params.extend(self.ordering_params) + result.append('FROM') + result.extend(from_) + params.extend(f_params) - result.append('FROM') - result.extend(from_) - params.extend(f_params) + if where: + result.append('WHERE %s' % where) + params.extend(w_params) - if where: - result.append('WHERE %s' % where) - params.extend(w_params) + grouping = [] + for g_sql, g_params in group_by: + grouping.append(g_sql) + params.extend(g_params) + if grouping: + if distinct_fields: + raise NotImplementedError( + "annotate() + distinct(fields) is not implemented.") + if not order_by: + order_by = self.connection.ops.force_no_ordering() + result.append('GROUP BY %s' % ', '.join(grouping)) - grouping, gb_params = self.get_grouping(having_group_by, ordering_group_by) - if grouping: - if distinct_fields: - raise NotImplementedError( - "annotate() + distinct(fields) not implemented.") - if not ordering: - ordering = self.connection.ops.force_no_ordering() - result.append('GROUP BY %s' % ', '.join(grouping)) - params.extend(gb_params) + if having: + result.append('HAVING %s' % having) + params.extend(h_params) - if having: - result.append('HAVING %s' % having) - params.extend(h_params) + if order_by: + ordering = [] + for _, (o_sql, o_params, order, _) in order_by: + ordering.append('%s %s' % (o_sql, order)) + params.extend(o_params) + result.append('ORDER BY %s' % ', '.join(ordering)) - if ordering: - result.append('ORDER BY %s' % ', '.join(ordering)) - params.extend(o_params) + if with_limits: + if self.query.high_mark is not None: + result.append('LIMIT %d' % (self.query.high_mark - self.query.low_mark)) + if self.query.low_mark: + if self.query.high_mark is None: + val = self.connection.ops.no_limit_value() + if val: + result.append('LIMIT %d' % val) + result.append('OFFSET %d' % self.query.low_mark) - if with_limits: - if self.query.high_mark is not None: - result.append('LIMIT %d' % (self.query.high_mark - self.query.low_mark)) - if self.query.low_mark: - if self.query.high_mark is None: - val = self.connection.ops.no_limit_value() - if val: - result.append('LIMIT %d' % val) - result.append('OFFSET %d' % self.query.low_mark) + if self.query.select_for_update and self.connection.features.has_select_for_update: + if self.connection.get_autocommit(): + raise TransactionManagementError( + "select_for_update cannot be used outside of a transaction." + ) - if self.query.select_for_update and self.connection.features.has_select_for_update: - if self.connection.get_autocommit(): - raise TransactionManagementError("select_for_update cannot be used outside of a transaction.") + # If we've been asked for a NOWAIT query but the backend does + # not support it, raise a DatabaseError otherwise we could get + # an unexpected deadlock. + nowait = self.query.select_for_update_nowait + if nowait and not self.connection.features.has_select_for_update_nowait: + raise DatabaseError('NOWAIT is not supported on this database backend.') + result.append(self.connection.ops.for_update_sql(nowait=nowait)) - # If we've been asked for a NOWAIT query but the backend does not support it, - # raise a DatabaseError otherwise we could get an unexpected deadlock. - nowait = self.query.select_for_update_nowait - if nowait and not self.connection.features.has_select_for_update_nowait: - raise DatabaseError('NOWAIT is not supported on this database backend.') - result.append(self.connection.ops.for_update_sql(nowait=nowait)) - - # Finally do cleanup - get rid of the joins we created above. - self.query.reset_refcounts(refcounts_before) - return ' '.join(result), tuple(params) + return ' '.join(result), tuple(params) + finally: + # Finally do cleanup - get rid of the joins we created above. + self.query.reset_refcounts(refcounts_before) def as_nested_sql(self): """ @@ -189,90 +441,7 @@ class SQLCompiler(object): obj.clear_ordering(True) return obj.get_compiler(connection=self.connection).as_sql() - def get_columns(self, with_aliases=False): - """ - Returns the list of columns to use in the select statement, as well as - a list any extra parameters that need to be included. If no columns - have been specified, returns all columns relating to fields in the - model. - - If 'with_aliases' is true, any column names that are duplicated - (without the table names) are given unique aliases. This is needed in - some cases to avoid ambiguity with nested queries. - """ - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)] - params = [] - aliases = set(self.query.extra_select.keys()) - if with_aliases: - col_aliases = aliases.copy() - else: - col_aliases = set() - if self.query.select: - only_load = self.deferred_to_columns() - for col, _ in self.query.select: - if isinstance(col, (list, tuple)): - alias, column = col - table = self.query.alias_map[alias].table_name - if table in only_load and column not in only_load[table]: - continue - r = '%s.%s' % (qn(alias), qn(column)) - if with_aliases: - if col[1] in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append('%s AS %s' % (r, qn2(col[1]))) - aliases.add(r) - col_aliases.add(col[1]) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col[1]) - else: - col_sql, col_params = self.compile(col) - result.append(col_sql) - params.extend(col_params) - - if hasattr(col, 'alias'): - aliases.add(col.alias) - col_aliases.add(col.alias) - - elif self.query.default_cols: - cols, new_aliases = self.get_default_columns(with_aliases, - col_aliases) - result.extend(cols) - aliases.update(new_aliases) - - max_name_length = self.connection.ops.max_name_length() - for alias, annotation in self.query.annotation_select.items(): - agg_sql, agg_params = self.compile(annotation) - if alias is None: - result.append(agg_sql) - else: - result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length)))) - params.extend(agg_params) - - for (table, col), _ in self.query.related_select_cols: - r = '%s.%s' % (qn(table), qn(col)) - if with_aliases and col in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s AS %s' % (r, c_alias)) - aliases.add(c_alias) - col_aliases.add(c_alias) - else: - result.append(r) - aliases.add(r) - col_aliases.add(col) - - self._select_aliases = aliases - return result, params - - def get_default_columns(self, with_aliases=False, col_aliases=None, - start_alias=None, opts=None, as_pairs=False, from_parent=None): + def get_default_columns(self, start_alias=None, opts=None, from_parent=None): """ Computes the default columns for selecting every field in the base model. Will sometimes be called to pull in related models (e.g. via @@ -287,9 +456,6 @@ class SQLCompiler(object): result = [] if opts is None: opts = self.query.get_meta() - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - aliases = set() only_load = self.deferred_to_columns() if not start_alias: start_alias = self.query.get_initial_alias() @@ -304,38 +470,21 @@ class SQLCompiler(object): # will assign None if the field belongs to this model. if model == opts.model: model = None - if from_parent and model is not None and issubclass(from_parent, model): + if from_parent and model is not None and issubclass( + from_parent._meta.concrete_model, model._meta.concrete_model): # Avoid loading data for already loaded parents. + # We end up here in the case select_related() resolution + # proceeds from parent model to child model. In that case the + # parent model data is already present in the SELECT clause, + # and we want to avoid reloading the same data again. + continue + if field.model in only_load and field.attname not in only_load[field.model]: continue alias = self.query.join_parent_model(opts, model, start_alias, seen_models) - column = field.column - for seen_model, seen_alias in seen_models.items(): - if seen_model and seen_alias == alias: - ancestor_link = seen_model._meta.get_ancestor_link(model) - if ancestor_link: - column = ancestor_link.column - break - table = self.query.alias_map[alias].table_name - if table in only_load and column not in only_load[table]: - continue - if as_pairs: - result.append((alias, field)) - aliases.add(alias) - continue - if with_aliases and column in col_aliases: - c_alias = 'Col%d' % len(col_aliases) - result.append('%s.%s AS %s' % (qn(alias), - qn2(column), c_alias)) - col_aliases.add(c_alias) - aliases.add(c_alias) - else: - r = '%s.%s' % (qn(alias), qn2(column)) - result.append(r) - aliases.add(r) - if with_aliases: - col_aliases.add(column) - return result, aliases + column = field.get_col(alias) + result.append(column) + return result def get_distinct(self): """ @@ -357,107 +506,6 @@ class SQLCompiler(object): result.append("%s.%s" % (qn(alias), qn2(target.column))) return result - def get_ordering(self): - """ - Returns a tuple containing a list representing the SQL elements in the - "order by" clause, and the list of SQL elements that need to be added - to the GROUP BY clause as a result of the ordering. - - Also sets the ordering_aliases attribute on this instance to a list of - extra aliases needed in the select. - - Determining the ordering SQL can change the tables we need to include, - so this should be run *before* get_from_clause(). - """ - if self.query.extra_order_by: - ordering = self.query.extra_order_by - elif not self.query.default_ordering: - ordering = self.query.order_by - else: - ordering = (self.query.order_by - or self.query.get_meta().ordering - or []) - qn = self.quote_name_unless_alias - qn2 = self.connection.ops.quote_name - distinct = self.query.distinct - select_aliases = self._select_aliases - result = [] - group_by = [] - ordering_aliases = [] - if self.query.standard_ordering: - asc, desc = ORDER_DIR['ASC'] - else: - asc, desc = ORDER_DIR['DESC'] - - # It's possible, due to model inheritance, that normal usage might try - # to include the same field more than once in the ordering. We track - # the table/column pairs we use and discard any after the first use. - processed_pairs = set() - - params = [] - ordering_params = [] - # For plain DISTINCT queries any ORDER BY clause must appear - # in SELECT clause. - # http://www.postgresql.org/message-id/27009.1171559417@sss.pgh.pa.us - must_append_to_select = distinct and not self.query.distinct_fields - for pos, field in enumerate(ordering): - if field == '?': - result.append(self.connection.ops.random_function_sql()) - continue - if isinstance(field, int): - if field < 0: - order = desc - field = -field - else: - order = asc - result.append('%s %s' % (field, order)) - group_by.append((str(field), [])) - continue - col, order = get_order_dir(field, asc) - if col in self.query.annotation_select: - result.append('%s %s' % (qn(col), order)) - continue - if '.' in field: - # This came in through an extra(order_by=...) addition. Pass it - # on verbatim. - table, col = col.split('.', 1) - if (table, col) not in processed_pairs: - elt = '%s.%s' % (qn(table), col) - processed_pairs.add((table, col)) - if not must_append_to_select or elt in select_aliases: - result.append('%s %s' % (elt, order)) - group_by.append((elt, [])) - elif not self.query._extra or get_order_dir(field)[0] not in self.query._extra: - # 'col' is of the form 'field' or 'field1__field2' or - # '-field1__field2__field', etc. - for table, cols, order in self.find_ordering_name(field, - self.query.get_meta(), default_order=asc): - for col in cols: - if (table, col) not in processed_pairs: - elt = '%s.%s' % (qn(table), qn2(col)) - processed_pairs.add((table, col)) - if must_append_to_select and elt not in select_aliases: - ordering_aliases.append(elt) - result.append('%s %s' % (elt, order)) - group_by.append((elt, [])) - else: - elt = qn2(col) - if col not in self.query.extra_select: - if must_append_to_select: - sql = "(%s) AS %s" % (self.query.extra[col][0], elt) - ordering_aliases.append(sql) - ordering_params.extend(self.query.extra[col][1]) - result.append('%s %s' % (elt, order)) - else: - result.append("(%s) %s" % (self.query.extra[col][0], order)) - params.extend(self.query.extra[col][1]) - else: - result.append('%s %s' % (elt, order)) - group_by.append(self.query.extra[col]) - self.ordering_aliases = ordering_aliases - self.ordering_params = ordering_params - return result, params, group_by - def find_ordering_name(self, name, opts, alias=None, default_order='ASC', already_seen=None): """ @@ -487,11 +535,11 @@ class SQLCompiler(object): order, already_seen)) return results targets, alias, _ = self.query.trim_joins(targets, joins, path) - return [(alias, [t.column for t in targets], order)] + return [(t.get_col(alias), order, False) for t in targets] def _setup_joins(self, pieces, opts, alias): """ - A helper method for get_ordering and get_distinct. + A helper method for get_order_by and get_distinct. Note that get_ordering and get_distinct must produce same target columns on same input, as the prefixes of get_ordering and get_distinct @@ -538,67 +586,8 @@ class SQLCompiler(object): result.append(', %s' % self.quote_name_unless_alias(alias)) return result, params - def get_grouping(self, having_group_by, ordering_group_by): - """ - Returns a tuple representing the SQL elements in the "group by" clause. - """ - qn = self.quote_name_unless_alias - result, params = [], [] - if self.query.group_by is not None: - select_cols = self.query.select + self.query.related_select_cols - # Just the column, not the fields. - select_cols = [s[0] for s in select_cols] - if (len(self.query.get_meta().concrete_fields) == len(self.query.select) - and self.connection.features.allows_group_by_pk): - self.query.group_by = [ - (self.query.get_initial_alias(), self.query.get_meta().pk.column) - ] - select_cols = [] - seen = set() - cols = self.query.group_by + having_group_by + select_cols - for col in cols: - col_params = () - if isinstance(col, (list, tuple)): - sql = '%s.%s' % (qn(col[0]), qn(col[1])) - elif hasattr(col, 'as_sql'): - sql, col_params = self.compile(col) - else: - sql = '(%s)' % str(col) - if sql not in seen or col_params: - result.append(sql) - params.extend(col_params) - seen.add(sql) - - # Still, we need to add all stuff in ordering (except if the backend can - # group by just by PK). - if ordering_group_by and not self.connection.features.allows_group_by_pk: - for order, order_params in ordering_group_by: - # Even if we have seen the same SQL string, it might have - # different params, so, we add same SQL in "has params" case. - if order not in seen or order_params: - result.append(order) - params.extend(order_params) - seen.add(order) - - # Unconditionally add the extra_select items. - for extra_select, extra_params in self.query.extra_select.values(): - sql = '(%s)' % str(extra_select) - result.append(sql) - params.extend(extra_params) - # Finally, add needed group by cols from annotations - for annotation in self.query.annotation_select.values(): - cols = annotation.get_group_by_cols() - for col in cols: - sql, col_params = self.compile(col) - if sql not in seen or col_params: - result.append(sql) - seen.add(sql) - params.extend(col_params) - - return result, params - - def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1, - requested=None, restricted=None): + def get_related_selections(self, select, opts=None, root_alias=None, cur_depth=1, + requested=None, restricted=None): """ Fill in the information needed for a select_related query. The current depth is measured as the number of connections away from the root model @@ -613,14 +602,14 @@ class SQLCompiler(object): ) return chain(direct_choices, reverse_choices) + related_klass_infos = [] if not restricted and self.query.max_depth and cur_depth > self.query.max_depth: # We've recursed far enough; bail out. - return + return related_klass_infos if not opts: opts = self.query.get_meta() root_alias = self.query.get_initial_alias() - self.query.related_select_cols = [] only_load = self.query.get_loaded_field_names() # Setup for the case when only particular related fields should be @@ -633,6 +622,9 @@ class SQLCompiler(object): else: restricted = False + def get_related_klass_infos(klass_info, related_klass_infos): + klass_info['related_klass_infos'] = related_klass_infos + for f in opts.fields: field_model = f.model._meta.concrete_model fields_found.add(f.name) @@ -656,15 +648,25 @@ class SQLCompiler(object): if not select_related_descend(f, restricted, requested, only_load.get(field_model)): continue + klass_info = { + 'model': f.rel.to, + 'field': f, + 'reverse': False, + 'from_parent': False, + } + related_klass_infos.append(klass_info) + select_fields = [] _, _, _, joins, _ = self.query.setup_joins( [f.name], opts, root_alias) alias = joins[-1] - columns, _ = self.get_default_columns(start_alias=alias, - opts=f.rel.to._meta, as_pairs=True) - self.query.related_select_cols.extend( - SelectInfo((col[0], col[1].column), col[1]) for col in columns - ) - self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, next, restricted) + columns = self.get_default_columns(start_alias=alias, opts=f.rel.to._meta) + for col in columns: + select_fields.append(len(select)) + select.append((col, None)) + klass_info['select_fields'] = select_fields + next_klass_infos = self.get_related_selections( + select, f.rel.to._meta, alias, cur_depth + 1, next, restricted) + get_related_klass_infos(klass_info, next_klass_infos) if restricted: related_fields = [ @@ -682,16 +684,26 @@ class SQLCompiler(object): _, _, _, joins, _ = self.query.setup_joins([related_field_name], opts, root_alias) alias = joins[-1] - from_parent = (opts.model if issubclass(model, opts.model) - else None) - columns, _ = self.get_default_columns(start_alias=alias, - opts=model._meta, as_pairs=True, from_parent=from_parent) - self.query.related_select_cols.extend( - SelectInfo((col[0], col[1].column), col[1]) for col in columns) + from_parent = issubclass(model, opts.model) + klass_info = { + 'model': model, + 'field': f, + 'reverse': True, + 'from_parent': from_parent, + } + related_klass_infos.append(klass_info) + select_fields = [] + columns = self.get_default_columns( + start_alias=alias, opts=model._meta, from_parent=opts.model) + for col in columns: + select_fields.append(len(select)) + select.append((col, None)) + klass_info['select_fields'] = select_fields next = requested.get(f.related_query_name(), {}) - self.fill_related_selections(model._meta, alias, cur_depth + 1, - next, restricted) - + next_klass_infos = self.get_related_selections( + select, model._meta, alias, cur_depth + 1, + next, restricted) + get_related_klass_infos(klass_info, next_klass_infos) fields_not_found = set(requested.keys()).difference(fields_found) if fields_not_found: invalid_fields = ("'%s'" % s for s in fields_not_found) @@ -702,6 +714,7 @@ class SQLCompiler(object): ', '.join(_get_field_choices()) or '(none)', ) ) + return related_klass_infos def deferred_to_columns(self): """ @@ -710,22 +723,17 @@ class SQLCompiler(object): dictionary. """ columns = {} - self.query.deferred_to_data(columns, self.query.deferred_to_columns_cb) + self.query.deferred_to_data(columns, self.query.get_loaded_field_names_cb) return columns - def get_converters(self, fields): + def get_converters(self, expressions): converters = {} - index_extra_select = len(self.query.extra_select) - for i, field in enumerate(fields): - if field: - try: - output_field = field.output_field - except AttributeError: - output_field = field - backend_converters = self.connection.ops.get_db_converters(output_field.get_internal_type()) - field_converters = field.get_db_converters(self.connection) + for i, expression in enumerate(expressions): + if expression: + backend_converters = self.connection.ops.get_db_converters(expression) + field_converters = expression.get_db_converters(self.connection) if backend_converters or field_converters: - converters[index_extra_select + i] = (backend_converters, field_converters, output_field) + converters[i] = (backend_converters, field_converters, expression) return converters def apply_converters(self, row, converters): @@ -733,62 +741,23 @@ class SQLCompiler(object): for pos, (backend_converters, field_converters, field) in converters.items(): value = row[pos] for converter in backend_converters: - value = converter(value, field) + value = converter(value, field, self.query.context) for converter in field_converters: - value = converter(value, self.connection) + value = converter(value, self.connection, self.query.context) row[pos] = value return tuple(row) - def results_iter(self): + def results_iter(self, results=None): """ Returns an iterator over the results from executing this query. """ - fields = None converters = None - has_annotation_select = bool(self.query.annotation_select) - for rows in self.execute_sql(MULTI): + if results is None: + results = self.execute_sql(MULTI) + fields = [s[0] for s in self.select[0:self.col_count]] + converters = self.get_converters(fields) + for rows in results: for row in rows: - if fields is None: - # We only set this up here because - # related_select_cols isn't populated until - # execute_sql() has been called. - - # If the field was deferred, exclude it from being passed - # into `get_converters` because it wasn't selected. - only_load = self.deferred_to_columns() - - # This code duplicates the logic for the order of fields - # found in get_columns(). It would be nice to clean this up. - if self.query.select: - fields = [f.field for f in self.query.select] - elif self.query.default_cols: - fields = list(self.query.get_meta().concrete_fields) - else: - fields = [] - - if only_load: - # strip deferred fields - fields = [ - f for f in fields if - f.model._meta.db_table not in only_load or - f.column in only_load[f.model._meta.db_table] - ] - - # annotations come before the related cols - if has_annotation_select: - # extra is always at the start of the field list - fields = fields + [ - anno for alias, anno in self.query.annotation_select.items()] - - # add related fields - fields = fields + [ - # strip deferred - f.field for f in self.query.related_select_cols if - f.field.model._meta.db_table not in only_load or - f.field.column in only_load[f.field.model._meta.db_table] - ] - - converters = self.get_converters(fields) if converters: row = self.apply_converters(row, converters) yield row @@ -841,9 +810,10 @@ class SQLCompiler(object): return cursor if result_type == SINGLE: try: - if self.ordering_aliases: - return cursor.fetchone()[:-len(self.ordering_aliases)] - return cursor.fetchone() + val = cursor.fetchone() + if val: + return val[0:self.col_count] + return val finally: # done with the cursor cursor.close() @@ -851,13 +821,10 @@ class SQLCompiler(object): cursor.close() return - # The MULTI case. - if self.ordering_aliases: - result = order_modified_iter(cursor, len(self.ordering_aliases), - self.connection.features.empty_fetchmany_value) - else: - result = cursor_iter(cursor, - self.connection.features.empty_fetchmany_value) + result = cursor_iter( + cursor, self.connection.features.empty_fetchmany_value, + self.col_count + ) if not self.connection.features.can_use_chunked_reads: try: # If we are using non-chunked reads, we return the same data @@ -871,17 +838,16 @@ class SQLCompiler(object): def as_subquery_condition(self, alias, columns, compiler): qn = compiler.quote_name_unless_alias - inner_qn = self.quote_name_unless_alias qn2 = self.connection.ops.quote_name if len(columns) == 1: sql, params = self.as_sql() return '%s.%s IN (%s)' % (qn(alias), qn2(columns[0]), sql), params for index, select_col in enumerate(self.query.select): - lhs = '%s.%s' % (inner_qn(select_col.col[0]), qn2(select_col.col[1])) + lhs_sql, lhs_params = self.compile(select_col) rhs = '%s.%s' % (qn(alias), qn2(columns[index])) self.query.where.add( - QueryWrapper('%s = %s' % (lhs, rhs), []), 'AND') + QueryWrapper('%s = %s' % (lhs_sql, rhs), lhs_params), 'AND') sql, params = self.as_sql() return 'EXISTS (%s)' % sql, params @@ -1074,24 +1040,19 @@ class SQLUpdateCompiler(SQLCompiler): the id values to update at this point so that they don't change as a result of the progressive updates. """ - self.query.select_related = False - self.query.clear_ordering(True) - super(SQLUpdateCompiler, self).pre_sql_setup() + refcounts_before = self.query.alias_refcount.copy() + # Ensure base table is in the query + self.query.get_initial_alias() count = self.query.count_active_tables() if not self.query.related_updates and count == 1: return - - # We need to use a sub-select in the where clause to filter on things - # from other tables. query = self.query.clone(klass=Query) + query.select_related = False + query.clear_ordering(True) query._extra = {} query.select = [] query.add_fields([query.get_meta().pk.name]) - # Recheck the count - it is possible that fiddling with the select - # fields above removes tables from the query. Refs #18304. - count = query.count_active_tables() - if not self.query.related_updates and count == 1: - return + super(SQLUpdateCompiler, self).pre_sql_setup() must_pre_select = count > 1 and not self.connection.features.update_can_self_select @@ -1110,8 +1071,7 @@ class SQLUpdateCompiler(SQLCompiler): else: # The fast path. Filters and updates in one query. self.query.add_filter(('pk__in', query)) - for alias in self.query.tables[1:]: - self.query.alias_refcount[alias] = 0 + self.query.reset_refcounts(refcounts_before) class SQLAggregateCompiler(SQLCompiler): @@ -1130,6 +1090,7 @@ class SQLAggregateCompiler(SQLCompiler): agg_sql, agg_params = self.compile(annotation) sql.append(agg_sql) params.extend(agg_params) + self.col_count = len(self.query.annotation_select) sql = ', '.join(sql) params = tuple(params) @@ -1138,29 +1099,14 @@ class SQLAggregateCompiler(SQLCompiler): return sql, params -def cursor_iter(cursor, sentinel): +def cursor_iter(cursor, sentinel, col_count): """ Yields blocks of rows from a cursor and ensures the cursor is closed when done. """ try: for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - sentinel): - yield rows - finally: - cursor.close() - - -def order_modified_iter(cursor, trim, sentinel): - """ - Yields blocks of rows from a cursor. We use this iterator in the special - case when extra output columns have been added to support ordering - requirements. We must trim those extra columns before anything else can use - the results, since they're only needed to make the SQL valid. - """ - try: - for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)), - sentinel): - yield [r[:-trim] for r in rows] + sentinel): + yield [r[0:col_count] for r in rows] finally: cursor.close() diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index e0e3f10100..57857796b8 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -2,7 +2,6 @@ Constants specific to the SQL storage portion of the ORM. """ -from collections import namedtuple import re # Valid query types (a set is used for speedy lookups). These are (currently) @@ -21,9 +20,6 @@ GET_ITERATOR_CHUNK_SIZE = 100 # Namedtuples for sql.* internal use. -# Pairs of column clauses to select, and (possibly None) field for the clause. -SelectInfo = namedtuple('SelectInfo', 'col field') - # How many results to expect from a cursor.execute call MULTI = 'multi' SINGLE = 'single' diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 1d1dbd8162..a613d8eba4 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -21,7 +21,7 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Ref from django.db.models.query_utils import PathInfo, Q, refs_aggregate from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, - ORDER_PATTERN, SelectInfo, INNER, LOUTER) + ORDER_PATTERN, INNER, LOUTER) from django.db.models.sql.datastructures import ( EmptyResultSet, Empty, MultiJoin, Join, BaseTable) from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, @@ -46,7 +46,7 @@ class RawQuery(object): A single raw SQL query """ - def __init__(self, sql, using, params=None): + def __init__(self, sql, using, params=None, context=None): self.params = params or () self.sql = sql self.using = using @@ -57,9 +57,10 @@ class RawQuery(object): self.low_mark, self.high_mark = 0, None # Used for offset/limit self.extra_select = {} self.annotation_select = {} + self.context = context or {} def clone(self, using): - return RawQuery(self.sql, using, params=self.params) + return RawQuery(self.sql, using, params=self.params, context=self.context.copy()) def get_columns(self): if self.cursor is None: @@ -122,20 +123,23 @@ class Query(object): self.standard_ordering = True self.used_aliases = set() self.filter_is_sticky = False - self.included_inherited_models = {} # SQL-related attributes - # Select and related select clauses as SelectInfo instances. + # Select and related select clauses are expressions to use in the + # SELECT clause of the query. # The select is used for cases where we want to set up the select - # clause to contain other than default fields (values(), annotate(), - # subqueries...) + # clause to contain other than default fields (values(), subqueries...) + # Note that annotations go to annotations dictionary. self.select = [] - # The related_select_cols is used for columns needed for - # select_related - this is populated in the compile stage. - self.related_select_cols = [] self.tables = [] # Aliases in the order they are created. self.where = where() self.where_class = where + # The group_by attribute can have one of the following forms: + # - None: no group by at all in the query + # - A list of expressions: group by (at least) those expressions. + # String refs are also allowed for now. + # - True: group by all select fields of the model + # See compiler.get_group_by() for details. self.group_by = None self.having = where() self.order_by = [] @@ -174,6 +178,8 @@ class Query(object): # load. self.deferred_loading = (set(), True) + self.context = {} + @property def extra(self): if self._extra is None: @@ -254,14 +260,14 @@ class Query(object): obj.default_cols = self.default_cols obj.default_ordering = self.default_ordering obj.standard_ordering = self.standard_ordering - obj.included_inherited_models = self.included_inherited_models.copy() obj.select = self.select[:] - obj.related_select_cols = [] obj.tables = self.tables[:] obj.where = self.where.clone() obj.where_class = self.where_class if self.group_by is None: obj.group_by = None + elif self.group_by is True: + obj.group_by = True else: obj.group_by = self.group_by[:] obj.having = self.having.clone() @@ -272,7 +278,6 @@ class Query(object): obj.select_for_update = self.select_for_update obj.select_for_update_nowait = self.select_for_update_nowait obj.select_related = self.select_related - obj.related_select_cols = [] obj._annotations = self._annotations.copy() if self._annotations is not None else None if self.annotation_select_mask is None: obj.annotation_select_mask = None @@ -310,8 +315,15 @@ class Query(object): obj.__dict__.update(kwargs) if hasattr(obj, '_setup_query'): obj._setup_query() + obj.context = self.context.copy() return obj + def add_context(self, key, value): + self.context[key] = value + + def get_context(self, key, default=None): + return self.context.get(key, default) + def relabeled_clone(self, change_map): clone = self.clone() clone.change_aliases(change_map) @@ -375,7 +387,8 @@ class Query(object): # done in a subquery so that we are aggregating on the limit and/or # distinct results instead of applying the distinct and limit after the # aggregation. - if (self.group_by or has_limit or has_existing_annotations or self.distinct): + if (isinstance(self.group_by, list) or has_limit or has_existing_annotations or + self.distinct): from django.db.models.sql.subqueries import AggregateQuery outer_query = AggregateQuery(self.model) inner_query = self.clone() @@ -383,7 +396,6 @@ class Query(object): inner_query.clear_ordering(True) inner_query.select_for_update = False inner_query.select_related = False - inner_query.related_select_cols = [] relabels = {t: 'subquery' for t in inner_query.tables} relabels[None] = 'subquery' @@ -407,26 +419,17 @@ class Query(object): self.select = [] self.default_cols = False self._extra = {} - self.remove_inherited_models() outer_query.clear_ordering(True) outer_query.clear_limits() outer_query.select_for_update = False outer_query.select_related = False - outer_query.related_select_cols = [] compiler = outer_query.get_compiler(using) result = compiler.execute_sql(SINGLE) if result is None: result = [None for q in outer_query.annotation_select.items()] - fields = [annotation.output_field - for alias, annotation in outer_query.annotation_select.items()] - converters = compiler.get_converters(fields) - for position, (alias, annotation) in enumerate(outer_query.annotation_select.items()): - if position in converters: - converters[position][1].insert(0, annotation.convert_value) - else: - converters[position] = ([], [annotation.convert_value], annotation.output_field) + converters = compiler.get_converters(outer_query.annotation_select.values()) result = compiler.apply_converters(result, converters) return { @@ -476,7 +479,6 @@ class Query(object): assert self.distinct_fields == rhs.distinct_fields, \ "Cannot combine queries with different distinct fields." - self.remove_inherited_models() # Work out how to relabel the rhs aliases, if necessary. change_map = {} conjunction = (connector == AND) @@ -545,13 +547,8 @@ class Query(object): # Selection columns and extra extensions are those provided by 'rhs'. self.select = [] - for col, field in rhs.select: - if isinstance(col, (list, tuple)): - new_col = change_map.get(col[0], col[0]), col[1] - self.select.append(SelectInfo(new_col, field)) - else: - new_col = col.relabeled_clone(change_map) - self.select.append(SelectInfo(new_col, field)) + for col in rhs.select: + self.add_select(col.relabeled_clone(change_map)) if connector == OR: # It would be nice to be able to handle this, but the queries don't @@ -661,17 +658,6 @@ class Query(object): for model, values in six.iteritems(seen): callback(target, model, values) - def deferred_to_columns_cb(self, target, model, fields): - """ - Callback used by deferred_to_columns(). The "target" parameter should - be a set instance. - """ - table = model._meta.db_table - if table not in target: - target[table] = set() - for field in fields: - target[table].add(field.column) - def table_alias(self, table_name, create=False): """ Returns a table alias for the given table_name and whether this is a @@ -788,10 +774,9 @@ class Query(object): # "group by", "where" and "having". self.where.relabel_aliases(change_map) self.having.relabel_aliases(change_map) - if self.group_by: + if isinstance(self.group_by, list): self.group_by = [relabel_column(col) for col in self.group_by] - self.select = [SelectInfo(relabel_column(s.col), s.field) - for s in self.select] + self.select = [col.relabeled_clone(change_map) for col in self.select] if self._annotations: self._annotations = OrderedDict( (key, relabel_column(col)) for key, col in self._annotations.items()) @@ -815,9 +800,6 @@ class Query(object): if alias == old_alias: self.tables[pos] = new_alias break - for key, alias in self.included_inherited_models.items(): - if alias in change_map: - self.included_inherited_models[key] = change_map[alias] self.external_aliases = {change_map.get(alias, alias) for alias in self.external_aliases} @@ -930,28 +912,6 @@ class Query(object): self.alias_map[alias] = join return alias - def setup_inherited_models(self): - """ - If the model that is the basis for this QuerySet inherits other models, - we need to ensure that those other models have their tables included in - the query. - - We do this as a separate step so that subclasses know which - tables are going to be active in the query, without needing to compute - all the select columns (this method is called from pre_sql_setup(), - whereas column determination is a later part, and side-effect, of - as_sql()). - """ - opts = self.get_meta() - root_alias = self.tables[0] - seen = {None: root_alias} - - for field in opts.fields: - model = field.model._meta.concrete_model - if model is not opts.model and model not in seen: - self.join_parent_model(opts, model, root_alias, seen) - self.included_inherited_models = seen - def join_parent_model(self, opts, model, alias, seen): """ Makes sure the given 'model' is joined in the query. If 'model' isn't @@ -969,7 +929,9 @@ class Query(object): curr_opts = opts for int_model in chain: if int_model in seen: - return seen[int_model] + curr_opts = int_model._meta + alias = seen[int_model] + continue # Proxy model have elements in base chain # with no parents, assign the new options # object and skip to the next base in that @@ -984,23 +946,13 @@ class Query(object): alias = seen[int_model] = joins[-1] return alias or seen[None] - def remove_inherited_models(self): - """ - Undoes the effects of setup_inherited_models(). Should be called - whenever select columns (self.select) are set explicitly. - """ - for key, alias in self.included_inherited_models.items(): - if key: - self.unref_alias(alias) - self.included_inherited_models = {} - def add_aggregate(self, aggregate, model, alias, is_summary): warnings.warn( "add_aggregate() is deprecated. Use add_annotation() instead.", RemovedInDjango20Warning, stacklevel=2) self.add_annotation(aggregate, alias, is_summary) - def add_annotation(self, annotation, alias, is_summary): + def add_annotation(self, annotation, alias, is_summary=False): """ Adds a single annotation expression to the Query """ @@ -1011,6 +963,7 @@ class Query(object): def prepare_lookup_value(self, value, lookups, can_reuse): # Default lookup if none given is exact. + used_joins = [] if len(lookups) == 0: lookups = ['exact'] # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all @@ -1026,7 +979,9 @@ class Query(object): RemovedInDjango19Warning, stacklevel=2) value = value() elif hasattr(value, 'resolve_expression'): + pre_joins = self.alias_refcount.copy() value = value.resolve_expression(self, reuse=can_reuse) + used_joins = [k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)] # Subqueries need to use a different set of aliases than the # outer query. Call bump_prefix to change aliases of the inner # query (the value). @@ -1044,7 +999,7 @@ class Query(object): lookups[-1] == 'exact' and value == ''): value = True lookups[-1] = 'isnull' - return value, lookups + return value, lookups, used_joins def solve_lookup_type(self, lookup): """ @@ -1173,8 +1128,7 @@ class Query(object): # Work out the lookup type and remove it from the end of 'parts', # if necessary. - value, lookups = self.prepare_lookup_value(value, lookups, can_reuse) - used_joins = getattr(value, '_used_joins', []) + value, lookups, used_joins = self.prepare_lookup_value(value, lookups, can_reuse) clause = self.where_class() if reffed_aggregate: @@ -1223,7 +1177,7 @@ class Query(object): # handle Expressions as annotations col = targets[0] else: - col = Col(alias, targets[0], field) + col = targets[0].get_col(alias, field) condition = self.build_lookup(lookups, col, value) if not condition: # Backwards compat for custom lookups @@ -1258,7 +1212,7 @@ class Query(object): # <=> # NOT (col IS NOT NULL AND col = someval). lookup_class = targets[0].get_lookup('isnull') - clause.add(lookup_class(Col(alias, targets[0], sources[0]), False), AND) + clause.add(lookup_class(targets[0].get_col(alias, sources[0]), False), AND) return clause, used_joins if not require_outer else () def add_filter(self, filter_clause): @@ -1535,7 +1489,7 @@ class Query(object): self.unref_alias(joins.pop()) return targets, joins[-1], joins - def resolve_ref(self, name, allow_joins, reuse, summarize): + def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False): if not allow_joins and LOOKUP_SEP in name: raise FieldError("Joined field references are not permitted in this query") if name in self.annotations: @@ -1558,8 +1512,7 @@ class Query(object): "isn't supported") if reuse is not None: reuse.update(join_list) - col = Col(join_list[-1], targets[0], sources[0]) - col._used_joins = join_list + col = targets[0].get_col(join_list[-1], sources[0]) return col def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): @@ -1588,26 +1541,28 @@ class Query(object): # Try to have as simple as possible subquery -> trim leading joins from # the subquery. trimmed_prefix, contains_louter = query.trim_start(names_with_path) - query.remove_inherited_models() # Add extra check to make sure the selected field will not be null # since we are adding an IN clause. This prevents the # database from tripping over IN (...,NULL,...) selects and returning # nothing - alias, col = query.select[0].col - if self.is_nullable(query.select[0].field): - lookup_class = query.select[0].field.get_lookup('isnull') - lookup = lookup_class(Col(alias, query.select[0].field, query.select[0].field), False) + col = query.select[0] + select_field = col.field + alias = col.alias + if self.is_nullable(select_field): + lookup_class = select_field.get_lookup('isnull') + lookup = lookup_class(select_field.get_col(alias), False) query.where.add(lookup, AND) if alias in can_reuse: - select_field = query.select[0].field pk = select_field.model._meta.pk # Need to add a restriction so that outer query's filters are in effect for # the subquery, too. query.bump_prefix(self) lookup_class = select_field.get_lookup('exact') - lookup = lookup_class(Col(query.select[0].col[0], pk, pk), - Col(alias, pk, pk)) + # Note that the query.select[0].alias is different from alias + # due to bump_prefix above. + lookup = lookup_class(pk.get_col(query.select[0].alias), + pk.get_col(alias)) query.where.add(lookup, AND) query.external_aliases.add(alias) @@ -1687,6 +1642,14 @@ class Query(object): """ self.select = [] + def add_select(self, col): + self.default_cols = False + self.select.append(col) + + def set_select(self, cols): + self.default_cols = False + self.select = cols + def add_distinct_fields(self, *field_names): """ Adds and resolves the given fields to the query's "distinct on" clause. @@ -1710,7 +1673,7 @@ class Query(object): name.split(LOOKUP_SEP), opts, alias, allow_many=allow_m2m) targets, final_alias, joins = self.trim_joins(targets, joins, path) for target in targets: - self.select.append(SelectInfo((final_alias, target.column), target)) + self.add_select(target.get_col(final_alias)) except MultiJoin: raise FieldError("Invalid field name: '%s'" % name) except FieldError: @@ -1723,7 +1686,6 @@ class Query(object): + list(self.annotation_select)) raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) - self.remove_inherited_models() def add_ordering(self, *ordering): """ @@ -1766,7 +1728,7 @@ class Query(object): """ self.group_by = [] - for col, _ in self.select: + for col in self.select: self.group_by.append(col) if self._annotations: @@ -1789,7 +1751,6 @@ class Query(object): for part in field.split(LOOKUP_SEP): d = d.setdefault(part, {}) self.select_related = field_dict - self.related_select_cols = [] def add_extra(self, select, select_params, where, params, tables, order_by): """ @@ -1897,7 +1858,7 @@ class Query(object): """ Callback used by get_deferred_field_names(). """ - target[model] = set(f.name for f in fields) + target[model] = {f.attname for f in fields} def set_aggregate_mask(self, names): warnings.warn( @@ -2041,7 +2002,7 @@ class Query(object): if self.alias_refcount[table] > 0: self.alias_map[table] = BaseTable(self.alias_map[table].table_name, table) break - self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields] + self.set_select([f.get_col(select_alias) for f in select_fields]) return trimmed_prefix, contains_louter def is_nullable(self, field): diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index bae9f11c23..35ce71311d 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -5,7 +5,7 @@ Query subclasses which provide extra functionality beyond simple data retrieval. from django.core.exceptions import FieldError from django.db import connections from django.db.models.query_utils import Q -from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo +from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS from django.db.models.sql.query import Query from django.utils import six @@ -71,7 +71,7 @@ class DeleteQuery(Query): else: innerq.clear_select_clause() innerq.select = [ - SelectInfo((self.get_initial_alias(), pk.column), None) + pk.get_col(self.get_initial_alias()) ] values = innerq self.where = self.where_class() diff --git a/docs/howto/custom-model-fields.txt b/docs/howto/custom-model-fields.txt index 6474742222..f35cc41d3d 100644 --- a/docs/howto/custom-model-fields.txt +++ b/docs/howto/custom-model-fields.txt @@ -483,7 +483,7 @@ instances:: class HandField(models.Field): # ... - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): if value is None: return value return parse_hand(value) diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index defcdb0217..126455aba4 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -399,7 +399,7 @@ calling the appropriate methods on the wrapped expression. clone.expression = self.expression.relabeled_clone(change_map) return clone - .. method:: convert_value(self, value, connection) + .. method:: convert_value(self, value, connection, context) A hook allowing the expression to coerce ``value`` into a more appropriate type. diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index bc214fae00..8604137afe 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1670,7 +1670,7 @@ Field API reference When loading data, :meth:`from_db_value` is used: - .. method:: from_db_value(value, connection) + .. method:: from_db_value(value, connection, context) .. versionadded:: 1.8 diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 917137cce1..ec19fcfd53 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -679,7 +679,7 @@ class BaseAggregateTestCase(TestCase): # the only "ORDER BY" clause present in the query. self.assertEqual( re.findall(r'order by (\w+)', qstr), - [', '.join(forced_ordering).lower()] + [', '.join(f[1][0] for f in forced_ordering).lower()] ) else: self.assertNotIn('order by', qstr) diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index de18372f9c..8d2a006137 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -490,9 +490,10 @@ class AggregationTests(TestCase): # Regression for #15709 - Ensure each group_by field only exists once # per query - qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by() - grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping([], []) - self.assertEqual(len(grouping), 1) + qstr = str(Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by().query) + # Check that there is just one GROUP BY clause (zero commas means at + # most one clause) + self.assertEqual(qstr[qstr.index('GROUP BY'):].count(', '), 0) def test_duplicate_alias(self): # Regression for #11256 - duplicating a default alias raises ValueError. @@ -924,14 +925,11 @@ class AggregationTests(TestCase): # There should only be one GROUP BY clause, for the `id` column. # `name` and `age` should not be grouped on. - grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], []) - self.assertEqual(len(grouping), 1) - assert 'id' in grouping[0] - assert 'name' not in grouping[0] - assert 'age' not in grouping[0] - - # The query group_by property should also only show the `id`. - self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')]) + _, _, group_by = results.query.get_compiler(using='default').pre_sql_setup() + self.assertEqual(len(group_by), 1) + self.assertIn('id', group_by[0][0]) + self.assertNotIn('name', group_by[0][0]) + self.assertNotIn('age', group_by[0][0]) # Ensure that we get correct results. self.assertEqual( @@ -953,14 +951,11 @@ class AggregationTests(TestCase): def test_aggregate_duplicate_columns_only(self): # Works with only() too. results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set')) - grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], []) + _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup() self.assertEqual(len(grouping), 1) - assert 'id' in grouping[0] - assert 'name' not in grouping[0] - assert 'age' not in grouping[0] - - # The query group_by property should also only show the `id`. - self.assertEqual(results.query.group_by, [('aggregation_regress_author', 'id')]) + self.assertIn('id', grouping[0][0]) + self.assertNotIn('name', grouping[0][0]) + self.assertNotIn('age', grouping[0][0]) # Ensure that we get correct results. self.assertEqual( @@ -983,14 +978,11 @@ class AggregationTests(TestCase): # And select_related() results = Book.objects.select_related('contact').annotate( num_authors=Count('authors')) - grouping, gb_params = results.query.get_compiler(using='default').get_grouping([], []) + _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup() self.assertEqual(len(grouping), 1) - assert 'id' in grouping[0] - assert 'name' not in grouping[0] - assert 'contact' not in grouping[0] - - # The query group_by property should also only show the `id`. - self.assertEqual(results.query.group_by, [('aggregation_regress_book', 'id')]) + self.assertIn('id', grouping[0][0]) + self.assertNotIn('name', grouping[0][0]) + self.assertNotIn('contact', grouping[0][0]) # Ensure that we get correct results. self.assertEqual( diff --git a/tests/custom_pk/fields.py b/tests/custom_pk/fields.py index 1f3265952a..bf349545e5 100644 --- a/tests/custom_pk/fields.py +++ b/tests/custom_pk/fields.py @@ -43,7 +43,7 @@ class MyAutoField(models.CharField): value = MyWrapper(value) return value - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): if not value: return return MyWrapper(value) diff --git a/tests/defer_regress/models.py b/tests/defer_regress/models.py index 58693c83db..a81b02058d 100644 --- a/tests/defer_regress/models.py +++ b/tests/defer_regress/models.py @@ -96,3 +96,11 @@ class Request(models.Model): request2 = models.CharField(default='request2', max_length=1000) request3 = models.CharField(default='request3', max_length=1000) request4 = models.CharField(default='request4', max_length=1000) + + +class Base(models.Model): + text = models.TextField() + + +class Derived(Base): + other_text = models.TextField() diff --git a/tests/defer_regress/tests.py b/tests/defer_regress/tests.py index f2a015a7a1..1222212d8d 100644 --- a/tests/defer_regress/tests.py +++ b/tests/defer_regress/tests.py @@ -12,7 +12,7 @@ from django.test import TestCase, override_settings from .models import ( ResolveThis, Item, RelatedItem, Child, Leaf, Proxy, SimpleItem, Feature, ItemAndSimpleItem, OneToOneItem, SpecialFeature, Location, Request, - ProxyRelated, + ProxyRelated, Derived, Base, ) @@ -145,6 +145,15 @@ class DeferRegressionTest(TestCase): list(SimpleItem.objects.annotate(Count('feature')).only('name')), list) + def test_ticket_23270(self): + Derived.objects.create(text="foo", other_text="bar") + with self.assertNumQueries(1): + obj = Base.objects.select_related("derived").defer("text")[0] + self.assertIsInstance(obj.derived, Derived) + self.assertEqual("bar", obj.derived.other_text) + self.assertNotIn("text", obj.__dict__) + self.assertEqual(1, obj.derived.base_ptr_id) + def test_only_and_defer_usage_on_proxy_models(self): # Regression for #15790 - only() broken for proxy models proxy = Proxy.objects.create(name="proxy", value=42) diff --git a/tests/from_db_value/models.py b/tests/from_db_value/models.py index 4cc9e62168..6a06c832ea 100644 --- a/tests/from_db_value/models.py +++ b/tests/from_db_value/models.py @@ -18,7 +18,7 @@ class CashField(models.DecimalField): kwargs['decimal_places'] = 2 super(CashField, self).__init__(**kwargs) - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): cash = Cash(value) cash.vendor = connection.vendor return cash diff --git a/tests/queries/tests.py b/tests/queries/tests.py index 900149b873..92e5982e10 100644 --- a/tests/queries/tests.py +++ b/tests/queries/tests.py @@ -3563,8 +3563,10 @@ class Ticket20955Tests(TestCase): # version's queries. task_get.creator.staffuser.staff task_get.owner.staffuser.staff - task_select_related = Task.objects.select_related( - 'creator__staffuser__staff', 'owner__staffuser__staff').get(pk=task.pk) + qs = Task.objects.select_related( + 'creator__staffuser__staff', 'owner__staffuser__staff') + self.assertEqual(str(qs.query).count(' JOIN '), 6) + task_select_related = qs.get(pk=task.pk) with self.assertNumQueries(0): self.assertEqual(task_select_related.creator.staffuser.staff, task_get.creator.staffuser.staff) diff --git a/tests/select_related_onetoone/tests.py b/tests/select_related_onetoone/tests.py index a13ac7809a..a6c0f3a42b 100644 --- a/tests/select_related_onetoone/tests.py +++ b/tests/select_related_onetoone/tests.py @@ -83,6 +83,7 @@ class ReverseSelectRelatedTestCase(TestCase): stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) self.assertEqual(stat.advanceduserstat.posts, 200) self.assertEqual(stat.user.username, 'bob') + with self.assertNumQueries(1): self.assertEqual(stat.advanceduserstat.user.username, 'bob') def test_nullable_relation(self): diff --git a/tests/serializers/models.py b/tests/serializers/models.py index 2d1ab9758b..b2864b1c71 100644 --- a/tests/serializers/models.py +++ b/tests/serializers/models.py @@ -112,7 +112,7 @@ class TeamField(models.CharField): return value return Team(value) - def from_db_value(self, value, connection): + def from_db_value(self, value, connection, context): return Team(value) def value_to_string(self, obj):