Simplified handling of GIS lookup params.

This commit is contained in:
Sergey Fedoseev 2017-07-20 19:08:55 +05:00 committed by Tim Graham
parent 9415fcfef6
commit 3b56f2191d
4 changed files with 33 additions and 71 deletions

View File

@ -42,17 +42,11 @@ class PostGISOperator(SpatialOperator):
return super().as_sql(connection, lookup, template_params, *args) return super().as_sql(connection, lookup, template_params, *args)
def check_raster(self, lookup, template_params): def check_raster(self, lookup, template_params):
# Get rhs value. spheroid = lookup.rhs_params and lookup.rhs_params[-1] == 'spheroid'
if isinstance(lookup.rhs, (tuple, list)):
rhs_val = lookup.rhs[0]
spheroid = lookup.rhs[-1] == 'spheroid'
else:
rhs_val = lookup.rhs
spheroid = False
# Check which input is a raster. # Check which input is a raster.
lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER' lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER'
rhs_is_raster = isinstance(rhs_val, GDALRaster) rhs_is_raster = isinstance(lookup.rhs, GDALRaster)
# Look for band indices and inject them if provided. # Look for band indices and inject them if provided.
if lookup.band_lhs is not None and lhs_is_raster: if lookup.band_lhs is not None and lhs_is_raster:
@ -90,7 +84,7 @@ class PostGISDistanceOperator(PostGISOperator):
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
template_params = self.check_raster(lookup, template_params) template_params = self.check_raster(lookup, template_params)
sql_template = self.sql_template sql_template = self.sql_template
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid': if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid':
template_params.update({ template_params.update({
'op': self.op, 'op': self.op,
'func': connection.ops.spatial_function_name('DistanceSpheroid'), 'func': connection.ops.spatial_function_name('DistanceSpheroid'),

View File

@ -6,7 +6,6 @@ from django.contrib.gis.db.models.proxy import SpatialProxy
from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.geometry.backend import Geometry, GeometryException from django.contrib.gis.geometry.backend import Geometry, GeometryException
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.models.expressions import Expression
from django.db.models.fields import Field from django.db.models.fields import Field
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -181,24 +180,7 @@ class BaseSpatialField(Field):
raise ValueError("Couldn't create spatial object from lookup value '%s'." % value) raise ValueError("Couldn't create spatial object from lookup value '%s'." % value)
def get_prep_value(self, value): def get_prep_value(self, value):
""" obj = super().get_prep_value(value)
Spatial lookup values are either a parameter that is (or may be
converted to) a geometry or raster, or a sequence of lookup values
that begins with a geometry or raster. Set up the geometry or raster
value properly and preserves any other lookup parameters.
"""
value = super().get_prep_value(value)
# For IsValid lookups, boolean values are allowed.
if isinstance(value, (Expression, bool)):
return value
elif isinstance(value, (tuple, list)):
obj = value[0]
seq_value = True
else:
obj = value
seq_value = False
# When the input is not a geometry or raster, attempt to construct one # When the input is not a geometry or raster, attempt to construct one
# from the given string input. # from the given string input.
if isinstance(obj, Geometry): if isinstance(obj, Geometry):
@ -221,13 +203,7 @@ class BaseSpatialField(Field):
# Assigning the SRID value. # Assigning the SRID value.
obj.srid = self.get_srid(obj) obj.srid = self.get_srid(obj)
return obj
if seq_value:
lookup_val = [obj]
lookup_val.extend(value[1:])
return tuple(lookup_val)
else:
return obj
class GeometryField(GeoSelectFormatMixin, BaseSpatialField): class GeometryField(GeoSelectFormatMixin, BaseSpatialField):

View File

@ -18,9 +18,21 @@ class GISLookup(Lookup):
band_rhs = None band_rhs = None
band_lhs = None band_lhs = None
def __init__(self, *args, **kwargs): def __init__(self, lhs, rhs):
super().__init__(*args, **kwargs) rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
super().__init__(lhs, rhs)
self.template_params = {} self.template_params = {}
self.process_rhs_params()
def process_rhs_params(self):
if self.rhs_params:
# Check if a band index was passed in the query argument.
if len(self.rhs_params) == (2 if self.lookup_name == 'relate' else 1):
self.process_band_indices()
elif len(self.rhs_params) > 1:
raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
elif isinstance(self.lhs, RasterBandTransform):
self.process_band_indices(only_lhs=True)
def process_band_indices(self, only_lhs=False): def process_band_indices(self, only_lhs=False):
""" """
@ -39,20 +51,11 @@ class GISLookup(Lookup):
else: else:
self.band_lhs = 1 self.band_lhs = 1
self.band_rhs = self.rhs[1] self.band_rhs, *self.rhs_params = self.rhs_params
if len(self.rhs) == 1:
self.rhs = self.rhs[0]
else:
self.rhs = (self.rhs[0], ) + self.rhs[2:]
def get_db_prep_lookup(self, value, connection): def get_db_prep_lookup(self, value, connection):
# get_db_prep_lookup is called by process_rhs from super class # get_db_prep_lookup is called by process_rhs from super class
if isinstance(value, (tuple, list)): return ('%s', [connection.ops.Adapter(value)] + (self.rhs_params or []))
# First param is assumed to be the geometric object
params = [connection.ops.Adapter(value[0])] + list(value)[1:]
else:
params = [connection.ops.Adapter(value)]
return ('%s', params)
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
if isinstance(self.rhs, Query): if isinstance(self.rhs, Query):
@ -72,16 +75,6 @@ class GISLookup(Lookup):
return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, [] return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, []
elif isinstance(self.rhs, Expression): elif isinstance(self.rhs, Expression):
raise ValueError('Complex expressions not supported for spatial fields.') raise ValueError('Complex expressions not supported for spatial fields.')
elif isinstance(self.rhs, (list, tuple)):
geom = self.rhs[0]
# Check if a band index was passed in the query argument.
if ((len(self.rhs) == 2 and not self.lookup_name == 'relate') or
(len(self.rhs) == 3 and self.lookup_name == 'relate')):
self.process_band_indices()
elif len(self.rhs) > 2:
raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
elif isinstance(self.lhs, RasterBandTransform):
self.process_band_indices(only_lhs=True)
rhs, rhs_params = super().process_rhs(compiler, connection) rhs, rhs_params = super().process_rhs(compiler, connection)
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler)
@ -275,14 +268,14 @@ class RelateLookup(GISLookup):
pattern_regex = re.compile(r'^[012TF\*]{9}$') pattern_regex = re.compile(r'^[012TF\*]{9}$')
def get_db_prep_lookup(self, value, connection): def get_db_prep_lookup(self, value, connection):
if len(value) != 2: if len(self.rhs_params) != 1:
raise ValueError('relate must be passed a two-tuple') raise ValueError('relate must be passed a two-tuple')
# Check the pattern argument # Check the pattern argument
backend_op = connection.ops.gis_operators[self.lookup_name] backend_op = connection.ops.gis_operators[self.lookup_name]
if hasattr(backend_op, 'check_relate_argument'): if hasattr(backend_op, 'check_relate_argument'):
backend_op.check_relate_argument(value[1]) backend_op.check_relate_argument(self.rhs_params[0])
else: else:
pattern = value[1] pattern = self.rhs_params[0]
if not isinstance(pattern, str) or not self.pattern_regex.match(pattern): if not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
return super().get_db_prep_lookup(value, connection) return super().get_db_prep_lookup(value, connection)
@ -302,20 +295,20 @@ class DistanceLookupBase(GISLookup):
distance = True distance = True
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s' sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def process_rhs(self, compiler, connection): def process_rhs_params(self):
if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 4: if not 1 <= len(self.rhs_params) <= 3:
raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name) raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
elif len(self.rhs) == 4 and not self.rhs[3] == 'spheroid': elif len(self.rhs_params) == 3 and self.rhs_params[2] != 'spheroid':
raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.") raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.")
# Check if the second parameter is a band index. # Check if the second parameter is a band index.
if len(self.rhs) > 2 and not self.rhs[2] == 'spheroid': if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
self.process_band_indices() self.process_band_indices()
params = [connection.ops.Adapter(self.rhs[0])] def process_rhs(self, compiler, connection):
params = [connection.ops.Adapter(self.rhs)]
# Getting the distance parameter in the units of the field. # Getting the distance parameter in the units of the field.
dist_param = self.rhs[1] dist_param = self.rhs_params[0]
if hasattr(dist_param, 'resolve_expression'): if hasattr(dist_param, 'resolve_expression'):
dist_param = dist_param.resolve_expression(compiler.query) dist_param = dist_param.resolve_expression(compiler.query)
sql, expr_params = compiler.compile(dist_param) sql, expr_params = compiler.compile(dist_param)
@ -323,7 +316,7 @@ class DistanceLookupBase(GISLookup):
params.extend(expr_params) params.extend(expr_params)
else: else:
params += connection.ops.get_distance( params += connection.ops.get_distance(
self.lhs.output_field, (dist_param,) + self.rhs[2:], self.lhs.output_field, self.rhs_params,
self.lookup_name, self.lookup_name,
) )
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler) rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)

View File

@ -271,10 +271,9 @@ class RasterFieldTest(TransactionTestCase):
def test_lookup_input_tuple_too_long(self): def test_lookup_input_tuple_too_long(self):
rast = GDALRaster(json.loads(JSON_RASTER)) rast = GDALRaster(json.loads(JSON_RASTER))
qs = RasterModel.objects.filter(rast__bbcontains=(rast, 1, 2))
msg = 'Tuple too long for lookup bbcontains.' msg = 'Tuple too long for lookup bbcontains.'
with self.assertRaisesMessage(ValueError, msg): with self.assertRaisesMessage(ValueError, msg):
qs.count() RasterModel.objects.filter(rast__bbcontains=(rast, 1, 2))
def test_lookup_input_band_not_allowed(self): def test_lookup_input_band_not_allowed(self):
rast = GDALRaster(json.loads(JSON_RASTER)) rast = GDALRaster(json.loads(JSON_RASTER))