diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py index 928f209329..712316e434 100644 --- a/django/contrib/gis/db/models/sql/query.py +++ b/django/contrib/gis/db/models/sql/query.py @@ -278,40 +278,6 @@ class GeoQuery(sql.Query): return sel_fmt # Private API utilities, subject to change. - def _check_geo_field(self, model, name_param): - """ - Recursive utility routine for checking the given name parameter - on the given model. Initially, the name parameter is a string, - of the field on the given model e.g., 'point', 'the_geom'. - Related model field strings like 'address__point', may also be - used. - - If a GeometryField exists according to the given name parameter - it will be returned, otherwise returns False. - """ - if isinstance(name_param, basestring): - # This takes into account the situation where the name is a - # lookup to a related geographic field, e.g., 'address__point'. - name_param = name_param.split(sql.constants.LOOKUP_SEP) - name_param.reverse() # Reversing so list operates like a queue of related lookups. - elif not isinstance(name_param, list): - raise TypeError - try: - # Getting the name of the field for the model (by popping the first - # name from the `name_param` list created above). - fld, mod, direct, m2m = model._meta.get_field_by_name(name_param.pop()) - except (FieldDoesNotExist, IndexError): - return False - # TODO: ManyToManyField? - if isinstance(fld, GeometryField): - return fld # A-OK. - elif isinstance(fld, ForeignKey): - # ForeignKey encountered, return the output of this utility called - # on the _related_ model with the remaining name parameters. - return self._check_geo_field(fld.rel.to, name_param) # Recurse to check ForeignKey relation. - else: - return False - def _field_column(self, field, table_alias=None): """ Helper function that returns the database column for the given field. @@ -339,4 +305,4 @@ class GeoQuery(sql.Query): else: # Otherwise, check by the given field name -- which may be # a lookup to a _related_ geographic field. - return self._check_geo_field(self.model, field_name) + return GeoWhereNode._check_geo_field(self.model._meta, field_name) diff --git a/django/contrib/gis/db/models/sql/where.py b/django/contrib/gis/db/models/sql/where.py index 5db845f0b9..52cf5cd9b4 100644 --- a/django/contrib/gis/db/models/sql/where.py +++ b/django/contrib/gis/db/models/sql/where.py @@ -1,7 +1,11 @@ -import datetime +from django.db import connection from django.db.models.fields import Field +from django.db.models.sql.constants import LOOKUP_SEP +from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.where import WhereNode from django.contrib.gis.db.backend import get_geo_where_clause, SpatialBackend +from django.contrib.gis.db.models.fields import GeometryField +qn = connection.ops.quote_name class GeoAnnotation(object): """ @@ -37,9 +41,35 @@ class GeoWhereNode(WhereNode): # Not a geographic field, so call `WhereNode.add`. return super(GeoWhereNode, self).add(data, connector) else: - # `GeometryField.get_db_prep_lookup` returns a where clause - # substitution array in addition to the parameters. - where, params = field.get_db_prep_lookup(lookup_type, value) + if isinstance(value, SQLEvaluator): + # Getting the geographic field to compare with from the expression. + geo_fld = self._check_geo_field(value.opts, value.expression.name) + if not geo_fld: + raise ValueError('No geographic field found in expression.') + + # Get the SRID of the geometry field that the expression was meant + # to operate on -- it's needed to determine whether transformation + # SQL is necessary. + srid = geo_fld._srid + + # Getting the quoted representation of the geometry column that + # the expression is operating on. + geo_col = '%s.%s' % tuple(map(qn, value.cols[value.expression])) + + # If it's in a different SRID, we'll need to wrap in + # transformation SQL. + if not srid is None and srid != field._srid and SpatialBackend.transform: + placeholder = '%s(%%s, %s)' % (SpatialBackend.transform, field._srid) + else: + placeholder = '%s' + + # Setting these up as if we had called `field.get_db_prep_lookup()`. + where = [placeholder % geo_col] + params = () + else: + # `GeometryField.get_db_prep_lookup` returns a where clause + # substitution array in addition to the parameters. + where, params = field.get_db_prep_lookup(lookup_type, value) # The annotation will be a `GeoAnnotation` object that # will contain the necessary geometry field metadata for @@ -64,3 +94,42 @@ class GeoWhereNode(WhereNode): # If not a GeometryField, call the `make_atom` from the # base class. return super(GeoWhereNode, self).make_atom(child, qn) + + @classmethod + def _check_geo_field(cls, opts, lookup): + """ + Utility for checking the given lookup with the given model options. + The lookup is a string either specifying the geographic field, e.g. + 'point, 'the_geom', or a related lookup on a geographic field like + 'address__point'. + + If a GeometryField exists according to the given lookup on the model + options, it will be returned. Otherwise returns None. + """ + # This takes into account the situation where the lookup is a + # lookup to a related geographic field, e.g., 'address__point'. + field_list = lookup.split(LOOKUP_SEP) + + # Reversing so list operates like a queue of related lookups, + # and popping the top lookup. + field_list.reverse() + fld_name = field_list.pop() + + try: + geo_fld = opts.get_field(fld_name) + # If the field list is still around, then it means that the + # lookup was for a geometry field across a relationship -- + # thus we keep on getting the related model options and the + # model field associated with the next field in the list + # until there's no more left. + while len(field_list): + opts = geo_fld.rel.to._meta + geo_fld = opts.get_field(field_list.pop()) + except (FieldDoesNotExist, AttributeError): + return False + + # Finally, make sure we got a Geographic field and return. + if isinstance(geo_fld, GeometryField): + return geo_fld + else: + return False diff --git a/django/contrib/gis/tests/__init__.py b/django/contrib/gis/tests/__init__.py index 19d305e3bd..53674afa62 100644 --- a/django/contrib/gis/tests/__init__.py +++ b/django/contrib/gis/tests/__init__.py @@ -21,12 +21,7 @@ def geo_suite(): 'test_measure', ] if HAS_GDAL: - if oracle: - # TODO: There's a problem with `select_related` and GeoQuerySet on - # Oracle -- e.g., GeoModel.objects.distance(geom, field_name='fk__point') - # doesn't work so we don't test `relatedapp`. - test_models += ['distapp', 'layermap'] - elif postgis: + if oracle or postgis: test_models += ['distapp', 'layermap', 'relatedapp'] elif mysql: test_models += ['relatedapp', 'layermap'] diff --git a/django/contrib/gis/tests/relatedapp/models.py b/django/contrib/gis/tests/relatedapp/models.py index 5db6baa1f3..8ea1469b3e 100644 --- a/django/contrib/gis/tests/relatedapp/models.py +++ b/django/contrib/gis/tests/relatedapp/models.py @@ -20,3 +20,14 @@ class DirectoryEntry(models.Model): listing_text = models.CharField(max_length=50) location = models.ForeignKey(AugmentedLocation) objects = models.GeoManager() + +class Parcel(models.Model): + name = models.CharField(max_length=30) + city = models.ForeignKey(City) + center1 = models.PointField() + # Throwing a curveball w/`db_column` here. + center2 = models.PointField(srid=2276, db_column='mycenter') + border1 = models.PolygonField() + border2 = models.PolygonField(srid=2276) + objects = models.GeoManager() + def __unicode__(self): return self.name diff --git a/django/contrib/gis/tests/relatedapp/tests.py b/django/contrib/gis/tests/relatedapp/tests.py index 47d49c73e9..0cde7cc6a7 100644 --- a/django/contrib/gis/tests/relatedapp/tests.py +++ b/django/contrib/gis/tests/relatedapp/tests.py @@ -1,8 +1,10 @@ import os, unittest from django.contrib.gis.geos import * -from django.contrib.gis.tests.utils import no_mysql, postgis +from django.contrib.gis.db.backend import SpatialBackend +from django.contrib.gis.db.models import F, Extent, Union +from django.contrib.gis.tests.utils import no_mysql, no_oracle from django.conf import settings -from models import City, Location, DirectoryEntry +from models import City, Location, DirectoryEntry, Parcel cities = (('Aurora', 'TX', -97.516111, 33.058333), ('Roswell', 'NM', -104.528056, 33.387222), @@ -14,11 +16,10 @@ class RelatedGeoModelTest(unittest.TestCase): def test01_setup(self): "Setting up for related model tests." for name, state, lon, lat in cities: - loc = Location(point=Point(lon, lat)) - loc.save() - c = City(name=name, state=state, location=loc) - c.save() - + loc = Location.objects.create(point=Point(lon, lat)) + c = City.objects.create(name=name, state=state, location=loc) + + @no_oracle # TODO: Fix select_related() problems w/Oracle and pagination. def test02_select_related(self): "Testing `select_related` on geographic models (see #7126)." qs1 = City.objects.all() @@ -33,28 +34,21 @@ class RelatedGeoModelTest(unittest.TestCase): self.assertEqual(Point(lon, lat), c.location.point) @no_mysql + @no_oracle # Pagination problem is implicated in this test as well. def test03_transform_related(self): "Testing the `transform` GeoQuerySet method on related geographic models." # All the transformations are to state plane coordinate systems using # US Survey Feet (thus a tolerance of 0 implies error w/in 1 survey foot). - if postgis: + if SpatialBackend.postgis: tol = 3 - nqueries = 4 # +1 for `postgis_lib_version` else: tol = 0 - nqueries = 3 def check_pnt(ref, pnt): self.assertAlmostEqual(ref.x, pnt.x, tol) self.assertAlmostEqual(ref.y, pnt.y, tol) self.assertEqual(ref.srid, pnt.srid) - # Turning on debug so we can manually verify the number of SQL queries issued. - # DISABLED: the number of queries count testing mechanism is way too brittle. - #dbg = settings.DEBUG - #settings.DEBUG = True - from django.db import connection - # Each city transformed to the SRID of their state plane coordinate system. transformed = (('Kecksburg', 2272, 'POINT(1490553.98959621 314792.131023984)'), ('Roswell', 2257, 'POINT(481902.189077221 868477.766629735)'), @@ -63,40 +57,111 @@ class RelatedGeoModelTest(unittest.TestCase): for name, srid, wkt in transformed: # Doing this implicitly sets `select_related` select the location. + # TODO: Fix why this breaks on Oracle. qs = list(City.objects.filter(name=name).transform(srid, field_name='location__point')) check_pnt(GEOSGeometry(wkt, srid), qs[0].location.point) - #settings.DEBUG= dbg - - # Verifying the number of issued SQL queries. - #self.assertEqual(nqueries, len(connection.queries)) @no_mysql def test04_related_aggregate(self): "Testing the `extent` and `unionagg` GeoQuerySet aggregates on related geographic models." - if postgis: - # One for all locations, one that excludes Roswell. - all_extent = (-104.528060913086, 33.0583305358887,-79.4607315063477, 40.1847610473633) - txpa_extent = (-97.51611328125, 33.0583305358887,-79.4607315063477, 40.1847610473633) - e1 = City.objects.extent(field_name='location__point') - e2 = City.objects.exclude(name='Roswell').extent(field_name='location__point') - for ref, e in [(all_extent, e1), (txpa_extent, e2)]: - for ref_val, e_val in zip(ref, e): self.assertAlmostEqual(ref_val, e_val) - # The second union is for a query that has something in the WHERE clause. - ref_u1 = GEOSGeometry('MULTIPOINT(-104.528056 33.387222,-97.516111 33.058333,-79.460734 40.18476)', 4326) - ref_u2 = GEOSGeometry('MULTIPOINT(-97.516111 33.058333,-79.460734 40.18476)', 4326) + # This combines the Extent and Union aggregates into one query + aggs = City.objects.aggregate(Extent('location__point'), Union('location__point')) + + # One for all locations, one that excludes Roswell. + all_extent = (-104.528060913086, 33.0583305358887,-79.4607315063477, 40.1847610473633) + txpa_extent = (-97.51611328125, 33.0583305358887,-79.4607315063477, 40.1847610473633) + e1 = City.objects.extent(field_name='location__point') + e2 = City.objects.exclude(name='Roswell').extent(field_name='location__point') + e3 = aggs['location__point__extent'] + + # The tolerance value is to four decimal places because of differences + # between the Oracle and PostGIS spatial backends on the extent calculation. + tol = 4 + for ref, e in [(all_extent, e1), (txpa_extent, e2), (all_extent, e3)]: + for ref_val, e_val in zip(ref, e): self.assertAlmostEqual(ref_val, e_val, tol) + + # These are the points that are components of the aggregate geographic + # union that is returned. + p1 = Point(-104.528056, 33.387222) + p2 = Point(-97.516111, 33.058333) + p3 = Point(-79.460734, 40.18476) + + # Creating the reference union geometry depending on the spatial backend, + # as Oracle will have a different internal ordering of the component + # geometries than PostGIS. The second union aggregate is for a union + # query that includes limiting information in the WHERE clause (in other + # words a `.filter()` precedes the call to `.unionagg()`). + if SpatialBackend.oracle: + ref_u1 = MultiPoint(p3, p1, p2, srid=4326) + ref_u2 = MultiPoint(p3, p2, srid=4326) + else: + ref_u1 = MultiPoint(p1, p2, p3, srid=4326) + ref_u2 = MultiPoint(p2, p3, srid=4326) + u1 = City.objects.unionagg(field_name='location__point') u2 = City.objects.exclude(name='Roswell').unionagg(field_name='location__point') + u3 = aggs['location__point__union'] + self.assertEqual(ref_u1, u1) self.assertEqual(ref_u2, u2) + self.assertEqual(ref_u1, u3) def test05_select_related_fk_to_subclass(self): "Testing that calling select_related on a query over a model with an FK to a model subclass works" # Regression test for #9752. l = list(DirectoryEntry.objects.all().select_related()) - # TODO: Related tests for KML, GML, and distance lookups. + def test6_f_expressions(self): + "Testing F() expressions on GeometryFields." + # Constructing a dummy parcel border and getting the City instance for + # assigning the FK. + b1 = GEOSGeometry('POLYGON((-97.501205 33.052520,-97.501205 33.052576,-97.501150 33.052576,-97.501150 33.052520,-97.501205 33.052520))', srid=4326) + pcity = City.objects.get(name='Aurora') + + # First parcel has incorrect center point that is equal to the City; + # it also has a second border that is different from the first as a + # 100ft buffer around the City. + c1 = pcity.location.point + c2 = c1.transform(2276, clone=True) + b2 = c2.buffer(100) + p1 = Parcel.objects.create(name='P1', city=pcity, center1=c1, center2=c2, border1=b1, border2=b2) + + # Now creating a second Parcel where the borders are the same, just + # in different coordinate systems. The center points are also the + # the same (but in different coordinate systems), and this time they + # actually correspond to the centroid of the border. + c1 = b1.centroid + c2 = c1.transform(2276, clone=True) + p2 = Parcel.objects.create(name='P2', city=pcity, center1=c1, center2=c2, border1=b1, border2=b1) + + # Should return the second Parcel, which has the center within the + # border. + qs = Parcel.objects.filter(center1__within=F('border1')) + self.assertEqual(1, len(qs)) + self.assertEqual('P2', qs[0].name) + if not SpatialBackend.mysql: + # This time center2 is in a different coordinate system and needs + # to be wrapped in transformation SQL. + qs = Parcel.objects.filter(center2__within=F('border1')) + self.assertEqual(1, len(qs)) + self.assertEqual('P2', qs[0].name) + + # Should return the first Parcel, which has the center point equal + # to the point in the City ForeignKey. + qs = Parcel.objects.filter(center1=F('city__location__point')) + self.assertEqual(1, len(qs)) + self.assertEqual('P1', qs[0].name) + + if not SpatialBackend.mysql: + # This time the city column should be wrapped in transformation SQL. + qs = Parcel.objects.filter(border2__contains=F('city__location__point')) + self.assertEqual(1, len(qs)) + self.assertEqual('P1', qs[0].name) + + # TODO: Related tests for KML, GML, and distance lookups. + def suite(): s = unittest.TestSuite() s.addTest(unittest.makeSuite(RelatedGeoModelTest))