Fixed #28391 -- Fixed Cast() with CharField and max_length on MySQL.

Thanks Tim Graham for the review.
This commit is contained in:
Mariusz Felisiak 2017-07-17 21:12:27 +02:00 committed by GitHub
parent feeafdad02
commit 776cee9749
5 changed files with 20 additions and 5 deletions

View File

@ -234,6 +234,9 @@ class BaseDatabaseFeatures:
# Does the backend support indexing a TextField? # Does the backend support indexing a TextField?
supports_index_on_text_field = True supports_index_on_text_field = True
# Does the backend support CAST with precision?
supports_cast_with_precision = True
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection

View File

@ -30,6 +30,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_clone_databases = True can_clone_databases = True
supports_temporal_subtraction = True supports_temporal_subtraction = True
ignores_table_name_case = True ignores_table_name_case = True
supports_cast_with_precision = False
@cached_property @cached_property
def uses_savepoints(self): def uses_savepoints(self):

View File

@ -607,13 +607,16 @@ class Field(RegisterLookupMixin):
self.run_validators(value) self.run_validators(value)
return value return value
def db_type_parameters(self, connection):
return DictWrapper(self.__dict__, connection.ops.quote_name, 'qn_')
def db_check(self, connection): def db_check(self, connection):
""" """
Return the database column check constraint for this field, for the Return the database column check constraint for this field, for the
provided connection. Works the same way as db_type() for the case that provided connection. Works the same way as db_type() for the case that
get_internal_type() does not map to a preexisting model field. get_internal_type() does not map to a preexisting model field.
""" """
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") data = self.db_type_parameters(connection)
try: try:
return connection.data_type_check_constraints[self.get_internal_type()] % data return connection.data_type_check_constraints[self.get_internal_type()] % data
except KeyError: except KeyError:
@ -639,7 +642,7 @@ class Field(RegisterLookupMixin):
# mapped to one of the built-in Django field types. In this case, you # mapped to one of the built-in Django field types. In this case, you
# can implement db_type() instead of get_internal_type() to specify # can implement db_type() instead of get_internal_type() to specify
# exactly which wacky database column type you want to use. # exactly which wacky database column type you want to use.
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") data = self.db_type_parameters(connection)
try: try:
return connection.data_types[self.get_internal_type()] % data return connection.data_types[self.get_internal_type()] % data
except KeyError: except KeyError:

View File

@ -10,7 +10,7 @@ class Cast(Func):
template = '%(function)s(%(expressions)s AS %(db_type)s)' template = '%(function)s(%(expressions)s AS %(db_type)s)'
mysql_types = { mysql_types = {
fields.CharField: 'char', fields.CharField: 'char(%(max_length)s)',
fields.IntegerField: 'signed integer', fields.IntegerField: 'signed integer',
fields.BigIntegerField: 'signed integer', fields.BigIntegerField: 'signed integer',
fields.SmallIntegerField: 'signed integer', fields.SmallIntegerField: 'signed integer',
@ -31,7 +31,8 @@ class Cast(Func):
extra_context = {} extra_context = {}
output_field_class = type(self.output_field) output_field_class = type(self.output_field)
if output_field_class in self.mysql_types: if output_field_class in self.mysql_types:
extra_context['db_type'] = self.mysql_types[output_field_class] data = self.output_field.db_type_parameters(connection)
extra_context['db_type'] = self.mysql_types[output_field_class] % data
return self.as_sql(compiler, connection, **extra_context) return self.as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):

View File

@ -1,7 +1,7 @@
from django.db import models from django.db import models
from django.db.models.expressions import Value from django.db.models.expressions import Value
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.test import TestCase from django.test import TestCase, ignore_warnings, skipUnlessDBFeature
from .models import Author from .models import Author
@ -19,6 +19,13 @@ class CastTests(TestCase):
numbers = Author.objects.annotate(cast_string=Cast('age', models.CharField(max_length=255)),) numbers = Author.objects.annotate(cast_string=Cast('age', models.CharField(max_length=255)),)
self.assertEqual(numbers.get().cast_string, '1') self.assertEqual(numbers.get().cast_string, '1')
# Silence "Truncated incorrect CHAR(1) value: 'Bob'".
@ignore_warnings(module='django.db.backends.mysql.base')
@skipUnlessDBFeature('supports_cast_with_precision')
def test_cast_to_char_field_with_max_length(self):
names = Author.objects.annotate(cast_string=Cast('name', models.CharField(max_length=1)))
self.assertEqual(names.get().cast_string, 'B')
def test_cast_to_integer(self): def test_cast_to_integer(self):
for field_class in ( for field_class in (
models.IntegerField, models.IntegerField,