[1.8.x] Fixed #24343 -- Ensure db converters are used for foreign keys.
Joint effort between myself, Josh, Anssi and Shai.
Conflicts:
django/db/models/query.py
tests/model_fields/models.py
Backport of 4755f8fc25
from master.
This commit is contained in:
parent
82f39bfb1a
commit
c54d73ae01
|
@ -585,10 +585,10 @@ class Random(ExpressionNode):
|
||||||
|
|
||||||
|
|
||||||
class Col(ExpressionNode):
|
class Col(ExpressionNode):
|
||||||
def __init__(self, alias, target, source=None):
|
def __init__(self, alias, target, output_field=None):
|
||||||
if source is None:
|
if output_field is None:
|
||||||
source = target
|
output_field = target
|
||||||
super(Col, self).__init__(output_field=source)
|
super(Col, self).__init__(output_field=output_field)
|
||||||
self.alias, self.target = alias, target
|
self.alias, self.target = alias, target
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -606,7 +606,10 @@ class Col(ExpressionNode):
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
def get_db_converters(self, connection):
|
def get_db_converters(self, connection):
|
||||||
return self.output_field.get_db_converters(connection)
|
if self.target == self.output_field:
|
||||||
|
return self.output_field.get_db_converters(connection)
|
||||||
|
return (self.output_field.get_db_converters(connection) +
|
||||||
|
self.target.get_db_converters(connection))
|
||||||
|
|
||||||
|
|
||||||
class Ref(ExpressionNode):
|
class Ref(ExpressionNode):
|
||||||
|
|
|
@ -333,12 +333,12 @@ class Field(RegisterLookupMixin):
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_col(self, alias, source=None):
|
def get_col(self, alias, output_field=None):
|
||||||
if source is None:
|
if output_field is None:
|
||||||
source = self
|
output_field = self
|
||||||
if alias != self.model._meta.db_table or source != self:
|
if alias != self.model._meta.db_table or output_field != self:
|
||||||
from django.db.models.expressions import Col
|
from django.db.models.expressions import Col
|
||||||
return Col(alias, self, source)
|
return Col(alias, self, output_field)
|
||||||
else:
|
else:
|
||||||
return self.cached_col
|
return self.cached_col
|
||||||
|
|
||||||
|
|
|
@ -1991,6 +1991,20 @@ class ForeignKey(ForeignObject):
|
||||||
def db_parameters(self, connection):
|
def db_parameters(self, connection):
|
||||||
return {"type": self.db_type(connection), "check": []}
|
return {"type": self.db_type(connection), "check": []}
|
||||||
|
|
||||||
|
def convert_empty_strings(self, value, connection, context):
|
||||||
|
if (not value) and isinstance(value, six.string_types):
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_db_converters(self, connection):
|
||||||
|
converters = super(ForeignKey, self).get_db_converters(connection)
|
||||||
|
if connection.features.interprets_empty_strings_as_nulls:
|
||||||
|
converters += [self.convert_empty_strings]
|
||||||
|
return converters
|
||||||
|
|
||||||
|
def get_col(self, alias, output_field=None):
|
||||||
|
return super(ForeignKey, self).get_col(alias, output_field or self.related_field)
|
||||||
|
|
||||||
|
|
||||||
class OneToOneField(ForeignKey):
|
class OneToOneField(ForeignKey):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -243,7 +243,7 @@ class QuerySet(object):
|
||||||
model_cls = klass_info['model']
|
model_cls = klass_info['model']
|
||||||
select_fields = klass_info['select_fields']
|
select_fields = klass_info['select_fields']
|
||||||
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
|
model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
|
||||||
init_list = [f[0].output_field.attname
|
init_list = [f[0].target.attname
|
||||||
for f in select[model_fields_start:model_fields_end]]
|
for f in select[model_fields_start:model_fields_end]]
|
||||||
if len(init_list) != len(model_cls._meta.concrete_fields):
|
if len(init_list) != len(model_cls._meta.concrete_fields):
|
||||||
init_set = set(init_list)
|
init_set = set(init_list)
|
||||||
|
@ -1699,7 +1699,7 @@ class RelatedPopulator(object):
|
||||||
self.cols_start = select_fields[0]
|
self.cols_start = select_fields[0]
|
||||||
self.cols_end = select_fields[-1] + 1
|
self.cols_end = select_fields[-1] + 1
|
||||||
self.init_list = [
|
self.init_list = [
|
||||||
f[0].output_field.attname for f in select[self.cols_start:self.cols_end]
|
f[0].target.attname for f in select[self.cols_start:self.cols_end]
|
||||||
]
|
]
|
||||||
self.reorder_for_init = None
|
self.reorder_for_init = None
|
||||||
else:
|
else:
|
||||||
|
@ -1708,7 +1708,7 @@ class RelatedPopulator(object):
|
||||||
]
|
]
|
||||||
reorder_map = []
|
reorder_map = []
|
||||||
for idx in select_fields:
|
for idx in select_fields:
|
||||||
field = select[idx][0].output_field
|
field = select[idx][0].target
|
||||||
init_pos = model_init_attnames.index(field.attname)
|
init_pos = model_init_attnames.index(field.attname)
|
||||||
reorder_map.append((init_pos, field.attname, idx))
|
reorder_map.append((init_pos, field.attname, idx))
|
||||||
reorder_map.sort()
|
reorder_map.sort()
|
||||||
|
|
|
@ -1542,7 +1542,7 @@ class Query(object):
|
||||||
# database from tripping over IN (...,NULL,...) selects and returning
|
# database from tripping over IN (...,NULL,...) selects and returning
|
||||||
# nothing
|
# nothing
|
||||||
col = query.select[0]
|
col = query.select[0]
|
||||||
select_field = col.field
|
select_field = col.target
|
||||||
alias = col.alias
|
alias = col.alias
|
||||||
if self.is_nullable(select_field):
|
if self.is_nullable(select_field):
|
||||||
lookup_class = select_field.get_lookup('isnull')
|
lookup_class = select_field.get_lookup('isnull')
|
||||||
|
|
|
@ -369,3 +369,7 @@ class NullableUUIDModel(models.Model):
|
||||||
|
|
||||||
class PrimaryKeyUUIDModel(models.Model):
|
class PrimaryKeyUUIDModel(models.Model):
|
||||||
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
|
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedToUUIDModel(models.Model):
|
||||||
|
uuid_fk = models.ForeignKey('PrimaryKeyUUIDModel')
|
||||||
|
|
|
@ -5,7 +5,9 @@ from django.core import exceptions, serializers
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from .models import NullableUUIDModel, PrimaryKeyUUIDModel, UUIDModel
|
from .models import (
|
||||||
|
NullableUUIDModel, PrimaryKeyUUIDModel, RelatedToUUIDModel, UUIDModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSaveLoad(TestCase):
|
class TestSaveLoad(TestCase):
|
||||||
|
@ -121,3 +123,9 @@ class TestAsPrimaryKey(TestCase):
|
||||||
self.assertTrue(u1_found)
|
self.assertTrue(u1_found)
|
||||||
self.assertTrue(u2_found)
|
self.assertTrue(u2_found)
|
||||||
self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2)
|
self.assertEqual(PrimaryKeyUUIDModel.objects.count(), 2)
|
||||||
|
|
||||||
|
def test_underlying_field(self):
|
||||||
|
pk_model = PrimaryKeyUUIDModel.objects.create()
|
||||||
|
RelatedToUUIDModel.objects.create(uuid_fk=pk_model)
|
||||||
|
related = RelatedToUUIDModel.objects.get()
|
||||||
|
self.assertEqual(related.uuid_fk.pk, related.uuid_fk_id)
|
||||||
|
|
Loading…
Reference in New Issue