From 32d4db66b999cde6776d4be7f71528dab94916cc Mon Sep 17 00:00:00 2001 From: Marc Tamlyn Date: Fri, 20 Feb 2015 10:53:59 +0000 Subject: [PATCH] Update converters to take a consistent set of parameters. As suggested by Anssi. This has the slightly strange side effect of passing the expression to Expression.convert_value has the expression passed back to it, but it allows more complex patterns of expressions. --- .../contrib/gis/db/backends/oracle/operations.py | 2 +- .../gis/db/backends/spatialite/operations.py | 2 +- django/contrib/gis/db/models/aggregates.py | 6 +++--- django/contrib/gis/db/models/fields.py | 2 +- django/contrib/gis/db/models/sql/conversion.py | 8 ++++---- django/db/backends/base/operations.py | 2 +- django/db/backends/mysql/operations.py | 6 +++--- django/db/backends/oracle/operations.py | 14 +++++++------- django/db/backends/sqlite3/operations.py | 10 +++++----- django/db/models/aggregates.py | 8 ++++---- django/db/models/expressions.py | 6 +++--- django/db/models/fields/related.py | 2 +- django/db/models/sql/compiler.py | 10 ++++------ docs/howto/custom-model-fields.txt | 2 +- docs/ref/models/expressions.txt | 2 +- docs/ref/models/fields.txt | 2 +- tests/custom_pk/fields.py | 2 +- tests/from_db_value/models.py | 2 +- tests/serializers/models.py | 2 +- 19 files changed, 44 insertions(+), 46 deletions(-) diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index 9d40505632..8709202cd2 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -127,7 +127,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): converters.append(self.convert_geometry) return converters - def convert_geometry(self, value, expression, context): + def convert_geometry(self, value, expression, connection, context): if value: value = Geometry(value) if 'transformed_srid' in context: diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 73adb02b05..6c89849caf 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -262,7 +262,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): converters.append(self.convert_geometry) return converters - def convert_geometry(self, value, expression, context): + def convert_geometry(self, value, expression, connection, context): if value: value = Geometry(value) if 'transformed_srid' in context: diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 60954951e6..7458b2bdeb 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -29,7 +29,7 @@ class GeoAggregate(Aggregate): raise ValueError('Geospatial aggregates only allowed on geometry fields.') return c - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): return connection.ops.convert_geom(value, self.output_field) @@ -44,7 +44,7 @@ class Extent(GeoAggregate): def __init__(self, expression, **extra): super(Extent, self).__init__(expression, output_field=ExtentField(), **extra) - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): return connection.ops.convert_extent(value, context.get('transformed_srid')) @@ -55,7 +55,7 @@ class Extent3D(GeoAggregate): def __init__(self, expression, **extra): super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra) - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): return connection.ops.convert_extent3d(value, context.get('transformed_srid')) diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index df26b8b9f8..1d95ce7be7 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -219,7 +219,7 @@ class GeometryField(GeoSelectFormatMixin, Field): else: return geom - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): if value and not isinstance(value, Geometry): value = Geometry(value) return value diff --git a/django/contrib/gis/db/models/sql/conversion.py b/django/contrib/gis/db/models/sql/conversion.py index 19e7b5bfbd..28e601613e 100644 --- a/django/contrib/gis/db/models/sql/conversion.py +++ b/django/contrib/gis/db/models/sql/conversion.py @@ -23,7 +23,7 @@ class AreaField(BaseField): def __init__(self, area_att): self.area_att = area_att - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): if value is not None: value = Area(**{self.area_att: value}) return value @@ -37,7 +37,7 @@ class DistanceField(BaseField): def __init__(self, distance_att): self.distance_att = distance_att - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): if value is not None: value = Distance(**{self.distance_att: value}) return value @@ -54,7 +54,7 @@ class GeomField(GeoSelectFormatMixin, BaseField): # Hacky marker for get_db_converters() geom_type = None - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): if value is not None: value = Geometry(value) return value @@ -71,5 +71,5 @@ class GMLField(BaseField): def get_internal_type(self): return 'GMLField' - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): return value diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index aebe4312c8..9777a77e65 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -512,7 +512,7 @@ class BaseDatabaseOperations(object): """ return [] - def convert_durationfield_value(self, value, expression, context): + def convert_durationfield_value(self, value, expression, connection, context): if value is not None: value = str(decimal.Decimal(value) / decimal.Decimal(1000000)) value = parse_duration(value) diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 7edc13b3ff..fd468cb182 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -184,17 +184,17 @@ class DatabaseOperations(BaseDatabaseOperations): converters.append(self.convert_textfield_value) return converters - def convert_booleanfield_value(self, value, expression, context): + def convert_booleanfield_value(self, value, expression, connection, context): if value in (0, 1): value = bool(value) return value - def convert_uuidfield_value(self, value, expression, context): + def convert_uuidfield_value(self, value, expression, connection, context): if value is not None: value = uuid.UUID(value) return value - def convert_textfield_value(self, value, expression, context): + def convert_textfield_value(self, value, expression, connection, context): if value is not None: value = force_text(value) return value diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 1ba49909d8..02fd6fcc7d 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -172,7 +172,7 @@ WHEN (new.%(col_name)s IS NULL) converters.append(self.convert_empty_values) return converters - def convert_empty_values(self, value, expression, context): + def convert_empty_values(self, value, expression, connection, context): # Oracle stores empty strings as null. We need to undo this in # order to adhere to the Django convention of using the empty # string instead of null, but only if the field accepts the @@ -184,17 +184,17 @@ WHEN (new.%(col_name)s IS NULL) value = b'' return value - def convert_textfield_value(self, value, expression, context): + def convert_textfield_value(self, value, expression, connection, context): if isinstance(value, Database.LOB): value = force_text(value.read()) return value - def convert_binaryfield_value(self, value, expression, context): + def convert_binaryfield_value(self, value, expression, connection, context): if isinstance(value, Database.LOB): value = force_bytes(value.read()) return value - def convert_booleanfield_value(self, value, expression, context): + def convert_booleanfield_value(self, value, expression, connection, context): if value in (1, 0): value = bool(value) return value @@ -202,16 +202,16 @@ WHEN (new.%(col_name)s IS NULL) # cx_Oracle always returns datetime.datetime objects for # DATE and TIMESTAMP columns, but Django wants to see a # python datetime.date, .time, or .datetime. - def convert_datefield_value(self, value, expression, context): + def convert_datefield_value(self, value, expression, connection, context): if isinstance(value, Database.Timestamp): return value.date() - def convert_timefield_value(self, value, expression, context): + def convert_timefield_value(self, value, expression, connection, context): if isinstance(value, Database.Timestamp): value = value.time() return value - def convert_uuidfield_value(self, value, expression, context): + def convert_uuidfield_value(self, value, expression, connection, context): if value is not None: value = uuid.UUID(value) return value diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 5d974f36df..f7e2c64da1 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -153,25 +153,25 @@ class DatabaseOperations(BaseDatabaseOperations): converters.append(self.convert_uuidfield_value) return converters - def convert_decimalfield_value(self, value, expression, context): + def convert_decimalfield_value(self, value, expression, connection, context): return backend_utils.typecast_decimal(expression.output_field.format_number(value)) - def convert_datefield_value(self, value, expression, context): + def convert_datefield_value(self, value, expression, connection, context): if value is not None and not isinstance(value, datetime.date): value = parse_date(value) return value - def convert_datetimefield_value(self, value, expression, context): + def convert_datetimefield_value(self, value, expression, connection, context): if value is not None and not isinstance(value, datetime.datetime): value = parse_datetime_with_timezone_support(value) return value - def convert_timefield_value(self, value, expression, context): + def convert_timefield_value(self, value, expression, connection, context): if value is not None and not isinstance(value, datetime.time): value = parse_time(value) return value - def convert_uuidfield_value(self, value, expression, context): + def convert_uuidfield_value(self, value, expression, connection, context): if value is not None: value = uuid.UUID(value) return value diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 01ab61f71a..b51fe5635a 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -77,7 +77,7 @@ class Avg(Aggregate): def __init__(self, expression, **extra): super(Avg, self).__init__(expression, output_field=FloatField(), **extra) - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): if value is None: return value return float(value) @@ -101,7 +101,7 @@ class Count(Aggregate): 'False' if self.extra['distinct'] == '' else 'True', ) - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): if value is None: return 0 return int(value) @@ -131,7 +131,7 @@ class StdDev(Aggregate): 'False' if self.function == 'STDDEV_POP' else 'True', ) - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): if value is None: return value return float(value) @@ -156,7 +156,7 @@ class Variance(Aggregate): 'False' if self.function == 'VAR_POP' else 'True', ) - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): if value is None: return value return float(value) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 2e41803a1a..00a6424886 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -252,7 +252,7 @@ class BaseExpression(object): raise FieldError( "Expression contains mixed types. You must set output_field") - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): """ Expressions provide their own converters because users have the option of manually specifying the output_field which may be a different type @@ -804,7 +804,7 @@ class Date(ExpressionNode): copy.lookup_type = self.lookup_type return copy - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): if isinstance(value, datetime.datetime): value = value.date() return value @@ -856,7 +856,7 @@ class DateTime(ExpressionNode): copy.tzname = self.tzname return copy - def convert_value(self, value, connection, context): + def convert_value(self, value, expression, connection, context): if settings.USE_TZ: if value is None: raise ValueError( diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 6be796521e..84c9e1987a 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -2064,7 +2064,7 @@ class ForeignKey(ForeignObject): def db_parameters(self, connection): return {"type": self.db_type(connection), "check": []} - def convert_empty_strings(self, value, connection, context): + def convert_empty_strings(self, value, expression, connection, context): if (not value) and isinstance(value, six.string_types): return None return value diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 3ca2d6d675..97022d554a 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -760,17 +760,15 @@ class SQLCompiler(object): backend_converters = self.connection.ops.get_db_converters(expression) field_converters = expression.get_db_converters(self.connection) if backend_converters or field_converters: - converters[i] = (backend_converters, field_converters, expression) + converters[i] = (backend_converters + field_converters, expression) return converters def apply_converters(self, row, converters): row = list(row) - for pos, (backend_converters, field_converters, field) in converters.items(): + for pos, (convs, expression) in converters.items(): value = row[pos] - for converter in backend_converters: - value = converter(value, field, self.query.context) - for converter in field_converters: - value = converter(value, self.connection, self.query.context) + for converter in convs: + value = converter(value, expression, self.connection, self.query.context) row[pos] = value return tuple(row) diff --git a/docs/howto/custom-model-fields.txt b/docs/howto/custom-model-fields.txt index 5c41d92bb6..35dfd4cb79 100644 --- a/docs/howto/custom-model-fields.txt +++ b/docs/howto/custom-model-fields.txt @@ -477,7 +477,7 @@ instances:: class HandField(models.Field): # ... - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): if value is None: return value return parse_hand(value) diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 380f4ced76..6632e4e94d 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -412,7 +412,7 @@ calling the appropriate methods on the wrapped expression. clone.expression = self.expression.relabeled_clone(change_map) return clone - .. method:: convert_value(self, value, connection, context) + .. method:: convert_value(self, value, expression, connection, context) A hook allowing the expression to coerce ``value`` into a more appropriate type. diff --git a/docs/ref/models/fields.txt b/docs/ref/models/fields.txt index ad7d099621..efce51369b 100644 --- a/docs/ref/models/fields.txt +++ b/docs/ref/models/fields.txt @@ -1630,7 +1630,7 @@ Field API reference When loading data, :meth:`from_db_value` is used: - .. method:: from_db_value(value, connection, context) + .. method:: from_db_value(value, expression, connection, context) .. versionadded:: 1.8 diff --git a/tests/custom_pk/fields.py b/tests/custom_pk/fields.py index bf349545e5..3779fc6c43 100644 --- a/tests/custom_pk/fields.py +++ b/tests/custom_pk/fields.py @@ -43,7 +43,7 @@ class MyAutoField(models.CharField): value = MyWrapper(value) return value - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): if not value: return return MyWrapper(value) diff --git a/tests/from_db_value/models.py b/tests/from_db_value/models.py index 6a06c832ea..aa62b1f567 100644 --- a/tests/from_db_value/models.py +++ b/tests/from_db_value/models.py @@ -18,7 +18,7 @@ class CashField(models.DecimalField): kwargs['decimal_places'] = 2 super(CashField, self).__init__(**kwargs) - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): cash = Cash(value) cash.vendor = connection.vendor return cash diff --git a/tests/serializers/models.py b/tests/serializers/models.py index b2864b1c71..08dc860821 100644 --- a/tests/serializers/models.py +++ b/tests/serializers/models.py @@ -112,7 +112,7 @@ class TeamField(models.CharField): return value return Team(value) - def from_db_value(self, value, connection, context): + def from_db_value(self, value, expression, connection, context): return Team(value) def value_to_string(self, obj):