diff --git a/django/contrib/contenttypes/generic.py b/django/contrib/contenttypes/generic.py index 6aff07e568..be7a5e5a22 100644 --- a/django/contrib/contenttypes/generic.py +++ b/django/contrib/contenttypes/generic.py @@ -11,6 +11,7 @@ from django.db import connection from django.db.models import signals from django.db import models, router, DEFAULT_DB_ALIAS from django.db.models.fields.related import RelatedField, Field, ManyToManyRel +from django.db.models.related import PathInfo from django.forms import ModelForm from django.forms.models import BaseModelFormSet, modelformset_factory, save_instance from django.contrib.admin.options import InlineModelAdmin, flatten_fieldsets @@ -160,6 +161,16 @@ class GenericRelation(RelatedField, Field): kwargs['serialize'] = False Field.__init__(self, **kwargs) + def get_path_info(self): + from_field = self.model._meta.pk + opts = self.rel.to._meta + target = opts.get_field_by_name(self.object_id_field_name)[0] + # Note that we are using different field for the join_field + # than from_field or to_field. This is a hack, but we need the + # GenericRelation to generate the extra SQL. + return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)], + opts, target, self) + def get_choices_default(self): return Field.get_choices(self, include_blank=False) diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 90fe69e23c..4b6a5b0aed 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -5,7 +5,7 @@ from django.db.backends import util from django.db.models import signals, get_model from django.db.models.fields import (AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) -from django.db.models.related import RelatedObject +from django.db.models.related import RelatedObject, PathInfo from django.db.models.query import QuerySet from django.db.models.query_utils import QueryWrapper from django.db.models.deletion import CASCADE @@ -16,7 +16,6 @@ from django.utils.functional import curry, cached_property from django.core import exceptions from django import forms - RECURSIVE_RELATIONSHIP_CONSTANT = 'self' pending_lookups = {} @@ -1004,6 +1003,31 @@ class ForeignKey(RelatedField, Field): ) Field.__init__(self, **kwargs) + def get_path_info(self): + """ + Get path from this field to the related model. + """ + opts = self.rel.to._meta + target = self.rel.get_related_field() + from_opts = self.model._meta + return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self + + def get_reverse_path_info(self): + """ + Get path from the related model to this field's model. + """ + opts = self.model._meta + from_field = self.rel.get_related_field() + from_opts = from_field.model._meta + pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)] + if from_field.model is self.model: + # Recursive foreign key to self. + target = opts.get_field_by_name( + self.rel.field_name)[0] + else: + target = opts.pk + return pathinfos, opts, target, self + def validate(self, value, model_instance): if self.rel.parent_link: return @@ -1198,6 +1222,30 @@ class ManyToManyField(RelatedField, Field): msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.') self.help_text = string_concat(self.help_text, ' ', msg) + def _get_path_info(self, direct=False): + """ + Called by both direct an indirect m2m traversal. + """ + pathinfos = [] + int_model = self.rel.through + linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0] + linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0] + if direct: + join1infos, _, _, _ = linkfield1.get_reverse_path_info() + join2infos, opts, target, final_field = linkfield2.get_path_info() + else: + join1infos, _, _, _ = linkfield2.get_reverse_path_info() + join2infos, opts, target, final_field = linkfield1.get_path_info() + pathinfos.extend(join1infos) + pathinfos.extend(join2infos) + return pathinfos, opts, target, final_field + + def get_path_info(self): + return self._get_path_info(direct=True) + + def get_reverse_path_info(self): + return self._get_path_info(direct=False) + def get_choices_default(self): return Field.get_choices(self, include_blank=False) diff --git a/django/db/models/related.py b/django/db/models/related.py index a0dcec7132..702853533d 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -1,6 +1,15 @@ +from collections import namedtuple + from django.utils.encoding import smart_text from django.db.models.fields import BLANK_CHOICE_DASH +# PathInfo is used when converting lookups (fk__somecol). The contents +# describe the relation in Model terms (model Options and Fields for both +# sides of the relation. The join_field is the field backing the relation. +PathInfo = namedtuple('PathInfo', + 'from_field to_field from_opts to_opts join_field ' + 'm2m direct') + class BoundRelatedObject(object): def __init__(self, related_object, field_mapping, original): self.relation = related_object @@ -67,3 +76,6 @@ class RelatedObject(object): def get_cache_name(self): return "_%s_cache" % self.get_accessor_name() + + def get_path_info(self): + return self.field.get_reverse_path_info() diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py index 9f82f426ed..1764db7fcc 100644 --- a/django/db/models/sql/constants.py +++ b/django/db/models/sql/constants.py @@ -26,12 +26,6 @@ JoinInfo = namedtuple('JoinInfo', 'table_name rhs_alias join_type lhs_alias ' 'lhs_join_col rhs_join_col nullable join_field') -# PathInfo is used when converting lookups (fk__somecol). The contents -# describe the join in Model terms (model Options and Fields for both -# sides of the join. The rel_field is the field we are joining along. -PathInfo = namedtuple('PathInfo', - 'from_field to_field from_opts to_opts join_field') - # Pairs of column clauses to select, and (possibly None) field for the clause. SelectInfo = namedtuple('SelectInfo', 'col field') diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 841452636b..e5833b2b51 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -18,9 +18,10 @@ from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import ExpressionNode from django.db.models.fields import FieldDoesNotExist from django.db.models.loading import get_model +from django.db.models.related import PathInfo from django.db.models.sql import aggregates as base_aggregates_module from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE, - ORDER_PATTERN, JoinInfo, SelectInfo, PathInfo) + ORDER_PATTERN, JoinInfo, SelectInfo) from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin from django.db.models.sql.expressions import SQLEvaluator from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode, @@ -1294,7 +1295,6 @@ class Query(object): contain the same value as the final field). """ path = [] - multijoin_pos = None for pos, name in enumerate(names): if name == 'pk': name = opts.pk.name @@ -1328,92 +1328,19 @@ class Query(object): target = final_field.rel.get_related_field() opts = int_model._meta path.append(PathInfo(final_field, target, final_field.model._meta, - opts, final_field)) - # We have five different cases to solve: foreign keys, reverse - # foreign keys, m2m fields (also reverse) and non-relational - # fields. We are mostly just using the related field API to - # fetch the from and to fields. The m2m fields are handled as - # two foreign keys, first one reverse, the second one direct. - if direct and not field.rel and not m2m: + opts, final_field, False, True)) + if hasattr(field, 'get_path_info'): + pathinfos, opts, target, final_field = field.get_path_info() + path.extend(pathinfos) + else: # Local non-relational field. final_field = target = field break - elif direct and not m2m: - # Foreign Key - opts = field.rel.to._meta - target = field.rel.get_related_field() - final_field = field - from_opts = field.model._meta - path.append(PathInfo(field, target, from_opts, opts, field)) - elif not direct and not m2m: - # Revere foreign key - final_field = to_field = field.field - opts = to_field.model._meta - from_field = to_field.rel.get_related_field() - from_opts = from_field.model._meta - path.append( - PathInfo(from_field, to_field, from_opts, opts, to_field)) - if from_field.model is to_field.model: - # Recursive foreign key to self. - target = opts.get_field_by_name( - field.field.rel.field_name)[0] - else: - target = opts.pk - elif direct and m2m: - if not field.rel.through: - # Gotcha! This is just a fake m2m field - a generic relation - # field). - from_field = opts.pk - opts = field.rel.to._meta - target = opts.get_field_by_name(field.object_id_field_name)[0] - final_field = field - # Note that we are using different field for the join_field - # than from_field or to_field. This is a hack, but we need the - # GenericRelation to generate the extra SQL. - path.append(PathInfo(from_field, target, field.model._meta, opts, - field)) - else: - # m2m field. We are travelling first to the m2m table along a - # reverse relation, then from m2m table to the target table. - from_field1 = opts.get_field_by_name( - field.m2m_target_field_name())[0] - opts = field.rel.through._meta - to_field1 = opts.get_field_by_name(field.m2m_field_name())[0] - path.append( - PathInfo(from_field1, to_field1, from_field1.model._meta, - opts, to_field1)) - final_field = from_field2 = opts.get_field_by_name( - field.m2m_reverse_field_name())[0] - opts = field.rel.to._meta - target = to_field2 = opts.get_field_by_name( - field.m2m_reverse_target_field_name())[0] - path.append( - PathInfo(from_field2, to_field2, from_field2.model._meta, - opts, from_field2)) - elif not direct and m2m: - # This one is just like above, except we are travelling the - # fields in opposite direction. - field = field.field - from_field1 = opts.get_field_by_name( - field.m2m_reverse_target_field_name())[0] - int_opts = field.rel.through._meta - to_field1 = int_opts.get_field_by_name( - field.m2m_reverse_field_name())[0] - path.append( - PathInfo(from_field1, to_field1, from_field1.model._meta, - int_opts, to_field1)) - final_field = from_field2 = int_opts.get_field_by_name( - field.m2m_field_name())[0] - opts = field.opts - target = to_field2 = opts.get_field_by_name( - field.m2m_target_field_name())[0] - path.append(PathInfo(from_field2, to_field2, from_field2.model._meta, - opts, from_field2)) - - if m2m and multijoin_pos is None: - multijoin_pos = pos - if not direct and not path[-1].to_field.unique and multijoin_pos is None: - multijoin_pos = pos + multijoin_pos = None + for m2mpos, pathinfo in enumerate(path): + if pathinfo.m2m: + multijoin_pos = m2mpos + break if pos != len(names) - 1: if pos == len(names) - 2: @@ -1463,16 +1390,15 @@ class Query(object): # joins at this stage - we will need the information about join type # of the trimmed joins. for pos, join in enumerate(path): - from_field, to_field, from_opts, opts, join_field = join - direct = join_field == from_field - if direct: - nullable = self.is_nullable(from_field) + opts = join.to_opts + if join.direct: + nullable = self.is_nullable(join.from_field) else: nullable = True - connection = alias, opts.db_table, from_field.column, to_field.column - reuse = None if direct or to_field.unique else can_reuse + connection = alias, opts.db_table, join.from_field.column, join.to_field.column + reuse = can_reuse if join.m2m else None alias = self.join(connection, reuse=reuse, - nullable=nullable, join_field=join_field) + nullable=nullable, join_field=join.join_field) joins.append(alias) return final_field, target, opts, joins, path