From 05656f2388b1989c9e99e1ff2aae8b2e1c805af2 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 25 Sep 2013 13:47:46 +0100 Subject: [PATCH] Add equality support for Project/ModelState --- django/db/migrations/state.py | 21 ++++++++++++++++++ tests/migrations/test_state.py | 40 ++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/django/db/migrations/state.py b/django/db/migrations/state.py index 0e532d3fdc..25b4b2b102 100644 --- a/django/db/migrations/state.py +++ b/django/db/migrations/state.py @@ -59,6 +59,14 @@ class ProjectState(object): models[(model_state.app_label, model_state.name.lower())] = model_state return cls(models) + def __eq__(self, other): + if set(self.models.keys()) != set(other.models.keys()): + return False + return all(model == other.models[key] for key, model in self.models.items()) + + def __ne__(self, other): + return not (self == other) + class ModelState(object): """ @@ -167,3 +175,16 @@ class ModelState(object): if fname == name: return field raise ValueError("No field called %s on model %s" % (name, self.name)) + + def __eq__(self, other): + return ( + (self.app_label == other.app_label) and + (self.name == other.name) and + (len(self.fields) == len(other.fields)) and + all((k1 == k2 and (f1.deconstruct()[1:] == f2.deconstruct()[1:])) for (k1, f1), (k2, f2) in zip(self.fields, other.fields)) and + (self.options == other.options) and + (self.bases == other.bases) + ) + + def __ne__(self, other): + return not (self == other) diff --git a/tests/migrations/test_state.py b/tests/migrations/test_state.py index 4707349176..210897dc66 100644 --- a/tests/migrations/test_state.py +++ b/tests/migrations/test_state.py @@ -175,3 +175,43 @@ class StateTests(TestCase): project_state.add_model_state(ModelState.from_model(F)) with self.assertRaises(InvalidBasesError): project_state.render() + + def test_equality(self): + """ + Tests that == and != are implemented correctly. + """ + + # Test two things that should be equal + project_state = ProjectState() + project_state.add_model_state(ModelState( + "migrations", + "Tag", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=100)), + ("hidden", models.BooleanField()), + ], + {}, + None, + )) + other_state = project_state.clone() + self.assertEqual(project_state, project_state) + self.assertEqual(project_state, other_state) + self.assertEqual(project_state != project_state, False) + self.assertEqual(project_state != other_state, False) + + # Make a very small change (max_len 99) and see if that affects it + project_state = ProjectState() + project_state.add_model_state(ModelState( + "migrations", + "Tag", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=99)), + ("hidden", models.BooleanField()), + ], + {}, + None, + )) + self.assertNotEqual(project_state, other_state) + self.assertEqual(project_state == other_state, False)