diff --git a/django/contrib/gis/db/backends/postgis/creation.py b/django/contrib/gis/db/backends/postgis/creation.py index 406dc4e487..4447b64db8 100644 --- a/django/contrib/gis/db/backends/postgis/creation.py +++ b/django/contrib/gis/db/backends/postgis/creation.py @@ -1,12 +1,23 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation +from django.utils.functional import cached_property + class PostGISCreation(DatabaseCreation): geom_index_type = 'GIST' geom_index_ops = 'GIST_GEOMETRY_OPS' geom_index_ops_nd = 'GIST_GEOMETRY_OPS_ND' + @cached_property + def template_postgis(self): + template_postgis = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis') + cursor = self.connection.cursor() + cursor.execute('SELECT 1 FROM pg_database WHERE datname = %s LIMIT 1;', (template_postgis,)) + if cursor.fetchone(): + return template_postgis + return None + def sql_indexes_for_field(self, model, f, style): "Return any spatial index creation SQL for the field." from django.contrib.gis.db.models.fields import GeometryField @@ -67,5 +78,19 @@ class PostGISCreation(DatabaseCreation): return output def sql_table_creation_suffix(self): - postgis_template = getattr(settings, 'POSTGIS_TEMPLATE', 'template_postgis') - return ' TEMPLATE %s' % self.connection.ops.quote_name(postgis_template) + if self.template_postgis is not None: + return ' TEMPLATE %s' % ( + self.connection.ops.quote_name(self.template_postgis),) + return '' + + def _create_test_db(self, verbosity, autoclobber): + test_database_name = super(PostGISCreation, self)._create_test_db(verbosity, autoclobber) + if self.template_postgis is None: + # Connect to the test database in order to create the postgis extension + self.connection.close() + self.connection.settings_dict["NAME"] = test_database_name + cursor = self.connection.cursor() + cursor.execute("CREATE EXTENSION postgis") + cursor.commit() + + return test_database_name