diff --git a/django/contrib/gis/db/backends/postgis/schema.py b/django/contrib/gis/db/backends/postgis/schema.py index 417bec6d426..9f2029728c6 100644 --- a/django/contrib/gis/db/backends/postgis/schema.py +++ b/django/contrib/gis/db/backends/postgis/schema.py @@ -7,6 +7,7 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): geom_index_ops_nd = 'GIST_GEOMETRY_OPS_ND' sql_add_geometry_column = "SELECT AddGeometryColumn(%(table)s, %(column)s, %(srid)s, %(geom_type)s, %(dim)s)" + sql_drop_geometry_column = "SELECT DropGeometryColumn(%(table)s, %(column)s)" sql_alter_geometry_column_not_null = "ALTER TABLE %(table)s ALTER COLUMN %(column)s SET NOT NULL" sql_add_spatial_index = "CREATE INDEX %(index)s ON %(table)s USING %(index_type)s (%(column)s %(ops)s)" @@ -88,3 +89,15 @@ class PostGISSchemaEditor(DatabaseSchemaEditor): for sql in self.geometry_sql: self.execute(sql) self.geometry_sql = [] + + def remove_field(self, model, field): + from django.contrib.gis.db.models.fields import GeometryField + if not isinstance(field, GeometryField) or self.connection.ops.spatial_version < (2, 0): + super(PostGISSchemaEditor, self).remove_field(model, field) + + self.execute( + self.sql_drop_geometry_column % { + "table": self.geo_quote_name(model._meta.db_table), + "column": self.geo_quote_name(field.column), + } + ) diff --git a/django/contrib/gis/tests/gis_migrations/test_operations.py b/django/contrib/gis/tests/gis_migrations/test_operations.py index e9c40575759..2655b2c61cc 100644 --- a/django/contrib/gis/tests/gis_migrations/test_operations.py +++ b/django/contrib/gis/tests/gis_migrations/test_operations.py @@ -10,6 +10,11 @@ from django.test import TransactionTestCase if HAS_SPATIAL_DB: from django.contrib.gis.db.models import fields + try: + from django.contrib.gis.models import GeometryColumns + HAS_GEOMETRY_COLUMNS = True + except ImportError: + HAS_GEOMETRY_COLUMNS = False @skipUnless(HAS_SPATIAL_DB, "Spatial db is required.") @@ -23,6 +28,9 @@ class OperationTests(TransactionTestCase): def assertColumnExists(self, table, column): self.assertIn(column, [c.name for c in self.get_table_description(table)]) + def assertColumnNotExists(self, table, column): + self.assertNotIn(column, [c.name for c in self.get_table_description(table)]) + def apply_operations(self, app_label, project_state, operations): migration = Migration('name', app_label) migration.operations = operations @@ -30,6 +38,17 @@ class OperationTests(TransactionTestCase): return migration.apply(project_state, editor) def set_up_test_model(self): + # Delete the tables if they already exist + with connection.cursor() as cursor: + try: + cursor.execute("DROP TABLE %s" % connection.ops.quote_name("gis_neighborhood")) + except: + pass + else: + if HAS_GEOMETRY_COLUMNS: + cursor.execute("DELETE FROM geometry_columns WHERE %s = %%s" % ( + GeometryColumns.table_name_col(),), ["gis_neighborhood"]) + operations = [migrations.CreateModel( "Neighborhood", [ @@ -57,11 +76,27 @@ class OperationTests(TransactionTestCase): self.assertColumnExists("gis_neighborhood", "path") # Test GeometryColumns when available - try: - from django.contrib.gis.models import GeometryColumns - except ImportError: - return - self.assertEqual( - GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(), - 2 - ) + if HAS_GEOMETRY_COLUMNS: + self.assertEqual( + GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(), + 2 + ) + + def test_remove_gis_field(self): + """ + Tests the RemoveField operation with a GIS-enabled column. + """ + project_state = self.set_up_test_model() + operation = migrations.RemoveField("Neighborhood", "geom") + new_state = project_state.clone() + operation.state_forwards("gis", new_state) + with connection.schema_editor() as editor: + operation.database_forwards("gis", editor, project_state, new_state) + self.assertColumnNotExists("gis_neighborhood", "geom") + + # Test GeometryColumns when available + if HAS_GEOMETRY_COLUMNS: + self.assertEqual( + GeometryColumns.objects.filter(**{GeometryColumns.table_name_col(): "gis_neighborhood"}).count(), + 0 + )