Fixed #18782 -- Prevented sql_flush to flush views

Thanks rodolfo_3 for the report and the initial patch, and
Josh Smeaton, Shai Berger and Tim Graham for the reviews.
This commit is contained in:
Claude Paroz 2014-09-21 00:00:52 +02:00
parent b8cdc7dcc3
commit ed297061a6
3 changed files with 28 additions and 7 deletions

View File

@ -129,9 +129,9 @@ def sql_flush(style, connection, only_django=False, reset_sequences=True, allow_
models and are in INSTALLED_APPS will be included. models and are in INSTALLED_APPS will be included.
""" """
if only_django: if only_django:
tables = connection.introspection.django_table_names(only_existing=True) tables = connection.introspection.django_table_names(only_existing=True, include_views=False)
else: else:
tables = connection.introspection.table_names() tables = connection.introspection.table_names(include_views=False)
seqs = connection.introspection.sequence_list() if reset_sequences else () seqs = connection.introspection.sequence_list() if reset_sequences else ()
statements = connection.ops.sql_flush(style, tables, seqs, allow_cascade) statements = connection.ops.sql_flush(style, tables, seqs, allow_cascade)
return statements return statements

View File

@ -1274,17 +1274,20 @@ class BaseDatabaseIntrospection(object):
""" """
return self.table_name_converter(name) return self.table_name_converter(name)
def table_names(self, cursor=None): def table_names(self, cursor=None, include_views=False):
""" """
Returns a list of names of all tables that exist in the database. Returns a list of names of all tables that exist in the database.
The returned table list is sorted by Python's default sorting. We The returned table list is sorted by Python's default sorting. We
do NOT use database's ORDER BY here to avoid subtle differences do NOT use database's ORDER BY here to avoid subtle differences
in sorting order between databases. in sorting order between databases.
""" """
def get_names(cursor):
return sorted([ti.name for ti in self.get_table_list(cursor)
if include_views or ti.type == 't'])
if cursor is None: if cursor is None:
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
return sorted([ti.name for ti in self.get_table_list(cursor) if ti.type == 't']) return get_names(cursor)
return sorted([ti.name for ti in self.get_table_list(cursor) if ti.type == 't']) return get_names(cursor)
def get_table_list(self, cursor): def get_table_list(self, cursor):
""" """
@ -1293,7 +1296,7 @@ class BaseDatabaseIntrospection(object):
""" """
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method') raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method')
def django_table_names(self, only_existing=False): def django_table_names(self, only_existing=False, include_views=True):
""" """
Returns a list of all table names that have associated Django models and Returns a list of all table names that have associated Django models and
are in INSTALLED_APPS. are in INSTALLED_APPS.
@ -1312,7 +1315,7 @@ class BaseDatabaseIntrospection(object):
tables.update(f.m2m_db_table() for f in model._meta.local_many_to_many) tables.update(f.m2m_db_table() for f in model._meta.local_many_to_many)
tables = list(tables) tables = list(tables)
if only_existing: if only_existing:
existing_tables = self.table_names() existing_tables = self.table_names(include_views=include_views)
tables = [ tables = [
t t
for t in tables for t in tables

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import connection from django.db import connection
from django.db.utils import DatabaseError
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from .models import Reporter, Article from .models import Reporter, Article
@ -34,6 +35,23 @@ class IntrospectionTests(TestCase):
tl = connection.introspection.django_table_names(only_existing=False) tl = connection.introspection.django_table_names(only_existing=False)
self.assertIs(type(tl), list) self.assertIs(type(tl), list)
def test_table_names_with_views(self):
with connection.cursor() as cursor:
try:
cursor.execute(
'CREATE VIEW introspection_article_view AS SELECT headline '
'from introspection_article;')
except DatabaseError as e:
if 'insufficient privileges' in str(e):
self.skipTest("The test user has no CREATE VIEW privileges")
else:
raise
self.assertIn('introspection_article_view',
connection.introspection.table_names(include_views=True))
self.assertNotIn('introspection_article_view',
connection.introspection.table_names())
def test_installed_models(self): def test_installed_models(self):
tables = [Article._meta.db_table, Reporter._meta.db_table] tables = [Article._meta.db_table, Reporter._meta.db_table]
models = connection.introspection.installed_models(tables) models = connection.introspection.installed_models(tables)