diff --git a/django/contrib/gis/db/backends/postgis/schema.py b/django/contrib/gis/db/backends/postgis/schema.py index 181826789be..417bec6d426 100644 --- a/django/contrib/gis/db/backends/postgis/schema.py +++ b/django/contrib/gis/db/backends/postgis/schema.py @@ -10,36 +10,27 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): sql_alter_geometry_column_not_null = "ALTER TABLE %(table)s ALTER COLUMN %(column)s SET NOT NULL" sql_add_spatial_index = "CREATE INDEX %(index)s ON %(table)s USING %(index_type)s (%(column)s %(ops)s)" + def __init__(self, *args, **kwargs): + super(PostGISSchemaEditor, self).__init__(*args, **kwargs) + self.geometry_sql = [] + def geo_quote_name(self, name): return self.connection.ops.geo_quote_name(name) - def create_model(self, model): + def column_sql(self, model, field, include_default=False): from django.contrib.gis.db.models.fields import GeometryField - # Do model creation first - super(PostGISSchemaEditor, self).create_model(model) - # Now add any spatial field SQL - sqls = [] - for field in model._meta.local_fields: - if isinstance(field, GeometryField): - sqls.extend(self.spatial_field_sql(model, field)) - for sql in sqls: - self.execute(sql) - - def spatial_field_sql(self, model, field): - """ - Takes a GeometryField and returns a list of SQL to execute to - create its spatial indexes. - """ - output = [] + if not isinstance(field, GeometryField): + return super(PostGISSchemaEditor, self).column_sql(model, field, include_default) if field.geography or self.connection.ops.geometry: # Geography and Geometry (PostGIS 2.0+) columns are # created normally. - pass + column_sql = super(PostGISSchemaEditor, self).column_sql(model, field, include_default) else: + column_sql = None, None # Geometry columns are created by the `AddGeometryColumn` # stored procedure. - output.append( + self.geometry_sql.append( self.sql_add_geometry_column % { "table": self.geo_quote_name(model._meta.db_table), "column": self.geo_quote_name(field.column), @@ -48,8 +39,9 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): "dim": field.dim, } ) + if not field.null: - output.append( + self.geometry_sql.append( self.sql_alter_geometry_column_not_null % { "table": self.quote_name(model._meta.db_table), "column": self.quote_name(field.column), @@ -72,7 +64,7 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): index_ops = '' else: index_ops = self.geom_index_ops - output.append( + self.geometry_sql.append( self.sql_add_spatial_index % { "index": self.quote_name('%s_%s_id' % (model._meta.db_table, field.column)), "table": self.quote_name(model._meta.db_table), @@ -81,5 +73,18 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): "ops": index_ops, } ) + return column_sql - return output + def create_model(self, model): + super(PostGISSchemaEditor, self).create_model(model) + # Create geometry columns + for sql in self.geometry_sql: + self.execute(sql) + self.geometry_sql = [] + + def add_field(self, model, field): + super(PostGISSchemaEditor, self).add_field(model, field) + # Create geometry columns + for sql in self.geometry_sql: + self.execute(sql) + self.geometry_sql = [] diff --git a/django/contrib/gis/tests/gis_migrations/test_operations.py b/django/contrib/gis/tests/gis_migrations/test_operations.py new file mode 100644 index 00000000000..e9c40575759 --- /dev/null +++ b/django/contrib/gis/tests/gis_migrations/test_operations.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals + +from unittest import skipUnless + +from django.contrib.gis.tests.utils import HAS_SPATIAL_DB +from django.db import connection, migrations, models +from django.db.migrations.migration import Migration +from django.db.migrations.state import ProjectState +from django.test import TransactionTestCase + +if HAS_SPATIAL_DB: + from django.contrib.gis.db.models import fields + + +@skipUnless(HAS_SPATIAL_DB, "Spatial db is required.") +class OperationTests(TransactionTestCase): + available_apps = ["django.contrib.gis.tests.gis_migrations"] + + def get_table_description(self, table): + with connection.cursor() as cursor: + return connection.introspection.get_table_description(cursor, table) + + def assertColumnExists(self, table, column): + self.assertIn(column, [c.name for c in self.get_table_description(table)]) + + def apply_operations(self, app_label, project_state, operations): + migration = Migration('name', app_label) + migration.operations = operations + with connection.schema_editor() as editor: + return migration.apply(project_state, editor) + + def set_up_test_model(self): + operations = [migrations.CreateModel( + "Neighborhood", + [ + ("id", models.AutoField(primary_key=True)), + ('name', models.CharField(max_length=100, unique=True)), + ('geom', fields.MultiPolygonField(srid=4326, null=True)), + ], + )] + return self.apply_operations('gis', ProjectState(), operations) + + def test_add_gis_field(self): + """ + Tests the AddField operation with a GIS-enabled column. + """ + project_state = self.set_up_test_model() + operation = migrations.AddField( + "Neighborhood", + "path", + fields.LineStringField(srid=4326, null=True, blank=True), + ) + new_state = project_state.clone() + operation.state_forwards("gis", new_state) + with connection.schema_editor() as editor: + operation.database_forwards("gis", editor, project_state, new_state) + self.assertColumnExists("gis_neighborhood", "path") + + # Test GeometryColumns when available + try: + from django.contrib.gis.models import GeometryColumns + except ImportError: + return + self.assertEqual( + GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(), + 2 + )