Fixed #26720 -- Prevented invalid CreateModel optimizations of related fields.

This commit is contained in:
Simon Charette 2017-02-02 22:09:12 -05:00 committed by Tim Graham
parent a97845a823
commit ad82900ad9
3 changed files with 77 additions and 18 deletions

View File

@ -7,9 +7,10 @@ from .utils import is_referenced_by_foreign_key
class FieldOperation(Operation):
def __init__(self, model_name, name):
def __init__(self, model_name, name, field=None):
self.model_name = model_name
self.name = name
self.field = field
@cached_property
def model_name_lower(self):
@ -29,22 +30,42 @@ class FieldOperation(Operation):
name_lower = name.lower()
if name_lower == self.model_name_lower:
return True
field = getattr(self, 'field', None)
if field and field.remote_field:
remote_app_label, remote_model_name = self.model_to_key(field.remote_field.model)
if (remote_model_name == name_lower and app_label is None or
not remote_app_label or remote_app_label == app_label):
return True
through = getattr(field.remote_field, 'through', None)
if through and self.model_to_key(through) == (app_label, name_lower):
through_app_label, through_model_name = self.model_to_key(through)
if (through_model_name == name_lower and app_label is None or
not through_app_label or through_app_label == app_label):
if self.field:
if self.field.remote_field:
remote_app_label, remote_model_name = self.model_to_key(self.field.remote_field.model)
if (remote_model_name == name_lower and app_label is None or
not remote_app_label or remote_app_label == app_label):
return True
return False
through = getattr(self.field.remote_field, 'through', None)
if through and self.model_to_key(through) == (app_label, name_lower):
through_app_label, through_model_name = self.model_to_key(through)
if (through_model_name == name_lower and app_label is None or
not through_app_label or through_app_label == app_label):
return True
return False
return True
def references_field(self, model_name, name, app_label=None):
return self.references_model(model_name, app_label) and name.lower() == self.name_lower
if self.field:
model_name_lower = model_name.lower()
remote_field = self.field.remote_field
if remote_field:
remote_app_label, remote_model_name = self.model_to_key(remote_field.model)
if (remote_model_name == model_name_lower and
(app_label is None or not remote_app_label or remote_app_label == app_label)):
# TODO: Consider to_fields/from_fields.
return True
through = getattr(remote_field, 'through', None)
if through and self.model_to_key(through) == (app_label, model_name_lower):
through_app_label, through_model_name = self.model_to_key(through)
if (through_model_name == model_name_lower and
(app_label is None or not through_app_label or through_app_label == app_label) and
(remote_field.through_fields is None or name in remote_field.through_fields)):
return True
elif model_name_lower == self.model_name_lower and name == self.name:
return True
return False
return True
def reduce(self, operation, in_between, app_label=None):
return (
@ -57,9 +78,8 @@ class AddField(FieldOperation):
"""Add a field to a model."""
def __init__(self, model_name, name, field, preserve_default=True):
self.field = field
self.preserve_default = preserve_default
super().__init__(model_name, name)
super().__init__(model_name, name, field)
def deconstruct(self):
kwargs = {
@ -173,6 +193,12 @@ class RemoveField(FieldOperation):
def describe(self):
return "Remove field %s from %s" % (self.name, self.model_name)
def reduce(self, operation, in_between, app_label=None):
from .models import DeleteModel
if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
return [operation]
return super().reduce(operation, in_between, app_label=app_label)
class AlterField(FieldOperation):
"""
@ -181,9 +207,8 @@ class AlterField(FieldOperation):
"""
def __init__(self, model_name, name, field, preserve_default=True):
self.field = field
self.preserve_default = preserve_default
super().__init__(model_name, name)
super().__init__(model_name, name, field)
def deconstruct(self):
kwargs = {

View File

@ -225,6 +225,11 @@ class DeleteModel(ModelOperation):
if self.allow_migrate_model(schema_editor.connection.alias, model):
schema_editor.create_model(model)
def references_model(self, name, app_label=None):
# The deleted model could be referencing the specified model through
# related fields.
return True
def describe(self):
return "Delete model %s" % self.name

View File

@ -270,6 +270,35 @@ class OptimizerTests(SimpleTestCase):
app_label="testapp",
)
# This could be optimized a bit more but it generates a valid set of
# operations.
self.assertOptimizesTo(
[
migrations.CreateModel('Book', [('name', models.CharField(max_length=255))]),
migrations.CreateModel('Person', [('name', models.CharField(max_length=255))]),
migrations.AddField('book', 'author', models.ForeignKey('test_app.Person', models.CASCADE)),
migrations.CreateModel('Review', [('book', models.ForeignKey('test_app.Book', models.CASCADE))]),
migrations.CreateModel('Reviewer', [('name', models.CharField(max_length=255))]),
migrations.AddField('review', 'reviewer', models.ForeignKey('test_app.Reviewer', models.CASCADE)),
migrations.RemoveField('book', 'author'),
migrations.DeleteModel('Person'),
],
[
migrations.CreateModel('Person', [('name', models.CharField(max_length=255))]),
migrations.CreateModel('Book', [
('name', models.CharField(max_length=255)),
('author', models.ForeignKey('test_app.Person', models.CASCADE)),
]),
migrations.CreateModel('Reviewer', [('name', models.CharField(max_length=255))]),
migrations.CreateModel('Review', [
('book', models.ForeignKey('test_app.Book', models.CASCADE)),
('reviewer', models.ForeignKey('test_app.Reviewer', models.CASCADE)),
]),
migrations.RemoveField('book', 'author'),
migrations.DeleteModel('Person'),
],
)
def test_create_model_add_field(self):
"""
AddField should optimize into CreateModel.