Refs #29898 -- Made ProjectState encapsulate alterations in relations registry.

Thanks Simon Charette and Chris Jerdonek for reviews.

Co-authored-by: Mariusz Felisiak <felisiak.mariusz@gmail.com>
This commit is contained in:
Manav Agarwal 2021-08-23 06:52:23 +02:00 committed by Mariusz Felisiak
parent d7394cfa13
commit 196a99da5d
2 changed files with 474 additions and 20 deletions

View File

@ -98,18 +98,34 @@ class ProjectState:
self.real_apps = real_apps
self.is_delayed = False
# {remote_model_key: {model_key: [(field_name, field)]}}
self.relations = None
self._relations = None
@property
def relations(self):
if self._relations is None:
self.resolve_fields_and_relations()
return self._relations
def add_model(self, model_state):
app_label, model_name = model_state.app_label, model_state.name_lower
self.models[(app_label, model_name)] = model_state
model_key = model_state.app_label, model_state.name_lower
self.models[model_key] = model_state
if self._relations is not None:
self.resolve_model_relations(model_key)
if 'apps' in self.__dict__: # hasattr would cache the property
self.reload_model(app_label, model_name)
self.reload_model(*model_key)
def remove_model(self, app_label, model_name):
del self.models[app_label, model_name]
model_key = app_label, model_name
del self.models[model_key]
if self._relations is not None:
self._relations.pop(model_key, None)
# Call list() since _relations can change size during iteration.
for related_model_key, model_relations in list(self._relations.items()):
model_relations.pop(model_key, None)
if not model_relations:
del self._relations[related_model_key]
if 'apps' in self.__dict__: # hasattr would cache the property
self.apps.unregister_model(app_label, model_name)
self.apps.unregister_model(*model_key)
# Need to do this explicitly since unregister_model() doesn't clear
# the cache automatically (#24513)
self.apps.clear_cache()
@ -137,6 +153,14 @@ class ProjectState:
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
if self._relations is not None:
old_name_key = app_label, old_name_lower
new_name_key = app_label, new_name_lower
if old_name_key in self._relations:
self._relations[new_name_key] = self._relations.pop(old_name_key)
for model_relations in self._relations.values():
if old_name_key in model_relations:
model_relations[new_name_key] = model_relations.pop(old_name_key)
# Reload models related to old model before removing the old model.
self.reload_models(to_reload, delay=True)
# Remove the old model.
@ -187,17 +211,23 @@ class ProjectState:
field.default = NOT_PROVIDED
else:
field = field
self.models[app_label, model_name].fields[name] = field
model_key = app_label, model_name
self.models[model_key].fields[name] = field
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, field)
# Delay rendering of relationships if it's not a relational field.
delay = not field.is_relation
self.reload_model(app_label, model_name, delay=delay)
self.reload_model(*model_key, delay=delay)
def remove_field(self, app_label, model_name, name):
model_state = self.models[app_label, model_name]
model_key = app_label, model_name
model_state = self.models[model_key]
old_field = model_state.fields.pop(name)
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, old_field)
# Delay rendering of relationships if it's not a relational field.
delay = not old_field.is_relation
self.reload_model(app_label, model_name, delay=delay)
self.reload_model(*model_key, delay=delay)
def alter_field(self, app_label, model_name, name, field, preserve_default):
if not preserve_default:
@ -205,20 +235,30 @@ class ProjectState:
field.default = NOT_PROVIDED
else:
field = field
model_state = self.models[app_label, model_name]
model_state.fields[name] = field
model_key = app_label, model_name
fields = self.models[model_key].fields
if self._relations is not None:
old_field = fields.pop(name)
if old_field.is_relation:
self.resolve_model_field_relations(model_key, name, old_field)
fields[name] = field
if field.is_relation:
self.resolve_model_field_relations(model_key, name, field)
else:
fields[name] = field
# TODO: investigate if old relational fields must be reloaded or if
# it's sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = (
not field.is_relation and
not field_is_referenced(self, (app_label, model_name), (name, field))
not field_is_referenced(self, model_key, (name, field))
)
self.reload_model(app_label, model_name, delay=delay)
self.reload_model(*model_key, delay=delay)
def rename_field(self, app_label, model_name, old_name, new_name):
model_state = self.models[app_label, model_name]
model_key = app_label, model_name
model_state = self.models[model_key]
# Rename the field.
fields = model_state.fields
try:
@ -246,7 +286,7 @@ class ProjectState:
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(self, (app_label, model_name), (old_name, found))
references = get_references(self, model_key, (old_name, found))
for *_, field, reference in references:
delay = False
if reference.to:
@ -258,7 +298,19 @@ class ProjectState:
new_name if to_field_name == old_name else to_field_name
for to_field_name in to_fields
])
self.reload_model(app_label, model_name, delay=delay)
if self._relations is not None:
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
for to_model in self._relations.values():
# It's safe to modify the same collection that is iterated
# because `break` is called right after.
for field_name, field in to_model[model_key]:
if field_name == old_name_lower:
field.name = new_name_lower
to_model[model_key].remove((old_name_lower, field))
to_model[model_key].append((new_name_lower, field))
break
self.reload_model(*model_key, delay=delay)
def _find_reload_model(self, app_label, model_name, delay=False):
if delay:
@ -352,7 +404,13 @@ class ProjectState:
remote_model_key = resolve_relation(model, *model_key)
if remote_model_key[0] not in self.real_apps and remote_model_key in concretes:
remote_model_key = concretes[remote_model_key]
self.relations[remote_model_key][model_key].append((field_name, field))
relations_to_remote_model = self._relations[remote_model_key]
if field_name in self.models[model_key].fields:
relations_to_remote_model[model_key].append((field_name, field))
else:
relations_to_remote_model[model_key].remove((field_name, field))
if not relations_to_remote_model[model_key]:
del relations_to_remote_model[model_key]
def resolve_model_field_relations(
self, model_key, field_name, field, concretes=None,
@ -387,14 +445,14 @@ class ProjectState:
field.name = field_name
# Resolve relations.
# {remote_model_key: {model_key: [(field_name, field)]}}
self.relations = defaultdict(partial(defaultdict, list))
self._relations = defaultdict(partial(defaultdict, list))
concretes, proxies = self._get_concrete_models_mapping_and_proxy_models()
for model_key in concretes:
self.resolve_model_relations(model_key, concretes)
for model_key in proxies:
self.relations[model_key] = self.relations[concretes[model_key]]
self._relations[model_key] = self._relations[concretes[model_key]]
def get_concrete_model_key(self, model):
concrete_models_mapping, _ = self._get_concrete_models_mapping_and_proxy_models()

View File

@ -1017,6 +1017,402 @@ class StateTests(SimpleTestCase):
self.assertEqual(list(choices_field.choices), choices)
class StateRelationsTests(SimpleTestCase):
def get_base_project_state(self):
new_apps = Apps()
class User(models.Model):
class Meta:
app_label = 'tests'
apps = new_apps
class Comment(models.Model):
text = models.TextField()
user = models.ForeignKey(User, models.CASCADE)
comments = models.ManyToManyField('self')
class Meta:
app_label = 'tests'
apps = new_apps
class Post(models.Model):
text = models.TextField()
authors = models.ManyToManyField(User)
class Meta:
app_label = 'tests'
apps = new_apps
project_state = ProjectState()
project_state.add_model(ModelState.from_model(User))
project_state.add_model(ModelState.from_model(Comment))
project_state.add_model(ModelState.from_model(Post))
return project_state
def test_relations_population(self):
tests = [
('add_model', [
ModelState(
app_label='migrations',
name='Tag',
fields=[('id', models.AutoField(primary_key=True))],
),
]),
('remove_model', ['tests', 'comment']),
('rename_model', ['tests', 'comment', 'opinion']),
('add_field', [
'tests',
'post',
'next_post',
models.ForeignKey('self', models.CASCADE),
True,
]),
('remove_field', ['tests', 'post', 'text']),
('rename_field', ['tests', 'comment', 'user', 'author']),
('alter_field', [
'tests',
'comment',
'user',
models.IntegerField(),
True,
]),
]
for method, args in tests:
with self.subTest(method=method):
project_state = self.get_base_project_state()
getattr(project_state, method)(*args)
# ProjectState's `_relations` are populated on `relations` access.
self.assertIsNone(project_state._relations)
self.assertEqual(project_state.relations, project_state._relations)
self.assertIsNotNone(project_state._relations)
def test_add_model(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
self.assertEqual(
list(project_state.relations['tests', 'comment']),
[('tests', 'comment')],
)
self.assertNotIn(('tests', 'post'), project_state.relations)
def test_add_model_no_relations(self):
project_state = ProjectState()
project_state.add_model(ModelState(
app_label='migrations',
name='Tag',
fields=[('id', models.AutoField(primary_key=True))],
))
self.assertEqual(project_state.relations, {})
def test_add_model_other_app(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
project_state.add_model(ModelState(
app_label='tests_other',
name='comment',
fields=[
('id', models.AutoField(primary_key=True)),
('user', models.ForeignKey('tests.user', models.CASCADE)),
],
))
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post'), ('tests_other', 'comment')],
)
def test_remove_model(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
self.assertEqual(
list(project_state.relations['tests', 'comment']),
[('tests', 'comment')],
)
project_state.remove_model('tests', 'comment')
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'post')],
)
self.assertNotIn(('tests', 'comment'), project_state.relations)
project_state.remove_model('tests', 'post')
self.assertEqual(project_state.relations, {})
project_state.remove_model('tests', 'user')
self.assertEqual(project_state.relations, {})
def test_rename_model(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
self.assertEqual(
list(project_state.relations['tests', 'comment']),
[('tests', 'comment')],
)
related_field = project_state.relations['tests', 'user']['tests', 'comment']
project_state.rename_model('tests', 'comment', 'opinion')
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'post'), ('tests', 'opinion')],
)
self.assertEqual(
list(project_state.relations['tests', 'opinion']),
[('tests', 'opinion')],
)
self.assertNotIn(('tests', 'comment'), project_state.relations)
self.assertEqual(
project_state.relations['tests', 'user']['tests', 'opinion'],
related_field,
)
project_state.rename_model('tests', 'user', 'author')
self.assertEqual(
list(project_state.relations['tests', 'author']),
[('tests', 'post'), ('tests', 'opinion')],
)
self.assertNotIn(('tests', 'user'), project_state.relations)
def test_rename_model_no_relations(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
related_field = project_state.relations['tests', 'user']['tests', 'post']
self.assertNotIn(('tests', 'post'), project_state.relations)
# Rename a model without relations.
project_state.rename_model('tests', 'post', 'blog')
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'blog')],
)
self.assertNotIn(('tests', 'blog'), project_state.relations)
self.assertEqual(
related_field,
project_state.relations['tests', 'user']['tests', 'blog'],
)
def test_add_field(self):
project_state = self.get_base_project_state()
self.assertNotIn(('tests', 'post'), project_state.relations)
# Add a self-referential foreign key.
new_field = models.ForeignKey('self', models.CASCADE)
project_state.add_field(
'tests', 'post', 'next_post', new_field, preserve_default=True,
)
self.assertEqual(
list(project_state.relations['tests', 'post']),
[('tests', 'post')],
)
self.assertEqual(
project_state.relations['tests', 'post']['tests', 'post'],
[('next_post', new_field)],
)
# Add a foreign key.
new_field = models.ForeignKey('tests.post', models.CASCADE)
project_state.add_field(
'tests', 'comment', 'post', new_field, preserve_default=True,
)
self.assertEqual(
list(project_state.relations['tests', 'post']),
[('tests', 'post'), ('tests', 'comment')],
)
self.assertEqual(
project_state.relations['tests', 'post']['tests', 'comment'],
[('post', new_field)],
)
def test_add_field_m2m_with_through(self):
project_state = self.get_base_project_state()
project_state.add_model(ModelState(
app_label='tests',
name='Tag',
fields=[('id', models.AutoField(primary_key=True))],
))
project_state.add_model(ModelState(
app_label='tests',
name='PostTag',
fields=[
('id', models.AutoField(primary_key=True)),
('post', models.ForeignKey('tests.post', models.CASCADE)),
('tag', models.ForeignKey('tests.tag', models.CASCADE)),
],
))
self.assertEqual(
list(project_state.relations['tests', 'post']),
[('tests', 'posttag')],
)
self.assertEqual(
list(project_state.relations['tests', 'tag']),
[('tests', 'posttag')],
)
# Add a many-to-many field with the through model.
new_field = models.ManyToManyField('tests.tag', through='tests.posttag')
project_state.add_field(
'tests', 'post', 'tags', new_field, preserve_default=True,
)
self.assertEqual(
list(project_state.relations['tests', 'post']),
[('tests', 'posttag')],
)
self.assertEqual(
list(project_state.relations['tests', 'tag']),
[('tests', 'posttag'), ('tests', 'post')],
)
self.assertEqual(
project_state.relations['tests', 'tag']['tests', 'post'],
[('tags', new_field)],
)
def test_remove_field(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
# Remove a many-to-many field.
project_state.remove_field('tests', 'post', 'authors')
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment')],
)
# Remove a foreign key.
project_state.remove_field('tests', 'comment', 'user')
self.assertEqual(project_state.relations['tests', 'user'], {})
def test_remove_field_no_relations(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
# Remove a non-relation field.
project_state.remove_field('tests', 'post', 'text')
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
def test_rename_field(self):
project_state = self.get_base_project_state()
field = project_state.models['tests', 'comment'].fields['user']
self.assertEqual(
project_state.relations['tests', 'user']['tests', 'comment'],
[('user', field)],
)
project_state.rename_field('tests', 'comment', 'user', 'author')
renamed_field = project_state.models['tests', 'comment'].fields['author']
self.assertEqual(
project_state.relations['tests', 'user']['tests', 'comment'],
[('author', renamed_field)],
)
self.assertEqual(field, renamed_field)
def test_rename_field_no_relations(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
# Rename a non-relation field.
project_state.rename_field('tests', 'post', 'text', 'description')
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
def test_alter_field(self):
project_state = self.get_base_project_state()
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
# Alter a foreign key to a non-relation field.
project_state.alter_field(
'tests', 'comment', 'user', models.IntegerField(), preserve_default=True,
)
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'post')],
)
# Alter a non-relation field to a many-to-many field.
m2m_field = models.ManyToManyField('tests.user')
project_state.alter_field(
'tests', 'comment', 'user', m2m_field, preserve_default=True,
)
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'post'), ('tests', 'comment')],
)
self.assertEqual(
project_state.relations['tests', 'user']['tests', 'comment'],
[('user', m2m_field)],
)
def test_alter_field_m2m_to_fk(self):
project_state = self.get_base_project_state()
project_state.add_model(ModelState(
app_label='tests_other',
name='user_other',
fields=[('id', models.AutoField(primary_key=True))],
))
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
self.assertNotIn(('tests_other', 'user_other'), project_state.relations)
# Alter a many-to-many field to a foreign key.
foreign_key = models.ForeignKey('tests_other.user_other', models.CASCADE)
project_state.alter_field(
'tests', 'post', 'authors', foreign_key, preserve_default=True,
)
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment')],
)
self.assertEqual(
list(project_state.relations['tests_other', 'user_other']),
[('tests', 'post')],
)
self.assertEqual(
project_state.relations['tests_other', 'user_other']['tests', 'post'],
[('authors', foreign_key)],
)
def test_many_relations_to_same_model(self):
project_state = self.get_base_project_state()
new_field = models.ForeignKey('tests.user', models.CASCADE)
project_state.add_field(
'tests', 'comment', 'reviewer', new_field, preserve_default=True,
)
self.assertEqual(
list(project_state.relations['tests', 'user']),
[('tests', 'comment'), ('tests', 'post')],
)
comment_rels = project_state.relations['tests', 'user']['tests', 'comment']
# Two foreign keys to the same model.
self.assertEqual(len(comment_rels), 2)
self.assertEqual(comment_rels[1], ('reviewer', new_field))
# Rename the second foreign key.
project_state.rename_field('tests', 'comment', 'reviewer', 'supervisor')
self.assertEqual(len(comment_rels), 2)
self.assertEqual(comment_rels[1], ('supervisor', new_field))
# Remove the first foreign key.
project_state.remove_field('tests', 'comment', 'user')
self.assertEqual(comment_rels, [('supervisor', new_field)])
class ModelStateTests(SimpleTestCase):
def test_custom_model_base(self):
state = ModelState.from_model(ModelWithCustomBase)