Simplified handling of GIS lookup params.
This commit is contained in:
parent
9415fcfef6
commit
3b56f2191d
|
@ -42,17 +42,11 @@ class PostGISOperator(SpatialOperator):
|
||||||
return super().as_sql(connection, lookup, template_params, *args)
|
return super().as_sql(connection, lookup, template_params, *args)
|
||||||
|
|
||||||
def check_raster(self, lookup, template_params):
|
def check_raster(self, lookup, template_params):
|
||||||
# Get rhs value.
|
spheroid = lookup.rhs_params and lookup.rhs_params[-1] == 'spheroid'
|
||||||
if isinstance(lookup.rhs, (tuple, list)):
|
|
||||||
rhs_val = lookup.rhs[0]
|
|
||||||
spheroid = lookup.rhs[-1] == 'spheroid'
|
|
||||||
else:
|
|
||||||
rhs_val = lookup.rhs
|
|
||||||
spheroid = False
|
|
||||||
|
|
||||||
# Check which input is a raster.
|
# Check which input is a raster.
|
||||||
lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER'
|
lhs_is_raster = lookup.lhs.field.geom_type == 'RASTER'
|
||||||
rhs_is_raster = isinstance(rhs_val, GDALRaster)
|
rhs_is_raster = isinstance(lookup.rhs, GDALRaster)
|
||||||
|
|
||||||
# Look for band indices and inject them if provided.
|
# Look for band indices and inject them if provided.
|
||||||
if lookup.band_lhs is not None and lhs_is_raster:
|
if lookup.band_lhs is not None and lhs_is_raster:
|
||||||
|
@ -90,7 +84,7 @@ class PostGISDistanceOperator(PostGISOperator):
|
||||||
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
|
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
|
||||||
template_params = self.check_raster(lookup, template_params)
|
template_params = self.check_raster(lookup, template_params)
|
||||||
sql_template = self.sql_template
|
sql_template = self.sql_template
|
||||||
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
|
if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid':
|
||||||
template_params.update({
|
template_params.update({
|
||||||
'op': self.op,
|
'op': self.op,
|
||||||
'func': connection.ops.spatial_function_name('DistanceSpheroid'),
|
'func': connection.ops.spatial_function_name('DistanceSpheroid'),
|
||||||
|
|
|
@ -6,7 +6,6 @@ from django.contrib.gis.db.models.proxy import SpatialProxy
|
||||||
from django.contrib.gis.gdal.error import GDALException
|
from django.contrib.gis.gdal.error import GDALException
|
||||||
from django.contrib.gis.geometry.backend import Geometry, GeometryException
|
from django.contrib.gis.geometry.backend import Geometry, GeometryException
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.db.models.expressions import Expression
|
|
||||||
from django.db.models.fields import Field
|
from django.db.models.fields import Field
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
@ -181,24 +180,7 @@ class BaseSpatialField(Field):
|
||||||
raise ValueError("Couldn't create spatial object from lookup value '%s'." % value)
|
raise ValueError("Couldn't create spatial object from lookup value '%s'." % value)
|
||||||
|
|
||||||
def get_prep_value(self, value):
|
def get_prep_value(self, value):
|
||||||
"""
|
obj = super().get_prep_value(value)
|
||||||
Spatial lookup values are either a parameter that is (or may be
|
|
||||||
converted to) a geometry or raster, or a sequence of lookup values
|
|
||||||
that begins with a geometry or raster. Set up the geometry or raster
|
|
||||||
value properly and preserves any other lookup parameters.
|
|
||||||
"""
|
|
||||||
value = super().get_prep_value(value)
|
|
||||||
|
|
||||||
# For IsValid lookups, boolean values are allowed.
|
|
||||||
if isinstance(value, (Expression, bool)):
|
|
||||||
return value
|
|
||||||
elif isinstance(value, (tuple, list)):
|
|
||||||
obj = value[0]
|
|
||||||
seq_value = True
|
|
||||||
else:
|
|
||||||
obj = value
|
|
||||||
seq_value = False
|
|
||||||
|
|
||||||
# When the input is not a geometry or raster, attempt to construct one
|
# When the input is not a geometry or raster, attempt to construct one
|
||||||
# from the given string input.
|
# from the given string input.
|
||||||
if isinstance(obj, Geometry):
|
if isinstance(obj, Geometry):
|
||||||
|
@ -221,12 +203,6 @@ class BaseSpatialField(Field):
|
||||||
|
|
||||||
# Assigning the SRID value.
|
# Assigning the SRID value.
|
||||||
obj.srid = self.get_srid(obj)
|
obj.srid = self.get_srid(obj)
|
||||||
|
|
||||||
if seq_value:
|
|
||||||
lookup_val = [obj]
|
|
||||||
lookup_val.extend(value[1:])
|
|
||||||
return tuple(lookup_val)
|
|
||||||
else:
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,21 @@ class GISLookup(Lookup):
|
||||||
band_rhs = None
|
band_rhs = None
|
||||||
band_lhs = None
|
band_lhs = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, lhs, rhs):
|
||||||
super().__init__(*args, **kwargs)
|
rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
|
||||||
|
super().__init__(lhs, rhs)
|
||||||
self.template_params = {}
|
self.template_params = {}
|
||||||
|
self.process_rhs_params()
|
||||||
|
|
||||||
|
def process_rhs_params(self):
|
||||||
|
if self.rhs_params:
|
||||||
|
# Check if a band index was passed in the query argument.
|
||||||
|
if len(self.rhs_params) == (2 if self.lookup_name == 'relate' else 1):
|
||||||
|
self.process_band_indices()
|
||||||
|
elif len(self.rhs_params) > 1:
|
||||||
|
raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
|
||||||
|
elif isinstance(self.lhs, RasterBandTransform):
|
||||||
|
self.process_band_indices(only_lhs=True)
|
||||||
|
|
||||||
def process_band_indices(self, only_lhs=False):
|
def process_band_indices(self, only_lhs=False):
|
||||||
"""
|
"""
|
||||||
|
@ -39,20 +51,11 @@ class GISLookup(Lookup):
|
||||||
else:
|
else:
|
||||||
self.band_lhs = 1
|
self.band_lhs = 1
|
||||||
|
|
||||||
self.band_rhs = self.rhs[1]
|
self.band_rhs, *self.rhs_params = self.rhs_params
|
||||||
if len(self.rhs) == 1:
|
|
||||||
self.rhs = self.rhs[0]
|
|
||||||
else:
|
|
||||||
self.rhs = (self.rhs[0], ) + self.rhs[2:]
|
|
||||||
|
|
||||||
def get_db_prep_lookup(self, value, connection):
|
def get_db_prep_lookup(self, value, connection):
|
||||||
# get_db_prep_lookup is called by process_rhs from super class
|
# get_db_prep_lookup is called by process_rhs from super class
|
||||||
if isinstance(value, (tuple, list)):
|
return ('%s', [connection.ops.Adapter(value)] + (self.rhs_params or []))
|
||||||
# First param is assumed to be the geometric object
|
|
||||||
params = [connection.ops.Adapter(value[0])] + list(value)[1:]
|
|
||||||
else:
|
|
||||||
params = [connection.ops.Adapter(value)]
|
|
||||||
return ('%s', params)
|
|
||||||
|
|
||||||
def process_rhs(self, compiler, connection):
|
def process_rhs(self, compiler, connection):
|
||||||
if isinstance(self.rhs, Query):
|
if isinstance(self.rhs, Query):
|
||||||
|
@ -72,16 +75,6 @@ class GISLookup(Lookup):
|
||||||
return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, []
|
return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, []
|
||||||
elif isinstance(self.rhs, Expression):
|
elif isinstance(self.rhs, Expression):
|
||||||
raise ValueError('Complex expressions not supported for spatial fields.')
|
raise ValueError('Complex expressions not supported for spatial fields.')
|
||||||
elif isinstance(self.rhs, (list, tuple)):
|
|
||||||
geom = self.rhs[0]
|
|
||||||
# Check if a band index was passed in the query argument.
|
|
||||||
if ((len(self.rhs) == 2 and not self.lookup_name == 'relate') or
|
|
||||||
(len(self.rhs) == 3 and self.lookup_name == 'relate')):
|
|
||||||
self.process_band_indices()
|
|
||||||
elif len(self.rhs) > 2:
|
|
||||||
raise ValueError('Tuple too long for lookup %s.' % self.lookup_name)
|
|
||||||
elif isinstance(self.lhs, RasterBandTransform):
|
|
||||||
self.process_band_indices(only_lhs=True)
|
|
||||||
|
|
||||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||||
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler)
|
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler)
|
||||||
|
@ -275,14 +268,14 @@ class RelateLookup(GISLookup):
|
||||||
pattern_regex = re.compile(r'^[012TF\*]{9}$')
|
pattern_regex = re.compile(r'^[012TF\*]{9}$')
|
||||||
|
|
||||||
def get_db_prep_lookup(self, value, connection):
|
def get_db_prep_lookup(self, value, connection):
|
||||||
if len(value) != 2:
|
if len(self.rhs_params) != 1:
|
||||||
raise ValueError('relate must be passed a two-tuple')
|
raise ValueError('relate must be passed a two-tuple')
|
||||||
# Check the pattern argument
|
# Check the pattern argument
|
||||||
backend_op = connection.ops.gis_operators[self.lookup_name]
|
backend_op = connection.ops.gis_operators[self.lookup_name]
|
||||||
if hasattr(backend_op, 'check_relate_argument'):
|
if hasattr(backend_op, 'check_relate_argument'):
|
||||||
backend_op.check_relate_argument(value[1])
|
backend_op.check_relate_argument(self.rhs_params[0])
|
||||||
else:
|
else:
|
||||||
pattern = value[1]
|
pattern = self.rhs_params[0]
|
||||||
if not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
|
if not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
|
||||||
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
|
raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
|
||||||
return super().get_db_prep_lookup(value, connection)
|
return super().get_db_prep_lookup(value, connection)
|
||||||
|
@ -302,20 +295,20 @@ class DistanceLookupBase(GISLookup):
|
||||||
distance = True
|
distance = True
|
||||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
|
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
|
||||||
|
|
||||||
def process_rhs(self, compiler, connection):
|
def process_rhs_params(self):
|
||||||
if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 4:
|
if not 1 <= len(self.rhs_params) <= 3:
|
||||||
raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
|
raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name)
|
||||||
elif len(self.rhs) == 4 and not self.rhs[3] == 'spheroid':
|
elif len(self.rhs_params) == 3 and self.rhs_params[2] != 'spheroid':
|
||||||
raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.")
|
raise ValueError("For 4-element tuples the last argument must be the 'spheroid' directive.")
|
||||||
|
|
||||||
# Check if the second parameter is a band index.
|
# Check if the second parameter is a band index.
|
||||||
if len(self.rhs) > 2 and not self.rhs[2] == 'spheroid':
|
if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
|
||||||
self.process_band_indices()
|
self.process_band_indices()
|
||||||
|
|
||||||
params = [connection.ops.Adapter(self.rhs[0])]
|
def process_rhs(self, compiler, connection):
|
||||||
|
params = [connection.ops.Adapter(self.rhs)]
|
||||||
# Getting the distance parameter in the units of the field.
|
# Getting the distance parameter in the units of the field.
|
||||||
dist_param = self.rhs[1]
|
dist_param = self.rhs_params[0]
|
||||||
if hasattr(dist_param, 'resolve_expression'):
|
if hasattr(dist_param, 'resolve_expression'):
|
||||||
dist_param = dist_param.resolve_expression(compiler.query)
|
dist_param = dist_param.resolve_expression(compiler.query)
|
||||||
sql, expr_params = compiler.compile(dist_param)
|
sql, expr_params = compiler.compile(dist_param)
|
||||||
|
@ -323,7 +316,7 @@ class DistanceLookupBase(GISLookup):
|
||||||
params.extend(expr_params)
|
params.extend(expr_params)
|
||||||
else:
|
else:
|
||||||
params += connection.ops.get_distance(
|
params += connection.ops.get_distance(
|
||||||
self.lhs.output_field, (dist_param,) + self.rhs[2:],
|
self.lhs.output_field, self.rhs_params,
|
||||||
self.lookup_name,
|
self.lookup_name,
|
||||||
)
|
)
|
||||||
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
|
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
|
||||||
|
|
|
@ -271,10 +271,9 @@ class RasterFieldTest(TransactionTestCase):
|
||||||
|
|
||||||
def test_lookup_input_tuple_too_long(self):
|
def test_lookup_input_tuple_too_long(self):
|
||||||
rast = GDALRaster(json.loads(JSON_RASTER))
|
rast = GDALRaster(json.loads(JSON_RASTER))
|
||||||
qs = RasterModel.objects.filter(rast__bbcontains=(rast, 1, 2))
|
|
||||||
msg = 'Tuple too long for lookup bbcontains.'
|
msg = 'Tuple too long for lookup bbcontains.'
|
||||||
with self.assertRaisesMessage(ValueError, msg):
|
with self.assertRaisesMessage(ValueError, msg):
|
||||||
qs.count()
|
RasterModel.objects.filter(rast__bbcontains=(rast, 1, 2))
|
||||||
|
|
||||||
def test_lookup_input_band_not_allowed(self):
|
def test_lookup_input_band_not_allowed(self):
|
||||||
rast = GDALRaster(json.loads(JSON_RASTER))
|
rast = GDALRaster(json.loads(JSON_RASTER))
|
||||||
|
|
Loading…
Reference in New Issue