Fixed #9871 -- Geometry objects are now returned in dictionaries and tuples returned by `values()` and `values_list()`, respectively; updated `GeoQuery` methods to be compatible with `defer()` and `only`; removed defunct `GeomSQL` class; and removed redundant logic from `Query.get_default_columns`.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10326 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Justin Bronn 2009-04-01 16:01:50 +00:00
parent f1c64816bb
commit 03de1fe5f4
5 changed files with 98 additions and 35 deletions

View File

@ -1,6 +1,6 @@
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import connection from django.db import connection
from django.db.models.query import sql, QuerySet, Q from django.db.models.query import QuerySet, Q, ValuesQuerySet, ValuesListQuerySet
from django.contrib.gis.db.backend import SpatialBackend from django.contrib.gis.db.backend import SpatialBackend
from django.contrib.gis.db.models import aggregates from django.contrib.gis.db.models import aggregates
@ -9,21 +9,28 @@ from django.contrib.gis.db.models.sql import AreaField, DistanceField, GeomField
from django.contrib.gis.measure import Area, Distance from django.contrib.gis.measure import Area, Distance
from django.contrib.gis.models import get_srid_info from django.contrib.gis.models import get_srid_info
class GeomSQL(object):
"Simple wrapper object for geometric SQL."
def __init__(self, geo_sql):
self.sql = geo_sql
def as_sql(self, *args, **kwargs):
return self.sql
class GeoQuerySet(QuerySet): class GeoQuerySet(QuerySet):
"The Geographic QuerySet." "The Geographic QuerySet."
### Methods overloaded from QuerySet ###
def __init__(self, model=None, query=None): def __init__(self, model=None, query=None):
super(GeoQuerySet, self).__init__(model=model, query=query) super(GeoQuerySet, self).__init__(model=model, query=query)
self.query = query or GeoQuery(self.model, connection) self.query = query or GeoQuery(self.model, connection)
def values(self, *fields):
return self._clone(klass=GeoValuesQuerySet, setup=True, _fields=fields)
def values_list(self, *fields, **kwargs):
flat = kwargs.pop('flat', False)
if kwargs:
raise TypeError('Unexpected keyword arguments to values_list: %s'
% (kwargs.keys(),))
if flat and len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
return self._clone(klass=GeoValuesListQuerySet, setup=True, flat=flat,
_fields=fields)
### GeoQuerySet Methods ###
def area(self, tolerance=0.05, **kwargs): def area(self, tolerance=0.05, **kwargs):
""" """
Returns the area of the geographic field in an `area` attribute on Returns the area of the geographic field in an `area` attribute on
@ -592,3 +599,14 @@ class GeoQuerySet(QuerySet):
return self.query._field_column(geo_field, parent_model._meta.db_table) return self.query._field_column(geo_field, parent_model._meta.db_table)
else: else:
return self.query._field_column(geo_field) return self.query._field_column(geo_field)
class GeoValuesQuerySet(ValuesQuerySet):
def __init__(self, *args, **kwargs):
super(GeoValuesQuerySet, self).__init__(*args, **kwargs)
# This flag tells `resolve_columns` to run the values through
# `convert_values`. This ensures that Geometry objects instead
# of string values are returned with `values()` or `values_list()`.
self.query.geo_values = True
class GeoValuesListQuerySet(GeoValuesQuerySet, ValuesListQuerySet):
pass

View File

@ -14,6 +14,8 @@ from django.contrib.gis.measure import Area, Distance
ALL_TERMS = sql.constants.QUERY_TERMS.copy() ALL_TERMS = sql.constants.QUERY_TERMS.copy()
ALL_TERMS.update(SpatialBackend.gis_terms) ALL_TERMS.update(SpatialBackend.gis_terms)
TABLE_NAME = sql.constants.TABLE_NAME
class GeoQuery(sql.Query): class GeoQuery(sql.Query):
""" """
A single spatial SQL query. A single spatial SQL query.
@ -64,10 +66,15 @@ class GeoQuery(sql.Query):
else: else:
col_aliases = set() col_aliases = set()
if self.select: if self.select:
only_load = self.deferred_to_columns()
# This loop customized for GeoQuery. # This loop customized for GeoQuery.
for col, field in izip(self.select, self.select_fields): for col, field in izip(self.select, self.select_fields):
if isinstance(col, (list, tuple)): if isinstance(col, (list, tuple)):
r = self.get_field_select(field, col[0]) alias, column = col
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and col not in only_load[table]:
continue
r = self.get_field_select(field, alias)
if with_aliases: if with_aliases:
if col[1] in col_aliases: if col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases) c_alias = 'Col%d' % len(col_aliases)
@ -75,7 +82,7 @@ class GeoQuery(sql.Query):
aliases.add(c_alias) aliases.add(c_alias)
col_aliases.add(c_alias) col_aliases.add(c_alias)
else: else:
result.append('%s AS %s' % (r, col[1])) result.append('%s AS %s' % (r, qn2(col[1])))
aliases.add(r) aliases.add(r)
col_aliases.add(col[1]) col_aliases.add(col[1])
else: else:
@ -123,10 +130,14 @@ class GeoQuery(sql.Query):
start_alias=None, opts=None, as_pairs=False): start_alias=None, opts=None, as_pairs=False):
""" """
Computes the default columns for selecting every field in the base Computes the default columns for selecting every field in the base
model. model. Will sometimes be called to pull in related models (e.g. via
select_related), in which case "opts" and "start_alias" will be given
to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement. directly, as well as a set of aliases used in the select statement (if
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
of strings as the first component and None as the second component).
This routine is overridden from Query to handle customized selection of This routine is overridden from Query to handle customized selection of
geometry columns. geometry columns.
@ -134,22 +145,34 @@ class GeoQuery(sql.Query):
result = [] result = []
if opts is None: if opts is None:
opts = self.model._meta opts = self.model._meta
if start_alias:
table_alias = start_alias
else:
table_alias = self.tables[0]
root_pk = opts.pk.column
seen = {None: table_alias}
aliases = set() aliases = set()
only_load = self.deferred_to_columns()
proxied_model = opts.proxy and opts.proxy_for_model or 0
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model(): for field, model in opts.get_fields_with_model():
if start_alias:
try: try:
alias = seen[model] alias = seen[model]
except KeyError: except KeyError:
alias = self.join((table_alias, model._meta.db_table, if model is proxied_model:
root_pk, model._meta.pk.column)) alias = start_alias
else:
link_field = opts.get_ancestor_link(model)
alias = self.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column))
seen[model] = alias seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
# aliases will have already been set up in pre_sql_setup(), so
# we can save time here.
alias = self.included_inherited_models[model]
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and field.column not in only_load[table]:
continue
if as_pairs: if as_pairs:
result.append((alias, field.column)) result.append((alias, field.column))
aliases.add(alias)
continue continue
# This part of the function is customized for GeoQuery. We # This part of the function is customized for GeoQuery. We
# see if there was any custom selection specified in the # see if there was any custom selection specified in the
@ -166,8 +189,6 @@ class GeoQuery(sql.Query):
aliases.add(r) aliases.add(r)
if with_aliases: if with_aliases:
col_aliases.add(field.column) col_aliases.add(field.column)
if as_pairs:
return result, None
return result, aliases return result, aliases
def resolve_columns(self, row, fields=()): def resolve_columns(self, row, fields=()):
@ -191,8 +212,8 @@ class GeoQuery(sql.Query):
# distance objects added by GeoQuerySet methods). # distance objects added by GeoQuerySet methods).
values = [self.convert_values(v, self.extra_select_fields.get(a, None)) values = [self.convert_values(v, self.extra_select_fields.get(a, None))
for v, a in izip(row[rn_offset:index_start], aliases)] for v, a in izip(row[rn_offset:index_start], aliases)]
if SpatialBackend.oracle: if SpatialBackend.oracle or getattr(self, 'geo_values', False):
# This is what happens normally in OracleQuery's `resolve_columns`. # We resolve the columns
for value, field in izip(row[index_start:], fields): for value, field in izip(row[index_start:], fields):
values.append(self.convert_values(value, field)) values.append(self.convert_values(value, field))
else: else:
@ -215,7 +236,7 @@ class GeoQuery(sql.Query):
value = Distance(**{field.distance_att : value}) value = Distance(**{field.distance_att : value})
elif isinstance(field, AreaField): elif isinstance(field, AreaField):
value = Area(**{field.area_att : value}) value = Area(**{field.area_att : value})
elif isinstance(field, GeomField) and value: elif isinstance(field, (GeomField, GeometryField)) and value:
value = SpatialBackend.Geometry(value) value = SpatialBackend.Geometry(value)
return value return value

View File

@ -2,15 +2,16 @@ from django.contrib.gis.db import models
from django.contrib.localflavor.us.models import USStateField from django.contrib.localflavor.us.models import USStateField
class Location(models.Model): class Location(models.Model):
name = models.CharField(max_length=50)
point = models.PointField() point = models.PointField()
objects = models.GeoManager() objects = models.GeoManager()
def __unicode__(self): return self.point.wkt
class City(models.Model): class City(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
state = USStateField() state = USStateField()
location = models.ForeignKey(Location) location = models.ForeignKey(Location)
objects = models.GeoManager() objects = models.GeoManager()
def __unicode__(self): return self.name
class AugmentedLocation(Location): class AugmentedLocation(Location):
extra_text = models.TextField(blank=True) extra_text = models.TextField(blank=True)

View File

@ -118,7 +118,7 @@ class RelatedGeoModelTest(unittest.TestCase):
# Regression test for #9752. # Regression test for #9752.
l = list(DirectoryEntry.objects.all().select_related()) l = list(DirectoryEntry.objects.all().select_related())
def test6_f_expressions(self): def test06_f_expressions(self):
"Testing F() expressions on GeometryFields." "Testing F() expressions on GeometryFields."
# Constructing a dummy parcel border and getting the City instance for # Constructing a dummy parcel border and getting the City instance for
# assigning the FK. # assigning the FK.
@ -166,6 +166,31 @@ class RelatedGeoModelTest(unittest.TestCase):
self.assertEqual(1, len(qs)) self.assertEqual(1, len(qs))
self.assertEqual('P1', qs[0].name) self.assertEqual('P1', qs[0].name)
def test07_values(self):
"Testing values() and values_list() and GeoQuerySets."
# GeoQuerySet and GeoValuesQuerySet, and GeoValuesListQuerySet respectively.
gqs = Location.objects.all()
gvqs = Location.objects.values()
gvlqs = Location.objects.values_list()
# Incrementing through each of the models, dictionaries, and tuples
# returned by the different types of GeoQuerySets.
for m, d, t in zip(gqs, gvqs, gvlqs):
# The values should be Geometry objects and not raw strings returned
# by the spatial database.
self.failUnless(isinstance(d['point'], SpatialBackend.Geometry))
self.failUnless(isinstance(t[1], SpatialBackend.Geometry))
self.assertEqual(m.point, d['point'])
self.assertEqual(m.point, t[1])
# Test disabled until #10572 is resolved.
#def test08_defer_only(self):
# "Testing defer() and only() on Geographic models."
# qs = Location.objects.all()
# def_qs = Location.objects.defer('point')
# for loc, def_loc in zip(qs, def_qs):
# self.assertEqual(loc.point, def_loc.point)
# TODO: Related tests for KML, GML, and distance lookups. # TODO: Related tests for KML, GML, and distance lookups.
def suite(): def suite():

View File

@ -784,8 +784,6 @@ class BaseQuery(object):
aliases.add(r) aliases.add(r)
if with_aliases: if with_aliases:
col_aliases.add(field.column) col_aliases.add(field.column)
if as_pairs:
return result, aliases
return result, aliases return result, aliases
def get_from_clause(self): def get_from_clause(self):