253 lines
7.4 KiB
Python
253 lines
7.4 KiB
Python
import datetime
|
|
import json
|
|
|
|
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
|
|
|
|
from django.contrib.postgres import forms, lookups
|
|
from django.db import models
|
|
|
|
from .utils import AttributeSetter
|
|
|
|
__all__ = [
|
|
'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
|
|
'FloatRangeField', 'DateTimeRangeField', 'DateRangeField',
|
|
]
|
|
|
|
|
|
class RangeField(models.Field):
|
|
empty_strings_allowed = False
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# Initializing base_field here ensures that its model matches the model for self.
|
|
if hasattr(self, 'base_field'):
|
|
self.base_field = self.base_field()
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@property
|
|
def model(self):
|
|
try:
|
|
return self.__dict__['model']
|
|
except KeyError:
|
|
raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
|
|
|
|
@model.setter
|
|
def model(self, model):
|
|
self.__dict__['model'] = model
|
|
self.base_field.model = model
|
|
|
|
def get_prep_value(self, value):
|
|
if value is None:
|
|
return None
|
|
elif isinstance(value, Range):
|
|
return value
|
|
elif isinstance(value, (list, tuple)):
|
|
return self.range_type(value[0], value[1])
|
|
return value
|
|
|
|
def to_python(self, value):
|
|
if isinstance(value, str):
|
|
# Assume we're deserializing
|
|
vals = json.loads(value)
|
|
for end in ('lower', 'upper'):
|
|
if end in vals:
|
|
vals[end] = self.base_field.to_python(vals[end])
|
|
value = self.range_type(**vals)
|
|
elif isinstance(value, (list, tuple)):
|
|
value = self.range_type(value[0], value[1])
|
|
return value
|
|
|
|
def set_attributes_from_name(self, name):
|
|
super().set_attributes_from_name(name)
|
|
self.base_field.set_attributes_from_name(name)
|
|
|
|
def value_to_string(self, obj):
|
|
value = self.value_from_object(obj)
|
|
if value is None:
|
|
return None
|
|
if value.isempty:
|
|
return json.dumps({"empty": True})
|
|
base_field = self.base_field
|
|
result = {"bounds": value._bounds}
|
|
for end in ('lower', 'upper'):
|
|
val = getattr(value, end)
|
|
if val is None:
|
|
result[end] = None
|
|
else:
|
|
obj = AttributeSetter(base_field.attname, val)
|
|
result[end] = base_field.value_to_string(obj)
|
|
return json.dumps(result)
|
|
|
|
def formfield(self, **kwargs):
|
|
kwargs.setdefault('form_class', self.form_field)
|
|
return super().formfield(**kwargs)
|
|
|
|
|
|
class IntegerRangeField(RangeField):
|
|
base_field = models.IntegerField
|
|
range_type = NumericRange
|
|
form_field = forms.IntegerRangeField
|
|
|
|
def db_type(self, connection):
|
|
return 'int4range'
|
|
|
|
|
|
class BigIntegerRangeField(RangeField):
|
|
base_field = models.BigIntegerField
|
|
range_type = NumericRange
|
|
form_field = forms.IntegerRangeField
|
|
|
|
def db_type(self, connection):
|
|
return 'int8range'
|
|
|
|
|
|
class FloatRangeField(RangeField):
|
|
base_field = models.FloatField
|
|
range_type = NumericRange
|
|
form_field = forms.FloatRangeField
|
|
|
|
def db_type(self, connection):
|
|
return 'numrange'
|
|
|
|
|
|
class DateTimeRangeField(RangeField):
|
|
base_field = models.DateTimeField
|
|
range_type = DateTimeTZRange
|
|
form_field = forms.DateTimeRangeField
|
|
|
|
def db_type(self, connection):
|
|
return 'tstzrange'
|
|
|
|
|
|
class DateRangeField(RangeField):
|
|
base_field = models.DateField
|
|
range_type = DateRange
|
|
form_field = forms.DateRangeField
|
|
|
|
def db_type(self, connection):
|
|
return 'daterange'
|
|
|
|
|
|
RangeField.register_lookup(lookups.DataContains)
|
|
RangeField.register_lookup(lookups.ContainedBy)
|
|
RangeField.register_lookup(lookups.Overlap)
|
|
|
|
|
|
class DateTimeRangeContains(models.Lookup):
|
|
"""
|
|
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
|
|
type.
|
|
"""
|
|
lookup_name = 'contains'
|
|
|
|
def process_rhs(self, compiler, connection):
|
|
# Transform rhs value for db lookup.
|
|
if isinstance(self.rhs, datetime.date):
|
|
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
|
|
value = models.Value(self.rhs, output_field=output_field)
|
|
self.rhs = value.resolve_expression(compiler.query)
|
|
return super().process_rhs(compiler, connection)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
lhs, lhs_params = self.process_lhs(compiler, connection)
|
|
rhs, rhs_params = self.process_rhs(compiler, connection)
|
|
params = lhs_params + rhs_params
|
|
# Cast the rhs if needed.
|
|
cast_sql = ''
|
|
if isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none:
|
|
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
|
|
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
|
|
return '%s @> %s%s' % (lhs, rhs, cast_sql), params
|
|
|
|
|
|
DateRangeField.register_lookup(DateTimeRangeContains)
|
|
DateTimeRangeField.register_lookup(DateTimeRangeContains)
|
|
|
|
|
|
class RangeContainedBy(models.Lookup):
|
|
lookup_name = 'contained_by'
|
|
type_mapping = {
|
|
'integer': 'int4range',
|
|
'bigint': 'int8range',
|
|
'double precision': 'numrange',
|
|
'date': 'daterange',
|
|
'timestamp with time zone': 'tstzrange',
|
|
}
|
|
|
|
def as_sql(self, qn, connection):
|
|
field = self.lhs.output_field
|
|
if isinstance(field, models.FloatField):
|
|
sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
|
|
else:
|
|
sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
|
|
lhs, lhs_params = self.process_lhs(qn, connection)
|
|
rhs, rhs_params = self.process_rhs(qn, connection)
|
|
params = lhs_params + rhs_params
|
|
return sql % (lhs, rhs), params
|
|
|
|
def get_prep_lookup(self):
|
|
return RangeField().get_prep_value(self.rhs)
|
|
|
|
|
|
models.DateField.register_lookup(RangeContainedBy)
|
|
models.DateTimeField.register_lookup(RangeContainedBy)
|
|
models.IntegerField.register_lookup(RangeContainedBy)
|
|
models.BigIntegerField.register_lookup(RangeContainedBy)
|
|
models.FloatField.register_lookup(RangeContainedBy)
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class FullyLessThan(lookups.PostgresSimpleLookup):
|
|
lookup_name = 'fully_lt'
|
|
operator = '<<'
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class FullGreaterThan(lookups.PostgresSimpleLookup):
|
|
lookup_name = 'fully_gt'
|
|
operator = '>>'
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class NotLessThan(lookups.PostgresSimpleLookup):
|
|
lookup_name = 'not_lt'
|
|
operator = '&>'
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class NotGreaterThan(lookups.PostgresSimpleLookup):
|
|
lookup_name = 'not_gt'
|
|
operator = '&<'
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class AdjacentToLookup(lookups.PostgresSimpleLookup):
|
|
lookup_name = 'adjacent_to'
|
|
operator = '-|-'
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class RangeStartsWith(models.Transform):
|
|
lookup_name = 'startswith'
|
|
function = 'lower'
|
|
|
|
@property
|
|
def output_field(self):
|
|
return self.lhs.output_field.base_field
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class RangeEndsWith(models.Transform):
|
|
lookup_name = 'endswith'
|
|
function = 'upper'
|
|
|
|
@property
|
|
def output_field(self):
|
|
return self.lhs.output_field.base_field
|
|
|
|
|
|
@RangeField.register_lookup
|
|
class IsEmpty(models.Transform):
|
|
lookup_name = 'isempty'
|
|
function = 'isempty'
|
|
output_field = models.BooleanField()
|