Fixed #24343 -- Ensure db converters are used for foreign keys.

Joint effort between myself, Josh, Anssi and Shai.
This commit is contained in:
Marc Tamlyn 2015-02-14 19:37:12 +00:00
parent dbacbc729a
commit 4755f8fc25
7 changed files with 44 additions and 15 deletions

View File

@ -585,10 +585,10 @@ class Random(ExpressionNode):
class Col(ExpressionNode): class Col(ExpressionNode):
def __init__(self, alias, target, source=None): def __init__(self, alias, target, output_field=None):
if source is None: if output_field is None:
source = target output_field = target
super(Col, self).__init__(output_field=source) super(Col, self).__init__(output_field=output_field)
self.alias, self.target = alias, target self.alias, self.target = alias, target
def __repr__(self): def __repr__(self):
@ -606,7 +606,10 @@ class Col(ExpressionNode):
return [self] return [self]
def get_db_converters(self, connection): def get_db_converters(self, connection):
return self.output_field.get_db_converters(connection) if self.target == self.output_field:
return self.output_field.get_db_converters(connection)
return (self.output_field.get_db_converters(connection) +
self.target.get_db_converters(connection))
class Ref(ExpressionNode): class Ref(ExpressionNode):

View File

@ -330,12 +330,12 @@ class Field(RegisterLookupMixin):
] ]
return [] return []
def get_col(self, alias, source=None): def get_col(self, alias, output_field=None):
if source is None: if output_field is None:
source = self output_field = self
if alias != self.model._meta.db_table or source != self: if alias != self.model._meta.db_table or output_field != self:
from django.db.models.expressions import Col from django.db.models.expressions import Col
return Col(alias, self, source) return Col(alias, self, output_field)
else: else:
return self.cached_col return self.cached_col

View File

@ -2064,6 +2064,20 @@ class ForeignKey(ForeignObject):
def db_parameters(self, connection): def db_parameters(self, connection):
return {"type": self.db_type(connection), "check": []} return {"type": self.db_type(connection), "check": []}
def convert_empty_strings(self, value, connection, context):
if (not value) and isinstance(value, six.string_types):
return None
return value
def get_db_converters(self, connection):
converters = super(ForeignKey, self).get_db_converters(connection)
if connection.features.interprets_empty_strings_as_nulls:
converters += [self.convert_empty_strings]
return converters
def get_col(self, alias, output_field=None):
return super(ForeignKey, self).get_col(alias, output_field or self.related_field)
class OneToOneField(ForeignKey): class OneToOneField(ForeignKey):
""" """

View File

@ -57,7 +57,7 @@ class ModelIterator(BaseIterator):
model_cls = klass_info['model'] model_cls = klass_info['model']
select_fields = klass_info['select_fields'] select_fields = klass_info['select_fields']
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
init_list = [f[0].output_field.attname init_list = [f[0].target.attname
for f in select[model_fields_start:model_fields_end]] for f in select[model_fields_start:model_fields_end]]
if len(init_list) != len(model_cls._meta.concrete_fields): if len(init_list) != len(model_cls._meta.concrete_fields):
init_set = set(init_list) init_set = set(init_list)
@ -1618,7 +1618,7 @@ class RelatedPopulator(object):
self.cols_start = select_fields[0] self.cols_start = select_fields[0]
self.cols_end = select_fields[-1] + 1 self.cols_end = select_fields[-1] + 1
self.init_list = [ self.init_list = [
f[0].output_field.attname for f in select[self.cols_start:self.cols_end] f[0].target.attname for f in select[self.cols_start:self.cols_end]
] ]
self.reorder_for_init = None self.reorder_for_init = None
else: else:
@ -1627,7 +1627,7 @@ class RelatedPopulator(object):
] ]
reorder_map = [] reorder_map = []
for idx in select_fields: for idx in select_fields:
field = select[idx][0].output_field field = select[idx][0].target
init_pos = model_init_attnames.index(field.attname) init_pos = model_init_attnames.index(field.attname)
reorder_map.append((init_pos, field.attname, idx)) reorder_map.append((init_pos, field.attname, idx))
reorder_map.sort() reorder_map.sort()

View File

@ -1458,7 +1458,7 @@ class Query(object):
# database from tripping over IN (...,NULL,...) selects and returning # database from tripping over IN (...,NULL,...) selects and returning
# nothing # nothing
col = query.select[0] col = query.select[0]
select_field = col.field select_field = col.target
alias = col.alias alias = col.alias
if self.is_nullable(select_field): if self.is_nullable(select_field):
lookup_class = select_field.get_lookup('isnull') lookup_class = select_field.get_lookup('isnull')

View File

@ -369,6 +369,10 @@ class PrimaryKeyUUIDModel(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4) id = models.UUIDField(primary_key=True, default=uuid.uuid4)
class RelatedToUUIDModel(models.Model):
uuid_fk = models.ForeignKey('PrimaryKeyUUIDModel')
############################################################################### ###############################################################################
# See ticket #24215. # See ticket #24215.

View File

@ -5,7 +5,9 @@ from django.core import exceptions, serializers
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from .models import NullableUUIDModel, PrimaryKeyUUIDModel, UUIDModel from .models import (
NullableUUIDModel, PrimaryKeyUUIDModel, RelatedToUUIDModel, UUIDModel,
)
class TestSaveLoad(TestCase): class TestSaveLoad(TestCase):
@ -121,3 +123,9 @@ class TestAsPrimaryKey(TestCase):
self.assertTrue(u1_found) self.assertTrue(u1_found)
self.assertTrue(u2_found) self.assertTrue(u2_found)
self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2) self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2)
def test_underlying_field(self):
pk_model = PrimaryKeyUUIDModel.objects.create()
RelatedToUUIDModel.objects.create(uuid_fk=pk_model)
related = RelatedToUUIDModel.objects.get()
self.assertEqual(related.uuid_fk.pk, related.uuid_fk_id)