Allowed database backends to specify data types for Cast().

A small refactor ahead of refs #28371.
This commit is contained in:
Mariusz Felisiak 2017-07-27 02:26:58 +02:00 committed by Tim Graham
parent ef9344b3a5
commit 8e41373c81
4 changed files with 21 additions and 20 deletions

View File

@ -32,6 +32,10 @@ class BaseDatabaseOperations:
'intersection': 'INTERSECT',
'difference': 'EXCEPT',
}
# Mapping of Field.get_internal_type() (typically the model field's class
# name) to the data type to use for the Cast() function, if different from
# DatabaseWrapper.data_types.
cast_data_types = {}
def __init__(self, connection):
self.connection = connection

View File

@ -15,6 +15,15 @@ class DatabaseOperations(BaseDatabaseOperations):
PositiveSmallIntegerField=(0, 65535),
PositiveIntegerField=(0, 4294967295),
)
cast_data_types = {
'CharField': 'char(%(max_length)s)',
'IntegerField': 'signed integer',
'BigIntegerField': 'signed integer',
'SmallIntegerField': 'signed integer',
'FloatField': 'signed',
'PositiveIntegerField': 'unsigned integer',
'PositiveSmallIntegerField': 'unsigned integer',
}
def date_extract_sql(self, lookup_type, field_name):
# http://dev.mysql.com/doc/mysql/en/date-and-time-functions.html

View File

@ -656,6 +656,13 @@ class Field(RegisterLookupMixin):
"""
return self.db_type(connection)
def cast_db_type(self, connection):
"""Return the data type to use in the Cast() function."""
db_type = connection.ops.cast_data_types.get(self.get_internal_type())
if db_type:
return db_type % self.db_type_parameters(connection)
return self.db_type(connection)
def db_parameters(self, connection):
"""
Extension of db_type(), providing a range of different return values

View File

@ -9,32 +9,13 @@ class Cast(Func):
function = 'CAST'
template = '%(function)s(%(expressions)s AS %(db_type)s)'
mysql_types = {
fields.CharField: 'char(%(max_length)s)',
fields.IntegerField: 'signed integer',
fields.BigIntegerField: 'signed integer',
fields.SmallIntegerField: 'signed integer',
fields.FloatField: 'signed',
fields.PositiveIntegerField: 'unsigned integer',
fields.PositiveSmallIntegerField: 'unsigned integer',
}
def __init__(self, expression, output_field):
super().__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
if 'db_type' not in extra_context:
extra_context['db_type'] = self.output_field.db_type(connection)
extra_context['db_type'] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection):
extra_context = {}
output_field_class = type(self.output_field)
if output_field_class in self.mysql_types:
data = self.output_field.db_type_parameters(connection)
extra_context['db_type'] = self.mysql_types[output_field_class] % data
return self.as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection):
# CAST would be valid too, but the :: shortcut syntax is more readable.
return self.as_sql(compiler, connection, template='%(expressions)s::%(db_type)s')