From 3b56f2191df0a437740182d49efe3be16c4d0d58 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Thu, 20 Jul 2017 19:08:55 +0500 Subject: [PATCH] Simplified handling of GIS lookup params. --- .../gis/db/backends/postgis/operations.py | 12 +--- django/contrib/gis/db/models/fields.py | 28 +-------- django/contrib/gis/db/models/lookups.py | 61 ++++++++----------- tests/gis_tests/rasterapp/test_rasterfield.py | 3 +- 4 files changed, 33 insertions(+), 71 deletions(-) diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index da673312db..2d27be14ca 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -42,17 +42,11 @@ class PostGISOperator(SpatialOperator): return super().as_sql(connection, lookup, template_params, *args) def check_raster(self, lookup, template_params): - # Get rhs value. - if isinstance(lookup.rhs, (tuple, list)): - rhs_val = lookup.rhs[0] - spheroid = lookup.rhs[-1] == 'spheroid' - else: - rhs_val = lookup.rhs - spheroid = False + spheroid = lookup.rhs_params and lookup.rhs_params[-1] == 'spheroid' # Check which input is a raster. lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER' - rhs_is_raster = isinstance(rhs_val, GDALRaster) + rhs_is_raster = isinstance(lookup.rhs, GDALRaster) # Look for band indices and inject them if provided. if lookup.band_lhs is not None and lhs_is_raster: @@ -90,7 +84,7 @@ class PostGISDistanceOperator(PostGISOperator): if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): template_params = self.check_raster(lookup, template_params) sql_template = self.sql_template - if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid': + if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid': template_params.update({ 'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSpheroid'), diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index a639f3c176..9f8c5a1e16 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -6,7 +6,6 @@ from django.contrib.gis.db.models.proxy import SpatialProxy from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.geometry.backend import Geometry, GeometryException from django.core.exceptions import ImproperlyConfigured -from django.db.models.expressions import Expression from django.db.models.fields import Field from django.utils.translation import gettext_lazy as _ @@ -181,24 +180,7 @@ class BaseSpatialField(Field): raise ValueError("Couldn't create spatial object from lookup value '%s'." % value) def get_prep_value(self, value): - """ - Spatial lookup values are either a parameter that is (or may be - converted to) a geometry or raster, or a sequence of lookup values - that begins with a geometry or raster. Set up the geometry or raster - value properly and preserves any other lookup parameters. - """ - value = super().get_prep_value(value) - - # For IsValid lookups, boolean values are allowed. - if isinstance(value, (Expression, bool)): - return value - elif isinstance(value, (tuple, list)): - obj = value[0] - seq_value = True - else: - obj = value - seq_value = False - + obj = super().get_prep_value(value) # When the input is not a geometry or raster, attempt to construct one # from the given string input. if isinstance(obj, Geometry): @@ -221,13 +203,7 @@ class BaseSpatialField(Field): # Assigning the SRID value. obj.srid = self.get_srid(obj) - - if seq_value: - lookup_val = [obj] - lookup_val.extend(value[1:]) - return tuple(lookup_val) - else: - return obj + return obj class GeometryField(GeoSelectFormatMixin, BaseSpatialField): diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index 42089f64d2..12caf80359 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -18,9 +18,21 @@ class GISLookup(Lookup): band_rhs = None band_lhs = None - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, lhs, rhs): + rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs] + super().__init__(lhs, rhs) self.template_params = {} + self.process_rhs_params() + + def process_rhs_params(self): + if self.rhs_params: + # Check if a band index was passed in the query argument. + if len(self.rhs_params) == (2 if self.lookup_name == 'relate' else 1): + self.process_band_indices() + elif len(self.rhs_params) > 1: + raise ValueError('Tuple too long for lookup %s.' % self.lookup_name) + elif isinstance(self.lhs, RasterBandTransform): + self.process_band_indices(only_lhs=True) def process_band_indices(self, only_lhs=False): """ @@ -39,20 +51,11 @@ class GISLookup(Lookup): else: self.band_lhs = 1 - self.band_rhs = self.rhs[1] - if len(self.rhs) == 1: - self.rhs = self.rhs[0] - else: - self.rhs = (self.rhs[0], ) + self.rhs[2:] + self.band_rhs, *self.rhs_params = self.rhs_params def get_db_prep_lookup(self, value, connection): # get_db_prep_lookup is called by process_rhs from super class - if isinstance(value, (tuple, list)): - # First param is assumed to be the geometric object - params = [connection.ops.Adapter(value[0])] + list(value)[1:] - else: - params = [connection.ops.Adapter(value)] - return ('%s', params) + return ('%s', [connection.ops.Adapter(value)] + (self.rhs_params or [])) def process_rhs(self, compiler, connection): if isinstance(self.rhs, Query): @@ -72,16 +75,6 @@ class GISLookup(Lookup): return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, [] elif isinstance(self.rhs, Expression): raise ValueError('Complex expressions not supported for spatial fields.') - elif isinstance(self.rhs, (list, tuple)): - geom = self.rhs[0] - # Check if a band index was passed in the query argument. - if ((len(self.rhs) == 2 and not self.lookup_name == 'relate') or - (len(self.rhs) == 3 and self.lookup_name == 'relate')): - self.process_band_indices() - elif len(self.rhs) > 2: - raise ValueError('Tuple too long for lookup %s.' % self.lookup_name) - elif isinstance(self.lhs, RasterBandTransform): - self.process_band_indices(only_lhs=True) rhs, rhs_params = super().process_rhs(compiler, connection) rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) @@ -275,14 +268,14 @@ class RelateLookup(GISLookup): pattern_regex = re.compile(r'^[012TF\*]{9}$') def get_db_prep_lookup(self, value, connection): - if len(value) != 2: + if len(self.rhs_params) != 1: raise ValueError('relate must be passed a two-tuple') # Check the pattern argument backend_op = connection.ops.gis_operators[self.lookup_name] if hasattr(backend_op, 'check_relate_argument'): - backend_op.check_relate_argument(value[1]) + backend_op.check_relate_argument(self.rhs_params[0]) else: - pattern = value[1] + pattern = self.rhs_params[0] if not isinstance(pattern, str) or not self.pattern_regex.match(pattern): raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) return super().get_db_prep_lookup(value, connection) @@ -302,20 +295,20 @@ class DistanceLookupBase(GISLookup): distance = True sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s' - def process_rhs(self, compiler, connection): - if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 4: + def process_rhs_params(self): + if not 1 <= len(self.rhs_params) <= 3: raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name) - elif len(self.rhs) == 4 and not self.rhs[3] == 'spheroid': + elif len(self.rhs_params) == 3 and self.rhs_params[2] != 'spheroid': raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.") # Check if the second parameter is a band index. - if len(self.rhs) > 2 and not self.rhs[2] == 'spheroid': + if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid': self.process_band_indices() - params = [connection.ops.Adapter(self.rhs[0])] - + def process_rhs(self, compiler, connection): + params = [connection.ops.Adapter(self.rhs)] # Getting the distance parameter in the units of the field. - dist_param = self.rhs[1] + dist_param = self.rhs_params[0] if hasattr(dist_param, 'resolve_expression'): dist_param = dist_param.resolve_expression(compiler.query) sql, expr_params = compiler.compile(dist_param) @@ -323,7 +316,7 @@ class DistanceLookupBase(GISLookup): params.extend(expr_params) else: params += connection.ops.get_distance( - self.lhs.output_field, (dist_param,) + self.rhs[2:], + self.lhs.output_field, self.rhs_params, self.lookup_name, ) rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler) diff --git a/tests/gis_tests/rasterapp/test_rasterfield.py b/tests/gis_tests/rasterapp/test_rasterfield.py index 4fba577f5e..710305fe62 100644 --- a/tests/gis_tests/rasterapp/test_rasterfield.py +++ b/tests/gis_tests/rasterapp/test_rasterfield.py @@ -271,10 +271,9 @@ class RasterFieldTest(TransactionTestCase): def test_lookup_input_tuple_too_long(self): rast = GDALRaster(json.loads(JSON_RASTER)) - qs = RasterModel.objects.filter(rast__bbcontains=(rast, 1, 2)) msg = 'Tuple too long for lookup bbcontains.' with self.assertRaisesMessage(ValueError, msg): - qs.count() + RasterModel.objects.filter(rast__bbcontains=(rast, 1, 2)) def test_lookup_input_band_not_allowed(self): rast = GDALRaster(json.loads(JSON_RASTER))