Fixed #23745 -- Reused states as much as possible in migrations

Thanks Tim Graham and Markus Holtermann for the reviews.
This commit is contained in:
Claude Paroz 2014-11-05 20:53:39 +01:00
parent 2a9c4b4901
commit 1aa3e09c20
9 changed files with 137 additions and 62 deletions

View File

@ -98,13 +98,13 @@ class MigrationExecutor(object):
self.progress_callback("apply_start", migration, fake) self.progress_callback("apply_start", migration, fake)
if not fake: if not fake:
# Test to see if this is an already-applied initial migration # Test to see if this is an already-applied initial migration
if self.detect_soft_applied(state, migration): applied, state = self.detect_soft_applied(state, migration)
if applied:
fake = True fake = True
else: else:
# Alright, do it normally # Alright, do it normally
with self.connection.schema_editor() as schema_editor: with self.connection.schema_editor() as schema_editor:
project_state = self.loader.project_state((migration.app_label, migration.name), at_end=False) state = migration.apply(state, schema_editor)
migration.apply(project_state, schema_editor)
# For replacement migrations, record individual statuses # For replacement migrations, record individual statuses
if migration.replaces: if migration.replaces:
for app_label, name in migration.replaces: for app_label, name in migration.replaces:
@ -124,8 +124,7 @@ class MigrationExecutor(object):
self.progress_callback("unapply_start", migration, fake) self.progress_callback("unapply_start", migration, fake)
if not fake: if not fake:
with self.connection.schema_editor() as schema_editor: with self.connection.schema_editor() as schema_editor:
project_state = self.loader.project_state((migration.app_label, migration.name), at_end=False) state = migration.unapply(state, schema_editor)
migration.unapply(project_state, schema_editor)
# For replacement migrations, record individual statuses # For replacement migrations, record individual statuses
if migration.replaces: if migration.replaces:
for app_label, name in migration.replaces: for app_label, name in migration.replaces:
@ -143,12 +142,15 @@ class MigrationExecutor(object):
tables it would create exist. This is intended only for use tables it would create exist. This is intended only for use
on initial migrations (as it only looks for CreateModel). on initial migrations (as it only looks for CreateModel).
""" """
project_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
apps = project_state.apps
found_create_migration = False
# Bail if the migration isn't the first one in its app # Bail if the migration isn't the first one in its app
if [name for app, name in migration.dependencies if app == migration.app_label]: if [name for app, name in migration.dependencies if app == migration.app_label]:
return False return False, project_state
if project_state is None:
after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
else:
after_state = migration.mutate_state(project_state)
apps = after_state.apps
found_create_migration = False
# Make sure all create model are done # Make sure all create model are done
for operation in migration.operations: for operation in migration.operations:
if isinstance(operation, migrations.CreateModel): if isinstance(operation, migrations.CreateModel):
@ -158,8 +160,8 @@ class MigrationExecutor(object):
# main app cache, as it's not a direct dependency. # main app cache, as it's not a direct dependency.
model = global_apps.get_model(model._meta.swapped) model = global_apps.get_model(model._meta.swapped)
if model._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()): if model._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
return False return False, project_state
found_create_migration = True found_create_migration = True
# If we get this far and we found at least one CreateModel migration, # If we get this far and we found at least one CreateModel migration,
# the migration is considered implicitly applied. # the migration is considered implicitly applied.
return found_create_migration return found_create_migration, after_state

View File

@ -97,19 +97,17 @@ class Migration(object):
schema_editor.collected_sql.append("-- %s" % operation.describe()) schema_editor.collected_sql.append("-- %s" % operation.describe())
schema_editor.collected_sql.append("--") schema_editor.collected_sql.append("--")
continue continue
# Get the state after the operation has run # Save the state before the operation has run
new_state = project_state.clone() old_state = project_state.clone()
operation.state_forwards(self.app_label, new_state) operation.state_forwards(self.app_label, project_state)
# Run the operation # Run the operation
if not schema_editor.connection.features.can_rollback_ddl and operation.atomic: if not schema_editor.connection.features.can_rollback_ddl and operation.atomic:
# We're forcing a transaction on a non-transactional-DDL backend # We're forcing a transaction on a non-transactional-DDL backend
with atomic(schema_editor.connection.alias): with atomic(schema_editor.connection.alias):
operation.database_forwards(self.app_label, schema_editor, project_state, new_state) operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
else: else:
# Normal behaviour # Normal behaviour
operation.database_forwards(self.app_label, schema_editor, project_state, new_state) operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
# Switch states
project_state = new_state
return project_state return project_state
def unapply(self, project_state, schema_editor, collect_sql=False): def unapply(self, project_state, schema_editor, collect_sql=False):
@ -133,10 +131,9 @@ class Migration(object):
# If it's irreversible, error out # If it's irreversible, error out
if not operation.reversible: if not operation.reversible:
raise Migration.IrreversibleError("Operation %s in %s is not reversible" % (operation, self)) raise Migration.IrreversibleError("Operation %s in %s is not reversible" % (operation, self))
new_state = project_state.clone() old_state = project_state.clone()
operation.state_forwards(self.app_label, new_state) operation.state_forwards(self.app_label, project_state)
to_run.append((operation, project_state, new_state)) to_run.append((operation, old_state, project_state))
project_state = new_state
# Now run them in reverse # Now run them in reverse
to_run.reverse() to_run.reverse()
for operation, to_state, from_state in to_run: for operation, to_state, from_state in to_run:

View File

@ -38,6 +38,7 @@ class AddField(Operation):
else: else:
field = self.field field = self.field
state.models[app_label, self.model_name.lower()].fields.append((self.name, field)) state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
state.reload_model(app_label, self.model_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name) to_model = to_state.apps.get_model(app_label, self.model_name)
@ -94,6 +95,7 @@ class RemoveField(Operation):
if name != self.name: if name != self.name:
new_fields.append((name, instance)) new_fields.append((name, instance))
state.models[app_label, self.model_name.lower()].fields = new_fields state.models[app_label, self.model_name.lower()].fields = new_fields
state.reload_model(app_label, self.model_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name) from_model = from_state.apps.get_model(app_label, self.model_name)
@ -150,6 +152,7 @@ class AlterField(Operation):
state.models[app_label, self.model_name.lower()].fields = [ state.models[app_label, self.model_name.lower()].fields = [
(n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields (n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
] ]
state.reload_model(app_label, self.model_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name) to_model = to_state.apps.get_model(app_label, self.model_name)
@ -220,6 +223,7 @@ class RenameField(Operation):
[self.new_name if n == self.old_name else n for n in together] [self.new_name if n == self.old_name else n for n in together]
for together in options[option] for together in options[option]
] ]
state.reload_model(app_label, self.model_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name) to_model = to_state.apps.get_model(app_label, self.model_name)

View File

@ -39,14 +39,14 @@ class CreateModel(Operation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
state.models[app_label, self.name.lower()] = ModelState( state.add_model(ModelState(
app_label, app_label,
self.name, self.name,
list(self.fields), list(self.fields),
dict(self.options), dict(self.options),
tuple(self.bases), tuple(self.bases),
list(self.managers), list(self.managers),
) ))
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name) model = to_state.apps.get_model(app_label, self.name)
@ -98,7 +98,7 @@ class DeleteModel(Operation):
) )
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
del state.models[app_label, self.name.lower()] state.remove_model(app_label, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.name) model = from_state.apps.get_model(app_label, self.name)
@ -141,12 +141,13 @@ class RenameModel(Operation):
# Get all of the related objects we need to repoint # Get all of the related objects we need to repoint
apps = state.apps apps = state.apps
model = apps.get_model(app_label, self.old_name) model = apps.get_model(app_label, self.old_name)
model._meta.apps = apps
related_objects = model._meta.get_all_related_objects() related_objects = model._meta.get_all_related_objects()
related_m2m_objects = model._meta.get_all_related_many_to_many_objects() related_m2m_objects = model._meta.get_all_related_many_to_many_objects()
# Rename the model # Rename the model
state.models[app_label, self.new_name.lower()] = state.models[app_label, self.old_name.lower()] state.models[app_label, self.new_name.lower()] = state.models[app_label, self.old_name.lower()]
state.models[app_label, self.new_name.lower()].name = self.new_name state.models[app_label, self.new_name.lower()].name = self.new_name
del state.models[app_label, self.old_name.lower()] state.remove_model(app_label, self.old_name)
# Repoint the FKs and M2Ms pointing to us # Repoint the FKs and M2Ms pointing to us
for related_object in (related_objects + related_m2m_objects): for related_object in (related_objects + related_m2m_objects):
# Use the new related key for self referential related objects. # Use the new related key for self referential related objects.
@ -164,7 +165,8 @@ class RenameModel(Operation):
field.rel.to = "%s.%s" % (app_label, self.new_name) field.rel.to = "%s.%s" % (app_label, self.new_name)
new_fields.append((name, field)) new_fields.append((name, field))
state.models[related_key].fields = new_fields state.models[related_key].fields = new_fields
del state.apps # FIXME: this should be replaced by a logic in state (update_model?) state.reload_model(*related_key)
state.reload_model(app_label, self.new_name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name) new_model = to_state.apps.get_model(app_label, self.new_name)
@ -235,6 +237,7 @@ class AlterModelTable(Operation):
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
state.models[app_label, self.name.lower()].options["db_table"] = self.table state.models[app_label, self.name.lower()].options["db_table"] = self.table
state.reload_model(app_label, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name) new_model = to_state.apps.get_model(app_label, self.name)
@ -290,6 +293,7 @@ class AlterUniqueTogether(Operation):
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name.lower()] model_state = state.models[app_label, self.name.lower()]
model_state.options[self.option_name] = self.unique_together model_state.options[self.option_name] = self.unique_together
state.reload_model(app_label, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name) new_model = to_state.apps.get_model(app_label, self.name)
@ -337,6 +341,7 @@ class AlterIndexTogether(Operation):
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name.lower()] model_state = state.models[app_label, self.name.lower()]
model_state.options[self.option_name] = self.index_together model_state.options[self.option_name] = self.index_together
state.reload_model(app_label, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name) new_model = to_state.apps.get_model(app_label, self.name)
@ -381,6 +386,7 @@ class AlterOrderWithRespectTo(Operation):
def state_forwards(self, app_label, state): def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name.lower()] model_state = state.models[app_label, self.name.lower()]
model_state.options['order_with_respect_to'] = self.order_with_respect_to model_state.options['order_with_respect_to'] = self.order_with_respect_to
state.reload_model(app_label, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.name) to_model = to_state.apps.get_model(app_label, self.name)
@ -451,6 +457,7 @@ class AlterModelOptions(Operation):
for key in self.ALTER_OPTION_KEYS: for key in self.ALTER_OPTION_KEYS:
if key not in self.options and key in model_state.options: if key not in self.options and key in model_state.options:
del model_state.options[key] del model_state.options[key]
state.reload_model(app_label, self.name)
def database_forwards(self, app_label, schema_editor, from_state, to_state): def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass pass

View File

@ -1,4 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import OrderedDict
import copy
from django.apps import AppConfig from django.apps import AppConfig
from django.apps.registry import Apps, apps as global_apps from django.apps.registry import Apps, apps as global_apps
@ -30,15 +32,52 @@ class ProjectState(object):
# Apps to include from main registry, usually unmigrated ones # Apps to include from main registry, usually unmigrated ones
self.real_apps = real_apps or [] self.real_apps = real_apps or []
def add_model_state(self, model_state): def add_model(self, model_state):
self.models[(model_state.app_label, model_state.name.lower())] = model_state app_label, model_name = model_state.app_label, model_state.name.lower()
self.models[(app_label, model_name)] = model_state
if 'apps' in self.__dict__: # hasattr would cache the property
self.reload_model(app_label, model_name)
def remove_model(self, app_label, model_name):
model_name = model_name.lower()
del self.models[app_label, model_name]
if 'apps' in self.__dict__: # hasattr would cache the property
self.apps.unregister_model(app_label, model_name)
def reload_model(self, app_label, model_name):
if 'apps' in self.__dict__: # hasattr would cache the property
# Get relations before reloading the models, as _meta.apps may change
model_name = model_name.lower()
try:
related_old = {
f.model for f in
self.apps.get_model(app_label, model_name)._meta.get_all_related_objects()
}
except LookupError:
related_old = set()
self._reload_one_model(app_label, model_name)
# Reload models if there are relations
model = self.apps.get_model(app_label, model_name)
related_m2m = {f.rel.to for f, _ in model._meta.get_m2m_with_model()}
for rel_model in related_old.union(related_m2m):
self._reload_one_model(rel_model._meta.app_label, rel_model._meta.model_name)
if related_m2m:
# Re-render this model after related models have been reloaded
self._reload_one_model(app_label, model_name)
def _reload_one_model(self, app_label, model_name):
self.apps.unregister_model(app_label, model_name)
self.models[app_label, model_name].render(self.apps)
def clone(self): def clone(self):
"Returns an exact copy of this ProjectState" "Returns an exact copy of this ProjectState"
return ProjectState( new_state = ProjectState(
models={k: v.clone() for k, v in self.models.items()}, models={k: v.clone() for k, v in self.models.items()},
real_apps=self.real_apps, real_apps=self.real_apps,
) )
if 'apps' in self.__dict__:
new_state.apps = self.apps.clone()
return new_state
@cached_property @cached_property
def apps(self): def apps(self):
@ -147,6 +186,31 @@ class StateApps(Apps):
else: else:
do_pending_lookups(model) do_pending_lookups(model)
def clone(self):
"""
Return a clone of this registry, mainly used by the migration framework.
"""
clone = StateApps([], {})
clone.all_models = copy.deepcopy(self.all_models)
clone.app_configs = copy.deepcopy(self.app_configs)
return clone
def register_model(self, app_label, model):
self.all_models[app_label][model._meta.model_name] = model
if app_label not in self.app_configs:
self.app_configs[app_label] = AppConfigStub(app_label)
self.app_configs[app_label].models = OrderedDict()
self.app_configs[app_label].models[model._meta.model_name] = model
self.clear_cache()
def unregister_model(self, app_label, model_name):
try:
del self.all_models[app_label][model_name]
del self.app_configs[app_label].models[model_name]
except KeyError:
pass
self.clear_cache()
class ModelState(object): class ModelState(object):
""" """
@ -368,7 +432,7 @@ class ModelState(object):
for mgr_name, manager in self.managers: for mgr_name, manager in self.managers:
body[mgr_name] = manager body[mgr_name] = manager
# Then, make a Model object # Then, make a Model object (apps.register_model is called in __new__)
return type( return type(
str(self.name), str(self.name),
bases, bases,

View File

@ -410,7 +410,7 @@ class AutodetectorTests(TestCase):
"Shortcut to make ProjectStates from lists of predefined models" "Shortcut to make ProjectStates from lists of predefined models"
project_state = ProjectState() project_state = ProjectState()
for model_state in model_states: for model_state in model_states:
project_state.add_model_state(model_state.clone()) project_state.add_model(model_state.clone())
return project_state return project_state
def test_arrange_for_graph(self): def test_arrange_for_graph(self):

View File

@ -223,7 +223,7 @@ class ExecutorTests(MigrationTestBase):
global_apps.get_app_config("migrations").models["author"] = migrations_apps.get_model("migrations", "author") global_apps.get_app_config("migrations").models["author"] = migrations_apps.get_model("migrations", "author")
try: try:
migration = executor.loader.get_migration("auth", "0001_initial") migration = executor.loader.get_migration("auth", "0001_initial")
self.assertEqual(executor.detect_soft_applied(None, migration), True) self.assertEqual(executor.detect_soft_applied(None, migration)[0], True)
finally: finally:
connection.introspection.table_names = old_table_names connection.introspection.table_names = old_table_names
del global_apps.get_app_config("migrations").models["author"] del global_apps.get_app_config("migrations").models["author"]

View File

@ -828,12 +828,13 @@ class OperationTests(OperationTestBase):
]) ])
self.assertTableExists("test_rmflmm_pony_stables") self.assertTableExists("test_rmflmm_pony_stables")
with_field_state = project_state.clone()
operations = [migrations.RemoveField("Pony", "stables")] operations = [migrations.RemoveField("Pony", "stables")]
self.apply_operations("test_rmflmm", project_state, operations=operations) project_state = self.apply_operations("test_rmflmm", project_state, operations=operations)
self.assertTableNotExists("test_rmflmm_pony_stables") self.assertTableNotExists("test_rmflmm_pony_stables")
# And test reversal # And test reversal
self.unapply_operations("test_rmflmm", project_state, operations=operations) self.unapply_operations("test_rmflmm", with_field_state, operations=operations)
self.assertTableExists("test_rmflmm_pony_stables") self.assertTableExists("test_rmflmm_pony_stables")
def test_remove_field_m2m_with_through(self): def test_remove_field_m2m_with_through(self):

View File

@ -159,7 +159,7 @@ class StateTests(TestCase):
Tests rendering a ProjectState into an Apps. Tests rendering a ProjectState into an Apps.
""" """
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState( project_state.add_model(ModelState(
app_label="migrations", app_label="migrations",
name="Tag", name="Tag",
fields=[ fields=[
@ -168,7 +168,7 @@ class StateTests(TestCase):
("hidden", models.BooleanField()), ("hidden", models.BooleanField()),
], ],
)) ))
project_state.add_model_state(ModelState( project_state.add_model(ModelState(
app_label="migrations", app_label="migrations",
name="SubTag", name="SubTag",
fields=[ fields=[
@ -187,7 +187,7 @@ class StateTests(TestCase):
base_mgr = models.Manager() base_mgr = models.Manager()
mgr1 = FoodManager('a', 'b') mgr1 = FoodManager('a', 'b')
mgr2 = FoodManager('x', 'y', c=3, d=4) mgr2 = FoodManager('x', 'y', c=3, d=4)
project_state.add_model_state(ModelState( project_state.add_model(ModelState(
app_label="migrations", app_label="migrations",
name="Food", name="Food",
fields=[ fields=[
@ -324,21 +324,21 @@ class StateTests(TestCase):
# Make a ProjectState and render it # Make a ProjectState and render it
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(A)) project_state.add_model(ModelState.from_model(A))
project_state.add_model_state(ModelState.from_model(B)) project_state.add_model(ModelState.from_model(B))
project_state.add_model_state(ModelState.from_model(C)) project_state.add_model(ModelState.from_model(C))
project_state.add_model_state(ModelState.from_model(D)) project_state.add_model(ModelState.from_model(D))
project_state.add_model_state(ModelState.from_model(E)) project_state.add_model(ModelState.from_model(E))
project_state.add_model_state(ModelState.from_model(F)) project_state.add_model(ModelState.from_model(F))
final_apps = project_state.apps final_apps = project_state.apps
self.assertEqual(len(final_apps.get_models()), 6) self.assertEqual(len(final_apps.get_models()), 6)
# Now make an invalid ProjectState and make sure it fails # Now make an invalid ProjectState and make sure it fails
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(A)) project_state.add_model(ModelState.from_model(A))
project_state.add_model_state(ModelState.from_model(B)) project_state.add_model(ModelState.from_model(B))
project_state.add_model_state(ModelState.from_model(C)) project_state.add_model(ModelState.from_model(C))
project_state.add_model_state(ModelState.from_model(F)) project_state.add_model(ModelState.from_model(F))
with self.assertRaises(InvalidBasesError): with self.assertRaises(InvalidBasesError):
project_state.apps project_state.apps
@ -358,8 +358,8 @@ class StateTests(TestCase):
# Make a ProjectState and render it # Make a ProjectState and render it
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(A)) project_state.add_model(ModelState.from_model(A))
project_state.add_model_state(ModelState.from_model(B)) project_state.add_model(ModelState.from_model(B))
self.assertEqual(len(project_state.apps.get_models()), 2) self.assertEqual(len(project_state.apps.get_models()), 2)
def test_equality(self): def test_equality(self):
@ -369,7 +369,7 @@ class StateTests(TestCase):
# Test two things that should be equal # Test two things that should be equal
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState( project_state.add_model(ModelState(
"migrations", "migrations",
"Tag", "Tag",
[ [
@ -388,7 +388,7 @@ class StateTests(TestCase):
# Make a very small change (max_len 99) and see if that affects it # Make a very small change (max_len 99) and see if that affects it
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState( project_state.add_model(ModelState(
"migrations", "migrations",
"Tag", "Tag",
[ [
@ -428,20 +428,20 @@ class StateTests(TestCase):
# Make a valid ProjectState and render it # Make a valid ProjectState and render it
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(Author)) project_state.add_model(ModelState.from_model(Author))
project_state.add_model_state(ModelState.from_model(Book)) project_state.add_model(ModelState.from_model(Book))
project_state.add_model_state(ModelState.from_model(Magazine)) project_state.add_model(ModelState.from_model(Magazine))
self.assertEqual(len(project_state.apps.get_models()), 3) self.assertEqual(len(project_state.apps.get_models()), 3)
# now make an invalid one with a ForeignKey # now make an invalid one with a ForeignKey
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(Book)) project_state.add_model(ModelState.from_model(Book))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
project_state.apps project_state.apps
# and another with ManyToManyField # and another with ManyToManyField
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(Magazine)) project_state.add_model(ModelState.from_model(Magazine))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
project_state.apps project_state.apps
@ -461,13 +461,13 @@ class StateTests(TestCase):
# If we just stick it into an empty state it should fail # If we just stick it into an empty state it should fail
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(TestModel)) project_state.add_model(ModelState.from_model(TestModel))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
project_state.apps project_state.apps
# If we include the real app it should succeed # If we include the real app it should succeed
project_state = ProjectState(real_apps=["contenttypes"]) project_state = ProjectState(real_apps=["contenttypes"])
project_state.add_model_state(ModelState.from_model(TestModel)) project_state.add_model(ModelState.from_model(TestModel))
rendered_state = project_state.apps rendered_state = project_state.apps
self.assertEqual( self.assertEqual(
len([x for x in rendered_state.get_models() if x._meta.app_label == "migrations"]), len([x for x in rendered_state.get_models() if x._meta.app_label == "migrations"]),
@ -498,8 +498,8 @@ class StateTests(TestCase):
# Make a valid ProjectState and render it # Make a valid ProjectState and render it
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(ModelState.from_model(Author)) project_state.add_model(ModelState.from_model(Author))
project_state.add_model_state(ModelState.from_model(Book)) project_state.add_model(ModelState.from_model(Book))
self.assertEqual( self.assertEqual(
[name for name, field in project_state.models["migrations", "book"].fields], [name for name, field in project_state.models["migrations", "book"].fields],
["id", "author"], ["id", "author"],
@ -534,6 +534,6 @@ class ModelStateTests(TestCase):
self.assertEqual(repr(state), "<ModelState: 'app.Model'>") self.assertEqual(repr(state), "<ModelState: 'app.Model'>")
project_state = ProjectState() project_state = ProjectState()
project_state.add_model_state(state) project_state.add_model(state)
with self.assertRaisesMessage(InvalidBasesError, "Cannot resolve bases for [<ModelState: 'app.Model'>]"): with self.assertRaisesMessage(InvalidBasesError, "Cannot resolve bases for [<ModelState: 'app.Model'>]"):
project_state.apps project_state.apps