[1.7.x] Fixed #22474 -- Made migration recorder aware of multiple databases

Thanks Tim Graham for the review.
Backport of 7c54f8cce from master.
This commit is contained in:
Claude Paroz 2014-04-30 16:53:20 +02:00
parent bb5c7e4e8d
commit 1084456ac2
2 changed files with 21 additions and 7 deletions

View File

@ -1,5 +1,6 @@
from django.apps.registry import Apps
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
from django.utils.timezone import now
@ -16,6 +17,7 @@ class MigrationRecorder(object):
a row in the table always means a migration is applied.
"""
@python_2_unicode_compatible
class Migration(models.Model):
app = models.CharField(max_length=255)
name = models.CharField(max_length=255)
@ -26,9 +28,16 @@ class MigrationRecorder(object):
app_label = "migrations"
db_table = "django_migrations"
def __str__(self):
return "Migration %s for %s" % (self.name, self.app)
def __init__(self, connection):
self.connection = connection
@property
def migration_qs(self):
return self.Migration.objects.using(self.connection.alias)
def ensure_schema(self):
"""
Ensures the table exists and has the correct schema.
@ -46,25 +55,24 @@ class MigrationRecorder(object):
Returns a set of (app, name) of applied migrations.
"""
self.ensure_schema()
return set(tuple(x) for x in self.Migration.objects.values_list("app", "name"))
return set(tuple(x) for x in self.migration_qs.values_list("app", "name"))
def record_applied(self, app, name):
"""
Records that a migration was applied.
"""
self.ensure_schema()
self.Migration.objects.create(app=app, name=name)
self.migration_qs.create(app=app, name=name)
def record_unapplied(self, app, name):
"""
Records that a migration was unapplied.
"""
self.ensure_schema()
self.Migration.objects.filter(app=app, name=name).delete()
self.migration_qs.filter(app=app, name=name).delete()
@classmethod
def flush(cls):
def flush(self):
"""
Deletes all migration records. Useful if you're testing migrations.
"""
cls.Migration.objects.all().delete()
self.migration_qs.all().delete()

View File

@ -1,7 +1,7 @@
from unittest import skipIf
from django.test import TestCase, override_settings
from django.db import connection
from django.db import connection, connections
from django.db.migrations.loader import MigrationLoader, AmbiguityError
from django.db.migrations.recorder import MigrationRecorder
from django.utils import six
@ -26,6 +26,12 @@ class RecorderTests(TestCase):
recorder.applied_migrations(),
set([("myapp", "0432_ponies")]),
)
# That should not affect records of another database
recorder_other = MigrationRecorder(connections['other'])
self.assertEqual(
recorder_other.applied_migrations(),
set(),
)
recorder.record_unapplied("myapp", "0432_ponies")
self.assertEqual(
recorder.applied_migrations(),