Simplified DatabaseOperations.sql_flush() on Oracle and PostgreSQL.

Added early return to decrease an indentation level.
This commit is contained in:
Jon Dufresne 2020-04-17 10:44:27 +02:00 committed by Mariusz Felisiak
parent 8bcca47e83
commit 8005829bb9
2 changed files with 64 additions and 65 deletions

View File

@ -405,21 +405,22 @@ END;
return lru_cache(maxsize=512)(self.__foreign_key_constraints) return lru_cache(maxsize=512)(self.__foreign_key_constraints)
def sql_flush(self, style, tables, sequences, allow_cascade=False): def sql_flush(self, style, tables, sequences, allow_cascade=False):
if tables: if not tables:
return []
truncated_tables = {table.upper() for table in tables} truncated_tables = {table.upper() for table in tables}
constraints = set() constraints = set()
# Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE # Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE foreign
# foreign keys which Django doesn't define. Emulate the # keys which Django doesn't define. Emulate the PostgreSQL behavior
# PostgreSQL behavior which truncates all dependent tables by # which truncates all dependent tables by manually retrieving all
# manually retrieving all foreign key constraints and resolving # foreign key constraints and resolving dependencies.
# dependencies.
for table in tables: for table in tables:
for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade): for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade):
if allow_cascade: if allow_cascade:
truncated_tables.add(foreign_table) truncated_tables.add(foreign_table)
constraints.add((foreign_table, constraint)) constraints.add((foreign_table, constraint))
sql = [ sql = [
"%s %s %s %s %s %s %s %s;" % ( '%s %s %s %s %s %s %s %s;' % (
style.SQL_KEYWORD('ALTER'), style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'), style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)), style.SQL_FIELD(self.quote_name(table)),
@ -430,13 +431,13 @@ END;
style.SQL_KEYWORD('INDEX'), style.SQL_KEYWORD('INDEX'),
) for table, constraint in constraints ) for table, constraint in constraints
] + [ ] + [
"%s %s %s;" % ( '%s %s %s;' % (
style.SQL_KEYWORD('TRUNCATE'), style.SQL_KEYWORD('TRUNCATE'),
style.SQL_KEYWORD('TABLE'), style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)), style.SQL_FIELD(self.quote_name(table)),
) for table in truncated_tables ) for table in truncated_tables
] + [ ] + [
"%s %s %s %s %s %s;" % ( '%s %s %s %s %s %s;' % (
style.SQL_KEYWORD('ALTER'), style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'), style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)), style.SQL_FIELD(self.quote_name(table)),
@ -445,12 +446,10 @@ END;
style.SQL_FIELD(self.quote_name(constraint)), style.SQL_FIELD(self.quote_name(constraint)),
) for table, constraint in constraints ) for table, constraint in constraints
] ]
# Since we've just deleted all the rows, running our sequence # Since we've just deleted all the rows, running our sequence ALTER
# ALTER code will reset the sequence to 0. # code will reset the sequence to 0.
sql.extend(self.sequence_reset_by_name_sql(style, sequences)) sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql return sql
else:
return []
def sequence_reset_by_name_sql(self, style, sequences): def sequence_reset_by_name_sql(self, style, sequences):
sql = [] sql = []

View File

@ -118,12 +118,14 @@ class DatabaseOperations(BaseDatabaseOperations):
return "SET TIME ZONE %s" return "SET TIME ZONE %s"
def sql_flush(self, style, tables, sequences, allow_cascade=False): def sql_flush(self, style, tables, sequences, allow_cascade=False):
if tables: if not tables:
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows return []
# us to truncate tables referenced by a foreign key in any other
# table. # Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
# to truncate tables referenced by a foreign key in any other table.
tables_sql = ', '.join( tables_sql = ', '.join(
style.SQL_FIELD(self.quote_name(table)) for table in tables) style.SQL_FIELD(self.quote_name(table)) for table in tables
)
if allow_cascade: if allow_cascade:
sql = ['%s %s %s;' % ( sql = ['%s %s %s;' % (
style.SQL_KEYWORD('TRUNCATE'), style.SQL_KEYWORD('TRUNCATE'),
@ -137,8 +139,6 @@ class DatabaseOperations(BaseDatabaseOperations):
)] )]
sql.extend(self.sequence_reset_by_name_sql(style, sequences)) sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql return sql
else:
return []
def sequence_reset_by_name_sql(self, style, sequences): def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements