Fixed #35704 -- Fixed reduction for AddIndex subclasses.

This commit is contained in:
Adam Johnson 2024-08-22 19:26:42 +01:00 committed by Sarah Boyce
parent ad7f8129f3
commit f5ddd54986
4 changed files with 90 additions and 39 deletions

View File

@ -979,7 +979,7 @@ class AddIndex(IndexOperation):
return [] return []
if isinstance(operation, RenameIndex) and self.index.name == operation.old_name: if isinstance(operation, RenameIndex) and self.index.name == operation.old_name:
self.index.name = operation.new_name self.index.name = operation.new_name
return [AddIndex(model_name=self.model_name, index=self.index)] return [self.__class__(model_name=self.model_name, index=self.index)]
return super().reduce(operation, app_label) return super().reduce(operation, app_label)

View File

@ -7,9 +7,11 @@ from importlib import import_module
from django.apps import apps from django.apps import apps
from django.db import connection, connections, migrations, models from django.db import connection, connections, migrations, models
from django.db.migrations.migration import Migration from django.db.migrations.migration import Migration
from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.serializer import serializer_factory
from django.db.migrations.state import ProjectState from django.db.migrations.state import ProjectState
from django.test import TransactionTestCase from django.test import SimpleTestCase, TransactionTestCase
from django.test.utils import extend_sys_path from django.test.utils import extend_sys_path
from django.utils.module_loading import module_dir from django.utils.module_loading import module_dir
@ -400,3 +402,38 @@ class OperationTestBase(MigrationTestBase):
) )
) )
return self.apply_operations(app_label, ProjectState(), operations) return self.apply_operations(app_label, ProjectState(), operations)
class OptimizerTestBase(SimpleTestCase):
"""Common functions to help test the optimizer."""
def optimize(self, operations, app_label):
"""
Handy shortcut for getting results + number of loops
"""
optimizer = MigrationOptimizer()
return optimizer.optimize(operations, app_label), optimizer._iterations
def serialize(self, value):
return serializer_factory(value).serialize()[0]
def assertOptimizesTo(
self, operations, expected, exact=None, less_than=None, app_label=None
):
result, iterations = self.optimize(operations, app_label or "migrations")
result = [self.serialize(f) for f in result]
expected = [self.serialize(f) for f in expected]
self.assertEqual(expected, result)
if exact is not None and iterations != exact:
raise self.failureException(
"Optimization did not take exactly %s iterations (it took %s)"
% (exact, iterations)
)
if less_than is not None and iterations >= less_than:
raise self.failureException(
"Optimization did not take less than %s iterations (it took %s)"
% (less_than, iterations)
)
def assertDoesNotOptimize(self, operations, **kwargs):
self.assertOptimizesTo(operations, operations, **kwargs)

View File

@ -1,49 +1,17 @@
from django.db import migrations, models from django.db import migrations, models
from django.db.migrations import operations from django.db.migrations import operations
from django.db.migrations.optimizer import MigrationOptimizer from django.db.migrations.optimizer import MigrationOptimizer
from django.db.migrations.serializer import serializer_factory
from django.db.models.functions import Abs from django.db.models.functions import Abs
from django.test import SimpleTestCase
from .models import EmptyManager, UnicodeModel from .models import EmptyManager, UnicodeModel
from .test_base import OptimizerTestBase
class OptimizerTests(SimpleTestCase): class OptimizerTests(OptimizerTestBase):
""" """
Tests the migration autodetector. Tests the migration optimizer.
""" """
def optimize(self, operations, app_label):
"""
Handy shortcut for getting results + number of loops
"""
optimizer = MigrationOptimizer()
return optimizer.optimize(operations, app_label), optimizer._iterations
def serialize(self, value):
return serializer_factory(value).serialize()[0]
def assertOptimizesTo(
self, operations, expected, exact=None, less_than=None, app_label=None
):
result, iterations = self.optimize(operations, app_label or "migrations")
result = [self.serialize(f) for f in result]
expected = [self.serialize(f) for f in expected]
self.assertEqual(expected, result)
if exact is not None and iterations != exact:
raise self.failureException(
"Optimization did not take exactly %s iterations (it took %s)"
% (exact, iterations)
)
if less_than is not None and iterations >= less_than:
raise self.failureException(
"Optimization did not take less than %s iterations (it took %s)"
% (less_than, iterations)
)
def assertDoesNotOptimize(self, operations, **kwargs):
self.assertOptimizesTo(operations, operations, **kwargs)
def test_none_app_label(self): def test_none_app_label(self):
optimizer = MigrationOptimizer() optimizer = MigrationOptimizer()
with self.assertRaisesMessage(TypeError, "app_label must be a str"): with self.assertRaisesMessage(TypeError, "app_label must be a str"):

View File

@ -1,8 +1,9 @@
import unittest import unittest
from migrations.test_base import OperationTestBase from migrations.test_base import OperationTestBase, OptimizerTestBase
from django.db import IntegrityError, NotSupportedError, connection, transaction from django.db import IntegrityError, NotSupportedError, connection, transaction
from django.db.migrations.operations import RemoveIndex, RenameIndex
from django.db.migrations.state import ProjectState from django.db.migrations.state import ProjectState
from django.db.migrations.writer import OperationWriter from django.db.migrations.writer import OperationWriter
from django.db.models import CheckConstraint, Index, Q, UniqueConstraint from django.db.models import CheckConstraint, Index, Q, UniqueConstraint
@ -30,7 +31,7 @@ except ImportError:
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
@modify_settings(INSTALLED_APPS={"append": "migrations"}) @modify_settings(INSTALLED_APPS={"append": "migrations"})
class AddIndexConcurrentlyTests(OperationTestBase): class AddIndexConcurrentlyTests(OptimizerTestBase, OperationTestBase):
app_label = "test_add_concurrently" app_label = "test_add_concurrently"
def test_requires_atomic_false(self): def test_requires_atomic_false(self):
@ -129,6 +130,51 @@ class AddIndexConcurrentlyTests(OperationTestBase):
) )
self.assertIndexNotExists(table_name, ["pink"]) self.assertIndexNotExists(table_name, ["pink"])
def test_reduce_add_remove_concurrently(self):
self.assertOptimizesTo(
[
AddIndexConcurrently(
"Pony",
Index(fields=["pink"], name="pony_pink_idx"),
),
RemoveIndex("Pony", "pony_pink_idx"),
],
[],
)
def test_reduce_add_remove(self):
self.assertOptimizesTo(
[
AddIndexConcurrently(
"Pony",
Index(fields=["pink"], name="pony_pink_idx"),
),
RemoveIndexConcurrently("Pony", "pony_pink_idx"),
],
[],
)
def test_reduce_add_rename(self):
self.assertOptimizesTo(
[
AddIndexConcurrently(
"Pony",
Index(fields=["pink"], name="pony_pink_idx"),
),
RenameIndex(
"Pony",
old_name="pony_pink_idx",
new_name="pony_pink_index",
),
],
[
AddIndexConcurrently(
"Pony",
Index(fields=["pink"], name="pony_pink_index"),
),
],
)
@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.") @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests.")
@modify_settings(INSTALLED_APPS={"append": "migrations"}) @modify_settings(INSTALLED_APPS={"append": "migrations"})