Fixed #17002 -- Allowed using a ManyToManyField through model that inherits another.

This commit is contained in:
InvalidInterrupt 2016-08-18 15:42:11 -07:00 committed by Tim Graham
parent b5f0b3478d
commit 98359109eb
7 changed files with 147 additions and 17 deletions

View File

@ -263,6 +263,7 @@ answer newbie questions, and generally made Django that much better:
Gabriel Grant <g@briel.ca>
Gabriel Hurley <gabriel@strikeawe.com>
gandalf@owca.info
Garry Lawrence
Garry Polley <garrympolley@gmail.com>
Garth Kidd <http://www.deadlybloodyserious.com/>
Gary Wilson <gary.wilson@gmail.com>

View File

@ -1529,7 +1529,21 @@ class ManyToManyField(RelatedField):
else:
join1infos = linkfield2.get_reverse_path_info()
join2infos = linkfield1.get_path_info()
# Get join infos between the last model of join 1 and the first model
# of join 2. Assume the only reason these may differ is due to model
# inheritance.
join1_final = join1infos[-1].to_opts
join2_initial = join2infos[0].from_opts
if join1_final is join2_initial:
intermediate_infos = []
elif issubclass(join1_final.model, join2_initial.model):
intermediate_infos = join1_final.get_path_to_parent(join2_initial.model)
else:
intermediate_infos = join2_initial.get_path_from_parent(join1_final.model)
pathinfos.extend(join1infos)
pathinfos.extend(intermediate_infos)
pathinfos.extend(join2infos)
return pathinfos

View File

@ -895,7 +895,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# For non-autocreated 'through' models, can't assume we are
# dealing with PK values.
fk = self.through._meta.get_field(self.source_field_name)
join_table = self.through._meta.db_table
join_table = fk.model._meta.db_table
connection = connections[queryset.db]
qn = connection.ops.quote_name
queryset = queryset.extra(select={

View File

@ -14,6 +14,7 @@ from django.db.models import Manager
from django.db.models.fields import AutoField
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import OneToOneField
from django.db.models.query_utils import PathInfo
from django.utils import six
from django.utils.datastructures import ImmutableList, OrderedSet
from django.utils.deprecation import (
@ -670,6 +671,50 @@ class Options(object):
# links
return self.parents[parent] or parent_link
def get_path_to_parent(self, parent):
"""
Return a list of PathInfos containing the path from the current
model to the parent model, or an empty list if parent is not a
parent of the current model.
"""
if self.model is parent:
return []
# Skip the chain of proxy to the concrete proxied model.
proxied_model = self.concrete_model
path = []
opts = self
for int_model in self.get_base_chain(parent):
if int_model is proxied_model:
opts = int_model._meta
else:
final_field = opts.parents[int_model]
targets = (final_field.remote_field.get_related_field(),)
opts = int_model._meta
path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
return path
def get_path_from_parent(self, parent):
"""
Return a list of PathInfos containing the path from the parent
model to the current model, or an empty list if parent is not a
parent of the current model.
"""
if self.model is parent:
return []
model = self.concrete_model
# Get a reversed base chain including both the current and parent
# models.
chain = model._meta.get_base_chain(parent)
chain.reverse()
chain.append(model)
# Construct a list of the PathInfos between models in chain.
path = []
for i, ancestor in enumerate(chain[:-1]):
child = chain[i + 1]
link = child._meta.get_ancestor_link(ancestor)
path.extend(link.get_reverse_path_info())
return path
def _populate_directed_relation_graph(self):
"""
This method is used by each model to find its reverse objects. As this

View File

@ -19,7 +19,7 @@ from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Col, Ref
from django.db.models.fields.related_lookups import MultiColSource
from django.db.models.query_utils import (
PathInfo, Q, check_rel_lookup_compatibility, refs_expression,
Q, check_rel_lookup_compatibility, refs_expression,
)
from django.db.models.sql.constants import (
INNER, LOUTER, ORDER_DIR, ORDER_PATTERN, QUERY_TERMS, SINGLE,
@ -1342,21 +1342,11 @@ class Query(object):
# field lives in parent, but we are currently in one of its
# children)
if model is not opts.model:
# The field lives on a base class of the current model.
# Skip the chain of proxy to the concrete proxied model
proxied_model = opts.concrete_model
for int_model in opts.get_base_chain(model):
if int_model is proxied_model:
opts = int_model._meta
else:
final_field = opts.parents[int_model]
targets = (final_field.remote_field.get_related_field(),)
opts = int_model._meta
path.append(PathInfo(final_field.model._meta, opts, targets, final_field, False, True))
cur_names_with_path[1].append(
PathInfo(final_field.model._meta, opts, targets, final_field, False, True)
)
path_to_parent = opts.get_path_to_parent(model)
if path_to_parent:
path.extend(path_to_parent)
cur_names_with_path[1].extend(path_to_parent)
opts = path_to_parent[-1].to_opts
if hasattr(field, 'get_path_info'):
pathinfos = field.get_path_info()
if not allow_many:

View File

@ -94,3 +94,32 @@ class CarDriver(models.Model):
def __str__(self):
return "pk=%s car=%s driver=%s" % (str(self.pk), self.car, self.driver)
# Through models using multi-table inheritance
class Event(models.Model):
name = models.CharField(max_length=50, unique=True)
people = models.ManyToManyField('Person', through='IndividualCompetitor')
special_people = models.ManyToManyField(
'Person',
through='ProxiedIndividualCompetitor',
related_name='special_event_set',
)
teams = models.ManyToManyField('Group', through='CompetingTeam')
class Competitor(models.Model):
event = models.ForeignKey(Event, models.CASCADE)
class IndividualCompetitor(Competitor):
person = models.ForeignKey(Person, models.CASCADE)
class CompetingTeam(Competitor):
team = models.ForeignKey(Group, models.CASCADE)
class ProxiedIndividualCompetitor(IndividualCompetitor):
class Meta:
proxy = True

View File

@ -0,0 +1,51 @@
from __future__ import unicode_literals
from django.test import TestCase
from .models import (
CompetingTeam, Event, Group, IndividualCompetitor, Membership, Person,
)
class MultiTableTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.alice = Person.objects.create(name='Alice')
cls.bob = Person.objects.create(name='Bob')
cls.chris = Person.objects.create(name='Chris')
cls.dan = Person.objects.create(name='Dan')
cls.team_alpha = Group.objects.create(name='Alpha')
Membership.objects.create(person=cls.alice, group=cls.team_alpha)
Membership.objects.create(person=cls.bob, group=cls.team_alpha)
cls.event = Event.objects.create(name='Exposition Match')
IndividualCompetitor.objects.create(event=cls.event, person=cls.chris)
IndividualCompetitor.objects.create(event=cls.event, person=cls.dan)
CompetingTeam.objects.create(event=cls.event, team=cls.team_alpha)
def test_m2m_query(self):
result = self.event.teams.all()
self.assertCountEqual(result, [self.team_alpha])
def test_m2m_reverse_query(self):
result = self.chris.event_set.all()
self.assertCountEqual(result, [self.event])
def test_m2m_query_proxied(self):
result = self.event.special_people.all()
self.assertCountEqual(result, [self.chris, self.dan])
def test_m2m_reverse_query_proxied(self):
result = self.chris.special_event_set.all()
self.assertCountEqual(result, [self.event])
def test_m2m_prefetch_proxied(self):
result = Event.objects.filter(name='Exposition Match').prefetch_related('special_people')
with self.assertNumQueries(2):
self.assertCountEqual(result, [self.event])
self.assertEqual(sorted([p.name for p in result[0].special_people.all()]), ['Chris', 'Dan'])
def test_m2m_prefetch_reverse_proxied(self):
result = Person.objects.filter(name='Dan').prefetch_related('special_event_set')
with self.assertNumQueries(2):
self.assertCountEqual(result, [self.dan])
self.assertEqual([event.name for event in result[0].special_event_set.all()], ['Exposition Match'])