Merged ManyRelatedObjectsDescriptor and ReverseManyRelatedObjectsDescriptor

and made all "many" related objects descriptors inherit from
ForeignRelatedObjectsDescriptor.
This commit is contained in:
Loic Bistuer 2015-02-16 16:03:34 +07:00
parent d652906aeb
commit c5a77721e2
4 changed files with 124 additions and 224 deletions

View File

@ -8,9 +8,12 @@ from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction
from django.db.models import DO_NOTHING, signals
from django.db.models.base import ModelBase
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.fields.related import (
ForeignObject, ForeignObjectRel, ForeignRelatedObjectsDescriptor,
)
from django.db.models.query_utils import PathInfo
from django.utils.encoding import python_2_unicode_compatible, smart_text
from django.utils.functional import cached_property
@python_2_unicode_compatible
@ -358,7 +361,7 @@ class GenericRelation(ForeignObject):
# Save a reference to which model this class is on for future use
self.model = cls
# Add the descriptor for the relation
setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self, self.for_concrete_model))
setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self.rel))
def set_attributes_from_rel(self):
pass
@ -393,7 +396,7 @@ class GenericRelation(ForeignObject):
})
class ReverseGenericRelatedObjectsDescriptor(object):
class ReverseGenericRelatedObjectsDescriptor(ForeignRelatedObjectsDescriptor):
"""
This class provides the functionality that makes the related-object
managers available as attributes on a model class, for fields that have
@ -402,87 +405,57 @@ class ReverseGenericRelatedObjectsDescriptor(object):
"article.publications", the publications attribute is a
ReverseGenericRelatedObjectsDescriptor instance.
"""
def __init__(self, field, for_concrete_model=True):
self.field = field
self.for_concrete_model = for_concrete_model
def __get__(self, instance, instance_type=None):
if instance is None:
return self
# Dynamically create a class that subclasses the related model's
# default manager.
rel_model = self.field.rel.to
superclass = rel_model._default_manager.__class__
RelatedManager = create_generic_related_manager(superclass)
qn = connection.ops.quote_name
content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
instance, for_concrete_model=self.for_concrete_model)
join_cols = self.field.get_joining_columns(reverse_join=True)[0]
manager = RelatedManager(
model=rel_model,
instance=instance,
source_col_name=qn(join_cols[0]),
target_col_name=qn(join_cols[1]),
content_type=content_type,
content_type_field_name=self.field.content_type_field_name,
object_id_field_name=self.field.object_id_field_name,
prefetch_cache_name=self.field.attname,
@cached_property
def related_manager_cls(self):
return create_generic_related_manager(
self.rel.to._default_manager.__class__,
self.rel,
)
return manager
def __set__(self, instance, value):
manager = self.__get__(instance)
manager.set(value)
def create_generic_related_manager(superclass):
def create_generic_related_manager(superclass, rel):
"""
Factory function for a manager that subclasses 'superclass' (which is a
Manager) and adds behavior for generic related objects.
"""
class GenericRelatedObjectManager(superclass):
def __init__(self, model=None, instance=None, symmetrical=None,
source_col_name=None, target_col_name=None, content_type=None,
content_type_field_name=None, object_id_field_name=None,
prefetch_cache_name=None):
def __init__(self, instance=None):
super(GenericRelatedObjectManager, self).__init__()
self.model = model
self.content_type = content_type
self.symmetrical = symmetrical
self.instance = instance
self.source_col_name = source_col_name
self.target_col_name = target_col_name
self.content_type_field_name = content_type_field_name
self.object_id_field_name = object_id_field_name
self.prefetch_cache_name = prefetch_cache_name
self.pk_val = self.instance._get_pk_val()
self.model = rel.to
content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
instance, for_concrete_model=rel.field.for_concrete_model)
self.content_type = content_type
qn = connection.ops.quote_name
join_cols = rel.field.get_joining_columns(reverse_join=True)[0]
self.source_col_name = qn(join_cols[0])
self.target_col_name = qn(join_cols[1])
self.content_type_field_name = rel.field.content_type_field_name
self.object_id_field_name = rel.field.object_id_field_name
self.prefetch_cache_name = rel.field.attname
self.pk_val = instance._get_pk_val()
self.core_filters = {
'%s__pk' % content_type_field_name: content_type.id,
'%s' % object_id_field_name: instance._get_pk_val(),
'%s__pk' % self.content_type_field_name: content_type.id,
self.object_id_field_name: self.pk_val,
}
def __call__(self, **kwargs):
# We use **kwargs rather than a kwarg argument to enforce the
# `manager='manager_name'` syntax.
manager = getattr(self.model, kwargs.pop('manager'))
manager_class = create_generic_related_manager(manager.__class__)
return manager_class(
model=self.model,
instance=self.instance,
symmetrical=self.symmetrical,
source_col_name=self.source_col_name,
target_col_name=self.target_col_name,
content_type=self.content_type,
content_type_field_name=self.content_type_field_name,
object_id_field_name=self.object_id_field_name,
prefetch_cache_name=self.prefetch_cache_name,
)
manager_class = create_generic_related_manager(manager.__class__, rel)
return manager_class(instance=self.instance)
do_not_call_in_templates = True
def __str__(self):

View File

@ -678,25 +678,28 @@ class ReverseSingleRelatedObjectDescriptor(object):
setattr(value, self.field.rel.get_cache_name(), instance)
def create_foreign_related_manager(superclass, rel_field, rel_model):
def create_foreign_related_manager(superclass, rel):
class RelatedManager(superclass):
def __init__(self, instance):
super(RelatedManager, self).__init__()
self.instance = instance
self.core_filters = {rel_field.name: instance}
self.model = rel_model
self.model = rel.related_model
self.field = rel.field
self.core_filters = {self.field.name: instance}
def __call__(self, **kwargs):
# We use **kwargs rather than a kwarg argument to enforce the
# `manager='manager_name'` syntax.
manager = getattr(self.model, kwargs.pop('manager'))
manager_class = create_foreign_related_manager(manager.__class__, rel_field, rel_model)
manager_class = create_foreign_related_manager(manager.__class__, rel)
return manager_class(self.instance)
do_not_call_in_templates = True
def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
return self.instance._prefetched_objects_cache[self.field.related_query_name()]
except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance)
empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
@ -705,11 +708,11 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
if self._db:
qs = qs.using(self._db)
qs = qs.filter(**self.core_filters)
for field in rel_field.foreign_related_fields:
for field in self.field.foreign_related_fields:
val = getattr(self.instance, field.attname)
if val is None or (val == '' and empty_strings_as_null):
return qs.none()
qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
qs._known_related_objects = {self.field: {self.instance.pk: self.instance}}
return qs
def get_prefetch_queryset(self, instances, queryset=None):
@ -719,18 +722,18 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
rel_obj_attr = rel_field.get_local_related_value
instance_attr = rel_field.get_foreign_related_value
rel_obj_attr = self.field.get_local_related_value
instance_attr = self.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
query = {'%s__in' % rel_field.name: instances}
query = {'%s__in' % self.field.name: instances}
queryset = queryset.filter(**query)
# Since we just bypassed this class' get_queryset(), we must manage
# the reverse relation manually.
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
setattr(rel_obj, rel_field.name, instance)
cache_name = rel_field.related_query_name()
setattr(rel_obj, self.field.name, instance)
cache_name = self.field.related_query_name()
return queryset, rel_obj_attr, instance_attr, False, cache_name
def add(self, *objs):
@ -741,42 +744,42 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
if not isinstance(obj, self.model):
raise TypeError("'%s' instance expected, got %r" %
(self.model._meta.object_name, obj))
setattr(obj, rel_field.name, self.instance)
setattr(obj, self.field.name, self.instance)
obj.save()
add.alters_data = True
def create(self, **kwargs):
kwargs[rel_field.name] = self.instance
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).create(**kwargs)
create.alters_data = True
def get_or_create(self, **kwargs):
kwargs[rel_field.name] = self.instance
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
get_or_create.alters_data = True
def update_or_create(self, **kwargs):
kwargs[rel_field.name] = self.instance
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs)
update_or_create.alters_data = True
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel_field.null:
if rel.field.null:
def remove(self, *objs, **kwargs):
if not objs:
return
bulk = kwargs.pop('bulk', True)
val = rel_field.get_foreign_related_value(self.instance)
val = self.field.get_foreign_related_value(self.instance)
old_ids = set()
for obj in objs:
# Is obj actually part of this descriptor set?
if rel_field.get_local_related_value(obj) == val:
if self.field.get_local_related_value(obj) == val:
old_ids.add(obj.pk)
else:
raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
raise self.field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
self._clear(self.filter(pk__in=old_ids), bulk)
remove.alters_data = True
@ -790,12 +793,12 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
queryset = queryset.using(db)
if bulk:
# `QuerySet.update()` is intrinsically atomic.
queryset.update(**{rel_field.name: None})
queryset.update(**{self.field.name: None})
else:
with transaction.atomic(using=db, savepoint=False):
for obj in queryset:
setattr(obj, rel_field.name, None)
obj.save(update_fields=[rel_field.name])
setattr(obj, self.field.name, None)
obj.save(update_fields=[self.field.name])
_clear.alters_data = True
def set(self, objs, **kwargs):
@ -805,7 +808,7 @@ def create_foreign_related_manager(superclass, rel_field, rel_model):
clear = kwargs.pop('clear', False)
if rel_field.null:
if self.field.null:
db = router.db_for_write(self.model, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
if clear:
@ -835,8 +838,16 @@ class ForeignRelatedObjectsDescriptor(object):
# multiple "remote" values and have a ForeignKey pointed at them by
# some other model. In the example "poll.choice_set", the choice_set
# attribute is a ForeignRelatedObjectsDescriptor instance.
def __init__(self, related):
self.related = related # RelatedObject instance
def __init__(self, rel):
self.rel = rel
self.field = rel.field
@cached_property
def related_manager_cls(self):
return create_foreign_related_manager(
self.rel.related_model._default_manager.__class__,
self.rel,
)
def __get__(self, instance, instance_type=None):
if instance is None:
@ -848,49 +859,47 @@ class ForeignRelatedObjectsDescriptor(object):
manager = self.__get__(instance)
manager.set(value)
@cached_property
def related_manager_cls(self):
# Dynamically create a class that subclasses the related model's default
# manager.
return create_foreign_related_manager(
self.related.related_model._default_manager.__class__,
self.related.field,
self.related.related_model,
)
def create_many_related_manager(superclass, rel, reverse):
def create_many_related_manager(superclass, rel):
"""Creates a manager that subclasses 'superclass' (which is a Manager)
and adds behavior for many-to-many related objects."""
class ManyRelatedManager(superclass):
def __init__(self, model=None, query_field_name=None, instance=None, symmetrical=None,
source_field_name=None, target_field_name=None, reverse=False,
through=None, prefetch_cache_name=None):
def __init__(self, instance=None):
super(ManyRelatedManager, self).__init__()
self.model = model
self.query_field_name = query_field_name
source_field = through._meta.get_field(source_field_name)
source_related_fields = source_field.related_fields
self.core_filters = {}
for lh_field, rh_field in source_related_fields:
self.core_filters['%s__%s' % (query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)
self.instance = instance
self.symmetrical = symmetrical
self.source_field = source_field
self.target_field = through._meta.get_field(target_field_name)
self.source_field_name = source_field_name
self.target_field_name = target_field_name
if not reverse:
self.model = rel.to
self.query_field_name = rel.field.related_query_name()
self.prefetch_cache_name = rel.field.name
self.source_field_name = rel.field.m2m_field_name()
self.target_field_name = rel.field.m2m_reverse_field_name()
self.symmetrical = rel.symmetrical
else:
self.model = rel.related_model
self.query_field_name = rel.field.name
self.prefetch_cache_name = rel.field.related_query_name()
self.source_field_name = rel.field.m2m_reverse_field_name()
self.target_field_name = rel.field.m2m_field_name()
self.symmetrical = False
self.through = rel.through
self.reverse = reverse
self.through = through
self.prefetch_cache_name = prefetch_cache_name
self.related_val = source_field.get_foreign_related_value(instance)
self.source_field = self.through._meta.get_field(self.source_field_name)
self.target_field = self.through._meta.get_field(self.target_field_name)
self.core_filters = {}
for lh_field, rh_field in self.source_field.related_fields:
self.core_filters['%s__%s' % (self.query_field_name, rh_field.name)] = getattr(instance, rh_field.attname)
self.related_val = self.source_field.get_foreign_related_value(instance)
if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
(instance, source_field_name))
(instance, self.source_field_name))
# Even if this relation is not to pk, we require still pk value.
# The wish is that the instance has been already saved to DB,
# although having a pk value isn't a guarantee of that.
@ -903,18 +912,8 @@ def create_many_related_manager(superclass, rel):
# We use **kwargs rather than a kwarg argument to enforce the
# `manager='manager_name'` syntax.
manager = getattr(self.model, kwargs.pop('manager'))
manager_class = create_many_related_manager(manager.__class__, rel)
return manager_class(
model=self.model,
query_field_name=self.query_field_name,
instance=self.instance,
symmetrical=self.symmetrical,
source_field_name=self.source_field_name,
target_field_name=self.target_field_name,
reverse=self.reverse,
through=self.through,
prefetch_cache_name=self.prefetch_cache_name,
)
manager_class = create_many_related_manager(manager.__class__, rel, reverse)
return manager_class(instance=self.instance)
do_not_call_in_templates = True
def _build_remove_filters(self, removed_vals):
@ -1198,107 +1197,38 @@ def create_many_related_manager(superclass, rel):
return ManyRelatedManager
class ManyRelatedObjectsDescriptor(object):
# This class provides the functionality that makes the related-object
# managers available as attributes on a model class, for fields that have
# multiple "remote" values and have a ManyToManyField pointed at them by
# some other model (rather than having a ManyToManyField themselves).
# In the example "publication.article_set", the article_set attribute is a
# ManyRelatedObjectsDescriptor instance.
def __init__(self, related):
self.related = related # RelatedObject instance
@cached_property
def related_manager_cls(self):
# Dynamically create a class that subclasses the related
# model's default manager.
return create_many_related_manager(
self.related.related_model._default_manager.__class__,
self.related.field.rel
)
def __get__(self, instance, instance_type=None):
if instance is None:
return self
rel_model = self.related.related_model
manager = self.related_manager_cls(
model=rel_model,
query_field_name=self.related.field.name,
prefetch_cache_name=self.related.field.related_query_name(),
instance=instance,
symmetrical=False,
source_field_name=self.related.field.m2m_reverse_field_name(),
target_field_name=self.related.field.m2m_field_name(),
reverse=True,
through=self.related.field.rel.through,
)
return manager
def __set__(self, instance, value):
manager = self.__get__(instance)
manager.set(value)
class ManyRelatedObjectsDescriptor(ForeignRelatedObjectsDescriptor):
class ReverseManyRelatedObjectsDescriptor(object):
# This class provides the functionality that makes the related-object
# managers available as attributes on a model class, for fields that have
# multiple "remote" values and have a ManyToManyField defined in their
# model (rather than having another model pointed *at* them).
# In the example "article.publications", the publications attribute is a
# ReverseManyRelatedObjectsDescriptor instance.
def __init__(self, m2m_field):
self.field = m2m_field
def __init__(self, rel, reverse=False):
super(ManyRelatedObjectsDescriptor, self).__init__(rel)
self.reverse = reverse
@property
def through(self):
# through is provided so that you have easy access to the through
# model (Book.authors.through) for inlines, etc. This is done as
# a property to ensure that the fully resolved value is returned.
return self.field.rel.through
return self.rel.through
@cached_property
def related_manager_cls(self):
model = self.rel.related_model if self.reverse else self.rel.to
# Dynamically create a class that subclasses the related model's
# default manager.
return create_many_related_manager(
self.field.rel.to._default_manager.__class__,
self.field.rel
model._default_manager.__class__,
self.rel,
reverse=self.reverse,
)
def __get__(self, instance, instance_type=None):
if instance is None:
return self
manager = self.related_manager_cls(
model=self.field.rel.to,
query_field_name=self.field.related_query_name(),
prefetch_cache_name=self.field.name,
instance=instance,
symmetrical=self.field.rel.symmetrical,
source_field_name=self.field.m2m_field_name(),
target_field_name=self.field.m2m_reverse_field_name(),
reverse=False,
through=self.field.rel.through,
)
return manager
def __set__(self, instance, value):
if not self.field.rel.through._meta.auto_created:
opts = self.field.rel.through._meta
raise AttributeError(
"Cannot set values on a ManyToManyField which specifies an "
"intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
)
manager = self.__get__(instance)
manager.set(value)
class ForeignObjectRel(object):
# Field flags
auto_created = True
concrete = False
@ -1505,10 +1435,6 @@ class ManyToManyRel(ForeignObjectRel):
self.symmetrical = symmetrical
self.db_constraint = db_constraint
def is_hidden(self):
"Should the related object be hidden?"
return self.related_name is not None and self.related_name[-1] == '+'
def get_related_field(self):
"""
Returns the field in the 'to' object to which this relationship is tied.
@ -2599,7 +2525,8 @@ class ManyToManyField(RelatedField):
self.rel.through = create_many_to_many_intermediary_model(self, cls)
# Add the descriptor for the m2m relation
setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self))
# Add the descriptor for the m2m relation.
setattr(cls, self.name, ManyRelatedObjectsDescriptor(self.rel, reverse=False))
# Set up the accessor for the m2m table name for the relation
self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
@ -2615,7 +2542,7 @@ class ManyToManyField(RelatedField):
# Internal M2Ms (i.e., those with a related name ending with '+')
# and swapped models don't get a related descriptor.
if not self.rel.is_hidden() and not related.related_model._meta.swapped:
setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(related))
setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(self.rel, reverse=True))
# Set up the accessors for the column names on the m2m table
self.m2m_column_name = curry(self._get_m2m_attr, related, 'column')

View File

@ -424,7 +424,7 @@ class ManyToManyTests(TestCase):
def test_reverse_assign_with_queryset(self):
# Ensure that querysets used in M2M assignments are pre-evaluated
# so their value isn't affected by the clearing operation in
# ReverseManyRelatedObjectsDescriptor.__set__. Refs #19816.
# ManyRelatedObjectsDescriptor.__set__. Refs #19816.
self.p1.article_set = [self.a1, self.a2]
qs = self.p1.article_set.filter(headline='Django lets you build Web apps easily')

View File

@ -1,6 +1,6 @@
from django.db.models.fields.related import (
RECURSIVE_RELATIONSHIP_CONSTANT, ManyToManyField, ManyToManyRel,
RelatedField, ReverseManyRelatedObjectsDescriptor,
RECURSIVE_RELATIONSHIP_CONSTANT, ManyRelatedObjectsDescriptor,
ManyToManyField, ManyToManyRel, RelatedField,
create_many_to_many_intermediary_model,
)
from django.utils.functional import curry
@ -40,7 +40,7 @@ class CustomManyToManyField(RelatedField):
super(CustomManyToManyField, self).contribute_to_class(cls, name, **kwargs)
if not self.rel.through and not cls._meta.abstract and not cls._meta.swapped:
self.rel.through = create_many_to_many_intermediary_model(self, cls)
setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self))
setattr(cls, self.name, ManyRelatedObjectsDescriptor(self.rel))
self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
def get_internal_type(self):