Fixed #9871 -- Geometry objects are now returned in dictionaries and tuples returned by `values()` and `values_list()`, respectively; updated `GeoQuery` methods to be compatible with `defer()` and `only`; removed defunct `GeomSQL` class; and removed redundant logic from `Query.get_default_columns`.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10326 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Justin Bronn 2009-04-01 16:01:50 +00:00
parent f1c64816bb
commit 03de1fe5f4
5 changed files with 98 additions and 35 deletions

View File

@ -1,6 +1,6 @@
from django.core.exceptions import ImproperlyConfigured
from django.db import connection
from django.db.models.query import sql, QuerySet, Q
from django.db.models.query import QuerySet, Q, ValuesQuerySet, ValuesListQuerySet
from django.contrib.gis.db.backend import SpatialBackend
from django.contrib.gis.db.models import aggregates
@ -9,21 +9,28 @@ from django.contrib.gis.db.models.sql import AreaField, DistanceField, GeomField
from django.contrib.gis.measure import Area, Distance
from django.contrib.gis.models import get_srid_info
class GeomSQL(object):
"Simple wrapper object for geometric SQL."
def __init__(self, geo_sql):
self.sql = geo_sql
def as_sql(self, *args, **kwargs):
return self.sql
class GeoQuerySet(QuerySet):
"The Geographic QuerySet."
### Methods overloaded from QuerySet ###
def __init__(self, model=None, query=None):
super(GeoQuerySet, self).__init__(model=model, query=query)
self.query = query or GeoQuery(self.model, connection)
def values(self, *fields):
return self._clone(klass=GeoValuesQuerySet, setup=True, _fields=fields)
def values_list(self, *fields, **kwargs):
flat = kwargs.pop('flat', False)
if kwargs:
raise TypeError('Unexpected keyword arguments to values_list: %s'
% (kwargs.keys(),))
if flat and len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
return self._clone(klass=GeoValuesListQuerySet, setup=True, flat=flat,
_fields=fields)
### GeoQuerySet Methods ###
def area(self, tolerance=0.05, **kwargs):
"""
Returns the area of the geographic field in an `area` attribute on
@ -592,3 +599,14 @@ class GeoQuerySet(QuerySet):
return self.query._field_column(geo_field, parent_model._meta.db_table)
else:
return self.query._field_column(geo_field)
class GeoValuesQuerySet(ValuesQuerySet):
def __init__(self, *args, **kwargs):
super(GeoValuesQuerySet, self).__init__(*args, **kwargs)
# This flag tells `resolve_columns` to run the values through
# `convert_values`. This ensures that Geometry objects instead
# of string values are returned with `values()` or `values_list()`.
self.query.geo_values = True
class GeoValuesListQuerySet(GeoValuesQuerySet, ValuesListQuerySet):
pass

View File

@ -14,6 +14,8 @@ from django.contrib.gis.measure import Area, Distance
ALL_TERMS = sql.constants.QUERY_TERMS.copy()
ALL_TERMS.update(SpatialBackend.gis_terms)
TABLE_NAME = sql.constants.TABLE_NAME
class GeoQuery(sql.Query):
"""
A single spatial SQL query.
@ -64,10 +66,15 @@ class GeoQuery(sql.Query):
else:
col_aliases = set()
if self.select:
only_load = self.deferred_to_columns()
# This loop customized for GeoQuery.
for col, field in izip(self.select, self.select_fields):
if isinstance(col, (list, tuple)):
r = self.get_field_select(field, col[0])
alias, column = col
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and col not in only_load[table]:
continue
r = self.get_field_select(field, alias)
if with_aliases:
if col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
@ -75,7 +82,7 @@ class GeoQuery(sql.Query):
aliases.add(c_alias)
col_aliases.add(c_alias)
else:
result.append('%s AS %s' % (r, col[1]))
result.append('%s AS %s' % (r, qn2(col[1])))
aliases.add(r)
col_aliases.add(col[1])
else:
@ -123,10 +130,14 @@ class GeoQuery(sql.Query):
start_alias=None, opts=None, as_pairs=False):
"""
Computes the default columns for selecting every field in the base
model.
model. Will sometimes be called to pull in related models (e.g. via
select_related), in which case "opts" and "start_alias" will be given
to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement.
directly, as well as a set of aliases used in the select statement (if
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
of strings as the first component and None as the second component).
This routine is overridden from Query to handle customized selection of
geometry columns.
@ -134,22 +145,34 @@ class GeoQuery(sql.Query):
result = []
if opts is None:
opts = self.model._meta
if start_alias:
table_alias = start_alias
else:
table_alias = self.tables[0]
root_pk = opts.pk.column
seen = {None: table_alias}
aliases = set()
only_load = self.deferred_to_columns()
proxied_model = opts.proxy and opts.proxy_for_model or 0
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model():
if start_alias:
try:
alias = seen[model]
except KeyError:
alias = self.join((table_alias, model._meta.db_table,
root_pk, model._meta.pk.column))
if model is proxied_model:
alias = start_alias
else:
link_field = opts.get_ancestor_link(model)
alias = self.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column))
seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
# aliases will have already been set up in pre_sql_setup(), so
# we can save time here.
alias = self.included_inherited_models[model]
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and field.column not in only_load[table]:
continue
if as_pairs:
result.append((alias, field.column))
aliases.add(alias)
continue
# This part of the function is customized for GeoQuery. We
# see if there was any custom selection specified in the
@ -166,8 +189,6 @@ class GeoQuery(sql.Query):
aliases.add(r)
if with_aliases:
col_aliases.add(field.column)
if as_pairs:
return result, None
return result, aliases
def resolve_columns(self, row, fields=()):
@ -191,8 +212,8 @@ class GeoQuery(sql.Query):
# distance objects added by GeoQuerySet methods).
values = [self.convert_values(v, self.extra_select_fields.get(a, None))
for v, a in izip(row[rn_offset:index_start], aliases)]
if SpatialBackend.oracle:
# This is what happens normally in OracleQuery's `resolve_columns`.
if SpatialBackend.oracle or getattr(self, 'geo_values', False):
# We resolve the columns
for value, field in izip(row[index_start:], fields):
values.append(self.convert_values(value, field))
else:
@ -215,7 +236,7 @@ class GeoQuery(sql.Query):
value = Distance(**{field.distance_att : value})
elif isinstance(field, AreaField):
value = Area(**{field.area_att : value})
elif isinstance(field, GeomField) and value:
elif isinstance(field, (GeomField, GeometryField)) and value:
value = SpatialBackend.Geometry(value)
return value

View File

@ -2,15 +2,16 @@ from django.contrib.gis.db import models
from django.contrib.localflavor.us.models import USStateField
class Location(models.Model):
name = models.CharField(max_length=50)
point = models.PointField()
objects = models.GeoManager()
def __unicode__(self): return self.point.wkt
class City(models.Model):
name = models.CharField(max_length=50)
state = USStateField()
location = models.ForeignKey(Location)
objects = models.GeoManager()
def __unicode__(self): return self.name
class AugmentedLocation(Location):
extra_text = models.TextField(blank=True)

View File

@ -118,7 +118,7 @@ class RelatedGeoModelTest(unittest.TestCase):
# Regression test for #9752.
l = list(DirectoryEntry.objects.all().select_related())
def test6_f_expressions(self):
def test06_f_expressions(self):
"Testing F() expressions on GeometryFields."
# Constructing a dummy parcel border and getting the City instance for
# assigning the FK.
@ -166,6 +166,31 @@ class RelatedGeoModelTest(unittest.TestCase):
self.assertEqual(1, len(qs))
self.assertEqual('P1', qs[0].name)
def test07_values(self):
"Testing values() and values_list() and GeoQuerySets."
# GeoQuerySet and GeoValuesQuerySet, and GeoValuesListQuerySet respectively.
gqs = Location.objects.all()
gvqs = Location.objects.values()
gvlqs = Location.objects.values_list()
# Incrementing through each of the models, dictionaries, and tuples
# returned by the different types of GeoQuerySets.
for m, d, t in zip(gqs, gvqs, gvlqs):
# The values should be Geometry objects and not raw strings returned
# by the spatial database.
self.failUnless(isinstance(d['point'], SpatialBackend.Geometry))
self.failUnless(isinstance(t[1], SpatialBackend.Geometry))
self.assertEqual(m.point, d['point'])
self.assertEqual(m.point, t[1])
# Test disabled until #10572 is resolved.
#def test08_defer_only(self):
# "Testing defer() and only() on Geographic models."
# qs = Location.objects.all()
# def_qs = Location.objects.defer('point')
# for loc, def_loc in zip(qs, def_qs):
# self.assertEqual(loc.point, def_loc.point)
# TODO: Related tests for KML, GML, and distance lookups.
def suite():

View File

@ -784,8 +784,6 @@ class BaseQuery(object):
aliases.add(r)
if with_aliases:
col_aliases.add(field.column)
if as_pairs:
return result, aliases
return result, aliases
def get_from_clause(self):