Fixed #18757, #14462, #21565 -- Reworked database-python type conversions

Complete rework of translating data values from database

Deprecation of SubfieldBase, removal of resolve_columns and
convert_values in favour of a more general converter based approach and
public API Field.from_db_value(). Now works seamlessly with aggregation,
.values() and raw queries.

Thanks to akaariai in particular for extensive advice and inspiration,
also to shaib, manfre and timograham for their reviews.
This commit is contained in:
Marc Tamlyn 2014-08-12 13:08:40 +01:00
parent 89559bcfb0
commit e9103402c0
35 changed files with 443 additions and 521 deletions

View File

@ -1,41 +0,0 @@
from django.contrib.gis.db.models.sql.compiler import GeoSQLCompiler as BaseGeoSQLCompiler
from django.db.backends.mysql import compiler
SQLCompiler = compiler.SQLCompiler
class GeoSQLCompiler(BaseGeoSQLCompiler, SQLCompiler):
def resolve_columns(self, row, fields=()):
"""
Integrate the cases handled both by the base GeoSQLCompiler and the
main MySQL compiler (converting 0/1 to True/False for boolean fields).
Refs #15169.
"""
row = BaseGeoSQLCompiler.resolve_columns(self, row, fields)
return SQLCompiler.resolve_columns(self, row, fields)
class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
pass
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
pass
class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
pass
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
pass

View File

@ -6,7 +6,7 @@ from django.contrib.gis.db.backends.base import BaseSpatialOperations
class MySQLOperations(DatabaseOperations, BaseSpatialOperations): class MySQLOperations(DatabaseOperations, BaseSpatialOperations):
compiler_module = 'django.contrib.gis.db.backends.mysql.compiler' compiler_module = 'django.contrib.gis.db.models.sql.compiler'
mysql = True mysql = True
name = 'mysql' name = 'mysql'
select = 'AsText(%s)' select = 'AsText(%s)'

View File

@ -197,6 +197,11 @@ class GeometryField(Field):
else: else:
return geom return geom
def from_db_value(self, value, connection):
if value is not None:
value = Geometry(value)
return value
def get_srid(self, geom): def get_srid(self, geom):
""" """
Returns the default SRID for the given geometry, taking into account Returns the default SRID for the given geometry, taking into account

View File

@ -1,5 +1,5 @@
from django.db import connections from django.db import connections
from django.db.models.query import QuerySet, ValuesQuerySet, ValuesListQuerySet from django.db.models.query import QuerySet
from django.contrib.gis.db.models import aggregates from django.contrib.gis.db.models import aggregates
from django.contrib.gis.db.models.fields import get_srid_info, PointField, LineStringField from django.contrib.gis.db.models.fields import get_srid_info, PointField, LineStringField
@ -18,19 +18,6 @@ class GeoQuerySet(QuerySet):
super(GeoQuerySet, self).__init__(model=model, query=query, using=using, hints=hints) super(GeoQuerySet, self).__init__(model=model, query=query, using=using, hints=hints)
self.query = query or GeoQuery(self.model) self.query = query or GeoQuery(self.model)
def values(self, *fields):
return self._clone(klass=GeoValuesQuerySet, setup=True, _fields=fields)
def values_list(self, *fields, **kwargs):
flat = kwargs.pop('flat', False)
if kwargs:
raise TypeError('Unexpected keyword arguments to values_list: %s'
% (list(kwargs),))
if flat and len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
return self._clone(klass=GeoValuesListQuerySet, setup=True, flat=flat,
_fields=fields)
### GeoQuerySet Methods ### ### GeoQuerySet Methods ###
def area(self, tolerance=0.05, **kwargs): def area(self, tolerance=0.05, **kwargs):
""" """
@ -767,16 +754,3 @@ class GeoQuerySet(QuerySet):
return self.query.get_compiler(self.db)._field_column(geo_field, parent_model._meta.db_table) return self.query.get_compiler(self.db)._field_column(geo_field, parent_model._meta.db_table)
else: else:
return self.query.get_compiler(self.db)._field_column(geo_field) return self.query.get_compiler(self.db)._field_column(geo_field)
class GeoValuesQuerySet(ValuesQuerySet):
def __init__(self, *args, **kwargs):
super(GeoValuesQuerySet, self).__init__(*args, **kwargs)
# This flag tells `resolve_columns` to run the values through
# `convert_values`. This ensures that Geometry objects instead
# of string values are returned with `values()` or `values_list()`.
self.query.geo_values = True
class GeoValuesListQuerySet(GeoValuesQuerySet, ValuesListQuerySet):
pass

View File

@ -1,12 +1,6 @@
import datetime from django.db.backends.utils import truncate_name
from django.conf import settings
from django.db.backends.utils import truncate_name, typecast_date, typecast_timestamp
from django.db.models.sql import compiler from django.db.models.sql import compiler
from django.db.models.sql.constants import MULTI
from django.utils import six from django.utils import six
from django.utils.six.moves import zip, zip_longest
from django.utils import timezone
SQLCompiler = compiler.SQLCompiler SQLCompiler = compiler.SQLCompiler
@ -153,38 +147,13 @@ class GeoSQLCompiler(compiler.SQLCompiler):
col_aliases.add(field.column) col_aliases.add(field.column)
return result, aliases return result, aliases
def resolve_columns(self, row, fields=()): def get_converters(self, fields):
""" converters = super(GeoSQLCompiler, self).get_converters(fields)
This routine is necessary so that distances and geometries returned for i, alias in enumerate(self.query.extra_select):
from extra selection SQL get resolved appropriately into Python field = self.query.extra_select_fields.get(alias)
objects. if field:
""" converters[i] = ([], [field.from_db_value], field)
values = [] return converters
aliases = list(self.query.extra_select)
# Have to set a starting row number offset that is used for
# determining the correct starting row index -- needed for
# doing pagination with Oracle.
rn_offset = 0
if self.connection.ops.oracle:
if self.query.high_mark is not None or self.query.low_mark:
rn_offset = 1
index_start = rn_offset + len(aliases)
# Converting any extra selection values (e.g., geometries and
# distance objects added by GeoQuerySet methods).
values = [self.query.convert_values(v,
self.query.extra_select_fields.get(a, None),
self.connection)
for v, a in zip(row[rn_offset:index_start], aliases)]
if self.connection.ops.oracle or getattr(self.query, 'geo_values', False):
# We resolve the rest of the columns if we're on Oracle or if
# the `geo_values` attribute is defined.
for value, field in zip_longest(row[index_start:], fields):
values.append(self.query.convert_values(value, field, self.connection))
else:
values.extend(row[index_start:])
return tuple(values)
#### Routines unique to GeoQuery #### #### Routines unique to GeoQuery ####
def get_extra_select_format(self, alias): def get_extra_select_format(self, alias):
@ -268,55 +237,8 @@ class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler): class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
""" pass
This is overridden for GeoDjango to properly cast date columns, since
`GeoQuery.resolve_columns` is used for spatial values.
See #14648, #16757.
"""
def results_iter(self):
if self.connection.ops.oracle:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
date = row[offset]
if self.connection.ops.oracle:
date = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
date = typecast_date(str(date))
if isinstance(date, datetime.datetime):
date = date.date()
yield date
class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler): class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
""" pass
This is overridden for GeoDjango to properly cast date columns, since
`GeoQuery.resolve_columns` is used for spatial values.
See #14648, #16757.
"""
def results_iter(self):
if self.connection.ops.oracle:
from django.db.models.fields import DateTimeField
fields = [DateTimeField()]
else:
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI):
for row in rows:
datetime = row[offset]
if self.connection.ops.oracle:
datetime = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
datetime = typecast_timestamp(str(datetime))
# Datetimes are artificially returned in UTC on databases that
# don't support time zone. Restore the zone used in the query.
if settings.USE_TZ:
datetime = datetime.replace(tzinfo=None)
datetime = timezone.make_aware(datetime, self.query.tzinfo)
yield datetime

View File

@ -1,32 +1,53 @@
""" """
This module holds simple classes used by GeoQuery.convert_values This module holds simple classes to convert geospatial values from the
to convert geospatial values from the database. database.
""" """
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Area, Distance
class BaseField(object): class BaseField(object):
empty_strings_allowed = True empty_strings_allowed = True
def get_internal_type(self):
"Overloaded method so OracleQuery.convert_values doesn't balk."
return None
class AreaField(BaseField): class AreaField(BaseField):
"Wrapper for Area values." "Wrapper for Area values."
def __init__(self, area_att): def __init__(self, area_att):
self.area_att = area_att self.area_att = area_att
def from_db_value(self, value, connection):
if value is not None:
value = Area(**{self.area_att: value})
return value
def get_internal_type(self):
return 'AreaField'
class DistanceField(BaseField): class DistanceField(BaseField):
"Wrapper for Distance values." "Wrapper for Distance values."
def __init__(self, distance_att): def __init__(self, distance_att):
self.distance_att = distance_att self.distance_att = distance_att
def from_db_value(self, value, connection):
if value is not None:
value = Distance(**{self.distance_att: value})
return value
def get_internal_type(self):
return 'DistanceField'
class GeomField(BaseField): class GeomField(BaseField):
""" """
Wrapper for Geometry values. It is a lightweight alternative to Wrapper for Geometry values. It is a lightweight alternative to
using GeometryField (which requires an SQL query upon instantiation). using GeometryField (which requires an SQL query upon instantiation).
""" """
pass def from_db_value(self, value, connection):
if value is not None:
value = Geometry(value)
return value
def get_internal_type(self):
return 'GeometryField'

View File

@ -5,9 +5,7 @@ from django.contrib.gis.db.models.constants import ALL_TERMS
from django.contrib.gis.db.models.fields import GeometryField from django.contrib.gis.db.models.fields import GeometryField
from django.contrib.gis.db.models.lookups import GISLookup from django.contrib.gis.db.models.lookups import GISLookup
from django.contrib.gis.db.models.sql import aggregates as gis_aggregates from django.contrib.gis.db.models.sql import aggregates as gis_aggregates
from django.contrib.gis.db.models.sql.conversion import AreaField, DistanceField, GeomField from django.contrib.gis.db.models.sql.conversion import GeomField
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Area, Distance
class GeoQuery(sql.Query): class GeoQuery(sql.Query):
@ -38,32 +36,6 @@ class GeoQuery(sql.Query):
obj.extra_select_fields = self.extra_select_fields.copy() obj.extra_select_fields = self.extra_select_fields.copy()
return obj return obj
def convert_values(self, value, field, connection):
"""
Using the same routines that Oracle does we can convert our
extra selection objects into Geometry and Distance objects.
TODO: Make converted objects 'lazy' for less overhead.
"""
if connection.ops.oracle:
# Running through Oracle's first.
value = super(GeoQuery, self).convert_values(value, field or GeomField(), connection)
if value is None:
# Output from spatial function is NULL (e.g., called
# function on a geometry field with NULL value).
pass
elif isinstance(field, DistanceField):
# Using the field's distance attribute, can instantiate
# `Distance` with the right context.
value = Distance(**{field.distance_att: value})
elif isinstance(field, AreaField):
value = Area(**{field.area_att: value})
elif isinstance(field, (GeomField, GeometryField)) and value:
value = Geometry(value)
elif field is not None:
return super(GeoQuery, self).convert_values(value, field, connection)
return value
def get_aggregation(self, using, force_subq=False): def get_aggregation(self, using, force_subq=False):
# Remove any aggregates marked for reduction from the subquery # Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery. # and move them to the outer AggregateQuery.

View File

@ -66,3 +66,18 @@ class MinusOneSRID(models.Model):
class Meta: class Meta:
app_label = 'geoapp' app_label = 'geoapp'
class NonConcreteField(models.IntegerField):
def db_type(self, connection):
return None
def get_attname_column(self):
attname, column = super(NonConcreteField, self).get_attname_column()
return attname, None
class NonConcreteModel(NamedModel):
non_concrete = NonConcreteField()
point = models.PointField(geography=True)

View File

@ -13,9 +13,7 @@ from django.utils import six
if HAS_GEOS: if HAS_GEOS:
from django.contrib.gis.geos import (fromstr, GEOSGeometry, from django.contrib.gis.geos import (fromstr, GEOSGeometry,
Point, LineString, LinearRing, Polygon, GeometryCollection) Point, LineString, LinearRing, Polygon, GeometryCollection)
from .models import Country, City, PennsylvaniaCity, State, Track, NonConcreteModel, Feature, MinusOneSRID
from .models import Country, City, PennsylvaniaCity, State, Track
from .models import Feature, MinusOneSRID
def postgis_bug_version(): def postgis_bug_version():
@ -754,10 +752,5 @@ class GeoQuerySetTest(TestCase):
self.assertEqual(None, qs.unionagg(field_name='point')) self.assertEqual(None, qs.unionagg(field_name='point'))
def test_non_concrete_field(self): def test_non_concrete_field(self):
pkfield = City._meta.get_field_by_name('id')[0] NonConcreteModel.objects.create(point=Point(0, 0), name='name')
orig_pkfield_col = pkfield.column list(NonConcreteModel.objects.all())
pkfield.column = None
try:
list(City.objects.all())
finally:
pkfield.column = orig_pkfield_col

View File

@ -71,3 +71,8 @@ class Article(SimpleModel):
class Book(SimpleModel): class Book(SimpleModel):
title = models.CharField(max_length=100) title = models.CharField(max_length=100)
author = models.ForeignKey(Author, related_name='books', null=True) author = models.ForeignKey(Author, related_name='books', null=True)
class Event(SimpleModel):
name = models.CharField(max_length=100)
when = models.DateTimeField()

View File

@ -4,13 +4,15 @@ from django.contrib.gis.geos import HAS_GEOS
from django.contrib.gis.tests.utils import no_oracle from django.contrib.gis.tests.utils import no_oracle
from django.db import connection from django.db import connection
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import override_settings
from django.utils import timezone
if HAS_GEOS: if HAS_GEOS:
from django.contrib.gis.db.models import Collect, Count, Extent, F, Union from django.contrib.gis.db.models import Collect, Count, Extent, F, Union
from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.geos import GEOSGeometry, Point, MultiPoint from django.contrib.gis.geos import GEOSGeometry, Point, MultiPoint
from .models import City, Location, DirectoryEntry, Parcel, Book, Author, Article from .models import City, Location, DirectoryEntry, Parcel, Book, Author, Article, Event
@skipUnlessDBFeature("gis_enabled") @skipUnlessDBFeature("gis_enabled")
@ -183,6 +185,12 @@ class RelatedGeoModelTest(TestCase):
self.assertEqual(m.point, d['point']) self.assertEqual(m.point, d['point'])
self.assertEqual(m.point, t[1]) self.assertEqual(m.point, t[1])
@override_settings(USE_TZ=True)
def test_07b_values(self):
"Testing values() and values_list() with aware datetime. See #21565."
Event.objects.create(name="foo", when=timezone.now())
list(Event.objects.values_list('when'))
def test08_defer_only(self): def test08_defer_only(self):
"Testing defer() and only() on Geographic models." "Testing defer() and only() on Geographic models."
qs = Location.objects.all() qs = Location.objects.all()

View File

@ -1190,20 +1190,13 @@ class BaseDatabaseOperations(object):
second = timezone.make_aware(second, tz) second = timezone.make_aware(second, tz)
return [first, second] return [first, second]
def convert_values(self, value, field): def get_db_converters(self, internal_type):
"""Get a list of functions needed to convert field data.
Some field types on some backends do not provide data in the correct
format, this is the hook for coverter functions.
""" """
Coerce the value returned by the database backend into a consistent type return []
that is compatible with the field type.
"""
if value is None or field is None:
return value
internal_type = field.get_internal_type()
if internal_type == 'FloatField':
return float(value)
elif (internal_type and (internal_type.endswith('IntegerField')
or internal_type == 'AutoField')):
return int(value)
return value
def check_aggregate_support(self, aggregate_func): def check_aggregate_support(self, aggregate_func):
"""Check that the backend supports the provided aggregate """Check that the backend supports the provided aggregate

View File

@ -394,6 +394,17 @@ class DatabaseOperations(BaseDatabaseOperations):
return 'POW(%s)' % ','.join(sub_expressions) return 'POW(%s)' % ','.join(sub_expressions)
return super(DatabaseOperations, self).combine_expression(connector, sub_expressions) return super(DatabaseOperations, self).combine_expression(connector, sub_expressions)
def get_db_converters(self, internal_type):
converters = super(DatabaseOperations, self).get_db_converters(internal_type)
if internal_type in ['BooleanField', 'NullBooleanField']:
converters.append(self.convert_booleanfield_value)
return converters
def convert_booleanfield_value(self, value, field):
if value in (0, 1):
value = bool(value)
return value
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql' vendor = 'mysql'

View File

@ -1,18 +1,7 @@
from django.db.models.sql import compiler from django.db.models.sql import compiler
from django.utils.six.moves import zip_longest
class SQLCompiler(compiler.SQLCompiler): class SQLCompiler(compiler.SQLCompiler):
def resolve_columns(self, row, fields=()):
values = []
index_extra_select = len(self.query.extra_select)
for value, field in zip_longest(row[index_extra_select:], fields):
if (field and field.get_internal_type() in ("BooleanField", "NullBooleanField") and
value in (0, 1)):
value = bool(value)
values.append(value)
return row[:index_extra_select] + tuple(values)
def as_subquery_condition(self, alias, columns, qn): def as_subquery_condition(self, alias, columns, qn):
qn2 = self.connection.ops.quote_name qn2 = self.connection.ops.quote_name
sql, params = self.as_sql() sql, params = self.as_sql()

View File

@ -17,6 +17,10 @@ class DatabaseValidation(BaseDatabaseValidation):
if getattr(field, 'rel', None) is None: if getattr(field, 'rel', None) is None:
field_type = field.db_type(connection) field_type = field.db_type(connection)
# Ignore any non-concrete fields
if field_type is None:
return errors
if (field_type.startswith('varchar') # Look for CharFields... if (field_type.startswith('varchar') # Look for CharFields...
and field.unique # ... that are unique and field.unique # ... that are unique
and (field.max_length is None or int(field.max_length) > 255)): and (field.max_length is None or int(field.max_length) > 255)):

View File

@ -250,49 +250,64 @@ WHEN (new.%(col_name)s IS NULL)
sql = field_name # Cast to DATE removes sub-second precision. sql = field_name # Cast to DATE removes sub-second precision.
return sql, [] return sql, []
def convert_values(self, value, field): def get_db_converters(self, internal_type):
if isinstance(value, Database.LOB): converters = super(DatabaseOperations, self).get_db_converters(internal_type)
value = value.read() if internal_type == 'TextField':
if field and field.get_internal_type() == 'TextField': converters.append(self.convert_textfield_value)
value = force_text(value) elif internal_type == 'BinaryField':
converters.append(self.convert_binaryfield_value)
elif internal_type in ['BooleanField', 'NullBooleanField']:
converters.append(self.convert_booleanfield_value)
elif internal_type == 'DecimalField':
converters.append(self.convert_decimalfield_value)
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
converters.append(self.convert_empty_values)
return converters
def convert_empty_values(self, value, field):
# Oracle stores empty strings as null. We need to undo this in # Oracle stores empty strings as null. We need to undo this in
# order to adhere to the Django convention of using the empty # order to adhere to the Django convention of using the empty
# string instead of null, but only if the field accepts the # string instead of null, but only if the field accepts the
# empty string. # empty string.
if value is None and field and field.empty_strings_allowed: if value is None and field.empty_strings_allowed:
value = ''
if field.get_internal_type() == 'BinaryField': if field.get_internal_type() == 'BinaryField':
value = b'' value = b''
else: return value
value = ''
# Convert 1 or 0 to True or False def convert_textfield_value(self, value, field):
elif value in (1, 0) and field and field.get_internal_type() in ('BooleanField', 'NullBooleanField'): if isinstance(value, Database.LOB):
value = force_text(value.read())
return value
def convert_binaryfield_value(self, value, field):
if isinstance(value, Database.LOB):
value = force_bytes(value.read())
return value
def convert_booleanfield_value(self, value, field):
if value in (1, 0):
value = bool(value) value = bool(value)
# Force floats to the correct type return value
elif value is not None and field and field.get_internal_type() == 'FloatField':
value = float(value) def convert_decimalfield_value(self, value, field):
# Convert floats to decimals if value is not None:
elif value is not None and field and field.get_internal_type() == 'DecimalField':
value = backend_utils.typecast_decimal(field.format_number(value)) value = backend_utils.typecast_decimal(field.format_number(value))
return value
# cx_Oracle always returns datetime.datetime objects for # cx_Oracle always returns datetime.datetime objects for
# DATE and TIMESTAMP columns, but Django wants to see a # DATE and TIMESTAMP columns, but Django wants to see a
# python datetime.date, .time, or .datetime. We use the type # python datetime.date, .time, or .datetime.
# of the Field to determine which to cast to, but it's not def convert_datefield_value(self, value, field):
# always available. if isinstance(value, Database.Timestamp):
# As a workaround, we cast to date if all the time-related return value.date()
# values are 0, or to time if the date is 1/1/1900.
# This could be cleaned a bit by adding a method to the Field def convert_timefield_value(self, value, field):
# classes to normalize values from the database (the to_python if isinstance(value, Database.Timestamp):
# method is used for validation and isn't what we want here).
elif isinstance(value, Database.Timestamp):
if field and field.get_internal_type() == 'DateTimeField':
pass
elif field and field.get_internal_type() == 'DateField':
value = value.date()
elif field and field.get_internal_type() == 'TimeField' or (value.year == 1900 and value.month == value.day == 1):
value = value.time() value = value.time()
elif value.hour == value.minute == value.second == value.microsecond == 0:
value = value.date()
return value return value
def deferrable_sql(self): def deferrable_sql(self):

View File

@ -1,23 +1,7 @@
from django.db.models.sql import compiler from django.db.models.sql import compiler
from django.utils.six.moves import zip_longest
class SQLCompiler(compiler.SQLCompiler): class SQLCompiler(compiler.SQLCompiler):
def resolve_columns(self, row, fields=()):
# If this query has limit/offset information, then we expect the
# first column to be an extra "_RN" column that we need to throw
# away.
if self.query.high_mark is not None or self.query.low_mark:
rn_offset = 1
else:
rn_offset = 0
index_start = rn_offset + len(self.query.extra_select)
values = [self.query.convert_values(v, None, connection=self.connection)
for v in row[rn_offset:index_start]]
for value, field in zip_longest(row[index_start:], fields):
values.append(self.query.convert_values(value, field, connection=self.connection))
return tuple(values)
def as_sql(self, with_limits=True, with_col_aliases=False): def as_sql(self, with_limits=True, with_col_aliases=False):
""" """
Creates the SQL for this query. Returns the SQL string and list Creates the SQL for this query. Returns the SQL string and list
@ -48,7 +32,7 @@ class SQLCompiler(compiler.SQLCompiler):
high_where = '' high_where = ''
if self.query.high_mark is not None: if self.query.high_mark is not None:
high_where = 'WHERE ROWNUM <= %d' % (self.query.high_mark,) high_where = 'WHERE ROWNUM <= %d' % (self.query.high_mark,)
sql = 'SELECT * FROM (SELECT ROWNUM AS "_RN", "_SUB".* FROM (%s) "_SUB" %s) WHERE "_RN" > %d' % (sql, high_where, self.query.low_mark) sql = 'SELECT * FROM (SELECT "_SUB".*, ROWNUM AS "_RN" FROM (%s) "_SUB" %s) WHERE "_RN" > %d' % (sql, high_where, self.query.low_mark)
return sql, params return sql, params

View File

@ -263,27 +263,36 @@ class DatabaseOperations(BaseDatabaseOperations):
return six.text_type(value) return six.text_type(value)
def convert_values(self, value, field): def get_db_converters(self, internal_type):
"""SQLite returns floats when it should be returning decimals, converters = super(DatabaseOperations, self).get_db_converters(internal_type)
and gets dates and datetimes wrong. if internal_type == 'DateTimeField':
For consistency with other backends, coerce when required. converters.append(self.convert_datetimefield_value)
"""
if value is None:
return None
internal_type = field.get_internal_type()
if internal_type == 'DecimalField':
return backend_utils.typecast_decimal(field.format_number(value))
elif internal_type and internal_type.endswith('IntegerField') or internal_type == 'AutoField':
return int(value)
elif internal_type == 'DateField': elif internal_type == 'DateField':
return parse_date(value) converters.append(self.convert_datefield_value)
elif internal_type == 'DateTimeField':
return parse_datetime_with_timezone_support(value)
elif internal_type == 'TimeField': elif internal_type == 'TimeField':
return parse_time(value) converters.append(self.convert_timefield_value)
elif internal_type == 'DecimalField':
converters.append(self.convert_decimalfield_value)
return converters
# No field, or the field isn't known to be a decimal or integer def convert_decimalfield_value(self, value, field):
if value is not None:
value = backend_utils.typecast_decimal(field.format_number(value))
return value
def convert_datefield_value(self, value, field):
if value is not None and not isinstance(value, datetime.date):
value = parse_date(value)
return value
def convert_datetimefield_value(self, value, field):
if value is not None and not isinstance(value, datetime.datetime):
value = parse_datetime_with_timezone_support(value)
return value
def convert_timefield_value(self, value, field):
if value is not None and not isinstance(value, datetime.time):
value = parse_time(value)
return value return value
def bulk_insert_sql(self, fields, num_values): def bulk_insert_sql(self, fields, num_values):

View File

@ -558,6 +558,11 @@ class Field(RegisterLookupMixin):
def db_type_suffix(self, connection): def db_type_suffix(self, connection):
return connection.creation.data_types_suffix.get(self.get_internal_type()) return connection.creation.data_types_suffix.get(self.get_internal_type())
def get_db_converters(self, connection):
if hasattr(self, 'from_db_value'):
return [self.from_db_value]
return []
@property @property
def unique(self): def unique(self):
return self._unique or self.primary_key return self._unique or self.primary_key

View File

@ -7,6 +7,10 @@ to_python() and the other necessary methods and everything will work
seamlessly. seamlessly.
""" """
import warnings
from django.utils.deprecation import RemovedInDjango20Warning
class SubfieldBase(type): class SubfieldBase(type):
""" """
@ -14,6 +18,9 @@ class SubfieldBase(type):
has the descriptor protocol attached to it. has the descriptor protocol attached to it.
""" """
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
warnings.warn("SubfieldBase has been deprecated. Use Field.from_db_value instead.",
RemovedInDjango20Warning)
new_class = super(SubfieldBase, cls).__new__(cls, name, bases, attrs) new_class = super(SubfieldBase, cls).__new__(cls, name, bases, attrs)
new_class.contribute_to_class = make_contrib( new_class.contribute_to_class = make_contrib(
new_class, attrs.get('contribute_to_class') new_class, attrs.get('contribute_to_class')

View File

@ -1560,7 +1560,6 @@ class RawQuerySet(object):
compiler = connections[db].ops.compiler('SQLCompiler')( compiler = connections[db].ops.compiler('SQLCompiler')(
self.query, connections[db], db self.query, connections[db], db
) )
need_resolv_columns = hasattr(compiler, 'resolve_columns')
query = iter(self.query) query = iter(self.query)
@ -1578,11 +1577,11 @@ class RawQuerySet(object):
model_cls = deferred_class_factory(self.model, skip) model_cls = deferred_class_factory(self.model, skip)
else: else:
model_cls = self.model model_cls = self.model
if need_resolv_columns:
fields = [self.model_fields.get(c, None) for c in self.columns] fields = [self.model_fields.get(c, None) for c in self.columns]
converters = compiler.get_converters(fields)
for values in query: for values in query:
if need_resolv_columns: if converters:
values = compiler.resolve_columns(values, fields) values = compiler.apply_converters(values, converters)
# Associate fields to values # Associate fields to values
model_init_values = [values[pos] for pos in model_init_pos] model_init_values = [values[pos] for pos in model_init_pos]
instance = model_cls.from_db(db, model_init_names, model_init_values) instance = model_cls.from_db(db, model_init_names, model_init_values)

View File

@ -690,12 +690,34 @@ class SQLCompiler(object):
self.query.deferred_to_data(columns, self.query.deferred_to_columns_cb) self.query.deferred_to_data(columns, self.query.deferred_to_columns_cb)
return columns return columns
def get_converters(self, fields):
converters = {}
index_extra_select = len(self.query.extra_select)
for i, field in enumerate(fields):
if field:
backend_converters = self.connection.ops.get_db_converters(field.get_internal_type())
field_converters = field.get_db_converters(self.connection)
if backend_converters or field_converters:
converters[index_extra_select + i] = (backend_converters, field_converters, field)
return converters
def apply_converters(self, row, converters):
row = list(row)
for pos, (backend_converters, field_converters, field) in converters.items():
value = row[pos]
for converter in backend_converters:
value = converter(value, field)
for converter in field_converters:
value = converter(value, self.connection)
row[pos] = value
return tuple(row)
def results_iter(self): def results_iter(self):
""" """
Returns an iterator over the results from executing this query. Returns an iterator over the results from executing this query.
""" """
resolve_columns = hasattr(self, 'resolve_columns')
fields = None fields = None
converters = None
has_aggregate_select = bool(self.query.aggregate_select) has_aggregate_select = bool(self.query.aggregate_select)
for rows in self.execute_sql(MULTI): for rows in self.execute_sql(MULTI):
for row in rows: for row in rows:
@ -703,7 +725,6 @@ class SQLCompiler(object):
loaded_fields = self.query.get_loaded_field_names().get(self.query.model, set()) or self.query.select loaded_fields = self.query.get_loaded_field_names().get(self.query.model, set()) or self.query.select
aggregate_start = len(self.query.extra_select) + len(loaded_fields) aggregate_start = len(self.query.extra_select) + len(loaded_fields)
aggregate_end = aggregate_start + len(self.query.aggregate_select) aggregate_end = aggregate_start + len(self.query.aggregate_select)
if resolve_columns:
if fields is None: if fields is None:
# We only set this up here because # We only set this up here because
# related_select_cols isn't populated until # related_select_cols isn't populated until
@ -725,7 +746,7 @@ class SQLCompiler(object):
fields = fields + [f.field for f in self.query.related_select_cols] fields = fields + [f.field for f in self.query.related_select_cols]
# If the field was deferred, exclude it from being passed # If the field was deferred, exclude it from being passed
# into `resolve_columns` because it wasn't selected. # into `get_converters` because it wasn't selected.
only_load = self.deferred_to_columns() only_load = self.deferred_to_columns()
if only_load: if only_load:
fields = [f for f in fields if f.model._meta.db_table not in only_load or fields = [f for f in fields if f.model._meta.db_table not in only_load or
@ -735,7 +756,9 @@ class SQLCompiler(object):
fields = fields[:aggregate_start] + [ fields = fields[:aggregate_start] + [
None for x in range(0, aggregate_end - aggregate_start) None for x in range(0, aggregate_end - aggregate_start)
] + fields[aggregate_start:] ] + fields[aggregate_start:]
row = self.resolve_columns(row, fields) converters = self.get_converters(fields)
if converters:
row = self.apply_converters(row, converters)
if has_aggregate_select: if has_aggregate_select:
row = tuple(row[:aggregate_start]) + tuple( row = tuple(row[:aggregate_start]) + tuple(
@ -1092,22 +1115,13 @@ class SQLDateCompiler(SQLCompiler):
""" """
Returns an iterator over the results from executing this query. Returns an iterator over the results from executing this query.
""" """
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateField from django.db.models.fields import DateField
fields = [DateField()] converters = self.get_converters([DateField()])
else:
from django.db.backends.utils import typecast_date
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select) offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI): for rows in self.execute_sql(MULTI):
for row in rows: for row in rows:
date = row[offset] date = self.apply_converters(row, converters)[offset]
if resolve_columns:
date = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
date = typecast_date(str(date))
if isinstance(date, datetime.datetime): if isinstance(date, datetime.datetime):
date = date.date() date = date.date()
yield date yield date
@ -1118,22 +1132,13 @@ class SQLDateTimeCompiler(SQLCompiler):
""" """
Returns an iterator over the results from executing this query. Returns an iterator over the results from executing this query.
""" """
resolve_columns = hasattr(self, 'resolve_columns')
if resolve_columns:
from django.db.models.fields import DateTimeField from django.db.models.fields import DateTimeField
fields = [DateTimeField()] converters = self.get_converters([DateTimeField()])
else:
from django.db.backends.utils import typecast_timestamp
needs_string_cast = self.connection.features.needs_datetime_string_cast
offset = len(self.query.extra_select) offset = len(self.query.extra_select)
for rows in self.execute_sql(MULTI): for rows in self.execute_sql(MULTI):
for row in rows: for row in rows:
datetime = row[offset] datetime = self.apply_converters(row, converters)[offset]
if resolve_columns:
datetime = self.resolve_columns(row, fields)[offset]
elif needs_string_cast:
datetime = typecast_timestamp(str(datetime))
# Datetimes are artificially returned in UTC on databases that # Datetimes are artificially returned in UTC on databases that
# don't support time zone. Restore the zone used in the query. # don't support time zone. Restore the zone used in the query.
if settings.USE_TZ: if settings.USE_TZ:

View File

@ -54,15 +54,6 @@ class RawQuery(object):
def clone(self, using): def clone(self, using):
return RawQuery(self.sql, using, params=self.params) return RawQuery(self.sql, using, params=self.params)
def convert_values(self, value, field, connection):
"""Convert the database-returned value into a type that is consistent
across database backends.
By default, this defers to the underlying backend operations, but
it can be overridden by Query classes for specific backends.
"""
return connection.ops.convert_values(value, field)
def get_columns(self): def get_columns(self):
if self.cursor is None: if self.cursor is None:
self._execute_query() self._execute_query()
@ -308,15 +299,6 @@ class Query(object):
obj._setup_query() obj._setup_query()
return obj return obj
def convert_values(self, value, field, connection):
"""Convert the database-returned value into a type that is consistent
across database backends.
By default, this defers to the underlying backend operations, but
it can be overridden by Query classes for specific backends.
"""
return connection.ops.convert_values(value, field)
def resolve_aggregate(self, value, aggregate, connection): def resolve_aggregate(self, value, aggregate, connection):
"""Resolve the value of aggregates returned by the database to """Resolve the value of aggregates returned by the database to
consistent (and reasonable) types. consistent (and reasonable) types.
@ -337,7 +319,13 @@ class Query(object):
return float(value) return float(value)
else: else:
# Return value depends on the type of the field being processed. # Return value depends on the type of the field being processed.
return self.convert_values(value, aggregate.field, connection) backend_converters = connection.ops.get_db_converters(aggregate.field.get_internal_type())
field_converters = aggregate.field.get_db_converters(connection)
for converter in backend_converters:
value = converter(value, aggregate.field)
for converter in field_converters:
value = converter(value, connection)
return value
def get_aggregation(self, using, force_subq=False): def get_aggregation(self, using, force_subq=False):
""" """

View File

@ -317,77 +317,6 @@ and reconstructing the field::
new_instance = MyField(*args, **kwargs) new_instance = MyField(*args, **kwargs)
self.assertEqual(my_field_instance.some_attribute, new_instance.some_attribute) self.assertEqual(my_field_instance.some_attribute, new_instance.some_attribute)
The ``SubfieldBase`` metaclass
------------------------------
.. class:: django.db.models.SubfieldBase
As we indicated in the introduction_, field subclasses are often needed for
two reasons: either to take advantage of a custom database column type, or to
handle complex Python types. Obviously, a combination of the two is also
possible. If you're only working with custom database column types and your
model fields appear in Python as standard Python types direct from the
database backend, you don't need to worry about this section.
If you're handling custom Python types, such as our ``Hand`` class, we need to
make sure that when Django initializes an instance of our model and assigns a
database value to our custom field attribute, we convert that value into the
appropriate Python object. The details of how this happens internally are a
little complex, but the code you need to write in your ``Field`` class is
simple: make sure your field subclass uses a special metaclass:
For example, on Python 2::
class HandField(models.Field):
description = "A hand of cards (bridge style)"
__metaclass__ = models.SubfieldBase
def __init__(self, *args, **kwargs):
...
On Python 3, in lieu of setting the ``__metaclass__`` attribute, add
``metaclass`` to the class definition::
class HandField(models.Field, metaclass=models.SubfieldBase):
...
If you want your code to work on Python 2 & 3, you can use
:func:`six.with_metaclass`::
from django.utils.six import with_metaclass
class HandField(with_metaclass(models.SubfieldBase, models.Field)):
...
This ensures that the :meth:`.to_python` method will always be called when the
attribute is initialized.
``ModelForm``\s and custom fields
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you use :class:`~django.db.models.SubfieldBase`, :meth:`.to_python` will be
called every time an instance of the field is assigned a value (in addition to
its usual call when retrieving the value from the database). This means that
whenever a value may be assigned to the field, you need to ensure that it will
be of the correct datatype, or that you handle any exceptions.
This is especially important if you use :doc:`ModelForms
</topics/forms/modelforms>`. When saving a ModelForm, Django will use
form values to instantiate model instances. However, if the cleaned
form data can't be used as valid input to the field, the normal form
validation process will break.
Therefore, you must ensure that the form field used to represent your
custom field performs whatever input validation and data cleaning is
necessary to convert user-provided form input into a
``to_python()``-compatible model field value. This may require writing a
custom form field, and/or implementing the :meth:`.formfield` method on
your field to return a form field class whose ``to_python()`` returns the
correct datatype.
Documenting your custom field Documenting your custom field
----------------------------- -----------------------------
@ -500,59 +429,79 @@ over this field. You are then responsible for creating the column in the right
table in some other way, of course, but this gives you a way to tell Django to table in some other way, of course, but this gives you a way to tell Django to
get out of the way. get out of the way.
.. _converting-database-values-to-python-objects: .. _converting-values-to-python-objects:
Converting database values to Python objects Converting values to Python objects
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. versionchanged:: 1.8
Historically, Django provided a metaclass called ``SubfieldBase`` which
always called :meth:`~Field.to_python` on assignment. This did not play
nicely with custom database transformations, aggregation, or values
queries, so it has been replaced with :meth:`~Field.from_db_value`.
If your custom :class:`~Field` class deals with data structures that are more If your custom :class:`~Field` class deals with data structures that are more
complex than strings, dates, integers or floats, then you'll need to override complex than strings, dates, integers, or floats, then you may need to override
:meth:`~Field.to_python`. As a general rule, the method should deal gracefully :meth:`~Field.from_db_value` and :meth:`~Field.to_python`.
with any of the following arguments:
If present for the field subclass, ``from_db_value()`` will be called in all
circumstances when the data is loaded from the database, including in
aggregates and :meth:`~django.db.models.query.QuerySet.values` calls.
``to_python()`` is called by deserialization and during the
:meth:`~django.db.models.Model.clean` method used from forms.
As a general rule, ``to_python()`` should deal gracefully with any of the
following arguments:
* An instance of the correct type (e.g., ``Hand`` in our ongoing example). * An instance of the correct type (e.g., ``Hand`` in our ongoing example).
* A string (e.g., from a deserializer). * A string
* Whatever the database returns for the column type you're using. * ``None`` (if the field allows ``null=True``)
In our ``HandField`` class, we're storing the data as a VARCHAR field in the In our ``HandField`` class, we're storing the data as a VARCHAR field in the
database, so we need to be able to process strings and ``Hand`` instances in database, so we need to be able to process strings and ``None`` in the
:meth:`.to_python`:: ``from_db_value()``. In ``to_python()``, we need to also handle ``Hand``
instances::
import re import re
from django.core.exceptions import ValidationError
from django.db import models
def parse_hand(hand_string):
"""Takes a string of cards and splits into a full hand."""
p1 = re.compile('.{26}')
p2 = re.compile('..')
args = [p2.findall(x) for x in p1.findall(hand_string)]
if len(args) != 4:
raise ValidationError("Invalid input for a Hand instance")
return Hand(*args)
class HandField(models.Field): class HandField(models.Field):
# ... # ...
def from_db_value(self, value, connection):
if value is None:
return value
return parse_hand(value)
def to_python(self, value): def to_python(self, value):
if isinstance(value, Hand): if isinstance(value, Hand):
return value return value
# The string case. if value is None:
p1 = re.compile('.{26}') return value
p2 = re.compile('..')
args = [p2.findall(x) for x in p1.findall(value)]
if len(args) != 4:
raise ValidationError("Invalid input for a Hand instance")
return Hand(*args)
Notice that we always return a ``Hand`` instance from this method. That's the return parse_hand(value)
Python object type we want to store in the model's attribute. If anything is
going wrong during value conversion, you should raise a
:exc:`~django.core.exceptions.ValidationError` exception.
**Remember:** If your custom field needs the :meth:`~Field.to_python` method to be Notice that we always return a ``Hand`` instance from these methods. That's the
called when it is created, you should be using `The SubfieldBase metaclass`_ Python object type we want to store in the model's attribute.
mentioned earlier. Otherwise :meth:`~Field.to_python` won't be called
automatically.
.. warning:: For ``to_python()``, if anything goes wrong during value conversion, you should
raise a :exc:`~django.core.exceptions.ValidationError` exception.
If your custom field allows ``null=True``, any field method that takes
``value`` as an argument, like :meth:`~Field.to_python` and
:meth:`~Field.get_prep_value`, should handle the case when ``value`` is
``None``.
.. _converting-python-objects-to-query-values: .. _converting-python-objects-to-query-values:

View File

@ -57,6 +57,8 @@ about each item can often be found in the release notes of two versions prior.
* The ``is_admin_site`` argument to * The ``is_admin_site`` argument to
``django.contrib.auth.views.password_reset()`` will be removed. ``django.contrib.auth.views.password_reset()`` will be removed.
* ``django.db.models.field.subclassing.SubfieldBase`` will be removed.
.. _deprecation-removed-in-1.9: .. _deprecation-removed-in-1.9:
1.9 1.9

View File

@ -1532,7 +1532,7 @@ Field API reference
``Field`` is an abstract class that represents a database table column. ``Field`` is an abstract class that represents a database table column.
Django uses fields to create the database table (:meth:`db_type`), to map Django uses fields to create the database table (:meth:`db_type`), to map
Python types to database (:meth:`get_prep_value`) and vice-versa Python types to database (:meth:`get_prep_value`) and vice-versa
(:meth:`to_python`), and to apply :doc:`/ref/models/lookups` (:meth:`from_db_value`), and to apply :doc:`/ref/models/lookups`
(:meth:`get_prep_lookup`). (:meth:`get_prep_lookup`).
A field is thus a fundamental piece in different Django APIs, notably, A field is thus a fundamental piece in different Django APIs, notably,
@ -1609,17 +1609,26 @@ Field API reference
See :ref:`converting-query-values-to-database-values` for usage. See :ref:`converting-query-values-to-database-values` for usage.
When loading data, :meth:`to_python` is used: When loading data, :meth:`from_db_value` is used:
.. method:: to_python(value) .. method:: from_db_value(value, connection)
Converts a value as returned by the database (or a serializer) to a .. versionadded:: 1.8
Python object. It is the reverse of :meth:`get_prep_value`.
The default implementation returns ``value``, which is the common case Converts a value as returned by the database to a Python object. It is
when the database backend already returns the correct Python type. the reverse of :meth:`get_prep_value`.
See :ref:`converting-database-values-to-python-objects` for usage. This method is not used for most built-in fields as the database
backend already returns the correct Python type, or the backend itself
does the conversion.
See :ref:`converting-values-to-python-objects` for usage.
.. note::
For performance reasons, ``from_db_value`` is not implemented as a
no-op on fields which do not require it (all Django fields).
Consequently you may not call ``super`` in your definition.
When saving, :meth:`pre_save` and :meth:`get_db_prep_save` are used: When saving, :meth:`pre_save` and :meth:`get_db_prep_save` are used:
@ -1644,15 +1653,6 @@ Field API reference
See :ref:`preprocessing-values-before-saving` for usage. See :ref:`preprocessing-values-before-saving` for usage.
Besides saving to the database, the field also needs to know how to
serialize its value (inverse of :meth:`to_python`):
.. method:: value_to_string(obj)
Converts ``obj`` to a string. Used to serialize the value of the field.
See :ref:`converting-model-field-to-serialization` for usage.
When a lookup is used on a field, the value may need to be "prepared". When a lookup is used on a field, the value may need to be "prepared".
Django exposes two methods for this: Django exposes two methods for this:
@ -1682,6 +1682,26 @@ Field API reference
``prepared`` describes whether the value has already been prepared with ``prepared`` describes whether the value has already been prepared with
:meth:`get_prep_lookup`. :meth:`get_prep_lookup`.
Fields often receive their values as a different type, either from
serialization or from forms.
.. method:: to_python(value)
Converts the value into the correct Python object. It acts as the
reverse of :meth:`value_to_string`, and is also called in
:meth:`~django.db.models.Model.clean`.
See :ref:`converting-values-to-python-objects` for usage.
Besides saving to the database, the field also needs to know how to
serialize its value:
.. method:: value_to_string(obj)
Converts ``obj`` to a string. Used to serialize the value of the field.
See :ref:`converting-model-field-to-serialization` for usage.
When using :class:`model forms <django.forms.ModelForm>`, the ``Field`` When using :class:`model forms <django.forms.ModelForm>`, the ``Field``
needs to know which form field it should be represented by: needs to know which form field it should be represented by:

View File

@ -736,3 +736,14 @@ also been deprecated.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
It's a legacy option that should no longer be necessary. It's a legacy option that should no longer be necessary.
``SubfieldBase``
~~~~~~~~~~~~~~~~
``django.db.models.fields.subclassing.SubfieldBase`` has been deprecated and
will be removed in Django 2.0. Historically, it was used to handle fields where
type conversion was needed when loading from the database, but it was not used
in ``.values()`` calls or in aggregates. It has been replaced with
:meth:`~django.db.models.Field.from_db_value`. Note that the new approach does
not call the :meth:`~django.db.models.Fields.to_python`` method on assignment
as was the case with ``SubfieldBase``.

View File

@ -894,18 +894,6 @@ class AggregationTests(TestCase):
lambda b: b.name lambda b: b.name
) )
def test_type_conversion(self):
# The database backend convert_values function should not try to covert
# CharFields to float. Refs #13844.
from django.db.models import CharField
from django.db import connection
testData = 'not_a_float_value'
testField = CharField()
self.assertEqual(
connection.ops.convert_values(testData, testField),
testData
)
def test_annotate_joins(self): def test_annotate_joins(self):
""" """
Test that the base table's join isn't promoted to LOUTER. This could Test that the base table's join isn't promoted to LOUTER. This could

View File

@ -20,8 +20,6 @@ from django.db.backends.signals import connection_created
from django.db.backends.postgresql_psycopg2 import version as pg_version from django.db.backends.postgresql_psycopg2 import version as pg_version
from django.db.backends.utils import format_number, CursorWrapper from django.db.backends.utils import format_number, CursorWrapper
from django.db.models import Sum, Avg, Variance, StdDev from django.db.models import Sum, Avg, Variance, StdDev
from django.db.models.fields import (AutoField, DateField, DateTimeField,
DecimalField, IntegerField, TimeField)
from django.db.models.sql.constants import CURSOR from django.db.models.sql.constants import CURSOR
from django.db.utils import ConnectionHandler from django.db.utils import ConnectionHandler
from django.test import (TestCase, TransactionTestCase, override_settings, from django.test import (TestCase, TransactionTestCase, override_settings,
@ -133,16 +131,6 @@ class SQLiteTests(TestCase):
self.assertRaises(NotImplementedError, self.assertRaises(NotImplementedError,
models.Item.objects.all().aggregate, aggregate('last_modified')) models.Item.objects.all().aggregate, aggregate('last_modified'))
def test_convert_values_to_handle_null_value(self):
from django.db.backends.sqlite3.base import DatabaseOperations
convert_values = DatabaseOperations(connection).convert_values
self.assertIsNone(convert_values(None, AutoField(primary_key=True)))
self.assertIsNone(convert_values(None, DateField()))
self.assertIsNone(convert_values(None, DateTimeField()))
self.assertIsNone(convert_values(None, DecimalField()))
self.assertIsNone(convert_values(None, IntegerField()))
self.assertIsNone(convert_values(None, TimeField()))
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
class PostgreSQLTests(TestCase): class PostgreSQLTests(TestCase):

View File

@ -23,7 +23,7 @@ class MyWrapper(object):
return self.value == other return self.value == other
class MyAutoField(six.with_metaclass(models.SubfieldBase, models.CharField)): class MyAutoField(models.CharField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['max_length'] = 10 kwargs['max_length'] = 10
@ -43,6 +43,11 @@ class MyAutoField(six.with_metaclass(models.SubfieldBase, models.CharField)):
value = MyWrapper(value) value = MyWrapper(value)
return value return value
def from_db_value(self, value, connection):
if not value:
return
return MyWrapper(value)
def get_db_prep_save(self, value, connection): def get_db_prep_save(self, value, connection):
if not value: if not value:
return return

View File

@ -2,13 +2,24 @@
Tests for field subclassing. Tests for field subclassing.
""" """
import warnings
from django.db import models from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.deprecation import RemovedInDjango20Warning
from .fields import Small, SmallField, SmallerField, JSONField from .fields import Small, SmallField, SmallerField, JSONField
from django.utils.encoding import python_2_unicode_compatible from django.utils.encoding import python_2_unicode_compatible
# Catch warning about subfieldbase -- remove in Django 2.0
warnings.filterwarnings(
'ignore',
'SubfieldBase has been deprecated. Use Field.from_db_value instead.',
RemovedInDjango20Warning
)
@python_2_unicode_compatible @python_2_unicode_compatible
class MyModel(models.Model): class MyModel(models.Model):
name = models.CharField(max_length=10) name = models.CharField(max_length=10)

View File

View File

@ -0,0 +1,32 @@
import decimal
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
class Cash(decimal.Decimal):
currency = 'USD'
def __str__(self):
s = super(Cash, self).__str__(self)
return '%s %s' % (s, self.currency)
class CashField(models.DecimalField):
def __init__(self, **kwargs):
kwargs['max_digits'] = 20
kwargs['decimal_places'] = 2
super(CashField, self).__init__(**kwargs)
def from_db_value(self, value, connection):
cash = Cash(value)
cash.vendor = connection.vendor
return cash
@python_2_unicode_compatible
class CashModel(models.Model):
cash = CashField()
def __str__(self):
return str(self.cash)

View File

@ -0,0 +1,30 @@
from django.db import connection
from django.db.models import Max
from django.test import TestCase
from .models import CashModel, Cash
class FromDBValueTest(TestCase):
def setUp(self):
CashModel.objects.create(cash='12.50')
def test_simple_load(self):
instance = CashModel.objects.get()
self.assertIsInstance(instance.cash, Cash)
def test_values(self):
values_list = CashModel.objects.values_list('cash', flat=True)
self.assertIsInstance(values_list[0], Cash)
def test_aggregation(self):
maximum = CashModel.objects.aggregate(m=Max('cash'))['m']
self.assertIsInstance(maximum, Cash)
def test_defer(self):
instance = CashModel.objects.defer('cash').get()
self.assertIsInstance(instance.cash, Cash)
def test_connection(self):
instance = CashModel.objects.get()
self.assertEqual(instance.cash.vendor, connection.vendor)

View File

@ -99,7 +99,7 @@ class Team(object):
return "%s" % self.title return "%s" % self.title
class TeamField(six.with_metaclass(models.SubfieldBase, models.CharField)): class TeamField(models.CharField):
def __init__(self): def __init__(self):
super(TeamField, self).__init__(max_length=100) super(TeamField, self).__init__(max_length=100)
@ -112,6 +112,9 @@ class TeamField(six.with_metaclass(models.SubfieldBase, models.CharField)):
return value return value
return Team(value) return Team(value)
def from_db_value(self, value, connection):
return Team(value)
def value_to_string(self, obj): def value_to_string(self, obj):
return self._get_val_from_obj(obj).to_string() return self._get_val_from_obj(obj).to_string()