diff --git a/django/db/models/base.py b/django/db/models/base.py index faed79cc2f..2f961a4393 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -84,9 +84,12 @@ class ModelBase(type): # Pass all attrs without a (Django-specific) contribute_to_class() # method to type.__new__() so that they're properly initialized # (i.e. __set_name__()). + contributable_attrs = {} for obj_name, obj in list(attrs.items()): - if not _has_contribute_to_class(obj): - new_attrs[obj_name] = attrs.pop(obj_name) + if _has_contribute_to_class(obj): + contributable_attrs[obj_name] = obj + else: + new_attrs[obj_name] = obj new_class = super_new(cls, name, bases, new_attrs, **kwargs) abstract = getattr(attr_meta, 'abstract', False) @@ -146,8 +149,9 @@ class ModelBase(type): if is_proxy and base_meta and base_meta.swapped: raise TypeError("%s cannot proxy the swapped model '%s'." % (name, base_meta.swapped)) - # Add all attributes to the class. - for obj_name, obj in attrs.items(): + # Add remaining attributes (those with a contribute_to_class() method) + # to the class. + for obj_name, obj in contributable_attrs.items(): new_class.add_to_class(obj_name, obj) # All the fields of any type declared on this model diff --git a/tests/model_regress/tests.py b/tests/model_regress/tests.py index e3977ee316..28eed87008 100644 --- a/tests/model_regress/tests.py +++ b/tests/model_regress/tests.py @@ -2,9 +2,10 @@ import datetime from operator import attrgetter from django.core.exceptions import ValidationError -from django.db import router +from django.db import models, router from django.db.models.sql import InsertQuery from django.test import TestCase, skipUnlessDBFeature +from django.test.utils import isolate_apps from django.utils.timezone import get_fixed_timezone from .models import ( @@ -217,6 +218,23 @@ class ModelTests(TestCase): m3 = Model3.objects.get(model2=1000) m3.model2 + @isolate_apps('model_regress') + def test_metaclass_can_access_attribute_dict(self): + """ + Model metaclasses have access to the class attribute dict in + __init__() (#30254). + """ + class HorseBase(models.base.ModelBase): + def __init__(cls, name, bases, attrs): + super(HorseBase, cls).__init__(name, bases, attrs) + cls.horns = (1 if 'magic' in attrs else 0) + + class Horse(models.Model, metaclass=HorseBase): + name = models.CharField(max_length=255) + magic = True + + self.assertEqual(Horse.horns, 1) + class ModelValidationTest(TestCase): def test_pk_validation(self):