django1/django/contrib/postgres/fields/ranges.py

253 lines
7.4 KiB
Python
Raw Normal View History

import datetime
import json
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range
from django.contrib.postgres import forms, lookups
from django.db import models
from .utils import AttributeSetter
__all__ = [
'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
'FloatRangeField', 'DateTimeRangeField', 'DateRangeField',
]
class RangeField(models.Field):
empty_strings_allowed = False
def __init__(self, *args, **kwargs):
# Initializing base_field here ensures that its model matches the model for self.
if hasattr(self, 'base_field'):
self.base_field = self.base_field()
super().__init__(*args, **kwargs)
@property
def model(self):
try:
return self.__dict__['model']
except KeyError:
raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
@model.setter
def model(self, model):
self.__dict__['model'] = model
self.base_field.model = model
def get_prep_value(self, value):
if value is None:
return None
elif isinstance(value, Range):
return value
elif isinstance(value, (list, tuple)):
return self.range_type(value[0], value[1])
return value
def to_python(self, value):
if isinstance(value, str):
# Assume we're deserializing
vals = json.loads(value)
for end in ('lower', 'upper'):
if end in vals:
vals[end] = self.base_field.to_python(vals[end])
value = self.range_type(**vals)
elif isinstance(value, (list, tuple)):
value = self.range_type(value[0], value[1])
return value
def set_attributes_from_name(self, name):
super().set_attributes_from_name(name)
self.base_field.set_attributes_from_name(name)
def value_to_string(self, obj):
value = self.value_from_object(obj)
if value is None:
return None
if value.isempty:
return json.dumps({"empty": True})
base_field = self.base_field
result = {"bounds": value._bounds}
for end in ('lower', 'upper'):
val = getattr(value, end)
if val is None:
result[end] = None
else:
obj = AttributeSetter(base_field.attname, val)
result[end] = base_field.value_to_string(obj)
return json.dumps(result)
def formfield(self, **kwargs):
kwargs.setdefault('form_class', self.form_field)
return super().formfield(**kwargs)
class IntegerRangeField(RangeField):
base_field = models.IntegerField
range_type = NumericRange
form_field = forms.IntegerRangeField
def db_type(self, connection):
return 'int4range'
class BigIntegerRangeField(RangeField):
base_field = models.BigIntegerField
range_type = NumericRange
form_field = forms.IntegerRangeField
def db_type(self, connection):
return 'int8range'
class FloatRangeField(RangeField):
base_field = models.FloatField
range_type = NumericRange
form_field = forms.FloatRangeField
def db_type(self, connection):
return 'numrange'
class DateTimeRangeField(RangeField):
base_field = models.DateTimeField
range_type = DateTimeTZRange
form_field = forms.DateTimeRangeField
def db_type(self, connection):
return 'tstzrange'
class DateRangeField(RangeField):
base_field = models.DateField
range_type = DateRange
form_field = forms.DateRangeField
def db_type(self, connection):
return 'daterange'
RangeField.register_lookup(lookups.DataContains)
RangeField.register_lookup(lookups.ContainedBy)
RangeField.register_lookup(lookups.Overlap)
class DateTimeRangeContains(models.Lookup):
"""
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
type.
"""
lookup_name = 'contains'
def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup.
if isinstance(self.rhs, datetime.date):
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
value = models.Value(self.rhs, output_field=output_field)
self.rhs = value.resolve_expression(compiler.query)
return super().process_rhs(compiler, connection)
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
# Cast the rhs if needed.
cast_sql = ''
if isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none:
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
return '%s @> %s%s' % (lhs, rhs, cast_sql), params
DateRangeField.register_lookup(DateTimeRangeContains)
DateTimeRangeField.register_lookup(DateTimeRangeContains)
class RangeContainedBy(models.Lookup):
lookup_name = 'contained_by'
type_mapping = {
'integer': 'int4range',
'bigint': 'int8range',
'double precision': 'numrange',
'date': 'daterange',
'timestamp with time zone': 'tstzrange',
}
def as_sql(self, qn, connection):
field = self.lhs.output_field
if isinstance(field, models.FloatField):
sql = '%s::numeric <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
else:
sql = '%s <@ %s::{}'.format(self.type_mapping[field.db_type(connection)])
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return sql % (lhs, rhs), params
def get_prep_lookup(self):
return RangeField().get_prep_value(self.rhs)
models.DateField.register_lookup(RangeContainedBy)
models.DateTimeField.register_lookup(RangeContainedBy)
models.IntegerField.register_lookup(RangeContainedBy)
models.BigIntegerField.register_lookup(RangeContainedBy)
models.FloatField.register_lookup(RangeContainedBy)
@RangeField.register_lookup
class FullyLessThan(lookups.PostgresSimpleLookup):
lookup_name = 'fully_lt'
operator = '<<'
@RangeField.register_lookup
class FullGreaterThan(lookups.PostgresSimpleLookup):
lookup_name = 'fully_gt'
operator = '>>'
@RangeField.register_lookup
class NotLessThan(lookups.PostgresSimpleLookup):
lookup_name = 'not_lt'
operator = '&>'
@RangeField.register_lookup
class NotGreaterThan(lookups.PostgresSimpleLookup):
lookup_name = 'not_gt'
operator = '&<'
@RangeField.register_lookup
class AdjacentToLookup(lookups.PostgresSimpleLookup):
lookup_name = 'adjacent_to'
operator = '-|-'
@RangeField.register_lookup
class RangeStartsWith(models.Transform):
lookup_name = 'startswith'
function = 'lower'
@property
def output_field(self):
return self.lhs.output_field.base_field
@RangeField.register_lookup
class RangeEndsWith(models.Transform):
lookup_name = 'endswith'
function = 'upper'
@property
def output_field(self):
return self.lhs.output_field.base_field
@RangeField.register_lookup
class IsEmpty(models.Transform):
lookup_name = 'isempty'
function = 'isempty'
output_field = models.BooleanField()