From 97774429aeb54df4c09895c07cd1b09e70201f7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Sun, 24 Mar 2013 18:40:40 +0200 Subject: [PATCH] Fixed #19385 again, now with real code changes The commit of 266de5f9ae9e9f2fbfaec3b7e4b5fb9941967801 included only tests, this time also code changes included... --- django/contrib/contenttypes/generic.py | 100 ++--- django/core/management/validation.py | 12 +- django/db/backends/mysql/compiler.py | 6 + django/db/models/__init__.py | 2 +- django/db/models/base.py | 19 +- django/db/models/deletion.py | 15 +- django/db/models/fields/__init__.py | 7 +- django/db/models/fields/related.py | 532 +++++++++++++++---------- django/db/models/options.py | 38 +- django/db/models/query.py | 26 +- django/db/models/related.py | 2 +- django/db/models/sql/compiler.py | 108 +++-- django/db/models/sql/constants.py | 2 +- django/db/models/sql/expressions.py | 7 +- django/db/models/sql/query.py | 174 ++++---- django/db/models/sql/where.py | 25 ++ django/forms/models.py | 4 +- 17 files changed, 653 insertions(+), 426 deletions(-) diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index aa232ab1d5..3e132bb3a4 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -8,10 +8,11 @@ from functools import partial from django.core.exceptions import ObjectDoesNotExist from django.db import connection -from django.db.models import signals from django.db import models, router, DEFAULT_DB_ALIAS -from django.db.models.fields.related import RelatedField, Field, ManyToManyRel +from django.db.models import signals +from django.db.models.fields.related import ForeignObject, ForeignObjectRel from django.db.models.related import PathInfo +from django.db.models.sql.where import Constraint from django.forms import ModelForm from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets @@ -149,17 +150,14 @@ class GenericForeignKey(six.with_metaclass(RenameGenericForeignKeyMethods)): setattr(instance, self.fk_field, fk) setattr(instance, self.cache_attr, value) -class GenericRelation(RelatedField, Field): +class GenericRelation(ForeignObject): """Provides an accessor to generic related objects (e.g. comments)""" def __init__(self, to, **kwargs): kwargs['verbose_name'] = kwargs.get('verbose_name', None) - kwargs['rel'] = GenericRel(to, - related_name=kwargs.pop('related_name', None), - limit_choices_to=kwargs.pop('limit_choices_to', None), - symmetrical=kwargs.pop('symmetrical', True)) - - + kwargs['rel'] = GenericRel( + self, to, related_name=kwargs.pop('related_name', None), + limit_choices_to=kwargs.pop('limit_choices_to', None),) # Override content-type/object-id field names on the related class self.object_id_field_name = kwargs.pop("object_id_field", "object_id") self.content_type_field_name = kwargs.pop("content_type_field", "content_type") @@ -167,47 +165,44 @@ class GenericRelation(RelatedField, Field): kwargs['blank'] = True kwargs['editable'] = False kwargs['serialize'] = False - Field.__init__(self, **kwargs) + # This construct is somewhat of an abuse of ForeignObject. This field + # represents a relation from pk to object_id field. But, this relation + # isn't direct, the join is generated reverse along foreign key. So, + # the from_field is object_id field, to_field is pk because of the + # reverse join. + super(GenericRelation, self).__init__( + to, to_fields=[], + from_fields=[self.object_id_field_name], **kwargs) - def get_path_info(self): - from_field = self.model._meta.pk + def resolve_related_fields(self): + self.to_fields = [self.model._meta.pk.name] + return [(self.rel.to._meta.get_field_by_name(self.object_id_field_name)[0], + self.model._meta.pk)] + + def get_reverse_path_info(self): opts = self.rel.to._meta target = opts.get_field_by_name(self.object_id_field_name)[0] - # Note that we are using different field for the join_field - # than from_field or to_field. This is a hack, but we need the - # GenericRelation to generate the extra SQL. - return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)], - opts, target, self) + return [PathInfo(self.model._meta, opts, (target,), self.rel, True, False)] def get_choices_default(self): - return Field.get_choices(self, include_blank=False) + return super(GenericRelation, self).get_choices(include_blank=False) def value_to_string(self, obj): qs = getattr(obj, self.name).all() return smart_text([instance._get_pk_val() for instance in qs]) - def m2m_db_table(self): - return self.rel.to._meta.db_table - - def m2m_column_name(self): - return self.object_id_field_name - - def m2m_reverse_name(self): - return self.rel.to._meta.pk.column - - def m2m_target_field_name(self): - return self.model._meta.pk.name - - def m2m_reverse_target_field_name(self): - return self.rel.to._meta.pk.name + def get_joining_columns(self, reverse_join=False): + if not reverse_join: + # This error message is meant for the user, and from user + # perspective this is a reverse join along the GenericRelation. + raise ValueError('Joining in reverse direction not allowed.') + return super(GenericRelation, self).get_joining_columns(reverse_join) def contribute_to_class(self, cls, name): - super(GenericRelation, self).contribute_to_class(cls, name) - + super(GenericRelation, self).contribute_to_class(cls, name, virtual_only=True) # Save a reference to which model this class is on for future use self.model = cls - - # Add the descriptor for the m2m relation + # Add the descriptor for the relation setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self)) def contribute_to_related_class(self, cls, related): @@ -219,21 +214,18 @@ class GenericRelation(RelatedField, Field): def get_internal_type(self): return "ManyToManyField" - def db_type(self, connection): - # Since we're simulating a ManyToManyField, in effect, best return the - # same db_type as well. - return None - def get_content_type(self): """ Returns the content type associated with this field's model. """ return ContentType.objects.get_for_model(self.model) - def get_extra_join_sql(self, connection, qn, lhs_alias, rhs_alias): - extra_col = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0].column - contenttype = self.get_content_type().pk - return " AND %s.%s = %%s" % (qn(rhs_alias), qn(extra_col)), [contenttype] + def get_extra_restriction(self, where_class, alias, remote_alias): + field = self.rel.to._meta.get_field_by_name(self.content_type_field_name)[0] + contenttype_pk = self.get_content_type().pk + cond = where_class() + cond.add((Constraint(remote_alias, field.column, field), 'exact', contenttype_pk), 'AND') + return cond def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): """ @@ -273,12 +265,12 @@ class ReverseGenericRelatedObjectsDescriptor(object): qn = connection.ops.quote_name content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(instance) + join_cols = self.field.get_joining_columns(reverse_join=True)[0] manager = RelatedManager( model = rel_model, instance = instance, - symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model), - source_col_name = qn(self.field.m2m_column_name()), - target_col_name = qn(self.field.m2m_reverse_name()), + source_col_name = qn(join_cols[0]), + target_col_name = qn(join_cols[1]), content_type = content_type, content_type_field_name = self.field.content_type_field_name, object_id_field_name = self.field.object_id_field_name, @@ -378,14 +370,10 @@ def create_generic_related_manager(superclass): return GenericRelatedObjectManager -class GenericRel(ManyToManyRel): - def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True): - self.to = to - self.related_name = related_name - self.limit_choices_to = limit_choices_to or {} - self.symmetrical = symmetrical - self.multiple = True - self.through = None +class GenericRel(ForeignObjectRel): + + def __init__(self, field, to, related_name=None, limit_choices_to=None): + super(GenericRel, self).__init__(field, to, related_name, limit_choices_to) class BaseGenericInlineFormSet(BaseModelFormSet): """ diff --git a/django/core/management/validation.py b/django/core/management/validation.py index 587d3a0ad7..94d604346b 100644 --- a/django/core/management/validation.py +++ b/django/core/management/validation.py @@ -153,8 +153,16 @@ def get_validation_errors(outfile, app=None): continue # Make sure the related field specified by a ForeignKey is unique - if not f.rel.to._meta.get_field(f.rel.field_name).unique: - e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.rel.field_name, f.rel.to.__name__)) + if f.requires_unique_target: + if len(f.foreign_related_fields) > 1: + has_unique_field = False + for rel_field in f.foreign_related_fields: + has_unique_field = has_unique_field or rel_field.unique + if not has_unique_field: + e.add(opts, "Field combination '%s' under model '%s' must have a unique=True constraint" % (','.join([rel_field.name for rel_field in f.foreign_related_fields]), f.rel.to.__name__)) + else: + if not f.foreign_related_fields[0].unique: + e.add(opts, "Field '%s' under model '%s' must have a unique=True constraint." % (f.foreign_related_fields[0].name, f.rel.to.__name__)) rel_opts = f.rel.to._meta rel_name = f.related.get_accessor_name() diff --git a/django/db/backends/mysql/compiler.py b/django/db/backends/mysql/compiler.py index f4c5563eb2..50a085212b 100644 --- a/django/db/backends/mysql/compiler.py +++ b/django/db/backends/mysql/compiler.py @@ -17,6 +17,12 @@ class SQLCompiler(compiler.SQLCompiler): values.append(value) return row[:index_extra_select] + tuple(values) + def as_subquery_condition(self, alias, columns): + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + sql, params = self.as_sql() + return '(%s) IN (%s)' % (', '.join(['%s.%s' % (qn(alias), qn2(column)) for column in columns]), sql), params + class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): pass diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 2c91d50856..6c5ccd4bd2 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -8,7 +8,7 @@ from django.db.models.aggregates import * from django.db.models.fields import * from django.db.models.fields.subclassing import SubfieldBase from django.db.models.fields.files import FileField, ImageField -from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel +from django.db.models.fields.related import ForeignKey, ForeignObject, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel from django.db.models.deletion import CASCADE, PROTECT, SET, SET_NULL, SET_DEFAULT, DO_NOTHING, ProtectedError from django.db.models import signals from django.utils.decorators import wraps diff --git a/django/db/models/base.py b/django/db/models/base.py index f3e3b76dd7..a2eee60c61 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -10,7 +10,7 @@ from django.conf import settings from django.core.exceptions import (ObjectDoesNotExist, MultipleObjectsReturned, FieldError, ValidationError, NON_FIELD_ERRORS) from django.db.models.fields import AutoField, FieldDoesNotExist -from django.db.models.fields.related import (ManyToOneRel, +from django.db.models.fields.related import (ForeignObjectRel, ManyToOneRel, OneToOneField, add_lazy_relation) from django.db import (router, transaction, DatabaseError, DEFAULT_DB_ALIAS) @@ -333,12 +333,12 @@ class Model(six.with_metaclass(ModelBase)): # The reason for the kwargs check is that standard iterator passes in by # args, and instantiation for iteration is 33% faster. args_len = len(args) - if args_len > len(self._meta.fields): + if args_len > len(self._meta.concrete_fields): # Daft, but matches old exception sans the err msg. raise IndexError("Number of args exceeds number of fields") - fields_iter = iter(self._meta.fields) if not kwargs: + fields_iter = iter(self._meta.concrete_fields) # The ordering of the zip calls matter - zip throws StopIteration # when an iter throws it. So if the first iter throws it, the second # is *not* consumed. We rely on this, so don't change the order @@ -347,6 +347,7 @@ class Model(six.with_metaclass(ModelBase)): setattr(self, field.attname, val) else: # Slower, kwargs-ready version. + fields_iter = iter(self._meta.fields) for val, field in zip(args, fields_iter): setattr(self, field.attname, val) kwargs.pop(field.name, None) @@ -363,11 +364,12 @@ class Model(six.with_metaclass(ModelBase)): # data-descriptor object (DeferredAttribute) without triggering its # __get__ method. if (field.attname not in kwargs and - isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute)): + (isinstance(self.__class__.__dict__.get(field.attname), DeferredAttribute) + or field.column is None)): # This field will be populated on request. continue if kwargs: - if isinstance(field.rel, ManyToOneRel): + if isinstance(field.rel, ForeignObjectRel): try: # Assume object instance was passed in. rel_obj = kwargs.pop(field.name) @@ -394,6 +396,7 @@ class Model(six.with_metaclass(ModelBase)): val = field.get_default() else: val = field.get_default() + if is_related_object: # If we are passed a related instance, set it using the # field.name instead of field.attname (e.g. "user" instead of @@ -528,7 +531,7 @@ class Model(six.with_metaclass(ModelBase)): # automatically do a "update_fields" save on the loaded fields. elif not force_insert and self._deferred and using == self._state.db: field_names = set() - for field in self._meta.fields: + for field in self._meta.concrete_fields: if not field.primary_key and not hasattr(field, 'through'): field_names.add(field.attname) deferred_fields = [ @@ -614,7 +617,7 @@ class Model(six.with_metaclass(ModelBase)): for a single table. """ meta = cls._meta - non_pks = [f for f in meta.local_fields if not f.primary_key] + non_pks = [f for f in meta.local_concrete_fields if not f.primary_key] if update_fields: non_pks = [f for f in non_pks @@ -652,7 +655,7 @@ class Model(six.with_metaclass(ModelBase)): **{field.name: getattr(self, field.attname)}).count() self._order = order_value - fields = meta.local_fields + fields = meta.local_concrete_fields if not pk_set: fields = [f for f in fields if not isinstance(f, AutoField)] diff --git a/django/db/models/deletion.py b/django/db/models/deletion.py index a04f05c73b..e0bfb9d879 100644 --- a/django/db/models/deletion.py +++ b/django/db/models/deletion.py @@ -1,4 +1,3 @@ -from functools import wraps from operator import attrgetter from django.db import connections, transaction, IntegrityError @@ -196,17 +195,13 @@ class Collector(object): self.fast_deletes.append(sub_objs) elif sub_objs: field.rel.on_delete(self, field, sub_objs, self.using) - - # TODO This entire block is only needed as a special case to - # support cascade-deletes for GenericRelation. It should be - # removed/fixed when the ORM gains a proper abstraction for virtual - # or composite fields, and GFKs are reworked to fit into that. - for relation in model._meta.many_to_many: - if not relation.rel.through: - sub_objs = relation.bulk_related_objects(new_objs, self.using) + for field in model._meta.virtual_fields: + if hasattr(field, 'bulk_related_objects'): + # Its something like generic foreign key. + sub_objs = field.bulk_related_objects(new_objs, self.using) self.collect(sub_objs, source=model, - source_attr=relation.rel.related_name, + source_attr=field.rel.related_name, nullable=True) def related_objects(self, related, objs): diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 142b33f6a7..a925f0c577 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -292,10 +292,13 @@ class Field(object): if self.verbose_name is None and self.name: self.verbose_name = self.name.replace('_', ' ') - def contribute_to_class(self, cls, name): + def contribute_to_class(self, cls, name, virtual_only=False): self.set_attributes_from_name(name) self.model = cls - cls._meta.add_field(self) + if virtual_only: + cls._meta.add_virtual_field(self) + else: + cls._meta.add_field(self) if self.choices: setattr(cls, 'get_%s_display' % self.name, curry(cls._get_FIELD_display, field=self)) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index ee1361779a..b7d68e9ce3 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -7,7 +7,6 @@ from django.db.models.fields import (AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) from django.db.models.related import RelatedObject, PathInfo from django.db.models.query import QuerySet -from django.db.models.query_utils import QueryWrapper from django.db.models.deletion import CASCADE from django.utils.encoding import smart_text from django.utils import six @@ -93,22 +92,27 @@ signals.class_prepared.connect(do_pending_lookups) #HACK -class RelatedField(object): - def contribute_to_class(self, cls, name): +class RelatedField(Field): + def db_type(self, connection): + '''By default related field will not have a column + as it relates columns to another table''' + return None + + def contribute_to_class(self, cls, name, virtual_only=False): sup = super(RelatedField, self) # Store the opts for related_query_name() self.opts = cls._meta if hasattr(sup, 'contribute_to_class'): - sup.contribute_to_class(cls, name) + sup.contribute_to_class(cls, name, virtual_only=virtual_only) if not cls._meta.abstract and self.rel.related_name: - self.rel.related_name = self.rel.related_name % { - 'class': cls.__name__.lower(), - 'app_label': cls._meta.app_label.lower(), - } - + related_name = self.rel.related_name % { + 'class': cls.__name__.lower(), + 'app_label': cls._meta.app_label.lower() + } + self.rel.related_name = related_name other = self.rel.to if isinstance(other, six.string_types) or other._meta.pk is None: def resolve_related_class(field, model, cls): @@ -122,7 +126,6 @@ class RelatedField(object): self.name = self.name or (self.rel.to._meta.model_name + '_' + self.rel.to._meta.pk.name) if self.verbose_name is None: self.verbose_name = self.rel.to._meta.verbose_name - self.rel.field_name = self.rel.field_name or self.rel.to._meta.pk.name def do_related_class(self, other, cls): self.set_attributes_from_rel() @@ -130,94 +133,6 @@ class RelatedField(object): if not cls._meta.abstract: self.contribute_to_related_class(other, self.related) - def get_prep_lookup(self, lookup_type, value): - if hasattr(value, 'prepare'): - return value.prepare() - if hasattr(value, '_prepare'): - return value._prepare() - # FIXME: lt and gt are explicitly allowed to make - # get_(next/prev)_by_date work; other lookups are not allowed since that - # gets messy pretty quick. This is a good candidate for some refactoring - # in the future. - if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']: - return self._pk_trace(value, 'get_prep_lookup', lookup_type) - if lookup_type in ('range', 'in'): - return [self._pk_trace(v, 'get_prep_lookup', lookup_type) for v in value] - elif lookup_type == 'isnull': - return [] - raise TypeError("Related Field has invalid lookup: %s" % lookup_type) - - def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): - if not prepared: - value = self.get_prep_lookup(lookup_type, value) - if hasattr(value, 'get_compiler'): - value = value.get_compiler(connection=connection) - if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): - # If the value has a relabeled_clone method it means the - # value will be handled later on. - if hasattr(value, 'relabeled_clone'): - return value - if hasattr(value, 'as_sql'): - sql, params = value.as_sql() - else: - sql, params = value._as_sql(connection=connection) - return QueryWrapper(('(%s)' % sql), params) - - # FIXME: lt and gt are explicitly allowed to make - # get_(next/prev)_by_date work; other lookups are not allowed since that - # gets messy pretty quick. This is a good candidate for some refactoring - # in the future. - if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']: - return [self._pk_trace(value, 'get_db_prep_lookup', lookup_type, - connection=connection, prepared=prepared)] - if lookup_type in ('range', 'in'): - return [self._pk_trace(v, 'get_db_prep_lookup', lookup_type, - connection=connection, prepared=prepared) - for v in value] - elif lookup_type == 'isnull': - return [] - raise TypeError("Related Field has invalid lookup: %s" % lookup_type) - - def _pk_trace(self, value, prep_func, lookup_type, **kwargs): - # Value may be a primary key, or an object held in a relation. - # If it is an object, then we need to get the primary key value for - # that object. In certain conditions (especially one-to-one relations), - # the primary key may itself be an object - so we need to keep drilling - # down until we hit a value that can be used for a comparison. - v = value - - # In the case of an FK to 'self', this check allows to_field to be used - # for both forwards and reverse lookups across the FK. (For normal FKs, - # it's only relevant for forward lookups). - if isinstance(v, self.rel.to): - field_name = getattr(self.rel, "field_name", None) - else: - field_name = None - try: - while True: - if field_name is None: - field_name = v._meta.pk.name - v = getattr(v, field_name) - field_name = None - except AttributeError: - pass - except exceptions.ObjectDoesNotExist: - v = None - - field = self - while field.rel: - if hasattr(field.rel, 'field_name'): - field = field.rel.to._meta.get_field(field.rel.field_name) - else: - field = field.rel.to._meta.pk - - if lookup_type in ('range', 'in'): - v = [v] - v = getattr(field, prep_func)(lookup_type, v, **kwargs) - if isinstance(v, list): - v = v[0] - return v - def related_query_name(self): # This method defines the name that can be used to identify this # related object in a table-spanning query. It uses the lower-cased @@ -254,8 +169,8 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri rel_obj_attr = attrgetter(self.related.field.attname) instance_attr = lambda obj: obj._get_pk_val() instances_dict = dict((instance_attr(inst), inst) for inst in instances) - params = {'%s__pk__in' % self.related.field.name: list(instances_dict)} - qs = self.get_queryset(instance=instances[0]).filter(**params) + query = {'%s__in' % self.related.field.name: instances} + qs = self.get_query_set(instance=instances[0]).filter(**query) # Since we're going to assign directly in the cache, # we must manage the reverse relation cache manually. rel_obj_cache_name = self.related.field.get_cache_name() @@ -274,7 +189,9 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri if related_pk is None: rel_obj = None else: - params = {'%s__pk' % self.related.field.name: related_pk} + params = {} + for lh_field, rh_field in self.related.field.related_fields: + params['%s__%s' % (self.related.field.name, rh_field.name)] = getattr(instance, rh_field.attname) try: rel_obj = self.get_queryset(instance=instance).get(**params) except self.related.model.DoesNotExist: @@ -314,13 +231,14 @@ class SingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjectDescri raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' % (value, instance._state.db, value._state.db)) - related_pk = getattr(instance, self.related.field.rel.get_related_field().attname) - if related_pk is None: + related_pk = tuple([getattr(instance, field.attname) for field in self.related.field.foreign_related_fields]) + if None in related_pk: raise ValueError('Cannot assign "%r": "%s" instance isn\'t saved in the database.' % (value, instance._meta.object_name)) # Set the value of the related field to the value of the related object's related field - setattr(value, self.related.field.attname, related_pk) + for index, field in enumerate(self.related.field.local_related_fields): + setattr(value, field.attname, related_pk[index]) # Since we already know what the related object is, seed the related # object caches now, too. This avoids another db hit if you get the @@ -352,16 +270,12 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec else: return QuerySet(self.field.rel.to).using(db) - def get_prefetch_queryset(self, instances): - other_field = self.field.rel.get_related_field() - rel_obj_attr = attrgetter(other_field.attname) - instance_attr = attrgetter(self.field.attname) + def get_prefetch_query_set(self, instances): + rel_obj_attr = self.field.get_foreign_related_value + instance_attr = self.field.get_local_related_value instances_dict = dict((instance_attr(inst), inst) for inst in instances) - if other_field.rel: - params = {'%s__pk__in' % self.field.rel.field_name: list(instances_dict)} - else: - params = {'%s__in' % self.field.rel.field_name: list(instances_dict)} - qs = self.get_queryset(instance=instances[0]).filter(**params) + query = {'%s__in' % self.field.related_query_name(): instances} + qs = self.get_query_set(instance=instances[0]).filter(**query) # Since we're going to assign directly in the cache, # we must manage the reverse relation cache manually. if not self.field.rel.multiple: @@ -377,16 +291,14 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec try: rel_obj = getattr(instance, self.cache_name) except AttributeError: - val = getattr(instance, self.field.attname) - if val is None: + val = self.field.get_local_related_value(instance) + if None in val: rel_obj = None else: - other_field = self.field.rel.get_related_field() - if other_field.rel: - params = {'%s__%s' % (self.field.rel.field_name, other_field.rel.field_name): val} - else: - params = {'%s__exact' % self.field.rel.field_name: val} - qs = self.get_queryset(instance=instance) + params = {rh_field.attname: getattr(instance, lh_field.attname) + for lh_field, rh_field in self.field.related_fields} + params.update(self.field.get_extra_descriptor_filter(instance)) + qs = self.get_query_set(instance=instance) # Assuming the database enforces foreign keys, this won't fail. rel_obj = qs.get(**params) if not self.field.rel.multiple: @@ -440,11 +352,11 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec setattr(related, self.field.related.get_cache_name(), None) # Set the value of the related field - try: - val = getattr(value, self.field.rel.get_related_field().attname) - except AttributeError: - val = None - setattr(instance, self.field.attname, val) + for lh_field, rh_field in self.field.related_fields: + try: + setattr(instance, lh_field.attname, getattr(value, rh_field.attname)) + except AttributeError: + setattr(instance, lh_field.attname, None) # Since we already know what the related object is, seed the related # object caches now, too. This avoids another db hit if you get the @@ -487,15 +399,12 @@ class ForeignRelatedObjectsDescriptor(object): superclass = self.related.model._default_manager.__class__ rel_field = self.related.field rel_model = self.related.model - attname = rel_field.rel.get_related_field().attname class RelatedManager(superclass): def __init__(self, instance): super(RelatedManager, self).__init__() self.instance = instance - self.core_filters = { - '%s__%s' % (rel_field.name, attname): getattr(instance, attname) - } + self.core_filters= {'%s__exact' % rel_field.name: instance} self.model = rel_model def get_queryset(self): @@ -504,20 +413,22 @@ class ForeignRelatedObjectsDescriptor(object): except (AttributeError, KeyError): db = self._db or router.db_for_read(self.model, instance=self.instance) qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters) - val = getattr(self.instance, attname) - if val is None or val == '' and connections[db].features.interprets_empty_strings_as_nulls: - return qs.none() + empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls + for field in rel_field.foreign_related_fields: + val = getattr(self.instance, field.attname) + if val is None or (val == '' and empty_strings_as_null): + return qs.none() qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}} return qs def get_prefetch_queryset(self, instances): - rel_obj_attr = attrgetter(rel_field.attname) - instance_attr = attrgetter(attname) + rel_obj_attr = rel_field.get_local_related_value + instance_attr = rel_field.get_foreign_related_value instances_dict = dict((instance_attr(inst), inst) for inst in instances) db = self._db or router.db_for_read(self.model, instance=instances[0]) - query = {'%s__%s__in' % (rel_field.name, attname): list(instances_dict)} - qs = super(RelatedManager, self).get_queryset().using(db).filter(**query) - # Since we just bypassed this class' get_queryset(), we must manage + query = {'%s__in' % rel_field.name: instances} + qs = super(RelatedManager, self).get_query_set().using(db).filter(**query) + # Since we just bypassed this class' get_query_set(), we must manage # the reverse relation manually. for rel_obj in qs: instance = instances_dict[rel_obj_attr(rel_obj)] @@ -550,10 +461,10 @@ class ForeignRelatedObjectsDescriptor(object): # remove() and clear() are only provided if the ForeignKey can have a value of null. if rel_field.null: def remove(self, *objs): - val = getattr(self.instance, attname) + val = rel_field.get_foreign_related_value(self.instance) for obj in objs: # Is obj actually part of this descriptor set? - if getattr(obj, rel_field.attname) == val: + if rel_field.get_local_related_value(obj) == val: setattr(obj, rel_field.name, None) obj.save() else: @@ -577,16 +488,26 @@ def create_many_related_manager(superclass, rel): super(ManyRelatedManager, self).__init__() self.model = model self.query_field_name = query_field_name - self.core_filters = {'%s__pk' % query_field_name: instance._get_pk_val()} + + source_field = through._meta.get_field(source_field_name) + source_related_fields = source_field.related_fields + + self.core_filters = {} + for lh_field, rh_field in source_related_fields: + self.core_filters['%s__%s' % (query_field_name, rh_field.name)] = getattr(instance, rh_field.attname) + self.instance = instance self.symmetrical = symmetrical + self.source_field = source_field self.source_field_name = source_field_name self.target_field_name = target_field_name self.reverse = reverse self.through = through self.prefetch_cache_name = prefetch_cache_name - self._fk_val = self._get_fk_val(instance, source_field_name) - if self._fk_val is None: + self.related_val = source_field.get_foreign_related_value(instance) + # Used for single column related auto created models + self._fk_val = self.related_val[0] + if None in self.related_val: raise ValueError('"%r" needs to have a value for field "%s" before ' 'this many-to-many relationship can be used.' % (instance, source_field_name)) @@ -620,11 +541,9 @@ def create_many_related_manager(superclass, rel): def get_prefetch_queryset(self, instances): instance = instances[0] - from django.db import connections db = self._db or router.db_for_read(instance.__class__, instance=instance) - query = {'%s__pk__in' % self.query_field_name: - set(obj._get_pk_val() for obj in instances)} - qs = super(ManyRelatedManager, self).get_queryset().using(db)._next_is_sticky().filter(**query) + query = {'%s__in' % self.query_field_name: instances} + qs = super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query) # M2M: need to annotate the query in order to get the primary model # that the secondary model was actually related to. We know that @@ -634,16 +553,14 @@ def create_many_related_manager(superclass, rel): # For non-autocreated 'through' models, can't assume we are # dealing with PK values. fk = self.through._meta.get_field(self.source_field_name) - source_col = fk.column join_table = self.through._meta.db_table connection = connections[db] qn = connection.ops.quote_name - qs = qs.extra(select={'_prefetch_related_val': - '%s.%s' % (qn(join_table), qn(source_col))}) - select_attname = fk.rel.get_related_field().get_attname() + qs = qs.extra(select={'_prefetch_related_val_%s' % f.attname: + '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields}) return (qs, - attrgetter('_prefetch_related_val'), - attrgetter(select_attname), + lambda result: tuple([getattr(result, '_prefetch_related_val_%s' % f.attname) for f in fk.local_related_fields]), + lambda inst: tuple([getattr(inst, f.attname) for f in fk.foreign_related_fields]), False, self.prefetch_cache_name) @@ -795,7 +712,7 @@ def create_many_related_manager(superclass, rel): instance=self.instance, reverse=self.reverse, model=self.model, pk_set=None, using=db) self.through._default_manager.using(db).filter(**{ - source_field_name: self._fk_val + source_field_name: self.related_val }).delete() if self.reverse or source_field_name == self.source_field_name: # Don't send the signal when we are clearing the @@ -918,19 +835,18 @@ class ReverseManyRelatedObjectsDescriptor(object): manager.clear() manager.add(*value) - -class ManyToOneRel(object): - def __init__(self, to, field_name, related_name=None, limit_choices_to=None, - parent_link=False, on_delete=None): +class ForeignObjectRel(object): + def __init__(self, field, to, related_name=None, limit_choices_to=None, + parent_link=False, on_delete=None): try: to._meta except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT assert isinstance(to, six.string_types), "'to' must be either a model, a model name or the string %r" % RECURSIVE_RELATIONSHIP_CONSTANT - self.to, self.field_name = to, field_name + + self.field = field + self.to = to self.related_name = related_name - if limit_choices_to is None: - limit_choices_to = {} - self.limit_choices_to = limit_choices_to + self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to self.multiple = True self.parent_link = parent_link self.on_delete = on_delete @@ -939,6 +855,20 @@ class ManyToOneRel(object): "Should the related object be hidden?" return self.related_name and self.related_name[-1] == '+' + def get_joining_columns(self): + return self.field.get_reverse_joining_columns() + + def get_extra_restriction(self, where_class, alias, related_alias): + return self.field.get_extra_restriction(where_class, related_alias, alias) + +class ManyToOneRel(ForeignObjectRel): + def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, + parent_link=False, on_delete=None): + super(ManyToOneRel, self).__init__( + field, to, related_name=related_name, limit_choices_to=limit_choices_to, + parent_link=parent_link, on_delete=on_delete) + self.field_name = field_name + def get_related_field(self): """ Returns the Field in the 'to' object to which this relationship is @@ -952,9 +882,9 @@ class ManyToOneRel(object): class OneToOneRel(ManyToOneRel): - def __init__(self, to, field_name, related_name=None, limit_choices_to=None, - parent_link=False, on_delete=None): - super(OneToOneRel, self).__init__(to, field_name, + def __init__(self, field, to, field_name, related_name=None, limit_choices_to=None, + parent_link=False, on_delete=None): + super(OneToOneRel, self).__init__(field, to, field_name, related_name=related_name, limit_choices_to=limit_choices_to, parent_link=parent_link, on_delete=on_delete ) @@ -963,7 +893,7 @@ class OneToOneRel(ManyToOneRel): class ManyToManyRel(object): def __init__(self, to, related_name=None, limit_choices_to=None, - symmetrical=True, through=None, db_constraint=True): + symmetrical=True, through=None, db_constraint=True): if through and not db_constraint: raise ValueError("Can't supply a through model and db_constraint=False") self.to = to @@ -989,7 +919,199 @@ class ManyToManyRel(object): return self.to._meta.pk -class ForeignKey(RelatedField, Field): +class ForeignObject(RelatedField): + requires_unique_target = True + generate_reverse_relation = True + + def __init__(self, to, from_fields, to_fields, **kwargs): + self.from_fields = from_fields + self.to_fields = to_fields + + if 'rel' not in kwargs: + kwargs['rel'] = ForeignObjectRel( + self, to, + related_name=kwargs.pop('related_name', None), + limit_choices_to=kwargs.pop('limit_choices_to', None), + parent_link=kwargs.pop('parent_link', False), + on_delete=kwargs.pop('on_delete', CASCADE), + ) + kwargs['verbose_name'] = kwargs.get('verbose_name', None) + + super(ForeignObject, self).__init__(**kwargs) + + def resolve_related_fields(self): + if len(self.from_fields) < 1 or len(self.from_fields) != len(self.to_fields): + raise ValueError('Foreign Object from and to fields must be the same non-zero length') + related_fields = [] + for index in range(len(self.from_fields)): + from_field_name = self.from_fields[index] + to_field_name = self.to_fields[index] + from_field = (self if from_field_name == 'self' + else self.opts.get_field_by_name(from_field_name)[0]) + to_field = (self.rel.to._meta.pk if to_field_name is None + else self.rel.to._meta.get_field_by_name(to_field_name)[0]) + related_fields.append((from_field, to_field)) + return related_fields + + @property + def related_fields(self): + if not hasattr(self, '_related_fields'): + self._related_fields = self.resolve_related_fields() + return self._related_fields + + @property + def reverse_related_fields(self): + return [(rhs_field, lhs_field) for lhs_field, rhs_field in self.related_fields] + + @property + def local_related_fields(self): + return tuple([lhs_field for lhs_field, rhs_field in self.related_fields]) + + @property + def foreign_related_fields(self): + return tuple([rhs_field for lhs_field, rhs_field in self.related_fields]) + + def get_local_related_value(self, instance): + return self.get_instance_value_for_fields(instance, self.local_related_fields) + + def get_foreign_related_value(self, instance): + return self.get_instance_value_for_fields(instance, self.foreign_related_fields) + + @staticmethod + def get_instance_value_for_fields(instance, fields): + return tuple([getattr(instance, field.attname) for field in fields]) + + def get_attname_column(self): + attname, column = super(ForeignObject, self).get_attname_column() + return attname, None + + def get_joining_columns(self, reverse_join=False): + source = self.reverse_related_fields if reverse_join else self.related_fields + return tuple([(lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source]) + + def get_reverse_joining_columns(self): + return self.get_joining_columns(reverse_join=True) + + def get_extra_descriptor_filter(self, instance): + """ + Returns an extra filter condition for related object fetching when + user does 'instance.fieldname', that is the extra filter is used in + the descriptor of the field. + + The filter should be something usable in .filter(**kwargs) call, and + will be ANDed together with the joining columns condition. + + A parallel method is get_extra_relation_restriction() which is used in + JOIN and subquery conditions. + """ + return {} + + def get_extra_restriction(self, where_class, alias, related_alias): + """ + Returns a pair condition used for joining and subquery pushdown. The + condition is something that responds to as_sql(qn, connection) method. + + Note that currently referring both the 'alias' and 'related_alias' + will not work in some conditions, like subquery pushdown. + + A parallel method is get_extra_descriptor_filter() which is used in + instance.fieldname related object fetching. + """ + return None + + def get_path_info(self): + """ + Get path from this field to the related model. + """ + opts = self.rel.to._meta + from_opts = self.model._meta + return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True)] + + def get_reverse_path_info(self): + """ + Get path from the related model to this field's model. + """ + opts = self.model._meta + from_opts = self.rel.to._meta + pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] + return pathinfos + + def get_lookup_constraint(self, constraint_class, alias, targets, sources, lookup_type, + raw_value): + from django.db.models.sql.where import SubqueryConstraint, Constraint, AND, OR + root_constraint = constraint_class() + assert len(targets) == len(sources) + + def get_normalized_value(value): + + from django.db.models import Model + if isinstance(value, Model): + value_list = [] + for source in sources: + # Account for one-to-one relations when sent a different model + while not isinstance(value, source.model): + source = source.rel.to._meta.get_field(source.rel.field_name) + value_list.append(getattr(value, source.attname)) + return tuple(value_list) + elif not isinstance(value, tuple): + return (value,) + return value + + is_multicolumn = len(self.related_fields) > 1 + if (hasattr(raw_value, '_as_sql') or + hasattr(raw_value, 'get_compiler')): + root_constraint.add(SubqueryConstraint(alias, [target.column for target in targets], + [source.name for source in sources], raw_value), + AND) + elif lookup_type == 'isnull': + root_constraint.add( + (Constraint(alias, targets[0].column, targets[0]), lookup_type, raw_value), AND) + elif (lookup_type == 'exact' or (lookup_type in ['gt', 'lt', 'gte', 'lte'] + and not is_multicolumn)): + value = get_normalized_value(raw_value) + for index, source in enumerate(sources): + root_constraint.add( + (Constraint(alias, targets[index].column, sources[index]), lookup_type, + value[index]), AND) + elif lookup_type in ['range', 'in'] and not is_multicolumn: + values = [get_normalized_value(value) for value in raw_value] + value = [val[0] for val in values] + root_constraint.add( + (Constraint(alias, targets[0].column, sources[0]), lookup_type, value), AND) + elif lookup_type == 'in': + values = [get_normalized_value(value) for value in raw_value] + for value in values: + value_constraint = constraint_class() + for index, target in enumerate(targets): + value_constraint.add( + (Constraint(alias, target.column, sources[index]), 'exact', value[index]), + AND) + root_constraint.add(value_constraint, OR) + else: + raise TypeError('Related Field got invalid lookup: %s' % lookup_type) + return root_constraint + + @property + def attnames(self): + return tuple([field.attname for field in self.local_related_fields]) + + def get_defaults(self): + return tuple([field.get_default() for field in self.local_related_fields]) + + def contribute_to_class(self, cls, name, virtual_only=False): + super(ForeignObject, self).contribute_to_class(cls, name, virtual_only=virtual_only) + setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self)) + + def contribute_to_related_class(self, cls, related): + # Internal FK's - i.e., those with a related name ending with '+' - + # and swapped models don't get a related descriptor. + if not self.rel.is_hidden() and not related.model._meta.swapped: + setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) + if self.rel.limit_choices_to: + cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to) + + +class ForeignKey(ForeignObject): empty_strings_allowed = False default_error_messages = { 'invalid': _('Model %(model)s with pk %(pk)r does not exist.') @@ -999,7 +1121,7 @@ class ForeignKey(RelatedField, Field): def __init__(self, to, to_field=None, rel_class=ManyToOneRel, db_constraint=True, **kwargs): try: - to._meta.model_name + to_name = to._meta.object_name.lower() except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT assert isinstance(to, six.string_types), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT) else: @@ -1008,44 +1130,33 @@ class ForeignKey(RelatedField, Field): # the to_field during FK construction. It won't be guaranteed to # be correct until contribute_to_class is called. Refs #12190. to_field = to_field or (to._meta.pk and to._meta.pk.name) - kwargs['verbose_name'] = kwargs.get('verbose_name', None) if 'db_index' not in kwargs: kwargs['db_index'] = True self.db_constraint = db_constraint - kwargs['rel'] = rel_class(to, to_field, + + kwargs['rel'] = rel_class( + self, to, to_field, related_name=kwargs.pop('related_name', None), limit_choices_to=kwargs.pop('limit_choices_to', None), parent_link=kwargs.pop('parent_link', False), on_delete=kwargs.pop('on_delete', CASCADE), ) - super(ForeignKey, self).__init__(**kwargs) + super(ForeignKey, self).__init__(to, ['self'], [to_field], **kwargs) - def get_path_info(self): - """ - Get path from this field to the related model. - """ - opts = self.rel.to._meta - target = self.rel.get_related_field() - from_opts = self.model._meta - return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self + @property + def related_field(self): + return self.foreign_related_fields[0] def get_reverse_path_info(self): """ Get path from the related model to this field's model. """ opts = self.model._meta - from_field = self.rel.get_related_field() - from_opts = from_field.model._meta - pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)] - if from_field.model is self.model: - # Recursive foreign key to self. - target = opts.get_field_by_name( - self.rel.field_name)[0] - else: - target = opts.pk - return pathinfos, opts, target, self + from_opts = self.rel.to._meta + pathinfos = [PathInfo(from_opts, opts, (opts.pk,), self.rel, not self.unique, False)] + return pathinfos def validate(self, value, model_instance): if self.rel.parent_link: @@ -1066,21 +1177,26 @@ class ForeignKey(RelatedField, Field): def get_attname(self): return '%s_id' % self.name + def get_attname_column(self): + attname = self.get_attname() + column = self.db_column or attname + return attname, column + def get_validator_unique_lookup_type(self): - return '%s__%s__exact' % (self.name, self.rel.get_related_field().name) + return '%s__%s__exact' % (self.name, self.related_field.name) def get_default(self): "Here we check if the default value is an object and return the to_field if so." field_default = super(ForeignKey, self).get_default() if isinstance(field_default, self.rel.to): - return getattr(field_default, self.rel.get_related_field().attname) + return getattr(field_default, self.related_field.attname) return field_default def get_db_prep_save(self, value, connection): if value == '' or value == None: return None else: - return self.rel.get_related_field().get_db_prep_save(value, + return self.related_field.get_db_prep_save(value, connection=connection) def value_to_string(self, obj): @@ -1093,19 +1209,10 @@ class ForeignKey(RelatedField, Field): choice_list = self.get_choices_default() if len(choice_list) == 2: return smart_text(choice_list[1][0]) - return Field.value_to_string(self, obj) - - def contribute_to_class(self, cls, name): - super(ForeignKey, self).contribute_to_class(cls, name) - setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self)) + return super(ForeignKey, self).value_to_string(obj) def contribute_to_related_class(self, cls, related): - # Internal FK's - i.e., those with a related name ending with '+' - - # and swapped models don't get a related descriptor. - if not self.rel.is_hidden() and not related.model._meta.swapped: - setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) - if self.rel.limit_choices_to: - cls._meta.related_fkey_lookups.append(self.rel.limit_choices_to) + super(ForeignKey, self).contribute_to_related_class(cls, related) if self.rel.field_name is None: self.rel.field_name = cls._meta.pk.name @@ -1130,7 +1237,7 @@ class ForeignKey(RelatedField, Field): # in which case the column type is simply that of an IntegerField. # If the database needs similar types for key fields however, the only # thing we can do is making AutoField an IntegerField. - rel_field = self.rel.get_related_field() + rel_field = self.related_field if (isinstance(rel_field, AutoField) or (not connection.features.related_fields_match_type and isinstance(rel_field, (PositiveIntegerField, @@ -1212,7 +1319,7 @@ def create_many_to_many_intermediary_model(field, klass): }) -class ManyToManyField(RelatedField, Field): +class ManyToManyField(RelatedField): description = _("Many-to-many relationship") def __init__(self, to, db_constraint=True, **kwargs): @@ -1252,14 +1359,14 @@ class ManyToManyField(RelatedField, Field): linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0] linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0] if direct: - join1infos, _, _, _ = linkfield1.get_reverse_path_info() - join2infos, opts, target, final_field = linkfield2.get_path_info() + join1infos = linkfield1.get_reverse_path_info() + join2infos = linkfield2.get_path_info() else: - join1infos, _, _, _ = linkfield2.get_reverse_path_info() - join2infos, opts, target, final_field = linkfield1.get_path_info() + join1infos = linkfield2.get_reverse_path_info() + join2infos = linkfield1.get_path_info() pathinfos.extend(join1infos) pathinfos.extend(join2infos) - return pathinfos, opts, target, final_field + return pathinfos def get_path_info(self): return self._get_path_info(direct=True) @@ -1402,8 +1509,3 @@ class ManyToManyField(RelatedField, Field): initial = initial() defaults['initial'] = [i._get_pk_val() for i in initial] return super(ManyToManyField, self).formfield(**defaults) - - def db_type(self, connection): - # A ManyToManyField is not represented by a single column, - # so return None. - return None diff --git a/django/db/models/options.py b/django/db/models/options.py index a302e2d73a..acb5ff38bc 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -10,6 +10,7 @@ from django.db.models.fields import AutoField, FieldDoesNotExist from django.db.models.fields.proxy import OrderWrt from django.db.models.loading import get_models, app_cache_ready from django.utils import six +from django.utils.functional import cached_property from django.utils.datastructures import SortedDict from django.utils.encoding import force_text, smart_text, python_2_unicode_compatible from django.utils.translation import activate, deactivate_all, get_language, string_concat @@ -173,6 +174,22 @@ class Options(object): if hasattr(self, '_field_cache'): del self._field_cache del self._field_name_cache + # The fields, concrete_fields and local_concrete_fields are + # implemented as cached properties for performance reasons. + # The attrs will not exists if the cached property isn't + # accessed yet, hence the try-excepts. + try: + del self.fields + except AttributeError: + pass + try: + del self.concrete_fields + except AttributeError: + pass + try: + del self.local_concrete_fields + except AttributeError: + pass if hasattr(self, '_name_map'): del self._name_map @@ -245,7 +262,8 @@ class Options(object): return None swapped = property(_swapped) - def _fields(self): + @cached_property + def fields(self): """ The getter for self.fields. This returns the list of field objects available to this model (including through parent models). @@ -258,7 +276,14 @@ class Options(object): except AttributeError: self._fill_fields_cache() return self._field_name_cache - fields = property(_fields) + + @cached_property + def concrete_fields(self): + return [f for f in self.fields if f.column is not None] + + @cached_property + def local_concrete_fields(self): + return [f for f in self.local_fields if f.column is not None] def get_fields_with_model(self): """ @@ -272,6 +297,10 @@ class Options(object): self._fill_fields_cache() return self._field_cache + def get_concrete_fields_with_model(self): + return [(field, model) for field, model in self.get_fields_with_model() if + field.column is not None] + def _fill_fields_cache(self): cache = [] for parent in self.parents: @@ -377,6 +406,9 @@ class Options(object): cache[f.name] = (f, model, True, True) for f, model in self.get_fields_with_model(): cache[f.name] = (f, model, True, False) + for f in self.virtual_fields: + if hasattr(f, 'related'): + cache[f.name] = (f.related, None if f.model == self.model else f.model, True, False) if app_cache_ready(): self._name_map = cache return cache @@ -432,7 +464,7 @@ class Options(object): for klass in get_models(include_auto_created=True, only_installed=False): if not klass._meta.swapped: for f in klass._meta.local_fields: - if f.rel and not isinstance(f.rel.to, six.string_types): + if f.rel and not isinstance(f.rel.to, six.string_types) and f.generate_reverse_relation: if self == f.rel.to._meta: cache[f.related] = None proxy_cache[f.related] = None diff --git a/django/db/models/query.py b/django/db/models/query.py index 7ddd933772..aec62c5c20 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -261,13 +261,13 @@ class QuerySet(object): only_load = self.query.get_loaded_field_names() if not fill_cache: - fields = self.model._meta.fields + fields = self.model._meta.concrete_fields load_fields = [] # If only/defer clauses have been specified, # build the list of fields that are to be loaded. if only_load: - for field, model in self.model._meta.get_fields_with_model(): + for field, model in self.model._meta.get_concrete_fields_with_model(): if model is None: model = self.model try: @@ -280,7 +280,7 @@ class QuerySet(object): load_fields.append(field.name) index_start = len(extra_select) - aggregate_start = index_start + len(load_fields or self.model._meta.fields) + aggregate_start = index_start + len(load_fields or self.model._meta.concrete_fields) skip = None if load_fields and not fill_cache: @@ -312,7 +312,11 @@ class QuerySet(object): if skip: obj = model_cls(**dict(zip(init_list, row_data))) else: - obj = model(*row_data) + try: + obj = model(*row_data) + except IndexError: + import ipdb; ipdb.set_trace() + pass # Store the source database of the object obj._state.db = db @@ -962,7 +966,7 @@ class QuerySet(object): """ opts = self.model._meta if self.query.group_by is None: - field_names = [f.attname for f in opts.fields] + field_names = [f.attname for f in opts.concrete_fields] self.query.add_fields(field_names, False) self.query.set_group_by() @@ -1055,7 +1059,7 @@ class ValuesQuerySet(QuerySet): else: # Default to all fields. self.extra_names = None - self.field_names = [f.attname for f in self.model._meta.fields] + self.field_names = [f.attname for f in self.model._meta.concrete_fields] self.aggregate_names = None self.query.select = [] @@ -1266,7 +1270,7 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, skip = set() init_list = [] # Build the list of fields that *haven't* been requested - for field, model in klass._meta.get_fields_with_model(): + for field, model in klass._meta.get_concrete_fields_with_model(): if field.name not in load_fields: skip.add(field.attname) elif from_parent and issubclass(from_parent, model.__class__): @@ -1285,22 +1289,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, else: # Load all fields on klass - field_count = len(klass._meta.fields) + field_count = len(klass._meta.concrete_fields) # Check if we need to skip some parent fields. - if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields): + if from_parent and len(klass._meta.local_concrete_fields) != len(klass._meta.concrete_fields): # Only load those fields which haven't been already loaded into # 'from_parent'. non_seen_models = [p for p in klass._meta.get_parent_list() if not issubclass(from_parent, p)] # Load local fields, too... non_seen_models.append(klass) - field_names = [f.attname for f in klass._meta.fields + field_names = [f.attname for f in klass._meta.concrete_fields if f.model in non_seen_models] field_count = len(field_names) # Try to avoid populating field_names variable for perfomance reasons. # If field_names variable is set, we use **kwargs based model init # which is slower than normal init. - if field_count == len(klass._meta.fields): + if field_count == len(klass._meta.concrete_fields): field_names = () restricted = requested is not None diff --git a/django/db/models/related.py b/django/db/models/related.py index 53645bedb9..6c93076c48 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -7,7 +7,7 @@ from django.db.models.fields import BLANK_CHOICE_DASH # describe the relation in Model terms (model Options and Fields for both # sides of the relation. The join_field is the field backing the relation. PathInfo = namedtuple('PathInfo', - 'from_field to_field from_opts to_opts join_field ' + 'from_opts to_opts target_fields join_field ' 'm2m direct') class RelatedObject(object): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 4711ea6e19..1f19131ba2 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -2,10 +2,9 @@ import datetime from django.conf import settings from django.core.exceptions import FieldError -from django.db import transaction from django.db.backends.util import truncate_name from django.db.models.constants import LOOKUP_SEP -from django.db.models.query_utils import select_related_descend +from django.db.models.query_utils import select_related_descend, QueryWrapper from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet @@ -33,7 +32,7 @@ class SQLCompiler(object): # cleaned. We are not using a clone() of the query here. """ if not self.query.tables: - self.query.join((None, self.query.model._meta.db_table, None, None)) + self.query.join((None, self.query.model._meta.db_table, None)) if (not self.query.select and self.query.default_cols and not self.query.included_inherited_models): self.query.setup_inherited_models() @@ -273,7 +272,7 @@ class SQLCompiler(object): # be used by local fields. seen_models = {None: start_alias} - for field, model in opts.get_fields_with_model(): + for field, model in opts.get_concrete_fields_with_model(): if from_parent and model is not None and issubclass(from_parent, model): # Avoid loading data for already loaded parents. continue @@ -314,9 +313,10 @@ class SQLCompiler(object): for name in self.query.distinct_fields: parts = name.split(LOOKUP_SEP) - field, col, alias, _, _ = self._setup_joins(parts, opts, None) - col, alias = self._final_join_removal(col, alias) - result.append("%s.%s" % (qn(alias), qn2(col))) + field, cols, alias, _, _ = self._setup_joins(parts, opts, None) + cols, alias = self._final_join_removal(cols, alias) + for col in cols: + result.append("%s.%s" % (qn(alias), qn2(col))) return result @@ -387,15 +387,16 @@ class SQLCompiler(object): elif get_order_dir(field)[0] not in self.query.extra_select: # 'col' is of the form 'field' or 'field1__field2' or # '-field1__field2__field', etc. - for table, col, order in self.find_ordering_name(field, + for table, cols, order in self.find_ordering_name(field, self.query.model._meta, default_order=asc): - if (table, col) not in processed_pairs: - elt = '%s.%s' % (qn(table), qn2(col)) - processed_pairs.add((table, col)) - if distinct and elt not in select_aliases: - ordering_aliases.append(elt) - result.append('%s %s' % (elt, order)) - group_by.append((elt, [])) + for col in cols: + if (table, col) not in processed_pairs: + elt = '%s.%s' % (qn(table), qn2(col)) + processed_pairs.add((table, col)) + if distinct and elt not in select_aliases: + ordering_aliases.append(elt) + result.append('%s %s' % (elt, order)) + group_by.append((elt, [])) else: elt = qn2(col) if distinct and col not in select_aliases: @@ -414,7 +415,7 @@ class SQLCompiler(object): """ name, order = get_order_dir(name, default_order) pieces = name.split(LOOKUP_SEP) - field, col, alias, joins, opts = self._setup_joins(pieces, opts, alias) + field, cols, alias, joins, opts = self._setup_joins(pieces, opts, alias) # If we get to this point and the field is a relation to another model, # append the default ordering for that model. @@ -432,8 +433,8 @@ class SQLCompiler(object): results.extend(self.find_ordering_name(item, opts, alias, order, already_seen)) return results - col, alias = self._final_join_removal(col, alias) - return [(alias, col, order)] + cols, alias = self._final_join_removal(cols, alias) + return [(alias, cols, order)] def _setup_joins(self, pieces, opts, alias): """ @@ -446,13 +447,13 @@ class SQLCompiler(object): """ if not alias: alias = self.query.get_initial_alias() - field, target, opts, joins, _ = self.query.setup_joins( + field, targets, opts, joins, _ = self.query.setup_joins( pieces, opts, alias) # We will later on need to promote those joins that were added to the # query afresh above. joins_to_promote = [j for j in joins if self.query.alias_refcount[j] < 2] alias = joins[-1] - col = target.column + cols = [target.column for target in targets] if not field.rel: # To avoid inadvertent trimming of a necessary alias, use the # refcount to show that we are referencing a non-relation field on @@ -463,9 +464,9 @@ class SQLCompiler(object): # Ordering or distinct must not affect the returned set, and INNER # JOINS for nullable fields could do this. self.query.promote_joins(joins_to_promote) - return field, col, alias, joins, opts + return field, cols, alias, joins, opts - def _final_join_removal(self, col, alias): + def _final_join_removal(self, cols, alias): """ A helper method for get_distinct and get_ordering. This method will trim extra not-needed joins from the tail of the join chain. @@ -477,12 +478,14 @@ class SQLCompiler(object): if alias: while 1: join = self.query.alias_map[alias] - if col != join.rhs_join_col: + lhs_cols, rhs_cols = zip(*[(lhs_col, rhs_col) for lhs_col, rhs_col in join.join_cols]) + if set(cols) != set(rhs_cols): break + + cols = [lhs_cols[rhs_cols.index(col)] for col in cols] self.query.unref_alias(alias) alias = join.lhs_alias - col = join.lhs_join_col - return col, alias + return cols, alias def get_from_clause(self): """ @@ -504,22 +507,30 @@ class SQLCompiler(object): if not self.query.alias_refcount[alias]: continue try: - name, alias, join_type, lhs, lhs_col, col, _, join_field = self.query.alias_map[alias] + name, alias, join_type, lhs, join_cols, _, join_field = self.query.alias_map[alias] except KeyError: # Extra tables can end up in self.tables, but not in the # alias_map if they aren't in a join. That's OK. We skip them. continue alias_str = (alias != name and ' %s' % alias or '') if join_type and not first: - if join_field and hasattr(join_field, 'get_extra_join_sql'): - extra_cond, extra_params = join_field.get_extra_join_sql( - self.connection, qn, lhs, alias) + extra_cond = join_field.get_extra_restriction( + self.query.where_class, alias, lhs) + if extra_cond: + extra_sql, extra_params = extra_cond.as_sql( + qn, self.connection) + extra_sql = 'AND (%s)' % extra_sql from_params.extend(extra_params) else: - extra_cond = "" - result.append('%s %s%s ON (%s.%s = %s.%s%s)' % - (join_type, qn(name), alias_str, qn(lhs), - qn2(lhs_col), qn(alias), qn2(col), extra_cond)) + extra_sql = "" + result.append('%s %s%s ON (' + % (join_type, qn(name), alias_str)) + for index, (lhs_col, rhs_col) in enumerate(join_cols): + if index != 0: + result.append(' AND ') + result.append('%s.%s = %s.%s' % + (qn(lhs), qn2(lhs_col), qn(alias), qn2(rhs_col))) + result.append('%s)' % extra_sql) else: connector = not first and ', ' or '' result.append('%s%s%s' % (connector, qn(name), alias_str)) @@ -545,7 +556,7 @@ class SQLCompiler(object): select_cols = self.query.select + self.query.related_select_cols # Just the column, not the fields. select_cols = [s[0] for s in select_cols] - if (len(self.query.model._meta.fields) == len(self.query.select) + if (len(self.query.model._meta.concrete_fields) == len(self.query.select) and self.connection.features.allows_group_by_pk): self.query.group_by = [ (self.query.model._meta.db_table, self.query.model._meta.pk.column) @@ -623,14 +634,13 @@ class SQLCompiler(object): table = f.rel.to._meta.db_table promote = nullable or f.null alias = self.query.join_parent_model(opts, model, root_alias, {}) - - alias = self.query.join((alias, table, f.column, - f.rel.get_related_field().column), + join_cols = f.get_joining_columns() + alias = self.query.join((alias, table, join_cols), outer_if_first=promote, join_field=f) columns, aliases = self.get_default_columns(start_alias=alias, opts=f.rel.to._meta, as_pairs=True) self.query.related_select_cols.extend( - SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields)) + SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.concrete_fields)) if restricted: next = requested.get(f.name, {}) else: @@ -653,7 +663,7 @@ class SQLCompiler(object): alias = self.query.join_parent_model(opts, f.rel.to, root_alias, {}) table = model._meta.db_table alias = self.query.join( - (alias, table, f.rel.get_related_field().column, f.column), + (alias, table, f.get_joining_columns(reverse_join=True)), outer_if_first=True, join_field=f ) from_parent = (opts.model if issubclass(model, opts.model) @@ -662,7 +672,7 @@ class SQLCompiler(object): opts=model._meta, as_pairs=True, from_parent=from_parent) self.query.related_select_cols.extend( SelectInfo(col, field) for col, field - in zip(columns, model._meta.fields)) + in zip(columns, model._meta.concrete_fields)) next = requested.get(f.related_query_name(), {}) # Use True here because we are looking at the _reverse_ side of # the relation, which is always nullable. @@ -706,7 +716,7 @@ class SQLCompiler(object): if self.query.select: fields = [f.field for f in self.query.select] else: - fields = self.query.model._meta.fields + fields = self.query.model._meta.concrete_fields fields = fields + [f.field for f in self.query.related_select_cols] # If the field was deferred, exclude it from being passed @@ -776,6 +786,22 @@ class SQLCompiler(object): return list(result) return result + def as_subquery_condition(self, alias, columns): + qn = self.quote_name_unless_alias + qn2 = self.connection.ops.quote_name + if len(columns) == 1: + sql, params = self.as_sql() + return '%s.%s IN (%s)' % (qn(alias), qn2(columns[0]), sql), params + + for index, select_col in enumerate(self.query.select): + lhs = '%s.%s' % (qn(select_col.col[0]), qn2(select_col.col[1])) + rhs = '%s.%s' % (qn(alias), qn2(columns[index])) + self.query.where.add( + QueryWrapper('%s = %s' % (lhs, rhs), []), 'AND') + + sql, params = self.as_sql() + return 'EXISTS (%s)' % sql, params + class SQLInsertCompiler(SQLCompiler): def placeholder(self, field, val): diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 81bd646d69..904f7b2c8b 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -25,7 +25,7 @@ GET_ITERATOR_CHUNK_SIZE = 100 # dictionary in the Query class). JoinInfo = namedtuple('JoinInfo', 'table_name rhs_alias join_type lhs_alias ' - 'lhs_join_col rhs_join_col nullable join_field') + 'join_cols nullable join_field') # Pairs of column clauses to select, and (possibly None) field for the clause. SelectInfo = namedtuple('SelectInfo', 'col field') diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 389099161a..62adf79d87 100644 --- a/django/db/models/sql/expressions.py +++ b/django/db/models/sql/expressions.py @@ -55,13 +55,14 @@ class SQLEvaluator(object): self.cols.append((node, query.aggregate_select[node.name])) else: try: - field, source, opts, join_list, path = query.setup_joins( + field, sources, opts, join_list, path = query.setup_joins( field_list, query.get_meta(), query.get_initial_alias(), self.reuse) - target, _, join_list = query.trim_joins(source, join_list, path) + targets, _, join_list = query.trim_joins(sources, join_list, path) if self.reuse is not None: self.reuse.update(join_list) - self.cols.append((node, (join_list[-1], target.column))) + for t in targets: + self.cols.append((node, (join_list[-1], t.column))) except FieldDoesNotExist: raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (self.name, diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 2953d8cdaf..fb42cfc5db 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -452,13 +452,13 @@ class Query(object): # Now, add the joins from rhs query into the new query (skipping base # table). for alias in rhs.tables[1:]: - table, _, join_type, lhs, lhs_col, col, nullable, join_field = rhs.alias_map[alias] + table, _, join_type, lhs, join_cols, nullable, join_field = rhs.alias_map[alias] promote = (join_type == self.LOUTER) # If the left side of the join was already relabeled, use the # updated alias. lhs = change_map.get(lhs, lhs) new_alias = self.join( - (lhs, table, lhs_col, col), reuse=reuse, + (lhs, table, join_cols), reuse=reuse, outer_if_first=not conjunction, nullable=nullable, join_field=join_field) if promote: @@ -682,7 +682,7 @@ class Query(object): aliases = list(aliases) while aliases: alias = aliases.pop(0) - if self.alias_map[alias].rhs_join_col is None: + if self.alias_map[alias].join_cols[0][1] is None: # This is the base table (first FROM entry) - this table # isn't really joined at all in the query, so we should not # alter its join type. @@ -818,7 +818,7 @@ class Query(object): alias = self.tables[0] self.ref_alias(alias) else: - alias = self.join((None, self.model._meta.db_table, None, None)) + alias = self.join((None, self.model._meta.db_table, None)) return alias def count_active_tables(self): @@ -834,11 +834,12 @@ class Query(object): """ Returns an alias for the join in 'connection', either reusing an existing alias for that join or creating a new one. 'connection' is a - tuple (lhs, table, lhs_col, col) where 'lhs' is either an existing - table alias or a table name. The join correspods to the SQL equivalent - of:: + tuple (lhs, table, join_cols) where 'lhs' is either an existing + table alias or a table name. 'join_cols' is a tuple of tuples containing + columns to join on ((l_id1, r_id1), (l_id2, r_id2)). The join corresponds + to the SQL equivalent of:: - lhs.lhs_col = table.col + lhs.l_id1 = table.r_id1 AND lhs.l_id2 = table.r_id2 The 'reuse' parameter can be either None which means all joins (matching the connection) are reusable, or it can be a set containing @@ -855,7 +856,7 @@ class Query(object): The 'join_field' is the field we are joining along (if any). """ - lhs, table, lhs_col, col = connection + lhs, table, join_cols = connection assert lhs is None or join_field is not None existing = self.join_map.get(connection, ()) if reuse is None: @@ -884,7 +885,7 @@ class Query(object): join_type = self.LOUTER else: join_type = self.INNER - join = JoinInfo(table, alias, join_type, lhs, lhs_col, col, nullable, + join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable, join_field) self.alias_map[alias] = join if connection in self.join_map: @@ -941,7 +942,7 @@ class Query(object): continue link_field = int_opts.get_ancestor_link(int_model) int_opts = int_model._meta - connection = (alias, int_opts.db_table, link_field.column, int_opts.pk.column) + connection = (alias, int_opts.db_table, link_field.get_joining_columns()) alias = seen[int_model] = self.join(connection, nullable=False, join_field=link_field) return alias or seen[None] @@ -982,18 +983,20 @@ class Query(object): # - this is an annotation over a model field # then we need to explore the joins that are required. - field, source, opts, join_list, path = self.setup_joins( + field, sources, opts, join_list, path = self.setup_joins( field_list, opts, self.get_initial_alias()) # Process the join chain to see if it can be trimmed - target, _, join_list = self.trim_joins(source, join_list, path) + targets, _, join_list = self.trim_joins(sources, join_list, path) # If the aggregate references a model or field that requires a join, # those joins must be LEFT OUTER - empty join rows must be returned # in order for zeros to be returned for those aggregates. self.promote_joins(join_list, True) - col = (join_list[-1], target.column) + col = targets[0].column + source = sources[0] + col = (join_list[-1], col) else: # The simplest cases. No joins required - # just reference the provided column alias. @@ -1086,7 +1089,7 @@ class Query(object): allow_many = not branch_negated try: - field, target, opts, join_list, path = self.setup_joins( + field, sources, opts, join_list, path = self.setup_joins( parts, opts, alias, can_reuse, allow_many, allow_explicit_fk=True) if can_reuse is not None: @@ -1106,13 +1109,19 @@ class Query(object): # the far end (fewer tables in a query is better). Note that join # promotion must happen before join trimming to have the join type # information available when reusing joins. - target, alias, join_list = self.trim_joins(target, join_list, path) - clause.add((Constraint(alias, target.column, field), lookup_type, value), - AND) + targets, alias, join_list = self.trim_joins(sources, join_list, path) + + if hasattr(field, 'get_lookup_constraint'): + constraint = field.get_lookup_constraint(self.where_class, alias, targets, sources, + lookup_type, value) + else: + constraint = (Constraint(alias, targets[0].column, field), lookup_type, value) + clause.add(constraint, AND) if current_negated and (lookup_type != 'isnull' or value is False): self.promote_joins(join_list) if (lookup_type != 'isnull' and ( - self.is_nullable(target) or self.alias_map[join_list[-1]].join_type == self.LOUTER)): + self.is_nullable(targets[0]) or + self.alias_map[join_list[-1]].join_type == self.LOUTER)): # The condition added here will be SQL like this: # NOT (col IS NOT NULL), where the first NOT is added in # upper layers of code. The reason for addition is that if col @@ -1122,7 +1131,7 @@ class Query(object): # (col IS NULL OR col != someval) # <=> # NOT (col IS NOT NULL AND col = someval). - clause.add((Constraint(alias, target.column, None), 'isnull', False), AND) + clause.add((Constraint(alias, targets[0].column, None), 'isnull', False), AND) return clause def add_filter(self, filter_clause): @@ -1272,22 +1281,26 @@ class Query(object): opts = int_model._meta else: final_field = opts.parents[int_model] - target = final_field.rel.get_related_field() + targets = (final_field.rel.get_related_field(),) opts = int_model._meta - path.append(PathInfo(final_field, target, final_field.model._meta, - opts, final_field, False, True)) + path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True)) if hasattr(field, 'get_path_info'): - pathinfos, opts, target, final_field = field.get_path_info() + pathinfos = field.get_path_info() if not allow_many: for inner_pos, p in enumerate(pathinfos): if p.m2m: names_with_path.append((name, pathinfos[0:inner_pos + 1])) raise MultiJoin(pos + 1, names_with_path) + last = pathinfos[-1] path.extend(pathinfos) + final_field = last.join_field + opts = last.to_opts + targets = last.target_fields names_with_path.append((name, pathinfos)) else: # Local non-relational field. - final_field = target = field + final_field = field + targets = (field,) break if pos != len(names) - 1: @@ -1297,7 +1310,7 @@ class Query(object): "the lookup type?" % (name, names[pos + 1])) else: raise FieldError("Join on field %r not permitted." % name) - return path, final_field, target + return path, final_field, targets def setup_joins(self, names, opts, alias, can_reuse=None, allow_many=True, allow_explicit_fk=False): @@ -1330,7 +1343,7 @@ class Query(object): """ joins = [alias] # First, generate the path for the names - path, final_field, target = self.names_to_path( + path, final_field, targets = self.names_to_path( names, opts, allow_many, allow_explicit_fk) # Then, add the path to the query's joins. Note that we can't trim # joins at this stage - we will need the information about join type @@ -1338,17 +1351,19 @@ class Query(object): for pos, join in enumerate(path): opts = join.to_opts if join.direct: - nullable = self.is_nullable(join.from_field) + nullable = self.is_nullable(join.join_field) else: nullable = True - connection = alias, opts.db_table, join.from_field.column, join.to_field.column + connection = alias, opts.db_table, join.join_field.get_joining_columns() reuse = can_reuse if join.m2m else None alias = self.join(connection, reuse=reuse, nullable=nullable, join_field=join.join_field) joins.append(alias) - return final_field, target, opts, joins, path + if hasattr(final_field, 'field'): + final_field = final_field.field + return final_field, targets, opts, joins, path - def trim_joins(self, target, joins, path): + def trim_joins(self, targets, joins, path): """ The 'target' parameter is the final field being joined to, 'joins' is the full list of join aliases. The 'path' contain the PathInfos @@ -1362,13 +1377,16 @@ class Query(object): trimmed as we don't know if there is anything on the other side of the join. """ - for info in reversed(path): - if info.to_field == target and info.direct: - target = info.from_field - self.unref_alias(joins.pop()) - else: + for pos, info in enumerate(reversed(path)): + if len(joins) == 1 or not info.direct: break - return target, joins[-1], joins + join_targets = set(t.column for t in info.join_field.foreign_related_fields) + cur_targets = set(t.column for t in targets) + if not cur_targets.issubset(join_targets): + break + targets = tuple(r[0] for r in info.join_field.related_fields if r[1].column in cur_targets) + self.unref_alias(joins.pop()) + return targets, joins[-1], joins def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path): """ @@ -1413,17 +1431,31 @@ class Query(object): trimmed_prefix = [] paths_in_prefix = trimmed_joins for name, path in names_with_path: - if paths_in_prefix - len(path) > 0: - trimmed_prefix.append(name) - paths_in_prefix -= len(path) - else: - trimmed_prefix.append( - path[paths_in_prefix - len(path)].from_field.name) + if paths_in_prefix - len(path) < 0: break + trimmed_prefix.append(name) + paths_in_prefix -= len(path) + join_field = path[paths_in_prefix].join_field + # TODO: This should be made properly multicolumn + # join aware. It is likely better to not use build_filter + # at all, instead construct joins up to the correct point, + # then construct the needed equality constraint manually, + # or maybe using SubqueryConstraint would work, too. + # The foreign_related_fields attribute is right here, we + # don't ever split joins for direct case. + trimmed_prefix.append( + join_field.field.foreign_related_fields[0].name) trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix) - return self.build_filter( + condition = self.build_filter( ('%s__in' % trimmed_prefix, query), current_negated=True, branch_negated=True, can_reuse=can_reuse) + # Intentionally leave the other alias as blank, if the condition + # refers it, things will break here. + extra_restriction = join_field.get_extra_restriction( + self.where_class, None, [t for t in query.tables if query.alias_refcount[t]][0]) + if extra_restriction: + query.where.add(extra_restriction, 'AND') + return condition def set_empty(self): self.where = EmptyWhere() @@ -1502,20 +1534,17 @@ class Query(object): try: for name in field_names: - field, target, u2, joins, u3 = self.setup_joins( + field, targets, u2, joins, path = self.setup_joins( name.split(LOOKUP_SEP), opts, alias, None, allow_m2m, True) - final_alias = joins[-1] - col = target.column - if len(joins) > 1: - join = self.alias_map[final_alias] - if col == join.rhs_join_col: - self.unref_alias(final_alias) - final_alias = join.lhs_alias - col = join.lhs_join_col - joins = joins[:-1] + + # Trim last join if possible + targets, final_alias, remaining_joins = self.trim_joins(targets, joins[-2:], path) + joins = joins[:-2] + remaining_joins + self.promote_joins(joins[1:]) - self.select.append(SelectInfo((final_alias, col), field)) + for target in targets: + self.select.append(SelectInfo((final_alias, target.column), target)) except MultiJoin: raise FieldError("Invalid field name: '%s'" % name) except FieldError: @@ -1590,7 +1619,7 @@ class Query(object): opts = self.model._meta if not self.select: count = self.aggregates_module.Count( - (self.join((None, opts.db_table, None, None)), opts.pk.column), + (self.join((None, opts.db_table, None)), opts.pk.column), is_summary=True, distinct=True) else: # Because of SQL portability issues, multi-column, distinct @@ -1792,22 +1821,27 @@ class Query(object): in "WHERE somecol IN (subquery)". This construct is needed by split_exclude(). _""" - join_pos = 0 + all_paths = [] for _, paths in names_with_path: - for path in paths: - peek = self.tables[join_pos + 1] - if self.alias_map[peek].join_type == self.LOUTER: - # Back up one level and break - select_alias = self.tables[join_pos] - select_field = path.from_field - break - select_alias = self.tables[join_pos + 1] - select_field = path.to_field - self.unref_alias(self.tables[join_pos]) - join_pos += 1 - self.select = [SelectInfo((select_alias, select_field.column), select_field)] + all_paths.extend(paths) + direct_join = True + for pos, path in enumerate(all_paths): + if self.alias_map[self.tables[pos + 1]].join_type == self.LOUTER: + direct_join = False + pos -= 1 + break + self.unref_alias(self.tables[pos]) + if path.direct: + direct_join = not direct_join + join_side = 0 if direct_join else 1 + select_alias = self.tables[pos + 1] + join_field = path.join_field + if hasattr(join_field, 'field'): + join_field = join_field.field + select_fields = [r[join_side] for r in join_field.related_fields] + self.select = [SelectInfo((select_alias, f.column), f) for f in select_fields] self.remove_inherited_models() - return join_pos + return pos def is_nullable(self, field): """ diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 42682c342e..c738c914d1 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -382,3 +382,28 @@ class Constraint(object): new.__class__ = self.__class__ new.alias, new.col, new.field = change_map[self.alias], self.col, self.field return new + +class SubqueryConstraint(object): + def __init__(self, alias, columns, targets, query_object): + self.alias = alias + self.columns = columns + self.targets = targets + self.query_object = query_object + + def as_sql(self, qn, connection): + query = self.query_object + + # QuerySet was sent + if hasattr(query, 'values'): + # as_sql should throw if we are using a + # connection on another database + query._as_sql(connection=connection) + query = query.values(*self.targets).query + + query_compiler = query.get_compiler(connection=connection) + return query_compiler.as_subquery_condition(self.alias, self.columns) + + def relabeled_clone(self, relabels): + return self.__class__( + relabels.get(self.alias, self.alias), + self.columns, self.query_object) diff --git a/django/forms/models.py b/django/forms/models.py index 0672bafc47..39d753b1a6 100644 --- a/django/forms/models.py +++ b/django/forms/models.py @@ -110,7 +110,7 @@ def model_to_dict(instance, fields=None, exclude=None): from django.db.models.fields.related import ManyToManyField opts = instance._meta data = {} - for f in opts.fields + opts.many_to_many: + for f in opts.concrete_fields + opts.many_to_many: if not f.editable: continue if fields and not f.name in fields: @@ -149,7 +149,7 @@ def fields_for_model(model, fields=None, exclude=None, widgets=None, formfield_c field_list = [] ignored = [] opts = model._meta - for f in sorted(opts.fields + opts.many_to_many): + for f in sorted(opts.concrete_fields + opts.many_to_many): if not f.editable: continue if fields is not None and not f.name in fields: