Fixed #19385 again, now with real code changes

The commit of 266de5f9ae included only
tests, this time also code changes included...
This commit is contained in:
Anssi Kääriäinen 2013-03-24 18:40:40 +02:00
parent 266de5f9ae
commit 97774429ae
17 changed files with 653 additions and 426 deletions

View File

@ -8,10 +8,11 @@ from functools import partial
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db import connection from django.db import connection
from django.db.models import signals
from django.db import models, router, DEFAULT_DB_ALIAS from django.db import models, router, DEFAULT_DB_ALIAS
from django.db.models.fields.related import RelatedField, Field, ManyToManyRel from django.db.models import signals
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.related import PathInfo from django.db.models.related import PathInfo
from django.db.models.sql.where import Constraint
from django.forms import ModelForm from django.forms import ModelForm
from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance
from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets
@ -149,17 +150,14 @@ class GenericForeignKey(six.with_metaclass(RenameGenericForeignKeyMethods)):
setattr(instance, self.fk_field, fk) setattr(instance, self.fk_field, fk)
setattr(instance, self.cache_attr, value) setattr(instance, self.cache_attr, value)
class GenericRelation(RelatedField, Field): class GenericRelation(ForeignObject):
"""Provides an accessor to generic related objects (e.g. comments)""" """Provides an accessor to generic related objects (e.g. comments)"""
def __init__(self, to, **kwargs): def __init__(self, to, **kwargs):
kwargs['verbose_name'] = kwargs.get('verbose_name', None) kwargs['verbose_name'] = kwargs.get('verbose_name', None)
kwargs['rel'] = GenericRel(to, kwargs['rel'] = GenericRel(
related_name=kwargs.pop('related_name', None), self, to, related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None), limit_choices_to=kwargs.pop('limit_choices_to', None),)
symmetrical=kwargs.pop('symmetrical', True))
# Override content-type/object-id field names on the related class # Override content-type/object-id field names on the related class
self.object_id_field_name = kwargs.pop("object_id_field", "object_id") self.object_id_field_name = kwargs.pop("object_id_field", "object_id")
self.content_type_field_name = kwargs.pop("content_type_field", "content_type") self.content_type_field_name = kwargs.pop("content_type_field", "content_type")
@ -167,47 +165,44 @@ class GenericRelation(RelatedField, Field):
kwargs['blank'] = True kwargs['blank'] = True
kwargs['editable'] = False kwargs['editable'] = False
kwargs['serialize'] = False kwargs['serialize'] = False
Field.__init__(self, **kwargs) # This construct is somewhat of an abuse of ForeignObject. This field
# represents a relation from pk to object_id field. But, this relation
# isn't direct, the join is generated reverse along foreign key. So,
# the from_field is object_id field, to_field is pk because of the
# reverse join.
super(GenericRelation, self).__init__(
to, to_fields=[],
from_fields=[self.object_id_field_name], **kwargs)
def get_path_info(self): def resolve_related_fields(self):
from_field = self.model._meta.pk self.to_fields = [self.model._meta.pk.name]
return [(self.rel.to._meta.get_field_by_name(self.object_id_field_name)[0],
self.model._meta.pk)]
def get_reverse_path_info(self):
opts = self.rel.to._meta opts = self.rel.to._meta
target = opts.get_field_by_name(self.object_id_field_name)[0] target = opts.get_field_by_name(self.object_id_field_name)[0]
# Note that we are using different field for the join_field return [PathInfo(self.model._meta, opts, (target,), self.rel, True, False)]
# than from_field or to_field. This is a hack, but we need the
# GenericRelation to generate the extra SQL.
return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
opts, target, self)
def get_choices_default(self): def get_choices_default(self):
return Field.get_choices(self, include_blank=False) return super(GenericRelation, self).get_choices(include_blank=False)
def value_to_string(self, obj): def value_to_string(self, obj):
qs = getattr(obj, self.name).all() qs = getattr(obj, self.name).all()
return smart_text([instance._get_pk_val() for instance in qs]) return smart_text([instance._get_pk_val() for instance in qs])
def m2m_db_table(self): def get_joining_columns(self, reverse_join=False):
return self.rel.to._meta.db_table if not reverse_join:
# This error message is meant for the user, and from user
def m2m_column_name(self): # perspective this is a reverse join along the GenericRelation.
return self.object_id_field_name raise ValueError('Joining in reverse direction not allowed.')
return super(GenericRelation, self).get_joining_columns(reverse_join)
def m2m_reverse_name(self):
return self.rel.to._meta.pk.column
def m2m_target_field_name(self):
return self.model._meta.pk.name
def m2m_reverse_target_field_name(self):
return self.rel.to._meta.pk.name
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
super(GenericRelation, self).contribute_to_class(cls, name) super(GenericRelation, self).contribute_to_class(cls, name, virtual_only=True)
# Save a reference to which model this class is on for future use # Save a reference to which model this class is on for future use
self.model = cls self.model = cls
# Add the descriptor for the relation
# Add the descriptor for the m2m relation
setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self)) setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self))
def contribute_to_related_class(self, cls, related): def contribute_to_related_class(self, cls, related):
@ -219,21 +214,18 @@ class GenericRelation(RelatedField, Field):
def get_internal_type(self): def get_internal_type(self):
return "ManyToManyField" return "ManyToManyField"
def db_type(self, connection):
# Since we're simulating a ManyToManyField, in effect, best return the
# same db_type as well.
return None
def get_content_type(self): def get_content_type(self):
""" """
Returns the content type associated with this field's model. Returns the content type associated with this field's model.
""" """
return ContentType.objects.get_for_model(self.model) return ContentType.objects.get_for_model(self.model)
def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias): def get_extra_restriction(self, where_class, alias, remote_alias):
extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0]
contenttype = self.get_content_type().pk contenttype_pk = self.get_content_type().pk
return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype] cond = where_class()
cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND')
return cond
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
""" """
@ -273,12 +265,12 @@ class ReverseGenericRelatedObjectsDescriptor(object):
qn = connection.ops.quote_name qn = connection.ops.quote_name
content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(instance) content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(instance)
join_cols = self.field.get_joining_columns(reverse_join=True)[0]
manager = RelatedManager( manager = RelatedManager(
model = rel_model, model = rel_model,
instance = instance, instance = instance,
symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model), source_col_name = qn(join_cols[0]),
source_col_name = qn(self.field.m2m_column_name()), target_col_name = qn(join_cols[1]),
target_col_name = qn(self.field.m2m_reverse_name()),
content_type = content_type, content_type = content_type,
content_type_field_name = self.field.content_type_field_name, content_type_field_name = self.field.content_type_field_name,
object_id_field_name = self.field.object_id_field_name, object_id_field_name = self.field.object_id_field_name,
@ -378,14 +370,10 @@ def create_generic_related_manager(superclass):
return GenericRelatedObjectManager return GenericRelatedObjectManager
class GenericRel(ManyToManyRel): class GenericRel(ForeignObjectRel):
def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True):
self.to = to def __init__(self, field, to, related_name=None, limit_choices_to=None):
self.related_name = related_name super(GenericRel, self).__init__(field, to, related_name, limit_choices_to)
self.limit_choices_to = limit_choices_to or {}
self.symmetrical = symmetrical
self.multiple = True
self.through = None
class BaseGenericInlineFormSet(BaseModelFormSet): class BaseGenericInlineFormSet(BaseModelFormSet):
""" """

View File

@ -153,8 +153,16 @@ def get_validation_errors(outfile, app=None):
continue continue
# Make sure the related field specified by a ForeignKey is unique # Make sure the related field specified by a ForeignKey is unique
if not f.rel.to._meta.get_field(f.rel.field_name).unique: if f.requires_unique_target:
e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.rel.field_name, f.rel.to.__name__)) if len(f.foreign_related_fields) > 1:
has_unique_field = False
for rel_field in f.foreign_related_fields:
has_unique_field = has_unique_field or rel_field.unique
if not has_unique_field:
e.add(opts, "Field combination '%s' under model '%s' must have a unique=True constraint" % (','.join([rel_field.name for rel_field in f.foreign_related_fields]), f.rel.to.__name__))
else:
if not f.foreign_related_fields[0].unique:
e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.foreign_related_fields[0].name, f.rel.to.__name__))
rel_opts = f.rel.to._meta rel_opts = f.rel.to._meta
rel_name = f.related.get_accessor_name() rel_name = f.related.get_accessor_name()

View File

@ -17,6 +17,12 @@ class SQLCompiler(compiler.SQLCompiler):
values.append(value) values.append(value)
return row[:index_extra_select] + tuple(values) return row[:index_extra_select] + tuple(values)
def as_subquery_condition(self, alias, columns):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
sql, params = self.as_sql()
return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass pass

View File

@ -8,7 +8,7 @@ from django.db.models.aggregates import *
from django.db.models.fields import * from django.db.models.fields import *
from django.db.models.fields.subclassing import SubfieldBase from django.db.models.fields.subclassing import SubfieldBase
from django.db.models.fields.files import FileField, ImageField from django.db.models.fields.files import FileField, ImageField
from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel from django.db.models.fields.related import ForeignKey, ForeignObject, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel
from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError
from django.db.models import signals from django.db.models import signals
from django.utils.decorators import wraps from django.utils.decorators import wraps

View File

@ -10,7 +10,7 @@ from django.conf import settings
from django.core.exceptions import (ObjectDoesNotExist, from django.core.exceptions import (ObjectDoesNotExist,
MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS) MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS)
from django.db.models.fields import AutoField, FieldDoesNotExist from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.related import (ManyToOneRel, from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel,
OneToOneField, add_lazy_relation) OneToOneField, add_lazy_relation)
from django.db import (router, transaction, DatabaseError, from django.db import (router, transaction, DatabaseError,
DEFAULT_DB_ALIAS) DEFAULT_DB_ALIAS)
@ -333,12 +333,12 @@ class Model(six.with_metaclass(ModelBase)):
# The reason for the kwargs check is that standard iterator passes in by # The reason for the kwargs check is that standard iterator passes in by
# args, and instantiation for iteration is 33% faster. # args, and instantiation for iteration is 33% faster.
args_len = len(args) args_len = len(args)
if args_len > len(self._meta.fields): if args_len > len(self._meta.concrete_fields):
# Daft, but matches old exception sans the err msg. # Daft, but matches old exception sans the err msg.
raise IndexError("Number of args exceeds number of fields") raise IndexError("Number of args exceeds number of fields")
fields_iter = iter(self._meta.fields)
if not kwargs: if not kwargs:
fields_iter = iter(self._meta.concrete_fields)
# The ordering of the zip calls matter - zip throws StopIteration # The ordering of the zip calls matter - zip throws StopIteration
# when an iter throws it. So if the first iter throws it, the second # when an iter throws it. So if the first iter throws it, the second
# is *not* consumed. We rely on this, so don't change the order # is *not* consumed. We rely on this, so don't change the order
@ -347,6 +347,7 @@ class Model(six.with_metaclass(ModelBase)):
setattr(self, field.attname, val) setattr(self, field.attname, val)
else: else:
# Slower, kwargs-ready version. # Slower, kwargs-ready version.
fields_iter = iter(self._meta.fields)
for val, field in zip(args, fields_iter): for val, field in zip(args, fields_iter):
setattr(self, field.attname, val) setattr(self, field.attname, val)
kwargs.pop(field.name, None) kwargs.pop(field.name, None)
@ -363,11 +364,12 @@ class Model(six.with_metaclass(ModelBase)):
# data-descriptor object (DeferredAttribute) without triggering its # data-descriptor object (DeferredAttribute) without triggering its
# __get__ method. # __get__ method.
if (field.attname not in kwargs and if (field.attname not in kwargs and
isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)): (isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)
or field.column is None)):
# This field will be populated on request. # This field will be populated on request.
continue continue
if kwargs: if kwargs:
if isinstance(field.rel, ManyToOneRel): if isinstance(field.rel, ForeignObjectRel):
try: try:
# Assume object instance was passed in. # Assume object instance was passed in.
rel_obj = kwargs.pop(field.name) rel_obj = kwargs.pop(field.name)
@ -394,6 +396,7 @@ class Model(six.with_metaclass(ModelBase)):
val = field.get_default() val = field.get_default()
else: else:
val = field.get_default() val = field.get_default()
if is_related_object: if is_related_object:
# If we are passed a related instance, set it using the # If we are passed a related instance, set it using the
# field.name instead of field.attname (e.g. "user" instead of # field.name instead of field.attname (e.g. "user" instead of
@ -528,7 +531,7 @@ class Model(six.with_metaclass(ModelBase)):
# automatically do a "update_fields" save on the loaded fields. # automatically do a "update_fields" save on the loaded fields.
elif not force_insert and self._deferred and using == self._state.db: elif not force_insert and self._deferred and using == self._state.db:
field_names = set() field_names = set()
for field in self._meta.fields: for field in self._meta.concrete_fields:
if not field.primary_key and not hasattr(field, 'through'): if not field.primary_key and not hasattr(field, 'through'):
field_names.add(field.attname) field_names.add(field.attname)
deferred_fields = [ deferred_fields = [
@ -614,7 +617,7 @@ class Model(six.with_metaclass(ModelBase)):
for a single table. for a single table.
""" """
meta = cls._meta meta = cls._meta
non_pks = [f for f in meta.local_fields if not f.primary_key] non_pks = [f for f in meta.local_concrete_fields if not f.primary_key]
if update_fields: if update_fields:
non_pks = [f for f in non_pks non_pks = [f for f in non_pks
@ -652,7 +655,7 @@ class Model(six.with_metaclass(ModelBase)):
**{field.name: getattr(self, field.attname)}).count() **{field.name: getattr(self, field.attname)}).count()
self._order = order_value self._order = order_value
fields = meta.local_fields fields = meta.local_concrete_fields
if not pk_set: if not pk_set:
fields = [f for f in fields if not isinstance(f, AutoField)] fields = [f for f in fields if not isinstance(f, AutoField)]

View File

@ -1,4 +1,3 @@
from functools import wraps
from operator import attrgetter from operator import attrgetter
from django.db import connections, transaction, IntegrityError from django.db import connections, transaction, IntegrityError
@ -196,17 +195,13 @@ class Collector(object):
self.fast_deletes.append(sub_objs) self.fast_deletes.append(sub_objs)
elif sub_objs: elif sub_objs:
field.rel.on_delete(self, field, sub_objs, self.using) field.rel.on_delete(self, field, sub_objs, self.using)
for field in model._meta.virtual_fields:
# TODO This entire block is only needed as a special case to if hasattr(field, 'bulk_related_objects'):
# support cascade-deletes for GenericRelation. It should be # Its something like generic foreign key.
# removed/fixed when the ORM gains a proper abstraction for virtual sub_objs = field.bulk_related_objects(new_objs, self.using)
# or composite fields, and GFKs are reworked to fit into that.
for relation in model._meta.many_to_many:
if not relation.rel.through:
sub_objs = relation.bulk_related_objects(new_objs, self.using)
self.collect(sub_objs, self.collect(sub_objs,
source=model, source=model,
source_attr=relation.rel.related_name, source_attr=field.rel.related_name,
nullable=True) nullable=True)
def related_objects(self, related, objs): def related_objects(self, related, objs):

View File

@ -292,10 +292,13 @@ class Field(object):
if self.verbose_name is None and self.name: if self.verbose_name is None and self.name:
self.verbose_name = self.name.replace('_', ' ') self.verbose_name = self.name.replace('_', ' ')
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name, virtual_only=False):
self.set_attributes_from_name(name) self.set_attributes_from_name(name)
self.model = cls self.model = cls
cls._meta.add_field(self) if virtual_only:
cls._meta.add_virtual_field(self)
else:
cls._meta.add_field(self)
if self.choices: if self.choices:
setattr(cls, 'get_%s_display' % self.name, setattr(cls, 'get_%s_display' % self.name,
curry(cls._get_FIELD_display, field=self)) curry(cls._get_FIELD_display, field=self))

View File

@ -7,7 +7,6 @@ from django.db.models.fields import (AutoField, Field, IntegerField,
PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist)
from django.db.models.related import RelatedObject, PathInfo from django.db.models.related import RelatedObject, PathInfo
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.db.models.query_utils import QueryWrapper
from django.db.models.deletion import CASCADE from django.db.models.deletion import CASCADE
from django.utils.encoding import smart_text from django.utils.encoding import smart_text
from django.utils import six from django.utils import six
@ -93,22 +92,27 @@ signals.class_prepared.connect(do_pending_lookups)
#HACK #HACK
class RelatedField(object): class RelatedField(Field):
def contribute_to_class(self, cls, name): def db_type(self, connection):
'''By default related field will not have a column
as it relates columns to another table'''
return None
def contribute_to_class(self, cls, name, virtual_only=False):
sup = super(RelatedField, self) sup = super(RelatedField, self)
# Store the opts for related_query_name() # Store the opts for related_query_name()
self.opts = cls._meta self.opts = cls._meta
if hasattr(sup, 'contribute_to_class'): if hasattr(sup, 'contribute_to_class'):
sup.contribute_to_class(cls, name) sup.contribute_to_class(cls, name, virtual_only=virtual_only)
if not cls._meta.abstract and self.rel.related_name: if not cls._meta.abstract and self.rel.related_name:
self.rel.related_name = self.rel.related_name % { related_name = self.rel.related_name % {
'class': cls.__name__.lower(), 'class': cls.__name__.lower(),
'app_label': cls._meta.app_label.lower(), 'app_label': cls._meta.app_label.lower()
} }
self.rel.related_name = related_name
other = self.rel.to other = self.rel.to
if isinstance(other, six.string_types) or other._meta.pk is None: if isinstance(other, six.string_types) or other._meta.pk is None:
def resolve_related_class(field, model, cls): def resolve_related_class(field, model, cls):
@ -122,7 +126,6 @@ class RelatedField(object):
self.name = self.name or (self.rel.to._meta.model_name + '_' + self.rel.to._meta.pk.name) self.name = self.name or (self.rel.to._meta.model_name + '_' + self.rel.to._meta.pk.name)
if self.verbose_name is None: if self.verbose_name is None:
self.verbose_name = self.rel.to._meta.verbose_name self.verbose_name = self.rel.to._meta.verbose_name
self.rel.field_name = self.rel.field_name or self.rel.to._meta.pk.name
def do_related_class(self, other, cls): def do_related_class(self, other, cls):
self.set_attributes_from_rel() self.set_attributes_from_rel()
@ -130,94 +133,6 @@ class RelatedField(object):
if not cls._meta.abstract: if not cls._meta.abstract:
self.contribute_to_related_class(other, self.related) self.contribute_to_related_class(other, self.related)
def get_prep_lookup(self, lookup_type, value):
if hasattr(value, 'prepare'):
return value.prepare()
if hasattr(value, '_prepare'):
return value._prepare()
# FIXME: lt and gt are explicitly allowed to make
# get_(next/prev)_by_date work; other lookups are not allowed since that
# gets messy pretty quick. This is a good candidate for some refactoring
# in the future.
if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
return self._pk_trace(value, 'get_prep_lookup', lookup_type)
if lookup_type in ('range', 'in'):
return [self._pk_trace(v, 'get_prep_lookup', lookup_type) for v in value]
elif lookup_type == 'isnull':
return []
raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
if not prepared:
value = self.get_prep_lookup(lookup_type, value)
if hasattr(value, 'get_compiler'):
value = value.get_compiler(connection=connection)
if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'):
# If the value has a relabeled_clone method it means the
# value will be handled later on.
if hasattr(value, 'relabeled_clone'):
return value
if hasattr(value, 'as_sql'):
sql, params = value.as_sql()
else:
sql, params = value._as_sql(connection=connection)
return QueryWrapper(('(%s)' % sql), params)
# FIXME: lt and gt are explicitly allowed to make
# get_(next/prev)_by_date work; other lookups are not allowed since that
# gets messy pretty quick. This is a good candidate for some refactoring
# in the future.
if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']:
return [self._pk_trace(value, 'get_db_prep_lookup', lookup_type,
connection=connection, prepared=prepared)]
if lookup_type in ('range', 'in'):
return [self._pk_trace(v, 'get_db_prep_lookup', lookup_type,
connection=connection, prepared=prepared)
for v in value]
elif lookup_type == 'isnull':
return []
raise TypeError("Related Field has invalid lookup: %s" % lookup_type)
def _pk_trace(self, value, prep_func, lookup_type, **kwargs):
# Value may be a primary key, or an object held in a relation.
# If it is an object, then we need to get the primary key value for
# that object. In certain conditions (especially one-to-one relations),
# the primary key may itself be an object - so we need to keep drilling
# down until we hit a value that can be used for a comparison.
v = value
# In the case of an FK to 'self', this check allows to_field to be used
# for both forwards and reverse lookups across the FK. (For normal FKs,
# it's only relevant for forward lookups).
if isinstance(v, self.rel.to):
field_name = getattr(self.rel, "field_name", None)
else:
field_name = None
try:
while True:
if field_name is None:
field_name = v._meta.pk.name
v = getattr(v, field_name)
field_name = None
except AttributeError:
pass
except exceptions.ObjectDoesNotExist:
v = None
field = self
while field.rel:
if hasattr(field.rel, 'field_name'):
field = field.rel.to._meta.get_field(field.rel.field_name)
else:
field = field.rel.to._meta.pk
if lookup_type in ('range', 'in'):
v = [v]
v = getattr(field, prep_func)(lookup_type, v, **kwargs)
if isinstance(v, list):
v = v[0]
return v
def related_query_name(self): def related_query_name(self):
# This method defines the name that can be used to identify this # This method defines the name that can be used to identify this
# related object in a table-spanning query. It uses the lower-cased # related object in a table-spanning query. It uses the lower-cased
@ -254,8 +169,8 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri
rel_obj_attr = attrgetter(self.related.field.attname) rel_obj_attr = attrgetter(self.related.field.attname)
instance_attr = lambda obj: obj._get_pk_val() instance_attr = lambda obj: obj._get_pk_val()
instances_dict = dict((instance_attr(inst), inst) for inst in instances) instances_dict = dict((instance_attr(inst), inst) for inst in instances)
params = {'%s__pk__in' % self.related.field.name: list(instances_dict)} query = {'%s__in' % self.related.field.name: instances}
qs = self.get_queryset(instance=instances[0]).filter(**params) qs = self.get_query_set(instance=instances[0]).filter(**query)
# Since we're going to assign directly in the cache, # Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually. # we must manage the reverse relation cache manually.
rel_obj_cache_name = self.related.field.get_cache_name() rel_obj_cache_name = self.related.field.get_cache_name()
@ -274,7 +189,9 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri
if related_pk is None: if related_pk is None:
rel_obj = None rel_obj = None
else: else:
params = {'%s__pk' % self.related.field.name: related_pk} params = {}
for lh_field, rh_field in self.related.field.related_fields:
params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname)
try: try:
rel_obj = self.get_queryset(instance=instance).get(**params) rel_obj = self.get_queryset(instance=instance).get(**params)
except self.related.model.DoesNotExist: except self.related.model.DoesNotExist:
@ -314,13 +231,14 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri
raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' % raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' %
(value, instance._state.db, value._state.db)) (value, instance._state.db, value._state.db))
related_pk = getattr(instance, self.related.field.rel.get_related_field().attname) related_pk = tuple([getattr(instance, field.attname) for field in self.related.field.foreign_related_fields])
if related_pk is None: if None in related_pk:
raise ValueError('Cannot assign "%r": "%s" instance isn\'t saved in the database.' % raise ValueError('Cannot assign "%r": "%s" instance isn\'t saved in the database.' %
(value, instance._meta.object_name)) (value, instance._meta.object_name))
# Set the value of the related field to the value of the related object's related field # Set the value of the related field to the value of the related object's related field
setattr(value, self.related.field.attname, related_pk) for index, field in enumerate(self.related.field.local_related_fields):
setattr(value, field.attname, related_pk[index])
# Since we already know what the related object is, seed the related # Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the # object caches now, too. This avoids another db hit if you get the
@ -352,16 +270,12 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
else: else:
return QuerySet(self.field.rel.to).using(db) return QuerySet(self.field.rel.to).using(db)
def get_prefetch_queryset(self, instances): def get_prefetch_query_set(self, instances):
other_field = self.field.rel.get_related_field() rel_obj_attr = self.field.get_foreign_related_value
rel_obj_attr = attrgetter(other_field.attname) instance_attr = self.field.get_local_related_value
instance_attr = attrgetter(self.field.attname)
instances_dict = dict((instance_attr(inst), inst) for inst in instances) instances_dict = dict((instance_attr(inst), inst) for inst in instances)
if other_field.rel: query = {'%s__in' % self.field.related_query_name(): instances}
params = {'%s__pk__in' % self.field.rel.field_name: list(instances_dict)} qs = self.get_query_set(instance=instances[0]).filter(**query)
else:
params = {'%s__in' % self.field.rel.field_name: list(instances_dict)}
qs = self.get_queryset(instance=instances[0]).filter(**params)
# Since we're going to assign directly in the cache, # Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually. # we must manage the reverse relation cache manually.
if not self.field.rel.multiple: if not self.field.rel.multiple:
@ -377,16 +291,14 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
try: try:
rel_obj = getattr(instance, self.cache_name) rel_obj = getattr(instance, self.cache_name)
except AttributeError: except AttributeError:
val = getattr(instance, self.field.attname) val = self.field.get_local_related_value(instance)
if val is None: if None in val:
rel_obj = None rel_obj = None
else: else:
other_field = self.field.rel.get_related_field() params = {rh_field.attname: getattr(instance, lh_field.attname)
if other_field.rel: for lh_field, rh_field in self.field.related_fields}
params = {'%s__%s' % (self.field.rel.field_name, other_field.rel.field_name): val} params.update(self.field.get_extra_descriptor_filter(instance))
else: qs = self.get_query_set(instance=instance)
params = {'%s__exact' % self.field.rel.field_name: val}
qs = self.get_queryset(instance=instance)
# Assuming the database enforces foreign keys, this won't fail. # Assuming the database enforces foreign keys, this won't fail.
rel_obj = qs.get(**params) rel_obj = qs.get(**params)
if not self.field.rel.multiple: if not self.field.rel.multiple:
@ -440,11 +352,11 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
setattr(related, self.field.related.get_cache_name(), None) setattr(related, self.field.related.get_cache_name(), None)
# Set the value of the related field # Set the value of the related field
try: for lh_field, rh_field in self.field.related_fields:
val = getattr(value, self.field.rel.get_related_field().attname) try:
except AttributeError: setattr(instance, lh_field.attname, getattr(value, rh_field.attname))
val = None except AttributeError:
setattr(instance, self.field.attname, val) setattr(instance, lh_field.attname, None)
# Since we already know what the related object is, seed the related # Since we already know what the related object is, seed the related
# object caches now, too. This avoids another db hit if you get the # object caches now, too. This avoids another db hit if you get the
@ -487,15 +399,12 @@ class ForeignRelatedObjectsDescriptor(object):
superclass = self.related.model._default_manager.__class__ superclass = self.related.model._default_manager.__class__
rel_field = self.related.field rel_field = self.related.field
rel_model = self.related.model rel_model = self.related.model
attname = rel_field.rel.get_related_field().attname
class RelatedManager(superclass): class RelatedManager(superclass):
def __init__(self, instance): def __init__(self, instance):
super(RelatedManager, self).__init__() super(RelatedManager, self).__init__()
self.instance = instance self.instance = instance
self.core_filters = { self.core_filters= {'%s__exact' % rel_field.name: instance}
'%s__%s' % (rel_field.name, attname): getattr(instance, attname)
}
self.model = rel_model self.model = rel_model
def get_queryset(self): def get_queryset(self):
@ -504,20 +413,22 @@ class ForeignRelatedObjectsDescriptor(object):
except (AttributeError, KeyError): except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance) db = self._db or router.db_for_read(self.model, instance=self.instance)
qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters) qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters)
val = getattr(self.instance, attname) empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
if val is None or val == '' and connections[db].features.interprets_empty_strings_as_nulls: for field in rel_field.foreign_related_fields:
return qs.none() val = getattr(self.instance, field.attname)
if val is None or (val == '' and empty_strings_as_null):
return qs.none()
qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}} qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs return qs
def get_prefetch_queryset(self, instances): def get_prefetch_queryset(self, instances):
rel_obj_attr = attrgetter(rel_field.attname) rel_obj_attr = rel_field.get_local_related_value
instance_attr = attrgetter(attname) instance_attr = rel_field.get_foreign_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances) instances_dict = dict((instance_attr(inst), inst) for inst in instances)
db = self._db or router.db_for_read(self.model, instance=instances[0]) db = self._db or router.db_for_read(self.model, instance=instances[0])
query = {'%s__%s__in' % (rel_field.name, attname): list(instances_dict)} query = {'%s__in' % rel_field.name: instances}
qs = super(RelatedManager, self).get_queryset().using(db).filter(**query) qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
# Since we just bypassed this class' get_queryset(), we must manage # Since we just bypassed this class' get_query_set(), we must manage
# the reverse relation manually. # the reverse relation manually.
for rel_obj in qs: for rel_obj in qs:
instance = instances_dict[rel_obj_attr(rel_obj)] instance = instances_dict[rel_obj_attr(rel_obj)]
@ -550,10 +461,10 @@ class ForeignRelatedObjectsDescriptor(object):
# remove() and clear() are only provided if the ForeignKey can have a value of null. # remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel_field.null: if rel_field.null:
def remove(self, *objs): def remove(self, *objs):
val = getattr(self.instance, attname) val = rel_field.get_foreign_related_value(self.instance)
for obj in objs: for obj in objs:
# Is obj actually part of this descriptor set? # Is obj actually part of this descriptor set?
if getattr(obj, rel_field.attname) == val: if rel_field.get_local_related_value(obj) == val:
setattr(obj, rel_field.name, None) setattr(obj, rel_field.name, None)
obj.save() obj.save()
else: else:
@ -577,16 +488,26 @@ def create_many_related_manager(superclass, rel):
super(ManyRelatedManager, self).__init__() super(ManyRelatedManager, self).__init__()
self.model = model self.model = model
self.query_field_name = query_field_name self.query_field_name = query_field_name
self.core_filters = {'%s__pk' % query_field_name: instance._get_pk_val()}
source_field = through._meta.get_field(source_field_name)
source_related_fields = source_field.related_fields
self.core_filters = {}
for lh_field, rh_field in source_related_fields:
self.core_filters['%s__%s' % (query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)
self.instance = instance self.instance = instance
self.symmetrical = symmetrical self.symmetrical = symmetrical
self.source_field = source_field
self.source_field_name = source_field_name self.source_field_name = source_field_name
self.target_field_name = target_field_name self.target_field_name = target_field_name
self.reverse = reverse self.reverse = reverse
self.through = through self.through = through
self.prefetch_cache_name = prefetch_cache_name self.prefetch_cache_name = prefetch_cache_name
self._fk_val = self._get_fk_val(instance, source_field_name) self.related_val = source_field.get_foreign_related_value(instance)
if self._fk_val is None: # Used for single column related auto created models
self._fk_val = self.related_val[0]
if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before ' raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' % 'this many-to-many relationship can be used.' %
(instance, source_field_name)) (instance, source_field_name))
@ -620,11 +541,9 @@ def create_many_related_manager(superclass, rel):
def get_prefetch_queryset(self, instances): def get_prefetch_queryset(self, instances):
instance = instances[0] instance = instances[0]
from django.db import connections
db = self._db or router.db_for_read(instance.__class__, instance=instance) db = self._db or router.db_for_read(instance.__class__, instance=instance)
query = {'%s__pk__in' % self.query_field_name: query = {'%s__in' % self.query_field_name: instances}
set(obj._get_pk_val() for obj in instances)} qs = super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query)
qs = super(ManyRelatedManager, self).get_queryset().using(db)._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model # M2M: need to annotate the query in order to get the primary model
# that the secondary model was actually related to. We know that # that the secondary model was actually related to. We know that
@ -634,16 +553,14 @@ def create_many_related_manager(superclass, rel):
# For non-autocreated 'through' models, can't assume we are # For non-autocreated 'through' models, can't assume we are
# dealing with PK values. # dealing with PK values.
fk = self.through._meta.get_field(self.source_field_name) fk = self.through._meta.get_field(self.source_field_name)
source_col = fk.column
join_table = self.through._meta.db_table join_table = self.through._meta.db_table
connection = connections[db] connection = connections[db]
qn = connection.ops.quote_name qn = connection.ops.quote_name
qs = qs.extra(select={'_prefetch_related_val': qs = qs.extra(select={'_prefetch_related_val_%s' % f.attname:
'%s.%s' % (qn(join_table), qn(source_col))}) '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
select_attname = fk.rel.get_related_field().get_attname()
return (qs, return (qs,
attrgetter('_prefetch_related_val'), lambda result: tuple([getattr(result, '_prefetch_related_val_%s' % f.attname) for f in fk.local_related_fields]),
attrgetter(select_attname), lambda inst: tuple([getattr(inst, f.attname) for f in fk.foreign_related_fields]),
False, False,
self.prefetch_cache_name) self.prefetch_cache_name)
@ -795,7 +712,7 @@ def create_many_related_manager(superclass, rel):
instance=self.instance, reverse=self.reverse, instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db) model=self.model, pk_set=None, using=db)
self.through._default_manager.using(db).filter(**{ self.through._default_manager.using(db).filter(**{
source_field_name: self._fk_val source_field_name: self.related_val
}).delete() }).delete()
if self.reverse or source_field_name == self.source_field_name: if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are clearing the # Don't send the signal when we are clearing the
@ -918,19 +835,18 @@ class ReverseManyRelatedObjectsDescriptor(object):
manager.clear() manager.clear()
manager.add(*value) manager.add(*value)
class ForeignObjectRel(object):
class ManyToOneRel(object): def __init__(self, field, to, related_name=None, limit_choices_to=None,
def __init__(self, to, field_name, related_name=None, limit_choices_to=None, parent_link=False, on_delete=None):
parent_link=False, on_delete=None):
try: try:
to._meta to._meta
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT assert isinstance(to, six.string_types), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT
self.to, self.field_name = to, field_name
self.field = field
self.to = to
self.related_name = related_name self.related_name = related_name
if limit_choices_to is None: self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
limit_choices_to = {}
self.limit_choices_to = limit_choices_to
self.multiple = True self.multiple = True
self.parent_link = parent_link self.parent_link = parent_link
self.on_delete = on_delete self.on_delete = on_delete
@ -939,6 +855,20 @@ class ManyToOneRel(object):
"Should the related object be hidden?" "Should the related object be hidden?"
return self.related_name and self.related_name[-1] == '+' return self.related_name and self.related_name[-1] == '+'
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_extra_restriction(self, where_class, alias, related_alias):
return self.field.get_extra_restriction(where_class, related_alias, alias)
class ManyToOneRel(ForeignObjectRel):
def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None):
super(ManyToOneRel, self).__init__(
field, to, related_name=related_name, limit_choices_to=limit_choices_to,
parent_link=parent_link, on_delete=on_delete)
self.field_name = field_name
def get_related_field(self): def get_related_field(self):
""" """
Returns the Field in the 'to' object to which this relationship is Returns the Field in the 'to' object to which this relationship is
@ -952,9 +882,9 @@ class ManyToOneRel(object):
class OneToOneRel(ManyToOneRel): class OneToOneRel(ManyToOneRel):
def __init__(self, to, field_name, related_name=None, limit_choices_to=None, def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None,
parent_link=False, on_delete=None): parent_link=False, on_delete=None):
super(OneToOneRel, self).__init__(to, field_name, super(OneToOneRel, self).__init__(field, to, field_name,
related_name=related_name, limit_choices_to=limit_choices_to, related_name=related_name, limit_choices_to=limit_choices_to,
parent_link=parent_link, on_delete=on_delete parent_link=parent_link, on_delete=on_delete
) )
@ -963,7 +893,7 @@ class OneToOneRel(ManyToOneRel):
class ManyToManyRel(object): class ManyToManyRel(object):
def __init__(self, to, related_name=None, limit_choices_to=None, def __init__(self, to, related_name=None, limit_choices_to=None,
symmetrical=True, through=None, db_constraint=True): symmetrical=True, through=None, db_constraint=True):
if through and not db_constraint: if through and not db_constraint:
raise ValueError("Can't supply a through model and db_constraint=False") raise ValueError("Can't supply a through model and db_constraint=False")
self.to = to self.to = to
@ -989,7 +919,199 @@ class ManyToManyRel(object):
return self.to._meta.pk return self.to._meta.pk
class ForeignKey(RelatedField, Field): class ForeignObject(RelatedField):
requires_unique_target = True
generate_reverse_relation = True
def __init__(self, to, from_fields, to_fields, **kwargs):
self.from_fields = from_fields
self.to_fields = to_fields
if 'rel' not in kwargs:
kwargs['rel'] = ForeignObjectRel(
self, to,
related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),
parent_link=kwargs.pop('parent_link', False),
on_delete=kwargs.pop('on_delete', CASCADE),
)
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
super(ForeignObject, self).__init__(**kwargs)
def resolve_related_fields(self):
if len(self.from_fields) < 1 or len(self.from_fields) != len(self.to_fields):
raise ValueError('Foreign Object from and to fields must be the same non-zero length')
related_fields = []
for index in range(len(self.from_fields)):
from_field_name = self.from_fields[index]
to_field_name = self.to_fields[index]
from_field = (self if from_field_name == 'self'
else self.opts.get_field_by_name(from_field_name)[0])
to_field = (self.rel.to._meta.pk if to_field_name is None
else self.rel.to._meta.get_field_by_name(to_field_name)[0])
related_fields.append((from_field, to_field))
return related_fields
@property
def related_fields(self):
if not hasattr(self, '_related_fields'):
self._related_fields = self.resolve_related_fields()
return self._related_fields
@property
def reverse_related_fields(self):
return [(rhs_field, lhs_field) for lhs_field, rhs_field in self.related_fields]
@property
def local_related_fields(self):
return tuple([lhs_field for lhs_field, rhs_field in self.related_fields])
@property
def foreign_related_fields(self):
return tuple([rhs_field for lhs_field, rhs_field in self.related_fields])
def get_local_related_value(self, instance):
return self.get_instance_value_for_fields(instance, self.local_related_fields)
def get_foreign_related_value(self, instance):
return self.get_instance_value_for_fields(instance, self.foreign_related_fields)
@staticmethod
def get_instance_value_for_fields(instance, fields):
return tuple([getattr(instance, field.attname) for field in fields])
def get_attname_column(self):
attname, column = super(ForeignObject, self).get_attname_column()
return attname, None
def get_joining_columns(self, reverse_join=False):
source = self.reverse_related_fields if reverse_join else self.related_fields
return tuple([(lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source])
def get_reverse_joining_columns(self):
return self.get_joining_columns(reverse_join=True)
def get_extra_descriptor_filter(self, instance):
"""
Returns an extra filter condition for related object fetching when
user does 'instance.fieldname', that is the extra filter is used in
the descriptor of the field.
The filter should be something usable in .filter(**kwargs) call, and
will be ANDed together with the joining columns condition.
A parallel method is get_extra_relation_restriction() which is used in
JOIN and subquery conditions.
"""
return {}
def get_extra_restriction(self, where_class, alias, related_alias):
"""
Returns a pair condition used for joining and subquery pushdown. The
condition is something that responds to as_sql(qn, connection) method.
Note that currently referring both the 'alias' and 'related_alias'
will not work in some conditions, like subquery pushdown.
A parallel method is get_extra_descriptor_filter() which is used in
instance.fieldname related object fetching.
"""
return None
def get_path_info(self):
"""
Get path from this field to the related model.
"""
opts = self.rel.to._meta
from_opts = self.model._meta
return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)]
def get_reverse_path_info(self):
"""
Get path from the related model to this field's model.
"""
opts = self.model._meta
from_opts = self.rel.to._meta
pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
return pathinfos
def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type,
raw_value):
from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR
root_constraint = constraint_class()
assert len(targets) == len(sources)
def get_normalized_value(value):
from django.db.models import Model
if isinstance(value, Model):
value_list = []
for source in sources:
# Account for one-to-one relations when sent a different model
while not isinstance(value, source.model):
source = source.rel.to._meta.get_field(source.rel.field_name)
value_list.append(getattr(value, source.attname))
return tuple(value_list)
elif not isinstance(value, tuple):
return (value,)
return value
is_multicolumn = len(self.related_fields) > 1
if (hasattr(raw_value, '_as_sql') or
hasattr(raw_value, 'get_compiler')):
root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets],
[source.name for source in sources], raw_value),
AND)
elif lookup_type == 'isnull':
root_constraint.add(
(Constraint(alias, targets[0].column, targets[0]), lookup_type, raw_value), AND)
elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte']
and not is_multicolumn)):
value = get_normalized_value(raw_value)
for index, source in enumerate(sources):
root_constraint.add(
(Constraint(alias, targets[index].column, sources[index]), lookup_type,
value[index]), AND)
elif lookup_type in ['range', 'in'] and not is_multicolumn:
values = [get_normalized_value(value) for value in raw_value]
value = [val[0] for val in values]
root_constraint.add(
(Constraint(alias, targets[0].column, sources[0]), lookup_type, value), AND)
elif lookup_type == 'in':
values = [get_normalized_value(value) for value in raw_value]
for value in values:
value_constraint = constraint_class()
for index, target in enumerate(targets):
value_constraint.add(
(Constraint(alias, target.column, sources[index]), 'exact', value[index]),
AND)
root_constraint.add(value_constraint, OR)
else:
raise TypeError('Related Field got invalid lookup: %s' % lookup_type)
return root_constraint
@property
def attnames(self):
return tuple([field.attname for field in self.local_related_fields])
def get_defaults(self):
return tuple([field.get_default() for field in self.local_related_fields])
def contribute_to_class(self, cls, name, virtual_only=False):
super(ForeignObject, self).contribute_to_class(cls, name, virtual_only=virtual_only)
setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
def contribute_to_related_class(self, cls, related):
# Internal FK's - i.e., those with a related name ending with '+' -
# and swapped models don't get a related descriptor.
if not self.rel.is_hidden() and not related.model._meta.swapped:
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
if self.rel.limit_choices_to:
cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
class ForeignKey(ForeignObject):
empty_strings_allowed = False empty_strings_allowed = False
default_error_messages = { default_error_messages = {
'invalid': _('Model %(model)s with pk %(pk)r does not exist.') 'invalid': _('Model %(model)s with pk %(pk)r does not exist.')
@ -999,7 +1121,7 @@ class ForeignKey(RelatedField, Field):
def __init__(self, to, to_field=None, rel_class=ManyToOneRel, def __init__(self, to, to_field=None, rel_class=ManyToOneRel,
db_constraint=True, **kwargs): db_constraint=True, **kwargs):
try: try:
to._meta.model_name to_name = to._meta.object_name.lower()
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT) assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT)
else: else:
@ -1008,44 +1130,33 @@ class ForeignKey(RelatedField, Field):
# the to_field during FK construction. It won't be guaranteed to # the to_field during FK construction. It won't be guaranteed to
# be correct until contribute_to_class is called. Refs #12190. # be correct until contribute_to_class is called. Refs #12190.
to_field = to_field or (to._meta.pk and to._meta.pk.name) to_field = to_field or (to._meta.pk and to._meta.pk.name)
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
if 'db_index' not in kwargs: if 'db_index' not in kwargs:
kwargs['db_index'] = True kwargs['db_index'] = True
self.db_constraint = db_constraint self.db_constraint = db_constraint
kwargs['rel'] = rel_class(to, to_field,
kwargs['rel'] = rel_class(
self, to, to_field,
related_name=kwargs.pop('related_name', None), related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None), limit_choices_to=kwargs.pop('limit_choices_to', None),
parent_link=kwargs.pop('parent_link', False), parent_link=kwargs.pop('parent_link', False),
on_delete=kwargs.pop('on_delete', CASCADE), on_delete=kwargs.pop('on_delete', CASCADE),
) )
super(ForeignKey, self).__init__(**kwargs) super(ForeignKey, self).__init__(to, ['self'], [to_field], **kwargs)
def get_path_info(self): @property
""" def related_field(self):
Get path from this field to the related model. return self.foreign_related_fields[0]
"""
opts = self.rel.to._meta
target = self.rel.get_related_field()
from_opts = self.model._meta
return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self
def get_reverse_path_info(self): def get_reverse_path_info(self):
""" """
Get path from the related model to this field's model. Get path from the related model to this field's model.
""" """
opts = self.model._meta opts = self.model._meta
from_field = self.rel.get_related_field() from_opts = self.rel.to._meta
from_opts = from_field.model._meta pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)]
pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)] return pathinfos
if from_field.model is self.model:
# Recursive foreign key to self.
target = opts.get_field_by_name(
self.rel.field_name)[0]
else:
target = opts.pk
return pathinfos, opts, target, self
def validate(self, value, model_instance): def validate(self, value, model_instance):
if self.rel.parent_link: if self.rel.parent_link:
@ -1066,21 +1177,26 @@ class ForeignKey(RelatedField, Field):
def get_attname(self): def get_attname(self):
return '%s_id' % self.name return '%s_id' % self.name
def get_attname_column(self):
attname = self.get_attname()
column = self.db_column or attname
return attname, column
def get_validator_unique_lookup_type(self): def get_validator_unique_lookup_type(self):
return '%s__%s__exact' % (self.name, self.rel.get_related_field().name) return '%s__%s__exact' % (self.name, self.related_field.name)
def get_default(self): def get_default(self):
"Here we check if the default value is an object and return the to_field if so." "Here we check if the default value is an object and return the to_field if so."
field_default = super(ForeignKey, self).get_default() field_default = super(ForeignKey, self).get_default()
if isinstance(field_default, self.rel.to): if isinstance(field_default, self.rel.to):
return getattr(field_default, self.rel.get_related_field().attname) return getattr(field_default, self.related_field.attname)
return field_default return field_default
def get_db_prep_save(self, value, connection): def get_db_prep_save(self, value, connection):
if value == '' or value == None: if value == '' or value == None:
return None return None
else: else:
return self.rel.get_related_field().get_db_prep_save(value, return self.related_field.get_db_prep_save(value,
connection=connection) connection=connection)
def value_to_string(self, obj): def value_to_string(self, obj):
@ -1093,19 +1209,10 @@ class ForeignKey(RelatedField, Field):
choice_list = self.get_choices_default() choice_list = self.get_choices_default()
if len(choice_list) == 2: if len(choice_list) == 2:
return smart_text(choice_list[1][0]) return smart_text(choice_list[1][0])
return Field.value_to_string(self, obj) return super(ForeignKey, self).value_to_string(obj)
def contribute_to_class(self, cls, name):
super(ForeignKey, self).contribute_to_class(cls, name)
setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
def contribute_to_related_class(self, cls, related): def contribute_to_related_class(self, cls, related):
# Internal FK's - i.e., those with a related name ending with '+' - super(ForeignKey, self).contribute_to_related_class(cls, related)
# and swapped models don't get a related descriptor.
if not self.rel.is_hidden() and not related.model._meta.swapped:
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
if self.rel.limit_choices_to:
cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to)
if self.rel.field_name is None: if self.rel.field_name is None:
self.rel.field_name = cls._meta.pk.name self.rel.field_name = cls._meta.pk.name
@ -1130,7 +1237,7 @@ class ForeignKey(RelatedField, Field):
# in which case the column type is simply that of an IntegerField. # in which case the column type is simply that of an IntegerField.
# If the database needs similar types for key fields however, the only # If the database needs similar types for key fields however, the only
# thing we can do is making AutoField an IntegerField. # thing we can do is making AutoField an IntegerField.
rel_field = self.rel.get_related_field() rel_field = self.related_field
if (isinstance(rel_field, AutoField) or if (isinstance(rel_field, AutoField) or
(not connection.features.related_fields_match_type and (not connection.features.related_fields_match_type and
isinstance(rel_field, (PositiveIntegerField, isinstance(rel_field, (PositiveIntegerField,
@ -1212,7 +1319,7 @@ def create_many_to_many_intermediary_model(field, klass):
}) })
class ManyToManyField(RelatedField, Field): class ManyToManyField(RelatedField):
description = _("Many-to-many relationship") description = _("Many-to-many relationship")
def __init__(self, to, db_constraint=True, **kwargs): def __init__(self, to, db_constraint=True, **kwargs):
@ -1252,14 +1359,14 @@ class ManyToManyField(RelatedField, Field):
linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0] linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0]
linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0] linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
if direct: if direct:
join1infos, _, _, _ = linkfield1.get_reverse_path_info() join1infos = linkfield1.get_reverse_path_info()
join2infos, opts, target, final_field = linkfield2.get_path_info() join2infos = linkfield2.get_path_info()
else: else:
join1infos, _, _, _ = linkfield2.get_reverse_path_info() join1infos = linkfield2.get_reverse_path_info()
join2infos, opts, target, final_field = linkfield1.get_path_info() join2infos = linkfield1.get_path_info()
pathinfos.extend(join1infos) pathinfos.extend(join1infos)
pathinfos.extend(join2infos) pathinfos.extend(join2infos)
return pathinfos, opts, target, final_field return pathinfos
def get_path_info(self): def get_path_info(self):
return self._get_path_info(direct=True) return self._get_path_info(direct=True)
@ -1402,8 +1509,3 @@ class ManyToManyField(RelatedField, Field):
initial = initial() initial = initial()
defaults['initial'] = [i._get_pk_val() for i in initial] defaults['initial'] = [i._get_pk_val() for i in initial]
return super(ManyToManyField, self).formfield(**defaults) return super(ManyToManyField, self).formfield(**defaults)
def db_type(self, connection):
# A ManyToManyField is not represented by a single column,
# so return None.
return None

View File

@ -10,6 +10,7 @@ from django.db.models.fields import AutoField, FieldDoesNotExist
from django.db.models.fields.proxy import OrderWrt from django.db.models.fields.proxy import OrderWrt
from django.db.models.loading import get_models, app_cache_ready from django.db.models.loading import get_models, app_cache_ready
from django.utils import six from django.utils import six
from django.utils.functional import cached_property
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.encoding import force_text, smart_text, python_2_unicode_compatible from django.utils.encoding import force_text, smart_text, python_2_unicode_compatible
from django.utils.translation import activate, deactivate_all, get_language, string_concat from django.utils.translation import activate, deactivate_all, get_language, string_concat
@ -173,6 +174,22 @@ class Options(object):
if hasattr(self, '_field_cache'): if hasattr(self, '_field_cache'):
del self._field_cache del self._field_cache
del self._field_name_cache del self._field_name_cache
# The fields, concrete_fields and local_concrete_fields are
# implemented as cached properties for performance reasons.
# The attrs will not exists if the cached property isn't
# accessed yet, hence the try-excepts.
try:
del self.fields
except AttributeError:
pass
try:
del self.concrete_fields
except AttributeError:
pass
try:
del self.local_concrete_fields
except AttributeError:
pass
if hasattr(self, '_name_map'): if hasattr(self, '_name_map'):
del self._name_map del self._name_map
@ -245,7 +262,8 @@ class Options(object):
return None return None
swapped = property(_swapped) swapped = property(_swapped)
def _fields(self): @cached_property
def fields(self):
""" """
The getter for self.fields. This returns the list of field objects The getter for self.fields. This returns the list of field objects
available to this model (including through parent models). available to this model (including through parent models).
@ -258,7 +276,14 @@ class Options(object):
except AttributeError: except AttributeError:
self._fill_fields_cache() self._fill_fields_cache()
return self._field_name_cache return self._field_name_cache
fields = property(_fields)
@cached_property
def concrete_fields(self):
return [f for f in self.fields if f.column is not None]
@cached_property
def local_concrete_fields(self):
return [f for f in self.local_fields if f.column is not None]
def get_fields_with_model(self): def get_fields_with_model(self):
""" """
@ -272,6 +297,10 @@ class Options(object):
self._fill_fields_cache() self._fill_fields_cache()
return self._field_cache return self._field_cache
def get_concrete_fields_with_model(self):
return [(field, model) for field, model in self.get_fields_with_model() if
field.column is not None]
def _fill_fields_cache(self): def _fill_fields_cache(self):
cache = [] cache = []
for parent in self.parents: for parent in self.parents:
@ -377,6 +406,9 @@ class Options(object):
cache[f.name] = (f, model, True, True) cache[f.name] = (f, model, True, True)
for f, model in self.get_fields_with_model(): for f, model in self.get_fields_with_model():
cache[f.name] = (f, model, True, False) cache[f.name] = (f, model, True, False)
for f in self.virtual_fields:
if hasattr(f, 'related'):
cache[f.name] = (f.related, None if f.model == self.model else f.model, True, False)
if app_cache_ready(): if app_cache_ready():
self._name_map = cache self._name_map = cache
return cache return cache
@ -432,7 +464,7 @@ class Options(object):
for klass in get_models(include_auto_created=True, only_installed=False): for klass in get_models(include_auto_created=True, only_installed=False):
if not klass._meta.swapped: if not klass._meta.swapped:
for f in klass._meta.local_fields: for f in klass._meta.local_fields:
if f.rel and not isinstance(f.rel.to, six.string_types): if f.rel and not isinstance(f.rel.to, six.string_types) and f.generate_reverse_relation:
if self == f.rel.to._meta: if self == f.rel.to._meta:
cache[f.related] = None cache[f.related] = None
proxy_cache[f.related] = None proxy_cache[f.related] = None

View File

@ -261,13 +261,13 @@ class QuerySet(object):
only_load = self.query.get_loaded_field_names() only_load = self.query.get_loaded_field_names()
if not fill_cache: if not fill_cache:
fields = self.model._meta.fields fields = self.model._meta.concrete_fields
load_fields = [] load_fields = []
# If only/defer clauses have been specified, # If only/defer clauses have been specified,
# build the list of fields that are to be loaded. # build the list of fields that are to be loaded.
if only_load: if only_load:
for field, model in self.model._meta.get_fields_with_model(): for field, model in self.model._meta.get_concrete_fields_with_model():
if model is None: if model is None:
model = self.model model = self.model
try: try:
@ -280,7 +280,7 @@ class QuerySet(object):
load_fields.append(field.name) load_fields.append(field.name)
index_start = len(extra_select) index_start = len(extra_select)
aggregate_start = index_start + len(load_fields or self.model._meta.fields) aggregate_start = index_start + len(load_fields or self.model._meta.concrete_fields)
skip = None skip = None
if load_fields and not fill_cache: if load_fields and not fill_cache:
@ -312,7 +312,11 @@ class QuerySet(object):
if skip: if skip:
obj = model_cls(**dict(zip(init_list, row_data))) obj = model_cls(**dict(zip(init_list, row_data)))
else: else:
obj = model(*row_data) try:
obj = model(*row_data)
except IndexError:
import ipdb; ipdb.set_trace()
pass
# Store the source database of the object # Store the source database of the object
obj._state.db = db obj._state.db = db
@ -962,7 +966,7 @@ class QuerySet(object):
""" """
opts = self.model._meta opts = self.model._meta
if self.query.group_by is None: if self.query.group_by is None:
field_names = [f.attname for f in opts.fields] field_names = [f.attname for f in opts.concrete_fields]
self.query.add_fields(field_names, False) self.query.add_fields(field_names, False)
self.query.set_group_by() self.query.set_group_by()
@ -1055,7 +1059,7 @@ class ValuesQuerySet(QuerySet):
else: else:
# Default to all fields. # Default to all fields.
self.extra_names = None self.extra_names = None
self.field_names = [f.attname for f in self.model._meta.fields] self.field_names = [f.attname for f in self.model._meta.concrete_fields]
self.aggregate_names = None self.aggregate_names = None
self.query.select = [] self.query.select = []
@ -1266,7 +1270,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
skip = set() skip = set()
init_list = [] init_list = []
# Build the list of fields that *haven't* been requested # Build the list of fields that *haven't* been requested
for field, model in klass._meta.get_fields_with_model(): for field, model in klass._meta.get_concrete_fields_with_model():
if field.name not in load_fields: if field.name not in load_fields:
skip.add(field.attname) skip.add(field.attname)
elif from_parent and issubclass(from_parent, model.__class__): elif from_parent and issubclass(from_parent, model.__class__):
@ -1285,22 +1289,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else: else:
# Load all fields on klass # Load all fields on klass
field_count = len(klass._meta.fields) field_count = len(klass._meta.concrete_fields)
# Check if we need to skip some parent fields. # Check if we need to skip some parent fields.
if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields): if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields):
# Only load those fields which haven't been already loaded into # Only load those fields which haven't been already loaded into
# 'from_parent'. # 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list() non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)] if not issubclass(from_parent, p)]
# Load local fields, too... # Load local fields, too...
non_seen_models.append(klass) non_seen_models.append(klass)
field_names = [f.attname for f in klass._meta.fields field_names = [f.attname for f in klass._meta.concrete_fields
if f.model in non_seen_models] if f.model in non_seen_models]
field_count = len(field_names) field_count = len(field_names)
# Try to avoid populating field_names variable for perfomance reasons. # Try to avoid populating field_names variable for perfomance reasons.
# If field_names variable is set, we use **kwargs based model init # If field_names variable is set, we use **kwargs based model init
# which is slower than normal init. # which is slower than normal init.
if field_count == len(klass._meta.fields): if field_count == len(klass._meta.concrete_fields):
field_names = () field_names = ()
restricted = requested is not None restricted = requested is not None

View File

@ -7,7 +7,7 @@ from django.db.models.fields import BLANK_CHOICE_DASH
# describe the relation in Model terms (model Options and Fields for both # describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation. # sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple('PathInfo', PathInfo = namedtuple('PathInfo',
'from_field to_field from_opts to_opts join_field ' 'from_opts to_opts target_fields join_field '
'm2m direct') 'm2m direct')
class RelatedObject(object): class RelatedObject(object):

View File

@ -2,10 +2,9 @@ import datetime
from django.conf import settings from django.conf import settings
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import transaction
from django.db.backends.util import truncate_name from django.db.backends.util import truncate_name
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend from django.db.models.query_utils import select_related_descend, QueryWrapper
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
GET_ITERATOR_CHUNK_SIZE, SelectInfo) GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.datastructures import EmptyResultSet
@ -33,7 +32,7 @@ class SQLCompiler(object):
# cleaned. We are not using a clone() of the query here. # cleaned. We are not using a clone() of the query here.
""" """
if not self.query.tables: if not self.query.tables:
self.query.join((None, self.query.model._meta.db_table, None, None)) self.query.join((None, self.query.model._meta.db_table, None))
if (not self.query.select and self.query.default_cols and not if (not self.query.select and self.query.default_cols and not
self.query.included_inherited_models): self.query.included_inherited_models):
self.query.setup_inherited_models() self.query.setup_inherited_models()
@ -273,7 +272,7 @@ class SQLCompiler(object):
# be used by local fields. # be used by local fields.
seen_models = {None: start_alias} seen_models = {None: start_alias}
for field, model in opts.get_fields_with_model(): for field, model in opts.get_concrete_fields_with_model():
if from_parent and model is not None and issubclass(from_parent, model): if from_parent and model is not None and issubclass(from_parent, model):
# Avoid loading data for already loaded parents. # Avoid loading data for already loaded parents.
continue continue
@ -314,9 +313,10 @@ class SQLCompiler(object):
for name in self.query.distinct_fields: for name in self.query.distinct_fields:
parts = name.split(LOOKUP_SEP) parts = name.split(LOOKUP_SEP)
field, col, alias, _, _ = self._setup_joins(parts, opts, None) field, cols, alias, _, _ = self._setup_joins(parts, opts, None)
col, alias = self._final_join_removal(col, alias) cols, alias = self._final_join_removal(cols, alias)
result.append("%s.%s" % (qn(alias), qn2(col))) for col in cols:
result.append("%s.%s" % (qn(alias), qn2(col)))
return result return result
@ -387,15 +387,16 @@ class SQLCompiler(object):
elif get_order_dir(field)[0] not in self.query.extra_select: elif get_order_dir(field)[0] not in self.query.extra_select:
# 'col' is of the form 'field' or 'field1__field2' or # 'col' is of the form 'field' or 'field1__field2' or
# '-field1__field2__field', etc. # '-field1__field2__field', etc.
for table, col, order in self.find_ordering_name(field, for table, cols, order in self.find_ordering_name(field,
self.query.model._meta, default_order=asc): self.query.model._meta, default_order=asc):
if (table, col) not in processed_pairs: for col in cols:
elt = '%s.%s' % (qn(table), qn2(col)) if (table, col) not in processed_pairs:
processed_pairs.add((table, col)) elt = '%s.%s' % (qn(table), qn2(col))
if distinct and elt not in select_aliases: processed_pairs.add((table, col))
ordering_aliases.append(elt) if distinct and elt not in select_aliases:
result.append('%s %s' % (elt, order)) ordering_aliases.append(elt)
group_by.append((elt, [])) result.append('%s %s' % (elt, order))
group_by.append((elt, []))
else: else:
elt = qn2(col) elt = qn2(col)
if distinct and col not in select_aliases: if distinct and col not in select_aliases:
@ -414,7 +415,7 @@ class SQLCompiler(object):
""" """
name, order = get_order_dir(name, default_order) name, order = get_order_dir(name, default_order)
pieces = name.split(LOOKUP_SEP) pieces = name.split(LOOKUP_SEP)
field, col, alias, joins, opts = self._setup_joins(pieces, opts, alias) field, cols, alias, joins, opts = self._setup_joins(pieces, opts, alias)
# If we get to this point and the field is a relation to another model, # If we get to this point and the field is a relation to another model,
# append the default ordering for that model. # append the default ordering for that model.
@ -432,8 +433,8 @@ class SQLCompiler(object):
results.extend(self.find_ordering_name(item, opts, alias, results.extend(self.find_ordering_name(item, opts, alias,
order, already_seen)) order, already_seen))
return results return results
col, alias = self._final_join_removal(col, alias) cols, alias = self._final_join_removal(cols, alias)
return [(alias, col, order)] return [(alias, cols, order)]
def _setup_joins(self, pieces, opts, alias): def _setup_joins(self, pieces, opts, alias):
""" """
@ -446,13 +447,13 @@ class SQLCompiler(object):
""" """
if not alias: if not alias:
alias = self.query.get_initial_alias() alias = self.query.get_initial_alias()
field, target, opts, joins, _ = self.query.setup_joins( field, targets, opts, joins, _ = self.query.setup_joins(
pieces, opts, alias) pieces, opts, alias)
# We will later on need to promote those joins that were added to the # We will later on need to promote those joins that were added to the
# query afresh above. # query afresh above.
joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2] joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2]
alias = joins[-1] alias = joins[-1]
col = target.column cols = [target.column for target in targets]
if not field.rel: if not field.rel:
# To avoid inadvertent trimming of a necessary alias, use the # To avoid inadvertent trimming of a necessary alias, use the
# refcount to show that we are referencing a non-relation field on # refcount to show that we are referencing a non-relation field on
@ -463,9 +464,9 @@ class SQLCompiler(object):
# Ordering or distinct must not affect the returned set, and INNER # Ordering or distinct must not affect the returned set, and INNER
# JOINS for nullable fields could do this. # JOINS for nullable fields could do this.
self.query.promote_joins(joins_to_promote) self.query.promote_joins(joins_to_promote)
return field, col, alias, joins, opts return field, cols, alias, joins, opts
def _final_join_removal(self, col, alias): def _final_join_removal(self, cols, alias):
""" """
A helper method for get_distinct and get_ordering. This method will A helper method for get_distinct and get_ordering. This method will
trim extra not-needed joins from the tail of the join chain. trim extra not-needed joins from the tail of the join chain.
@ -477,12 +478,14 @@ class SQLCompiler(object):
if alias: if alias:
while 1: while 1:
join = self.query.alias_map[alias] join = self.query.alias_map[alias]
if col != join.rhs_join_col: lhs_cols, rhs_cols = zip(*[(lhs_col, rhs_col) for lhs_col, rhs_col in join.join_cols])
if set(cols) != set(rhs_cols):
break break
cols = [lhs_cols[rhs_cols.index(col)] for col in cols]
self.query.unref_alias(alias) self.query.unref_alias(alias)
alias = join.lhs_alias alias = join.lhs_alias
col = join.lhs_join_col return cols, alias
return col, alias
def get_from_clause(self): def get_from_clause(self):
""" """
@ -504,22 +507,30 @@ class SQLCompiler(object):
if not self.query.alias_refcount[alias]: if not self.query.alias_refcount[alias]:
continue continue
try: try:
name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias] name, alias, join_type, lhs, join_cols, _, join_field = self.query.alias_map[alias]
except KeyError: except KeyError:
# Extra tables can end up in self.tables, but not in the # Extra tables can end up in self.tables, but not in the
# alias_map if they aren't in a join. That's OK. We skip them. # alias_map if they aren't in a join. That's OK. We skip them.
continue continue
alias_str = (alias != name and ' %s' % alias or '') alias_str = (alias != name and ' %s' % alias or '')
if join_type and not first: if join_type and not first:
if join_field and hasattr(join_field, 'get_extra_join_sql'): extra_cond = join_field.get_extra_restriction(
extra_cond, extra_params = join_field.get_extra_join_sql( self.query.where_class, alias, lhs)
self.connection, qn, lhs, alias) if extra_cond:
extra_sql, extra_params = extra_cond.as_sql(
qn, self.connection)
extra_sql = 'AND (%s)' % extra_sql
from_params.extend(extra_params) from_params.extend(extra_params)
else: else:
extra_cond = "" extra_sql = ""
result.append('%s %s%s ON (%s.%s = %s.%s%s)' % result.append('%s %s%s ON ('
(join_type, qn(name), alias_str, qn(lhs), % (join_type, qn(name), alias_str))
qn2(lhs_col), qn(alias), qn2(col), extra_cond)) for index, (lhs_col, rhs_col) in enumerate(join_cols):
if index != 0:
result.append(' AND ')
result.append('%s.%s = %s.%s' %
(qn(lhs), qn2(lhs_col), qn(alias), qn2(rhs_col)))
result.append('%s)' % extra_sql)
else: else:
connector = not first and ', ' or '' connector = not first and ', ' or ''
result.append('%s%s%s' % (connector, qn(name), alias_str)) result.append('%s%s%s' % (connector, qn(name), alias_str))
@ -545,7 +556,7 @@ class SQLCompiler(object):
select_cols = self.query.select + self.query.related_select_cols select_cols = self.query.select + self.query.related_select_cols
# Just the column, not the fields. # Just the column, not the fields.
select_cols = [s[0] for s in select_cols] select_cols = [s[0] for s in select_cols]
if (len(self.query.model._meta.fields) == len(self.query.select) if (len(self.query.model._meta.concrete_fields) == len(self.query.select)
and self.connection.features.allows_group_by_pk): and self.connection.features.allows_group_by_pk):
self.query.group_by = [ self.query.group_by = [
(self.query.model._meta.db_table, self.query.model._meta.pk.column) (self.query.model._meta.db_table, self.query.model._meta.pk.column)
@ -623,14 +634,13 @@ class SQLCompiler(object):
table = f.rel.to._meta.db_table table = f.rel.to._meta.db_table
promote = nullable or f.null promote = nullable or f.null
alias = self.query.join_parent_model(opts, model, root_alias, {}) alias = self.query.join_parent_model(opts, model, root_alias, {})
join_cols = f.get_joining_columns()
alias = self.query.join((alias, table, f.column, alias = self.query.join((alias, table, join_cols),
f.rel.get_related_field().column),
outer_if_first=promote, join_field=f) outer_if_first=promote, join_field=f)
columns, aliases = self.get_default_columns(start_alias=alias, columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True) opts=f.rel.to._meta, as_pairs=True)
self.query.related_select_cols.extend( self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields)) SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.concrete_fields))
if restricted: if restricted:
next = requested.get(f.name, {}) next = requested.get(f.name, {})
else: else:
@ -653,7 +663,7 @@ class SQLCompiler(object):
alias = self.query.join_parent_model(opts, f.rel.to, root_alias, {}) alias = self.query.join_parent_model(opts, f.rel.to, root_alias, {})
table = model._meta.db_table table = model._meta.db_table
alias = self.query.join( alias = self.query.join(
(alias, table, f.rel.get_related_field().column, f.column), (alias, table, f.get_joining_columns(reverse_join=True)),
outer_if_first=True, join_field=f outer_if_first=True, join_field=f
) )
from_parent = (opts.model if issubclass(model, opts.model) from_parent = (opts.model if issubclass(model, opts.model)
@ -662,7 +672,7 @@ class SQLCompiler(object):
opts=model._meta, as_pairs=True, from_parent=from_parent) opts=model._meta, as_pairs=True, from_parent=from_parent)
self.query.related_select_cols.extend( self.query.related_select_cols.extend(
SelectInfo(col, field) for col, field SelectInfo(col, field) for col, field
in zip(columns, model._meta.fields)) in zip(columns, model._meta.concrete_fields))
next = requested.get(f.related_query_name(), {}) next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of # Use True here because we are looking at the _reverse_ side of
# the relation, which is always nullable. # the relation, which is always nullable.
@ -706,7 +716,7 @@ class SQLCompiler(object):
if self.query.select: if self.query.select:
fields = [f.field for f in self.query.select] fields = [f.field for f in self.query.select]
else: else:
fields = self.query.model._meta.fields fields = self.query.model._meta.concrete_fields
fields = fields + [f.field for f in self.query.related_select_cols] fields = fields + [f.field for f in self.query.related_select_cols]
# If the field was deferred, exclude it from being passed # If the field was deferred, exclude it from being passed
@ -776,6 +786,22 @@ class SQLCompiler(object):
return list(result) return list(result)
return result return result
def as_subquery_condition(self, alias, columns):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
if len(columns) == 1:
sql, params = self.as_sql()
return '%s.%s IN (%s)' % (qn(alias), qn2(columns[0]), sql), params
for index, select_col in enumerate(self.query.select):
lhs = '%s.%s' % (qn(select_col.col[0]), qn2(select_col.col[1]))
rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
self.query.where.add(
QueryWrapper('%s = %s' % (lhs, rhs), []), 'AND')
sql, params = self.as_sql()
return 'EXISTS (%s)' % sql, params
class SQLInsertCompiler(SQLCompiler): class SQLInsertCompiler(SQLCompiler):
def placeholder(self, field, val): def placeholder(self, field, val):

View File

@ -25,7 +25,7 @@ GET_ITERATOR_CHUNK_SIZE = 100
# dictionary in the Query class). # dictionary in the Query class).
JoinInfo = namedtuple('JoinInfo', JoinInfo = namedtuple('JoinInfo',
'table_name rhs_alias join_type lhs_alias ' 'table_name rhs_alias join_type lhs_alias '
'lhs_join_col rhs_join_col nullable join_field') 'join_cols nullable join_field')
# Pairs of column clauses to select, and (possibly None) field for the clause. # Pairs of column clauses to select, and (possibly None) field for the clause.
SelectInfo = namedtuple('SelectInfo', 'col field') SelectInfo = namedtuple('SelectInfo', 'col field')

View File

@ -55,13 +55,14 @@ class SQLEvaluator(object):
self.cols.append((node, query.aggregate_select[node.name])) self.cols.append((node, query.aggregate_select[node.name]))
else: else:
try: try:
field, source, opts, join_list, path = query.setup_joins( field, sources, opts, join_list, path = query.setup_joins(
field_list, query.get_meta(), field_list, query.get_meta(),
query.get_initial_alias(), self.reuse) query.get_initial_alias(), self.reuse)
target, _, join_list = query.trim_joins(source, join_list, path) targets, _, join_list = query.trim_joins(sources, join_list, path)
if self.reuse is not None: if self.reuse is not None:
self.reuse.update(join_list) self.reuse.update(join_list)
self.cols.append((node, (join_list[-1], target.column))) for t in targets:
self.cols.append((node, (join_list[-1], t.column)))
except FieldDoesNotExist: except FieldDoesNotExist:
raise FieldError("Cannot resolve keyword %r into field. " raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (self.name, "Choices are: %s" % (self.name,

View File

@ -452,13 +452,13 @@ class Query(object):
# Now, add the joins from rhs query into the new query (skipping base # Now, add the joins from rhs query into the new query (skipping base
# table). # table).
for alias in rhs.tables[1:]: for alias in rhs.tables[1:]:
table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias] table, _, join_type, lhs, join_cols, nullable, join_field = rhs.alias_map[alias]
promote = (join_type == self.LOUTER) promote = (join_type == self.LOUTER)
# If the left side of the join was already relabeled, use the # If the left side of the join was already relabeled, use the
# updated alias. # updated alias.
lhs = change_map.get(lhs, lhs) lhs = change_map.get(lhs, lhs)
new_alias = self.join( new_alias = self.join(
(lhs, table, lhs_col, col), reuse=reuse, (lhs, table, join_cols), reuse=reuse,
outer_if_first=not conjunction, nullable=nullable, outer_if_first=not conjunction, nullable=nullable,
join_field=join_field) join_field=join_field)
if promote: if promote:
@ -682,7 +682,7 @@ class Query(object):
aliases = list(aliases) aliases = list(aliases)
while aliases: while aliases:
alias = aliases.pop(0) alias = aliases.pop(0)
if self.alias_map[alias].rhs_join_col is None: if self.alias_map[alias].join_cols[0][1] is None:
# This is the base table (first FROM entry) - this table # This is the base table (first FROM entry) - this table
# isn't really joined at all in the query, so we should not # isn't really joined at all in the query, so we should not
# alter its join type. # alter its join type.
@ -818,7 +818,7 @@ class Query(object):
alias = self.tables[0] alias = self.tables[0]
self.ref_alias(alias) self.ref_alias(alias)
else: else:
alias = self.join((None, self.model._meta.db_table, None, None)) alias = self.join((None, self.model._meta.db_table, None))
return alias return alias
def count_active_tables(self): def count_active_tables(self):
@ -834,11 +834,12 @@ class Query(object):
""" """
Returns an alias for the join in 'connection', either reusing an Returns an alias for the join in 'connection', either reusing an
existing alias for that join or creating a new one. 'connection' is a existing alias for that join or creating a new one. 'connection' is a
tuple (lhs, table, lhs_col, col) where 'lhs' is either an existing tuple (lhs, table, join_cols) where 'lhs' is either an existing
table alias or a table name. The join correspods to the SQL equivalent table alias or a table name. 'join_cols' is a tuple of tuples containing
of:: columns to join on ((l_id1, r_id1), (l_id2, r_id2)). The join corresponds
to the SQL equivalent of::
lhs.lhs_col = table.col lhs.l_id1 = table.r_id1 AND lhs.l_id2 = table.r_id2
The 'reuse' parameter can be either None which means all joins The 'reuse' parameter can be either None which means all joins
(matching the connection) are reusable, or it can be a set containing (matching the connection) are reusable, or it can be a set containing
@ -855,7 +856,7 @@ class Query(object):
The 'join_field' is the field we are joining along (if any). The 'join_field' is the field we are joining along (if any).
""" """
lhs, table, lhs_col, col = connection lhs, table, join_cols = connection
assert lhs is None or join_field is not None assert lhs is None or join_field is not None
existing = self.join_map.get(connection, ()) existing = self.join_map.get(connection, ())
if reuse is None: if reuse is None:
@ -884,7 +885,7 @@ class Query(object):
join_type = self.LOUTER join_type = self.LOUTER
else: else:
join_type = self.INNER join_type = self.INNER
join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable, join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable,
join_field) join_field)
self.alias_map[alias] = join self.alias_map[alias] = join
if connection in self.join_map: if connection in self.join_map:
@ -941,7 +942,7 @@ class Query(object):
continue continue
link_field = int_opts.get_ancestor_link(int_model) link_field = int_opts.get_ancestor_link(int_model)
int_opts = int_model._meta int_opts = int_model._meta
connection = (alias, int_opts.db_table, link_field.column, int_opts.pk.column) connection = (alias, int_opts.db_table, link_field.get_joining_columns())
alias = seen[int_model] = self.join(connection, nullable=False, alias = seen[int_model] = self.join(connection, nullable=False,
join_field=link_field) join_field=link_field)
return alias or seen[None] return alias or seen[None]
@ -982,18 +983,20 @@ class Query(object):
# - this is an annotation over a model field # - this is an annotation over a model field
# then we need to explore the joins that are required. # then we need to explore the joins that are required.
field, source, opts, join_list, path = self.setup_joins( field, sources, opts, join_list, path = self.setup_joins(
field_list, opts, self.get_initial_alias()) field_list, opts, self.get_initial_alias())
# Process the join chain to see if it can be trimmed # Process the join chain to see if it can be trimmed
target, _, join_list = self.trim_joins(source, join_list, path) targets, _, join_list = self.trim_joins(sources, join_list, path)
# If the aggregate references a model or field that requires a join, # If the aggregate references a model or field that requires a join,
# those joins must be LEFT OUTER - empty join rows must be returned # those joins must be LEFT OUTER - empty join rows must be returned
# in order for zeros to be returned for those aggregates. # in order for zeros to be returned for those aggregates.
self.promote_joins(join_list, True) self.promote_joins(join_list, True)
col = (join_list[-1], target.column) col = targets[0].column
source = sources[0]
col = (join_list[-1], col)
else: else:
# The simplest cases. No joins required - # The simplest cases. No joins required -
# just reference the provided column alias. # just reference the provided column alias.
@ -1086,7 +1089,7 @@ class Query(object):
allow_many = not branch_negated allow_many = not branch_negated
try: try:
field, target, opts, join_list, path = self.setup_joins( field, sources, opts, join_list, path = self.setup_joins(
parts, opts, alias, can_reuse, allow_many, parts, opts, alias, can_reuse, allow_many,
allow_explicit_fk=True) allow_explicit_fk=True)
if can_reuse is not None: if can_reuse is not None:
@ -1106,13 +1109,19 @@ class Query(object):
# the far end (fewer tables in a query is better). Note that join # the far end (fewer tables in a query is better). Note that join
# promotion must happen before join trimming to have the join type # promotion must happen before join trimming to have the join type
# information available when reusing joins. # information available when reusing joins.
target, alias, join_list = self.trim_joins(target, join_list, path) targets, alias, join_list = self.trim_joins(sources, join_list, path)
clause.add((Constraint(alias, target.column, field), lookup_type, value),
AND) if hasattr(field, 'get_lookup_constraint'):
constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources,
lookup_type, value)
else:
constraint = (Constraint(alias, targets[0].column, field), lookup_type, value)
clause.add(constraint, AND)
if current_negated and (lookup_type != 'isnull' or value is False): if current_negated and (lookup_type != 'isnull' or value is False):
self.promote_joins(join_list) self.promote_joins(join_list)
if (lookup_type != 'isnull' and ( if (lookup_type != 'isnull' and (
self.is_nullable(target) or self.alias_map[join_list[-1]].join_type == self.LOUTER)): self.is_nullable(targets[0]) or
self.alias_map[join_list[-1]].join_type == self.LOUTER)):
# The condition added here will be SQL like this: # The condition added here will be SQL like this:
# NOT (col IS NOT NULL), where the first NOT is added in # NOT (col IS NOT NULL), where the first NOT is added in
# upper layers of code. The reason for addition is that if col # upper layers of code. The reason for addition is that if col
@ -1122,7 +1131,7 @@ class Query(object):
# (col IS NULL OR col != someval) # (col IS NULL OR col != someval)
# <=> # <=>
# NOT (col IS NOT NULL AND col = someval). # NOT (col IS NOT NULL AND col = someval).
clause.add((Constraint(alias, target.column, None), 'isnull', False), AND) clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND)
return clause return clause
def add_filter(self, filter_clause): def add_filter(self, filter_clause):
@ -1272,22 +1281,26 @@ class Query(object):
opts = int_model._meta opts = int_model._meta
else: else:
final_field = opts.parents[int_model] final_field = opts.parents[int_model]
target = final_field.rel.get_related_field() targets = (final_field.rel.get_related_field(),)
opts = int_model._meta opts = int_model._meta
path.append(PathInfo(final_field, target, final_field.model._meta, path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
opts, final_field, False, True))
if hasattr(field, 'get_path_info'): if hasattr(field, 'get_path_info'):
pathinfos, opts, target, final_field = field.get_path_info() pathinfos = field.get_path_info()
if not allow_many: if not allow_many:
for inner_pos, p in enumerate(pathinfos): for inner_pos, p in enumerate(pathinfos):
if p.m2m: if p.m2m:
names_with_path.append((name, pathinfos[0:inner_pos + 1])) names_with_path.append((name, pathinfos[0:inner_pos + 1]))
raise MultiJoin(pos + 1, names_with_path) raise MultiJoin(pos + 1, names_with_path)
last = pathinfos[-1]
path.extend(pathinfos) path.extend(pathinfos)
final_field = last.join_field
opts = last.to_opts
targets = last.target_fields
names_with_path.append((name, pathinfos)) names_with_path.append((name, pathinfos))
else: else:
# Local non-relational field. # Local non-relational field.
final_field = target = field final_field = field
targets = (field,)
break break
if pos != len(names) - 1: if pos != len(names) - 1:
@ -1297,7 +1310,7 @@ class Query(object):
"the lookup type?" % (name, names[pos + 1])) "the lookup type?" % (name, names[pos + 1]))
else: else:
raise FieldError("Join on field %r not permitted." % name) raise FieldError("Join on field %r not permitted." % name)
return path, final_field, target return path, final_field, targets
def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True,
allow_explicit_fk=False): allow_explicit_fk=False):
@ -1330,7 +1343,7 @@ class Query(object):
""" """
joins = [alias] joins = [alias]
# First, generate the path for the names # First, generate the path for the names
path, final_field, target = self.names_to_path( path, final_field, targets = self.names_to_path(
names, opts, allow_many, allow_explicit_fk) names, opts, allow_many, allow_explicit_fk)
# Then, add the path to the query's joins. Note that we can't trim # Then, add the path to the query's joins. Note that we can't trim
# joins at this stage - we will need the information about join type # joins at this stage - we will need the information about join type
@ -1338,17 +1351,19 @@ class Query(object):
for pos, join in enumerate(path): for pos, join in enumerate(path):
opts = join.to_opts opts = join.to_opts
if join.direct: if join.direct:
nullable = self.is_nullable(join.from_field) nullable = self.is_nullable(join.join_field)
else: else:
nullable = True nullable = True
connection = alias, opts.db_table, join.from_field.column, join.to_field.column connection = alias, opts.db_table, join.join_field.get_joining_columns()
reuse = can_reuse if join.m2m else None reuse = can_reuse if join.m2m else None
alias = self.join(connection, reuse=reuse, alias = self.join(connection, reuse=reuse,
nullable=nullable, join_field=join.join_field) nullable=nullable, join_field=join.join_field)
joins.append(alias) joins.append(alias)
return final_field, target, opts, joins, path if hasattr(final_field, 'field'):
final_field = final_field.field
return final_field, targets, opts, joins, path
def trim_joins(self, target, joins, path): def trim_joins(self, targets, joins, path):
""" """
The 'target' parameter is the final field being joined to, 'joins' The 'target' parameter is the final field being joined to, 'joins'
is the full list of join aliases. The 'path' contain the PathInfos is the full list of join aliases. The 'path' contain the PathInfos
@ -1362,13 +1377,16 @@ class Query(object):
trimmed as we don't know if there is anything on the other side of trimmed as we don't know if there is anything on the other side of
the join. the join.
""" """
for info in reversed(path): for pos, info in enumerate(reversed(path)):
if info.to_field == target and info.direct: if len(joins) == 1 or not info.direct:
target = info.from_field
self.unref_alias(joins.pop())
else:
break break
return target, joins[-1], joins join_targets = set(t.column for t in info.join_field.foreign_related_fields)
cur_targets = set(t.column for t in targets)
if not cur_targets.issubset(join_targets):
break
targets = tuple(r[0] for r in info.join_field.related_fields if r[1].column in cur_targets)
self.unref_alias(joins.pop())
return targets, joins[-1], joins
def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
""" """
@ -1413,17 +1431,31 @@ class Query(object):
trimmed_prefix = [] trimmed_prefix = []
paths_in_prefix = trimmed_joins paths_in_prefix = trimmed_joins
for name, path in names_with_path: for name, path in names_with_path:
if paths_in_prefix - len(path) > 0: if paths_in_prefix - len(path) < 0:
trimmed_prefix.append(name)
paths_in_prefix -= len(path)
else:
trimmed_prefix.append(
path[paths_in_prefix - len(path)].from_field.name)
break break
trimmed_prefix.append(name)
paths_in_prefix -= len(path)
join_field = path[paths_in_prefix].join_field
# TODO: This should be made properly multicolumn
# join aware. It is likely better to not use build_filter
# at all, instead construct joins up to the correct point,
# then construct the needed equality constraint manually,
# or maybe using SubqueryConstraint would work, too.
# The foreign_related_fields attribute is right here, we
# don't ever split joins for direct case.
trimmed_prefix.append(
join_field.field.foreign_related_fields[0].name)
trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix) trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
return self.build_filter( condition = self.build_filter(
('%s__in' % trimmed_prefix, query), ('%s__in' % trimmed_prefix, query),
current_negated=True, branch_negated=True, can_reuse=can_reuse) current_negated=True, branch_negated=True, can_reuse=can_reuse)
# Intentionally leave the other alias as blank, if the condition
# refers it, things will break here.
extra_restriction = join_field.get_extra_restriction(
self.where_class, None, [t for t in query.tables if query.alias_refcount[t]][0])
if extra_restriction:
query.where.add(extra_restriction, 'AND')
return condition
def set_empty(self): def set_empty(self):
self.where = EmptyWhere() self.where = EmptyWhere()
@ -1502,20 +1534,17 @@ class Query(object):
try: try:
for name in field_names: for name in field_names:
field, target, u2, joins, u3 = self.setup_joins( field, targets, u2, joins, path = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, None, allow_m2m, name.split(LOOKUP_SEP), opts, alias, None, allow_m2m,
True) True)
final_alias = joins[-1]
col = target.column # Trim last join if possible
if len(joins) > 1: targets, final_alias, remaining_joins = self.trim_joins(targets, joins[-2:], path)
join = self.alias_map[final_alias] joins = joins[:-2] + remaining_joins
if col == join.rhs_join_col:
self.unref_alias(final_alias)
final_alias = join.lhs_alias
col = join.lhs_join_col
joins = joins[:-1]
self.promote_joins(joins[1:]) self.promote_joins(joins[1:])
self.select.append(SelectInfo((final_alias, col), field)) for target in targets:
self.select.append(SelectInfo((final_alias, target.column), target))
except MultiJoin: except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name) raise FieldError("Invalid field name: '%s'" % name)
except FieldError: except FieldError:
@ -1590,7 +1619,7 @@ class Query(object):
opts = self.model._meta opts = self.model._meta
if not self.select: if not self.select:
count = self.aggregates_module.Count( count = self.aggregates_module.Count(
(self.join((None, opts.db_table, None, None)), opts.pk.column), (self.join((None, opts.db_table, None)), opts.pk.column),
is_summary=True, distinct=True) is_summary=True, distinct=True)
else: else:
# Because of SQL portability issues, multi-column, distinct # Because of SQL portability issues, multi-column, distinct
@ -1792,22 +1821,27 @@ class Query(object):
in "WHERE somecol IN (subquery)". This construct is needed by in "WHERE somecol IN (subquery)". This construct is needed by
split_exclude(). split_exclude().
_""" _"""
join_pos = 0 all_paths = []
for _, paths in names_with_path: for _, paths in names_with_path:
for path in paths: all_paths.extend(paths)
peek = self.tables[join_pos + 1] direct_join = True
if self.alias_map[peek].join_type == self.LOUTER: for pos, path in enumerate(all_paths):
# Back up one level and break if self.alias_map[self.tables[pos + 1]].join_type == self.LOUTER:
select_alias = self.tables[join_pos] direct_join = False
select_field = path.from_field pos -= 1
break break
select_alias = self.tables[join_pos + 1] self.unref_alias(self.tables[pos])
select_field = path.to_field if path.direct:
self.unref_alias(self.tables[join_pos]) direct_join = not direct_join
join_pos += 1 join_side = 0 if direct_join else 1
self.select = [SelectInfo((select_alias, select_field.column), select_field)] select_alias = self.tables[pos + 1]
join_field = path.join_field
if hasattr(join_field, 'field'):
join_field = join_field.field
select_fields = [r[join_side] for r in join_field.related_fields]
self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields]
self.remove_inherited_models() self.remove_inherited_models()
return join_pos return pos
def is_nullable(self, field): def is_nullable(self, field):
""" """

View File

@ -382,3 +382,28 @@ class Constraint(object):
new.__class__ = self.__class__ new.__class__ = self.__class__
new.alias, new.col, new.field = change_map[self.alias], self.col, self.field new.alias, new.col, new.field = change_map[self.alias], self.col, self.field
return new return new
class SubqueryConstraint(object):
def __init__(self, alias, columns, targets, query_object):
self.alias = alias
self.columns = columns
self.targets = targets
self.query_object = query_object
def as_sql(self, qn, connection):
query = self.query_object
# QuerySet was sent
if hasattr(query, 'values'):
# as_sql should throw if we are using a
# connection on another database
query._as_sql(connection=connection)
query = query.values(*self.targets).query
query_compiler = query.get_compiler(connection=connection)
return query_compiler.as_subquery_condition(self.alias, self.columns)
def relabeled_clone(self, relabels):
return self.__class__(
relabels.get(self.alias, self.alias),
self.columns, self.query_object)

View File

@ -110,7 +110,7 @@ def model_to_dict(instance, fields=None, exclude=None):
from django.db.models.fields.related import ManyToManyField from django.db.models.fields.related import ManyToManyField
opts = instance._meta opts = instance._meta
data = {} data = {}
for f in opts.fields + opts.many_to_many: for f in opts.concrete_fields + opts.many_to_many:
if not f.editable: if not f.editable:
continue continue
if fields and not f.name in fields: if fields and not f.name in fields:
@ -149,7 +149,7 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None, formfield_c
field_list = [] field_list = []
ignored = [] ignored = []
opts = model._meta opts = model._meta
for f in sorted(opts.fields + opts.many_to_many): for f in sorted(opts.concrete_fields + opts.many_to_many):
if not f.editable: if not f.editable:
continue continue
if fields is not None and not f.name in fields: if fields is not None and not f.name in fields: